├── scripts ├── config │ ├── hw_simulator │ │ ├── io_speed │ │ │ ├── paper_flash1.yaml │ │ │ ├── paper_flash2.yaml │ │ │ └── paper_flash0.5.yaml │ │ ├── processor │ │ │ ├── A18_APL1V08.yaml │ │ │ ├── A4_APL0398.yaml │ │ │ ├── A5X_APL5498.yaml │ │ │ ├── A5_APL7498.yaml │ │ │ ├── A6X_APL5598.yaml │ │ │ ├── A6_APL0598.yaml │ │ │ ├── A7_APL5698.yaml │ │ │ ├── A8X_APL1021.yaml │ │ │ ├── A8_APL1011.yaml │ │ │ ├── A9X_APL1021.yaml │ │ │ ├── A9_APL1022.yaml │ │ │ ├── A10_Fusion_APL1W24.yaml │ │ │ ├── A11_Bionic_APL1W72.yaml │ │ │ ├── A12_Bionic_APL1W81.yaml │ │ │ ├── A13_Bionic_APL1W85.yaml │ │ │ ├── A14_Bionic_APL1W01.yaml │ │ │ ├── A17_Pro_APL1V02.yaml │ │ │ ├── A18_Pro_APL1V07.yaml │ │ │ ├── A10X_Fusion_APL1071.yaml │ │ │ ├── A12X_Bionic_APL1083.yaml │ │ │ ├── A12Z_Bionic_APL1083.yaml │ │ │ ├── A15_Bionic_APL1W07_.yaml │ │ │ └── A16_Bionic_APL1W10_.yaml │ │ └── default.yaml │ ├── __init__.py │ ├── predictor │ │ ├── loss │ │ │ ├── abstopk_ce.yaml │ │ │ └── absthreshold_ce.yaml │ │ └── base.yaml │ ├── cache_hooks │ │ ├── write_only.yaml │ │ ├── weighting_current_cache.yaml │ │ └── approximate_caching.yaml │ ├── masking_hooks │ │ ├── dip.yaml │ │ ├── glu_pruning.yaml │ │ ├── turbosparse_wrap.yaml │ │ ├── cats.yaml │ │ ├── dip_free_params.yaml │ │ ├── predictor.yaml │ │ ├── up_pruning.yaml │ │ └── gate_pruning.yaml │ ├── evaluation │ │ ├── mmlu.yaml │ │ ├── arc_easy.yaml │ │ ├── perplexity.yaml │ │ └── ppl_mmlu.yaml │ ├── dense_model │ │ ├── dummy.yaml │ │ ├── opt-350M.yaml │ │ ├── phi-3-mini.yaml │ │ ├── llama-v3-8B.yaml │ │ ├── phi-3-medium.yaml │ │ ├── mistral-v01-7B.yaml │ │ └── turbosparse-mistral.yaml │ ├── data_preprocessing │ │ └── default.yaml │ ├── experiment │ │ ├── evaluate_llm.yaml │ │ └── store_activations.yaml │ ├── data │ │ ├── dummy.yaml │ │ └── wikitext.yaml │ ├── adapter │ │ └── lora.yaml │ └── config.yaml ├── __init__.py └── run_experiment.py ├── tests ├── .DS_Store ├── __init__.py ├── scripts │ ├── __init__.py │ └── test_evaluate_llm.py └── contextual_sparsity │ ├── __init__.py │ ├── test_masking_hooks.py │ ├── test_cache_hooks.py │ ├── test_evaluation.py │ ├── test_sparse_linear.py │ └── hw_simulator │ ├── test_cache.py │ └── test_simulator.py ├── figures ├── methods.png └── results+dip_ca.png ├── contextual_sparsity ├── hw_simulator │ ├── __init__.py │ ├── simulator_hooks.py │ ├── constants.py │ └── cache.py ├── __init__.py ├── mask │ └── __init__.py ├── utils │ ├── __init__.py │ ├── submodule.py │ ├── logging.py │ ├── lr_scheduler.py │ ├── turbosparse.py │ ├── phi.py │ ├── sparsify.py │ ├── misc.py │ ├── tokenizers.py │ └── layer_names.py ├── nn │ ├── sparse │ │ ├── __init__.py │ │ └── linear.py │ ├── __init__.py │ ├── utils.py │ └── binarization.py ├── dense_models │ ├── __init__.py │ ├── dummy.py │ └── llm.py ├── scripts │ ├── __init__.py │ ├── llm_evaluation.py │ └── compute_activations.py ├── adapters │ ├── __init__.py │ ├── lora.py │ └── base.py ├── masking_hooks │ ├── trained │ │ ├── __init__.py │ │ ├── model.py │ │ ├── turbosparse_wrap.py │ │ ├── optimization.py │ │ ├── hook.py │ │ └── loss.py │ ├── __init__.py │ ├── glu_pruning.py │ ├── partial_glu_pruning.py │ └── dip.py ├── evaluation │ ├── hooks │ │ ├── __init__.py │ │ ├── perplexity.py │ │ ├── base.py │ │ └── memory.py │ ├── __init__.py │ ├── lm_eval.py │ ├── perplexity.py │ └── predictor.py └── data │ ├── __init__.py │ ├── data_processing.py │ ├── activations.py │ ├── dummy.py │ ├── hf.py │ └── slimpajama.py ├── .gitignore ├── requirements.txt └── LICENSE /scripts/config/hw_simulator/io_speed/paper_flash1.yaml: -------------------------------------------------------------------------------- 1 | dram: 60e9 # GB/s 2 | flash: 1e9 # GB/s 3 | -------------------------------------------------------------------------------- /scripts/config/hw_simulator/io_speed/paper_flash2.yaml: -------------------------------------------------------------------------------- 1 | dram: 60e9 # GB/s 2 | flash: 2e9 # GB/s 3 | -------------------------------------------------------------------------------- /tests/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qualcomm-AI-research/dynamic-sparsity/HEAD/tests/.DS_Store -------------------------------------------------------------------------------- /scripts/config/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | -------------------------------------------------------------------------------- /figures/methods.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qualcomm-AI-research/dynamic-sparsity/HEAD/figures/methods.png -------------------------------------------------------------------------------- /scripts/config/hw_simulator/io_speed/paper_flash0.5.yaml: -------------------------------------------------------------------------------- 1 | 2 | dram: 60e9 # GB/s 3 | flash: 0.5e9 # GB/s 4 | -------------------------------------------------------------------------------- /contextual_sparsity/hw_simulator/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | -------------------------------------------------------------------------------- /figures/results+dip_ca.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qualcomm-AI-research/dynamic-sparsity/HEAD/figures/results+dip_ca.png -------------------------------------------------------------------------------- /contextual_sparsity/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | """The contextual_sparsity module.""" 5 | -------------------------------------------------------------------------------- /contextual_sparsity/mask/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from .hooks import MaskingHook 5 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | """Contains modules can be executed from CLI.""" 5 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | """Submodule containing unit and integration tests.""" 5 | -------------------------------------------------------------------------------- /contextual_sparsity/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from .sparsify import sparsify_linear 5 | -------------------------------------------------------------------------------- /scripts/config/predictor/loss/abstopk_ce.yaml: -------------------------------------------------------------------------------- 1 | _target_: contextual_sparsity.masking_hooks.trained.loss.build_abstopk_cross_entropy_loss 2 | 3 | k: null 4 | keep: 0.1 5 | -------------------------------------------------------------------------------- /tests/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | """Unit and integration tests for the code in `scripts`.""" 5 | -------------------------------------------------------------------------------- /contextual_sparsity/nn/sparse/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from .linear import SimulatedSparseLinear, SparseLinear 5 | -------------------------------------------------------------------------------- /scripts/config/hw_simulator/processor/A18_APL1V08.yaml: -------------------------------------------------------------------------------- 1 | # @package hw_simulator 2 | 3 | dram: 4 | capacity: 8.0e9 5 | io_speed: 6 | dram: 60.0e9 # GB/s 7 | flash: 1e9 # GB/s 8 | -------------------------------------------------------------------------------- /scripts/config/hw_simulator/processor/A4_APL0398.yaml: -------------------------------------------------------------------------------- 1 | # @package hw_simulator 2 | 3 | dram: 4 | capacity: 0.512e9 5 | io_speed: 6 | dram: 3.2e9 # GB/s 7 | flash: 1e9 # GB/s 8 | -------------------------------------------------------------------------------- /scripts/config/hw_simulator/processor/A5X_APL5498.yaml: -------------------------------------------------------------------------------- 1 | # @package hw_simulator 2 | 3 | dram: 4 | capacity: 1.0e9 5 | io_speed: 6 | dram: 12.8e9 # GB/s 7 | flash: 1e9 # GB/s 8 | -------------------------------------------------------------------------------- /scripts/config/hw_simulator/processor/A5_APL7498.yaml: -------------------------------------------------------------------------------- 1 | # @package hw_simulator 2 | 3 | dram: 4 | capacity: 0.512e9 5 | io_speed: 6 | dram: 6.4e9 # GB/s 7 | flash: 1e9 # GB/s 8 | -------------------------------------------------------------------------------- /scripts/config/hw_simulator/processor/A6X_APL5598.yaml: -------------------------------------------------------------------------------- 1 | # @package hw_simulator 2 | 3 | dram: 4 | capacity: 1.0e9 5 | io_speed: 6 | dram: 17.0e9 # GB/s 7 | flash: 1e9 # GB/s 8 | -------------------------------------------------------------------------------- /scripts/config/hw_simulator/processor/A6_APL0598.yaml: -------------------------------------------------------------------------------- 1 | # @package hw_simulator 2 | 3 | dram: 4 | capacity: 1.0e9 5 | io_speed: 6 | dram: 8.5e9 # GB/s 7 | flash: 1e9 # GB/s 8 | -------------------------------------------------------------------------------- /scripts/config/hw_simulator/processor/A7_APL5698.yaml: -------------------------------------------------------------------------------- 1 | # @package hw_simulator 2 | 3 | dram: 4 | capacity: 1.0e9 5 | io_speed: 6 | dram: 12.8e9 # GB/s 7 | flash: 1e9 # GB/s 8 | -------------------------------------------------------------------------------- /scripts/config/hw_simulator/processor/A8X_APL1021.yaml: -------------------------------------------------------------------------------- 1 | # @package hw_simulator 2 | 3 | dram: 4 | capacity: 2.0e9 5 | io_speed: 6 | dram: 25.6e9 # GB/s 7 | flash: 1e9 # GB/s 8 | -------------------------------------------------------------------------------- /scripts/config/hw_simulator/processor/A8_APL1011.yaml: -------------------------------------------------------------------------------- 1 | # @package hw_simulator 2 | 3 | dram: 4 | capacity: 2.0e9 5 | io_speed: 6 | dram: 12.8e9 # GB/s 7 | flash: 1e9 # GB/s 8 | -------------------------------------------------------------------------------- /scripts/config/hw_simulator/processor/A9X_APL1021.yaml: -------------------------------------------------------------------------------- 1 | # @package hw_simulator 2 | 3 | dram: 4 | capacity: 4.0e9 5 | io_speed: 6 | dram: 51.2e9 # GB/s 7 | flash: 1e9 # GB/s 8 | -------------------------------------------------------------------------------- /scripts/config/hw_simulator/processor/A9_APL1022.yaml: -------------------------------------------------------------------------------- 1 | # @package hw_simulator 2 | 3 | dram: 4 | capacity: 2.0e9 5 | io_speed: 6 | dram: 25.6e9 # GB/s 7 | flash: 1e9 # GB/s 8 | -------------------------------------------------------------------------------- /scripts/config/hw_simulator/processor/A10_Fusion_APL1W24.yaml: -------------------------------------------------------------------------------- 1 | # @package hw_simulator 2 | 3 | dram: 4 | capacity: 3.0e9 5 | io_speed: 6 | dram: 25.6e9 # GB/s 7 | flash: 1e9 # GB/s 8 | -------------------------------------------------------------------------------- /scripts/config/hw_simulator/processor/A11_Bionic_APL1W72.yaml: -------------------------------------------------------------------------------- 1 | # @package hw_simulator 2 | 3 | dram: 4 | capacity: 3.0e9 5 | io_speed: 6 | dram: 34.1e9 # GB/s 7 | flash: 1e9 # GB/s 8 | -------------------------------------------------------------------------------- /scripts/config/hw_simulator/processor/A12_Bionic_APL1W81.yaml: -------------------------------------------------------------------------------- 1 | # @package hw_simulator 2 | 3 | dram: 4 | capacity: 4.0e9 5 | io_speed: 6 | dram: 34.1e9 # GB/s 7 | flash: 1e9 # GB/s 8 | -------------------------------------------------------------------------------- /scripts/config/hw_simulator/processor/A13_Bionic_APL1W85.yaml: -------------------------------------------------------------------------------- 1 | # @package hw_simulator 2 | 3 | dram: 4 | capacity: 4.0e9 5 | io_speed: 6 | dram: 34.1e9 # GB/s 7 | flash: 1e9 # GB/s 8 | -------------------------------------------------------------------------------- /scripts/config/hw_simulator/processor/A14_Bionic_APL1W01.yaml: -------------------------------------------------------------------------------- 1 | # @package hw_simulator 2 | 3 | dram: 4 | capacity: 4.0e9 5 | io_speed: 6 | dram: 34.1e9 # GB/s 7 | flash: 1e9 # GB/s 8 | -------------------------------------------------------------------------------- /scripts/config/hw_simulator/processor/A17_Pro_APL1V02.yaml: -------------------------------------------------------------------------------- 1 | # @package hw_simulator 2 | 3 | dram: 4 | capacity: 8.0e9 5 | io_speed: 6 | dram: 51.2e9 # GB/s 7 | flash: 1e9 # GB/s 8 | -------------------------------------------------------------------------------- /scripts/config/hw_simulator/processor/A18_Pro_APL1V07.yaml: -------------------------------------------------------------------------------- 1 | # @package hw_simulator 2 | 3 | dram: 4 | capacity: 8.0e9 5 | io_speed: 6 | dram: 60.0e9 # GB/s 7 | flash: 1e9 # GB/s 8 | -------------------------------------------------------------------------------- /scripts/config/predictor/loss/absthreshold_ce.yaml: -------------------------------------------------------------------------------- 1 | _target_: contextual_sparsity.masking_hooks.trained.loss.build_absthreshold_cross_entropy_loss 2 | 3 | threshold: null 4 | keep: null 5 | -------------------------------------------------------------------------------- /contextual_sparsity/dense_models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from .dummy import DummyModel 5 | from .llm import load_hf_model 6 | -------------------------------------------------------------------------------- /scripts/config/hw_simulator/processor/A10X_Fusion_APL1071.yaml: -------------------------------------------------------------------------------- 1 | # @package hw_simulator 2 | 3 | dram: 4 | capacity: 4.0e9 5 | io_speed: 6 | dram: 51.2e9 # GB/s 7 | flash: 1e9 # GB/s 8 | -------------------------------------------------------------------------------- /scripts/config/hw_simulator/processor/A12X_Bionic_APL1083.yaml: -------------------------------------------------------------------------------- 1 | # @package hw_simulator 2 | 3 | dram: 4 | capacity: 6.0e9 5 | io_speed: 6 | dram: 68.2e9 # GB/s 7 | flash: 1e9 # GB/s 8 | -------------------------------------------------------------------------------- /scripts/config/hw_simulator/processor/A12Z_Bionic_APL1083.yaml: -------------------------------------------------------------------------------- 1 | # @package hw_simulator 2 | 3 | dram: 4 | capacity: 16.0e9 5 | io_speed: 6 | dram: 68.2e9 # GB/s 7 | flash: 1e9 # GB/s 8 | -------------------------------------------------------------------------------- /scripts/config/hw_simulator/processor/A15_Bionic_APL1W07_.yaml: -------------------------------------------------------------------------------- 1 | # @package hw_simulator 2 | 3 | dram: 4 | capacity: 6.0e9 5 | io_speed: 6 | dram: 34.1e9 # GB/s 7 | flash: 1e9 # GB/s 8 | -------------------------------------------------------------------------------- /scripts/config/hw_simulator/processor/A16_Bionic_APL1W10_.yaml: -------------------------------------------------------------------------------- 1 | # @package hw_simulator 2 | 3 | dram: 4 | capacity: 6.0e9 5 | io_speed: 6 | dram: 51.2e9 # GB/s 7 | flash: 1e9 # GB/s 8 | -------------------------------------------------------------------------------- /tests/contextual_sparsity/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | """Unit and integration tests for the code in `contextual_sparsity`.""" 5 | -------------------------------------------------------------------------------- /scripts/config/cache_hooks/write_only.yaml: -------------------------------------------------------------------------------- 1 | _target_: contextual_sparsity.mask.cache_hooks.link_masking_hooks_with_cache 2 | method: null # only write to cache, masking is not cache-aware 3 | kwargs: null 4 | -------------------------------------------------------------------------------- /scripts/config/masking_hooks/dip.yaml: -------------------------------------------------------------------------------- 1 | _target_: contextual_sparsity.masking_hooks.build_optimized_dip_masking_hooks 2 | model_id: ${model_id} 3 | data_id: ${data_id} 4 | layers_to_sparsify: all 5 | keep: ??? -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Pycharm 2 | .idea 3 | 4 | # VSCode 5 | .vscode 6 | 7 | # pytest 8 | .coverage 9 | .pytest 10 | .pytest_cache 11 | 12 | # Python 13 | *__pycache__* 14 | *.py[cod] 15 | *.cpython-36.pyc 16 | *.egg-info -------------------------------------------------------------------------------- /contextual_sparsity/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from .compute_activations import store_activations_main 5 | from .llm_evaluation import evaluate_llm_main 6 | -------------------------------------------------------------------------------- /contextual_sparsity/adapters/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from .base import Adapter, add_adapters 5 | from .lora import LoRA 6 | from .training import load_adapters, train_adapters 7 | -------------------------------------------------------------------------------- /contextual_sparsity/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from .binarization import RandomKMask, RandomMask, StaticMask, ThresholdMask, TopKMask 5 | from .sparse import * 6 | from .utils import * 7 | -------------------------------------------------------------------------------- /scripts/config/evaluation/mmlu.yaml: -------------------------------------------------------------------------------- 1 | mmlu: 2 | _target_: contextual_sparsity.evaluation.run_lm_eval 3 | tasks: 4 | - mmlu 5 | store_full_output: false 6 | num_fewshot: 5 7 | batch_size: null 8 | limit: null 9 | device: ${hardware.device} 10 | -------------------------------------------------------------------------------- /scripts/config/evaluation/arc_easy.yaml: -------------------------------------------------------------------------------- 1 | arc_easy: 2 | _target_: contextual_sparsity.evaluation.run_lm_eval 3 | tasks: 4 | - arc_easy 5 | store_full_output: false 6 | num_fewshot: 5 7 | batch_size: null 8 | limit: null 9 | device: ${hardware.device} 10 | 11 | -------------------------------------------------------------------------------- /contextual_sparsity/masking_hooks/trained/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from .hook import build_predictor_masking_hooks 5 | from .model import SimplePredictor 6 | from .turbosparse_wrap import build_original_turbosparse_hooks 7 | -------------------------------------------------------------------------------- /contextual_sparsity/evaluation/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from .base import CollectHooksOutput, EvaluationHook 5 | from .memory import MEMORY, MLP_DENSITY, MLP_MEMORY, WEIGHT_DENSITY, Memory 6 | from .perplexity import CROSS_ENTROPY, PERPLEXITY, Perplexity 7 | -------------------------------------------------------------------------------- /scripts/config/masking_hooks/glu_pruning.yaml: -------------------------------------------------------------------------------- 1 | _target_: contextual_sparsity.masking_hooks.build_glu_pruning_masking_hooks 2 | model_id: ${model_id} 3 | data_id: ${data_id} 4 | preprocess_batch: ${move_dict_to_device} 5 | layers_to_sparsify: all 6 | binarization_type: topk 7 | threshold: null 8 | k: null 9 | keep: null 10 | 11 | -------------------------------------------------------------------------------- /contextual_sparsity/nn/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class Abs(nn.Module): 9 | """ 10 | Absolute value as an nn.Module layer. 11 | """ 12 | 13 | def forward(self, x): 14 | return torch.abs(x) 15 | -------------------------------------------------------------------------------- /contextual_sparsity/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from .data_processing import ( 5 | move_dict_to_device, 6 | separate_prompt, 7 | sequential_preprocessing, 8 | ) 9 | from .hf import get_dataloader 10 | from .slimpajama import get_slimpajama_dataloader 11 | -------------------------------------------------------------------------------- /contextual_sparsity/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from .hooks import ( 5 | CROSS_ENTROPY, 6 | MEMORY, 7 | MLP_DENSITY, 8 | MLP_MEMORY, 9 | PERPLEXITY, 10 | Memory, 11 | Perplexity, 12 | ) 13 | from .lm_eval import run_lm_eval 14 | from .perplexity import evaluate_sparse_perplexity 15 | -------------------------------------------------------------------------------- /scripts/config/masking_hooks/turbosparse_wrap.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | masking_hooks: 3 | _target_: contextual_sparsity.masking_hooks.build_original_turbosparse_hooks 4 | preprocess_batch: ${move_dict_to_device} 5 | model_id: ${model_id} 6 | data_id: ${data_id} 7 | layers_to_sparsify: all 8 | k: null 9 | keep: null 10 | threshold: null 11 | 12 | dense_model: 13 | remove_predictors: false -------------------------------------------------------------------------------- /scripts/config/dense_model/dummy.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model_id: dummy 3 | 4 | precision: 5 | embedding: 8 6 | lm_head: 8 7 | attention: 16 8 | mlp: 16 9 | activations: 16 10 | kv_cache: 8 11 | predictors: 16 12 | 13 | dense_model: 14 | _target_: contextual_sparsity.dense_models.dummy.DummyModel 15 | model_id: ${model_id} 16 | device: ${hardware.device} 17 | 18 | tokenizer: null 19 | -------------------------------------------------------------------------------- /contextual_sparsity/masking_hooks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from .dip import build_dip_masking_hooks, build_optimized_dip_masking_hooks 5 | from .glu_pruning import build_glu_pruning_masking_hooks 6 | from .partial_glu_pruning import build_partial_glu_pruning_masking_hooks 7 | from .trained import build_original_turbosparse_hooks, build_predictor_masking_hooks 8 | -------------------------------------------------------------------------------- /scripts/config/data_preprocessing/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | move_dict_to_device: 4 | _target_: contextual_sparsity.data.move_dict_to_device 5 | _partial_: true 6 | device: ${hardware.device} 7 | 8 | separate_prompt: 9 | _target_: contextual_sparsity.data.separate_prompt 10 | _partial_: true 11 | sequence_length: ${data.test.sequence_length} 12 | prompt_length: ${data.test.prompt_length} 13 | -------------------------------------------------------------------------------- /scripts/config/masking_hooks/cats.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | precision: 3 | predictors: ${precision.mlp} 4 | 5 | masking_hooks: 6 | _target_: contextual_sparsity.masking_hooks.build_partial_glu_pruning_masking_hooks 7 | model_id: ${model_id} 8 | data_id: ${data_id} 9 | preprocess_batch: ${move_dict_to_device} 10 | layers_to_sparsify: all 11 | binarization_type: threshold 12 | predictor_type: gate 13 | keep: ??? -------------------------------------------------------------------------------- /scripts/config/masking_hooks/dip_free_params.yaml: -------------------------------------------------------------------------------- 1 | _target_: contextual_sparsity.masking_hooks.build_dip_masking_hooks 2 | model_id: ${model_id} 3 | data_id: ${data_id} 4 | preprocess_batch: ${move_dict_to_device} 5 | layers_to_sparsify: all 6 | up_k: null 7 | down_k: null 8 | gate_k: null 9 | up_keep: null 10 | down_keep: null 11 | gate_keep: null 12 | up_threshold: null 13 | down_threshold: null 14 | gate_threshold: null 15 | -------------------------------------------------------------------------------- /scripts/config/evaluation/perplexity.yaml: -------------------------------------------------------------------------------- 1 | perplexity: 2 | _target_: contextual_sparsity.evaluation.evaluate_sparse_perplexity 3 | preprocess_batch: ${move_dict_to_device} 4 | test_data: ${data.test} 5 | evaluation_hooks: 6 | - _target_: contextual_sparsity.evaluation.hooks.Perplexity 7 | - _target_: contextual_sparsity.evaluation.hooks.Memory 8 | model_id: ${model_id} 9 | precision: ${precision} 10 | sequence_length: 2048 11 | -------------------------------------------------------------------------------- /scripts/config/masking_hooks/predictor.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /predictor: base 4 | 5 | predictor: 6 | layer_to_mask: null 7 | 8 | masking_hooks: 9 | _target_: contextual_sparsity.masking_hooks.build_predictor_masking_hooks 10 | model_id: ${model_id} 11 | layers_to_sparsify: all 12 | force_retrain: false 13 | predictor_cache_dir: ${hardware.paths.cache}/predictors/${model_id} 14 | k: null 15 | keep: null 16 | 17 | 18 | -------------------------------------------------------------------------------- /scripts/config/masking_hooks/up_pruning.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | precision: 3 | predictors: ${precision.mlp} 4 | 5 | masking_hooks: 6 | _target_: contextual_sparsity.masking_hooks.build_partial_glu_pruning_masking_hooks 7 | model_id: ${model_id} 8 | data_id: ${data_id} 9 | preprocess_batch: ${move_dict_to_device} 10 | layers_to_sparsify: all 11 | binarization_type: topk 12 | predictor_type: up 13 | threshold: null 14 | k: null 15 | keep: null -------------------------------------------------------------------------------- /scripts/config/masking_hooks/gate_pruning.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | precision: 3 | predictors: ${precision.mlp} 4 | 5 | masking_hooks: 6 | _target_: contextual_sparsity.masking_hooks.build_partial_glu_pruning_masking_hooks 7 | model_id: ${model_id} 8 | data_id: ${data_id} 9 | preprocess_batch: ${move_dict_to_device} 10 | layers_to_sparsify: all 11 | binarization_type: topk 12 | predictor_type: gate 13 | threshold: null 14 | k: null 15 | keep: null -------------------------------------------------------------------------------- /scripts/config/cache_hooks/weighting_current_cache.yaml: -------------------------------------------------------------------------------- 1 | _target_: contextual_sparsity.mask.cache_hooks.link_masking_hooks_with_cache 2 | method: weighting_current_cache # re-weight mask based on current cache content 3 | kwargs: 4 | gamma: 0 # [0, 1] 5 | fixed_top_n: 0 # whether to avoid deweighting the top_n neurons 6 | use_alpha_blend: False # whether to use alpha for alpha-blending instead of deweighting 7 | warm_up: False # whether to warm up the value of alpha from zero in the first tokens 8 | -------------------------------------------------------------------------------- /scripts/config/experiment/evaluate_llm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /dense_model: phi-3-mini 4 | - /masking_hooks: dip 5 | - /data: wikitext 6 | - /evaluation: perplexity 7 | - /cache_hooks: write_only 8 | 9 | experiment: 10 | # Function that determines the code to run 11 | script: 12 | _target_: contextual_sparsity.scripts.evaluate_llm_main 13 | # Run Identifier (by default datetime) 14 | run_id: ${now:%Y-%m-%d_%H-%M-%S} 15 | save_dir: ${hardware.paths.log}/${model_id} 16 | 17 | hw_simulator: null 18 | -------------------------------------------------------------------------------- /scripts/config/data/dummy.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | data_id: dummy 3 | 4 | data: 5 | train: 6 | _target_: contextual_sparsity.data.dummy.get_dummy_dataloader 7 | 8 | valid: 9 | _target_: contextual_sparsity.data.dummy.get_dummy_dataloader 10 | 11 | test: 12 | _target_: contextual_sparsity.data.dummy.get_dummy_dataloader 13 | 14 | calibration: 15 | _target_: contextual_sparsity.data.dummy.get_dummy_dataloader 16 | 17 | tiny: 18 | _target_: contextual_sparsity.data.dummy.get_dummy_dataloader 19 | 20 | preprocess_batch: null -------------------------------------------------------------------------------- /scripts/config/cache_hooks/approximate_caching.yaml: -------------------------------------------------------------------------------- 1 | _target_: contextual_sparsity.mask.cache_hooks.link_masking_hooks_with_cache 2 | method: approximate_caching # allows fallback to top-M neurons if they are in cache. 3 | kwargs: 4 | fixed_top_n: 0 # whether to always keep the top_n neurons (n < k) 5 | top_k: ${masking_hooks.k} # this is the number of neurons to be chosen for the mask. Only compatible with TopK masking approaches. 6 | top_m: 0 # scores in top_m are allowed to be picked if they are in cache. (m > k) 7 | gamma: 0 # if gamma > 0, use it to reweight scores based on cache slot_count statistic (e.g.: normalized LFU) 8 | -------------------------------------------------------------------------------- /scripts/config/dense_model/opt-350M.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model_id: opt-350M 3 | 4 | precision: 5 | embedding: 8 6 | lm_head: 8 7 | attention: 16 8 | mlp: 16 9 | activations: 16 10 | kv_cache: 8 11 | predictors: 16 12 | 13 | dense_model: 14 | _target_: contextual_sparsity.dense_models.load_hf_model 15 | pretrained_model_path: ${hardware.paths.models.opt-350M} 16 | model_id: ${model_id} 17 | device: ${hardware.device} 18 | 19 | tokenizer: 20 | _target_: contextual_sparsity.utils.tokenizers.load_tokenizer 21 | pretrained_model_path: ${hardware.paths.models.opt-350M} 22 | use_fast_tokenizer: false 23 | -------------------------------------------------------------------------------- /scripts/config/dense_model/phi-3-mini.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model_id: phi-3-mini 3 | 4 | precision: 5 | embedding: 8 6 | lm_head: 8 7 | attention: 16 8 | mlp: 16 9 | activations: 16 10 | kv_cache: 8 11 | predictors: 16 12 | 13 | dense_model: 14 | _target_: contextual_sparsity.dense_models.load_hf_model 15 | pretrained_model_path: ${hardware.paths.models.phi-3-mini} 16 | model_id: ${model_id} 17 | device: ${hardware.device} 18 | dtype: float16 19 | 20 | tokenizer: 21 | _target_: contextual_sparsity.utils.tokenizers.load_tokenizer 22 | pretrained_model_path: ${hardware.paths.models.phi-3-mini} 23 | -------------------------------------------------------------------------------- /scripts/config/dense_model/llama-v3-8B.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model_id: llama-v3-8B 3 | 4 | precision: 5 | embedding: 8 6 | lm_head: 8 7 | attention: 16 8 | mlp: 16 9 | activations: 16 10 | kv_cache: 8 11 | predictors: 16 12 | 13 | dense_model: 14 | _target_: contextual_sparsity.dense_models.load_hf_model 15 | pretrained_model_path: ${hardware.paths.models.llama-v3-8B} 16 | model_id: ${model_id} 17 | device: ${hardware.device} 18 | dtype: float16 19 | 20 | tokenizer: 21 | _target_: contextual_sparsity.utils.tokenizers.load_tokenizer 22 | pretrained_model_path: ${hardware.paths.models.llama-v3-8B} 23 | -------------------------------------------------------------------------------- /scripts/config/dense_model/phi-3-medium.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model_id: phi-3-medium 3 | 4 | precision: 5 | embedding: 8 6 | lm_head: 8 7 | attention: 16 8 | mlp: 16 9 | activations: 16 10 | kv_cache: 8 11 | predictors: 16 12 | 13 | dense_model: 14 | _target_: contextual_sparsity.dense_models.load_hf_model 15 | pretrained_model_path: ${hardware.paths.models.phi-3-medium} 16 | model_id: ${model_id} 17 | device: ${hardware.device} 18 | dtype: float16 19 | 20 | tokenizer: 21 | _target_: contextual_sparsity.utils.tokenizers.load_tokenizer 22 | pretrained_model_path: ${hardware.paths.models.phi-3-medium} 23 | -------------------------------------------------------------------------------- /scripts/config/dense_model/mistral-v01-7B.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model_id: mistral-v01-7B 3 | 4 | precision: 5 | embedding: 8 6 | lm_head: 8 7 | attention: 16 8 | mlp: 16 9 | activations: 16 10 | kv_cache: 8 11 | predictors: 16 12 | 13 | dense_model: 14 | _target_: contextual_sparsity.dense_models.load_hf_model 15 | pretrained_model_path: ${hardware.paths.models.mistral-v01-7B} 16 | model_id: ${model_id} 17 | device: ${hardware.device} 18 | dtype: float16 19 | 20 | tokenizer: 21 | _target_: contextual_sparsity.utils.tokenizers.load_tokenizer 22 | pretrained_model_path: ${hardware.paths.models.mistral-v01-7B} 23 | -------------------------------------------------------------------------------- /scripts/config/evaluation/ppl_mmlu.yaml: -------------------------------------------------------------------------------- 1 | mmlu: 2 | _target_: contextual_sparsity.evaluation.run_lm_eval 3 | tasks: 4 | - mmlu 5 | store_full_output: false 6 | num_fewshot: 5 7 | batch_size: null 8 | limit: null 9 | device: ${hardware.device} 10 | perplexity: 11 | _target_: contextual_sparsity.evaluation.evaluate_sparse_perplexity 12 | preprocess_batch: ${move_dict_to_device} 13 | test_data: ${data.test} 14 | evaluation_hooks: 15 | - _target_: contextual_sparsity.evaluation.hooks.Perplexity 16 | - _target_: contextual_sparsity.evaluation.hooks.Memory 17 | model_id: ${model_id} 18 | precision: ${precision} 19 | sequence_length: 2048 20 | -------------------------------------------------------------------------------- /scripts/config/experiment/store_activations.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /data: wikitext 4 | - /dense_model: ??? 5 | - _self_ 6 | 7 | experiment: 8 | # Function that determines the code to run 9 | script: 10 | _target_: contextual_sparsity.scripts.store_activations_main 11 | # Run Identifier 12 | run_id: ${model_id}/${data_id}/${activations.split}/${activations.dtype} 13 | save_dir: ${hardware.paths.cache}/activations 14 | 15 | device: ${hardware.device} 16 | 17 | activations: 18 | split: test 19 | dtype: float16 20 | # E.g. [net.decoder.layers.2.fc1.input,net.decoder.layers.4.activation_fn.output]. Default none = compute all 21 | activation_ids: ??? 22 | 23 | preprocess_batch: ${move_dict_to_device} -------------------------------------------------------------------------------- /scripts/config/dense_model/turbosparse-mistral.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | # HF URL: https://huggingface.co/PowerInfer/TurboSparse-Mistral-Instruct 3 | model_id: turbosparse-mistral 4 | 5 | precision: 6 | embedding: 8 7 | lm_head: 8 8 | attention: 16 9 | mlp: 16 10 | activations: 16 11 | kv_cache: 8 12 | predictors: 16 13 | 14 | dense_model: 15 | _target_: contextual_sparsity.dense_models.load_hf_model 16 | pretrained_model_path: ${hardware.paths.models.turbosparse-mistral} 17 | model_id: ${model_id} 18 | device: ${hardware.device} 19 | dtype: float16 20 | local_files_only: false 21 | 22 | tokenizer: 23 | _target_: contextual_sparsity.utils.tokenizers.load_tokenizer 24 | pretrained_model_path: ${hardware.paths.models.turbosparse-mistral} 25 | -------------------------------------------------------------------------------- /contextual_sparsity/utils/submodule.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from torch import nn 5 | 6 | 7 | def set_submodule(module: nn.Module, submodule_name: str, submodule: nn.Module) -> None: 8 | """ 9 | Set a submodule of a given nn.Module with the specified key (module[submodule_name] = submodule). 10 | """ 11 | subkeys = submodule_name.split(".") 12 | if len(subkeys) == 1: 13 | setattr(module, subkeys[0], submodule) 14 | else: 15 | if subkeys[0].isnumeric(): 16 | set_submodule(module[int(subkeys[0])], ".".join(subkeys[1:]), submodule) 17 | else: 18 | if not hasattr(module, subkeys[0]): 19 | set_submodule(module, subkeys[0], nn.Module()) 20 | set_submodule(module.get_submodule(subkeys[0]), ".".join(subkeys[1:]), submodule) 21 | -------------------------------------------------------------------------------- /contextual_sparsity/evaluation/hooks/perplexity.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from typing import Any, Dict 5 | 6 | import torch 7 | from torch import nn 8 | 9 | from contextual_sparsity.evaluation.hooks.base import EvaluationHook 10 | 11 | CROSS_ENTROPY = "cross_entropy" 12 | PERPLEXITY = "perplexity" 13 | 14 | 15 | # Perplexity and cross-entropy evaluation 16 | class Perplexity(EvaluationHook): 17 | """ 18 | Computes the perplexity metric. 19 | """ 20 | 21 | metric_dims = {CROSS_ENTROPY: 1, PERPLEXITY: 1} 22 | 23 | def collect_results(self, module, input, kwargs, output): 24 | valid_tokens = torch.greater_equal(kwargs["labels"].view(-1), 0).int() 25 | ntok = torch.sum(valid_tokens) 26 | return {CROSS_ENTROPY: output.loss.unsqueeze(0).repeat(ntok).unsqueeze(-1)} 27 | 28 | def attach_to(self, model: nn.Module): 29 | self._attach_to(model) 30 | 31 | def finalize(self, stats: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: 32 | stats["."][PERPLEXITY] = {"mean": torch.exp(stats["."][CROSS_ENTROPY]["mean"])} 33 | return stats 34 | -------------------------------------------------------------------------------- /contextual_sparsity/utils/logging.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import os 5 | import pickle 6 | from typing import Any, Dict, List 7 | 8 | import pandas as pd 9 | 10 | 11 | def save_pickle(x: Any, filepath: str, expect_dir=False, override=False): 12 | dirpath = os.path.dirname(filepath) 13 | if expect_dir: 14 | assert os.path.isdir(dirpath), dirpath 15 | else: 16 | os.makedirs(dirpath, exist_ok=True) 17 | assert override or not os.path.exists(filepath) 18 | 19 | with open(filepath, "wb") as f: 20 | pickle.dump(x, f) 21 | 22 | 23 | def load_pickle(filepath: str): 24 | assert os.path.exists(filepath), filepath 25 | with open(filepath, "rb") as f: 26 | x = pickle.load(f) 27 | return x 28 | 29 | 30 | class CSVLogger: 31 | def __init__(self, csv_filepath: str): 32 | self.csv_filepath = csv_filepath 33 | self._log: List[Dict[str, Any]] = [] 34 | 35 | def log(self, **values): 36 | self._log.append(values) 37 | pd.DataFrame(self._log).to_csv(self.csv_filepath, index=False) 38 | 39 | def reset(self): 40 | self._log = [] 41 | -------------------------------------------------------------------------------- /scripts/run_experiment.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import logging 5 | import os 6 | import random 7 | from typing import Any 8 | 9 | import hydra 10 | import numpy as np 11 | import torch 12 | from hydra.utils import instantiate 13 | from omegaconf import DictConfig, OmegaConf 14 | 15 | log = logging.getLogger(__name__) 16 | 17 | 18 | @hydra.main(config_path="pkg://scripts/config", config_name="config.yaml", version_base="1.3") 19 | def parse(conf: DictConfig) -> Any: 20 | # Set environment variable if specified. 21 | # This is broken in hydra because of issues with variable interpolations. 22 | # See: https://github.com/facebookresearch/hydra/issues/2800 23 | if "env" in conf: 24 | for name, value in conf.env.items(): 25 | os.environ[name] = value 26 | 27 | # Set up the random seeds 28 | torch.manual_seed(conf.experiment.seed) 29 | random.seed(conf.experiment.seed) 30 | np.random.seed(conf.experiment.seed) 31 | 32 | log.info(f"Running experiment {conf.experiment.run_id}") 33 | 34 | run = instantiate(conf.experiment.script, _partial_=True) 35 | return_values = run(conf) 36 | 37 | return return_values 38 | 39 | 40 | if __name__ == "__main__": 41 | OmegaConf.register_new_resolver("eval", eval) 42 | parse() 43 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Default dependencies 2 | # (specifying the version with '~=' allows newer "patch" versions, without upgrading the "major" or "minor" version) 3 | black~=24.4.1 4 | isort~=5.13.2 5 | mypy==1.10.0 6 | pydocstyle[toml]==6.3.0 7 | pylint==3.1.0 8 | pytest~=8.1.1 9 | pytest-cov~=5.0.0 10 | setuptools 11 | setuptools_scm[toml]>=6.2 12 | twine~=5.0.0 13 | wheel 14 | xenon==0.9.1 15 | 16 | # Machine Learning 17 | datasets~=2.19.1 18 | lm-eval~=0.4.3 19 | numpy~=1.24.4 # 1.26.4 is not found (or compatible with other libraries) 20 | pandas~=2.0.3 21 | ptflops~=0.7.3 22 | scikit-learn~=1.3.2 23 | torch~=2.3.0 # Note: 'flash_attn' whl relies on having torch version 2.3.* 24 | torchmetrics~=1.4.0 25 | transformers[sentencepiece]~=4.43.0 # sentencepiece is an optional dependency for some tokenizers in transformers 26 | 27 | # Optimization 28 | optuna~=3.6.1 29 | optuna-dashboard~=0.15.1 30 | 31 | # CLI 32 | fire~=0.6.0 33 | hydra-core~=1.3.2 34 | 35 | # Plotting 36 | matplotlib~=3.7.5 37 | protobuf~=4.25 # needed for tensorboard compatibility 38 | seaborn~=0.13.2 39 | tensorboard~=2.14.0 40 | 41 | # Jupyter 42 | ipywidgets~=8.1.2 43 | jupyter~=1.0.0 # includes notebook, IPython 44 | IPython~=8.12.3 45 | jupyterlab~=4.2.1 46 | notebook~=7.2.0 47 | 48 | # Utils 49 | h5py~=3.11.0 50 | mock~=5.1.0 51 | tqdm~=4.66.4 52 | wget~=3.2 53 | gitpython~=3.1.43 54 | -------------------------------------------------------------------------------- /contextual_sparsity/hw_simulator/simulator_hooks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from typing import Callable 5 | 6 | import torch.nn as nn 7 | 8 | 9 | class SimulatorResetHook: 10 | def __init__(self, hw_simulator_reset_fn: Callable, model: nn.Module, active: bool = True): 11 | super().__init__() 12 | self._input_handle = model.register_forward_hook(self._hook) 13 | self.hw_simulator_reset_fn = hw_simulator_reset_fn 14 | self.active = active 15 | 16 | def _hook(self, *args, **kwargs): 17 | if self.active: 18 | self.hw_simulator_reset_fn() 19 | 20 | def set_inactive(self): 21 | """ 22 | The hook can be disabled when the token generation for the whole sequence does not happen at once with a 23 | single model forward. This happens either with sequential generation (auto-regressive, without teacher forcing) 24 | or when part the sequence (the prompt) is encoded first with a dense model. 25 | """ 26 | self.active = False 27 | 28 | def set_active(self): 29 | self.active = True 30 | 31 | def is_attached(self) -> bool: 32 | return self._input_handle is not None 33 | 34 | def remove(self): 35 | if self.is_attached(): 36 | self._input_handle.remove() 37 | self._input_handle = None 38 | -------------------------------------------------------------------------------- /scripts/config/hw_simulator/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - io_speed: paper_flash1 3 | 4 | 5 | _target_: contextual_sparsity.hw_simulator.simulator.HardwareSimulator 6 | model_id: ${model_id} 7 | sequence_length: ${oc.select:data.test.sequence_length,null} 8 | prompt_length: ${oc.select:data.test.prompt_length,null} 9 | device: ${hardware.device} 10 | precision: ${precision} 11 | 12 | dram: 13 | capacity: 4e9 14 | layers_static: [lm_head,attention,kv_cache,predictors] # layers always in DRAM 15 | layers_dynamic: [mlp] # layers dynamically loaded to DRAM depending on masking/predictor 16 | layers_streamed_at_prompt_encoding: [mlp,embedding] # layers streamed once (e.g.: because not fitting in DRAM and not needed during token generation) 17 | layers_streamed_at_token_generation: [] # layers streamed for each token (e.g.: because not fitting in DRAM) 18 | concurrent_dram_flash_io: True # enable if io for a single layer can be parallelized between Flash and DRAM. 19 | cache_strategy: lfu 20 | simulate_glu_pruning: False # Used code to simulate that Up and Gate layers are dense, and only Down layers are sparsified (with TopK or other masking methods) 21 | allow_static_layers_streaming: True # For cases where static layers already do not fit in DRAM (e.g.: ablation at low DRAM capacities), this option allows to load every time the static layers from Flash to processing unit, without passing by the DRAM (if it is already full) 22 | verbose: False 23 | -------------------------------------------------------------------------------- /contextual_sparsity/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from torch.optim import Optimizer 5 | from torch.optim.lr_scheduler import LinearLR, SequentialLR 6 | 7 | 8 | class LinearWarmup(SequentialLR): 9 | def __init__( 10 | self, 11 | optimizer: Optimizer, 12 | total_iterations: int, 13 | min_factor: float = 1e-3, 14 | max_factor: float = 1.0, 15 | warmup_percentage: float = 0.1, 16 | ): 17 | """ 18 | Sequential learning rate scheduler with an initial warm-up and a cool-down phases. 19 | """ 20 | warmup_iterations = int(total_iterations * warmup_percentage) 21 | decay_iterations = total_iterations - warmup_iterations 22 | super(LinearWarmup, self).__init__( 23 | optimizer=optimizer, 24 | schedulers=[ 25 | LinearLR( 26 | optimizer=optimizer, 27 | start_factor=min_factor, 28 | end_factor=max_factor, 29 | total_iters=warmup_iterations, 30 | ), 31 | LinearLR( 32 | optimizer=optimizer, 33 | start_factor=max_factor, 34 | end_factor=min_factor, 35 | total_iters=decay_iterations, 36 | ), 37 | ], 38 | milestones=[warmup_iterations], 39 | ) 40 | -------------------------------------------------------------------------------- /contextual_sparsity/data/data_processing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 5 | 6 | import torch 7 | 8 | 9 | def separate_prompt( 10 | batch: Dict[str, torch.Tensor], prompt_length: int, sequence_length: int 11 | ) -> Tuple[Optional[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]]: 12 | """ 13 | Separate batch in prompt and rest (generated part of the sentence). 14 | If prompt_length == 0, return None for the prompt. 15 | """ 16 | assert 0 <= prompt_length < sequence_length, (prompt_length, sequence_length) 17 | 18 | if prompt_length == 0: 19 | assert batch["input_ids"].shape[1] == sequence_length 20 | return None, batch 21 | 22 | prompt, rest = dict(), dict() 23 | for k, v in batch.items(): 24 | if k == "attention_mask": 25 | continue 26 | assert v.shape[1] == sequence_length 27 | prompt[k] = v[:, :prompt_length] 28 | rest[k] = v[:, prompt_length:] 29 | return prompt, rest 30 | 31 | 32 | def move_dict_to_device( 33 | batch: Dict[str, torch.Tensor], device: Union[str, torch.device] 34 | ) -> Dict[str, torch.Tensor]: 35 | return {k: v.to(device) for k, v in batch.items()} 36 | 37 | 38 | def sequential_preprocessing(batch: Any, functions: List[Callable]) -> Any: 39 | # Applies a sequence of preprocessing functions to the input batch 40 | for fn in functions: 41 | batch = fn(batch) 42 | return batch 43 | -------------------------------------------------------------------------------- /scripts/config/adapter/lora.yaml: -------------------------------------------------------------------------------- 1 | apply_to: 'all' 2 | 3 | model: 4 | _target_: contextual_sparsity.adapters.LoRA.from_module 5 | _partial_: true 6 | rank: 32 7 | dropout_rate: 0.1 8 | 9 | training: 10 | cache_dir: ${hardware.paths.cache}/adapters/${model_id} 11 | masking_hooks: ${masking_hooks} 12 | sparse_model: ${sparse_model} 13 | 14 | ############################# 15 | # Train and Validation data # 16 | ############################# 17 | data: 18 | train_on: 19 | _target_: contextual_sparsity.data.get_slimpajama_dataloader 20 | tokenized_dataset_path: ${hardware.paths.data.tokenized_slimpajama} 21 | sequence_length: 1024 22 | batch_size: 1 23 | num_workers: ${hardware.cpu_cores} 24 | shuffle: true 25 | model_id: ${dense_model.model_id} 26 | device: ${hardware.device} 27 | valid_on: ${data.valid} 28 | preprocess_batch: ${move_dict_to_device} 29 | 30 | ########################## 31 | # Optimization Procedure # 32 | ########################## 33 | optimization: 34 | gradient_accumulation_steps: 32 35 | optimizer: 36 | _target_: torch.optim.AdamW 37 | _partial_: true 38 | lr: 1e-3 39 | weight_decay: 1e-3 40 | lr_scheduler: 41 | _target_: contextual_sparsity.utils.lr_scheduler.LinearWarmup 42 | _partial_: true 43 | optimizer: ??? 44 | total_iterations: ??? 45 | min_factor: 1e-3 46 | max_factor: 1.0 47 | n_epochs: 2 48 | validate_every: 500 # optimization steps 49 | patience: 5 50 | device: ${hardware.device} 51 | 52 | loss: kd 53 | 54 | 55 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2025 Qualcomm Technologies, Inc. 2 | 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without modification, are permitted (subject to the limitations in the disclaimer below) provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer: 8 | 9 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 10 | 11 | * Neither the name of Qualcomm Technologies, Inc. nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 12 | 13 | NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 14 | -------------------------------------------------------------------------------- /contextual_sparsity/data/activations.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from typing import Dict, Union 5 | 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import Dataset 9 | 10 | 11 | class ActivationDictDataset(Dataset): 12 | """ 13 | Wrapper class used to make a dataset out of the model activations. 14 | Each element is a dictionary {activation_id: activation_values} 15 | """ 16 | 17 | def __init__( 18 | self, 19 | activations: Dict[str, Union[np.ndarray, torch.Tensor]], 20 | flatten: bool = False, 21 | **keys, 22 | ): 23 | super().__init__() 24 | self.activations = activations 25 | self.flatten = flatten 26 | 27 | for k, activation_name in keys.items(): 28 | if activation_name not in activations: 29 | raise KeyError(f"Key {k} not found in activations") 30 | self.keys = keys 31 | self.batch_shape = self.activations[next(iter(self.keys.values()))].shape[:-1] 32 | 33 | def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: 34 | if self.flatten: 35 | multi_index = [] 36 | 37 | for dim in self.batch_shape[::-1]: 38 | multi_index.append(index % dim) 39 | index //= dim 40 | 41 | index = multi_index[::-1] 42 | else: 43 | index = [index] 44 | 45 | return {k: self.activations[layer][tuple(index)] for k, layer in self.keys.items()} 46 | 47 | def __len__(self) -> int: 48 | if self.flatten: 49 | n_datapoints = np.prod(self.batch_shape) 50 | else: 51 | n_datapoints = self.batch_shape[0] 52 | return n_datapoints 53 | -------------------------------------------------------------------------------- /contextual_sparsity/utils/turbosparse.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from typing import List 5 | 6 | from torch import nn 7 | 8 | from contextual_sparsity.utils.layer_names import MLP, get_layer_ids 9 | from contextual_sparsity.utils.submodule import set_submodule 10 | 11 | 12 | class BambooMLPWithoutPredictor(nn.Module): 13 | def __init__(self, mlp: nn.Module): 14 | """ 15 | Takes as input a BambooMLP module as defined in 16 | https://huggingface.co/PowerInfer/TurboSparse-Mistral-Instruct/blob/main/modeling_bamboo.py 17 | """ 18 | super().__init__() 19 | self.config = mlp.config 20 | self.hidden_size = mlp.hidden_size 21 | self.intermediate_size = mlp.intermediate_size 22 | self.layer_id = mlp.layer_id 23 | self.gate_proj = mlp.gate_proj 24 | self.up_proj = mlp.up_proj 25 | self.down_proj = mlp.down_proj 26 | self.act_fn = mlp.act_fn 27 | 28 | def forward(self, x, before_norm): 29 | return self.down_proj(self.act_fn(self.gate_proj(x)) * self.act_fn(self.up_proj(x))) 30 | 31 | 32 | def remove_turbosparse_predictors(model: nn.Module, model_id: str) -> List[nn.Module]: 33 | """ 34 | Replaces all existing MLPs in TurboSparse models with MLP copies that do not use a predictor. 35 | """ 36 | layer_names = get_layer_ids(model_id=model_id, layer_type=MLP, layer_names="all") 37 | 38 | predictors = [] 39 | for layer_name in layer_names: 40 | mlp_with_predictor = model.get_submodule(layer_name) 41 | predictors.append(mlp_with_predictor.predictor) 42 | mlp_without_predictor = BambooMLPWithoutPredictor(mlp_with_predictor) 43 | set_submodule(model, layer_name, mlp_without_predictor) 44 | 45 | return predictors 46 | -------------------------------------------------------------------------------- /contextual_sparsity/utils/phi.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import torch 5 | from torch import nn 6 | 7 | from contextual_sparsity.utils.layer_names import ( 8 | FC_GATE, 9 | FC_UP, 10 | MLP, 11 | MODEL_MAPS, 12 | get_layer_ids, 13 | ) 14 | from contextual_sparsity.utils.misc import split_gate_up_layer 15 | from contextual_sparsity.utils.submodule import set_submodule 16 | 17 | 18 | class Phi3SplitMLP(nn.Module): 19 | """ 20 | Wrapper for Phi architecture with separate up and gate linear layer for consistency with the other LLMs. 21 | """ 22 | 23 | def __init__(self, upgate_mlp: nn.Module): 24 | super().__init__() 25 | 26 | # Split the up and gate 27 | gate_proj, up_proj = split_gate_up_layer(upgate_mlp.gate_up_proj) 28 | 29 | self.gate_proj = gate_proj 30 | self.up_proj = up_proj 31 | self.down_proj = upgate_mlp.down_proj 32 | self.activation_fn = upgate_mlp.activation_fn 33 | 34 | def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: 35 | gate = self.gate_proj(hidden_states) 36 | up_states = self.up_proj(hidden_states) 37 | up_states = up_states * self.activation_fn(gate) 38 | 39 | return self.down_proj(up_states) 40 | 41 | 42 | def split_upgate(model: nn.Module, model_id: str): 43 | """ 44 | Replaces all existing MLPs in Phi models with MLP copies that have separate up and gate layers. 45 | """ 46 | layer_names = get_layer_ids(model_id=model_id, layer_type=MLP, layer_names="all") 47 | 48 | for layer_name in layer_names: 49 | upgate_mlp = model.get_submodule(layer_name) 50 | mlp = Phi3SplitMLP(upgate_mlp) 51 | set_submodule(model, layer_name, mlp) 52 | 53 | MODEL_MAPS[model_id][FC_GATE] = ".".join([MODEL_MAPS[model_id][MLP], "gate_proj"]) 54 | MODEL_MAPS[model_id][FC_UP] = ".".join([MODEL_MAPS[model_id][MLP], "up_proj"]) 55 | -------------------------------------------------------------------------------- /scripts/config/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - data_preprocessing: default 4 | # The experiment specific info. Contains the definition of models, data, etc... 5 | - experiment: evaluate_llm 6 | 7 | 8 | # Hardware specific values including 9 | hardware: 10 | # Device on which the experiment runs (cpu/cuda) 11 | device: cuda 12 | # Number of CPU cores (used for data loading) 13 | cpu_cores: 8 14 | # Paths of stored LLMs (each path must contain the stored model and tokenizer) 15 | paths: 16 | models: 17 | opt-350M: ??? 18 | llama-v3-8B: ??? 19 | phi-3-medium: ??? 20 | phi-3-mini: ??? 21 | turbosparse-mistral: ??? 22 | mistral-v01-7B: ??? 23 | data: 24 | wikitext: ??? 25 | tokenized_slimpajama: ??? 26 | log: ./logs 27 | cache: /tmp/cache 28 | 29 | 30 | # Experiment details 31 | experiment: 32 | 33 | script: ??? 34 | 35 | # Return value/values (optional) 36 | # This can be either a list of [minimize/maximize, ...] of the same length of the script return 37 | # Or a list of dictionaries {name: minimize/maximize} in case the scripts returns a dictionary 38 | return: maximize 39 | 40 | 41 | 42 | # Path in which logs are saved (root) 43 | save_dir: ??? 44 | 45 | # Name of the study for multiruns 46 | sweep: ??? 47 | 48 | # Experiment seed 49 | seed: 42 50 | 51 | # Override existing experiments 52 | overwrite: false 53 | 54 | # Set the required environment variables 55 | env: 56 | HF_HOME: ${hardware.paths.cache}/.hf_home 57 | TRANSFORMERS_CACHE: ${hardware.paths.cache}/.transformers_cache 58 | HF_DATASETS_CACHE: ${hardware.paths.cache}/.hf_datasets_cache 59 | LM_HARNESS_CACHE_PATH: ${hardware.paths.cache}/.lm_harness_cache_path 60 | 61 | hydra: 62 | job: 63 | chdir: true 64 | 65 | # Specifying the save path for single runs 66 | run: 67 | dir: ${experiment.save_dir}/runs/${experiment.run_id} 68 | 69 | # And for sweeps 70 | sweep: 71 | dir: ${experiment.save_dir}/sweeps/${experiment.sweep} 72 | subdir: ${hydra.job.num} 73 | -------------------------------------------------------------------------------- /contextual_sparsity/data/dummy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import torch 5 | from torch.utils.data import DataLoader, Dataset 6 | 7 | N_FEATURES = 101 8 | SEQUENCE_LENGTH = 50 9 | N_SEQUENCES = 1000 10 | BATCH_SIZE = 64 11 | 12 | 13 | def compute_labels(x: torch.Tensor) -> torch.Tensor: 14 | """ 15 | Make the label for each dummy input tensor 16 | """ 17 | return (torch.floor(x).long() % 2).int() 18 | 19 | 20 | class DummyDataset(Dataset): 21 | """ 22 | Mock dataset for testing. 23 | The dataset consists of N_SEQUENCE sequences that differ by a small additive constant. 24 | Each sequence consists of SEQUENCE_LENGTH vectors of numbers 0 to DATA_SIZE shifted by i 25 | with i being the index withing the sequence. 26 | The labels corresponds to the parity of the closest integer for each sequence. 27 | """ 28 | 29 | def __init__(self, n_features: int, sequence_length: int, n_sequences: int): 30 | dummy_sequence = torch.cat( 31 | [torch.roll(torch.arange(n_features), i).unsqueeze(0) for i in range(sequence_length)], 32 | 0, 33 | ) 34 | self.data = torch.cat( 35 | [ 36 | dummy_sequence.unsqueeze(0) + float(i) / (n_sequences + 1) 37 | for i in range(1, n_sequences + 1) 38 | ], 39 | 0, 40 | ) 41 | 42 | def __getitem__(self, index): 43 | return {"x": self.data[index], "labels": compute_labels(self.data[index])} 44 | 45 | def __len__(self): 46 | return len(self.data) 47 | 48 | 49 | def get_dummy_dataloader( 50 | n_features: int = N_FEATURES, 51 | sequence_length: int = SEQUENCE_LENGTH, 52 | n_sequences: int = N_SEQUENCES, 53 | batch_size: int = BATCH_SIZE, 54 | ) -> DataLoader: 55 | dataset = DummyDataset( 56 | n_features=n_features, sequence_length=sequence_length, n_sequences=n_sequences 57 | ) 58 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) 59 | return dataloader 60 | -------------------------------------------------------------------------------- /contextual_sparsity/masking_hooks/trained/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from typing import List, Optional, Union 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | def make_mlp( 11 | input_dim: int, 12 | output_dim: int, 13 | hidden_dims: Union[List[int], int], 14 | normalization: Optional[str] = None, 15 | ): 16 | """ 17 | Utility to build a simple MLP model 18 | """ 19 | layers: List[nn.Module] = [] 20 | if isinstance(hidden_dims, int): 21 | hidden_dims = [hidden_dims] 22 | else: 23 | hidden_dims = list(hidden_dims) 24 | layer_sizes = [input_dim] + hidden_dims 25 | 26 | for i in range(len(layer_sizes) - 1): 27 | layers.append(nn.Linear(layer_sizes[i], layer_sizes[i + 1])) 28 | if normalization == "batchnorm": 29 | layers.append(nn.BatchNorm1d(layer_sizes[i + 1])) 30 | elif normalization == "layernorm": 31 | layers.append(nn.LayerNorm(layer_sizes[i + 1])) 32 | layers.append(nn.ReLU()) 33 | 34 | layers.append(nn.Linear(layer_sizes[-1], output_dim)) 35 | return nn.Sequential(*layers) 36 | 37 | 38 | class Predictor(nn.Module): 39 | """ 40 | Abstract predictor class 41 | """ 42 | 43 | def __init__(self, input_dim: int, output_dim: int): 44 | super().__init__() 45 | self.input_dim = input_dim 46 | self.output_dim = output_dim 47 | 48 | 49 | class SimplePredictor(Predictor): 50 | """ 51 | Simple predictor class consisting of an MLP. 52 | """ 53 | 54 | def __init__( 55 | self, 56 | input_dim: int, 57 | hidden_dims: Union[List[int]], 58 | output_dim: int, 59 | ): 60 | super(SimplePredictor, self).__init__(input_dim=input_dim, output_dim=output_dim) 61 | self.hidden_dims = hidden_dims 62 | self.net = make_mlp(input_dim=input_dim, hidden_dims=hidden_dims, output_dim=output_dim) 63 | 64 | def forward(self, x): 65 | original_dtype = x.dtype 66 | x = x.type(torch.float32) 67 | return self.net(x).to(original_dtype) 68 | -------------------------------------------------------------------------------- /scripts/config/predictor/base.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - loss: abstopk_ce 3 | 4 | # layer_to_mask can be either a int (layer index) or a string referring to layer block e.g. net.decoder.layers.1 5 | layer_to_mask: ??? 6 | 7 | # input_activation can be 8 | # int: Use the input to up_layer for the layer with index input_activation 9 | # str: Use the activation with identifier corresponding to the specified string. 10 | # E.g. model.decoder.layers.1.fc1.input to use the input of the fc1 submodule of layer 1 11 | input_activation: ${predictor.layer_to_mask} 12 | model_id: ${model_id} 13 | 14 | 15 | ########################## 16 | # Predictor Architecture # 17 | ########################## 18 | model: 19 | _target_: contextual_sparsity.masking_hooks.trained.model.SimplePredictor 20 | hidden_dims: 1000 21 | 22 | ############################# 23 | # Train and Validation data # 24 | ############################# 25 | data: 26 | train: 27 | activation_data: 28 | dataloader: 29 | _target_: contextual_sparsity.data.get_slimpajama_dataloader 30 | sequence_length: 2048 31 | tokenized_dataset_path: ${hardware.paths.data.tokenized_slimpajama} 32 | batch_size: 1 33 | num_workers: ${hardware.cpu_cores} 34 | shuffle: true 35 | model_id: ${dense_model.model_id} 36 | device: ${hardware.device} 37 | take_n: 2000 38 | preprocess_batch: ${move_dict_to_device} 39 | dtype: float16 40 | memory: ${hardware.device} 41 | flatten_activations: true 42 | batch_size: 1024 43 | num_workers: 0 44 | shuffle: true 45 | valid: 46 | activation_data: 47 | dataloader: ${data.valid} 48 | preprocess_batch: ${move_dict_to_device} 49 | dtype: float16 50 | memory: ${hardware.device} 51 | flatten_activations: true 52 | batch_size: 1024 53 | shuffle: false 54 | num_workers: 0 55 | 56 | 57 | ########################## 58 | # Optimization Procedure # 59 | ########################## 60 | optimization: 61 | optimizer: 62 | _target_: torch.optim.AdamW 63 | _partial_: true 64 | lr: 5e-4 65 | weight_decay: 5e-5 66 | n_epochs: 20 67 | patience: 5 68 | device: ${hardware.device} 69 | -------------------------------------------------------------------------------- /contextual_sparsity/dense_models/dummy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from copy import deepcopy 5 | 6 | import torch 7 | from torch import nn 8 | 9 | from contextual_sparsity.data.dummy import N_FEATURES, compute_labels 10 | from contextual_sparsity.utils.layer_names import MODEL_MAPS, N_LAYERS 11 | 12 | 13 | class Identity(nn.Module): 14 | """ 15 | Mock identity module 16 | """ 17 | 18 | def forward(self, x: torch.Tensor) -> torch.Tensor: 19 | return x 20 | 21 | 22 | class DummyOutput: 23 | """ 24 | Mock dummy output for consistency with LLMs. 25 | """ 26 | 27 | def __init__(self, loss): 28 | self.loss = loss 29 | 30 | 31 | class DummyBlock(nn.Module): 32 | """ 33 | Mock dummy block for consistency with LLMs. 34 | """ 35 | 36 | def __init__(self, n_features: int): 37 | super().__init__() 38 | 39 | identity_layer = nn.Linear(n_features, n_features) 40 | identity_layer.weight.data = torch.eye(n_features) 41 | identity_layer.bias.data = torch.zeros(n_features) 42 | 43 | self.up = deepcopy(identity_layer) 44 | self.activation_fn = Identity() 45 | self.down = deepcopy(identity_layer) 46 | 47 | def forward(self, x: torch.Tensor) -> torch.Tensor: 48 | x = self.up(x) 49 | x = self.activation_fn(x) 50 | x = self.down(x) 51 | return x 52 | 53 | 54 | def compute_loss(output, labels): 55 | prediction = compute_labels(output) 56 | return torch.pow(prediction - labels, 2).float().sum(-1).mean() 57 | 58 | 59 | class DummyModel(nn.Module): 60 | def __init__( 61 | self, 62 | n_features: int = N_FEATURES, 63 | n_layers: int = MODEL_MAPS["dummy"][N_LAYERS], 64 | model_id=None, 65 | device="cpu", 66 | ): 67 | super().__init__() 68 | self.layers = nn.Sequential(*[DummyBlock(n_features) for _ in range(n_layers)]) 69 | self.to(device) 70 | 71 | def forward(self, x, labels): 72 | output = self.layers(x) 73 | loss = compute_loss(output, labels) 74 | 75 | # We wrap the output into a dummy container to add the attribute .loss 76 | return DummyOutput(loss) 77 | -------------------------------------------------------------------------------- /contextual_sparsity/adapters/lora.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from typing import Any, Dict, List 5 | 6 | import torch 7 | from torch import nn 8 | 9 | from contextual_sparsity.adapters.base import Adapter 10 | from contextual_sparsity.nn import SparseLinear 11 | 12 | 13 | class LoRA(Adapter): 14 | """ 15 | Simple LoRA Adapter for a Linear layer. 16 | """ 17 | 18 | def __init__(self, W: torch.Tensor, rank: int, dropout_rate: float = 0.0): 19 | super().__init__() 20 | dtype = W.dtype 21 | 22 | # Usual LoRA initialization 23 | A = torch.zeros(W.shape[1], rank).to(W.device) 24 | B = torch.zeros(rank, W.shape[0]).to(W.device) 25 | A.normal_(0, 1.0 / A.shape[1]) 26 | 27 | A = A.to(dtype) 28 | B = B.to(dtype) 29 | self.A = nn.Parameter(A) 30 | self.B = nn.Parameter(B) 31 | self.dropout = nn.Dropout(dropout_rate) 32 | self.dropout_rate = dropout_rate 33 | self.rank = rank 34 | 35 | @staticmethod 36 | def from_module(linear: nn.Module, **kwargs) -> Adapter: 37 | if isinstance(linear, SparseLinear): 38 | W = linear._weight.detach() 39 | else: 40 | W = linear.weight.detach() 41 | 42 | adapter = LoRA(W, **kwargs) 43 | return adapter 44 | 45 | def _hook( 46 | self, 47 | module: nn.Module, 48 | args: List[Any], 49 | kwargs: Dict[str, Any], 50 | out: torch.Tensor, 51 | ) -> torch.Tensor: 52 | x = args[0] 53 | 54 | # If the Adapter is applied to a Sparse Linear Layer 55 | if isinstance(module, SparseLinear): 56 | # Apply the same mask to the input before using the adapter 57 | if module._col_mask is not None: 58 | mask = module._col_mask.to(x.device).to(x.dtype).detach() 59 | x = mask * x 60 | 61 | out = out + self.forward(x) 62 | return out 63 | 64 | def forward(self, x: torch.Tensor) -> torch.Tensor: 65 | z = x @ self.A 66 | z = self.dropout(z) 67 | return z @ self.B 68 | 69 | def __repr__(self): 70 | return f"LoRA(rank={self.rank}, dropout={self.dropout_rate})" 71 | -------------------------------------------------------------------------------- /contextual_sparsity/masking_hooks/glu_pruning.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from typing import Callable, List, Optional, Union 5 | 6 | from torch import nn 7 | from torch.utils.data import DataLoader 8 | 9 | from contextual_sparsity.mask import MaskingHook 10 | from contextual_sparsity.masking_hooks.binarization import ( 11 | BinarizationType, 12 | build_binarization, 13 | ) 14 | from contextual_sparsity.nn import Abs 15 | from contextual_sparsity.utils.layer_names import FC_DOWN, get_layer_ids 16 | 17 | 18 | def build_glu_pruning_masking_hooks( 19 | model_id: str, 20 | dense_model: nn.Module, 21 | layers_to_sparsify: Union[str, List[int], int], 22 | binarization_type: str = BinarizationType.topk.value, 23 | data_id: Optional[str] = None, 24 | calibration_data: Optional[DataLoader] = None, 25 | preprocess_batch: Optional[Callable] = None, 26 | keep: Optional[Union[int, List[int]]] = None, 27 | k: Optional[Union[int, List[int]]] = None, 28 | threshold: Optional[Union[int, List[int]]] = None, 29 | ) -> List[MaskingHook]: 30 | """ 31 | Factory function for building GLU pruning masking hooks. 32 | """ 33 | 34 | down_layer_ids = get_layer_ids( 35 | model_id=model_id, layer_type=FC_DOWN, layer_names=layers_to_sparsify 36 | ) 37 | down_activation_ids = [".".join([down_layer_id, "input"]) for down_layer_id in down_layer_ids] 38 | 39 | # Build the layer responsible for making the activations binary 40 | activation_binarization = build_binarization( 41 | activation_ids=down_activation_ids, 42 | model_id=model_id, 43 | dense_model=dense_model, 44 | data_id=data_id, 45 | calibration_data=calibration_data, 46 | binarization_type=binarization_type, 47 | preprocess_batch=preprocess_batch, 48 | keep=keep, 49 | k=k, 50 | threshold=threshold, 51 | ) 52 | 53 | masking_hooks = [] 54 | for i, layer_id in enumerate(down_layer_ids): 55 | # Make the masking function 56 | activation_id = ".".join([layer_id, "input"]) 57 | 58 | masking_hook = MaskingHook( 59 | masking_func=nn.Sequential( 60 | Abs(), 61 | activation_binarization[activation_id], 62 | ), 63 | input_from=layer_id, 64 | mask_rows_of=[], 65 | mask_cols_of=[layer_id], 66 | ) 67 | 68 | masking_hooks.append(masking_hook) 69 | 70 | return masking_hooks 71 | -------------------------------------------------------------------------------- /scripts/config/data/wikitext.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | data_id: wikitext 3 | 4 | data: 5 | train: 6 | _target_: contextual_sparsity.data.get_dataloader 7 | dataset_id: wikitext 8 | name: wikitext-2-raw-v1 9 | dataset_path: ${hardware.paths.data.wikitext} 10 | take_n_tokens: null 11 | take_n_sequences: null 12 | sequence_length: 2048 13 | prompt_length: 0 14 | num_workers: ${hardware.cpu_cores} 15 | split: train 16 | shuffle: true 17 | model_id: ${dense_model.model_id} 18 | 19 | valid: 20 | _target_: contextual_sparsity.data.get_dataloader 21 | dataset_id: wikitext 22 | name: wikitext-2-raw-v1 23 | dataset_path: ${hardware.paths.data.wikitext} 24 | take_n_tokens: null 25 | take_n_sequences: null 26 | sequence_length: 2048 27 | prompt_length: 0 28 | num_workers: ${hardware.cpu_cores} 29 | split: validation 30 | shuffle: false 31 | model_id: ${dense_model.model_id} 32 | 33 | test: 34 | _target_: contextual_sparsity.data.get_dataloader 35 | dataset_id: wikitext 36 | name: wikitext-2-raw-v1 37 | dataset_path: ${hardware.paths.data.wikitext} 38 | take_n_tokens: null 39 | take_n_sequences: null 40 | sequence_length: 2048 41 | prompt_length: 0 42 | num_workers: ${hardware.cpu_cores} 43 | split: test 44 | shuffle: false 45 | model_id: ${dense_model.model_id} 46 | 47 | calibration: 48 | _target_: contextual_sparsity.data.get_dataloader 49 | dataset_id: wikitext 50 | name: wikitext-2-raw-v1 51 | dataset_path: ${hardware.paths.data.wikitext} 52 | take_n_tokens: null 53 | take_n_sequences: 10 54 | sequence_length: 2048 55 | prompt_length: 0 56 | num_workers: ${hardware.cpu_cores} 57 | split: train 58 | shuffle: true 59 | model_id: ${dense_model.model_id} 60 | 61 | slimpajama_calibration: 62 | _target_: contextual_sparsity.get_slimpajama_dataloader 63 | sequence_length: 2048 64 | batch_size: 1 65 | num_workers: ${hardware.cpu_cores} 66 | shuffle: true 67 | model_id: ${dense_model.model_id} 68 | device: ${hardware.device} 69 | take_n: 10 70 | 71 | tiny: 72 | _target_: contextual_sparsity.data.get_dataloader 73 | dataset_id: wikitext 74 | name: wikitext-2-raw-v1 75 | dataset_path: ${hardware.paths.data.wikitext} 76 | take_n_tokens: 100 77 | take_n_sequences: null 78 | sequence_length: 2048 79 | prompt_length: 0 80 | num_workers: ${hardware.cpu_cores} 81 | split: test 82 | shuffle: false 83 | model_id: ${dense_model.model_id} -------------------------------------------------------------------------------- /contextual_sparsity/dense_models/llm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import logging 5 | from typing import Optional, Type, Union 6 | 7 | import torch 8 | from transformers import AutoModelForCausalLM, PreTrainedModel 9 | 10 | from contextual_sparsity.utils.layer_names import LAYERS_CONTAINER, MODEL_MAPS, N_LAYERS 11 | from contextual_sparsity.utils.misc import parse_dtype 12 | from contextual_sparsity.utils.phi import split_upgate 13 | from contextual_sparsity.utils.sparsify import set_submodule 14 | from contextual_sparsity.utils.turbosparse import remove_turbosparse_predictors 15 | 16 | # A logger for this file 17 | log = logging.getLogger(__name__) 18 | 19 | 20 | def trim_layers(model: PreTrainedModel, model_id: str) -> PreTrainedModel: 21 | """ 22 | Remove all layers but the first one from the model 23 | """ 24 | layer_container = MODEL_MAPS[model_id][LAYERS_CONTAINER] 25 | layers = model.get_submodule(layer_container) 26 | set_submodule(model, layer_container, layers[:1]) 27 | MODEL_MAPS[model_id][N_LAYERS] = 1 28 | 29 | return model 30 | 31 | 32 | def load_hf_model( 33 | pretrained_model_path: str, 34 | model_id: str, 35 | dtype: Optional[Union[str, torch.dtype]] = None, 36 | model_type: Type[PreTrainedModel] = AutoModelForCausalLM, 37 | test_mode: bool = False, 38 | device: Union[str, torch.device] = "cpu", 39 | local_files_only: bool = True, 40 | remove_predictors: bool = True, 41 | ) -> torch.nn.Module: 42 | """ 43 | Load a pre-trained model from Huggingface given a specified path. 44 | """ 45 | # Parse the data type 46 | torch_dtype = parse_dtype(dtype) 47 | 48 | if model_id not in MODEL_MAPS: 49 | raise ValueError(f"Model {model_id} not found in {MODEL_MAPS.keys()}") 50 | 51 | # The net is stored locally 52 | log.info( 53 | f"Loading the {model_id} pretrained model from {pretrained_model_path} using {model_type}" 54 | ) 55 | model = model_type.from_pretrained( 56 | pretrained_model_name_or_path=pretrained_model_path, 57 | local_files_only=local_files_only, 58 | torch_dtype=torch_dtype, 59 | device_map="cpu", 60 | trust_remote_code=True, 61 | ) 62 | 63 | if model_id == "turbosparse-mistral" and remove_predictors: 64 | remove_turbosparse_predictors(model, model_id) 65 | 66 | # Split up and gate into separate matrices 67 | if "phi-3" in model_id: 68 | split_upgate(model, model_id) 69 | 70 | # Trim all layers but the first if in test mode 71 | if test_mode: 72 | model = trim_layers(model, model_id) 73 | 74 | return model.to(device) 75 | -------------------------------------------------------------------------------- /contextual_sparsity/data/hf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import logging 5 | import os 6 | from typing import Optional 7 | 8 | from datasets import load_from_disk 9 | from torch.utils.data import DataLoader 10 | from transformers import PreTrainedTokenizer 11 | 12 | from contextual_sparsity.hw_simulator.constants import MODEL_ID_TO_DIMS 13 | from contextual_sparsity.utils.tokenizers import tokenize_for_language_modeling 14 | 15 | # A logger for this file 16 | log = logging.getLogger(__name__) 17 | 18 | 19 | def get_dataloader( 20 | dataset_id: str, 21 | tokenizer: PreTrainedTokenizer, 22 | dataset_path: str, 23 | model_id: str, 24 | name: Optional[str], 25 | split: Optional[str], 26 | take_n_tokens: Optional[int] = None, 27 | take_n_sequences: Optional[int] = None, 28 | num_workers: Optional[int] = None, 29 | sequence_length: int = 2048, 30 | prompt_length: Optional[int] = None, 31 | batch_size: int = 1, 32 | shuffle: bool = True, 33 | keep_in_memory: bool = True, 34 | ) -> DataLoader: 35 | """ 36 | Builds a dataloader for a huggingface dataset. Note that dataset_path corresponds to a local path in which 37 | the dataset is stored. 38 | """ 39 | log.info(f"Loading {split} {dataset_id} dataset from {dataset_path}") 40 | assert sequence_length <= MODEL_ID_TO_DIMS[model_id]["max_position_embeddings"], ( 41 | f'Selected sequence_length "{sequence_length}" is larger than the context_length ' 42 | f'"{MODEL_ID_TO_DIMS[model_id]["max_position_embeddings"]}" for this model! ' 43 | ) 44 | 45 | if name is not None: 46 | dataset_path = os.path.join(dataset_path, name) 47 | if split is not None: 48 | dataset_path = os.path.join(dataset_path, split) 49 | 50 | # Load the dataset 51 | dataset = load_from_disk(dataset_path=dataset_path) 52 | log.info(f"Loaded {split} {dataset_id} dataset from {dataset_path}. Size: {len(dataset)}") 53 | 54 | # How many tokens to subsample 55 | if take_n_tokens is not None: 56 | dataset = dataset.take(take_n_tokens) 57 | 58 | # Tokenize and make sequences of sequence length 59 | log.info(f"Tokenizing the {split} {dataset_id} dataset.") 60 | dataset = tokenize_for_language_modeling( 61 | tokenizer, 62 | dataset, 63 | sequence_length=sequence_length, 64 | keep_in_memory=keep_in_memory, 65 | ) 66 | 67 | # How many sequences/datapoints to subsample 68 | if take_n_sequences is not None: 69 | dataset = dataset.take(take_n_sequences) 70 | 71 | dataloader = DataLoader( 72 | dataset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle 73 | ) 74 | 75 | return dataloader 76 | -------------------------------------------------------------------------------- /contextual_sparsity/nn/binarization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class ThresholdMask(nn.Module): 9 | def __init__(self, threshold: float): 10 | """ 11 | Binarization layer based on the k largest elements on a fixed threshold 12 | """ 13 | super().__init__() 14 | self.threshold = threshold 15 | 16 | def forward(self, x: torch.Tensor) -> torch.BoolTensor: 17 | return (x > self.threshold).squeeze(-1) 18 | 19 | def extra_repr(self) -> str: 20 | return f"threshold={self.threshold}" 21 | 22 | 23 | class TopKMask(nn.Module): 24 | def __init__(self, k: int): 25 | """ 26 | Binarization layer based on the k largest elements on the last dimension 27 | """ 28 | super().__init__() 29 | self.k = k 30 | 31 | def forward(self, x: torch.Tensor) -> torch.BoolTensor: 32 | # Make sure we have at least 2 dims 33 | if x.ndim == 1: 34 | x = x.unsqueeze(0) 35 | 36 | assert x.ndim >= 2 37 | 38 | if self.k > x.size(-1): 39 | mask = torch.ones_like(x).bool() 40 | else: 41 | mask = torch.zeros_like(x) 42 | 43 | _, top_idx = torch.topk(x, self.k, dim=-1) 44 | mask.scatter_(-1, top_idx, 1) 45 | 46 | return mask.bool() 47 | 48 | def extra_repr(self) -> str: 49 | return f"k={self.k}" 50 | 51 | 52 | class RandomMask(nn.Module): 53 | """ 54 | Binarization layer based on random mask 55 | """ 56 | 57 | def forward(self, x: torch.Tensor) -> torch.BoolTensor: 58 | return torch.rand_like(x) >= self.p 59 | 60 | def extra_repr(self) -> str: 61 | return f"p={self.p}" 62 | 63 | 64 | class RandomKMask(TopKMask): 65 | """ 66 | Binarization layer with k entries set to 1 67 | """ 68 | 69 | def forward(self, x: torch.Tensor) -> torch.BoolTensor: 70 | return super().forward(torch.rand_like(x)) 71 | 72 | 73 | class StaticMask(nn.Module): 74 | def __init__(self, mask: torch.BoolTensor): 75 | """ 76 | Static binarization based on a specified mask 77 | """ 78 | super(StaticMask, self).__init__() 79 | if mask.ndim == 1: 80 | mask = mask.unsqueeze(0) 81 | assert mask.ndim == 2 82 | self.register_buffer("mask", mask) 83 | 84 | def forward(self, x): 85 | self.mask = self.mask.to(x.device) 86 | assert x.shape[-1] == self.mask.shape[-1] 87 | if x.ndim == 1: 88 | return self.mask.squeeze(0) 89 | elif x.ndim == 2: 90 | return self.mask.repeat(x.shape[0], 1) 91 | elif x.ndim == 3: 92 | return self.mask.unsqueeze(0).repeat(x.shape[0], x.shape[1], 1) 93 | -------------------------------------------------------------------------------- /tests/scripts/test_evaluate_llm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import os 5 | 6 | import pytest 7 | from hydra import compose, initialize 8 | 9 | from scripts.run_experiment import parse 10 | 11 | HARDWARE_ID = "ci_node" 12 | BASE_CONFIG_PATH = "scripts/config" 13 | if not os.path.exists(BASE_CONFIG_PATH): 14 | BASE_CONFIG_PATH = os.path.join("..", BASE_CONFIG_PATH) 15 | BASE_CONFIG_PATH = os.path.abspath(BASE_CONFIG_PATH) 16 | 17 | 18 | @pytest.mark.parametrize("use_simulator", [False, True]) 19 | def test_evaluate_llm_perplexity(tmpdir, use_simulator): 20 | log_dir = os.path.join(tmpdir, "log") 21 | cache_dir = os.path.join(tmpdir, "cache") 22 | overrides = [ 23 | "dense_model=opt-350M", 24 | "experiment=evaluate_llm", 25 | "evaluation=perplexity", 26 | "data=wikitext", 27 | "data.test.take_n_sequences=1", 28 | "data.test.sequence_length=5", 29 | "data.test.prompt_length=0", 30 | "masking_hooks=glu_pruning", 31 | "masking_hooks.layers_to_sparsify=all", 32 | "+masking_hooks.k=128", 33 | f"hardware.paths.log={log_dir}", 34 | f"hardware.paths.cache={cache_dir}", 35 | ] 36 | if use_simulator: 37 | overrides += [ 38 | "+hw_simulator=default", 39 | "cache_hooks=write_only", 40 | ] 41 | 42 | with initialize(version_base="1.3", config_path="pkg://scripts/config"): 43 | cfg = compose("config.yaml", overrides=overrides) 44 | 45 | parse(cfg) 46 | 47 | assert os.path.exists("results.csv") 48 | if use_simulator: 49 | assert os.path.exists("results_hwsim.csv") 50 | 51 | 52 | @pytest.mark.parametrize("use_simulator", [False, True]) 53 | def test_evaluate_llm_lmeval(tmpdir, use_simulator): 54 | log_dir = os.path.join(tmpdir, "log") 55 | cache_dir = os.path.join(tmpdir, "cache") 56 | overrides = [ 57 | "dense_model=opt-350M", 58 | "experiment=evaluate_llm", 59 | "evaluation=arc_easy", 60 | "evaluation.arc_easy.limit=1", 61 | "masking_hooks=glu_pruning", 62 | "masking_hooks.layers_to_sparsify=all", 63 | "+masking_hooks.k=128", 64 | f"hardware.paths.log={log_dir}", 65 | f"hardware.paths.cache={cache_dir}", 66 | ] 67 | if use_simulator: 68 | overrides += [ 69 | "+hw_simulator=default", 70 | "cache_hooks=write_only", 71 | "hw_simulator.sequence_length=5", 72 | "hw_simulator.prompt_length=0", 73 | ] 74 | 75 | with initialize(version_base="1.3", config_path="pkg://scripts/config"): 76 | cfg = compose("config.yaml", overrides=overrides) 77 | 78 | parse(cfg) 79 | 80 | assert os.path.exists("results.csv") 81 | if use_simulator: 82 | assert os.path.exists("results_hwsim.csv") 83 | -------------------------------------------------------------------------------- /contextual_sparsity/hw_simulator/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | 5 | # Dict keys chosen to be compatible with model.config. 6 | MODEL_ID_TO_DIMS = { 7 | "opt-350M": { 8 | "num_hidden_layers": 24, # number of blocks in NN 9 | "vocab_size": 50272, 10 | "hidden_size": 1024, 11 | "intermediate_size": 4096, 12 | "max_position_embeddings": 2048, # context length used at training size 13 | "has_gate_proj": False, 14 | "num_attention_heads": 16, 15 | "num_key_value_heads": 16, # Grouped Query Attention ratio = num_key_value_heads / num_attention_heads 16 | }, # This model actually has a linear projection from 512 (embedding dimension) to 1024 (d_model) 17 | "llama-v3-8B": { 18 | "num_hidden_layers": 32, 19 | "vocab_size": 128256, 20 | "hidden_size": 4096, 21 | "intermediate_size": 14336, 22 | "max_position_embeddings": 8192, 23 | "has_gate_proj": True, 24 | "num_attention_heads": 32, 25 | "num_key_value_heads": 8, 26 | }, 27 | "phi-3-medium": { 28 | "num_hidden_layers": 40, 29 | "vocab_size": 32064, 30 | "hidden_size": 5120, 31 | "intermediate_size": 17920, 32 | "max_position_embeddings": 4096, 33 | "has_gate_proj": True, 34 | "num_attention_heads": 40, 35 | "num_key_value_heads": 10, 36 | }, 37 | "phi-3-mini": { 38 | "num_hidden_layers": 32, 39 | "vocab_size": 32064, 40 | "hidden_size": 3072, 41 | "intermediate_size": 8192, 42 | "max_position_embeddings": 4096, 43 | "has_gate_proj": True, 44 | "num_attention_heads": 32, 45 | "num_key_value_heads": 32, 46 | }, 47 | "mistral-v01-7B": { 48 | "num_hidden_layers": 32, 49 | "vocab_size": 32000, 50 | "hidden_size": 4096, 51 | "intermediate_size": 14336, 52 | "max_position_embeddings": 32768, 53 | "has_gate_proj": True, 54 | "num_attention_heads": 32, 55 | "num_key_value_heads": 8, 56 | }, 57 | "turbosparse-mistral": { 58 | "num_hidden_layers": 32, 59 | "vocab_size": 32064, 60 | "hidden_size": 4096, 61 | "intermediate_size": 14336, 62 | "max_position_embeddings": 4096, # the original max context length is 32k, but in Turbosparse finetuning 4k was used instead 63 | "has_gate_proj": True, 64 | "num_attention_heads": 32, 65 | "num_key_value_heads": 8, 66 | }, 67 | "dummy": { 68 | "num_hidden_layers": 2, 69 | "vocab_size": 101, 70 | "hidden_size": 101, 71 | "intermediate_size": 101, 72 | "max_position_embeddings": 2048, 73 | "has_gate_proj": False, 74 | "num_attention_heads": 0, 75 | "num_key_value_heads": 0, 76 | }, 77 | } 78 | -------------------------------------------------------------------------------- /tests/contextual_sparsity/test_masking_hooks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import os 5 | 6 | import pytest 7 | import torch 8 | from hydra import compose, initialize 9 | from hydra.utils import instantiate 10 | 11 | from contextual_sparsity.data.data_processing import move_dict_to_device 12 | from contextual_sparsity.data.dummy import get_dummy_dataloader 13 | from contextual_sparsity.dense_models import DummyModel 14 | from contextual_sparsity.utils.sparsify import build_sparse_model 15 | 16 | N_SEQUENCES = 2 17 | SEQUENCE_LENGTH = 100 18 | PROMPT_LENGTH = 20 19 | KEEP = 0.5 20 | 21 | 22 | @pytest.mark.parametrize("masking_hook", ["glu_pruning", "dip_free_params"]) 23 | def test_masking_hooks(tmpdir, masking_hook): 24 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 25 | overrides = [ 26 | "experiment=evaluate_llm", 27 | "dense_model=dummy", 28 | f"masking_hooks={masking_hook}", 29 | "data=wikitext", 30 | f"data.test.sequence_length={SEQUENCE_LENGTH}", 31 | f"data.test.prompt_length={PROMPT_LENGTH}", 32 | "+hw_simulator=default", 33 | "hw_simulator.cache_strategy=lfu", 34 | "cache_hooks=write_only", 35 | f"hardware.device={device}", 36 | f"hardware.paths.log={tmpdir}", 37 | ] 38 | if masking_hook == "glu_pruning": 39 | overrides.append(f"masking_hooks.keep={KEEP}") 40 | if masking_hook == "dip_free_params": 41 | overrides.append(f"masking_hooks.up_keep={KEEP}") 42 | overrides.append(f"masking_hooks.down_keep={KEEP}") 43 | 44 | with initialize(version_base="1.3", config_path="pkg://scripts/config"): 45 | conf = compose("config.yaml", overrides) 46 | os.chdir( 47 | tmpdir 48 | ) # Calling compose to set Hydra config does not have the same side effects as @hydra.main. 49 | 50 | # Initialize model with hooks and simulator 51 | model = DummyModel(device=device) 52 | dataloader = get_dummy_dataloader( 53 | sequence_length=SEQUENCE_LENGTH - PROMPT_LENGTH, 54 | n_sequences=N_SEQUENCES, 55 | batch_size=1, 56 | ) 57 | masking_hooks = instantiate(conf.masking_hooks, dense_model=model) 58 | hardware = instantiate(conf.hw_simulator, model=model, masking_hooks=masking_hooks) 59 | masking_hooks = instantiate( 60 | conf.cache_hooks, masking_hooks=masking_hooks, hw_simulator=hardware 61 | ) 62 | model = build_sparse_model( 63 | masking_hooks=masking_hooks, 64 | dense_model=model, 65 | ) 66 | model.eval() 67 | 68 | # Smoke test inference with sample data 69 | for hook in masking_hooks: 70 | hook.set_sparse() 71 | 72 | hardware.reset_hook.set_active() 73 | for batch in dataloader: 74 | batch = move_dict_to_device(batch, device=device) 75 | model(**batch) 76 | -------------------------------------------------------------------------------- /contextual_sparsity/utils/sparsify.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from copy import deepcopy 5 | from typing import List, Optional, Union 6 | 7 | from torch import nn 8 | 9 | from contextual_sparsity.mask import MaskingHook 10 | from contextual_sparsity.nn.sparse.linear import SimulatedSparseLinear, SparseLinear 11 | from contextual_sparsity.utils.submodule import set_submodule 12 | 13 | 14 | def sparsify_linear( 15 | linear: nn.Linear, 16 | simulated: bool = False, 17 | layer_name: Optional[str] = None, 18 | **kwargs, 19 | ) -> SparseLinear: 20 | """ 21 | Sparsify a linear layer, returning the corresponding wrapped Sparse Linear layer. 22 | """ 23 | # We are interested in simulating a sparse block, we can just zero out the entries 24 | if simulated: 25 | SparseLinearClass = SimulatedSparseLinear 26 | else: 27 | # Actual sparse forward can be implemented by creating a subclass of SparseLinear and adding it here 28 | raise NotImplementedError() 29 | 30 | return SparseLinearClass( 31 | weight=linear.weight, bias=linear.bias, layer_name=layer_name, **kwargs 32 | ) 33 | 34 | 35 | def sparsify_model( 36 | masking_hooks: List[MaskingHook], 37 | dense_model: nn.Module, 38 | simulated: bool = True, 39 | ) -> nn.Module: 40 | """ 41 | Sparsify the MLP layers corresponding to the masking hooks. 42 | """ 43 | 44 | # Attach all the hooks 45 | for masking_hook in masking_hooks: 46 | for layer_id in masking_hook.mask_rows_of: 47 | linear_module = dense_model.get_submodule(layer_id) 48 | if isinstance(linear_module, nn.Linear): 49 | sparse_linear_module = sparsify_linear( 50 | linear_module, 51 | layer_name=masking_hook.mask_cols_of, 52 | simulated=simulated, 53 | ) 54 | set_submodule(dense_model, layer_id, sparse_linear_module) 55 | 56 | for layer_id in masking_hook.mask_cols_of: 57 | linear_module = dense_model.get_submodule(layer_id) 58 | if isinstance(linear_module, nn.Linear): 59 | sparse_linear_module = sparsify_linear( 60 | linear_module, 61 | layer_name=masking_hook.mask_cols_of, 62 | simulated=simulated, 63 | ) 64 | else: 65 | sparse_linear_module = linear_module 66 | set_submodule(dense_model, layer_id, sparse_linear_module) 67 | 68 | masking_hook.attach_to(dense_model) 69 | 70 | return dense_model 71 | 72 | 73 | def build_sparse_model( 74 | dense_model: nn.Module, 75 | masking_hooks: Union[MaskingHook, List[MaskingHook]], 76 | simulated: bool = True, 77 | inplace: bool = True, 78 | ) -> nn.Module: 79 | """ 80 | Apply all the masking hooks to a dense model, making it sparse. 81 | """ 82 | if not inplace: 83 | dense_model = deepcopy(dense_model) 84 | 85 | masking_hooks = list(masking_hooks) 86 | 87 | return sparsify_model( 88 | dense_model=dense_model, 89 | masking_hooks=masking_hooks, 90 | simulated=simulated, 91 | ) 92 | -------------------------------------------------------------------------------- /tests/contextual_sparsity/test_cache_hooks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import os 5 | 6 | import pytest 7 | import torch 8 | from hydra import compose, initialize 9 | from hydra.utils import instantiate 10 | 11 | from contextual_sparsity.data.data_processing import move_dict_to_device 12 | from contextual_sparsity.data.dummy import N_FEATURES, get_dummy_dataloader 13 | from contextual_sparsity.dense_models import DummyModel 14 | from contextual_sparsity.utils.sparsify import build_sparse_model 15 | 16 | N_SEQUENCES = 2 17 | SEQUENCE_LENGTH = 100 18 | PROMPT_LENGTH = 20 19 | TOPK = N_FEATURES // 10 20 | 21 | 22 | @pytest.mark.parametrize( 23 | "cache_hook", ["write_only", "weighting_current_cache", "approximate_caching"] 24 | ) 25 | def test_cache_hooks(tmpdir, cache_hook): 26 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 27 | overrides = [ 28 | "experiment=evaluate_llm", 29 | "dense_model=dummy", 30 | "masking_hooks=glu_pruning", 31 | f"masking_hooks.k={TOPK}", 32 | "data=wikitext", 33 | f"data.test.sequence_length={SEQUENCE_LENGTH}", 34 | f"data.test.prompt_length={PROMPT_LENGTH}", 35 | "+hw_simulator=default", 36 | "hw_simulator.cache_strategy=lfu", 37 | f"cache_hooks={cache_hook}", 38 | f"hardware.device={device}", 39 | f"hardware.paths.log={tmpdir}", 40 | "+logging_overrides=stdout_only", 41 | ] 42 | if "weighting_" in cache_hook: 43 | overrides.append("cache_hooks.kwargs.gamma=0.1") 44 | overrides.append(f"cache_hooks.kwargs.fixed_top_n={TOPK // 2}") 45 | if cache_hook == "weighting_current_cache": 46 | overrides.append("cache_hooks.kwargs.warm_up=True") 47 | if cache_hook == "approximate_caching": 48 | overrides.append(f"cache_hooks.kwargs.fixed_top_n={TOPK // 2}") 49 | overrides.append(f"cache_hooks.kwargs.top_m={TOPK * 2}") 50 | overrides.append("cache_hooks.kwargs.gamma=0.1") 51 | 52 | with initialize(version_base="1.3", config_path="pkg://scripts/config"): 53 | conf = compose("config.yaml", overrides) 54 | os.chdir( 55 | tmpdir 56 | ) # Calling compose to set Hydra config does not have the same side effects as @hydra.main. 57 | 58 | # Initialize model with hooks and simulator 59 | model = DummyModel(device=device) 60 | dataloader = get_dummy_dataloader( 61 | sequence_length=SEQUENCE_LENGTH - PROMPT_LENGTH, 62 | n_sequences=N_SEQUENCES, 63 | batch_size=1, 64 | ) 65 | masking_hooks = instantiate(conf.masking_hooks, dense_model=model) 66 | hardware = instantiate(conf.hw_simulator, model=model, masking_hooks=masking_hooks) 67 | masking_hooks = instantiate( 68 | conf.cache_hooks, masking_hooks=masking_hooks, hw_simulator=hardware 69 | ) 70 | model = build_sparse_model( 71 | masking_hooks=masking_hooks, 72 | dense_model=model, 73 | ) 74 | model.eval() 75 | 76 | # Smoke test inference with sample data 77 | for hook in masking_hooks: 78 | hook.set_sparse() 79 | 80 | hardware.reset_hook.set_active() 81 | for batch in dataloader: 82 | batch = move_dict_to_device(batch, device=device) 83 | model(**batch) 84 | -------------------------------------------------------------------------------- /contextual_sparsity/evaluation/hooks/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | 5 | from functools import partial 6 | from typing import Any, Callable, Dict, List, Optional 7 | 8 | from torch import nn 9 | 10 | 11 | class EvaluationHook: 12 | """ 13 | Base class for all evaluation hooks. 14 | """ 15 | 16 | metric_dims: Dict[str, int] = {} 17 | 18 | def __init__(self): 19 | self._batch_output = {} 20 | self._handles = [] 21 | 22 | def collect_results(self, module, inputs, kwargs, outputs): 23 | raise NotImplementedError() 24 | 25 | def attach_to(self, model: nn.Module): 26 | raise NotImplementedError() 27 | 28 | def __call__(self, module, inputs, kwargs, outputs, attached_to: Optional[str]): 29 | output = self.collect_results(module, inputs, kwargs, outputs) 30 | if attached_to is None: 31 | attached_to = "." 32 | for quantity, value in output.items(): 33 | assert ( 34 | value.shape[-1] == self.metric_dims[quantity] 35 | ), f"{quantity}: {value.shape} != {self.metric_dims[quantity]}" 36 | 37 | output = {attached_to: output} 38 | 39 | self._batch_output.update(output) 40 | 41 | def _attach_to(self, module: nn.Module, attached_to: Optional[str] = None): 42 | hook_with_name = partial(self, attached_to=attached_to) 43 | 44 | self._handles.append(module.register_forward_hook(hook_with_name, with_kwargs=True)) 45 | 46 | def remove(self): 47 | for handle in self._handles: 48 | handle.remove() 49 | 50 | def finalize(self, stats: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: 51 | return stats 52 | 53 | def _post_process_batch( 54 | self, batch_stats: Dict[str, Dict[str, Any]] 55 | ) -> Dict[str, Dict[str, Any]]: 56 | return batch_stats 57 | 58 | @property 59 | def batch_output(self): 60 | batch_output = self._post_process_batch(self._batch_output) 61 | return batch_output 62 | 63 | def reset(self): 64 | self._batch_output = {} 65 | 66 | 67 | # Utilities to merge dictionaries with overlapping keys 68 | def merge_dicts(orig_dict: Dict[str, Any], add_dict: Dict[str, Any]) -> Dict[str, Any]: 69 | for k, v in add_dict.items(): 70 | if k in orig_dict: 71 | orig_dict[k] = merge_dicts(orig_dict[k], v) 72 | else: 73 | orig_dict[k] = v 74 | 75 | return orig_dict 76 | 77 | 78 | # Utility to collect the internal state of all evaluation hooks 79 | class CollectHooksOutput: 80 | def __init__( 81 | self, 82 | model: Callable, 83 | hooks: List[EvaluationHook], 84 | preprocess_batch: Optional[Callable], 85 | ): 86 | self.hooks = hooks 87 | self.model = model 88 | self.preprocess_batch = preprocess_batch 89 | 90 | def __call__(self, batch: Any): 91 | if self.preprocess_batch is not None: 92 | batch = self.preprocess_batch(batch) 93 | 94 | if isinstance(batch, dict): 95 | self.model(**batch) 96 | else: 97 | self.model(batch) 98 | 99 | outputs: Dict[str, Any] = {} 100 | for hook in self.hooks: 101 | outputs = merge_dicts(outputs, hook.batch_output) 102 | hook.reset() 103 | 104 | return outputs 105 | -------------------------------------------------------------------------------- /tests/contextual_sparsity/test_evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import pytest 7 | import torch 8 | 9 | from contextual_sparsity.data.dummy import get_dummy_dataloader 10 | from contextual_sparsity.dense_models import DummyModel 11 | from contextual_sparsity.evaluation import ( 12 | CROSS_ENTROPY, 13 | Memory, 14 | Perplexity, 15 | evaluate_sparse_perplexity, 16 | ) 17 | from contextual_sparsity.mask import MaskingHook 18 | from contextual_sparsity.nn import ThresholdMask 19 | from contextual_sparsity.utils.sparsify import sparsify_model 20 | 21 | N_FEATURES = 101 22 | SEQUENCE_LENGTH = 50 23 | N_SEQUENCES = 129 24 | BATCH_SIZE = 64 25 | 26 | 27 | @pytest.mark.parametrize("threshold", [-1, 1, N_FEATURES // 2, N_FEATURES]) 28 | def test_evaluation(threshold: float): 29 | # Define a dummy architecture with two identity transformations 30 | dense_model = DummyModel() 31 | dataloader = get_dummy_dataloader( 32 | n_features=N_FEATURES, 33 | n_sequences=N_SEQUENCES, 34 | sequence_length=SEQUENCE_LENGTH, 35 | batch_size=BATCH_SIZE, 36 | ) 37 | dataset = dataloader.dataset 38 | 39 | # Masking hooks 40 | masking_hooks = [ 41 | MaskingHook( 42 | masking_func=ThresholdMask(threshold=threshold), 43 | input_from="layers.0.up", 44 | mask_rows_of=["layers.0.up"], 45 | mask_cols_of=["layers.0.down"], 46 | ), 47 | MaskingHook( 48 | masking_func=ThresholdMask(threshold=threshold), 49 | input_from="layers.1.up", 50 | mask_rows_of=["layers.1.up"], 51 | mask_cols_of=["layers.1.down"], 52 | ), 53 | ] 54 | 55 | sparse_model = sparsify_model(dense_model=dense_model, masking_hooks=masking_hooks) 56 | 57 | evaluation_hooks = [ 58 | Perplexity(), 59 | Memory( 60 | model_id="dummy", 61 | precision={ 62 | "embedding": 8, 63 | "lm_head": 8, 64 | "attention": 4, 65 | "mlp": 4, 66 | "activations": 16, 67 | "kv_cache": 8, 68 | "predictors": 16, 69 | }, 70 | sequence_length=2048, 71 | ), 72 | ] 73 | 74 | evaluate_sparse_perplexity( 75 | model=sparse_model, test_data=dataloader, evaluation_hooks=evaluation_hooks 76 | ) 77 | 78 | results = pd.read_csv("results.csv") 79 | 80 | ####################### 81 | # Check Cross-entropy # 82 | ####################### 83 | true_cross_entropy = torch.cat( 84 | [sparse_model(**dataset[i]).loss.unsqueeze(0) for i in range(len(dataset))], 0 85 | ).view(-1) 86 | 87 | measured_cross_entropy = results[ 88 | (results["computed_at"] == ".") & (results["quantity"] == CROSS_ENTROPY) 89 | ].pivot_table(columns="stat", values="value") 90 | assert len(measured_cross_entropy) == 1 91 | measured_cross_entropy = measured_cross_entropy.iloc[0] 92 | 93 | assert np.isclose( 94 | measured_cross_entropy["mean"], true_cross_entropy.mean(), atol=1e-6 95 | ), f"{measured_cross_entropy['mean']} != {true_cross_entropy.mean()}" 96 | # The standard deviation can't be checked properly since the LLM loss function returns one value per batch instead 97 | # Of one per element. As a result, the standard deviation depends on the batches created by the DataLoader. 98 | -------------------------------------------------------------------------------- /contextual_sparsity/masking_hooks/trained/turbosparse_wrap.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import logging 5 | from typing import List, Optional, Union 6 | 7 | from torch import nn 8 | 9 | from contextual_sparsity.mask import MaskingHook 10 | from contextual_sparsity.masking_hooks.binarization import ( 11 | BinarizationType, 12 | build_binarization, 13 | ) 14 | from contextual_sparsity.utils.layer_names import ( 15 | FC_DOWN, 16 | LAYERS_CONTAINER, 17 | MODEL_MAPS, 18 | block_id_to_layer_ids, 19 | block_number_from_id, 20 | get_block_ids, 21 | get_layer_ids, 22 | ) 23 | from contextual_sparsity.utils.turbosparse import remove_turbosparse_predictors 24 | 25 | log = logging.getLogger(__name__) 26 | 27 | 28 | def build_original_turbosparse_hooks( 29 | model_id: str, 30 | dense_model: nn.Module, 31 | layers_to_sparsify: Union[str, List[int], int], 32 | k: Optional[Union[int, List[int]]] = None, 33 | keep: Optional[Union[float, List[float]]] = None, 34 | threshold: Optional[Union[float, List[float]]] = None, 35 | ) -> List[MaskingHook]: 36 | """ 37 | Factory function for wrapping the original pre-trained TurboSparse predictors into MaskingHooks 38 | """ 39 | 40 | assert model_id == "turbosparse-mistral" 41 | block_ids = get_block_ids( 42 | model_id=model_id, 43 | block_names=layers_to_sparsify, 44 | ) 45 | down_layer_ids = get_layer_ids( 46 | model_id=model_id, layer_type=FC_DOWN, layer_names=layers_to_sparsify 47 | ) 48 | 49 | down_activation_ids = [".".join([layer_id, "input"]) for layer_id in down_layer_ids] 50 | 51 | # Build the layer responsible for making the activations binary 52 | activation_binarization = build_binarization( 53 | activation_ids=down_activation_ids, 54 | model_id=model_id, 55 | dense_model=dense_model, 56 | data_id=None, 57 | calibration_data=None, 58 | binarization_type=( 59 | BinarizationType.topk if threshold is None else BinarizationType.threshold 60 | ), 61 | preprocess_batch=None, 62 | keep=keep, 63 | k=k, 64 | threshold=threshold, 65 | ) 66 | 67 | predictors = remove_turbosparse_predictors(dense_model, model_id) 68 | assert len(predictors) == len(dense_model.model.layers) 69 | 70 | masking_hooks: List[MaskingHook] = [] 71 | for i, block_id in enumerate(block_ids): 72 | up_layer_id, down_layer_id, gate_layer_id = block_id_to_layer_ids( 73 | block_id=block_id, model_id=model_id 74 | ) 75 | 76 | # Determine which predictor to use 77 | block_number = block_number_from_id(model_id=model_id, submodule_id=block_id) 78 | predictor = predictors[block_number] 79 | 80 | # The input comes from before the post_attention_layernom module 81 | input_from = ".".join( 82 | [ 83 | MODEL_MAPS[model_id][LAYERS_CONTAINER], 84 | str(block_number), 85 | "post_attention_layernorm", 86 | ] 87 | ) 88 | 89 | masking_hook = MaskingHook( 90 | masking_func=nn.Sequential( 91 | predictor, 92 | activation_binarization[".".join([down_layer_id, "input"])], 93 | ), 94 | mask_rows_of=[up_layer_id, gate_layer_id], 95 | mask_cols_of=[down_layer_id], 96 | input_from=input_from, 97 | ) 98 | masking_hooks.append(masking_hook) 99 | 100 | return masking_hooks 101 | -------------------------------------------------------------------------------- /contextual_sparsity/adapters/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from abc import abstractmethod 5 | from typing import Any, Dict, List, Optional 6 | 7 | import torch 8 | from hydra.utils import instantiate 9 | from omegaconf import DictConfig 10 | from torch import nn 11 | 12 | from contextual_sparsity.utils.submodule import set_submodule 13 | 14 | 15 | class Adapter(nn.Module): 16 | def __init__(self): 17 | super().__init__() 18 | self._handles = [] 19 | self._enabled = True 20 | self._attached_to = [] 21 | 22 | @staticmethod 23 | @abstractmethod 24 | def from_module(linear: nn.Module, *args, **kwargs) -> "Adapter": 25 | """ 26 | Creates an adapter from a Linear module 27 | """ 28 | pass 29 | 30 | def enable(self): 31 | """ 32 | Enables the Adapter 33 | """ 34 | self._enabled = True 35 | 36 | def disable(self): 37 | """ 38 | Disables the Adapter 39 | """ 40 | self._enabled = False 41 | 42 | def attach_to(self, module: nn.Module) -> None: 43 | """ 44 | Attaches the adapter to a module by registering it as a forward hook 45 | """ 46 | self._handles.append( 47 | module.register_forward_hook(self.hook, with_kwargs=True, prepend=True) 48 | ) 49 | self._attached_to.append(module) 50 | 51 | def hook( 52 | self, module: nn.Module, args: List[Any], kwargs: Dict[str, Any], out 53 | ) -> Optional[torch.Tensor]: 54 | """ 55 | Hook function that can be registered as a forward hook to a linear module. 56 | See https://pytorch.org/docs/stable/generated/torch.nn.modules.module.register_module_forward_hook.html 57 | for more details on the hook function and its arguments. 58 | """ 59 | if self._enabled: 60 | return self._hook(module, args, kwargs, out) 61 | else: 62 | return None 63 | 64 | @abstractmethod 65 | def _hook( 66 | self, 67 | module: nn.Module, 68 | args: List[Any], 69 | kwargs: Dict[str, Any], 70 | out: torch.Tensor, 71 | ) -> torch.Tensor: 72 | pass 73 | 74 | def remove(self): 75 | """ 76 | Removes the Adapter from the list of handles 77 | """ 78 | for handle in self._handles: 79 | handle.remove() 80 | self._attached_to = [] 81 | 82 | 83 | def add_adapters(sparse_model: nn.Module, adapter_conf: DictConfig): 84 | """ 85 | Adds adapters to a specified sparse model using the provided adapter_conf. 86 | 87 | Args: 88 | sparse_model (nn.Module): the sparse model to be modified 89 | adapter_conf (DictConfig): the configuration of the adapters 90 | """ 91 | if not hasattr(sparse_model, "adapters"): 92 | set_submodule(sparse_model, "adapters", nn.ModuleDict()) 93 | 94 | device = next(sparse_model.parameters()).device 95 | make_adapter = instantiate(adapter_conf.model) 96 | 97 | # Add all the adapters 98 | for masking_hook in sparse_model.masking_hooks: 99 | for layer_id in masking_hook.mask_rows_of: 100 | linear_module = sparse_model.get_submodule(layer_id) 101 | adapter = make_adapter(linear_module).to(device) 102 | adapter.attach_to(linear_module) 103 | sparse_model.adapters[layer_id.replace(".", "_")] = adapter 104 | 105 | for layer_id in masking_hook.mask_cols_of: 106 | linear_module = sparse_model.get_submodule(layer_id) 107 | adapter = make_adapter(linear_module).to(device) 108 | adapter.attach_to(linear_module) 109 | sparse_model.adapters[layer_id.replace(".", "_")] = adapter 110 | -------------------------------------------------------------------------------- /contextual_sparsity/evaluation/lm_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import logging 5 | import os 6 | from typing import Any, Dict, List, Optional 7 | 8 | import pandas as pd 9 | import yaml 10 | from lm_eval import evaluator 11 | from lm_eval.models.huggingface import HFLM 12 | from lm_eval.utils import make_table 13 | from transformers import PreTrainedModel, PreTrainedTokenizer 14 | 15 | from contextual_sparsity.hw_simulator.simulator import HardwareSimulator 16 | 17 | log = logging.getLogger(__name__) 18 | 19 | 20 | def make_dataframe(out: Dict[str, Any]) -> pd.DataFrame: 21 | """ 22 | Make a dataframe form the lm_eval output 23 | """ 24 | pd_results = [] 25 | for metric_name, stat_values in out["results"].items(): 26 | version = out["versions"].get(metric_name, "N/A") 27 | n = str(out["n-shot"].get(metric_name, "N/A")) 28 | 29 | if "alias" in stat_values: 30 | metric_name = stat_values.pop("alias") 31 | 32 | for stat, value in stat_values.items(): 33 | stat, _, f = stat.partition(",") 34 | 35 | if stat == "acc": 36 | stat = "mean" 37 | if stat == "acc_stderr": 38 | stat = "std" 39 | if stat == "acc_norm": 40 | stat = "norm" 41 | if stat == "acc_norm_stderr": 42 | stat = "norm_std" 43 | 44 | pd_results.append( 45 | { 46 | "quantity": metric_name, 47 | "stat": stat, 48 | "value": value, 49 | "version": version, 50 | "filter": f, 51 | "n_shots": n, 52 | } 53 | ) 54 | return pd.DataFrame(pd_results) 55 | 56 | 57 | def run_lm_eval( 58 | model: PreTrainedModel, 59 | tokenizer: PreTrainedTokenizer, 60 | tasks: List[Any], 61 | store_full_output: bool = False, 62 | hw_simulator: Optional[HardwareSimulator] = None, 63 | **kwargs, 64 | ): 65 | """ 66 | Run the lm_eval evaluation on a specified model 67 | """ 68 | if not isinstance(tasks, list): 69 | tasks = list(tasks) 70 | 71 | if "limit" in kwargs: 72 | log.warning( 73 | f"The evaluation is running with limit={kwargs['limit']}. Use this only while debugging!" 74 | ) 75 | 76 | # Wrap it into a HFLM object for evaluation 77 | lm_model = HFLM(model, tokenizer=tokenizer) 78 | 79 | # Run the simple_eval method from lm_eval 80 | out = evaluator.simple_evaluate(lm_model, tasks=tasks, **kwargs) 81 | log.info("\n" + make_table(out)) 82 | 83 | # Store the full output if specified 84 | if store_full_output: 85 | out_path = os.path.abspath(os.path.join("", "lm_eval_full.yaml")) 86 | # Delete the configuration since it is not serializable 87 | del out["config"] 88 | with open(out_path, "w") as f: 89 | yaml.dump(out, f) 90 | log.info(f"Results stored in {out_path}") 91 | 92 | # Convert the results to a dataframe and store it 93 | pd_results = make_dataframe(out) 94 | summary_path = os.path.abspath(os.path.join("", "lm_eval_results.csv")) 95 | log.info(f"Summary results stored in {summary_path}") 96 | pd_results.to_csv(summary_path, index=False) 97 | 98 | # Get and store results from HW simulator 99 | if hw_simulator is not None: 100 | results_hwsim = hw_simulator.get_stats_df() 101 | log.info(results_hwsim) 102 | results_hwsim.to_csv("results_hwsim.csv", index=False) 103 | log.info( 104 | f"Results for HW simulator saved in {os.path.abspath(os.path.join('', 'results_hwsim.csv'))}" 105 | ) 106 | 107 | average_accuracy = pd_results[pd_results["stat"] == "mean"]["value"].mean() 108 | log.info(f"Average Accuracy: {average_accuracy}") 109 | 110 | return {"average_accuracy": average_accuracy} 111 | -------------------------------------------------------------------------------- /contextual_sparsity/data/slimpajama.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import logging 5 | from functools import partial 6 | from glob import glob 7 | from typing import Optional 8 | 9 | import torch 10 | from datasets import Dataset, concatenate_datasets 11 | from torch.utils.data import DataLoader 12 | from transformers import PreTrainedTokenizer 13 | 14 | log = logging.getLogger(__name__) 15 | 16 | 17 | def data_collator_with_truncation(features, seq_length, bos_token): 18 | batch = {} 19 | 20 | first = features[0] 21 | for k in first.keys(): 22 | values = [f[k][:seq_length] for f in features] 23 | if bos_token is not None and k == "input_ids": 24 | values = [[bos_token] + v[:-1] for v in values] 25 | batch[k] = torch.tensor(values) 26 | 27 | if "labels" not in batch: 28 | batch["labels"] = batch["input_ids"].clone() 29 | 30 | return batch 31 | 32 | 33 | def get_slimpajama_dataloader( 34 | tokenized_dataset_path: str, 35 | sequence_length: int, 36 | batch_size: int = 1, 37 | shuffle: bool = False, 38 | tokenizer: Optional[PreTrainedTokenizer] = None, 39 | model_id: Optional[str] = None, 40 | bos_token: Optional[str] = None, 41 | num_workers: int = 0, 42 | device: str = "cpu", 43 | take_n: Optional[int] = None, 44 | ): 45 | """ 46 | Creates a dataloader object for a pre-tokenized SlimPajama dataset consisting of multiple .arrow files. 47 | The tokenized dataset is assumed to also come with sequences of length >= sequence_length that will be sliced 48 | if needed. 49 | 50 | Args: 51 | tokenized_dataset_path: The path to the tokenized .arrow files. Use '*" to match all the .arrow files. 52 | sequence_length: The desired sequence length for each batch produced by the dataloader. 53 | take_n: Slice the SlimPajama dataset at a specified number of sequences. 54 | batch_size: The batch size for the dataloader. 55 | model_id: The id of the model to load (this is used only to maintain a consistent interface). 56 | shuffle: Flag to enable shuffling. 57 | tokenizer: The tokenizer used to process the data. 58 | This argument is ignored in this function since the data is pre-tokenized. 59 | bos_token: Beginning of Sequence token (if any). 60 | device: The device to use for computation (this is used only to maintain a consistent interface). 61 | num_workers: Number of workers used for the dataloader. 62 | Returns: 63 | DataLoader: A dataloader object for the SlimPajama dataset. 64 | """ 65 | 66 | # Determine all the files that match to the specified path 67 | files = glob(tokenized_dataset_path) 68 | 69 | if not files: 70 | raise RuntimeError("SlimPajama dataset not found.") 71 | 72 | log.info(f"Loading the tokenized SlimPajama dataset from {tokenized_dataset_path}") 73 | # Concatenate all the .arrow fragments 74 | tokenized_dataset = concatenate_datasets([Dataset.from_file(fpath) for fpath in sorted(files)]) 75 | # Define a data collator that makes sequences of the specified length 76 | custom_collate_fn = partial( 77 | data_collator_with_truncation, seq_length=sequence_length, bos_token=bos_token 78 | ) 79 | 80 | # If specified, slice the dataset to keep only a subset 81 | if take_n is not None: 82 | log.info(f"Taking a subset of {take_n} samples from SlimPajama") 83 | tokenized_dataset = tokenized_dataset.select(range(take_n)) 84 | 85 | # Make a dataloader 86 | dataloader = DataLoader( 87 | tokenized_dataset, 88 | collate_fn=custom_collate_fn, 89 | batch_size=batch_size, 90 | shuffle=shuffle, 91 | num_workers=num_workers, 92 | ) 93 | 94 | batch = next(iter(dataloader)) 95 | dataset_seq_length = batch["input_ids"].shape[1] 96 | assert dataset_seq_length == sequence_length, ( 97 | f"Dataset does not support the requested seq_length of {sequence_length}. " 98 | f"Sequences up to {dataset_seq_length} are supported." 99 | ) 100 | dataloader.seq_length = sequence_length 101 | 102 | return dataloader 103 | -------------------------------------------------------------------------------- /contextual_sparsity/masking_hooks/trained/optimization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import logging 5 | from copy import deepcopy 6 | from typing import Any, Callable, Optional, Union 7 | 8 | import pandas as pd 9 | import torch 10 | from torch.optim import Optimizer 11 | from torch.utils.data import DataLoader 12 | from tqdm.auto import tqdm 13 | 14 | from contextual_sparsity.masking_hooks.trained.loss import PredictorLoss 15 | from contextual_sparsity.utils.misc import get_batch_size, move_to_device 16 | 17 | log = logging.getLogger(__name__) 18 | 19 | LOSS = "loss" 20 | 21 | 22 | def optimize( 23 | train_loader: DataLoader, 24 | valid_loader: DataLoader, 25 | device: Union[str, torch.device], 26 | n_epochs: int, 27 | patience: Optional[int], 28 | optimizer: Callable[[Any], Optimizer], 29 | loss_func: PredictorLoss, 30 | ) -> pd.DataFrame: 31 | """ 32 | Train the parameters of a given loss function on a given train_loader. Validate on valid_valid to determine 33 | early stopping, otherwise train for n_epochs. The function returns a dataframe containing the log of training 34 | and validation losses. 35 | """ 36 | 37 | opt = optimizer([param for param in loss_func.parameters() if param.requires_grad]) 38 | train_log = [] 39 | iteration = 0 40 | best_loss = float("inf") 41 | best_weights = None 42 | if patience is None: 43 | patience = n_epochs 44 | original_patience = patience 45 | loss_func = loss_func.to(device) 46 | 47 | grad_enabled = torch.is_grad_enabled() 48 | torch.set_grad_enabled(True) 49 | 50 | # Progress bars 51 | epochs_pbar = tqdm(total=n_epochs) 52 | 53 | for epoch in range(n_epochs): 54 | # Train iteration 55 | loss_func.train() 56 | for batch in train_loader: 57 | batch = move_to_device(batch, device) 58 | 59 | batch_out = loss_func(**batch) 60 | 61 | if not isinstance(batch_out, dict): 62 | batch_out = {LOSS: batch_out} 63 | else: 64 | assert LOSS in batch_out 65 | loss = batch_out[LOSS] 66 | 67 | opt.zero_grad() 68 | loss.backward() 69 | opt.step() 70 | 71 | log_entry = { 72 | k: v.item() if torch.is_tensor(v) else v.detach() for k, v in batch_out.items() 73 | } 74 | log_entry.update({"epoch": epoch, "iteration": iteration, "split": "train"}) 75 | train_log.append(log_entry) 76 | iteration += 1 77 | 78 | # Validation 79 | val_loss = 0.0 80 | n = 0.0 81 | loss_func.eval() 82 | for batch in valid_loader: 83 | batch = move_to_device(batch, device) 84 | batch_size = get_batch_size(batch) 85 | 86 | batch_out = loss_func(**batch) 87 | if not isinstance(batch_out, dict): 88 | batch_out = {LOSS: batch_out} 89 | else: 90 | assert LOSS in batch_out 91 | loss = batch_out[LOSS] 92 | 93 | val_loss += loss.item() * batch_size 94 | n += batch_size 95 | log_entry = { 96 | k: v.item() if torch.is_tensor(v) else v.detach() for k, v in batch_out.items() 97 | } 98 | log_entry.update({"epoch": epoch, "iteration": iteration, "split": "valid"}) 99 | train_log.append(log_entry) 100 | 101 | val_loss /= n 102 | epochs_pbar.set_postfix({"valid_loss": val_loss}) 103 | epochs_pbar.update(1) 104 | if val_loss < best_loss: 105 | best_loss = val_loss 106 | best_weights = deepcopy(loss_func.predictor.state_dict()) 107 | patience = original_patience 108 | log.info(f"Best validation loss: {val_loss}") 109 | else: 110 | patience -= 1 111 | 112 | if patience == 0: 113 | log.info(f"Early stopping at epoch {epoch}") 114 | break 115 | 116 | torch.set_grad_enabled(grad_enabled) 117 | if best_weights is not None: 118 | loss_func.predictor.load_state_dict(best_weights) 119 | return pd.DataFrame(train_log) 120 | -------------------------------------------------------------------------------- /contextual_sparsity/masking_hooks/partial_glu_pruning.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from enum import Enum 5 | from typing import Callable, List, Optional, Union 6 | 7 | from torch import nn 8 | from torch.utils.data import DataLoader 9 | 10 | from contextual_sparsity.mask import MaskingHook 11 | from contextual_sparsity.masking_hooks.binarization import ( 12 | BinarizationType, 13 | build_binarization, 14 | ) 15 | from contextual_sparsity.nn import Abs 16 | from contextual_sparsity.utils.layer_names import ( 17 | FC_ACT, 18 | FC_UP, 19 | MODEL_MAPS, 20 | block_id_to_layer_ids, 21 | block_id_to_mlp_id, 22 | get_block_ids, 23 | get_layer_ids, 24 | has_gate, 25 | ) 26 | 27 | 28 | class PredictorType(Enum): 29 | up = "up" 30 | gate = "gate" 31 | 32 | 33 | def build_partial_glu_pruning_masking_hooks( 34 | model_id: str, 35 | dense_model: nn.Module, 36 | layers_to_sparsify: Union[str, List[int], int], 37 | predictor_type: str, 38 | binarization_type: str = BinarizationType.topk.value, 39 | data_id: Optional[str] = None, 40 | calibration_data: Optional[DataLoader] = None, 41 | preprocess_batch: Optional[Callable] = None, 42 | keep: Optional[Union[int, List[int]]] = None, 43 | k: Optional[Union[int, List[int]]] = None, 44 | threshold: Optional[Union[int, List[int]]] = None, 45 | ) -> List[MaskingHook]: 46 | """ 47 | Factory function for building either UP or GATE masking hooks based on the specified predictor type 48 | (either up or gate). 49 | """ 50 | 51 | predictor_type = PredictorType[predictor_type] 52 | 53 | if predictor_type == PredictorType.gate: 54 | assert has_gate(model_id) 55 | 56 | block_ids = get_block_ids(model_id=model_id, block_names=layers_to_sparsify) 57 | 58 | if predictor_type == PredictorType.up: 59 | layer_ids = get_layer_ids( 60 | model_id=model_id, layer_type=FC_UP, layer_names=layers_to_sparsify 61 | ) 62 | else: 63 | layer_ids = get_layer_ids( 64 | model_id=model_id, layer_type=FC_ACT, layer_names=layers_to_sparsify 65 | ) 66 | activation_ids = [".".join([layer_id, "output"]) for layer_id in layer_ids] 67 | 68 | # Build the layer responsible for making the activations binary 69 | activation_binarization = build_binarization( 70 | activation_ids=activation_ids, 71 | model_id=model_id, 72 | dense_model=dense_model, 73 | data_id=data_id, 74 | calibration_data=calibration_data, 75 | binarization_type=binarization_type, 76 | threshold=threshold, 77 | preprocess_batch=preprocess_batch, 78 | keep=keep, 79 | k=k, 80 | ) 81 | 82 | masking_hooks = [] 83 | for i, block_id in enumerate(block_ids): 84 | up_layer_id, down_layer_id, gate_layer_id = block_id_to_layer_ids( 85 | model_id=model_id, block_id=block_id 86 | ) 87 | mlp_layer_id = block_id_to_mlp_id(model_id=model_id, block_id=block_id) 88 | act_id = ".".join([block_id, MODEL_MAPS[model_id][FC_ACT]]) 89 | 90 | mask_rows_of = [] 91 | if predictor_type == PredictorType.up: 92 | predictor = dense_model.get_submodule(up_layer_id) 93 | if gate_layer_id is not None: 94 | mask_rows_of.append(gate_layer_id) 95 | binarization_id = ".".join([up_layer_id, "output"]) 96 | else: 97 | predictor = dense_model.get_submodule(gate_layer_id) 98 | act_fn = dense_model.get_submodule(act_id) 99 | predictor = nn.Sequential(predictor, act_fn) 100 | mask_rows_of.append(up_layer_id) 101 | binarization_id = ".".join([act_id, "output"]) 102 | 103 | # Predict the activations (using gate or up), apply the absolute value, and then make them binary 104 | masking_func = nn.Sequential( 105 | predictor, 106 | Abs(), 107 | activation_binarization[binarization_id], 108 | ) 109 | 110 | # Wrap it into a masking hook, which contains references to where it should be attached 111 | masking_hook = MaskingHook( 112 | masking_func=masking_func, 113 | input_from=mlp_layer_id, 114 | mask_rows_of=mask_rows_of, 115 | mask_cols_of=[down_layer_id], 116 | ) 117 | masking_hooks.append(masking_hook) 118 | 119 | return masking_hooks 120 | -------------------------------------------------------------------------------- /contextual_sparsity/nn/sparse/linear.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from typing import Optional 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | 10 | 11 | class SparseLinear(nn.Module): 12 | allow_different_mask_size: bool = False 13 | 14 | def __init__( 15 | self, 16 | weight: torch.Tensor, 17 | bias: Optional[torch.Tensor] = None, 18 | layer_name: Optional[str] = None, 19 | ): 20 | super().__init__() 21 | 22 | self.store_weights(weight, bias) 23 | self.layer_name = layer_name 24 | 25 | # No masking is active 26 | self._row_mask = None 27 | self._col_mask = None 28 | 29 | self._output_correction = None 30 | 31 | def store_weights(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): 32 | raise NotImplementedError() 33 | 34 | def load_weight_rows(self, mask: torch.BoolTensor) -> torch.Tensor: 35 | raise NotImplementedError() 36 | 37 | def load_weight_cols(self, mask: torch.BoolTensor) -> torch.Tensor: 38 | raise NotImplementedError() 39 | 40 | def load_weight_rows_and_cols( 41 | self, row_mask: torch.BoolTensor, col_mask: torch.BoolTensor 42 | ) -> torch.Tensor: 43 | raise NotImplementedError() 44 | 45 | def check_valid_mask(self, mask: torch.BoolTensor): 46 | # Check the mask has the correct shape 47 | if mask.ndim >= 2: 48 | # Here we are working with a batch_row_mask of shape [BATCH_SIZE, M] 49 | # Check we are selecting the same number of columns for each element of the batch 50 | n_selected_entries = mask.sum(-1) 51 | if (n_selected_entries == n_selected_entries[0]).sum() != len( 52 | n_selected_entries 53 | ) and not self.allow_different_mask_size: 54 | raise NotImplementedError( 55 | "The batch of masks contains different number of entries per batch elements. This is not supported" 56 | ) 57 | 58 | def check_valid_row_mask(self, row_mask: torch.BoolTensor): 59 | return self.check_valid_mask(row_mask) 60 | 61 | def check_valid_col_mask(self, col_mask: torch.BoolTensor): 62 | return self.check_valid_mask(col_mask) 63 | 64 | def set_active( 65 | self, 66 | row_mask: Optional[torch.BoolTensor] = None, 67 | col_mask: Optional[torch.BoolTensor] = None, 68 | ): 69 | # Basic shape checks 70 | if row_mask is not None: 71 | self.check_valid_row_mask(row_mask) 72 | 73 | if col_mask is not None: 74 | self.check_valid_col_mask(col_mask) 75 | 76 | # Update the masks 77 | self._row_mask = row_mask 78 | self._col_mask = col_mask 79 | 80 | def reset_active(self): 81 | self._row_mask = None 82 | self._col_mask = None 83 | 84 | def set_output_correction(self, correction: nn.Module): 85 | self._output_correction = correction 86 | 87 | def forward(self, x: torch.Tensor) -> torch.Tensor: 88 | y = self._forward(x) 89 | if self._output_correction is not None: 90 | y = self._output_correction(y) 91 | return y 92 | 93 | 94 | class SimulatedSparseLinear(SparseLinear): 95 | allow_different_mask_size: bool = True 96 | 97 | def store_weights(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): 98 | self.register_buffer("_weight", weight) 99 | self.register_buffer("_bias", bias) 100 | 101 | def _forward(self, x: torch.Tensor) -> torch.Tensor: 102 | end_squeeze = False 103 | if x.ndim == 1: 104 | x = x.unsqueeze(0) 105 | end_squeeze = True 106 | 107 | # Apply the column mask by zeroing out the inputs 108 | if self._col_mask is not None: 109 | # One mask per batch 110 | if self._col_mask.ndim == 1: 111 | x = x * self._col_mask.to(x.device).unsqueeze(0) 112 | elif self._col_mask.ndim >= 2: 113 | x = x * self._col_mask.to(x.device) 114 | 115 | y = F.linear(x, self._weight, self._bias) 116 | 117 | # squeeze the batch dimension if it was inflated artificially 118 | if end_squeeze: 119 | y = y.squeeze(0) 120 | 121 | return y 122 | 123 | def extra_repr(self): 124 | return f"in_features={self._weight.shape[-1]}, out_features={self._weight.shape[-2]}, bias={self._bias is not None}" 125 | -------------------------------------------------------------------------------- /tests/contextual_sparsity/test_sparse_linear.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from typing import Callable 5 | 6 | import pytest 7 | import torch 8 | from torch import nn 9 | 10 | from contextual_sparsity.nn.sparse.linear import SimulatedSparseLinear, SparseLinear 11 | from contextual_sparsity.utils import sparsify_linear 12 | 13 | input_dim = 3 14 | output_dim = 4 15 | 16 | device = "cpu" 17 | secondary_device = "cpu" 18 | 19 | 20 | @pytest.mark.parametrize( 21 | "sparsification_func", 22 | [ 23 | lambda x: sparsify_linear(x, simulated=True), 24 | ], 25 | ids=["simulated"], 26 | ) 27 | def test_sparse_linear_shapes(sparsification_func: Callable[[nn.Linear], SparseLinear]): 28 | # Create a DynamicSparseLinear layer 29 | dense_layer = nn.Linear(input_dim, output_dim, device=device) 30 | sparse_layer = sparsification_func(dense_layer) 31 | 32 | ########### 33 | # No Mask # 34 | ########### 35 | 36 | # Check the output is consistent 37 | # Without setting the active rows or columns, these should perform the same 38 | x = torch.randn((10, input_dim)).to(device) 39 | hat_y = sparse_layer(x) 40 | y = dense_layer(x) 41 | 42 | # Check the two layers perform the same operation 43 | assert torch.pow(y - hat_y, 2).sum() == 0, "Inconsistent output when nothing is sparsified" 44 | 45 | ############ 46 | # Col Mask # 47 | ############ 48 | 49 | # Set some columns as active 50 | col_mask = torch.tensor([0, 1, 1]).bool() 51 | sparse_layer.set_active(col_mask=col_mask) 52 | if isinstance(sparse_layer, SimulatedSparseLinear): 53 | masked_x = x 54 | else: 55 | masked_x = x[:, col_mask] 56 | hat_y = sparse_layer(masked_x) 57 | 58 | # Check devices 59 | assert str(x.device).startswith(device) 60 | 61 | # And shapes 62 | if not isinstance(sparse_layer, SimulatedSparseLinear): 63 | assert col_mask.sum() == sparse_layer.weight.shape[1] 64 | assert sparse_layer.bias.shape == dense_layer.bias.shape 65 | 66 | assert dense_layer.weight.shape == sparse_layer._weight.shape 67 | assert dense_layer.bias.shape == sparse_layer._bias.shape 68 | assert hat_y.shape[0] == x.shape[0] 69 | assert hat_y.shape[1] == y.shape[1] 70 | 71 | assert torch.pow(y - hat_y, 2).sum() > 0, "The two outputs should be different when masked" 72 | 73 | ################### 74 | # No Mask (Reset) # 75 | ################### 76 | sparse_layer.reset_active() 77 | hat_y = sparse_layer(x) 78 | 79 | # Check the two layers perform the same operation 80 | assert ( 81 | torch.pow(y - hat_y, 2).sum() == 0 82 | ), "Inconsistent output when nothing is sparsified (reset does not work)." 83 | 84 | ############ 85 | # Row Mask # 86 | ############ 87 | 88 | row_mask = torch.tensor([1, 0, 0, 1]).bool() 89 | sparse_layer.set_active(row_mask=row_mask) 90 | print(sparse_layer._row_mask.shape) 91 | hat_y = sparse_layer(x) 92 | 93 | # Check devices 94 | assert str(x.device).startswith(device) 95 | 96 | # And shapes 97 | if not isinstance(sparse_layer, SimulatedSparseLinear): 98 | assert row_mask.sum() == sparse_layer.weight.shape[0] 99 | assert sparse_layer.bias.shape[0] == row_mask.sum() 100 | assert hat_y.shape[1] == row_mask.sum() 101 | 102 | assert dense_layer.weight.shape == sparse_layer._weight.shape 103 | assert dense_layer.bias.shape == sparse_layer._bias.shape 104 | assert hat_y.shape[0] == x.shape[0] 105 | 106 | #################### 107 | # Row and Col Mask # 108 | #################### 109 | 110 | sparse_layer.reset_active() 111 | sparse_layer.set_active(row_mask=row_mask, col_mask=col_mask) 112 | 113 | if isinstance(sparse_layer, SimulatedSparseLinear): 114 | masked_x = x 115 | else: 116 | masked_x = x[:, col_mask] 117 | 118 | hat_y = sparse_layer(masked_x) 119 | 120 | # Check devices 121 | assert str(x.device).startswith(device) 122 | 123 | # And shapes 124 | if not isinstance(sparse_layer, SimulatedSparseLinear): 125 | assert row_mask.sum() == sparse_layer.weight.shape[0] 126 | assert col_mask.sum() == sparse_layer.weight.shape[1] 127 | assert sparse_layer.bias.shape[0] == row_mask.sum() 128 | assert hat_y.shape[1] == row_mask.sum() 129 | 130 | assert dense_layer.weight.shape == sparse_layer._weight.shape 131 | assert sparse_layer._bias.shape == dense_layer.bias.shape 132 | assert hat_y.shape[0] == x.shape[0] 133 | -------------------------------------------------------------------------------- /contextual_sparsity/evaluation/perplexity.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import os 5 | from logging import getLogger 6 | from typing import Callable, Dict, List, Optional 7 | 8 | import pandas as pd 9 | import torch 10 | from torch import nn 11 | from torch.utils.data import DataLoader 12 | 13 | from contextual_sparsity.evaluation.hooks import ( 14 | MLP_DENSITY, 15 | PERPLEXITY, 16 | CollectHooksOutput, 17 | EvaluationHook, 18 | ) 19 | from contextual_sparsity.hw_simulator.simulator import HardwareSimulator 20 | from contextual_sparsity.utils.stats import MEAN, compute_func_stats 21 | 22 | log = getLogger(__name__) 23 | MAX_PERPLEXITY = 1000 24 | 25 | 26 | def get_perplexity(results: pd.DataFrame) -> float: 27 | """ 28 | Extract perplexity from a result dataframe. 29 | """ 30 | perplexity = results[ 31 | (results["quantity"] == PERPLEXITY) 32 | & (results["stat"] == MEAN) 33 | & (results["computed_at"] == ".") 34 | ]["value"].values 35 | assert len(perplexity) == 1 36 | perplexity = perplexity[0] 37 | if perplexity != perplexity or perplexity > MAX_PERPLEXITY: 38 | perplexity = MAX_PERPLEXITY 39 | return perplexity 40 | 41 | 42 | def get_mlp_density(results: pd.DataFrame) -> float: 43 | """ 44 | Extract mlp density from a result dataframe. 45 | """ 46 | weight_density = results[ 47 | (results["quantity"] == MLP_DENSITY) 48 | & (results["stat"] == MEAN) 49 | & (results["computed_at"] == ".") 50 | ]["value"].values 51 | assert len(weight_density) == 1 52 | weight_density = weight_density[0] 53 | return weight_density 54 | 55 | 56 | def evaluate_sparse_perplexity( 57 | model: nn.Module, 58 | test_data: DataLoader, 59 | evaluation_hooks: List[EvaluationHook], 60 | preprocess_batch: Optional[Callable] = None, 61 | hw_simulator: Optional[HardwareSimulator] = None, 62 | ) -> Dict[str, float]: 63 | """ 64 | Evaluate the perplexity of a given sparse model 65 | """ 66 | 67 | # Attach all the evaluation hooks to the model 68 | for evaluation_hook in evaluation_hooks: 69 | evaluation_hook.attach_to(model) 70 | 71 | # Define a global output collection hook. 72 | # This is essentially a function that given a model input returns the output stored by all the hooks 73 | collect_outputs = CollectHooksOutput( 74 | model=model, hooks=evaluation_hooks, preprocess_batch=preprocess_batch 75 | ) 76 | 77 | model.eval() 78 | 79 | # Compute the statistics for all the values returned by the output hooks over the whole test_data 80 | stats = compute_func_stats( 81 | dataloader=test_data, 82 | func=collect_outputs, 83 | ) 84 | 85 | # Remove all the hooks 86 | for evaluation_hook in evaluation_hooks: 87 | stats = evaluation_hook.finalize(stats) 88 | evaluation_hook.remove() 89 | 90 | # The statistics are in a nested dictionary [computed_at][quantity][stat_name] = value 91 | # Convert it to a pandas dataframe for convenience 92 | results = [] 93 | for computed_at, quantity_dict in stats.items(): 94 | for quantity, stat_dict in quantity_dict.items(): 95 | for stat_name, stat_val in stat_dict.items(): 96 | if torch.is_tensor(stat_val): 97 | stat_val = stat_val.to("cpu").numpy() 98 | if stat_val.ndim == 1: 99 | stat_val = stat_val[0] 100 | 101 | results.append( 102 | { 103 | "computed_at": computed_at, 104 | "quantity": quantity, 105 | "stat": stat_name, 106 | "value": stat_val, 107 | } 108 | ) 109 | results = pd.DataFrame(results) 110 | 111 | # Store the results 112 | results.to_csv("results.csv", index=False) 113 | log.info(f"Results saved in {os.path.join(os.getcwd(), 'results.csv')}") 114 | 115 | # Get and store results from HW simulator 116 | if hw_simulator is not None: 117 | results_hwsim = hw_simulator.get_stats_df() 118 | log.info(results_hwsim) 119 | results_hwsim.to_csv("results_hwsim.csv", index=False) 120 | log.info( 121 | f"Results for HW simulator saved in {os.path.join(os.getcwd(), 'results_hwsim.csv')}" 122 | ) 123 | 124 | # Summarize the results into scalars for Optuna by reading the average weight density and perplexity 125 | summary = { 126 | MLP_DENSITY: get_mlp_density(results), 127 | "perplexity": get_perplexity(results), 128 | } 129 | 130 | return summary 131 | -------------------------------------------------------------------------------- /contextual_sparsity/scripts/llm_evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import importlib 5 | import inspect 6 | import logging 7 | from typing import Dict, List 8 | 9 | import torch 10 | from hydra.utils import instantiate 11 | from omegaconf import DictConfig, open_dict 12 | 13 | from contextual_sparsity.utils.sparsify import build_sparse_model 14 | 15 | log = logging.getLogger(__name__) 16 | 17 | 18 | def get_parameters(class_path: str) -> List[str]: 19 | """ 20 | Determine the parameters of the given function 21 | """ 22 | module_name = ".".join(class_path.split(".")[:-1]) 23 | function_name = class_path.split(".")[-1] 24 | func_builder = getattr(importlib.import_module(module_name), function_name) 25 | func_params = inspect.signature(func_builder).parameters 26 | return list(func_params) 27 | 28 | 29 | def evaluate_llm_main(conf: DictConfig) -> Dict[str, float]: 30 | """ 31 | Evaluate the LLM model according to the given configuration 32 | """ 33 | torch.set_grad_enabled(False) 34 | 35 | log.info("Instantiating the tokenizer") 36 | tokenizer = instantiate(conf.tokenizer) 37 | 38 | log.info("Instantiating the dense model") 39 | dense_model = instantiate(conf.dense_model) 40 | 41 | # Determine if dense_model, calibration_data, and tokenizer are required to build the masking hooks 42 | # If so, pass them 43 | masking_hooks_parameters = get_parameters(conf.masking_hooks._target_) 44 | kwargs = {} 45 | 46 | # Instantiate the calibration 47 | calibration_data = instantiate(conf.data.calibration, tokenizer=tokenizer) 48 | 49 | if "calibration_data" in masking_hooks_parameters: 50 | kwargs["calibration_data"] = calibration_data 51 | if "dense_model" in masking_hooks_parameters: 52 | kwargs["dense_model"] = dense_model 53 | if "tokenizer" in masking_hooks_parameters: 54 | kwargs["tokenizer"] = tokenizer 55 | 56 | log.info("Instantiating the masking functions") 57 | masking_hooks = instantiate(conf.masking_hooks, **kwargs) 58 | 59 | # Create an instance of the hardware simulator, add cache hooks to masking hooks. 60 | hardware = None 61 | if conf.hw_simulator is not None: 62 | assert conf.cache_hooks is not None 63 | log.info("Instantiating Hardware Simulator") 64 | hardware = instantiate(conf.hw_simulator, model=dense_model, masking_hooks=masking_hooks) 65 | masking_hooks = instantiate( 66 | conf.cache_hooks, masking_hooks=masking_hooks, hw_simulator=hardware 67 | ) 68 | 69 | log.info("Building the sparse model") 70 | sparse_model = build_sparse_model( 71 | masking_hooks=masking_hooks, 72 | dense_model=dense_model, 73 | ) 74 | 75 | # Add the LoRA adapter if specified 76 | if "adapter" in conf: 77 | tokenizer = tokenizer 78 | from contextual_sparsity.adapters import add_adapters 79 | 80 | add_adapters( 81 | sparse_model=sparse_model, 82 | adapter_conf=conf.adapter, 83 | ) 84 | 85 | if "load" in conf.adapter: 86 | from contextual_sparsity.adapters import load_adapters 87 | 88 | sparse_model = load_adapters( 89 | sparse_model=sparse_model, 90 | adapter_conf=conf.adapter, 91 | adapter_path=conf.adapter.load, 92 | ) 93 | elif "training" in conf.adapter: 94 | from contextual_sparsity.adapters import train_adapters 95 | 96 | train_adapters( 97 | sparse_model=sparse_model, 98 | tokenizer=tokenizer, 99 | adapter_conf=conf.adapter, 100 | ) 101 | 102 | log.info(sparse_model) 103 | 104 | # Perform all the specified evaluations 105 | merged_results = {} 106 | for name, evaluation in conf.evaluation.items(): 107 | # If the tokenizer is required, pass it to the evaluation function 108 | kwargs = {} 109 | 110 | eval_parameters = get_parameters(evaluation._target_) 111 | if "tokenizer" in eval_parameters: 112 | kwargs["tokenizer"] = tokenizer 113 | if "test_data" in eval_parameters: 114 | # Replace the configuration with its instance 115 | kwargs["test_data"] = instantiate(evaluation.test_data, tokenizer=tokenizer) 116 | with open_dict(evaluation): 117 | del evaluation["test_data"] 118 | 119 | log.info(f"Running the {name} evaluation") 120 | results = instantiate(evaluation, model=sparse_model, hw_simulator=hardware, **kwargs) 121 | merged_results.update(results) 122 | 123 | # Print all the return values 124 | for quantity, value in merged_results.items(): 125 | log.info(f"{quantity}: {value}") 126 | 127 | return merged_results 128 | -------------------------------------------------------------------------------- /contextual_sparsity/masking_hooks/trained/hook.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import logging 5 | import os 6 | from copy import deepcopy 7 | from typing import List, Optional, Union 8 | 9 | import torch 10 | from omegaconf import DictConfig, OmegaConf, open_dict 11 | from torch import nn 12 | from transformers import PreTrainedTokenizer 13 | 14 | from contextual_sparsity.mask import MaskingHook 15 | from contextual_sparsity.masking_hooks.binarization import ( 16 | BinarizationType, 17 | build_binarization, 18 | ) 19 | from contextual_sparsity.utils.layer_names import ( 20 | FC_DOWN, 21 | block_id_to_layer_ids, 22 | get_block_ids, 23 | get_layer_ids, 24 | ) 25 | 26 | log = logging.getLogger(__name__) 27 | 28 | 29 | def build_predictor_masking_hooks( 30 | layers_to_sparsify: Union[str, List[int], int], 31 | model_id: str, 32 | dense_model: torch.nn.Module, 33 | tokenizer: PreTrainedTokenizer, 34 | predictor_cache_dir: Optional[str] = None, 35 | k: Optional[int] = None, 36 | keep: Optional[float] = None, 37 | threshold: Optional[float] = None, 38 | force_retrain: bool = False, 39 | base_predictor_conf: Optional[DictConfig] = None, 40 | ) -> List[MaskingHook]: 41 | """ 42 | Factory function for building predictive masking hooks (based on a trained masking function). 43 | This includes methods based on DejaVU 44 | """ 45 | from contextual_sparsity.masking_hooks.trained.train_predictor import get_predictor 46 | 47 | # Determine the name of the layers to mask 48 | block_ids = get_block_ids(model_id=model_id, block_names=layers_to_sparsify) 49 | down_layer_ids = get_layer_ids( 50 | model_id=model_id, layer_type=FC_DOWN, layer_names=layers_to_sparsify 51 | ) 52 | down_activation_ids = [".".join([layer_id, "input"]) for layer_id in down_layer_ids] 53 | 54 | # Build the layer responsible for making the activations binary 55 | activation_binarization = build_binarization( 56 | activation_ids=down_activation_ids, 57 | model_id=model_id, 58 | dense_model=dense_model, 59 | data_id=None, 60 | calibration_data=None, 61 | binarization_type=( 62 | BinarizationType.topk.value if threshold is None else BinarizationType.threshold.value 63 | ), 64 | preprocess_batch=None, 65 | keep=keep, 66 | k=k, 67 | threshold=threshold, 68 | ) 69 | 70 | # Load the predictor conf if not provided 71 | if base_predictor_conf is None: 72 | config_file = os.path.join(".hydra", "config.yaml") 73 | if not os.path.isfile(config_file): 74 | raise FileNotFoundError( 75 | f"No config file found at {config_file}. Please specify predictor_conf" 76 | ) 77 | base_predictor_conf = OmegaConf.load(config_file).predictor 78 | 79 | OmegaConf.resolve(base_predictor_conf) 80 | 81 | # Determine which device to use 82 | device = next(dense_model.parameters()).device 83 | 84 | # Wrap each one in a corresponding masking hook 85 | masking_hooks = [] 86 | for i, block_id in enumerate(block_ids): 87 | up_layer_id, down_layer_id, gate_layer_id = block_id_to_layer_ids( 88 | block_id=block_id, model_id=model_id 89 | ) 90 | 91 | log.info(f"Preparing the predictor masking hook for layer {block_id}") 92 | mask_cols_of = [down_layer_id] 93 | mask_rows_of = [up_layer_id] 94 | if gate_layer_id is not None: 95 | mask_rows_of.append(gate_layer_id) 96 | input_from = up_layer_id 97 | input_activation_id = ".".join([input_from, "input"]) 98 | 99 | # Make a layer-specific predictor configuration from the base and the layer ids 100 | predictor_conf = deepcopy(base_predictor_conf) 101 | with open_dict(predictor_conf): 102 | predictor_conf.layer_to_mask = block_id 103 | predictor_conf.input_activation = input_activation_id 104 | 105 | masking_func = get_predictor( 106 | predictor_conf=predictor_conf, 107 | predictor_cache_dir=predictor_cache_dir, 108 | force_retrain=force_retrain, 109 | model_id=model_id, 110 | dense_model=dense_model, 111 | tokenizer=tokenizer, 112 | ).to(device) 113 | 114 | masking_hook = MaskingHook( 115 | masking_func=nn.Sequential( 116 | masking_func, 117 | activation_binarization[".".join([down_layer_id, "input"])], 118 | ), 119 | input_from=input_activation_id.replace(".input", ""), 120 | mask_cols_of=mask_cols_of, 121 | mask_rows_of=mask_rows_of, 122 | ) 123 | 124 | # When training sequentially, 125 | # add the hook to the model with the new predictor before computing the new activations 126 | 127 | masking_hooks.append(masking_hook) 128 | 129 | return masking_hooks 130 | -------------------------------------------------------------------------------- /contextual_sparsity/evaluation/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from typing import Callable, Optional 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import torch 9 | from torch import nn 10 | from torch.utils.data import DataLoader 11 | 12 | from contextual_sparsity.utils.misc import move_to_device 13 | from contextual_sparsity.utils.stats import compute_func_stats 14 | 15 | MASK_OVERLAP_PERC = "mask_overlap" 16 | PRESERVED_NORM_PERC = "preserved_norm" 17 | MSE = "mse" 18 | RSE = "rse" 19 | MSE_DIFF = "mse_diff" 20 | RSE_DIFF = "rse_diff" 21 | ACT_NORM = "act_norm" 22 | 23 | 24 | def evaluate_predictor( 25 | predictor: nn.Module, 26 | dataloader: DataLoader, 27 | down_layer: nn.Module, 28 | up_layer: nn.Module, 29 | act_fn: Callable, 30 | gate_layer: Optional[nn.Module] = None, 31 | k_spacing: int = 64, 32 | ) -> pd.DataFrame: 33 | """ 34 | Evaluate a (sparsity) predictor on a given dataloader by comparing the predictions with gt activations. 35 | """ 36 | device = next(predictor.parameters()).device 37 | 38 | # Set the correct devices and evaluation mode 39 | predictor.eval() 40 | down_layer.eval() 41 | predictor = predictor.to(device) 42 | down_layer = down_layer.to(device) 43 | up_layer = up_layer.to(device) 44 | if gate_layer is not None: 45 | gate_layer = gate_layer.to(device) 46 | 47 | # Determine the activation size 48 | first_batch = next(iter(dataloader))["x"][:2].to(device) 49 | K = predictor(first_batch).shape[-1] 50 | 51 | all_ks = np.arange(0, K)[::k_spacing] 52 | dtype = down_layer.weight.dtype 53 | 54 | # Function that is called every batch to compute the relevant metrics 55 | def compute_metrics(batch): 56 | nonlocal predictor, down_layer, all_ks 57 | 58 | batch = move_to_device(batch, device) 59 | 60 | logits = predictor(batch["x"].to(torch.float32)) 61 | 62 | # Compute the true activations 63 | # Get the input to the layer to sparsify 64 | if "x_last" in batch: 65 | x_last = batch["x_last"] 66 | else: 67 | x_last = batch["x"] 68 | up = up_layer(x_last.to(up_layer.weight.dtype)) 69 | if gate_layer is not None: 70 | gate = act_fn(gate_layer(x_last)) 71 | act = up * gate 72 | else: 73 | act = act_fn(up) 74 | 75 | ordinal_targets = torch.argsort(torch.abs(act), dim=-1, descending=True) 76 | ordinal_predicted = torch.argsort(logits, dim=-1, descending=True) 77 | 78 | results = { 79 | MASK_OVERLAP_PERC: [], 80 | PRESERVED_NORM_PERC: [], 81 | MSE: [], 82 | MSE_DIFF: [], 83 | } 84 | # Start with all-zeros masks 85 | target_mask = (act * 0).type(dtype) 86 | predicted_mask = (act * 0).type(dtype) 87 | 88 | prev_k = 0 89 | for k in all_ks: 90 | # Set the k-th largest entry to 1 91 | target_mask.scatter_(-1, ordinal_targets[:, prev_k:k], 1) 92 | predicted_mask.scatter_(-1, ordinal_predicted[:, prev_k:k], 1) 93 | 94 | # Mask overlap 95 | accuracy = (target_mask * predicted_mask).float().sum(-1) / (k + 1) 96 | results[MASK_OVERLAP_PERC].append(accuracy.unsqueeze(-1)) 97 | 98 | # Preserved Norm Percentage 99 | masked_target = (target_mask * act).type(dtype) 100 | masked_predicted = (predicted_mask * act).type(dtype) 101 | 102 | target_norm = torch.norm(masked_target, dim=-1).unsqueeze(-1) 103 | predicted_norm = torch.norm(masked_predicted, dim=-1).unsqueeze(-1) 104 | results[PRESERVED_NORM_PERC].append(predicted_norm / target_norm) 105 | 106 | # MSE 107 | topk_out = down_layer(masked_target) 108 | predictor_out = down_layer(masked_predicted) 109 | out = down_layer(act) 110 | 111 | topk_mse = (out - topk_out).pow(2).mean(-1) 112 | predictor_mse = (out - predictor_out).pow(2).mean(-1) 113 | results[MSE].append(predictor_mse.unsqueeze(-1)) 114 | results[MSE_DIFF].append(predictor_mse.unsqueeze(-1) - topk_mse.unsqueeze(-1)) 115 | 116 | prev_k = k 117 | 118 | results = {ky: torch.cat(v, -1) for ky, v in results.items()} 119 | results[ACT_NORM] = torch.norm(act, 2, -1).unsqueeze(-1) 120 | 121 | return results 122 | 123 | # Compute the function statistics over the dataloader 124 | results = compute_func_stats( 125 | dataloader=dataloader, 126 | func=compute_metrics, 127 | ) 128 | 129 | # convert to a dataframe for convenience 130 | pd_results = [] 131 | for metric in [MASK_OVERLAP_PERC, PRESERVED_NORM_PERC, MSE, MSE_DIFF]: 132 | for i, k in enumerate(all_ks): 133 | pd_results.append( 134 | { 135 | "k": k, 136 | "metric": metric, 137 | "mean": results[metric]["mean"][i].item(), 138 | "std": results[metric]["std"][i].item(), 139 | } 140 | ) 141 | 142 | # Relative squared error computation 143 | act_var = results[ACT_NORM]["std"].item() ** 2 144 | for i, k in enumerate(all_ks): 145 | pd_results.append( 146 | { 147 | "k": k, 148 | "metric": RSE, 149 | "mean": results[MSE]["mean"][i].item() / act_var, 150 | "std": results[MSE]["std"][i].item() / act_var, 151 | } 152 | ) 153 | pd_results.append( 154 | { 155 | "k": k, 156 | "metric": RSE_DIFF, 157 | "mean": results[MSE_DIFF]["mean"][i].item() / act_var, 158 | "std": results[MSE_DIFF]["std"][i].item() / act_var, 159 | } 160 | ) 161 | 162 | return pd.DataFrame(pd_results) 163 | -------------------------------------------------------------------------------- /contextual_sparsity/masking_hooks/trained/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from typing import Any, Callable, Dict, Optional, Union 5 | 6 | import torch 7 | from torch import nn 8 | from torch.distributions import Bernoulli, Independent 9 | 10 | from contextual_sparsity.nn import Abs, ThresholdMask, TopKMask 11 | 12 | 13 | def binary_crossentropy(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 14 | """ 15 | Binary crossentropy loss function 16 | """ 17 | return -Independent(Bernoulli(logits=pred), 1).log_prob(target.float()).mean() 18 | 19 | 20 | class PredictorLoss(nn.Module): 21 | def __init__( 22 | self, 23 | predictor: Callable, 24 | up_layer: Optional[nn.Module] = None, 25 | down_layer: Optional[nn.Module] = None, 26 | gate_layer: Optional[nn.Module] = None, 27 | act_fn: Optional[nn.Module] = None, 28 | double_gating: bool = False, 29 | ): 30 | super().__init__() 31 | self.predictor = predictor 32 | self.double_gating = double_gating 33 | 34 | self.up_layer = up_layer 35 | if up_layer is not None: 36 | for param in up_layer.parameters(): 37 | param.requires_grad = False 38 | self.gate_layer = gate_layer 39 | if gate_layer is not None: 40 | for param in gate_layer.parameters(): 41 | param.requires_grad = False 42 | self.down_layer = down_layer 43 | if down_layer is not None: 44 | for param in down_layer.parameters(): 45 | param.requires_grad = False 46 | 47 | self.act_fn = act_fn 48 | self.model_dtype = self.up_layer.weight.dtype 49 | 50 | def forward( 51 | self, x: torch.Tensor, x_last: Optional[torch.Tensor] = None, **kwargs 52 | ) -> Union[torch.Tensor, Dict[str, Any]]: 53 | if "act" not in kwargs: 54 | if x_last is None: 55 | x_last = x 56 | 57 | x_last = x_last.to(self.model_dtype) 58 | 59 | up = self.up_layer(x_last.to(self.model_dtype)) 60 | if self.double_gating: 61 | up = self.act_fn(up) 62 | kwargs["up"] = up.detach() 63 | 64 | if self.gate_layer is None: 65 | act = self.act_fn(up) 66 | else: 67 | gate = self.act_fn(self.gate_layer(x_last)) 68 | kwargs["gate"] = gate.detach() 69 | act = up * gate 70 | 71 | kwargs["act"] = act.detach() 72 | 73 | return self._loss(x, **kwargs) 74 | 75 | def _loss( 76 | self, 77 | x: torch.Tensor, 78 | act: torch.Tensor, 79 | up: Optional[torch.Tensor] = None, 80 | gate: Optional[torch.Tensor] = None, 81 | ) -> Union[torch.Tensor, Dict[str, Any]]: 82 | raise NotImplementedError() 83 | 84 | 85 | class ActivationLoss(PredictorLoss): 86 | def __init__( 87 | self, 88 | predictor: Callable, 89 | loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], 90 | act_transform: Optional[Callable] = None, 91 | target: str = "act", 92 | **kwargs, 93 | ): 94 | super().__init__(predictor=predictor, **kwargs) 95 | self.loss = loss 96 | self.act_transform = act_transform 97 | if target not in ["act", "up", "gate"]: 98 | raise ValueError(f'Target must be either "act" or "up" or "gate", got {target}') 99 | self.target = target 100 | 101 | def _loss(self, x: torch.Tensor, **kwargs) -> Union[torch.Tensor, Dict[str, Any]]: 102 | act = kwargs[self.target] 103 | 104 | if self.act_transform is not None: 105 | act = self.act_transform(act) 106 | 107 | act = act.to(torch.float32) 108 | pred_out = self.predictor(x.type(torch.float32)) 109 | 110 | return self.loss(pred_out, act) 111 | 112 | def extra_repr(self) -> str: 113 | return f"loss={self.loss}, act_transform={self.act_transform}" 114 | 115 | 116 | def build_abstopk_cross_entropy_loss( 117 | predictor: nn.Module, 118 | down_layer: nn.Linear, 119 | up_layer: Optional[nn.Linear] = None, 120 | gate_layer: Optional[nn.Linear] = None, 121 | act_fn: Optional[Callable] = None, 122 | target: str = "act", 123 | k: Optional[int] = None, 124 | keep: Optional[float] = None, 125 | ) -> PredictorLoss: 126 | """ 127 | Factory function for building absolute cross-entropy loss functions based on the k (or keep%) largest activations. 128 | """ 129 | if k is None and keep is None: 130 | raise ValueError("k or keep must be specified.") 131 | 132 | if k is None: 133 | k = int(keep * down_layer.weight.shape[1]) 134 | 135 | binarize = nn.Sequential( 136 | Abs(), 137 | TopKMask(k=k), 138 | ) 139 | 140 | return ActivationLoss( 141 | predictor=predictor, 142 | act_transform=binarize, 143 | loss=binary_crossentropy, 144 | up_layer=up_layer, 145 | gate_layer=gate_layer, 146 | act_fn=act_fn, 147 | target=target, 148 | ) 149 | 150 | 151 | def build_absthreshold_cross_entropy_loss( 152 | predictor: nn.Module, 153 | threshold: float, 154 | up_layer: Optional[nn.Linear] = None, 155 | gate_layer: Optional[nn.Linear] = None, 156 | down_layer: Optional[nn.Linear] = None, 157 | act_fn: Optional[Callable] = None, 158 | target: str = "act", 159 | ) -> PredictorLoss: 160 | """ 161 | Factory function for building absolute cross-entropy loss functions based on a fixed threshold. 162 | """ 163 | binarize = nn.Sequential( 164 | Abs(), 165 | ThresholdMask(threshold=threshold), 166 | ) 167 | 168 | return ActivationLoss( 169 | predictor=predictor, 170 | act_transform=binarize, 171 | loss=binary_crossentropy, 172 | up_layer=up_layer, 173 | gate_layer=gate_layer, 174 | act_fn=act_fn, 175 | down_layer=down_layer, 176 | target=target, 177 | ) 178 | -------------------------------------------------------------------------------- /contextual_sparsity/evaluation/hooks/memory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from typing import Any, Dict 5 | 6 | import torch 7 | from torch import nn 8 | 9 | from contextual_sparsity.evaluation.hooks.base import EvaluationHook 10 | from contextual_sparsity.hw_simulator.utils import ( 11 | calculate_footprint, 12 | get_dimensions_from_model, 13 | precision_to_bytes, 14 | ) 15 | from contextual_sparsity.utils.layer_names import ( 16 | block_id_to_layer_ids, 17 | get_block_ids, 18 | has_gate, 19 | layer_id_to_block_id, 20 | ) 21 | 22 | WEIGHT_DENSITY = "weight_density" 23 | MLP_DENSITY = "mlp_density" 24 | MEMORY = "memory" 25 | MLP_MEMORY = "mlp_memory" 26 | MB = 1000.0**2 27 | 28 | 29 | class Memory(EvaluationHook): 30 | """ 31 | Evaluate the memory usage based on the activation density. 32 | """ 33 | 34 | metric_dims = {WEIGHT_DENSITY: 1, MEMORY: 1, MLP_MEMORY: 1, MLP_DENSITY: 1} 35 | 36 | def __init__(self, model_id: str, precision: Dict[str, int], sequence_length: int): 37 | super().__init__() 38 | self.model_id = model_id 39 | self.precision = precision_to_bytes(precision) 40 | self.dimensions = get_dimensions_from_model(model_id) 41 | self.sequence_length = sequence_length 42 | self.block_ids = get_block_ids(model_id, block_names="all") 43 | 44 | # Static memory in Bytes 45 | self.head_memory = calculate_footprint( 46 | layer_types=["lm_head"], 47 | precision=self.precision, 48 | dimensions=self.dimensions, 49 | seq_len=self.sequence_length, 50 | verbose=False, 51 | ) 52 | 53 | self.attention_memory = calculate_footprint( 54 | layer_types=["attention", "kv_cache"], 55 | precision=self.precision, 56 | dimensions=self.dimensions, 57 | seq_len=self.sequence_length, 58 | verbose=False, 59 | ) 60 | 61 | self.W_params = self.dimensions["intermediate_size"] * self.dimensions["hidden_size"] 62 | if has_gate(model_id): 63 | mlp_params = self.W_params * 3 64 | else: 65 | mlp_params = self.W_params * 2 66 | self.mlp_n_params = mlp_params 67 | self.mlp_base_memory = mlp_params * self.precision["mlp"] 68 | 69 | self.predictor_parameters = {} 70 | 71 | def collect_results(self, module, input, kwargs, output): 72 | return {WEIGHT_DENSITY: output.float().mean(-1).unsqueeze(-1)} 73 | 74 | def _post_process_batch( 75 | self, batch_stats: Dict[str, Dict[str, Any]] 76 | ) -> Dict[str, Dict[str, Any]]: 77 | # Determine the shape and device 78 | batch = next(iter(batch_stats.values()))[WEIGHT_DENSITY] 79 | batch_size = batch.shape 80 | device = batch.device 81 | 82 | total_mlp_memory = 0 83 | total_mlp_parameters = 0 84 | 85 | for block_id in self.block_ids: 86 | up_layer_id, down_layer_id, gate_layer_id = block_id_to_layer_ids( 87 | model_id=self.model_id, block_id=block_id 88 | ) 89 | 90 | # If the up, down or gate layers do not have a density value, we assume density=0 91 | if down_layer_id not in batch_stats: 92 | batch_stats[down_layer_id] = {WEIGHT_DENSITY: torch.zeros(batch_size).to(device)} 93 | if up_layer_id not in batch_stats: 94 | batch_stats[up_layer_id] = {WEIGHT_DENSITY: torch.zeros(batch_size).to(device)} 95 | if gate_layer_id is not None: 96 | if gate_layer_id not in batch_stats: 97 | batch_stats[gate_layer_id] = { 98 | WEIGHT_DENSITY: torch.zeros(batch_size).to(device) 99 | } 100 | 101 | # Compute the Memory usage for the active MLP weights 102 | mlp_memory = ( 103 | batch_stats[down_layer_id][WEIGHT_DENSITY] 104 | + batch_stats[up_layer_id][WEIGHT_DENSITY] 105 | ) 106 | if gate_layer_id is not None: 107 | mlp_memory = mlp_memory + batch_stats[gate_layer_id][WEIGHT_DENSITY] 108 | mlp_active_params = mlp_memory * self.W_params 109 | mlp_memory = mlp_active_params * self.precision["mlp"] 110 | 111 | # Consider the memory for the predictor 112 | predictor_active_params = self.predictor_parameters[block_id] 113 | predictor_memory = predictor_active_params * self.precision["predictors"] 114 | 115 | # Compute the memory percentage and total MLP memory and add them to the stats 116 | mlp_memory = mlp_memory + predictor_memory 117 | mlp_param = predictor_active_params + mlp_active_params 118 | 119 | mlp_param_percentage = mlp_param / self.mlp_n_params 120 | batch_stats[block_id] = { 121 | MLP_DENSITY: mlp_param_percentage, 122 | MLP_MEMORY: mlp_memory, 123 | } 124 | 125 | total_mlp_memory = total_mlp_memory + mlp_memory 126 | total_mlp_parameters = total_mlp_parameters + mlp_param 127 | 128 | # Compute the average MLP density 129 | mlp_density = total_mlp_parameters / self.mlp_n_params / len(self.block_ids) 130 | total_memory = total_mlp_memory + self.head_memory + self.attention_memory 131 | 132 | batch_stats["."] = { 133 | MLP_DENSITY: mlp_density, 134 | MLP_MEMORY: total_mlp_memory / MB, 135 | MEMORY: total_memory / MB, 136 | } 137 | return batch_stats 138 | 139 | def attach_to(self, model: nn.Module): 140 | if hasattr(model, "masking_hooks"): 141 | for masking_hook in model.masking_hooks: 142 | for layer_id in masking_hook.mask_cols_of: 143 | self._attach_to(masking_hook.masking_func, attached_to=layer_id) 144 | for layer_id in masking_hook.mask_rows_of: 145 | self._attach_to(masking_hook.masking_func, attached_to=layer_id) 146 | 147 | # Compute the memory usage for the predictor 148 | n_params = sum([param.numel() for param in masking_hook.parameters()]) 149 | block_id = layer_id_to_block_id( 150 | model_id=self.model_id, submodule_id=masking_hook.input_from 151 | ) 152 | self.predictor_parameters[block_id] = n_params 153 | -------------------------------------------------------------------------------- /contextual_sparsity/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import logging 5 | import os.path 6 | from typing import Any, List, Optional, Set, Tuple, Type, Union 7 | 8 | import torch 9 | from datasets import load_dataset 10 | from torch import nn 11 | from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel 12 | 13 | # A logger for this file 14 | log = logging.getLogger(__name__) 15 | 16 | 17 | def pairwise_disjoint(sets: List[Set]) -> bool: 18 | """ 19 | Checks if a collection of sets is pairwise disjoint. 20 | A collection of sets is pairwise disjoint if any two sets in the collection are disjoint. 21 | """ 22 | union = set().union(*sets) 23 | return len(union) == sum(map(len, sets)) 24 | 25 | 26 | def parse_dtype(dtype: Optional[Union[str, torch.dtype]]) -> Optional[torch.dtype]: 27 | """ 28 | Parse a dtype string to a torch.dtype if necessary. 29 | """ 30 | if dtype is None: 31 | return None 32 | 33 | if isinstance(dtype, torch.dtype): 34 | torch_dtype = dtype 35 | elif isinstance(dtype, str): 36 | torch_dtype = getattr(torch, dtype) 37 | else: 38 | raise ValueError(f"Unsupported dtype {dtype}") 39 | 40 | assert isinstance(torch_dtype, torch.dtype) 41 | return torch_dtype 42 | 43 | 44 | def download_tokenizer( 45 | repo_id: str, 46 | download_dir: str, 47 | ): 48 | """ 49 | Download pre-trained tokenizer 50 | """ 51 | log.info(f"Loading a pretrained {repo_id} tokenizer from huggingface") 52 | tokenizer = AutoTokenizer.from_pretrained(repo_id, force_download=True) 53 | 54 | log.info(f"Tokenizer downloaded and stored to {download_dir}") 55 | tokenizer.save_pretrained(download_dir) 56 | 57 | 58 | def download_hf_model( 59 | repo_id: str, 60 | download_dir: str, 61 | max_num_downloads: int = 5, 62 | model_type: Type[PreTrainedModel] = AutoModelForCausalLM, 63 | ): 64 | """ 65 | Download pre-trained HuggingFace model 66 | """ 67 | model = None 68 | 69 | while max_num_downloads > 0: 70 | try: 71 | log.info(f"Loading a pretrained {repo_id} from huggingface with type {model_type}") 72 | model = model_type.from_pretrained( 73 | pretrained_model_name_or_path=repo_id, 74 | resume_download=True, 75 | trust_remote_code=True, 76 | low_cpu_mem_usage=True, 77 | device_map="cpu", 78 | torch_dtype="auto", 79 | ) 80 | max_num_downloads = 0 81 | except OSError: 82 | max_num_downloads -= 1 83 | if model is None: 84 | raise NotImplementedError(f"Could not download the model {repo_id}") 85 | 86 | log.info(f"Model {repo_id} loaded and stored in {download_dir}") 87 | model.save_pretrained(download_dir) 88 | 89 | 90 | def download_hf_model_and_tokenizer( 91 | repo_id: str, 92 | download_dir: str, 93 | max_num_downloads: int = 5, 94 | model_type: Optional[Type[PreTrainedModel]] = None, 95 | ): 96 | """ 97 | Download pre-trained HuggingFace model and tokenizer 98 | """ 99 | download_hf_model(repo_id, download_dir, max_num_downloads, model_type) 100 | download_tokenizer(repo_id, download_dir) 101 | 102 | 103 | def download_dataset( 104 | dataset_id: str, 105 | download_dir: str, 106 | name: Optional[str], 107 | split: Optional[str] = None, 108 | cache_dir: Optional[str] = None, 109 | ): 110 | """ 111 | Download dataset from huggingface 112 | """ 113 | log.info(f"Downloading the {dataset_id} dataset from huggingface") 114 | dataset = load_dataset(path=dataset_id, name=name, split=split, cache_dir=cache_dir) 115 | 116 | if name is not None: 117 | download_dir = os.path.join(download_dir, name) 118 | if split is not None: 119 | download_dir = os.path.join(download_dir, split) 120 | 121 | log.info(f"Dataset downloaded and stored in {download_dir}") 122 | dataset.save_to_disk(download_dir) 123 | 124 | 125 | def move_to_device(batch: Any, device: Union[str, torch.device]) -> Any: 126 | """ 127 | Move a batch to a specified device 128 | """ 129 | if isinstance(batch, dict): 130 | batch = {k: v.to(device) for k, v in batch.items()} 131 | elif isinstance(batch, torch.Tensor): 132 | batch = batch.to(device) 133 | elif isinstance(batch, list): 134 | batch = [v.to(device) for v in batch] 135 | else: 136 | raise NotImplementedError() 137 | return batch 138 | 139 | 140 | def cast_to(batch: Any, dtype: torch.dtype) -> Any: 141 | """ 142 | Cast a batch to a specified dtype 143 | """ 144 | if isinstance(batch, dict): 145 | batch = {k: v.type(dtype) for k, v in batch.items()} 146 | elif isinstance(batch, torch.Tensor): 147 | batch = batch.type(dtype) 148 | elif isinstance(batch, list): 149 | batch = [v.type(dtype) for v in batch] 150 | else: 151 | raise NotImplementedError() 152 | return batch 153 | 154 | 155 | def get_batch_size(batch: Any) -> int: 156 | """ 157 | Determine the batch size of a batch 158 | """ 159 | if isinstance(batch, dict): 160 | batch_size = next(iter(batch.values())).shape[0] 161 | elif isinstance(batch, torch.Tensor): 162 | batch_size = batch.shape[0] 163 | elif isinstance(batch, list): 164 | batch_size = next(iter(batch)).shape[0] 165 | else: 166 | raise NotImplementedError() 167 | return batch_size 168 | 169 | 170 | def split_gate_up_layer(gate_up_layer: nn.Linear) -> Tuple[nn.Linear, nn.Linear]: 171 | """ 172 | Split a gate-up linear layer into two linear layers 173 | """ 174 | # Make the two linear layers out of the one up_gate_layer 175 | w_gate, w_up = torch.chunk(gate_up_layer.weight, 2, dim=0) 176 | if gate_up_layer.bias is not None: 177 | b_gate, b_up = torch.chunk(gate_up_layer.bias, 2, dim=0) 178 | else: 179 | b_gate, b_up = None, None 180 | 181 | fc_gate = nn.Linear(w_gate.shape[1], w_gate.shape[0], bias=b_gate is not None) 182 | fc_gate.weight.data = w_gate.data 183 | if b_gate is not None: 184 | fc_gate.bias.data = b_gate.data 185 | 186 | fc_up = nn.Linear(w_up.shape[1], w_up.shape[0], bias=b_up is not None) 187 | fc_up.weight.data = w_up.data 188 | if b_up is not None: 189 | fc_up.bias.data = b_up.data 190 | 191 | return fc_gate, fc_up 192 | -------------------------------------------------------------------------------- /tests/contextual_sparsity/hw_simulator/test_cache.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import pytest 5 | import torch 6 | 7 | from contextual_sparsity.hw_simulator.cache import ( 8 | BeladyCache, 9 | LFUMaskFirstCache, 10 | LRUCache, 11 | ) 12 | 13 | 14 | @pytest.mark.parametrize( 15 | "cache, slot_counts, mask, expected_cache, expected_slot_counts, expectation", 16 | [ 17 | ( 18 | torch.tensor([1, 1, 1, 0, 0], dtype=torch.bool), 19 | torch.zeros(5, dtype=torch.int), 20 | torch.tensor([0, 0, 1, 1, 1], dtype=torch.bool), 21 | torch.tensor([0, 0, 1, 1, 1], dtype=torch.bool), 22 | torch.tensor([0, 0, 0, 0, 0], dtype=torch.int), 23 | None, 24 | ), # cache new ones and keep overlap 25 | ( 26 | torch.tensor([1, 1, 1, 0, 0], dtype=torch.bool), 27 | torch.tensor([0, 1, 2, 0, 0], dtype=torch.int), 28 | torch.tensor([0, 0, 1, 1, 0], dtype=torch.bool), 29 | torch.tensor([1, 0, 1, 1, 0], dtype=torch.bool), 30 | torch.tensor([1, 0, 0, 0, 0], dtype=torch.int), 31 | None, 32 | ), # keep overlap even if not used recently, empty cache based on recent usage 33 | ( 34 | torch.zeros(5, dtype=torch.bool), 35 | torch.zeros(5, dtype=torch.int), 36 | torch.ones(6, dtype=torch.bool), 37 | None, 38 | None, 39 | "error", 40 | ), # mask larger than layer cache 41 | ], 42 | ) 43 | def test_cache_logic_lru( 44 | cache, slot_counts, mask, expected_cache, expected_slot_counts, expectation 45 | ): 46 | allow_mlp_streaming = False if expectation == "error" else True 47 | hw_cache = LRUCache( 48 | size_per_idx=1, 49 | precision=1, 50 | max_cache_size=3, 51 | max_index=len(cache), 52 | device="cpu", 53 | allow_mlp_streaming=allow_mlp_streaming, 54 | ) 55 | hw_cache.cache = cache 56 | hw_cache.slot_counts = slot_counts 57 | 58 | if expectation == "error": 59 | with pytest.raises(AssertionError): 60 | hw_cache.update(mask) 61 | else: 62 | hw_cache.update(mask) 63 | assert torch.equal(hw_cache.cache, expected_cache) 64 | assert torch.equal(hw_cache.slot_counts, expected_slot_counts) 65 | 66 | 67 | @pytest.mark.parametrize( 68 | "cache, slot_counts, mask, expected_cache, expected_slot_counts, expectation", 69 | [ 70 | ( 71 | torch.tensor([1, 1, 1, 0, 0], dtype=torch.bool), 72 | torch.ones(5, dtype=torch.int), 73 | torch.tensor([0, 0, 1, 1, 1], dtype=torch.bool), 74 | torch.tensor([0, 0, 1, 1, 1], dtype=torch.bool), 75 | torch.tensor([1, 1, 2, 2, 2], dtype=torch.int), 76 | None, 77 | ), # in equal hit-rate, new mask should be saved 78 | ( 79 | torch.tensor([1, 1, 1, 0, 0], dtype=torch.bool), 80 | torch.tensor([2, 3, 1, 1, 1], dtype=torch.int), 81 | torch.tensor([0, 0, 1, 1, 0], dtype=torch.bool), 82 | torch.tensor([0, 1, 1, 1, 0], dtype=torch.bool), 83 | torch.tensor([2, 3, 2, 2, 1], dtype=torch.int), 84 | None, 85 | ), # prioritize mask and then high hit-rates 86 | ( 87 | torch.zeros(5, dtype=torch.bool), 88 | torch.zeros(5, dtype=torch.int), 89 | torch.ones(6, dtype=torch.bool), 90 | None, 91 | None, 92 | "error", 93 | ), # mask larger than layer cache 94 | ], 95 | ) 96 | def test_cache_logic_mask_first_lfu( 97 | cache, slot_counts, mask, expected_cache, expected_slot_counts, expectation 98 | ): 99 | allow_mlp_streaming = False if expectation == "error" else True 100 | hw_cache = LFUMaskFirstCache( 101 | size_per_idx=1, 102 | precision=1, 103 | max_cache_size=3, 104 | max_index=len(cache), 105 | device="cpu", 106 | allow_mlp_streaming=allow_mlp_streaming, 107 | ) 108 | hw_cache.cache = cache.clone() 109 | hw_cache.slot_counts = slot_counts.clone() 110 | 111 | if expectation == "error": 112 | with pytest.raises(AssertionError): 113 | hw_cache.update(mask) 114 | else: 115 | hw_cache.update(mask) 116 | assert torch.equal(hw_cache.cache, expected_cache), ( 117 | hw_cache.cache, 118 | hw_cache.slot_counts, 119 | ) 120 | assert torch.equal(hw_cache.slot_counts, expected_slot_counts), ( 121 | hw_cache.cache, 122 | hw_cache.slot_counts, 123 | ) 124 | 125 | 126 | @pytest.mark.parametrize( 127 | "cache, prev_slot_counts, mask, expected_cache, next_slot_counts, expectation", 128 | [ 129 | ( 130 | torch.tensor([1, 1, 1, 0, 0], dtype=torch.bool), 131 | torch.ones(5, dtype=torch.int), 132 | torch.tensor([0, 0, 1, 1, 1], dtype=torch.bool), 133 | torch.tensor([0, 0, 1, 1, 1], dtype=torch.bool), 134 | torch.tensor([1, 1, 1, 1, 1], dtype=torch.int), 135 | None, 136 | ), # prioritize mask 137 | ( 138 | torch.tensor([1, 1, 1, 0, 0], dtype=torch.bool), 139 | torch.tensor([2, 3, 1, 1, 1], dtype=torch.int), 140 | torch.tensor([0, 0, 1, 1, 0], dtype=torch.bool), 141 | torch.tensor([1, 0, 1, 1, 0], dtype=torch.bool), 142 | torch.tensor([2, 3, 1, 1, 1], dtype=torch.int), 143 | None, 144 | ), # prioritize mask and then short horizon 145 | ( 146 | torch.zeros(5, dtype=torch.bool), 147 | torch.zeros(5, dtype=torch.int), 148 | torch.ones(6, dtype=torch.bool), 149 | None, 150 | None, 151 | "error", 152 | ), # mask larger than layer cache 153 | ], 154 | ) 155 | def test_cache_logic_mask_first_belady( 156 | cache, prev_slot_counts, mask, expected_cache, next_slot_counts, expectation 157 | ): 158 | allow_mlp_streaming = False if expectation == "error" else True 159 | hw_cache = BeladyCache( 160 | size_per_idx=1, 161 | precision=1, 162 | max_cache_size=3, 163 | max_index=len(cache), 164 | device="cpu", 165 | allow_mlp_streaming=allow_mlp_streaming, 166 | ) 167 | hw_cache.cache = cache.clone() 168 | hw_cache.slot_counts = prev_slot_counts.clone() 169 | 170 | if expectation == "error": 171 | with pytest.raises(AssertionError): 172 | hw_cache.update(mask, next_slot_counts) 173 | else: 174 | hw_cache.update(mask, next_slot_counts) 175 | assert torch.equal(hw_cache.cache, expected_cache), ( 176 | hw_cache.cache, 177 | hw_cache.slot_counts, 178 | ) 179 | assert torch.equal(hw_cache.slot_counts, next_slot_counts), ( 180 | hw_cache.cache, 181 | hw_cache.slot_counts, 182 | ) 183 | -------------------------------------------------------------------------------- /tests/contextual_sparsity/hw_simulator/test_simulator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import os 5 | 6 | import pytest 7 | import torch 8 | from hydra import compose, initialize 9 | from hydra.utils import instantiate 10 | 11 | from contextual_sparsity.hw_simulator.cache import BeladyCache 12 | from contextual_sparsity.hw_simulator.constants import MODEL_ID_TO_DIMS 13 | from contextual_sparsity.utils.layer_names import FC_DOWN, get_layer_id 14 | 15 | 16 | def test_dynamic_token_generation(tmpdir): 17 | device = f'{"cuda:0" if torch.cuda.is_available() else "cpu"}' 18 | 19 | model_id = "opt-350M" 20 | model_dims = MODEL_ID_TO_DIMS[model_id] 21 | keep = 0.5 22 | cache_strategy = "lru" 23 | with initialize(version_base="1.3", config_path="pkg://scripts/config"): 24 | overrides = [ 25 | "experiment=evaluate_llm", 26 | "+hw_simulator=default", 27 | "data=dummy", 28 | "masking_hooks=glu_pruning", 29 | f"masking_hooks.keep={keep}", 30 | f"dense_model={model_id}", 31 | "hw_simulator.sequence_length=2", 32 | "hw_simulator.prompt_length=1", 33 | f"hw_simulator.device={device}", 34 | f"hw_simulator.cache_strategy={cache_strategy}", 35 | ] 36 | conf = compose("config.yaml", overrides) 37 | os.chdir(tmpdir) 38 | 39 | dense_model = instantiate(conf.dense_model) 40 | masking_hooks = instantiate(conf.masking_hooks, dense_model=dense_model) 41 | 42 | hardware = instantiate(conf.hw_simulator, model=dense_model, masking_hooks=masking_hooks) 43 | 44 | layer_key = get_layer_id(model_id=model_id, layer_type=FC_DOWN, layer_name=1) 45 | mask = torch.zeros(model_dims["intermediate_size"], dtype=torch.bool, device=device) 46 | k = int(keep * model_dims["intermediate_size"]) 47 | mask[:k] = True 48 | 49 | # for generation of the first token, all necessary weights will be read from Flash, no matter cache strategy 50 | transfer_footprint = k * model_dims["hidden_size"] * hardware.precision["mlp"] 51 | first_token_time = transfer_footprint / hardware.flash_io_speed 52 | hardware.write_to_memory(layer_key=layer_key, cur_mask=mask) 53 | 54 | # first token generation time should be equal to moving selected weights of the selected layer from Flash 55 | torch.testing.assert_close(first_token_time, hardware.current_token_generation_dynamic) 56 | 57 | hardware.write_to_memory(layer_key=layer_key, cur_mask=mask) 58 | # For the second token, all weights are in DRAM already 59 | second_token_time = transfer_footprint / hardware.dram_io_speed 60 | torch.testing.assert_close( 61 | first_token_time + second_token_time, 62 | hardware.current_token_generation_dynamic, 63 | ) 64 | 65 | hardware._reset() 66 | assert ( 67 | hardware.current_token_generation_dynamic == 0 68 | ), "token generation time should be 0 after reset" 69 | 70 | 71 | def test_static_elapsed_time(tmpdir): 72 | device = f'{"cuda:0" if torch.cuda.is_available() else "cpu"}' 73 | 74 | model_id = "opt-350M" 75 | with initialize(version_base="1.3", config_path="pkg://scripts/config"): 76 | overrides = [ 77 | "experiment=evaluate_llm", 78 | "+hw_simulator=default", 79 | "data=dummy", 80 | "masking_hooks=glu_pruning", 81 | "masking_hooks.keep=1.0", 82 | f"dense_model={model_id}", 83 | "hw_simulator.sequence_length=2", 84 | "hw_simulator.prompt_length=1", 85 | f"hw_simulator.device={device}", 86 | "hw_simulator.dram.layers_dynamic=[]", 87 | ] 88 | conf = compose("config.yaml", overrides) 89 | os.chdir(tmpdir) 90 | 91 | dense_model = instantiate(conf.dense_model) 92 | masking_hooks = instantiate(conf.masking_hooks, dense_model=dense_model) 93 | 94 | hardware = instantiate(conf.hw_simulator, model=dense_model, masking_hooks=masking_hooks) 95 | 96 | # prompt encoding and static token generation should increase corresponding elapsed time and NOT reset (currently it is fixed) 97 | assert hardware.current_prompt_encoding > 0 98 | assert hardware.current_token_generation_fixed > 0 99 | hardware._reset() 100 | assert hardware.current_prompt_encoding > 0 101 | assert hardware.current_token_generation_fixed > 0 102 | 103 | 104 | @pytest.mark.parametrize( 105 | "seq_mask, expected_cache", 106 | [ 107 | ( 108 | [ 109 | [0, 0, 1, 1, 0], 110 | [1, 1, 0, 0, 0], 111 | [0, 1, 0, 0, 1], 112 | [0, 1, 1, 0, 0], 113 | [0, 1, 1, 0, 0], 114 | [1, 1, 1, 1, 1], 115 | ], 116 | [0, 1, 1, 0, 1], 117 | ), 118 | ( 119 | [ 120 | [0, 0, 1, 1, 0], 121 | [0, 1, 0, 0, 1], 122 | [1, 1, 0, 0, 0], 123 | [1, 0, 0, 0, 1], 124 | [0, 0, 0, 1, 1], 125 | [1, 1, 1, 1, 1], 126 | ], 127 | [0, 1, 0, 1, 1], 128 | ), 129 | ], 130 | ) 131 | def test_playback_seq_masking(tmpdir, seq_mask, expected_cache): 132 | device = f'{"cuda:0" if torch.cuda.is_available() else "cpu"}' 133 | 134 | model_id = "opt-350M" 135 | with initialize(version_base="1.3", config_path="pkg://scripts/config"): 136 | overrides = [ 137 | "experiment=evaluate_llm", 138 | "+hw_simulator=default", 139 | "data=dummy", 140 | "masking_hooks=glu_pruning", 141 | "masking_hooks.keep=1.0", 142 | f"dense_model={model_id}", 143 | f"hw_simulator.sequence_length={len(seq_mask)}", 144 | "hw_simulator.prompt_length=1", 145 | f"hw_simulator.device={device}", 146 | "hw_simulator.cache_strategy=belady", 147 | ] 148 | conf = compose("config.yaml", overrides) 149 | os.chdir(tmpdir) 150 | 151 | dense_model = instantiate(conf.dense_model) 152 | masking_hooks = instantiate(conf.masking_hooks, dense_model=dense_model) 153 | 154 | hardware = instantiate(conf.hw_simulator, model=dense_model, masking_hooks=masking_hooks) 155 | 156 | hw_cache = BeladyCache( 157 | size_per_idx=1, precision=1, max_cache_size=3, max_index=5, device=device 158 | ) 159 | 160 | layer_key = get_layer_id(model_id=model_id, layer_type=FC_DOWN, layer_name=1) 161 | 162 | hardware.seq_mask[layer_key] = torch.tensor(seq_mask, dtype=torch.bool, device=device) 163 | hardware.caches[layer_key] = hw_cache 164 | hardware.layer_call_counter[layer_key] = 5 165 | hardware._counter_forward_calls[layer_key] = 5 166 | 167 | hardware._reset() 168 | 169 | assert all(torch.eq(hw_cache.cache, torch.tensor(expected_cache, device=device))), ( 170 | hw_cache.cache, 171 | expected_cache, 172 | ) 173 | -------------------------------------------------------------------------------- /contextual_sparsity/utils/tokenizers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import logging 5 | from typing import Any, Dict, List, Optional, Type, Union 6 | 7 | import torch 8 | from datasets import Dataset 9 | from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerBase 10 | 11 | from contextual_sparsity.utils.misc import parse_dtype 12 | 13 | # A logger for this file 14 | log = logging.getLogger(__name__) 15 | 16 | 17 | def tokenize_for_language_modeling( 18 | tokenizer: PreTrainedTokenizer, 19 | data: Dataset, 20 | sequence_length: int = 1024, 21 | sliding_window_length: int = -1, 22 | data_key: str = "text", 23 | batch_size: Optional[int] = None, 24 | num_workers: int = 1, 25 | separator: str = "\n\n", 26 | keep_in_memory: bool = True, 27 | ) -> Dataset: 28 | """Tokenization for perplexity evaluation. 29 | Tokenize all sentences in ``experiment`` and concatenate them to form blocks 30 | of size `sequence_length`. Optionally, the concatenation can be done in a sliding 31 | window approach such that overlapping tokens are masked in the ground-truth labels. 32 | This is inspired from this Huggingface post on computing perplexity of fixe-length models 33 | https://huggingface.co/docs/transformers/perplexity 34 | 35 | Args: 36 | tokenizer: Pretrained tokenizer 37 | data: Huggingface dataset 38 | sequence_length: Length of each block in the final dataset 39 | sliding_window_length: Length of sliding window. If None, this will be set to the 40 | length of the context; i.e. blocks will be non-overlapping 41 | batch_size: Batch size for processing. If None, this will be set to the 42 | length of the dataset, i.e. proper concatenation. This should be preferred 43 | for small datasets. 44 | num_workers: Number of processes to use in map functions. 45 | data_key: Key that indexes the input sentence in ``experiment`` 46 | separator: String separator used to join/concat the input sentences 47 | keep_in_memory: Keep the dataset in memory instead of writing it to a cache file. 48 | Returns: 49 | Dataset: A dataset containing: 50 | * `input_ids`: A block of token of size `sequence_length`, possibly overlapping 51 | with neighboring blocks when `striding_window_length` < `sequence_length` 52 | * `attention_mask`: Corresponding attention mask 53 | * `labels`: Corresponding ground-truth labels (*not* shifted) where tokens 54 | overlapping between windows are masked 55 | """ 56 | # Tokenize the dataset 57 | try: 58 | tokenized_data = data.map( 59 | lambda samples: tokenizer(separator.join(samples[data_key])), 60 | remove_columns=[data_key], 61 | batched=True, 62 | batch_size=batch_size, 63 | num_proc=num_workers, 64 | keep_in_memory=keep_in_memory, 65 | ) 66 | except: 67 | tokenized_data = data.map( 68 | lambda samples: tokenizer(separator.join(samples[data_key])), 69 | remove_columns=[data_key], 70 | batched=True, 71 | batch_size=batch_size, 72 | keep_in_memory=keep_in_memory, 73 | ) 74 | tokenized_data.set_format(type="torch", columns=["input_ids", "attention_mask"]) 75 | 76 | # Concat to get windows of size context length 77 | if sliding_window_length < 0: 78 | sliding_window_length = sequence_length 79 | assert 1 <= sliding_window_length <= sequence_length 80 | 81 | def __get_sliding_windows__(tokenized_samples: Dict[str, Any]) -> Dict[str, Any]: 82 | out: Dict[str, List[Any]] = { 83 | "input_ids": [], 84 | "attention_mask": [], 85 | "labels": [], 86 | } 87 | for start in range(0, len(tokenized_samples["input_ids"]), sliding_window_length): 88 | # if the last window is too small, all the tokens are already 89 | # covered by previous windows, and we can ignore this sample 90 | if ( 91 | len(tokenized_samples["input_ids"][start:]) 92 | <= sliding_window_length 93 | < sequence_length 94 | ): 95 | continue 96 | 97 | # collect a window of size context length 98 | for key in ["input_ids", "attention_mask"]: 99 | out[key].append(tokenized_samples[key][start : start + sequence_length]) 100 | out["labels"].append(out["input_ids"][-1].clone()) 101 | 102 | # pad if smaller than context length 103 | if len(out["input_ids"][-1]) < sequence_length: 104 | padding = sequence_length - len(out["input_ids"][-1]) 105 | out["input_ids"][-1] = torch.nn.functional.pad( 106 | out["input_ids"][-1], (0, padding), value=0 107 | ) 108 | out["attention_mask"][-1] = torch.nn.functional.pad( 109 | out["attention_mask"][-1], (0, padding), value=0 110 | ) 111 | out["labels"][-1] = torch.nn.functional.pad( 112 | out["labels"][-1], (0, padding), value=-100 113 | ) 114 | 115 | # ignore loss for tokens overlapping in the previous window 116 | if start > 0: 117 | out["labels"][-1][:-sliding_window_length] = -100 118 | return out 119 | 120 | assert ( 121 | batch_size is None or batch_size >= sequence_length 122 | ), "batch size too small, all generated sequences will have trailing zero-padding!" 123 | 124 | # gather 125 | try: 126 | tokenized_data = tokenized_data.map( 127 | __get_sliding_windows__, 128 | batched=True, 129 | batch_size=batch_size or len(tokenized_data), 130 | num_proc=num_workers, 131 | ) 132 | except: 133 | tokenized_data = tokenized_data.map( 134 | __get_sliding_windows__, 135 | batched=True, 136 | batch_size=batch_size or len(tokenized_data), 137 | ) 138 | assert tokenized_data[0]["input_ids"].shape[0] == sequence_length, ( 139 | tokenized_data[0]["input_ids"].shape[0], 140 | sequence_length, 141 | ) 142 | 143 | return tokenized_data 144 | 145 | 146 | def load_tokenizer( 147 | pretrained_model_path: str, 148 | use_fast_tokenizer: bool = True, 149 | dtype: Optional[Union[str, torch.dtype]] = None, 150 | tokenizer_type: Type[PreTrainedTokenizerBase] = AutoTokenizer, 151 | ) -> PreTrainedTokenizerBase: 152 | """ 153 | Load the specified pretrained tokenizer. 154 | """ 155 | 156 | if dtype is None: 157 | torch_dtype = torch.float16 158 | else: 159 | # Parse the data type 160 | torch_dtype = parse_dtype(dtype) 161 | 162 | tokenizer = tokenizer_type.from_pretrained( 163 | pretrained_model_path, 164 | use_fast=use_fast_tokenizer, 165 | torch_dtype=torch_dtype, 166 | local_files_only=True, 167 | ) 168 | if tokenizer.pad_token_id is None: 169 | tokenizer.pad_token_id = tokenizer.vocab_size - 1 170 | 171 | return tokenizer 172 | -------------------------------------------------------------------------------- /contextual_sparsity/utils/layer_names.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from typing import Dict, List, Optional, Tuple, Union 5 | 6 | LAYERS_CONTAINER = "layer_container" 7 | N_LAYERS = "n_layers" 8 | 9 | FC_UP = "fc_up" 10 | FC_DOWN = "fc_down" 11 | FC_ACT = "fc_act" 12 | FC_GATE = "fc_gate" 13 | MLP = "mlp" 14 | 15 | LAYER_TYPES = [FC_UP, FC_DOWN, FC_ACT, FC_GATE, MLP] 16 | 17 | MODEL_MAPS: Dict[str, Dict[str, Union[int, str]]] = { 18 | "dummy": { 19 | N_LAYERS: 2, 20 | LAYERS_CONTAINER: "layers", 21 | FC_UP: "up", 22 | FC_ACT: "activation_fn", 23 | FC_DOWN: "down", 24 | MLP: "", 25 | }, 26 | "opt-350M": { 27 | N_LAYERS: 24, 28 | LAYERS_CONTAINER: "model.decoder.layers", 29 | FC_UP: "fc1", 30 | FC_ACT: "activation_fn", 31 | FC_DOWN: "fc2", 32 | MLP: "", 33 | }, 34 | "llama-v3-8B": { 35 | N_LAYERS: 32, 36 | LAYERS_CONTAINER: "model.layers", 37 | FC_UP: "mlp.up_proj", 38 | FC_GATE: "mlp.gate_proj", 39 | FC_ACT: "mlp.act_fn", 40 | FC_DOWN: "mlp.down_proj", 41 | MLP: "mlp", 42 | }, 43 | "phi-3-medium": { 44 | N_LAYERS: 40, 45 | LAYERS_CONTAINER: "model.layers", 46 | FC_UP: "mlp.gate_up_proj", 47 | FC_GATE: "mlp.gate_up_proj", 48 | FC_ACT: "mlp.activation_fn", 49 | FC_DOWN: "mlp.down_proj", 50 | MLP: "mlp", 51 | }, 52 | "phi-3-mini": { 53 | N_LAYERS: 32, 54 | LAYERS_CONTAINER: "model.layers", 55 | FC_UP: "mlp.gate_up_proj", 56 | FC_GATE: "mlp.gate_up_proj", 57 | FC_ACT: "mlp.activation_fn", 58 | FC_DOWN: "mlp.down_proj", 59 | MLP: "mlp", 60 | }, 61 | "turbosparse-mistral": { 62 | N_LAYERS: 32, 63 | LAYERS_CONTAINER: "model.layers", 64 | FC_UP: "mlp.up_proj", 65 | FC_GATE: "mlp.gate_proj", 66 | FC_ACT: "mlp.act_fn", 67 | FC_DOWN: "mlp.down_proj", 68 | MLP: "mlp", 69 | }, 70 | "mistral-v01-7B": { 71 | N_LAYERS: 32, 72 | LAYERS_CONTAINER: "model.layers", 73 | FC_UP: "mlp.up_proj", 74 | FC_GATE: "mlp.gate_proj", 75 | FC_ACT: "mlp.act_fn", 76 | FC_DOWN: "mlp.down_proj", 77 | MLP: "mlp", 78 | }, 79 | } 80 | 81 | 82 | def block_number_from_id(model_id: str, submodule_id: str) -> int: 83 | """ 84 | Convert a string corresponding to a block ID to a block number. 85 | """ 86 | prefix = MODEL_MAPS[model_id][LAYERS_CONTAINER] 87 | return int(submodule_id.replace(prefix, "").split(".")[1]) 88 | 89 | 90 | def layer_id_to_block_id(model_id: str, submodule_id: str) -> str: 91 | """ 92 | Convert a string corresponding to a layer ID to a block number. 93 | """ 94 | block_number = block_number_from_id(model_id, submodule_id) 95 | return get_block_id(model_id=model_id, block_name=block_number) 96 | 97 | 98 | def cast_layer_names_to_int_list(model_id: str, layer_names: Union[int, List[int]]) -> List[int]: 99 | """ 100 | Convert a list of block numbers, or the keyword 'none' or 'all' into a list of integers 101 | representing the block numbers. 102 | """ 103 | if isinstance(layer_names, str): 104 | if layer_names == "all": 105 | layer_names = range(MODEL_MAPS[model_id][N_LAYERS]) 106 | elif layer_names == "none": 107 | layer_names = [] 108 | else: 109 | raise ValueError( 110 | "'layers_to_sparsify' must be one of 'all', 'none', int or list of ints" 111 | ) 112 | elif isinstance(layer_names, int): 113 | layer_names = [layer_names] 114 | return layer_names 115 | 116 | 117 | def get_block_id(model_id: str, block_name: int) -> str: 118 | """ 119 | Convert a block number to a block ID. 120 | """ 121 | 122 | block_id = ".".join( 123 | [ 124 | MODEL_MAPS[model_id][LAYERS_CONTAINER], 125 | str(block_name), 126 | ] 127 | ) 128 | 129 | return block_id 130 | 131 | 132 | def block_id_to_mlp_id(model_id: str, block_id: str) -> str: 133 | """ 134 | Convert a block number to a MLP ID. 135 | """ 136 | mlp_id = MODEL_MAPS[model_id][MLP] 137 | return block_id if mlp_id == "" else ".".join([block_id, mlp_id]) 138 | 139 | 140 | def block_id_to_layer_ids(model_id: str, block_id: str) -> Tuple[str, str, Optional[str]]: 141 | """ 142 | Convert a block number to a layer ID. 143 | """ 144 | down_name = MODEL_MAPS[model_id][FC_DOWN] 145 | up_name = MODEL_MAPS[model_id][FC_UP] 146 | 147 | down_id = ".".join([block_id, down_name]) 148 | up_id = ".".join([block_id, up_name]) 149 | gate_id = None 150 | 151 | if has_gate(model_id): 152 | gate_name = MODEL_MAPS[model_id][FC_GATE] 153 | gate_id = ".".join([block_id, gate_name]) 154 | 155 | return up_id, down_id, gate_id 156 | 157 | 158 | def has_gate(model_id: str): 159 | """ 160 | Check if a model has a gate layer in the MLPs 161 | """ 162 | return FC_GATE in MODEL_MAPS[model_id] 163 | 164 | 165 | def get_block_ids(model_id: str, block_names: Union[int, List[int], str]) -> List[str]: 166 | """ 167 | Convert a list of block numbers, or the keyword 'none' or 'all' into a list of block IDs 168 | """ 169 | if isinstance(block_names, str): 170 | if block_names == "all": 171 | n_layers: int = MODEL_MAPS[model_id][N_LAYERS] 172 | block_names = range(n_layers) 173 | else: 174 | raise ValueError(f"'{block_names}' must be one of 'all', int or list of ints") 175 | block_names = cast_layer_names_to_int_list(model_id, block_names) 176 | return [get_block_id(model_id=model_id, block_name=block_name) for block_name in block_names] 177 | 178 | 179 | def get_layer_id(model_id: str, layer_type: str, layer_name: int) -> str: 180 | """ 181 | Convert a layer number to a layer ID. 182 | """ 183 | if layer_type not in LAYER_TYPES or layer_type is None: 184 | raise ValueError(f"Layer {layer_type} not found in {LAYER_TYPES}") 185 | if model_id not in MODEL_MAPS: 186 | raise ValueError(f"Model {model_id} not found in {MODEL_MAPS.keys()}") 187 | if layer_type not in MODEL_MAPS[model_id]: 188 | raise ValueError(f"Layer {layer_type} not specified in {MODEL_MAPS[model_id].keys()}") 189 | if layer_name > MODEL_MAPS[model_id][N_LAYERS] or layer_name < 0: 190 | raise ValueError( 191 | f"Layer {layer_name} out of bounds for net {model_id}, maximum is {MODEL_MAPS[model_id][N_LAYERS]}" 192 | ) 193 | 194 | layer_name = ".".join([get_block_id(model_id, layer_name), MODEL_MAPS[model_id][layer_type]]) 195 | return layer_name 196 | 197 | 198 | def get_layer_ids( 199 | model_id: str, layer_type: str, layer_names: Union[str, List[int], int] 200 | ) -> List[str]: 201 | """ 202 | Convert a list of block numbers, or the keyword 'none' or 'all' into a list of layer IDs 203 | """ 204 | layer_names = cast_layer_names_to_int_list(model_id, layer_names) 205 | 206 | return [ 207 | get_layer_id(model_id=model_id, layer_type=layer_type, layer_name=layer_names) 208 | for layer_names in layer_names 209 | ] 210 | -------------------------------------------------------------------------------- /contextual_sparsity/hw_simulator/cache.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from abc import ABC, abstractmethod 5 | from typing import Optional, Tuple 6 | 7 | import torch 8 | 9 | 10 | class HardwareCache(ABC): 11 | """ 12 | This class is used to simulate read from memory, write to memory and memory management behavior of the hardware. 13 | """ 14 | 15 | def __init__( 16 | self, 17 | size_per_idx, 18 | precision, 19 | max_cache_size, 20 | max_index, 21 | device, 22 | allow_mlp_streaming=True, 23 | ): 24 | self.size_per_idx = size_per_idx 25 | self.precision = precision 26 | self.max_cache_size = max_cache_size 27 | self.max_index = max_index 28 | self.device = device 29 | self.allow_mlp_streaming = allow_mlp_streaming 30 | self.cache = torch.zeros(self.max_index, dtype=torch.bool, device=self.device) 31 | self.slot_counts = torch.zeros(self.max_index, dtype=torch.int, device=self.device) 32 | 33 | def update(self, mask: torch.Tensor, slot_count: Optional[torch.Tensor] = None): 34 | n_to_evict, old_inactive, old_active, new_active = self.get_n_to_evict(mask=mask) 35 | # evict from cache 36 | mask = self.evict( 37 | mask=mask, 38 | n_to_evict=n_to_evict, 39 | old_inactive=old_inactive, 40 | slot_count=slot_count, 41 | ) 42 | # load to cache 43 | if mask.sum() <= self.max_cache_size: 44 | self.cache[mask] = 1 45 | else: 46 | # There are more neurons to load than space in DRAM. Randomly select a subset to load to DRAM 47 | # (to avoid bias in the cache eviction strategy towards certain neurons) 48 | assert self.allow_mlp_streaming, (mask.sum(), self.max_cache_size) 49 | 50 | # Select random subset of neurons among the one in mask and not in cache 51 | n_to_load = self.max_cache_size - old_active.sum() 52 | randperm = torch.randperm(new_active.sum(dim=-1)) 53 | idx_new_active = torch.where(new_active)[0] 54 | idx_selected = idx_new_active[randperm[:n_to_load]] 55 | 56 | # Check cache integrity and load selected neurons 57 | assert torch.all(self.cache[old_active]) 58 | assert torch.all(self.cache[~old_active] == False) 59 | self.cache[idx_selected] = True 60 | 61 | @abstractmethod 62 | def evict(self, *args, **kwargs): 63 | raise NotImplementedError("Child classes should have an eviction method.") 64 | 65 | def get_n_to_evict( 66 | self, mask: torch.Tensor 67 | ) -> Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor]: 68 | assert ( 69 | self.allow_mlp_streaming or mask.sum() <= self.max_cache_size 70 | ), f"cache capacity exceeded! {mask.sum()}/{self.max_cache_size}" 71 | old_inactive = ~mask & self.cache 72 | old_active = mask & self.cache 73 | new_active = mask & ~self.cache 74 | n_to_evict = min(old_inactive.sum(), mask.sum() + old_inactive.sum() - self.max_cache_size) 75 | return n_to_evict, old_inactive, old_active, new_active 76 | 77 | def get_usage_in_bytes(self) -> float: 78 | usage_in_bytes = self.cache.sum().item() * self.size_per_idx * self.precision 79 | return usage_in_bytes 80 | 81 | def get_current_io_division(self, mask: torch.Tensor) -> Tuple[float, float]: 82 | flash_io = ( 83 | (mask & ~self.cache).sum().item() * self.size_per_idx * self.precision 84 | ) # reading from Flash 85 | dram_io = ( 86 | (mask & self.cache).sum().item() * self.size_per_idx * self.precision 87 | ) # reading from DRAM 88 | 89 | return flash_io, dram_io 90 | 91 | def get_cache_hit_rate(self, mask: torch.Tensor) -> float: 92 | assert mask.ndim == 1, mask.shape 93 | if mask.sum().item() == 0: 94 | return 1.0 95 | return (mask & self.cache).sum().item() / mask.sum().item() 96 | 97 | 98 | class NotCache(HardwareCache): 99 | # evict least recently used cache slots 100 | 101 | def __init__(self, *args, **kwargs): 102 | super().__init__(*args, **kwargs) 103 | 104 | def update(self, *args, **kwargs): 105 | pass 106 | 107 | def evict(self, *args, **kwargs): 108 | pass 109 | 110 | def get_current_io_division(self, mask: torch.Tensor) -> Tuple[float, float]: 111 | dram_io = 0 112 | flash_io = mask.sum().item() * self.size_per_idx * self.precision 113 | return flash_io, dram_io 114 | 115 | 116 | class LRUCache(HardwareCache): 117 | # evict least recently used cache slots 118 | 119 | def __init__(self, *args, **kwargs): 120 | super().__init__(*args, **kwargs) 121 | 122 | def evict( 123 | self, mask: torch.Tensor, n_to_evict: int, old_inactive: torch.Tensor, **kwargs 124 | ) -> torch.Tensor: 125 | # update slot count 126 | self.slot_counts[mask] = 0 127 | self.slot_counts[old_inactive] += 1 128 | 129 | # evict from cache 130 | if n_to_evict > 0: 131 | if n_to_evict == old_inactive.sum(): 132 | self.cache[old_inactive] = 0 133 | self.slot_counts[old_inactive] = 0 134 | else: 135 | idx = torch.topk(self.slot_counts, k=n_to_evict)[1] 136 | self.cache[idx] = 0 137 | self.slot_counts[idx] = 0 138 | 139 | return mask 140 | 141 | 142 | class LFUMaskFirstCache(HardwareCache): 143 | # prioritize caching the current mask, evict the rest based on frequency of usage 144 | 145 | def __init__(self, *args, **kwargs): 146 | super().__init__(*args, **kwargs) 147 | 148 | def evict( 149 | self, mask: torch.Tensor, n_to_evict: int, old_inactive: torch.Tensor, **kwargs 150 | ) -> torch.Tensor: 151 | # update slot count 152 | self.slot_counts[mask] += 1 153 | 154 | # evict from cache 155 | if n_to_evict > 0: 156 | if n_to_evict == old_inactive.sum(): 157 | self.cache[old_inactive] = 0 158 | else: 159 | idx = torch.topk(-self.slot_counts[old_inactive], k=n_to_evict)[1] 160 | orig_idx = torch.nonzero(old_inactive)[idx] 161 | self.cache[orig_idx] = 0 162 | 163 | return mask 164 | 165 | 166 | class BeladyCache(HardwareCache): 167 | # Oracle algorithm which is the performance upperbound for LRU and LFU 168 | 169 | def __init__(self, *args, **kwargs): 170 | super().__init__(*args, **kwargs) 171 | 172 | def evict( 173 | self, 174 | mask: torch.Tensor, 175 | n_to_evict: int, 176 | old_inactive: torch.Tensor, 177 | slot_count: torch.Tensor, 178 | **kwargs, 179 | ) -> torch.Tensor: 180 | # evict from cache 181 | if n_to_evict > 0: 182 | if n_to_evict == old_inactive.sum(): 183 | self.cache[old_inactive] = 0 184 | else: 185 | idx = torch.topk(self.slot_counts[old_inactive], k=n_to_evict)[1] 186 | orig_idx = torch.nonzero(old_inactive)[idx] 187 | self.cache[orig_idx] = 0 188 | 189 | # update slot count 190 | self.slot_counts = slot_count 191 | 192 | return mask 193 | 194 | 195 | cache_strategy_to_class = { 196 | "no_cache": NotCache, 197 | "lru": LRUCache, 198 | "lfu": LFUMaskFirstCache, 199 | "belady": BeladyCache, 200 | } 201 | -------------------------------------------------------------------------------- /contextual_sparsity/scripts/compute_activations.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | import logging 5 | import os 6 | from functools import partial 7 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 8 | 9 | import h5py 10 | import numpy as np 11 | import torch 12 | from hydra.utils import instantiate 13 | from omegaconf import DictConfig 14 | from torch.utils.data import DataLoader 15 | from tqdm.auto import tqdm 16 | 17 | from contextual_sparsity.utils.layer_names import FC_DOWN, FC_UP, get_layer_ids 18 | from contextual_sparsity.utils.misc import parse_dtype 19 | from contextual_sparsity.utils.stats import StopComputation, collect_activations 20 | 21 | log = logging.getLogger(__name__) 22 | DATAFILE = "activations.h5" 23 | MEMORY_OPTIONS = ["cpu", "cuda", "disk"] 24 | 25 | 26 | def collect_input_and_output( 27 | module, 28 | args, 29 | kwargs, 30 | out, 31 | layer_name, 32 | data, 33 | store_input, 34 | store_output, 35 | dtype, 36 | stop_computation=False, 37 | ): 38 | """ 39 | Hook function used to store input/outputs of each layer 40 | """ 41 | batch_size = out.shape[0] 42 | 43 | # Collect the inputs of the layer 44 | if store_input: 45 | i = args[0].cpu() 46 | i = i.type(dtype) 47 | i = i.view(batch_size, -1, i.shape[-1]) 48 | data[f"{layer_name}.input"] = i 49 | 50 | # Collect the outputs of the layer 51 | if store_output: 52 | o = out.cpu() 53 | o = o.type(dtype) 54 | o = o.view(batch_size, -1, o.shape[-1]) 55 | data[f"{layer_name}.output"] = o 56 | 57 | if stop_computation: 58 | # skip all computation once the last value is saved 59 | raise StopComputation() 60 | 61 | 62 | def output_to_next_input( 63 | out: Any, in_args: List[Any], in_kwargs: Dict[str, Any] 64 | ) -> Tuple[List[Any], Dict[str, Any]]: 65 | return [out[0]], in_kwargs 66 | 67 | 68 | def compute_activations( 69 | dataloader: DataLoader, 70 | dtype: Union[str, torch.dtype], 71 | preprocess_batch: Optional[Callable], 72 | dense_model: torch.nn.Module, 73 | activation_ids: List[str], 74 | memory: str, 75 | stop_computation: bool = True, 76 | ) -> Dict[str, Union[np.ndarray, torch.Tensor]]: 77 | """ 78 | compute all the specified activations for a given model and dataloader. Activations are cast to the specified dtype 79 | and stored to disk, gpu, or cpu memory. 80 | """ 81 | assert memory in MEMORY_OPTIONS 82 | 83 | batch_size = dataloader.batch_size 84 | # Parse the data type 85 | dtype = parse_dtype(dtype) 86 | layers = {} 87 | layer_id = "" 88 | 89 | # Consider all the activations ids and determine for each layer if we collect only inputs, outputs or both 90 | for activation_id in activation_ids: 91 | layer_id = ".".join(activation_id.split(".")[:-1]) 92 | io = activation_id.split(".")[-1] 93 | assert dense_model.get_submodule(layer_id) is not None, f"{layer_id} not found in net." 94 | assert io in [ 95 | "input", 96 | "output", 97 | ], f"{activation_id} is not a valid activation id." 98 | if layer_id not in layers: 99 | layers[layer_id] = {"store_input": False, "store_output": False} 100 | layers[layer_id][f"store_{io}"] = True 101 | layers[layer_id]["dtype"] = dtype 102 | 103 | if len(layers) > 0: 104 | layers[layer_id]["stop_computation"] = stop_computation 105 | 106 | # Consider an appropriate hook based on if we are interested in storing only input/output or both 107 | collection_funcs = {} 108 | for activation_id in activation_ids: 109 | layer_id = ".".join(activation_id.split(".")[:-1]) 110 | collection_funcs[layer_id] = partial(collect_input_and_output, **layers[layer_id]) 111 | 112 | dense_model.eval() 113 | 114 | if memory == "disk": 115 | log.info(f"Storing the activations in {os.getcwd()}") 116 | activations = h5py.File(DATAFILE, "w") 117 | else: 118 | activations = {} 119 | 120 | log.info("Computing the activations") 121 | # Store all the activations (in memory or h5py file) 122 | last_idx = 0 123 | dataset_size = len(dataloader.dataset) 124 | with torch.no_grad(): 125 | with collect_activations( 126 | collection_funcs=collection_funcs, 127 | dense_model=dense_model, 128 | preprocess_batch=preprocess_batch, 129 | ) as collect_act: 130 | for batch in tqdm(dataloader): 131 | acts = collect_act(batch) 132 | for act_name, act_value in acts.items(): 133 | # For model that use shape [seq_len, batch_size, features], we transpose the first two dimensions 134 | if act_value.shape[0] != dataloader.batch_size: 135 | act_value = act_value.transpose(1, 0) 136 | 137 | assert act_value.shape[0] == dataloader.batch_size 138 | 139 | if act_name not in activations: 140 | # Determine the shape of the whole dataset and allocate it 141 | shape = [dataset_size] + list(act_value.shape[1:]) 142 | if memory == "disk": 143 | activations.create_dataset( 144 | act_name, 145 | shape=shape, 146 | dtype=act_value.data.numpy().dtype, 147 | ) 148 | else: 149 | activations[act_name] = torch.zeros( 150 | shape, dtype=act_value.dtype, device=memory 151 | ) 152 | # Set the values for each entry 153 | for i in range(act_value.shape[0]): 154 | activations[act_name][last_idx + i] = act_value[i].detach() 155 | last_idx += batch_size 156 | 157 | # Save the data to file if required 158 | if memory == "disk": 159 | activations.flush() 160 | log.info(f"Storing the activations in {os.path.join(os.getcwd(), DATAFILE)}") 161 | 162 | return activations 163 | 164 | 165 | def store_activations_main( 166 | conf: DictConfig, 167 | ) -> Dict[str, Union[np.ndarray, torch.Tensor]]: 168 | """ 169 | Function to store the activations to disk given the specified configuration. This function is called when 170 | specifying experiment=store_activations from CLI 171 | """ 172 | 173 | log.info("Instantiating the original Dataset") 174 | split = conf.activations.split 175 | dataloader = instantiate(conf.data[split], tokenizer=conf.tokenizer, shuffle=False) 176 | preprocess_batch = instantiate(conf.preprocess_batch) 177 | 178 | log.info("Instantiating the Model") 179 | dense_model = instantiate(conf.dense_model).to(conf.device) 180 | 181 | activation_ids = conf.activations.activation_ids 182 | 183 | # If "none" store input and output of down layers and input of up layers 184 | if activation_ids is None: 185 | activations_ids = get_layer_ids( 186 | model_id=conf.model_id, 187 | layer_type=FC_DOWN, 188 | layer_names=conf.activations.layer_ids, 189 | ) + get_layer_ids( 190 | model_id=conf.model_id, 191 | layer_type=FC_UP, 192 | layer_names=conf.activations.layer_ids, 193 | ) 194 | else: 195 | activations_ids = activation_ids 196 | 197 | if isinstance(activations_ids, str): 198 | activations_ids = [activations_ids] 199 | 200 | compute_activations( 201 | dataloader=dataloader, 202 | dtype=conf.activations.dtype, 203 | preprocess_batch=preprocess_batch, 204 | dense_model=dense_model, 205 | activation_ids=activations_ids, 206 | memory="disk", 207 | ) 208 | -------------------------------------------------------------------------------- /contextual_sparsity/masking_hooks/dip.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Qualcomm Technologies, Inc. 2 | # All Rights Reserved. 3 | 4 | from typing import Callable, List, Optional, Union 5 | 6 | import numpy as np 7 | from scipy.special import expit, logit 8 | from torch import nn 9 | from torch.utils.data import DataLoader 10 | 11 | from contextual_sparsity.mask import MaskingHook 12 | from contextual_sparsity.masking_hooks.binarization import ( 13 | BinarizationType, 14 | build_binarization, 15 | ) 16 | from contextual_sparsity.nn import Abs 17 | from contextual_sparsity.utils.layer_names import ( 18 | FC_DOWN, 19 | FC_UP, 20 | block_id_to_layer_ids, 21 | block_id_to_mlp_id, 22 | get_block_ids, 23 | get_layer_ids, 24 | has_gate, 25 | ) 26 | 27 | 28 | def build_dip_masking_hooks( 29 | model_id: str, 30 | dense_model: nn.Module, 31 | layers_to_sparsify: Union[str, List[int], int], 32 | data_id: Optional[str] = None, 33 | calibration_data: Optional[DataLoader] = None, 34 | preprocess_batch: Optional[Callable] = None, 35 | binarization_type: str = BinarizationType.topk.value, 36 | down_k: Optional[Union[int, List[int]]] = None, 37 | up_k: Optional[Union[int, List[int]]] = None, 38 | gate_k: Optional[Union[int, List[int]]] = None, 39 | down_keep: Optional[Union[float, List[float]]] = None, 40 | up_keep: Optional[Union[float, List[float]]] = None, 41 | gate_keep: Optional[Union[float, List[float]]] = None, 42 | down_threshold: Optional[Union[float, List[float]]] = None, 43 | up_threshold: Optional[Union[float, List[float]]] = None, 44 | gate_threshold: Optional[Union[float, List[float]]] = None, 45 | ) -> List[MaskingHook]: 46 | """ 47 | Factory function for building DIP masking hooks 48 | """ 49 | 50 | block_ids = get_block_ids(model_id=model_id, block_names=layers_to_sparsify) 51 | masking_hooks = [] 52 | # If not specified, the sparsity of up and gate are the same 53 | if gate_k is None and gate_keep is None and gate_threshold is None: 54 | gate_keep = up_keep 55 | gate_k = up_k 56 | gate_threshold = up_threshold 57 | 58 | down_activation_ids = [ 59 | ".".join([layer_id, "input"]) 60 | for layer_id in get_layer_ids( 61 | model_id=model_id, layer_type=FC_DOWN, layer_names=layers_to_sparsify 62 | ) 63 | ] 64 | 65 | # Create a binarization function for all the intermediate activations (inputs to down) 66 | activation_binarization = build_binarization( 67 | activation_ids=down_activation_ids, 68 | model_id=model_id, 69 | dense_model=dense_model, 70 | data_id=data_id, 71 | calibration_data=calibration_data, 72 | preprocess_batch=preprocess_batch, 73 | binarization_type=binarization_type, 74 | threshold=down_threshold, 75 | keep=down_keep, 76 | k=down_k, 77 | ) 78 | 79 | # Do the same for all the inputs to the up layers 80 | up_activation_ids = [ 81 | ".".join([layer_id, "input"]) 82 | for layer_id in get_layer_ids( 83 | model_id=model_id, layer_type=FC_UP, layer_names=layers_to_sparsify 84 | ) 85 | ] 86 | activation_binarization.update( 87 | build_binarization( 88 | activation_ids=up_activation_ids, 89 | model_id=model_id, 90 | dense_model=dense_model, 91 | data_id=data_id, 92 | calibration_data=calibration_data, 93 | preprocess_batch=preprocess_batch, 94 | binarization_type=binarization_type, 95 | threshold=up_threshold, 96 | keep=up_keep, 97 | k=up_k, 98 | ) 99 | ) 100 | 101 | # If the model has gate and a different binarization policy for the gate component, also compute them 102 | if ( 103 | up_keep != gate_keep 104 | or up_k != gate_keep 105 | or up_threshold != gate_threshold 106 | and has_gate(model_id) 107 | ): 108 | activation_binarization.update( 109 | build_binarization( 110 | activation_ids=up_activation_ids, 111 | model_id=model_id, 112 | dense_model=dense_model, 113 | data_id=data_id, 114 | calibration_data=calibration_data, 115 | preprocess_batch=preprocess_batch, 116 | binarization_type=binarization_type, 117 | threshold=gate_threshold, 118 | keep=gate_keep, 119 | k=gate_k, 120 | ) 121 | ) 122 | 123 | for i, block_id in enumerate(block_ids): 124 | up_layer_id, down_layer_id, gate_layer_id = block_id_to_layer_ids( 125 | model_id=model_id, block_id=block_id 126 | ) 127 | mlp_layer_id = block_id_to_mlp_id(model_id=model_id, block_id=block_id) 128 | 129 | ##################### 130 | # Down Masking Hook # 131 | ##################### 132 | down_masking_hook = MaskingHook( 133 | masking_func=nn.Sequential( 134 | Abs(), 135 | activation_binarization[".".join([down_layer_id, "input"])], 136 | ), 137 | input_from=down_layer_id, 138 | mask_cols_of=[down_layer_id], 139 | mask_rows_of=[], 140 | ) 141 | masking_hooks.append(down_masking_hook) 142 | 143 | ################### 144 | # Up Masking Hook # 145 | ################### 146 | up_masking_hook = MaskingHook( 147 | masking_func=nn.Sequential( 148 | Abs(), 149 | activation_binarization[".".join([up_layer_id, "input"])], 150 | ), 151 | input_from=mlp_layer_id, 152 | mask_cols_of=[up_layer_id], 153 | mask_rows_of=[], 154 | ) 155 | # Set the input to the input of the MLP block 156 | masking_hooks.append(up_masking_hook) 157 | 158 | if gate_layer_id is not None: 159 | gate_activation_id = ".".join([gate_layer_id, "input"]) 160 | 161 | # Case 1: gate does not have a dedicated binarization 162 | # The selected columns are the same as up. We apply the same up_masking_hook to the columns of gate 163 | if gate_activation_id not in activation_binarization: 164 | up_masking_hook.mask_cols_of.append(gate_layer_id) 165 | 166 | # Case 2: gate has a different binarization function. 167 | # The selected columns are different. Make a new masking hook for gate 168 | else: 169 | gate_masking_hook = MaskingHook( 170 | masking_func=nn.Sequential( 171 | Abs(), 172 | activation_binarization[gate_activation_id], 173 | ), 174 | input_from=mlp_layer_id, 175 | mask_cols_of=[gate_layer_id], 176 | mask_rows_of=[], 177 | ) 178 | masking_hooks.append(gate_masking_hook) 179 | 180 | return masking_hooks 181 | 182 | 183 | def optimal_up_keep_from_keep( 184 | keep: Union[float, List[float], np.ndarray], 185 | ) -> Union[float, List[float], np.ndarray]: 186 | # Best linear fit in logit space 187 | m, b = (0.9849930711017189, 0.29976477589716316) 188 | up_keep = expit(logit(keep) * m + b) 189 | return up_keep 190 | 191 | 192 | def build_optimized_dip_masking_hooks( 193 | model_id: str, 194 | dense_model: nn.Module, 195 | data_id: str, 196 | layers_to_sparsify: Union[str, List[int], int], 197 | keep: Union[float, List[float]], 198 | ) -> List[MaskingHook]: 199 | if not has_gate(model_id): 200 | raise NotImplementedError("The optimized values are computed for LLMs with gating") 201 | 202 | keep = np.array(keep) 203 | up_keep = optimal_up_keep_from_keep(keep) 204 | down_keep = 3 * (keep - 2 / 3.0 * up_keep) 205 | 206 | return build_dip_masking_hooks( 207 | model_id=model_id, 208 | dense_model=dense_model, 209 | data_id=data_id, 210 | layers_to_sparsify=layers_to_sparsify, 211 | down_keep=down_keep, 212 | up_keep=up_keep, 213 | ) 214 | --------------------------------------------------------------------------------