├── examples ├── utils │ ├── __init__.py │ ├── loggers.py │ ├── setup_llm.py │ └── dsets.py ├── configs │ ├── llm │ │ ├── peft │ │ │ ├── adalora.yaml │ │ │ ├── lora.yaml │ │ │ └── high_r_lora.yaml │ │ ├── quantization │ │ │ ├── none.yaml │ │ │ ├── 8bit.yaml │ │ │ └── 4bit.yaml │ │ ├── gpt2.yaml │ │ ├── llama2.yaml │ │ ├── roberta.yaml │ │ ├── zephyr.yaml │ │ └── phi.yaml │ ├── hydra │ │ ├── dev.yaml │ │ └── default.yaml │ ├── opt │ │ ├── adam.yaml │ │ └── adamw.yaml │ ├── dset │ │ ├── arc.yaml │ │ ├── cqa.yaml │ │ ├── qqp.yaml │ │ ├── boolq.yaml │ │ ├── cola.yaml │ │ ├── mrpc.yaml │ │ ├── obqa.yaml │ │ ├── qnli.yaml │ │ ├── rte.yaml │ │ ├── sst2.yaml │ │ ├── wnli.yaml │ │ ├── mnli.yaml │ │ └── winogrande.yaml │ ├── accelerate │ │ ├── simple.yaml │ │ ├── local_with_ds.yaml │ │ └── fsdp.yaml │ ├── base_config.yaml │ ├── paths │ │ └── default.yaml │ └── example_usage.yaml └── example_usage.py ├── bayesian_lora ├── __init__.py ├── main.py └── kfac.py ├── documentation ├── writedocs.sh └── source │ ├── example_usage.rst │ ├── bayesian_lora.rst │ ├── kfac.rst │ ├── index.rst │ ├── conf.py │ └── _static │ └── block_diagonal.svg ├── tests ├── test_example.py ├── conftest.py └── test_kfac.py ├── CITATION.cff ├── setup.py ├── notebooks ├── test_notebook.py └── roberta_dev.py ├── Makefile ├── pyproject.toml ├── .gitignore ├── README.md └── LICENSE /examples/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/configs/llm/peft/adalora.yaml: -------------------------------------------------------------------------------- 1 | _target_: peft.AdaLoraConfig 2 | lora_alpha: 8 3 | task_type: "CAUSAL_LM" 4 | inference_mode: false 5 | -------------------------------------------------------------------------------- /bayesian_lora/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Bayesian low-rank adapters 3 | """ 4 | __version__ = "0.0.6" 5 | 6 | from .kfac import * 7 | from .main import * 8 | -------------------------------------------------------------------------------- /examples/configs/hydra/dev.yaml: -------------------------------------------------------------------------------- 1 | run: 2 | dir: ${paths.root_dir}/outputs/${task_name}/dev_run 3 | 4 | defaults: 5 | - default.yaml 6 | - _self_ 7 | -------------------------------------------------------------------------------- /examples/configs/opt/adam.yaml: -------------------------------------------------------------------------------- 1 | module: torch.optim 2 | classname: Adam 3 | lr: 0.00005 4 | betas: [0.9, 0.999] 5 | eps: 0.00001 # 1e-5 6 | weight_decay: 0.1 7 | -------------------------------------------------------------------------------- /examples/configs/opt/adamw.yaml: -------------------------------------------------------------------------------- 1 | module: torch.optim 2 | classname: AdamW 3 | lr: 0.00005 4 | betas: [0.9, 0.999] 5 | eps: 0.00001 # 1e-5 6 | weight_decay: 0.1 7 | -------------------------------------------------------------------------------- /examples/configs/llm/quantization/none.yaml: -------------------------------------------------------------------------------- 1 | _target_: transformers.utils.quantization_config.BitsAndBytesConfig 2 | 3 | load_in_4bit: false 4 | load_in_8bit: false 5 | -------------------------------------------------------------------------------- /documentation/writedocs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | sphinx-autobuild documentation/source documentation/build/html --open-browser --port 0 --watch $(dirname "$(pwd)"/bayesian_lora) 4 | -------------------------------------------------------------------------------- /examples/configs/dset/arc.yaml: -------------------------------------------------------------------------------- 1 | name: arc 2 | max_epochs: 5 3 | train_bs: 32 4 | train_split: train 5 | eval_bs: 32 6 | eval_split: validation 7 | eval_subset: 1500 8 | max_length: 256 9 | n_labels: 5 10 | -------------------------------------------------------------------------------- /examples/configs/dset/cqa.yaml: -------------------------------------------------------------------------------- 1 | name: cqa 2 | max_epochs: 4 3 | train_bs: 16 4 | train_split: train 5 | eval_bs: 32 6 | eval_split: validation 7 | eval_subset: -1 8 | max_length: 256 9 | n_labels: 5 10 | -------------------------------------------------------------------------------- /examples/configs/dset/qqp.yaml: -------------------------------------------------------------------------------- 1 | name: qqp 2 | max_epochs: 5 3 | train_bs: 32 4 | train_split: train 5 | eval_bs: 32 6 | eval_split: validation 7 | eval_subset: 1500 8 | max_length: 256 9 | n_labels: 2 10 | -------------------------------------------------------------------------------- /examples/configs/dset/boolq.yaml: -------------------------------------------------------------------------------- 1 | name: boolq 2 | max_epochs: 5 3 | train_bs: 32 4 | train_split: train 5 | eval_bs: 32 6 | eval_split: validation 7 | eval_subset: 1500 8 | max_length: 256 9 | n_labels: 2 10 | -------------------------------------------------------------------------------- /examples/configs/dset/cola.yaml: -------------------------------------------------------------------------------- 1 | name: cola 2 | max_epochs: 3 3 | train_bs: 8 4 | train_split: train 5 | eval_bs: 32 6 | eval_split: validation 7 | eval_subset: 500 8 | max_length: 1024 9 | n_labels: 2 10 | -------------------------------------------------------------------------------- /examples/configs/dset/mrpc.yaml: -------------------------------------------------------------------------------- 1 | name: mrpc 2 | max_epochs: 3 3 | train_bs: 16 4 | train_split: train 5 | eval_bs: 32 6 | eval_split: validation 7 | eval_subset: 1500 8 | max_length: 256 9 | n_labels: 2 10 | -------------------------------------------------------------------------------- /examples/configs/dset/obqa.yaml: -------------------------------------------------------------------------------- 1 | name: obqa 2 | max_epochs: 5 3 | train_bs: 32 4 | train_split: train 5 | eval_bs: 32 6 | eval_split: validation 7 | eval_subset: 1500 8 | max_length: 256 9 | n_labels: 4 10 | -------------------------------------------------------------------------------- /examples/configs/dset/qnli.yaml: -------------------------------------------------------------------------------- 1 | name: qnli 2 | max_epochs: 3 3 | train_bs: 16 4 | train_split: train 5 | eval_bs: 32 6 | eval_split: validation 7 | eval_subset: 1500 8 | max_length: 256 9 | n_labels: 2 10 | -------------------------------------------------------------------------------- /examples/configs/dset/rte.yaml: -------------------------------------------------------------------------------- 1 | name: rte 2 | max_epochs: 10 3 | train_bs: 16 4 | train_split: train 5 | eval_bs: 32 6 | eval_split: validation 7 | eval_subset: 1500 8 | max_length: 256 9 | n_labels: 2 10 | -------------------------------------------------------------------------------- /examples/configs/dset/sst2.yaml: -------------------------------------------------------------------------------- 1 | name: sst2 2 | max_epochs: 10 3 | train_bs: 32 4 | train_split: train 5 | eval_bs: 32 6 | eval_split: validation 7 | eval_subset: 1500 8 | max_length: 256 9 | n_labels: 2 10 | -------------------------------------------------------------------------------- /examples/configs/dset/wnli.yaml: -------------------------------------------------------------------------------- 1 | name: wnli 2 | max_epochs: 10 3 | train_bs: 32 4 | train_split: train 5 | eval_bs: 32 6 | eval_split: validation 7 | eval_subset: 1500 8 | max_length: 256 9 | n_labels: 2 10 | -------------------------------------------------------------------------------- /examples/configs/dset/mnli.yaml: -------------------------------------------------------------------------------- 1 | name: mnli 2 | max_epochs: 3 3 | train_bs: 16 4 | train_split: train 5 | eval_bs: 32 6 | eval_split: validation_matched 7 | eval_subset: 500 8 | max_length: 1024 9 | n_labels: 3 10 | -------------------------------------------------------------------------------- /examples/configs/dset/winogrande.yaml: -------------------------------------------------------------------------------- 1 | name: winogrande 2 | max_epochs: 5 3 | train_bs: 32 4 | train_split: train 5 | eval_bs: 32 6 | eval_split: validation 7 | eval_subset: 1500 8 | max_length: 256 9 | n_labels: 2 10 | -------------------------------------------------------------------------------- /examples/configs/llm/quantization/8bit.yaml: -------------------------------------------------------------------------------- 1 | _target_: transformers.utils.quantization_config.BitsAndBytesConfig 2 | 3 | load_in_4bit: false 4 | load_in_8bit: true 5 | llm_int8_threshold: 6.0 6 | llm_int8_has_fp16_weight: false 7 | -------------------------------------------------------------------------------- /examples/configs/llm/peft/lora.yaml: -------------------------------------------------------------------------------- 1 | _target_: peft.LoraConfig 2 | r: 8 # 16 3 | lora_alpha: 8 4 | lora_dropout: 0.05 5 | task_type: "CAUSAL_LM" 6 | inference_mode: false 7 | bias: lora_only 8 | # target_modules: ["c_attn", "c_proj", "c_fc", "c_proj", "lm_head"] 9 | -------------------------------------------------------------------------------- /tests/test_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple example unit test. 3 | """ 4 | 5 | import pytest 6 | 7 | 8 | def test_example(): 9 | assert 2 + 2 == 4 10 | 11 | 12 | @pytest.mark.slow 13 | def test_example_slow(): 14 | for i in range(int(1e8)): 15 | assert i == i 16 | -------------------------------------------------------------------------------- /examples/configs/llm/peft/high_r_lora.yaml: -------------------------------------------------------------------------------- 1 | _target_: peft.LoraConfig 2 | r: 256 3 | bias: none 4 | lora_alpha: 128 5 | lora_dropout: 0.05 6 | task_type: "CAUSAL_LM" 7 | inference_mode: false 8 | # target_modules: ["c_attn", "c_proj", "c_fc", "c_proj", "lm_head"] 9 | # target_modules: all-linear 10 | -------------------------------------------------------------------------------- /documentation/source/example_usage.rst: -------------------------------------------------------------------------------- 1 | .. _example_usage: 2 | 3 | Example Usage 4 | ============= 5 | 6 | .. note:: Unfinished. For now, please see read through the `example 7 | `_ 8 | and comments therein. 9 | 10 | -------------------------------------------------------------------------------- /examples/configs/llm/quantization/4bit.yaml: -------------------------------------------------------------------------------- 1 | _target_: transformers.utils.quantization_config.BitsAndBytesConfig 2 | 3 | load_in_4bit: true 4 | load_in_8bit: false 5 | llm_int8_threshold: 6.0 6 | llm_int8_has_fp16_weight: false 7 | bnb_4bit_quant_type: nf4 8 | bnb_4bit_compute_dtype: bfloat16 9 | bnb_4bit_use_double_quant: true 10 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.1.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: Robeyns 5 | given-names: Maxime 6 | orcid: https://orcid.org/0000-0001-9802-9597 7 | title: "Bayesian LoRA" 8 | version: 0.0.1 9 | date-released: 2024-01-31 10 | repository-code: "https://github.com/MaximeRobeyns/bayesian_lora" 11 | -------------------------------------------------------------------------------- /examples/configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # The output directory is generated dynamically on each run based on the task 2 | # name. 3 | run: 4 | dir: ${paths.root_dir}/outputs/${task_name}/${now:%Y-%m-%d_%H-%M-%S} 5 | 6 | help: 7 | app_name: "Bayesian LoRA" 8 | header: "\n\nHelp Information" 9 | footer: "Use --hydra-help to view Hydra specific help information" 10 | -------------------------------------------------------------------------------- /examples/configs/accelerate/simple.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: "NO" 4 | downcast_bf16: "no" 5 | gpu_ids: all 6 | machine_rank: 0 7 | main_training_function: main 8 | mixed_precision: bf16 9 | num_machines: 1 10 | num_processes: 1 11 | rdzv_backend: static 12 | same_network: true 13 | tpu_env: [] 14 | tpu_use_cluster: false 15 | tpu_use_sudo: false 16 | use_cpu: false 17 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | def pytest_addoption(parser): 5 | parser.addoption( 6 | "--run-slow", action="store_true", default=False, help="run slow tests" 7 | ) 8 | 9 | 10 | def pytest_collection_modifyitems(config, items): 11 | if config.getoption("--run-slow"): 12 | # --run-slow is provided, so don't skip any tests 13 | return 14 | skip_slow = pytest.mark.skip(reason="need --run-slow option to run") 15 | for item in items: 16 | if "slow" in item.keywords and not config.getoption("--run-slow"): 17 | item.add_marker(skip_slow) 18 | -------------------------------------------------------------------------------- /examples/configs/accelerate/local_with_ds.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | gradient_accumulation_steps: 1 4 | offload_optimizer_device: cpu 5 | offload_param_device: cpu 6 | zero3_init_flag: false 7 | zero_stage: 2 8 | distributed_type: DEEPSPEED 9 | downcast_bf16: "no" 10 | dynamo_backend: "NO" 11 | fsdp_config: {} 12 | machine_rank: 0 13 | main_training_function: do_finetuning 14 | mixed_precision: "no" 15 | num_machines: 1 16 | num_processes: 1 17 | rdzv_backend: static 18 | same_network: true 19 | tpu_env: [] 20 | tpu_use_cluster: false 21 | tpu_use_sudo: false 22 | use_cpu: false 23 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from setuptools import setup 3 | 4 | CWD = Path(__file__).absolute().parent 5 | 6 | 7 | def get_version(): 8 | """Gets the project version""" 9 | path = CWD / "bayesian_lora" / "__init__.py" 10 | content = path.read_text() 11 | for line in content.splitlines(): 12 | if line.startswith("__version__"): 13 | return line.strip().split()[-1].strip().strip('"') 14 | raise RuntimeError("bad version data in __init__.py") 15 | 16 | 17 | if __name__ == "__main__": 18 | print(f"Version: {get_version()}") 19 | setup(name="bayesian_lora", version=get_version()) 20 | -------------------------------------------------------------------------------- /examples/configs/llm/gpt2.yaml: -------------------------------------------------------------------------------- 1 | name: gpt2 2 | model_name_or_path: gpt2 3 | 4 | config_class: AutoConfig 5 | config_kwargs: {} 6 | 7 | tokenizer_class: AutoTokenizer 8 | tokenizer_kwargs: 9 | use_fast: true 10 | 11 | tokenizer_special_tokens: {} 12 | 13 | model_class: AutoModelForCausalLM 14 | model_kwargs: 15 | torch_dtype: bfloat16 # auto 16 | 17 | # Global HF generation configurations 18 | global_gen_kwargs: {} 19 | 20 | add_space: true 21 | is_sc: false 22 | 23 | use_peft: false 24 | peft: 25 | target_modules: ["c_attn", "c_proj", "c_fc", "lm_head"] 26 | 27 | use_quant: false 28 | 29 | defaults: 30 | - quantization: none 31 | - peft: lora 32 | - _self_ 33 | -------------------------------------------------------------------------------- /examples/configs/base_config.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | # Baseline / default configuration to extend from for all other experiments. 3 | # (Include it as the first item in the defaults list; see example_usage.yaml.) 4 | 5 | notes: |- 6 | Write some notes here to describe the run. 7 | 8 | # tip: will be used for paths (avoid spces and special characters) 9 | task_name: example 10 | 11 | # The logging level; one of CRITICAL, ERROR, WARNING, INFO, DEBUG, NOTSET 12 | log_level: INFO 13 | 14 | # optional random seed for deterministic runs 15 | seed: null 16 | 17 | # Whether to print out the configuration at the start of the run 18 | print_config: True 19 | 20 | defaults: 21 | - hydra: default.yaml 22 | - paths: default.yaml 23 | - _self_ 24 | -------------------------------------------------------------------------------- /examples/configs/paths/default.yaml: -------------------------------------------------------------------------------- 1 | # Path to the root of the repo. We just set this to the current directory, 2 | # since we will mostly be invoking scripts from there (Makefile, submission 3 | # script, manually on the CLI etc.) 4 | root_dir: . 5 | 6 | # Path to the data directory (useful when large datasets are stored on 7 | # different volumes) 8 | data_dir: ${paths.root_dir}/data/ 9 | 10 | # Directory to use for logs (tensorboard, csv, etc) 11 | log_dir: ${paths.root_dir}/logs/ 12 | 13 | # Can be used as a place to store any artifacts generated during the run. 14 | # Again, can be used to put heavy files on a separate volume or storage media. 15 | output_dir: ${hydra:runtime.output_dir} 16 | 17 | # Path to the working directory 18 | work_dir: ${hydra:runtime.cwd} 19 | -------------------------------------------------------------------------------- /examples/configs/llm/llama2.yaml: -------------------------------------------------------------------------------- 1 | name: llama2 2 | model_name_or_path: meta-llama/Llama-2-7b-hf 3 | 4 | config_class: AutoConfig 5 | config_kwargs: 6 | trust_remote_code: True 7 | 8 | tokenizer_class: AutoTokenizer 9 | tokenizer_kwargs: 10 | use_fast: true 11 | 12 | tokenizer_special_tokens: {} 13 | 14 | model_class: AutoModelForCausalLM 15 | model_kwargs: 16 | torch_dtype: bfloat16 17 | low_cpu_mem_usage: true 18 | attn_implementation: "flash_attention_2" 19 | 20 | # Global HF generation configurations 21 | global_gen_kwargs: {} 22 | 23 | add_space: false 24 | is_sc: false 25 | 26 | use_peft: false 27 | peft: 28 | target_modules: ["q_proj", "v_proj", "lm_head"] 29 | 30 | use_quant: false 31 | 32 | defaults: 33 | - quantization: none 34 | - peft: lora 35 | - _self_ 36 | -------------------------------------------------------------------------------- /examples/configs/llm/roberta.yaml: -------------------------------------------------------------------------------- 1 | name: RoBERTa 2 | model_name_or_path: FacebookAI/roberta-base 3 | 4 | config_class: AutoConfig 5 | config_kwargs: {} 6 | 7 | tokenizer_class: AutoTokenizer 8 | tokenizer_kwargs: 9 | use_fast: true 10 | 11 | tokenizer_special_tokens: {} 12 | 13 | model_class: AutoModelForSequenceClassification 14 | model_kwargs: 15 | torch_dtype: bfloat16 # auto 16 | attn_implementation: "flash_attention_2" 17 | problem_type: multi_label_classification 18 | 19 | # Global HF generation configurations 20 | global_gen_kwargs: {} 21 | 22 | add_space: true 23 | is_sc: true 24 | 25 | use_peft: false 26 | peft: 27 | target_modules: ["query", "value", "dense"] 28 | 29 | use_quant: false 30 | 31 | defaults: 32 | - quantization: none 33 | - peft: lora 34 | - _self_ 35 | -------------------------------------------------------------------------------- /examples/configs/llm/zephyr.yaml: -------------------------------------------------------------------------------- 1 | name: llama2 2 | model_name_or_path: HuggingFaceH4/zephyr-7b-beta 3 | 4 | config_class: AutoConfig 5 | config_kwargs: {} 6 | 7 | tokenizer_class: AutoTokenizer 8 | tokenizer_kwargs: 9 | use_fast: true 10 | 11 | tokenizer_special_tokens: {} 12 | 13 | model_class: AutoModelForCausalLM 14 | model_kwargs: 15 | torch_dtype: bfloat16 16 | 17 | # Global HF generation configurations 18 | global_gen_kwargs: {} 19 | 20 | add_space: false 21 | is_sc: false 22 | 23 | use_peft: false 24 | peft: 25 | r: 8 26 | target_modules: ["q_proj", "v_proj", "lm_head"] 27 | # target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "lm_head"] 28 | bias: lora_only 29 | 30 | use_quant: false 31 | 32 | defaults: 33 | - quantization: none 34 | - peft: lora 35 | - _self_ 36 | -------------------------------------------------------------------------------- /examples/configs/accelerate/fsdp.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: FSDP 4 | downcast_bf16: "no" 5 | fsdp_config: 6 | fsdp_auto_wrap_policy: SIZE_BASED_WRAP # TRANSFORMER_BASED_WRAP 7 | fsdp_backward_prefetch_policy: BACKWARD_POST 8 | fsdp_cpu_ram_efficient_loading: true 9 | fsdp_forward_prefetch: false 10 | fsdp_offload_params: true 11 | fsdp_sharding_strategy: 1 12 | fsdp_state_dict_type: SHARDED_STATE_DICT 13 | fsdp_sync_module_states: true 14 | fsdp_use_orig_params: false 15 | machine_rank: 0 16 | main_training_function: main 17 | mixed_precision: no # fp16 18 | num_machines: 1 19 | num_processes: 4 20 | # num_cpu_threads_per_process: 4 # TODO: update this 21 | rdzv_backend: static 22 | same_network: true 23 | tpu_env: [] 24 | tpu_use_cluster: false 25 | tpu_use_sudo: false 26 | use_cpu: true 27 | -------------------------------------------------------------------------------- /examples/configs/llm/phi.yaml: -------------------------------------------------------------------------------- 1 | name: phi 2 | model_name_or_path: microsoft/phi-2 3 | 4 | config_class: AutoConfig 5 | config_kwargs: 6 | trust_remote_code: true 7 | 8 | tokenizer_class: AutoTokenizer 9 | tokenizer_kwargs: 10 | use_fast: true 11 | 12 | tokenizer_special_tokens: {} 13 | 14 | model_class: AutoModelForCausalLM 15 | model_kwargs: 16 | torch_dtype: auto 17 | trust_remote_code: true 18 | # attn_implementation: sdpa 19 | attn_implementation: "flash_attention_2" 20 | 21 | # Global HF generation configurations 22 | global_gen_kwargs: {} 23 | 24 | add_space: false 25 | is_sc: false 26 | 27 | use_peft: false 28 | peft: 29 | # List of all available modules to target for phi 30 | # target_modules: ["q_proj", "k_proj", "v_proj", "dense", "fc1", "fc2", "lm_head"] 31 | target_modules: ["q_proj", "v_proj", "lm_head"] 32 | 33 | use_quant: false 34 | 35 | defaults: 36 | - quantization: none 37 | - peft: lora 38 | - _self_ 39 | -------------------------------------------------------------------------------- /examples/utils/loggers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import logging 4 | 5 | from omegaconf import DictConfig, OmegaConf 6 | 7 | # from lightning.fabric.loggers.csv_logs import CSVLogger 8 | # from lightning.fabric.loggers.tensorboard import TensorBoardLogger 9 | 10 | 11 | def clean_dir(dir_path: str) -> None: 12 | """Empties a directory by deleting the directory and creating a new empty 13 | directory in its place. 14 | 15 | Args: 16 | dir_path: path to directory to clean. 17 | """ 18 | shutil.rmtree(dir_path) 19 | os.mkdir(dir_path) 20 | 21 | 22 | def setup_loggers(cfg: DictConfig): 23 | """ 24 | Sets up loggers for the run based on the provided configurations. 25 | """ 26 | 27 | logging.getLogger().setLevel(getattr(logging, cfg.log_level.upper(), "INFO")) 28 | 29 | if cfg.print_config: 30 | print(OmegaConf.to_yaml(cfg)) 31 | 32 | if cfg.paths.output_dir.split("/")[-1] == "dev_run": 33 | logging.info("Cleaning development log directory") 34 | clean_dir(cfg.paths.output_dir) 35 | 36 | # Save the configuration values in a file in the outout directory for later 37 | # reference 38 | with open(os.path.join(cfg.paths.output_dir, "config.yaml"), "w") as f: 39 | f.write(OmegaConf.to_yaml(cfg)) 40 | 41 | # Setup TensorBoard and CSV loggers 42 | # op = cfg.paths.output_dir.split("/") 43 | # tb_logger = TensorBoardLogger("/".join(op[:-2]), op[-2], op[-1]) 44 | # csv_logger = CSVLogger( 45 | # "/".join(op[:-2]), op[-2], op[-1], flush_logs_every_n_steps=1 46 | # ) 47 | # return tb_logger, csv_logger 48 | -------------------------------------------------------------------------------- /notebooks/test_notebook.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # # Incremental SVD 5 | 6 | # In[1]: 7 | 8 | 9 | import torch as t 10 | from torch import Tensor 11 | from jaxtyping import Float 12 | 13 | 14 | # In[2]: 15 | 16 | 17 | device = "cuda:0" 18 | dtype = t.float32 19 | 20 | 21 | # In[9]: 22 | 23 | 24 | d, n_kfac, batch = 1024, 1000, 16 25 | 26 | 27 | # ## Testing the Incrementatl SVD 28 | 29 | # Import our library method 30 | 31 | # In[10]: 32 | 33 | 34 | from bayesian_lora.kfac import incremental_svd 35 | 36 | 37 | # Create a ground-truth, full-rank matrix 38 | 39 | # In[11]: 40 | 41 | 42 | true_B = t.randn(d, d).to(device, dtype) 43 | true_B = true_B@true_B.T 44 | 45 | 46 | # Calculate the full SVD to get a low-rank factor 47 | 48 | # In[12]: 49 | 50 | 51 | U_full, S_full, _ = t.linalg.svd(true_B, full_matrices=False) 52 | full_B = U_full[:, :n_kfac]@t.diag(S_full[:n_kfac]) 53 | 54 | 55 | # In[13]: 56 | 57 | 58 | t.norm(true_B - (full_B@full_B.T)) 59 | 60 | 61 | # In[7]: 62 | 63 | 64 | t.norm(true_B - (full_B@full_B.T)) 65 | 66 | 67 | # In[8]: 68 | 69 | 70 | assert t.allclose(true_B, full_B@full_B.T) 71 | 72 | 73 | # ## Using Eigendecomposition 74 | 75 | # In[25]: 76 | 77 | 78 | # Compute eigenvalues and eigenvectors 79 | eigenvalues, eigenvectors = t.linalg.eigh(true_B) 80 | 81 | 82 | # In[34]: 83 | 84 | 85 | # Choose the rank for the approximation 86 | N = 1000 87 | 88 | # Select the top N eigenvalues and eigenvectors 89 | top_eigenvalues = eigenvalues[-N:] 90 | top_eigenvectors = eigenvectors[:, -N:] 91 | 92 | # Reconstruct the low-rank approximation of the matrix 93 | B_approx = top_eigenvectors @ t.diag(top_eigenvalues) @ top_eigenvectors.T 94 | 95 | 96 | # In[35]: 97 | 98 | 99 | t.norm(true_B - B_approx) 100 | 101 | 102 | # #### Dribs and Drabs 103 | 104 | # In[5]: 105 | 106 | 107 | A = t.randn(d, n_kfac).to(device, dtype) 108 | a = t.randn(d, batch).to(device, dtype) 109 | 110 | U, S, _ = t.linalg.svd(A, full_matrices=False) 111 | 112 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: install install_all example test test_all mypy lab kernel docs help 2 | .PHONY: publish release pypi 3 | 4 | install: ## Install this package in the current environment 5 | pip install -e . 6 | 7 | install_all: ## Install everything 8 | pip install -e ".[all]" 9 | 10 | # Examples and tests =========================================================== 11 | # Run: pip install -e ".[examples]" 12 | 13 | example: ## Run through the example script 14 | python ./examples/example_usage.py 15 | 16 | test: ## Run the unit tests 17 | @python -m pytest -s tests --typeguard-packages=bayesian_lora -k "not test_example" 18 | 19 | test_all: ## Run all the tests (including the slow ones) 20 | @python -m pytest -s tests --run-slow 21 | 22 | # Development ================================================================== 23 | # Run: pip install -e ".[dev]" 24 | 25 | mypy: ## Run static type checking 26 | @mypy 27 | 28 | lab: ## To start a Jupyter Lab server 29 | @python -m jupyter lab --notebook-dir=notebooks 30 | 31 | kernel: ## To setup a Jupyter kernel to run notebooks in the project's virtual env 32 | python -m ipykernel install --user --name bayesian_lora \ 33 | --display-name "bayesian_lora" 34 | 35 | pypi: ## Creates a source distribution and wheel, and uplaods to PyPI 36 | python3 -m pip install --upgrade build 37 | python3 -m build 38 | python3 -m twine upload dist/* 39 | 40 | release: ## Create release for GitHub 41 | $(eval VERSION := $(shell python -c "import bayesian_lora; print(bayesian_lora.__version__)")) 42 | git checkout master 43 | git pull origin master 44 | git tag -a $(VERSION) -m "Release version $(VERSION)" 45 | git push origin $(VERSION) 46 | 47 | publish: release pypi ## Publish a new release and PyPI package 48 | 49 | # Documentation ================================================================ 50 | # Run: pip install -e ".[docs]" 51 | 52 | docs: ## Compile the documentation and start watcher 53 | @./documentation/writedocs.sh 54 | 55 | help: 56 | @grep -E '^[a-zA-Z0-9_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 57 | -------------------------------------------------------------------------------- /examples/configs/example_usage.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # 1. Load up the default component configurations 4 | defaults: 5 | - base_config 6 | - opt: adamw 7 | - dset: arc 8 | - llm: phi 9 | - override llm/peft: lora # unused if use_peft: None 10 | - override llm/quantization: 4bit 11 | - _self_ 12 | 13 | notes: |- 14 | Demonstration usage for Laplace-LoRA method with a decoder-only language 15 | model. 16 | 17 | log_level: INFO # DEBUG | INFO | WARN | ERROR | CRITICAL 18 | 19 | # Whether to (re)-run each of the steps (MAP training, log likelihood 20 | # calculation, Kronecker factor calculation, marginal likelihood tuning and 21 | # linearized prediction), or use existing checkpoints if present. 22 | run_every_step: true 23 | 24 | # Print this configuration at the start 25 | print_config: false 26 | 27 | # Show progress bars using tqdm 28 | use_tqdm: true 29 | 30 | # MAP training parameters 31 | map_lr: 5e-5 32 | train_steps: 1000 33 | 34 | # The rank used for the large Kronecker factors 35 | n_kfac: 10 36 | 37 | # The Kronecker factor edge size over which to use the low-rank approximation 38 | lr_threshold: 100 39 | 40 | # Initial variance of the prior variance (before marginal likelihood-based 41 | # optimisation) 42 | prior_var: 0.1 43 | 44 | # Used to create a unique result directory for (dataset, model) combinations 45 | # ('tasks') 46 | task_name: ${dset.name}_${llm.name} 47 | 48 | # to recover from a given checkpoint, paste the directory path here: 49 | hydra: 50 | run: 51 | dir: ${paths.root_dir}/example_outputs/${task_name} 52 | 53 | # Random seed for reproducible runs 54 | seed: null 55 | 56 | # Dataset configuration overrides 57 | dset: 58 | train_bs: 6 # 12 59 | eval_bs: 4 60 | train_subset: -1 # 250 61 | 62 | # keyword arguments passed to the tokenizer's `__call__` method 63 | tokenizer_run_kwargs: 64 | padding: true 65 | truncation: true 66 | return_tensors: "pt" 67 | max_length: 512 68 | 69 | # LLM configuration overrides 70 | llm: 71 | use_peft: true 72 | use_quant: true 73 | 74 | model_kwargs: 75 | # torch_dtype: auto # float32 76 | torch_dtype: bfloat16 # Try using float32 77 | low_cpu_mem_usage: true 78 | 79 | # kwargs used when setting up the tokenizer 80 | tokenizer_kwargs: 81 | padding_side: left 82 | 83 | tokenizer_special_tokens: 84 | pad_token: tokenizer.bos_token 85 | 86 | peft: 87 | r: 8 # 'n_lora' the LoRA rank 88 | bias: "none" # "lora_only" 89 | 90 | out_dir: null 91 | # out_dir: outputs/llama2_cqa 92 | -------------------------------------------------------------------------------- /documentation/source/bayesian_lora.rst: -------------------------------------------------------------------------------- 1 | .. _bayesian_lora: 2 | 3 | Bayesian Lora 4 | ============= 5 | 6 | This file contains the main methods relating to the `Bayesian Low-Rank 7 | Adaptation for Large Language Models 8 | `_ paper. Namely, calculating the 9 | model evidence for tuning prior and network hyperparameters and calculating the 10 | posterior predictive parameters for making (linearised) predictions. 11 | 12 | Model Evidence 13 | -------------- 14 | 15 | The model evidence, or marginal likelihood, is a scalar value that indicates the 16 | evidence provided by the data for a particular model. A model with a higher 17 | marginal likelihood is considered more supported by the data under the given 18 | prior. 19 | 20 | .. autofunction:: bayesian_lora.main.model_evidence 21 | 22 | 23 | Posterior Predictive 24 | -------------------- 25 | 26 | This involves two steps, calculating the mean and the variance. 27 | 28 | For the first, we invoke the (admittedly, awkwardly named) ``jacobian_mean`` 29 | function, which returns the Jacobian, and the mean, respectively. 30 | 31 | .. autofunction:: bayesian_lora.main.jacobian_mean 32 | 33 | As you can see, there are two ways of calling this function, which determine how 34 | we'll handle the outputs from the wrapped network call. 35 | 36 | 1. **Directly, with parameters** Here, we assume that a model is either a 37 | sequence-to-sequence model or not (defaults to ``False``), and that we may 38 | optionally want to pick out some specific logits from the model's full 39 | vocabulary: 40 | 41 | 42 | .. code-block:: py 43 | 44 | jacobian, f_mu = jacobian_mean( 45 | model, batch_inputs, target_ids=dset.target_ids, is_sc=False 46 | ) 47 | 48 | 2. **Custom output callback** Here, we allow the user to provide a callback 49 | function, taking in the result of the model's ``forward`` call, and returning 50 | the logits of interest, with arbitrary post-processing in between. 51 | 52 | .. code-block:: py 53 | 54 | def default_output_callback(outputs: ModelOutput) -> Tensor: 55 | logits = outputs.logits if cfg.llm.is_sc else outputs.logits[:, -1] 56 | target_logits = logits[:, dset.target_ids] 57 | return target_logits 58 | 59 | jacobian, f_mu = jacobian_mean( 60 | model, batch_inputs, output_callback=output_callback 61 | ) 62 | 63 | For the second step, we calculate the output logits' covariance matrix. 64 | 65 | .. autofunction:: bayesian_lora.main.variance 66 | -------------------------------------------------------------------------------- /documentation/source/kfac.rst: -------------------------------------------------------------------------------- 1 | .. _kfac: 2 | 3 | K-FAC Methods 4 | ============= 5 | 6 | The :mod:`bayesian_lora.kfac` module provides functions for calculating 7 | an approximate Fisher information matrix (or GGN) using Kronecker-factored 8 | approximate curvature. 9 | 10 | Recall that K-FAC first finds a block-diagonal approximation to the full Fisher 11 | / GGN. If we had a simple 4-layer network, then this would be: 12 | 13 | .. figure:: _static/block_diagonal.svg 14 | :align: center 15 | :width: 70% 16 | :alt: Block-diagonal approximation 17 | 18 | Eeach of these blocks (:math:`\mathbf{G}_{\ell \ell}`) are further 19 | approximated as the product of two Kronecker factors, one corresponding to the 20 | input *activations*, :math:`\mathbf{A}_{\ell-1}`, and another to the *output 21 | gradients*, :math:`\mathbf{S}_{\ell}`. That is, for a particular layer / 22 | ``nn.Module`` indexed by :math:`\ell`, we approximate its block of the full 23 | Fisher as 24 | 25 | .. math:: 26 | :label: kfacblock 27 | 28 | \mathbf{G}_{\ell\ell} \approx \mathbf{A}_{\ell-1} \otimes \mathbf{S}_{\ell}. 29 | 30 | These factors (curvature information around the network's current parameters) 31 | are calculated over some dataset :math:`\mathcal{D}`, and this is what the 32 | :func:`bayesian_lora.calculate_kronecker_factors` function below calculates. 33 | 34 | Rather than using numerical indices :math:`\ell \in \{1, 2, \ldots, L\}`, we use 35 | the ``nn.Module``'s name to identify the different blocks, and return the 36 | factors in dictionaries of type ``dict[str, t.Tensor]``. 37 | 38 | Full-Rank K-FAC 39 | --------------- 40 | 41 | The simplest variant is a *full-rank* Kronecker factorisation, meaning that we 42 | store the :math:`\mathbf{A}` and :math:`\mathbf{S}` matrices exactly. 43 | 44 | .. autofunction:: bayesian_lora.calculate_kronecker_factors 45 | 46 | Notice how these Kronecker factors can themselves be approximated as low-rank 47 | which is particularly useful for LLMs, where the factors may be :math:`4096 48 | \times 4096` for each layer in a transformer. 49 | 50 | Internal Functions 51 | ------------------ 52 | 53 | The above is the main way to use the K-FAC functionality from this library. 54 | It calls a number of internal functions, which we document here for re-use and 55 | completeness. 56 | 57 | .. autofunction:: bayesian_lora.kfac.register_hooks 58 | 59 | .. autofunction:: bayesian_lora.kfac.remove_hooks 60 | 61 | .. autofunction:: bayesian_lora.kfac.save_input_hook 62 | 63 | .. autofunction:: bayesian_lora.kfac.save_output_grad_hook 64 | -------------------------------------------------------------------------------- /documentation/source/index.rst: -------------------------------------------------------------------------------- 1 | Bayesian LoRA 2 | ============= 3 | 4 | This repository contains: 5 | 6 | - an implementation of K-FAC 7 | - Bayesian LoRA for language models. 8 | 9 | .. .. todo:: Add a diagram of the components 10 | .. .. figure:: _static/ml_overview.png 11 | 12 | Installation Guide 13 | ------------------ 14 | 15 | The simplest way to use the library is to simply pip install it:: 16 | 17 | pip install bayesian-lora 18 | 19 | Editable Installation 20 | ^^^^^^^^^^^^^^^^^^^^^ 21 | 22 | If you would like to modify the library or build upon it, while keeping it as a 23 | separate library, then you can clone the repo and run an editable installation:: 24 | 25 | git clone https://github.com/MaximeRobeyns/bayesian_lora 26 | cd bayesian_lora 27 | pip install -e . 28 | 29 | Hackable Installation 30 | ^^^^^^^^^^^^^^^^^^^^^ 31 | 32 | The library is currently very small, and has three core dependencies, ``torch`` 33 | ``tqdm``, and ``jaxtyping``; and two main files. 34 | 35 | To this end, feel free to directly copy the file you need into your own project 36 | and start hacking on it. 37 | 38 | Installation with Examples 39 | ^^^^^^^^^^^^^^^^^^^^^^^^^^ 40 | 41 | There are some examples included with the GitHub repository. Before running the 42 | code in these files, you must install some additional dependencies which are 43 | omitted from the main library to keep it small. To do this, after cloning the 44 | repo, from the root simply run:: 45 | 46 | pip install -e ".[examples]" 47 | 48 | Development Installation 49 | ^^^^^^^^^^^^^^^^^^^^^^^^ 50 | 51 | If you plan on developing on this library, you may wish to install some 52 | development-related packages with ``pip install -e ".[dev]"``. To write 53 | documentation, install the requirements with ``pip install -e ".[docs]"``. 54 | 55 | For simplicity, you can also just run:: 56 | 57 | pip install -e ".[all]" 58 | 59 | Jupyter Notebooks 60 | ````````````````` 61 | 62 | To test functions from a Jupyter notebook, make sure that you have installed the 63 | project with the ``dev`` dependencies. You then need to run the following 64 | command once to set up the iPython kernel:: 65 | 66 | make kernel 67 | 68 | You only need to do this once. After you do so, you will see a new 69 | ``bayesian_lora`` kernel inside jupyterlab. To launch jupyterlab, we include a 70 | convenience target:: 71 | 72 | make lab 73 | 74 | Contents 75 | -------- 76 | 77 | .. toctree:: 78 | :maxdepth: 2 79 | :glob: 80 | :caption: Contents: 81 | 82 | kfac 83 | bayesian_lora 84 | example_usage 85 | 86 | .. 87 | Indices and tables 88 | ------------------ 89 | 90 | * :ref:`genindex` 91 | * :ref:`modindex` 92 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # Package ===================================================================== 2 | 3 | [build-system] 4 | requires = ["setuptools >= 61.0.0"] 5 | build-backend = "setuptools.build_meta" 6 | 7 | [project] 8 | name = "bayesian_lora" 9 | dynamic = ["version"] # version number is inferred in ./setup.py 10 | description = "Bayesian LoRA adapters for Language Models" 11 | authors = [ 12 | { name = "Maxime Robeyns", email = "dev@maximerobeyns.com" }, 13 | ] 14 | license = { text = "Apache-2.0" } 15 | readme = "README.md" 16 | requires-python = ">=3.8" 17 | keywords = ["Bayes", "LLM", "LoRA", "machine learning", "uncertainty"] 18 | classifiers = [ 19 | "Development Status :: 3 - Alpha", 20 | "Environment :: GPU :: NVIDIA CUDA", 21 | "Operating System :: OS Independent", 22 | "Programming Language :: Python :: 3", 23 | "License :: OSI Approved :: Apache Software License", 24 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 25 | ] 26 | # Minimal dependencies are intentionally left without version requirements, 27 | # since this has made past Laplace projects hard to work with... 28 | dependencies = [ 29 | "jaxtyping>=0.2.25", 30 | "torch", 31 | "tqdm", 32 | ] 33 | 34 | [project.optional-dependencies] 35 | # Dependencies for running the examples and tests 36 | examples = [ 37 | "datasets>=2.16.1", 38 | "hydra-core>=1.2.0, <2.0", 39 | "omegaconf>=2.3.0", 40 | "peft>=0.5.0", 41 | "torchmetrics>=1.2.0", 42 | "transformers>=4.37.2", 43 | "pytest>=7.2.0", 44 | "bitsandbytes", 45 | ] 46 | # Other miscellaneous dev tools 47 | dev = [ 48 | "ipywidgets>=8.0.4", 49 | "jupyterlab>=3.5, <3.6", 50 | "jupyterlab-vim", 51 | "jupyterlab-vimrc", 52 | "mypy>=0.990,<=1.0", 53 | "tensorboard>=2.11.2, <3.0", 54 | ] 55 | # Doc writing 56 | docs = [ 57 | "furo>=2022.9.29", 58 | "sphinx-autobuild>=2021.3.14", 59 | "sphinx-copybutton>=0.5.1", 60 | "sphinxext-opengraph>=0.7.2", 61 | ] 62 | all = ["bayesian_lora[examples]", "bayesian_lora[dev]", "bayesian_lora[docs]"] 63 | 64 | [project.urls] 65 | Homepage = "https://github.com/MaximeRobeyns/bayesian_lora" 66 | Repository = "https://github.com/MaximeRobeyns/bayesian_lora" 67 | Documentation = "https://maximerobeyns.github.io/bayesian_lora/" 68 | 69 | [tool.setuptools] 70 | include-package-data = true 71 | 72 | [tool.setuptools.packages.find] 73 | include = ["bayesian_lora", "bayesian_lora/*"] 74 | 75 | [tool.setuptools.package-data] 76 | # include any package data as a list of paths here 77 | bayesian_lora = [ ] 78 | 79 | [tool.mypy] 80 | python_version = "3.11" 81 | ignore_missing_imports = true 82 | files = "bayesian_lora/**/*.py" 83 | 84 | [tool.pytest.ini_options] 85 | # --ff for previously failed first 86 | # -l for print state on failure 87 | # -x for stop on first failure 88 | # -s for show stdout while testing 89 | # -v for verbose (e.g. show test names) 90 | # -n for n threadsafe parallel workers 91 | addopts = "-l -x --ff -s -v" 92 | testpaths = ["tests"] 93 | filterwarnings = ["ignore::DeprecationWarning"] 94 | markers = [ 95 | "slow: marks tests as slow (run with '--run-slow')", 96 | ] 97 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | ### VisualStudioCode 131 | .vscode/* 132 | !.vscode/settings.json 133 | !.vscode/tasks.json 134 | !.vscode/launch.json 135 | !.vscode/extensions.json 136 | *.code-workspace 137 | **/.vscode 138 | 139 | # JetBrains 140 | .idea/ 141 | 142 | # Data & Models 143 | *.h5 144 | *.tar 145 | *.tar.gz 146 | 147 | # Lightning-Hydra-Template 148 | configs/local/default.yaml 149 | /data/ 150 | /logs/ 151 | .env 152 | 153 | # Aim logging 154 | .aim 155 | 156 | planning.org 157 | logs 158 | 159 | .envrc 160 | reference 161 | notebooks 162 | 163 | .tabs 164 | .pdfs 165 | outputs 166 | third_party 167 | sync 168 | *.ipynb 169 | */.ipynb_checkpoints/* 170 | .nbcache 171 | example_outputs 172 | ./notebooks/test_notebook.py 173 | ./notebooks/roberta_dev.py 174 | -------------------------------------------------------------------------------- /examples/utils/setup_llm.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023-24 Maxime Robeyns 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Boilerplate for setting up a HuggingFace LLM. 16 | """ 17 | 18 | import logging 19 | import torch as t 20 | import transformers 21 | 22 | from typing import Optional 23 | from omegaconf import OmegaConf 24 | from hydra.utils import instantiate 25 | from transformers import BitsAndBytesConfig, GenerationConfig 26 | from transformers.utils import is_flash_attn_2_available 27 | 28 | # Avoid importing this globally for systems where peft is not installed 29 | from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training 30 | 31 | 32 | def str_to_torch_dtype(name: str) -> t.dtype: 33 | dt = t.__dict__[name] 34 | assert isinstance(dt, t.dtype) 35 | return dt 36 | 37 | 38 | def setup_model_kwargs( 39 | model_kwargs: dict = dict(), 40 | use_quant: bool = False, 41 | quantization: Optional[BitsAndBytesConfig] = None, 42 | ) -> dict: 43 | """ 44 | - Gets the config from hydra 45 | - Converts dtype strings to torch dtypes 46 | - Adds quantization configurations from hydra 47 | """ 48 | try: 49 | model_kwargs = OmegaConf.to_object(model_kwargs) 50 | except Exception: 51 | pass 52 | assert isinstance(model_kwargs, dict) 53 | for k, v in model_kwargs.items(): 54 | if "dtype" in k.lower() and v != "auto": 55 | model_kwargs[k] = str_to_torch_dtype(v) 56 | if "attn_implementation" in k.lower(): 57 | if v == "flash_attention_2": 58 | model_kwargs[k] = ( 59 | "flash_attention_2" if is_flash_attn_2_available() else "sdpa" 60 | ) 61 | if use_quant and quantization is not None: 62 | model_kwargs["quantization_config"] = instantiate(quantization) 63 | return model_kwargs 64 | 65 | 66 | def setup_llm( 67 | model_name_or_path: str, 68 | config_class: str = "AutoConfig", 69 | config_kwargs: dict = dict(), 70 | tokenizer_class: str = "AutoTokenizer", 71 | tokenizer_kwargs: dict = dict(), 72 | tokenizer_special_tokens: dict = dict(), 73 | model_class: str = "AutoModelForCausalLM", 74 | model_kwargs: dict = dict(), 75 | global_gen_kwargs: dict = dict(), 76 | use_peft: bool = False, 77 | peft: Optional[LoraConfig] = None, 78 | use_quant: bool = False, 79 | quantization: Optional[BitsAndBytesConfig] = None, 80 | **_kwargs, 81 | ): 82 | """ 83 | A simple function to wrap all the HuggingFace boilerplate. 84 | This loads the model configuration, the model itself, apply any 85 | BitsAndBytes quantization configuration, and PEFT configuration, and return 86 | the prepared model, tokenizer and generation config. 87 | """ 88 | # Load the HF model config 89 | config_cls = getattr(transformers, config_class) 90 | if not isinstance(config_kwargs, dict): 91 | config_kwargs = OmegaConf.to_object(config_kwargs) 92 | model_config = config_cls.from_pretrained(model_name_or_path, **config_kwargs) 93 | 94 | # Load the HF model 95 | model_cls = getattr(transformers, model_class) 96 | model_kwargs = setup_model_kwargs(model_kwargs, use_quant, quantization) 97 | model = model_cls.from_pretrained( 98 | model_name_or_path, config=model_config, **model_kwargs 99 | ) 100 | if use_quant and quantization is not None: 101 | model = prepare_model_for_kbit_training(model) 102 | 103 | # Configure PEFT if required 104 | if use_peft and peft is not None: 105 | logging.info("Setting up PEFT") 106 | peft_cfg = instantiate(peft) 107 | peft_cfg.target_modules = OmegaConf.to_object(peft_cfg.target_modules) 108 | 109 | # model.add_adapter(peft_cfg) 110 | # model.enable_adapters() 111 | # model.train() 112 | 113 | # NOTE: this manner of setting up the configuration seems to cause 114 | # issues when saving with `save_pretrained`... Investigate. 115 | # # peft_cfg = OmegaConf.to_container(peft_cfg, resolve=True) 116 | model = get_peft_model(model, peft_cfg) 117 | 118 | # Load the HF tokenizer 119 | tokenizer_cls = getattr(transformers, tokenizer_class) 120 | if not isinstance(tokenizer_kwargs, dict): 121 | tokenizer_kwargs = OmegaConf.to_object(tokenizer_kwargs) 122 | tokenizer = tokenizer_cls.from_pretrained(model_name_or_path, **tokenizer_kwargs) 123 | tokenizer_special_tokens = { 124 | k: ( 125 | getattr(tokenizer, v.split(".")[-1]) 126 | if isinstance(v, str) and v.startswith("tokenizer") 127 | else v 128 | ) 129 | for k, v in tokenizer_special_tokens.items() 130 | } 131 | if len(tokenizer_special_tokens) > 0: 132 | tokenizer.add_special_tokens(tokenizer_special_tokens) 133 | if tokenizer.pad_token is None: 134 | tokenizer.pad_token = tokenizer.eos_token 135 | 136 | # Load the global genration config 137 | if not isinstance(global_gen_kwargs, dict): 138 | global_gen_kwargs = OmegaConf.to_object(global_gen_kwargs) 139 | gen_cfg = GenerationConfig.from_pretrained(model_name_or_path, **global_gen_kwargs) 140 | 141 | return model, tokenizer, gen_cfg 142 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bayesian LoRA 2 | 3 | Code for the paper [Bayesian Low-Rank Adaptation for Large Language Models](https://openreview.net/forum?id=FJiUyzOF1m). 4 | 5 | See the explanatory [blog post](https://maximerobeyns.com/bayesian_lora) and [documentation](https://maximerobeyns.github.io/bayesian_lora/) for more information. 6 | 7 | ## Installation 8 | 9 | ```bash 10 | pip install bayesian-lora 11 | ``` 12 | 13 | # Example 14 | 15 | We provide a comprehensive example in `examples/example_usage.py`, running 16 | through the main methods using Phi-2 on ARC-E. 17 | 18 | Note that running this requires a local installation with a few extra 19 | dependencies. Run: 20 | ```bash 21 | git clone https://github.com/MaximeRobeyns/bayesian_lora 22 | cd bayesian_lora 23 | pip install -e ".[examples]" 24 | ``` 25 | and then 26 | ```bash 27 | python ./examples/example_usage.py 28 | ``` 29 | 30 | The main functions this library provides are for calculating Kronecker factors, 31 | the marginal likelihood, and the posterior predictive distribution. We show how 32 | to use these in the examples below. 33 | 34 | ## Calculating (low-rank) Kronecker factors 35 | 36 | First, wrap your model call in a function that takes a batch from your data 37 | loader, and returns the relevant logits. For a CausalLM from HuggingFace: 38 | 39 | ```python 40 | def fwd_call(model: nn.Module, batch_prompts: Any) -> t.Tensor: 41 | inputs = tokenizer(batch_prompts).to(device) 42 | outputs = model(**inputs) 43 | logits = outputs.logits[:, -1] # Get the last token logits 44 | return logits 45 | ``` 46 | You can now call our `calculate_kronecker_factors` function: 47 | ```python 48 | from bayesian_lora import calculate_kronecker_factors 49 | 50 | factors = calculate_kronecker_factors( 51 | model, # Your model (not necessarily PEFT) 52 | fwd_call, # Model call wrapper, defined above 53 | train_loader, # Your training data loader 54 | cfg.n_kfac, # (Optional) rank to use 55 | cfg.lr_threshold, # (Optional) threshold for low-rank approximation 56 | ["lora"], # (Optional) modules to target; defaults to all modules 57 | use_tqdm=True, # (Optional) use tqdm for progress bar 58 | ) 59 | ``` 60 | In the above, the `["lora"]` argument contains a case-insensitive list of 61 | keywords to identify modules to target. Since we're working with a LoRA model, 62 | we choose `"lora"` to target LoRA modules, for instance 63 | `layers.0.q_proj.lora_A`. 64 | 65 | The `factors` are a dictionary with keys being the full name of the targetted 66 | modules, and a tuple of two tensors as the values: the first being the 67 | (possibly low-rank) Kronecker factor corresponding to the input activations, 68 | and the second being the (possibly low-rank) factor corresponding to the output 69 | gradients. 70 | 71 | See [the K-FAC docs](https://maximerobeyns.github.io/bayesian_lora/kfac.html) 72 | for more detail. 73 | 74 | ## Model Evidence 75 | 76 | We provide a function called `model_evidence` which returns the evidence / 77 | marginal likelihood. 78 | 79 | ```python 80 | from bayesian_lora import model_evidence 81 | 82 | evidence = model_evidence( 83 | model, # Your model 84 | log_likelihood, # A Tensor with model's log likelihood on some eval dataset 85 | factors, # Kronecker factors, as calculated above 86 | n_lora, # rank used in the LoRA adapters 87 | n_kfac, # rank used in the Kronecker factors 88 | prior_var, # prior variance hyperparameter, as a tensor 89 | ) 90 | ``` 91 | 92 | You can then use `evidence` as the loss in a normal training loop, presuming 93 | your parameters (e.g. `prior_var` have gradients). 94 | 95 | ## Posterior Predictive Distribution 96 | 97 | To get the parameters of the Gaussian over the logits, use 98 | the `jacobian_mean` and `variance` functions. 99 | 100 | ```python 101 | with t.no_grad(): 102 | for batch in validation_loader 103 | prompts, classes = batch 104 | 105 | batch_inputs = tokenizer(prompts) 106 | 107 | # Predict the output logit locations 108 | # target_ids is a tensor containing the indices of the target tokens 109 | # e.g. [354, 355, 356]. 110 | jacobian, f_mu = jacobian_mean( 111 | model, batch_inputs, target_ids 112 | ) 113 | 114 | # Predict the output logit variances 115 | f_var = variance( 116 | batch_inputs, # inputs 117 | jacobian, # the Jacobian dictionary, obtained above 118 | factors, # Kronecker factors, as calculated above 119 | prior_var, # prior variance hyperparameter, as a tensor 120 | classes.size(-1), # number of classes to predict 121 | n_lora, # rank of the LoRA adapters 122 | n_kfac, # rank of the Kronecker factors 123 | device, # device to use 124 | ) 125 | 126 | # Now use the parameters to e.g. sample logits from the Gaussian 127 | # predictive, parametrised by f_mu, f_var 128 | L = t.linalg.cholesky(f_var) 129 | samples = 100_000 130 | f_mu = f_mu.expand(samples, *f_mu.shape) 131 | L = L.expand(samples, *L.shape) 132 | eps = t.randn_like(f_mu) 133 | logits = (f_mu + L @ eps).squeeze(-1).mean(0) 134 | ``` 135 | 136 | The above is a minimal example; see [this 137 | section](https://maximerobeyns.github.io/bayesian_lora/bayesian_lora.html#posterior-predictive) 138 | of the documentation for more detail. 139 | 140 | # Development 141 | 142 | This library is intentionally very small and hackable. It has two main files, 143 | and three dependencies (`torch`, `tqdm` and `jaxtyping`.) 144 | 145 | - `main.py` contains methods specific to [the paper](https://openreview.net/forum?id=FJiUyzOF1m), 146 | - `kfac.py` contains relatively portable K-FAC methods 147 | 148 | Feel free to directly copy the code into your projects and hack on it. 149 | -------------------------------------------------------------------------------- /documentation/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | import bayesian_lora 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | # import os 16 | # import sys 17 | # sys.path.insert(0, os.path.abspath('.')) 18 | 19 | 20 | # -- Project information ----------------------------------------------------- 21 | 22 | project = "bayesian_lora" 23 | copyright = "2024 Bayesian LoRA" 24 | author = "Maxime Robeyns" 25 | 26 | # The full version, including alpha/beta/rc tags 27 | 28 | release = bayesian_lora.__version__ 29 | # release = "0.0.1" 30 | 31 | 32 | # -- General configuration --------------------------------------------------- 33 | 34 | # Add any Sphinx extension module names here, as strings. They can be 35 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 36 | # ones. 37 | extensions = [ 38 | "sphinx.ext.autodoc", 39 | "sphinx.ext.napoleon", 40 | "sphinx.ext.todo", 41 | "sphinx.ext.coverage", 42 | "sphinx.ext.mathjax", 43 | "sphinx.ext.autosummary", 44 | "sphinx.ext.autosectionlabel", 45 | "sphinx_copybutton", 46 | "sphinxext.opengraph", 47 | ] 48 | 49 | autoclass_content = "both" 50 | # autodoc_mock_imports = ["config"] 51 | autodoc_default_flags = ["members", "inherited-members"] 52 | autosummary_generate = True 53 | 54 | # Add any paths that contain templates here, relative to this directory. 55 | templates_path = ["_templates"] 56 | 57 | # The suffix(es) of source filenames. 58 | # You can specify multiple suffix as a list of string: 59 | # 60 | # source_suffix = ['.rst', '.md'] 61 | source_suffix = ".rst" 62 | 63 | smartquotes = True 64 | 65 | # The master toctree document. 66 | master_doc = "index" 67 | 68 | # The language for content autogenerated by Sphinx. Refer to documentation 69 | # for a list of supported languages. 70 | # 71 | # This is also used if you do content translation via gettext catalogs. 72 | # Usually you set "language" from the command line for these cases. 73 | language = "en" 74 | 75 | # List of patterns, relative to source directory, that match files and 76 | # directories to ignore when looking for source files. 77 | # This pattern also affects html_static_path and html_extra_path. 78 | exclude_patterns = [] 79 | 80 | # The name of the Pygments (syntax highlighting) style to use. 81 | # pygments_style = "sphinx" 82 | 83 | show_authors = True 84 | 85 | 86 | # -- Options for HTML output ------------------------------------------------- 87 | 88 | # The theme to use for HTML and HTML Help pages. See the documentation for 89 | # a list of builtin themes. 90 | # 91 | html_theme = "furo" 92 | html_title = "Bayesian LoRA Documentation" 93 | 94 | math_eqref_format = "Equation {number}" 95 | 96 | # Theme options are theme-specific and customize the look and feel of a theme 97 | # further. For a list of options available for each theme, see the 98 | # documentation. 99 | # 100 | html_theme_options = { 101 | # "top_of_page_button": "None", 102 | } 103 | 104 | # Add any paths that contain custom static files (such as style sheets) here, 105 | # relative to this directory. They are copied after the builtin static files, 106 | # so a file named "default.css" will overwrite the builtin "default.css". 107 | html_static_path = ["_static"] 108 | 109 | # -- Options for HTMLHelp output --------------------------------------------- 110 | 111 | # Output file base name for HTML help builder. 112 | htmlhelp_basename = "bayesianLoraDoc" 113 | 114 | # -- Options for LaTeX output ------------------------------------------------ 115 | 116 | latex_engine = "pdflatex" 117 | 118 | latex_elements = { 119 | # The paper size ('letterpaper' or 'a4paper'). 120 | # 121 | # 'papersize': 'letterpaper', 122 | # The font size ('10pt', '11pt' or '12pt'). 123 | # 124 | # 'pointsize': '10pt', 125 | # Additional stuff for the LaTeX preamble. 126 | # 127 | # 'preamble': '', 128 | # Latex figure (float) alignment 129 | # 130 | # 'figure_align': 'htbp', 131 | } 132 | 133 | # Grouping the document tree into LaTeX files. List of tuples 134 | # (source start file, target name, title, 135 | # author, documentclass [howto, manual, or own class], toctree_only). 136 | latex_documents = [ 137 | ( 138 | "index", 139 | "bayesian_lora_docs.tex", 140 | "Bayesian LoRA Documentation", 141 | "Miscellaneous", 142 | "manual", 143 | False, 144 | ), 145 | ] 146 | 147 | latex_show_pagerefs = True 148 | 149 | # -- Options for manual page output ------------------------------------------ 150 | 151 | # One entry per manual page. List of tuples 152 | # (source start file, name, description, authors, manual section). 153 | man_pages = [ 154 | ( 155 | master_doc, 156 | "bayesian_lora", 157 | "Bayesian LoRA Documentation", 158 | [author], 159 | 1, 160 | ) 161 | ] 162 | 163 | 164 | # -- Options for Texinfo output ---------------------------------------------- 165 | 166 | # Grouping the document tree into Texinfo files. List of tuples 167 | # (source start file, target name, title, author, 168 | # dir menu entry, description, category) 169 | texinfo_documents = [ 170 | ( 171 | master_doc, 172 | "bayesian-lora-docs", 173 | "Bayesian LoRA Documentation", 174 | author, 175 | "bayesian_lora", 176 | "Documentation for the Bayesian LoRA package.", 177 | "Miscellaneous", 178 | False, 179 | ), 180 | ] 181 | 182 | 183 | # -- Extension configuration ------------------------------------------------- 184 | 185 | # -- Options for todo extension ---------------------------------------------- 186 | 187 | # If true, `todo` and `todoList` produce output, else they produce nothing. 188 | todo_include_todos = True 189 | -------------------------------------------------------------------------------- /tests/test_kfac.py: -------------------------------------------------------------------------------- 1 | """ 2 | K-FAC tests 3 | """ 4 | 5 | import pytest 6 | import torch as t 7 | import torch.nn as nn 8 | 9 | from torch import Tensor 10 | from typing import Any 11 | from jaxtyping import Float 12 | from torch.linalg import LinAlgError 13 | from torch.utils.data import TensorDataset, DataLoader 14 | 15 | from bayesian_lora.kfac import ( 16 | stable_cholesky, 17 | incremental_svd, 18 | calculate_kronecker_factors, 19 | ) 20 | 21 | 22 | def test_ill_conditioned_matrix(): 23 | """Test with an ill-conditioned matrix.""" 24 | # Create an ill-conditioned matrix 25 | ill_cond_matrix = t.rand((10, 10)) 26 | ill_cond_matrix = ill_cond_matrix @ ill_cond_matrix.T # Make it symmetric 27 | ill_cond_matrix[0, 0] = 1e-8 # Introduce ill-conditioning 28 | 29 | # Verify that the matrix really is ill-conditioned 30 | with pytest.raises(LinAlgError): 31 | L = t.linalg.cholesky(ill_cond_matrix) 32 | 33 | # Test if stable Cholesky decomposition succeeds 34 | L = stable_cholesky(ill_cond_matrix) 35 | assert isinstance(L, t.Tensor) 36 | assert not t.isnan(L).any() 37 | 38 | 39 | def test_well_conditioned_matrix(): 40 | """Test with a well-conditioned matrix.""" 41 | well_cond_matrix = t.rand((10, 10)) 42 | well_cond_matrix = well_cond_matrix @ well_cond_matrix.T + t.eye(10) 43 | 44 | # Test if Cholesky decomposition succeeds 45 | L = stable_cholesky(well_cond_matrix) 46 | assert isinstance(L, t.Tensor) 47 | assert not t.isnan(L).any() 48 | 49 | 50 | def test_non_square_matrix(): 51 | """Test with a non-square matrix.""" 52 | non_square_matrix = t.rand((10, 9)) 53 | 54 | with pytest.raises(Exception): 55 | stable_cholesky(non_square_matrix) 56 | 57 | 58 | def test_zero_matrix(): 59 | """Test with a zero matrix.""" 60 | zero_matrix = t.zeros((10, 10)) 61 | 62 | L = stable_cholesky(zero_matrix) 63 | assert isinstance(L, t.Tensor) 64 | assert not t.isnan(L).any() 65 | 66 | 67 | def test_incremental_svd(): 68 | d, n_kfac, batch = 1024, 10, 16 69 | A = t.randn(d, n_kfac) 70 | a = t.randn(batch, d) 71 | B = incremental_svd(A, a) 72 | assert A.shape == B.shape 73 | assert not t.isnan(B).any() 74 | 75 | 76 | class _TestingModel(nn.Module): 77 | def __init__(self, features: list[int], bias: bool = False): 78 | super().__init__() 79 | self.net = nn.Sequential() 80 | for i, (j, k) in enumerate(zip(features[:-1], features[1:])): 81 | self.net.add_module(name=f"FC{i}", module=nn.Linear(j, k, bias=bias)) 82 | if i < len(features) - 2: 83 | self.net.add_module(name=f"A{i}", module=nn.ReLU()) 84 | self.net.add_module(name=f"LN{i}", module=nn.LayerNorm(k)) 85 | else: 86 | self.net.add_module(name=f"SM{i}", module=nn.Softmax(-1)) 87 | 88 | def forward(self, x: Float[Tensor, "b n"]) -> Float[Tensor, "b m"]: 89 | return self.net(x).softmax(-1) 90 | 91 | 92 | def fwd_call(model: nn.Module, batch: Any) -> Float[Tensor, "batch out_params"]: 93 | xs, _ = batch 94 | logits = model(xs) 95 | logits = logits[:, -1] # emulate selecting the last token 96 | return logits 97 | 98 | 99 | def test_full_rank_kfac(): 100 | N, S, bs = 100, 8, 16 101 | features = [10, 20, 5] 102 | tmp_model = _TestingModel(features) 103 | xs, ys = t.randn(N, S, features[0]), t.randn(N, S, features[-1]) 104 | loader = DataLoader(TensorDataset(xs, ys), batch_size=bs) 105 | 106 | # Sanity check test setup 107 | for b in loader: 108 | xs, ys = b 109 | assert xs.shape == (bs, S, features[0]) 110 | assert ys.shape == (bs, S, features[-1]) 111 | out = fwd_call(tmp_model, b) 112 | assert out.shape == (bs, features[-1]) 113 | break 114 | 115 | factors = calculate_kronecker_factors( 116 | tmp_model, fwd_call, loader, target_module_keywords=["FC"] 117 | ) 118 | 119 | assert factors is not None 120 | assert len(factors) == len(features) - 1 121 | for i, (k, (A, S)) in enumerate(factors.items()): 122 | n, m = features[i], features[i + 1] 123 | assert A.shape == (n, n), f"Unexpected shape for {k}:A" 124 | assert S.shape == (m, m), f"Unexpected shape for {k}:S" 125 | 126 | 127 | def test_low_rank_kfac(): 128 | N, S, bs = 100, 8, 16 129 | n_kfac, lr_threshold = 4, 128 130 | features = [256, 256, 10] 131 | tmp_model = _TestingModel(features) 132 | xs, ys = t.randn(N, S, features[0]), t.randn(N, S, features[-1]) 133 | loader = DataLoader(TensorDataset(xs, ys), batch_size=bs) 134 | 135 | factors = calculate_kronecker_factors( 136 | tmp_model, 137 | fwd_call, 138 | loader, 139 | n_kfac=n_kfac, 140 | lr_threshold=128, 141 | target_module_keywords=["FC"], 142 | ) 143 | 144 | assert factors is not None 145 | assert len(factors) == len(features) - 1 146 | for i, (k, (A, S)) in enumerate(factors.items()): 147 | n, m = features[i], features[i + 1] 148 | if n < lr_threshold: 149 | assert A.shape == (n, n), f"Unexpected shape for {k}:A" 150 | else: 151 | assert A.shape == (n, n_kfac), f"Unexpected shape for {k}:A" 152 | if m < lr_threshold: 153 | assert S.shape == (m, m), f"Unexpected shape for {k}:S" 154 | else: 155 | assert S.shape == (m, n_kfac), f"Unexpected shape for {k}:S" 156 | 157 | 158 | def test_low_rank_kfac_lora_like(): 159 | """ 160 | LoRA-like alternating feature shapes 161 | """ 162 | N, S, bs = 100, 8, 16 163 | n_kfac, lr_threshold = 4, 128 164 | features = [256, 32, 256, 32, 512] 165 | tmp_model = _TestingModel(features) 166 | xs, ys = t.randn(N, S, features[0]), t.randn(N, S, features[-1]) 167 | loader = DataLoader(TensorDataset(xs, ys), batch_size=bs) 168 | 169 | factors = calculate_kronecker_factors( 170 | tmp_model, 171 | fwd_call, 172 | loader, 173 | n_kfac=n_kfac, 174 | lr_threshold=128, 175 | target_module_keywords=["FC"], 176 | ) 177 | 178 | assert factors is not None 179 | assert len(factors) == len(features) - 1 180 | for i, (k, (A, S)) in enumerate(factors.items()): 181 | n, m = features[i], features[i + 1] 182 | if n < lr_threshold: 183 | assert A.shape == (n, n), f"Unexpected shape for {k}:A" 184 | else: 185 | assert A.shape == (n, n_kfac), f"Unexpected shape for {k}:A" 186 | if m < lr_threshold: 187 | assert S.shape == (m, m), f"Unexpected shape for {k}:S" 188 | else: 189 | assert S.shape == (m, n_kfac), f"Unexpected shape for {k}:S" 190 | -------------------------------------------------------------------------------- /notebooks/roberta_dev.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # # RoBERTa Testing 5 | 6 | # In[1]: 7 | 8 | 9 | model_id = "FacebookAI/roberta-base" 10 | # model_id = "gpt2" 11 | 12 | 13 | # In[2]: 14 | 15 | 16 | import os 17 | import torch as t 18 | import bayesian_lora 19 | try: 20 | assert(_SETUP) 21 | except NameError: 22 | os.chdir(os.path.split(bayesian_lora.__path__[0])[0]) 23 | device = t.device("cuda") if t.cuda.is_available() else t.device("cpu") 24 | _SETUP = True 25 | 26 | 27 | # In[27]: 28 | 29 | 30 | import torch 31 | import matplotlib.pyplot as plt 32 | import torch.nn.functional as F 33 | 34 | from tqdm.notebook import tqdm 35 | from transformers import AutoTokenizer, AutoModelForSequenceClassification 36 | 37 | 38 | # In[4]: 39 | 40 | 41 | get_ipython().run_line_magic('load_ext', 'autoreload') 42 | from examples.utils import dsets 43 | get_ipython().run_line_magic('autoreload', '2') 44 | 45 | 46 | # In[5]: 47 | 48 | 49 | tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left") 50 | 51 | 52 | # Example dataset usage 53 | 54 | # In[6]: 55 | 56 | 57 | dset_class: dsets.ClassificationDataset = getattr(dsets, "boolq") 58 | dset = dset_class(tokenizer, add_space=True, max_len=50) 59 | 60 | 61 | # In[7]: 62 | 63 | 64 | print(f"The dataset has {dset.n_labels} labels") 65 | 66 | 67 | # # Single-label classification example 68 | 69 | # In[ ]: 70 | 71 | 72 | import torch 73 | from transformers import AutoTokenizer, RobertaForSequenceClassification 74 | 75 | tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-emotion") 76 | model = RobertaForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-emotion") 77 | 78 | inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") 79 | 80 | with torch.no_grad(): 81 | logits = model(**inputs).logits 82 | 83 | predicted_class_id = logits.argmax().item() 84 | model.config.id2label[predicted_class_id] 85 | 86 | # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)` 87 | num_labels = len(model.config.id2label) 88 | model = RobertaForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-emotion", num_labels=num_labels) 89 | 90 | labels = torch.tensor([1]) 91 | loss = model(**inputs, labels=labels).loss 92 | round(loss.item(), 2) 93 | 94 | 95 | # In[ ]: 96 | 97 | 98 | import torch 99 | from transformers import AutoTokenizer, BertForSequenceClassification 100 | 101 | model_id = "google/bert-base-uncased" 102 | tokenizer = AutoTokenizer.from_pretrained(model_id) 103 | model = BertForSequenceClassification.from_pretrained(model_id) 104 | 105 | inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") 106 | 107 | with torch.no_grad(): 108 | logits = model(**inputs).logits 109 | 110 | predicted_class_id = logits.argmax().item() 111 | model.config.id2label[predicted_class_id] 112 | 113 | # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)` 114 | num_labels = len(model.config.id2label) 115 | model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-yelp-polarity", num_labels=num_labels) 116 | 117 | labels = torch.tensor([1]) 118 | loss = model(**inputs, labels=labels).loss 119 | round(loss.item(), 2) 120 | 121 | 122 | # ## Sequence Classification (single label) 123 | 124 | # In[8]: 125 | 126 | 127 | model = AutoModelForSequenceClassification.from_pretrained( 128 | model_id, 129 | # low_cpu_mem_usage=True, 130 | torch_dtype=t.bfloat16, 131 | num_labels=dset.n_labels, 132 | ) 133 | 134 | 135 | # In[24]: 136 | 137 | 138 | model = model.to(0).train() 139 | opt = t.optim.AdamW(model.parameters(), lr=5e-4) 140 | 141 | 142 | # In[25]: 143 | 144 | 145 | loader = dset.loader(is_sc=True) 146 | 147 | 148 | # In[34]: 149 | 150 | 151 | def class_to_label(classes, num_labels): 152 | # dset.n_labels 153 | problem_type="multi_label_classification" 154 | 155 | labels = t.sum( 156 | F.one_hot(classes[:, None], num_classes=num_labels), dim=1 157 | ).to(torch.float) 158 | return labels 159 | 160 | 161 | # In[ ]: 162 | 163 | 164 | loss = model(**inputs, labels=labels).loss 165 | 166 | 167 | # In[40]: 168 | 169 | 170 | losses = [] 171 | for epoch in range(1): 172 | for batch in tqdm(loader): 173 | prompts, classes, _ = batch 174 | inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(0) 175 | # labels = class_to_label(classes, dset.n_labels).to(0) 176 | outputs = model(**inputs, labels=classes) 177 | opt.zero_grad() 178 | outputs.loss.backward() 179 | opt.step() 180 | break 181 | 182 | 183 | # In[ ]: 184 | 185 | 186 | test_loader = dset.loader(is_sc=True, split="test") 187 | for batch in test_loader: 188 | prompts, classes, _ = batch 189 | inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(0) 190 | outputs = model(**inputs, labels=classes.to(0)) 191 | predicted_class_ids = outputs.logits.argmax(0).item() 192 | print(predicted_class_ids) 193 | model.config.id2label[predicted_class_ids] 194 | break 195 | 196 | 197 | # In[ ]: 198 | 199 | 200 | inputs = tokenizer("Hello, this is my dog, Java. Woof.", return_tensors="pt") 201 | 202 | 203 | # In[36]: 204 | 205 | 206 | with torch.no_grad(): 207 | logits = model(**inputs).logits 208 | 209 | 210 | # In[37]: 211 | 212 | 213 | predicted_class_id = logits.argmax().item() 214 | model.config.id2label[predicted_class_id] 215 | 216 | 217 | # In[38]: 218 | 219 | 220 | # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)` 221 | num_labels = len(model.config.id2label) 222 | model = RobertaForSequenceClassification.from_pretrained(model_id, num_labels=num_labels) 223 | 224 | labels = torch.tensor([1]) 225 | loss = model(**inputs, labels=labels).loss 226 | round(loss.item(), 2) 227 | 228 | 229 | # ## Sequence Classification (multi label) 230 | 231 | # In[48]: 232 | 233 | 234 | tokenizer = AutoTokenizer.from_pretrained(model_id) 235 | model = AutoModelForSequenceClassification.from_pretrained( 236 | model_id, problem_type="multi_label_classification", num_labels=3 237 | ) 238 | 239 | 240 | # In[49]: 241 | 242 | 243 | inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") 244 | 245 | 246 | # In[51]: 247 | 248 | 249 | with torch.no_grad(): 250 | logits = model(**inputs).logits 251 | 252 | predicted_class_ids = torch.arange(0, logits.shape[-1])[torch.sigmoid(logits).squeeze(dim=0) > 0.5] 253 | 254 | 255 | # In[43]: 256 | 257 | 258 | predicted_class_ids 259 | 260 | 261 | # In[ ]: 262 | 263 | 264 | # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)` 265 | num_labels = len(model.config.id2label) 266 | model = RobertaForSequenceClassification.from_pretrained( 267 | "cardiffnlp/twitter-roberta-base-emotion", num_labels=num_labels, problem_type="multi_label_classification" 268 | ) 269 | 270 | labels = torch.sum( 271 | torch.nn.functional.one_hot(predicted_class_ids[None, :].clone(), num_classes=num_labels), dim=1 272 | ).to(torch.float) 273 | loss = model(**inputs, labels=labels).loss 274 | 275 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | -------------------------------------------------------------------------------- /bayesian_lora/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023-24 Maxime Robeyns 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Bayesian low-rank adaptation. 16 | """ 17 | 18 | import logging 19 | import torch as t 20 | import torch.nn as nn 21 | 22 | from torch import Tensor 23 | from typing import Callable 24 | from jaxtyping import Float 25 | from torch.func import jacrev, functional_call 26 | from transformers import BatchEncoding 27 | from transformers.modeling_outputs import ModelOutput 28 | 29 | from .kfac import stable_cholesky, KFAC_t, activation_t, outgrad_t 30 | 31 | __all__ = ["model_evidence", "variance", "cholesky_decompose_small_factors"] 32 | 33 | 34 | def calc_M( 35 | activations: activation_t, 36 | output_grads: outgrad_t, 37 | n_lora: int, 38 | n_kfac: int, 39 | s2: t.Tensor, 40 | return_LB: bool = False, 41 | ) -> ( 42 | t.Tensor 43 | | tuple[ 44 | Float[Tensor, "n_lora_x_n_kfac n_lora_x_n_kfac"], 45 | tuple[Float[Tensor, "n_lora n_lora"], Float[Tensor, "d n_kfac"]] | None, 46 | ] 47 | ): 48 | """ 49 | Calculates the `M` matrix in Eq. 32 of https://openreview.net/forum?id=FJiUyzOF1m 50 | 51 | Most conventional uses of this library should not need to call this 52 | function 'externally'. 53 | 54 | Args: 55 | activations: matrix of uncentred input activation covariances 56 | output_grads: matrix of uncentred output gradient covariances 57 | n_lora: LoRA rank 58 | n_kfac: low rank to use with Kronecker factors 59 | s2: prior variance 60 | return_LB: whether to return the `L` and `B` matrices; where 61 | - `L` is the e.g. Cholesky factorisation of small Kronecker factor with 62 | shape (n_lora, n_lora), and 63 | - `B` is the low-rank factorization of the large Kronecker factor with 64 | shape (d, n_kfac) 65 | 66 | Returns: 67 | The `M` matrix, and optionally the `L` and `B` matrices too. 68 | """ 69 | if activations.shape[-2:] == (n_lora, n_lora): 70 | L, B = (activations, output_grads) 71 | else: 72 | B, L = activations, output_grads 73 | assert L.shape[-2:] == (n_lora, n_lora) 74 | assert B.shape[-1:] == (n_kfac,) 75 | 76 | M_size = n_lora * n_kfac 77 | I = t.eye(M_size, device=L.device, dtype=L.dtype) 78 | M = I + s2 * t.kron(B.mT @ B, L.mT @ L) 79 | assert M.shape == (M_size, M_size) 80 | 81 | if return_LB: 82 | return M, (L, B) 83 | return M 84 | 85 | 86 | def cholesky_decompose_small_factors( 87 | factors: KFAC_t, lr_threshold: int, device: str, dtype: t.dtype 88 | ) -> KFAC_t: 89 | """ 90 | Compute the Cholesky factors for the full-rank (smaller) Kronecker 91 | factors 92 | 93 | Args: 94 | factors (dict[str, tuple[t.Tensor, t.Tensor]]): the Kronecker factors 95 | lr_threshold: the threshold beyond which a Kronecker factor is 96 | considered large and a low-rank approximation is applied. 97 | device: device to use 98 | dtype: datatype to store factors in on disk 99 | Returns: 100 | Kronecker factors, with small factors Cholesky decomposed. 101 | """ 102 | for name, (A, S) in factors.items(): 103 | if A.size(0) < lr_threshold: 104 | A = stable_cholesky(A.to(dtype=t.float64)) 105 | if S.size(0) < lr_threshold: 106 | S = stable_cholesky(S.to(dtype=t.float64)) 107 | factors[name] = (A.to(device, dtype), S.to(device, dtype)) 108 | return factors 109 | 110 | 111 | def model_evidence( 112 | model: nn.Module, 113 | LL: t.Tensor, 114 | factors: KFAC_t, 115 | n_lora: int, 116 | n_kfac: int, 117 | s2: Float[Tensor, "1"], 118 | ) -> Float[Tensor, "1"]: 119 | """ 120 | Use this function to calculate the marginal likelihood / model evidence; 121 | for instance to tune the value of s2 (prior variance). 122 | 123 | Args: 124 | model: your model 125 | LL: the log likelihood on a dataset of interest 126 | factors: dictionary of Kronecker factors 127 | n_lora: LoRA rank 128 | n_kfac: K-FAC rank 129 | s2: prior variance 130 | 131 | Returns: 132 | model evidence 133 | """ 134 | logdet = t.tensor(0.0) 135 | d = 1 136 | 137 | for A, S in factors.values(): 138 | d = max(A.shape + S.shape) 139 | 140 | M = calc_M(A, S, n_lora, n_kfac, s2) 141 | assert isinstance(M, t.Tensor) 142 | M = M.to(dtype=t.float64) 143 | _, slogdet = t.slogdet(M) 144 | logdet = logdet.to(dtype=A.dtype) + slogdet.to(dtype=A.dtype) 145 | logdet += -n_lora * d * t.log(s2) 146 | 147 | map_norms = 0.0 148 | # TODO: is this a reliable way of identifying the LoRA parameters? 149 | lora_params = { 150 | k: v 151 | for k, v in dict(model.named_parameters()).items() 152 | if "lora" in k.lower() and v.requires_grad 153 | } 154 | for param in lora_params.values(): 155 | map_norms += t.linalg.norm(param) 156 | model_evidence = LL + 1 / s2 * map_norms + 0.5 * logdet 157 | return model_evidence 158 | 159 | 160 | def default_output_callback(outputs: ModelOutput) -> Tensor: 161 | """Post process model outputs. 162 | 163 | This function will be passed the results of model(**batch_inputs), and 164 | should return the relevant logits. For multiple-choice tasks, this is 165 | the class logits, but for full next-token prediction, this would just 166 | be all the logits. 167 | """ 168 | # Get the last token for CausalLM 169 | logits = outputs.logits if cfg.llm.is_sc else outputs.logits[:, -1] 170 | # Select the logits corresponding to our target classes 171 | target_logits = logits[:, dset.target_ids] 172 | return target_logits 173 | 174 | 175 | def jacobian_mean( 176 | model: nn.Module, 177 | batch_inputs: BatchEncoding, 178 | target_ids: Tensor | None = None, 179 | is_sc: bool = False, 180 | output_callback: Callable[[ModelOutput], Tensor] | None = None, 181 | ) -> tuple[dict[str, Tensor], Tensor]: 182 | """Calculates the Jacobian and logit means 183 | 184 | Args: 185 | model: the LoRA LLM from which to make predictions 186 | batch_inputs: the batch inputs, exactly as you would pass them into 187 | your model with ``model(**inputs)``. 188 | target_ids: selects specific model outputs. Leave this as None if 189 | either a) you wish to consider all model outputs or b) you are 190 | providing an output_callback to post-process the model output. 191 | is_sc: whether this is a sequence classification model. Can omit if 192 | providing an output_callback 193 | output_callback: a function that takes the results of 194 | ``model(**batch_inputs)`` and returns the logits of interest 195 | Returns: 196 | The Jacobian (a dictionary of module keys and Jacobian Tensors) and the 197 | logit mean predictions. 198 | """ 199 | 200 | if output_callback is None: 201 | 202 | def ocb(outputs: ModelOutput) -> Tensor: 203 | logits = outputs.logits if cfg.llm.is_sc else outputs.logits[:, -1] 204 | if target_ids is not None: 205 | logits = logits[:, target_ids] 206 | return logits 207 | 208 | output_callback = ocb 209 | 210 | def f( 211 | model: nn.Module, lora_params: dict[str, Tensor], batch_inputs: BatchEncoding 212 | ): 213 | outputs = functional_call(model, lora_params, args=(), kwargs=batch_inputs) 214 | target_logits = output_callback(outputs) 215 | return target_logits, target_logits 216 | 217 | # Get the LoRA parameters 218 | # TODO: ensure that these are the same LoRA adapters as applied to the 219 | # modules targeted in ``calculate_kronecker_factors``. 220 | lora_params = { 221 | k: v for k, v in dict(model.named_parameters()).items() if v.requires_grad 222 | } 223 | # Sanity check 224 | for k in lora_params.keys(): 225 | assert "lora" in k.lower() 226 | 227 | # Calculate the Jacobian of each LoRA layer (and mean predictions) 228 | jacobian, f_mu = jacrev(f, argnums=1, has_aux=True)( 229 | model, lora_params, batch_inputs 230 | ) 231 | return jacobian, f_mu 232 | 233 | 234 | def variance( 235 | inputs, 236 | jacobian, 237 | factors: KFAC_t, 238 | s2: t.Tensor, 239 | n_logits: int, 240 | n_lora: int, 241 | n_kfac: int, 242 | device: str, 243 | ): 244 | """ 245 | Calculates the variance matrix for performing (linearised) prediction. 246 | 247 | Args: 248 | inputs (dict): tokenized batch of inputs (returned from a HF Tokenizer) 249 | jacobian (dict): a dictionary of first derivatives for each of the 250 | target module's parameters 251 | factors: dictionary of Kronecker factors 252 | s2: prior variance (scalar valued tensor) 253 | n_logits: the number of logits to predict (e.g. the number of classes 254 | in your Categorical likelihood) 255 | n_lora: rank used in the LoRA adapters 256 | n_kfac: rank used for the low-rank approximation of large Kronekcer 257 | factors 258 | device: device on which to accumulate the variance matrix 259 | """ 260 | jac_keys = jacobian.keys() 261 | 262 | batch_size = inputs.input_ids.size(0) 263 | 264 | # initialise a matrix to accumulate the result 265 | var_matrix = t.zeros((batch_size, n_logits, n_logits), device=device) 266 | 267 | # Iterate over the layers; `k` is the layer name / key, `A` is the input 268 | # activations and `S` are the output gradients. 269 | for k, (A, S) in factors.items(): 270 | # Jacobian term 271 | # TODO: make this less brittle ---------------------------------------- 272 | # g_key = "base_model.model." + k + ".weight" 273 | # g_key = k + ".weight" 274 | g_key = None 275 | for jac_key in jac_keys: 276 | if k in jac_key: 277 | g_key = jac_key 278 | break 279 | assert ( 280 | g_key is not None 281 | ), f"Could not find weight corresponding to kronecker factor {k}" 282 | # --------------------------------------------------------------------- 283 | 284 | G = jacobian.get(g_key).squeeze().to(device) 285 | # Ensure that G is [batch, n_logits, d, n_lora] sized at all times 286 | if G.shape[-1] != n_lora: 287 | G = G.mT 288 | assert G.shape[-1] == n_lora 289 | 290 | # Flatten the last 2 dimensions; giving [batch, n_logits, d * n_lora] 291 | G_vec = G.flatten(-2) 292 | term_1 = s2 * G_vec @ G_vec.mT 293 | assert term_1.shape == (batch_size, n_logits, n_logits) 294 | 295 | M, LB = calc_M(A, S, n_lora, n_kfac, s2, return_LB=True) 296 | assert LB is not None 297 | L, B = LB 298 | M_size = n_kfac * n_lora 299 | assert M.shape == (M_size, M_size) 300 | M = M.to(dtype=t.float64) 301 | 302 | B_expanded = B.mT[None, None, :] # [1, 1, n_kfc, d] 303 | L_expanded = L[None, None, :] # [1, 1, n_lora, n_lora] 304 | BGL = B_expanded @ G.to(dtype=B.dtype) @ L_expanded 305 | BGL_vec = BGL.flatten(-2).to(dtype=t.float64) # [batch, n_logits, M_size] 306 | term_2 = s2.pow(2.0) * BGL_vec @ t.linalg.inv(M) @ BGL_vec.mT 307 | assert term_2.shape == (batch_size, n_logits, n_logits) 308 | 309 | var_matrix += term_1 - term_2.to(var_matrix.dtype) 310 | 311 | logging.debug(f"After layer {k}, variance is {var_matrix}") 312 | return var_matrix 313 | -------------------------------------------------------------------------------- /examples/example_usage.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023-24 Maxime Robeyns 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Example usage for Bayesian LoRA. 16 | """ 17 | 18 | import os 19 | import sys 20 | import peft 21 | import hydra 22 | import logging 23 | import importlib 24 | import torch as t 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | 28 | from tqdm import tqdm 29 | from torch import Tensor 30 | from typing import Any 31 | from omegaconf import DictConfig 32 | from torch.func import jacrev, functional_call 33 | from torchmetrics import Accuracy, CalibrationError 34 | from transformers.modeling_outputs import ModelOutput 35 | 36 | from bayesian_lora import ( 37 | calculate_kronecker_factors, 38 | cholesky_decompose_small_factors, 39 | model_evidence, 40 | variance, 41 | stable_cholesky, 42 | ) 43 | from utils import dsets 44 | from utils.loggers import setup_loggers 45 | from utils.setup_llm import setup_llm 46 | from bayesian_lora.main import jacobian_mean 47 | 48 | 49 | @hydra.main( 50 | version_base="1.3", 51 | config_path="configs", 52 | config_name="example_usage", 53 | ) 54 | def main(cfg: DictConfig): 55 | # 56 | # 1. Load configuration from Hydra 57 | # 58 | device = "cuda:0" 59 | setup_loggers(cfg) 60 | os.makedirs(cfg.paths.output_dir, exist_ok=True) 61 | 62 | # 63 | # 2. Load PEFT model and dataset 64 | # 65 | model, tokenizer, gen_cfg = setup_llm(**cfg.llm) 66 | # model = model.to(device) 67 | dset_class: dsets.ClassificationDataset = getattr(dsets, cfg.dset.name) 68 | dset = dset_class(tokenizer, add_space=cfg.llm.add_space) 69 | 70 | # 71 | # 3. Do MAP training 72 | # 73 | train_loader = dset.loader( 74 | is_sc=cfg.llm.is_sc, # sequence to sequence model? 75 | batch_size=cfg.dset.train_bs, # training batch size 76 | split=cfg.dset.train_split, # training split name in dset 77 | subset_size=cfg.dset.train_subset, # train on subset? (-1 = no subset) 78 | ) 79 | map_param_path = f"{cfg.paths.output_dir}/MAP_params.pth" 80 | grad_steps, epoch = 0, 0 81 | if not os.path.exists(map_param_path) or cfg.run_every_step: 82 | # setup optimiser 83 | opt_cfg = dict(cfg.opt) 84 | # add prior / regularization for MAP objective: 85 | opt_cfg |= {"weight_decay": 1 / cfg.prior_var} 86 | optclass = getattr( 87 | importlib.import_module(opt_cfg.pop("module")), 88 | opt_cfg.pop("classname"), 89 | ) 90 | opt = optclass(model.parameters(), **opt_cfg) 91 | logging.info("Training MAP parameters") 92 | while grad_steps < cfg.train_steps: 93 | epoch += 1 94 | logging.info( 95 | f"Beginning epoch {epoch} (step {grad_steps} of {cfg.train_steps})" 96 | ) 97 | for batch in tqdm(train_loader, disable=not cfg.use_tqdm, file=sys.stdout): 98 | opt.zero_grad() 99 | prompts, classes, _ = batch 100 | inputs = tokenizer(prompts, **cfg.tokenizer_run_kwargs).to(device) 101 | logits = model(**inputs).logits[:, -1, dset.target_ids.squeeze(-1)] 102 | # loss = F.cross_entropy(logits[:, -1], targets.to(device)) 103 | loss = F.cross_entropy(logits, classes.to(device)) 104 | assert not t.isnan(loss).any(), "NaN in loss for MAP training." 105 | loss.backward() 106 | opt.step() 107 | grad_steps += 1 108 | if not grad_steps < cfg.train_steps: 109 | break 110 | logging.info(f"Saving MAP parameters after finetuning to {map_param_path}") 111 | model.save_pretrained(map_param_path) 112 | else: 113 | logging.info(f"Loading MAP parameters from {map_param_path}") 114 | del model 115 | llm_params = dict(cfg.llm) | {"use_peft": False} 116 | model, _, _ = setup_llm(**llm_params) 117 | model = peft.PeftModel.from_pretrained(model, map_param_path, is_trainable=True) 118 | model = model.to(device) 119 | 120 | # 121 | # 4. Evaluate the log likelihood 122 | # 123 | ll_path = f"{cfg.paths.output_dir}/ll.pth" 124 | if not os.path.exists(ll_path) or cfg.run_every_step: 125 | logging.info("Evaluating the MAP log likelihood") 126 | val_loader = dset.loader( 127 | is_sc=cfg.llm.is_sc, 128 | batch_size=cfg.dset.eval_bs, 129 | split=cfg.dset.eval_split, 130 | subset_size=cfg.dset.eval_subset, 131 | ) 132 | LL = 0.0 133 | with t.no_grad(), t.inference_mode(): 134 | for batch in tqdm(val_loader, disable=not cfg.use_tqdm, file=sys.stdout): 135 | prompts, classes, _ = batch 136 | inputs = tokenizer(prompts, **cfg.tokenizer_run_kwargs).to(device) 137 | logits = model(**inputs).logits[:, -1, dset.target_ids.squeeze(-1)] 138 | probs = logits.softmax(-1) 139 | LL += probs.gather(1, classes[:, None].to(device)).log().sum() 140 | t.save(LL, ll_path) 141 | else: 142 | logging.info(f"Loading LL from {ll_path}") 143 | LL = t.load(ll_path) 144 | 145 | # 146 | # 5. Calculate the (low-rank) Kronecker factors 147 | # 148 | def fwd_call(model: nn.Module, batch: Any) -> t.Tensor: 149 | prompts, _, _ = batch 150 | tok_kwargs = dict(cfg.tokenizer_run_kwargs) | { 151 | "padding": True, 152 | "return_tensors": "pt", 153 | } 154 | inputs = tokenizer(prompts, **tok_kwargs).to(device) 155 | outputs = model(**inputs) 156 | logits = ( 157 | outputs.logits[:, dset.target_ids.squeeze(-1)] 158 | if cfg.llm.is_sc 159 | else outputs.logits[:, -1, dset.target_ids.squeeze(-1)] 160 | ) 161 | logits = logits.softmax(-1) 162 | return logits 163 | 164 | kfac_path = f"{cfg.paths.output_dir}/kronecker_factors.pth" 165 | if not os.path.exists(kfac_path) or cfg.run_every_step: 166 | logging.info("Computing the low-rank Kronecker factors") 167 | factors = calculate_kronecker_factors( 168 | model, 169 | fwd_call, 170 | train_loader, 171 | cfg.n_kfac, 172 | cfg.lr_threshold, 173 | ["lora"], 174 | use_tqdm=cfg.use_tqdm, 175 | ) 176 | # Calculate Cholesky decomposition of the smaller factors 177 | factors = cholesky_decompose_small_factors( 178 | factors, cfg.lr_threshold, device, t.float32 179 | ) 180 | t.save({"factors": factors}, kfac_path) 181 | else: 182 | logging.info(f"Loading low-rank Kronecker factors from {kfac_path}") 183 | kfactors = t.load(kfac_path) 184 | factors = kfactors["factors"] 185 | 186 | # 187 | # 6. Use the marginal likelihood to optimise the prior variance 188 | # 189 | prior_path = f"{cfg.paths.output_dir}/prior_params.pth" 190 | if not os.path.exists(prior_path) or cfg.run_every_step: 191 | logging.info("Optimising priors using marginal likelihood") 192 | s2 = t.tensor(cfg.prior_var, requires_grad=True) 193 | opt = t.optim.AdamW([s2], lr=1e-2) 194 | 195 | for _ in range(200): 196 | opt.zero_grad() 197 | loss = model_evidence( 198 | model, LL, factors, cfg.llm.peft.r, cfg.n_kfac, s2 199 | ).log() 200 | loss.backward() 201 | t.nn.utils.clip_grad_norm_(s2, 1.0) 202 | opt.step() 203 | t.save({"s2": s2}, prior_path) 204 | logging.info(f"prior variance is: {s2.item()}") 205 | else: 206 | logging.info("Loading prior parameters (optimised using marginal likelihood)") 207 | priors = t.load(prior_path) 208 | s2 = priors["s2"] 209 | 210 | # 211 | # 7. Make linearized predictions 212 | # 213 | # NOTE: we need to re-load the model without using BitsAndBytes (our 214 | # gradient calculations sadly don't currently work with 4/8-bit 215 | # quantization) 216 | del model 217 | t.cuda.empty_cache() 218 | logging.info("Doing linearized prediction") 219 | 220 | cfg.llm.use_quant = False # because our gradient calcs don't support bnb 221 | cfg.llm.use_peft = False # due to the quirk in loading PEFT models 222 | # cfg.llm.model_kwargs.attn_implementation = "sdpa" 223 | model, tokenizer, gen_cfg = setup_llm(**cfg.llm) 224 | model = peft.PeftModel.from_pretrained(model, map_param_path, is_trainable=True) 225 | model = model.to(device) 226 | 227 | val_loader = dset.loader( 228 | is_sc=cfg.llm.is_sc, 229 | batch_size=cfg.dset.eval_bs, 230 | split=cfg.dset.eval_split, 231 | subset_size=cfg.dset.eval_subset, 232 | ) 233 | 234 | pred_mu = [] 235 | pred_var = [] 236 | pred_logits = [] 237 | 238 | total_loss = 0 239 | metric_kwargs = {"task": "multiclass", "num_classes": dset.n_labels} 240 | acc_metric = Accuracy(**metric_kwargs).to(device) 241 | ece_metric = CalibrationError(**metric_kwargs).to(device) 242 | 243 | def output_callback(outputs: ModelOutput) -> Tensor: 244 | """Post process model outputs. 245 | 246 | This function will be passed the results of model(**batch_inputs), and 247 | should return the relevant logits. For multiple-choice tasks, this is 248 | the class logits, but for full next-token prediction, this would just 249 | be all the logits. 250 | """ 251 | # Get the last token for CausalLM 252 | logits = outputs.logits if cfg.llm.is_sc else outputs.logits[:, -1] 253 | # Select the logits corresponding to our target classes 254 | target_logits = logits[:, dset.target_ids.squeeze(-1)] 255 | return target_logits 256 | 257 | with t.no_grad(): 258 | for batch in tqdm(val_loader, disable=not cfg.use_tqdm, file=sys.stdout): 259 | prompts, classes, _ = batch 260 | classes = classes.to(device) 261 | 262 | batch_inputs = tokenizer(prompts, **cfg.tokenizer_run_kwargs).to(device) 263 | 264 | # Predict the output logit locations 265 | jacobian, f_mu = jacobian_mean( 266 | model, batch_inputs, output_callback=output_callback 267 | ) 268 | pred_mu.append(f_mu.clone().cpu()) 269 | 270 | # Predict the output logit variances 271 | f_var = variance( 272 | batch_inputs, 273 | jacobian, 274 | factors, 275 | s2, 276 | dset.n_labels, 277 | cfg.llm.peft.r, 278 | cfg.n_kfac, 279 | device, 280 | ) 281 | pred_var.append(f_var.clone().cpu()) 282 | 283 | # Sample logits from a Gaussian parametrised by f_mu, f_var 284 | L = stable_cholesky(f_var) 285 | samples = 100_000 286 | f_mu = f_mu.expand(samples, *f_mu.shape) 287 | L = L.expand(samples, *L.shape) 288 | eps = t.randn_like(f_mu).unsqueeze(-1) 289 | logits = f_mu[..., None] + L @ eps 290 | logits = logits.squeeze(-1).mean(0) 291 | 292 | pred_logits.append(logits.cpu()) 293 | total_loss += F.cross_entropy(logits, classes).item() 294 | acc_metric(logits, classes) 295 | ece_metric(logits, classes) 296 | 297 | loss = total_loss / len(val_loader) 298 | acc = acc_metric.compute().item() 299 | ece = ece_metric.compute().item() 300 | 301 | logging.info(f"NLL: {loss:.5f}, ACC: {acc:.5f}, ECE: {ece:.5f}") 302 | 303 | output_path = f"{cfg.paths.output_dir}/predicted_logits.pth" 304 | t.save( 305 | {"pred_mu": pred_mu, "pred_var": pred_var, "pred_logits": pred_logits}, 306 | output_path, 307 | ) 308 | 309 | logging.info("Successfully finished.") 310 | 311 | 312 | if __name__ == "__main__": 313 | main() 314 | -------------------------------------------------------------------------------- /bayesian_lora/kfac.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023-24 Maxime Robeyns 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Kronecker-factored approximate curvature methods. 16 | """ 17 | 18 | import sys 19 | import logging 20 | import torch as t 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | 24 | from tqdm import tqdm 25 | from torch import Tensor 26 | from typing import Any, Callable 27 | from jaxtyping import Float 28 | from contextlib import contextmanager 29 | from torch.linalg import LinAlgError 30 | from torch.utils.data import DataLoader 31 | from torch.utils.hooks import RemovableHandle 32 | 33 | 34 | __all__ = [ 35 | "stable_cholesky", 36 | "calculate_kronecker_factors", 37 | "activation_t", 38 | "outgrad_t", 39 | "KFAC_t", 40 | ] 41 | 42 | 43 | # Utility functions =========================================================== 44 | 45 | 46 | def stabilise( 47 | K: Float[Tensor, "... d d"], mult_eps: float, abs_eps: float 48 | ) -> Float[Tensor, "... d d"]: 49 | """Multiply and add the stabilisation terms `mult_eps` and `abs_eps`""" 50 | eye = t.eye(K.shape[-1], dtype=K.dtype, device=K.device) 51 | return K * (1.0 + mult_eps * eye) + abs_eps * eye 52 | 53 | 54 | def stable_cholesky( 55 | K: Float[Tensor, "... d d"], 56 | mult_eps: float = 1e-8, 57 | abs_eps: float = 1e-8, 58 | max_tries: int = 1000, 59 | ) -> Float[Tensor, "... d d"]: 60 | for i in range(max_tries): 61 | try: 62 | scaled_mult_eps = mult_eps * (1.1**i) 63 | scaled_abs_eps = abs_eps * (1.1**i) 64 | # L = t.linalg.cholesky(stabilise(K, i * mult_eps, i * abs_eps)) 65 | L = t.linalg.cholesky(stabilise(K, scaled_mult_eps, scaled_abs_eps)) 66 | return L 67 | except LinAlgError: 68 | logging.debug(f"Chokesky decomposition failed ({i})") 69 | continue 70 | raise ValueError(f"Could not calculate Cholesky decomposition of {K}") 71 | 72 | 73 | def incremental_svd( 74 | A: Float[Tensor, "d r"], 75 | a: Float[Tensor, "batch d"], 76 | dtype: t.dtype = t.float64, 77 | n_kfac: int | None = None, 78 | ) -> Float[Tensor, "d n_kfac"]: 79 | """Calculate a low-rank estimate of a big [d, d] tensor, without 80 | materialising this full matrix. 81 | 82 | Args: 83 | A: The accumulated low-rank factor 84 | a: a new batch of points 85 | dtype: the datatype to use for the svd 86 | n_kfac: (optional) specify the rank of the resulting factor. If 87 | omitted, we use `r` from the `A` argument. 88 | You may choose to set this higher than the final rank during the 89 | accumulation. 90 | """ 91 | if n_kfac is None: 92 | n_kfac = A.size(-1) 93 | a = a.to(dtype=dtype) 94 | A_prime = t.hstack((A, a.T)) # [d, r+batch] 95 | U, S, _ = t.linalg.svd(A_prime, full_matrices=False) 96 | return U[:, :n_kfac] @ t.diag(S[:n_kfac]) 97 | 98 | 99 | # K-FAC Section =============================================================== 100 | 101 | # Datatype for Kronecker factors. l_in, l_out refers to the number of input and 102 | # output features of layer l in the network, respectively. Note, if the layer 103 | # has a bias, then l_in will in fact be l_in + 1. 104 | activation_t = Float[Tensor, "l_in l_in_or_n_kfac"] 105 | outgrad_t = Float[Tensor, "l_out l_out_or_n_kfac"] 106 | KFAC_t = dict[str, tuple[activation_t, outgrad_t]] 107 | 108 | # We add hooks to the nn.Module (e.g. AutoModelForCausalLM, PeftModel, etc) to 109 | # keep track of each layer's input activations and output gradients. 110 | # The following context managers let you enable / disable these hooks without 111 | # removing them. 112 | _hooks_enabled: bool = True 113 | _input_hooks_disabled: bool = False 114 | 115 | 116 | @contextmanager 117 | def hooks_disabled(): 118 | """ 119 | Allows the hooks for both the activations and output gradients to be 120 | temporarily disabled within a context. 121 | 122 | Example: 123 | >>> with hooks_disabled(): 124 | >>> output = model(**inputs) 125 | >>> output.loss.backward() 126 | """ 127 | global _hooks_enabled 128 | orig_state = _hooks_enabled 129 | _hooks_enabled = False 130 | try: 131 | yield 132 | finally: 133 | _hooks_enabled = orig_state 134 | 135 | 136 | @contextmanager 137 | def disable_input_hooks(): 138 | """ 139 | Disables just the input activation hooks but keeps the output gradient 140 | hooks. Useful when calculating a 'pullback' metric. 141 | 142 | Example: 143 | >>> with disable_input_hooks(): 144 | >>> loss.backward() 145 | """ 146 | global _input_hooks_disabled 147 | orig_state = _input_hooks_disabled 148 | _input_hooks_disabled = True 149 | try: 150 | yield 151 | finally: 152 | _input_hooks_disabled = orig_state 153 | 154 | 155 | def save_input_hook( 156 | module_name: str, 157 | activations: dict[str, tuple[activation_t, bool]], 158 | n_kfac: int | None, 159 | lr_threshold: int, 160 | has_bias: bool = False, 161 | svd_dtype: t.dtype = t.float64, 162 | ): 163 | """A closure which returns a new hook to capture a layer's input 164 | activations. 165 | 166 | Args: 167 | module_name: name used as a key for the 'activations' dictionary. While 168 | torch modules themselves can be hashed, using a string here makes 169 | the Kronecker factors more portable. 170 | activations: a mapping from layer / module name to input activation 171 | Kronecker factor, and a boolean flag indicating whether it is 172 | low-rank 173 | n_kfac: the rank we use if we're using a low rank appproximation to 174 | this Kronecker factor 175 | lr_threshold: if the side length `l_in+1` exceeds this threshold, and 176 | n_kfac is not None, then treat the factor as low-rank 177 | has_bias: does this layer have a bias? 178 | svd_dtype: dtype to cast tensors to for SVD calculations 179 | """ 180 | 181 | def input_hook(_module: nn.Module, pos_args: tuple[t.Tensor]) -> None: 182 | if not _hooks_enabled or _input_hooks_disabled: 183 | return 184 | # Select the first positional argument given to this layer (the input 185 | # activation), then the last token in the token sequence [:, -1]. `a` 186 | # should be a [batch, l_in] tensor. 187 | a: Float[Tensor, "batch l_in"] = pos_args[0].detach().clone()[:, -1] 188 | if has_bias: 189 | a = t.hstack((a, t.ones_like(a[:, :1]))) 190 | assert a.dim() == 2 191 | if a.size(-1) < lr_threshold or n_kfac is None: 192 | # We're not using a low-rank approximation for this factor; just do 193 | # the outer product of the activations for all the elements in the 194 | # batch, then sum along batch dim: 195 | A = (a[..., None] @ a[:, None]).sum(0) 196 | if module_name not in activations.keys(): 197 | activations[module_name] = A, False 198 | else: 199 | A_tmp = activations[module_name][0] 200 | activations[module_name] = A_tmp + A, False 201 | else: 202 | if module_name not in activations.keys(): 203 | # Initialise a correctly sized matrix of 0s 204 | activations[module_name] = ( 205 | t.zeros(a.size(-1), n_kfac, device=a.device, dtype=svd_dtype), 206 | True, 207 | ) 208 | A = incremental_svd(activations[module_name][0], a, svd_dtype, n_kfac) 209 | activations[module_name] = A, True 210 | 211 | return input_hook 212 | 213 | 214 | def save_output_grad_hook( 215 | module_name: str, 216 | output_grads: dict[str, tuple[outgrad_t, bool]], 217 | n_kfac: int | None, 218 | lr_threshold: int, 219 | svd_dtype: t.dtype = t.float64, 220 | ): 221 | """A closure which returns a new hook to capture a layer's output 222 | gradients. 223 | 224 | Args: 225 | module_name: name used as a key for the 'output_grads' dictionary. 226 | While modules themselves can be hashed, this makes the Kronecker 227 | factors more portable. 228 | output_grads: mapping from layer / module name to the output gradient 229 | Kronecker factor, and a flag indicating whether it is low-rank. 230 | n_kfac: the rank we use if we're using a low rank appproximation to 231 | this Kronecker factor 232 | lr_threshold: if the side length `l_in+1` exceeds this threshold, and 233 | n_kfac is not none, treat the factor as low-rank 234 | svd_dtype: dtype to cast tensors to for SVD calculations 235 | """ 236 | 237 | def output_grad_hook(_module: nn.Module, _, out_pos_grad: tuple[Tensor]) -> None: 238 | if not _hooks_enabled: 239 | return 240 | # Select the gradient of the first positional output of this layer, 241 | # then the last token in the token sequence [:, -1]. `s` should be a 242 | # [batch, l_out] tensor. 243 | s: Float[Tensor, "batch l_out"] = out_pos_grad[0].detach().clone()[:, -1] 244 | if s.size(-1) < lr_threshold or n_kfac is None: 245 | # We're not using a low-rank approximation for this factor; just do 246 | # the outer product of the output gradients for all elements in the 247 | # batch, then sum along the batch dimension; giving an [l_out, 248 | # l_out] tensor. 249 | S = (s[..., None] @ s[:, None]).sum(0) 250 | if module_name not in output_grads.keys(): 251 | output_grads[module_name] = S, False 252 | else: 253 | S_tmp = output_grads[module_name][0] 254 | output_grads[module_name] = S_tmp + S, False 255 | else: 256 | # Never reach this branch if n_kfac is None 257 | if module_name not in output_grads.keys(): 258 | # Initialise a correctly sized matrix of 0s 259 | output_grads[module_name] = ( 260 | t.zeros(s.size(-1), n_kfac, device=s.device, dtype=s.dtype), 261 | True, 262 | ) 263 | S = incremental_svd(output_grads[module_name][0], s, svd_dtype, n_kfac) 264 | output_grads[module_name] = S, True 265 | 266 | return output_grad_hook 267 | 268 | 269 | def register_hooks( 270 | model: nn.Module, 271 | activations: dict[str, tuple[activation_t, bool]], 272 | output_grads: dict[str, tuple[outgrad_t, bool]], 273 | target_module_keywords: list[str], 274 | n_kfac: int | None = 10, 275 | lr_threshold: int = 100, 276 | exclude_bias: bool = False, 277 | ) -> list[RemovableHandle]: 278 | """Registers the activation and output gradient hooks. 279 | 280 | Args: 281 | model: the ``nn.Module`` on which to attach the hooks (usually the full 282 | model) 283 | activations: dictionary in which to store the parameter activations and 284 | flag indicating whether this factor is low-rank. 285 | The side length is ``l_in`` (i.e. equal to the number of input 286 | features in layer ``l``), or ``l_in + 1`` if there is a bias. The 287 | last dimension is ``n_kfac`` if ``l_in >= lr_threshold``. 288 | output_grads: dictionary in which to store the output gradients and a 289 | flag indicating whether this factor is low-rank. The side length 290 | ``l_out`` is equal to the number of output features of layer ``l`` 291 | (regardless of the presence of a bias; unlike the activations). The 292 | last dimension is ``n_kfac`` if ``l_out >= lr_threshold`` 293 | target_module_keywords: a list of the network modules to include in the 294 | GGN. Note, only nn.Linear layers are currently supported. 295 | n_kfac: the rank we use to approximate large Kronecker factors. If set 296 | to None, we treat all factors as full rank (turns off the lr 297 | approximation). 298 | lr_threshold: threshold beyond which to consider a layer's input to be 299 | wide (to decide whether to approximate a Kronecker factor as low 300 | rank). LoRA layers with a wide input (e.g. LoRA-A) will have a 301 | low-rank approximation of their activation Kronecker factor, A, 302 | while LoRA layers with a narrow input (e.g. LoRA-B) will have a 303 | low-rank approximation of their output-gradient Kronecker factor, 304 | S. 305 | exclude_bias: whether to ignore bias terms (just consider the weights) 306 | 307 | Returns: 308 | - a list of hooks (for later removal), 309 | """ 310 | hooks: list[RemovableHandle] = [] 311 | for name, module in model.named_modules(): 312 | if any([kw in name for kw in target_module_keywords]) and ( 313 | isinstance(module, nn.Linear) 314 | ): 315 | logging.debug(f"Registering hook for module {name}") 316 | if name in activations.keys() or name in output_grads.keys(): 317 | raise Exception(f"Module of same name {name} already registered") 318 | has_bias = hasattr(module, "bias") and module.bias is not None 319 | if exclude_bias: 320 | # NOTE: this is a hack that should be removed 321 | has_bias = False 322 | fwd_hook = module.register_forward_pre_hook( 323 | save_input_hook(name, activations, n_kfac, lr_threshold, has_bias) 324 | ) 325 | bwd_hook = module.register_full_backward_hook( 326 | save_output_grad_hook(name, output_grads, n_kfac, lr_threshold) 327 | ) 328 | hooks.extend((fwd_hook, bwd_hook)) 329 | 330 | return hooks 331 | 332 | 333 | def remove_hooks(hooks: list[RemovableHandle]) -> None: 334 | """Remove the hooks from the module. 335 | 336 | Args: 337 | hooks: list of hooks, returned from `register_hooks` 338 | """ 339 | while len(hooks): 340 | hooks.pop().remove() 341 | 342 | 343 | def calculate_kronecker_factors( 344 | model: nn.Module, 345 | forward_call: Callable[[nn.Module, Any], Float[Tensor, "batch n_classes"]], 346 | loader: DataLoader, 347 | n_kfac: int | None = None, 348 | lr_threshold: int = 512, 349 | target_module_keywords: list[str] = [""], 350 | exclude_bias: bool = False, 351 | use_tqdm: bool = False, 352 | ) -> KFAC_t: 353 | """ 354 | Calculate the Kronecer factors, (A, S) for the likelihood, used to 355 | approximate the GGN / Fisher. 356 | 357 | Args: 358 | model: the model for which we are calculating the Kronecker factors. 359 | Note that it needn't have LoRA adapters. 360 | forward_call: A function which accepts a batch from the provided data 361 | loader, and returns the parameters of the model's predictive 362 | distribution, as a ``Tensor``. Usually this contains the logits 363 | over each class label. 364 | loader: a data loader for the dataset with which to calculate the 365 | curvature / Kronecker factors. 366 | n_kfac: an optional integer rank to use for a low-rank approximation of 367 | large Kronecker factors. If this is ``None``, then no low-rank 368 | approximations are used. 369 | lr_threshold: the threshold beyond which the side length of a Kronecker 370 | factor is considered large and a low-rank approximation is applied. 371 | target_module_keywords: a list of keywords which identify the network 372 | modules whose parameters we want to include in the Hessian 373 | calculation. This is particularly useful when working with LoRA 374 | adapters. By deafult, this is ``[""]``; targetting every module. 375 | exclude_bias: whether to ignore bias terms (NOTE: this is a hack and 376 | should not be used) 377 | use_tqdm: whether to show progress with ``tqdm``. 378 | 379 | Warning: 380 | This function has only been implemented for nn.Linear. Models 381 | implemented using Conv1D (e.g. GPT2) will sadly not work for now. 382 | 383 | Warning: 384 | Your data loader should not have a partial final batch, since this will 385 | result in an incorrect expectation. You can drop the final batch with 386 | `drop_last=True` in a standard PyTorch DataLoader. 387 | 388 | Examples: 389 | 390 | Full-rank Kronecker factor calculation. 391 | 392 | >>> factors = calculate_kronecker_factors( 393 | >>> model, fwd_call, loader 394 | >>> ) 395 | 396 | Low-rank Kronecker factors on LoRA adaptors with inputs 397 | 398 | >>> factors = calculate_kronecker_factors( 399 | >>> model, fwd_call, loader, n_kfac=10, 400 | >>> lr_threshold=512, target_module_keywords=["lora"], 401 | >>> ) 402 | 403 | 404 | Returns: 405 | A dictionary containing the Kronecker factors; keyed by module name, 406 | containing a tuple (A, S) with the activation factor (A) as the first 407 | element, and the output gradient factor (S) as the second element. 408 | """ 409 | model = model.train() 410 | 411 | activations: dict[str, tuple[t.Tensor, bool]] = dict() 412 | output_grads: dict[str, tuple[t.Tensor, bool]] = dict() 413 | 414 | hooks = register_hooks( 415 | model, 416 | activations, 417 | output_grads, 418 | target_module_keywords, 419 | n_kfac, 420 | lr_threshold, 421 | exclude_bias=exclude_bias, 422 | ) 423 | 424 | for batch in tqdm(loader, disable=not use_tqdm, file=sys.stdout): 425 | model.zero_grad() 426 | logits = forward_call(model, batch) 427 | assert logits.dim() == 2 428 | 429 | # TODO: support other model distributions. 430 | # We use the mean reduction here in the losses to keep the magnitude of 431 | # the output gradients invariant to the batch size used when 432 | # calculating the Kronecker factors. 433 | if logits.size(-1) == 1: 434 | # We are dealing with binary outputs 435 | with t.no_grad(): 436 | sampled_ys = t.bernoulli(logits.sigmoid()).view(-1) 437 | pullback_loss = F.binary_cross_entropy_with_logits( 438 | logits.squeeze(-1), sampled_ys, reduction="mean" 439 | ) 440 | else: 441 | with t.no_grad(): 442 | sampled_ys = t.multinomial(logits.softmax(-1), 1).view(-1) 443 | pullback_loss = F.cross_entropy(logits, sampled_ys, reduction="mean") 444 | 445 | with disable_input_hooks(): 446 | pullback_loss.backward() 447 | 448 | t.cuda.empty_cache() 449 | 450 | remove_hooks(hooks) 451 | factors: KFAC_t = dict() 452 | for k, (A, A_lr) in activations.items(): 453 | S, S_lr = output_grads[k] 454 | # Average only the non low-rank factors. 455 | if not S_lr: 456 | S /= len(loader) 457 | if not A_lr: 458 | A /= len(loader) 459 | factors[k] = A, S 460 | 461 | return factors 462 | -------------------------------------------------------------------------------- /documentation/source/_static/block_diagonal.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | block_diagonal 4 | Maxime Robeyns 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | -------------------------------------------------------------------------------- /examples/utils/dsets.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023-24 Maxime Robeyns 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Convenience wrappers around classification datasets 16 | """ 17 | import torch as t 18 | 19 | from abc import abstractmethod 20 | from enum import Enum 21 | from datasets import load_dataset 22 | from transformers import AutoTokenizer 23 | from collections import OrderedDict 24 | from torch.utils.data import DataLoader, Dataset 25 | 26 | # List of datasets available in this module 27 | dsets = [ 28 | "boolq", 29 | "obqa", 30 | "arc", 31 | "winogrande", 32 | "cqa", 33 | "cola", 34 | "mnli", 35 | "mrpc", 36 | "qnli", 37 | "qqp ", 38 | "rte ", 39 | "sst2", 40 | "wnli", 41 | ] 42 | 43 | 44 | class ClassificationDataset: 45 | """ 46 | An abstract base dataset for sequence classification problems. Multiple 47 | choice QA problems could also be made a subclass of this class with an 48 | appropriate collation / formatting. 49 | """ 50 | 51 | def __init__( 52 | self, 53 | dset, 54 | tokenizer, 55 | n_labels: int, 56 | preamble: str = "", 57 | add_space: bool = False, 58 | numerical: bool = True, 59 | boolean: bool = False, 60 | few_shot: bool = False, 61 | max_len: int = 1024, 62 | ): 63 | """ 64 | Args: 65 | dset: The loaded Dataset 66 | tokenizer: The model tokenizer 67 | n_labels: The number of labels / classes for each question 68 | preamble: Preamble for general pre-trained / 'CausalLM' models 69 | add_space: Add an explicit space suffix between preamble and answer tokens. 70 | numerical: whether labels are numerical (0, 1, etc.) or alphabetical (A, B, etc.) 71 | boolean: whether the labels are boolean (0, 1) 72 | few_shot: whether to use few-shot prompting (if available) 73 | max_len: the matximum length of the prompt. 74 | """ 75 | self.dset = dset 76 | self.n_labels = n_labels 77 | self.preamble = preamble 78 | self.add_space = add_space 79 | self.tokenizer = tokenizer 80 | self.numerical = numerical 81 | self.few_shot = few_shot 82 | self.max_len = max_len 83 | 84 | spc = " " if self.add_space else "" 85 | 86 | # 1. Build up the token IDS of the class labels. 87 | if numerical and boolean: 88 | raise ValueError("Question type cannot be both numerical and boolean") 89 | if boolean: 90 | labels = [f"{spc}True", f"{spc}False"] 91 | elif numerical: 92 | labels = [f"{spc}{i}" for i in range(self.n_labels)] 93 | else: # alphabetical 94 | labels = [f"{spc}{chr(ord('A')+i)}" for i in range(self.n_labels)] 95 | self.target_ids = tokenizer( 96 | labels, return_tensors="pt", add_special_tokens=False 97 | ).input_ids[:, -1:] 98 | assert ( 99 | self.target_ids.unique().numel() == self.target_ids.numel() 100 | ), "Target label IDS are not unique! Try changing add_space or numerical." 101 | 102 | # 2. Get a mapping from the label indices (e.g. 0, 1, 2, etc.) to the 103 | # target token ids from above (e.g. 345, 346, etc.). 104 | # That is; {(0, 345), (1, 346), etc} 105 | self.label_idx2target_id = OrderedDict( 106 | [(i, self.target_ids[i]) for i in range(n_labels)] 107 | ) 108 | self.target_id2label_idx = OrderedDict( 109 | [(self.target_ids[i], i) for i in range(n_labels)] 110 | ) 111 | 112 | @abstractmethod 113 | def sc_collate_fn(self, batch): 114 | """Collate function for sequence classification models""" 115 | raise NotImplementedError 116 | 117 | def sc_loader(self, dset: Dataset, *args, **kwargs) -> DataLoader: 118 | """Returns the dataloader for sequence classification models""" 119 | return t.utils.data.DataLoader( 120 | dset, collate_fn=self.sc_collate_fn, *args, **kwargs 121 | ) 122 | 123 | @abstractmethod 124 | def clm_collate_fn(self, batch): 125 | """Collate function for causal language models""" 126 | raise NotImplementedError 127 | 128 | def clm_loader(self, dset: Dataset, *args, **kwargs) -> DataLoader: 129 | """Returns the dataloader for causal language models""" 130 | return t.utils.data.DataLoader( 131 | dset, collate_fn=self.clm_collate_fn, *args, **kwargs 132 | ) 133 | 134 | def loader( 135 | self, 136 | *args, 137 | is_sc: bool = False, 138 | split: str = "train", 139 | subset_size: int = -1, 140 | subset_seed: int | None = 42, 141 | grad_acc_steps: int = 1, 142 | drop_last: bool = True, 143 | **kwargs, 144 | ): 145 | if subset_size > 0: 146 | subset_size = ( 147 | len(self.dset[split]) 148 | if len(self.dset[split]) < subset_size 149 | else subset_size 150 | ) 151 | dset = self.dset[split].shuffle(seed=subset_seed).select(range(subset_size)) 152 | else: 153 | dset = self.dset[split] 154 | 155 | kwargs = {"batch_size": 32, "drop_last": drop_last} | kwargs 156 | assert ( 157 | kwargs["batch_size"] % grad_acc_steps == 0 158 | ), "batch size must be divisible by gradient accumulation steps" 159 | kwargs["batch_size"] = kwargs["batch_size"] // grad_acc_steps 160 | 161 | if is_sc: 162 | return self.sc_loader(dset, *args, **kwargs) 163 | else: 164 | return self.clm_loader(dset, *args, **kwargs) 165 | 166 | 167 | class BoolQDataset(ClassificationDataset): 168 | def __init__( 169 | self, 170 | tokenizer: AutoTokenizer, 171 | add_space: bool = True, 172 | few_shot: bool = False, 173 | max_len: int = 256, 174 | ): 175 | dset = load_dataset("boolq") 176 | 177 | prompt = """Read the passage below and answer the question with the words 'True' or 'False'. 178 | 179 | Passage: {passage} 180 | Question: {question} 181 | Answer (True or False):""" 182 | 183 | super().__init__( 184 | dset, 185 | tokenizer, 186 | 2, 187 | prompt, 188 | add_space, 189 | numerical=False, 190 | boolean=True, 191 | few_shot=few_shot, 192 | max_len=max_len, 193 | ) 194 | 195 | def clm_collate_fn(self, batch): 196 | prompts = [ 197 | self.preamble.format( 198 | passage=e["passage"][-self.max_len :], question=e["question"] 199 | ) 200 | for e in batch 201 | ] 202 | classes = t.tensor([int(e["answer"]) for e in batch]) 203 | targets = t.cat([self.label_idx2target_id[c.item()] for c in classes]) 204 | return prompts, classes, targets 205 | 206 | def sc_collate_fn(self, batch): 207 | prompts = [ 208 | self.preamble.format( 209 | passage=e["passage"][-self.max_len :], question=e["question"] 210 | ) 211 | for e in batch 212 | ] 213 | classes = t.tensor([int(e["answer"]) for e in batch]) 214 | return prompts, classes, None 215 | 216 | 217 | boolq = BoolQDataset 218 | 219 | 220 | class OBQADataset(ClassificationDataset): 221 | def __init__( 222 | self, 223 | tokenizer: AutoTokenizer, 224 | add_space: bool = True, 225 | few_shot: bool = False, 226 | max_len: int = 1024, 227 | ): 228 | dset = load_dataset("openbookqa", "main") 229 | prompt = self.few_shot_preamble if few_shot else self.zero_shot_preamble 230 | super().__init__( 231 | dset, 232 | tokenizer, 233 | 4, 234 | prompt, 235 | add_space, 236 | numerical=False, 237 | few_shot=few_shot, 238 | max_len=max_len, 239 | ) 240 | 241 | few_shot_preamble = """Return the label of the correct answer for each question below. 242 | 243 | The sun is responsible for 244 | Choices: 245 | A) puppies learning new tricks 246 | B) children growing up and getting old 247 | C) flowers wilting in a vase 248 | D) plants sprouting, blooming and wilting 249 | Answer: D 250 | 251 | What doesn't eliminate waste? 252 | A) plants 253 | B) robots 254 | C) mushrooms 255 | D) bacteria 256 | Answer: B 257 | 258 | {question} 259 | Choices: 260 | {choices} 261 | Answer:""" 262 | 263 | zero_shot_preamble = """Return the label of the correct answer for the question below. 264 | 265 | Question: {question} 266 | Chioces: 267 | {choices} 268 | Answer:""" 269 | 270 | def _format_prompts(self, batch): 271 | prompts = [] 272 | for e in batch: 273 | choices = "\n".join( 274 | [ 275 | f"{l}) {c}" 276 | for c, l, in zip(e["choices"]["text"], e["choices"]["label"]) 277 | ] 278 | ) 279 | prompts.append( 280 | self.preamble.format(question=e["question_stem"], choices=choices) 281 | ) 282 | return prompts 283 | 284 | def clm_collate_fn(self, batch): 285 | prompts = self._format_prompts(batch) 286 | # prompts = self.tokenizer(prompts, padding=True, return_tensors="pt") 287 | # prompts = {k: v[:, -self.max_len :] for k, v in prompts.items()} 288 | classes = t.tensor([ord(e["answerKey"]) - ord("A") for e in batch]) 289 | targets = t.cat([self.label_idx2target_id[c.item()] for c in classes]) 290 | return prompts, classes, targets 291 | 292 | def sc_collate_fn(self, batch): 293 | prompts = self._format_prompts(batch) 294 | # prompts = self.tokenizer(prompts, padding=True, return_tensors="pt") 295 | # prompts = {k: v[:, -self.max_len :] for k, v in prompts.items()} 296 | classes = t.tensor([ord(e["answerKey"]) - ord("A") for e in batch]) 297 | return prompts, classes, None 298 | 299 | 300 | obqa = OBQADataset 301 | 302 | 303 | class ArcSplit(Enum): 304 | C = "ARC-Challenge" 305 | E = "ARC-Easy" 306 | 307 | 308 | class ARCDataset(ClassificationDataset): 309 | def __init__( 310 | self, 311 | tokenizer: AutoTokenizer, 312 | name: ArcSplit = ArcSplit.E, 313 | add_space: bool = True, 314 | few_shot: bool = False, 315 | max_len: int = 4096, 316 | ): 317 | dset = load_dataset("ai2_arc", name.value) 318 | prompt = self.few_shot_preamble if few_shot else self.zero_shot_preamble 319 | super().__init__( 320 | dset, 321 | tokenizer, 322 | 5, 323 | prompt, 324 | add_space, 325 | numerical=False, 326 | few_shot=few_shot, 327 | max_len=max_len, 328 | ) 329 | 330 | few_shot_preamble = """Return the label of the correct answer for each question below. 331 | 332 | Which two body systems are directly involved in movement? 333 | Choices: 334 | A) muscular and skeletal 335 | B) digestive and muscular 336 | C) skeletal and respiratory 337 | D) respiratory and digestive 338 | Answer: A 339 | 340 | {question} 341 | Choices: 342 | {choices} 343 | Answer:""" 344 | 345 | zero_shot_preamble = """Return the label of the correct answer for the question below. 346 | 347 | Question: {question} 348 | Choices: 349 | {choices} 350 | Answer:""" 351 | 352 | def _format_prompts(self, batch): 353 | prompts = [] 354 | for e in batch: 355 | choices = "\n".join( 356 | [ 357 | f"{l}) {c}" 358 | for c, l in zip(e["choices"]["text"], e["choices"]["label"]) 359 | ] 360 | ) 361 | prompts.append( 362 | self.preamble.format(question=e["question"], choices=choices) 363 | ) 364 | return prompts 365 | 366 | def clm_collate_fn(self, batch): 367 | prompts = self._format_prompts(batch) 368 | # prompts = self.tokenizer(prompts, padding=True, return_tensors="pt") 369 | # prompts = {k: v[:, -self.max_len :] for k, v in prompts.items()} 370 | classes_alpha = t.tensor([ord(e["answerKey"]) - ord("A") for e in batch]) 371 | classes_num = [] 372 | for e in batch: 373 | try: 374 | classes_num.append(int(e["answerKey"]) - 1) 375 | except: 376 | classes_num.append(-1) 377 | # classes_num = t.tensor([int(e["answerKey"]) - 1 for e in batch]) 378 | classes = t.where(classes_alpha < 0, t.tensor(classes_num), classes_alpha) 379 | targets = t.cat([self.label_idx2target_id[c.item()] for c in classes]) 380 | return prompts, classes, targets 381 | 382 | def sc_collate_fn(self, batch): 383 | prompts = self._format_prompts(batch) 384 | # prompts = self.tokenizer(prompts, padding=True, return_tensors="pt") 385 | # prompts = {k: v[:, -self.max_len :] for k, v in prompts.items()} 386 | classes = t.tensor([ord(e["answerKey"]) - ord("A") for e in batch]) 387 | return prompts, classes, None 388 | 389 | 390 | arc = ARCDataset 391 | 392 | 393 | class WinograndeSplit(Enum): 394 | XS = "winogrande_xs" 395 | S = "winogrande_s" 396 | M = "winogrande_m" 397 | L = "winogrande_l" 398 | XL = "winogrande_xl" 399 | 400 | 401 | class WinograndeDataset(ClassificationDataset): 402 | def __init__( 403 | self, 404 | tokenizer: AutoTokenizer, 405 | name: WinograndeSplit = WinograndeSplit.S, 406 | add_space: bool = True, 407 | few_shot: bool = False, 408 | max_len: int = 4096, 409 | ): 410 | dset = load_dataset("winogrande", name.value) 411 | prompt = self.few_shot_preamble if few_shot else self.zero_shot_preamble 412 | super().__init__( 413 | dset, 414 | tokenizer, 415 | 2, 416 | prompt, 417 | add_space, 418 | numerical=False, 419 | few_shot=few_shot, 420 | max_len=max_len, 421 | ) 422 | 423 | few_shot_preamble = """Return the label of the correct answer for each question below. 424 | 425 | Adam put handwash only clothes in the washer but Aaron washed them by hand as _ was lazy. 426 | Choices: 427 | A) Adam 428 | B) Aaron 429 | Answer: A 430 | 431 | Steven proudly showed Michael the mangoes he grew himself all this summer. _ is astonished. 432 | Choices: 433 | A) Stephen 434 | B) Michael 435 | Answer: B 436 | 437 | {question} 438 | Choices: 439 | {choices} 440 | Answer:""" 441 | 442 | zero_shot_preamble = """Return the label of the correct answer for the question below. 443 | 444 | Question: {question} 445 | Choices: 446 | {choices} 447 | Answer:""" 448 | 449 | def _format_prompts(self, batch): 450 | prompts = [] 451 | for e in batch: 452 | choices = f"A) {e['option1']}\nB) {e['option2']}" 453 | prompts.append( 454 | self.preamble.format(question=e["sentence"], choices=choices) 455 | ) 456 | return prompts 457 | 458 | def clm_collate_fn(self, batch): 459 | prompts = self._format_prompts(batch) 460 | # prompts = self.tokenizer(prompts, padding=True, return_tensors="pt") 461 | # prompts = {k: v[:, -self.max_len :] for k, v in prompts.items()} 462 | classes = t.tensor([int(e["answer"]) - 1 for e in batch]) 463 | targets = t.cat([self.label_idx2target_id[c.item()] for c in classes]) 464 | return prompts, classes, targets 465 | 466 | def sc_collate_fn(self, batch): 467 | prompts = self._format_prompts(batch) 468 | # prompts = [e["sentence"] for e in batch] 469 | # prompts = self.tokenizer(prompts, padding=True, return_tensors="pt") 470 | # prompts = {k: v[:, -self.max_len :] for k, v in prompts.items()} 471 | classes = t.tensor([int(e["answer"]) - 1 for e in batch]) 472 | return prompts, classes, None 473 | 474 | 475 | winogrande = WinograndeDataset 476 | 477 | 478 | class CommonsenseQADataset(ClassificationDataset): 479 | def __init__( 480 | self, 481 | tokenizer: AutoTokenizer, 482 | add_space: bool = True, 483 | few_shot: bool = True, 484 | max_len=4096, 485 | ): 486 | dset = load_dataset("commonsense_qa") 487 | super().__init__( 488 | dset, 489 | tokenizer, 490 | 5, 491 | self.few_shot_preamble if few_shot else self.zero_shot_preamble, 492 | add_space, 493 | numerical=False, 494 | few_shot=few_shot, 495 | max_len=max_len, 496 | ) 497 | 498 | # few-shot preamble 499 | few_shot_preamble = """Answer the questions below correctly. 500 | 501 | Question: What do people aim to do at work? 502 | Choices: 503 | A) complete job 504 | B) learn from each other 505 | C) kill animals 506 | D) wear hats 507 | E) talk to each other 508 | Answer: A 509 | 510 | Question: Where do adults use glue sticks? 511 | Choices: 512 | A) classroom 513 | B) desk drawer 514 | C) at school 515 | D) office 516 | E) kitchen draw 517 | Answer: D 518 | 519 | Question: {question} 520 | Choices: 521 | {choices} 522 | Answer:""" 523 | 524 | zero_shot_preamble = """Answer the multiple choice question below by returning the answer label (A to E) 525 | 526 | Question: {question} 527 | Choices: 528 | {choices} 529 | Answer:""" 530 | 531 | def _format_prompts(self, batch): 532 | prompts = [] 533 | for e in batch: 534 | choices = "\n".join( 535 | [ 536 | f"{l}) {c}" 537 | for l, c in zip(e["choices"]["label"], e["choices"]["text"]) 538 | ] 539 | ) 540 | prompts.append( 541 | self.preamble.format(question=e["question"], choices=choices) 542 | ) 543 | return prompts 544 | 545 | def clm_collate_fn(self, batch): 546 | prompts = self._format_prompts(batch) 547 | # prompts = self.tokenizer(prompts, padding=True, return_tensors="pt") 548 | # prompts = {k: v[:, -self.max_len :] for k, v in prompts.items()} 549 | classes = t.tensor([ord(e["answerKey"]) - ord("A") for e in batch]) 550 | targets = t.cat([self.label_idx2target_id[c.item()] for c in classes]) 551 | return prompts, classes, targets 552 | 553 | def sc_collate_fn(self, batch): 554 | prompts = self._format_prompts(batch) 555 | # prompts = self.tokenizer(prompts, padding=True, return_tensors="pt") 556 | # prompts = {k: v[:, -self.max_len :] for k, v in prompts.items()} 557 | classes = t.tensor([ord(e["answerKey"]) - ord("A") for e in batch]) 558 | return prompts, classes, None 559 | 560 | 561 | cqa = CommonsenseQADataset 562 | 563 | 564 | class CoLADataset(ClassificationDataset): 565 | def __init__( 566 | self, 567 | tokenizer: AutoTokenizer, 568 | add_space: bool = True, 569 | few_shot: bool = False, 570 | max_len: int = 4096, 571 | ): 572 | dset = load_dataset("glue", "cola") 573 | prompt = self.few_shot_preamble if few_shot else self.zero_shot_preamble 574 | super().__init__( 575 | dset, 576 | tokenizer, 577 | 2, 578 | prompt, 579 | add_space, 580 | numerical=True, 581 | few_shot=few_shot, 582 | max_len=max_len, 583 | ) 584 | 585 | few_shot_preamble = """For each sentence below, indicate whether it is grammatically acceptable (1) or unacceptable (0). 586 | 587 | Sentence: If you had eaten more, you would want less. 588 | Answer: 1 589 | 590 | Sentence: As you eat the most, you want the least. 591 | Answer: 0 592 | 593 | Sentence: {sentence} 594 | Answer:""" 595 | 596 | zero_shot_preamble = """For each sentence below, indicate whether it is grammatically acceptable (1) or unacceptable (0). 597 | 598 | Sentence: {sentence} 599 | Answer:""" 600 | 601 | def clm_collate_fn(self, batch): 602 | # No need to use self.add_space here since we add it to the target tokens 603 | prompts = [self.preamble.format(sentence=e["sentence"]) for e in batch] 604 | # prompts = self.tokenizer(prompts, padding=True, return_tensors="pt") 605 | # prompts = {k: v[:, -self.max_len :] for k, v in prompts.items()} 606 | classes = t.tensor([e["label"] for e in batch]) 607 | targets = t.cat([self.label_idx2target_id[e["label"]] for e in batch]) 608 | return prompts, classes, targets 609 | 610 | def sc_collate_fn(self, batch): 611 | prompts = [e["sentence"] for e in batch] 612 | # prompts = self.tokenizer(prompts, padding=True, return_tensors="pt") 613 | # prompts = {k: v[:, -self.max_len :] for k, v in prompts.items()} 614 | classes = t.tensor([e["label"] for e in batch]) 615 | return prompts, classes, None 616 | 617 | 618 | cola = CoLADataset 619 | 620 | 621 | class MNLIDataset(ClassificationDataset): 622 | def __init__( 623 | self, 624 | tokenizer: AutoTokenizer, 625 | add_space: bool = True, 626 | few_shot: bool = False, 627 | max_len: int = 4096, 628 | ): 629 | dset = load_dataset("glue", "mnli") 630 | prompt = self.few_shot_preamble if few_shot else self.zero_shot_preamble 631 | super().__init__( 632 | dset, 633 | tokenizer, 634 | 3, 635 | prompt, 636 | add_space, 637 | numerical=True, 638 | few_shot=few_shot, 639 | max_len=max_len, 640 | ) 641 | 642 | few_shot_preamble = """For each premise below, indicate whether the hypothesis entails (0), is neutral towards (1) or contradicts (2) the premise. 643 | 644 | Hypothesis: Buffet and a la carte available. 645 | Premise: It has a buffet. 646 | Answer: 0 647 | 648 | Hypothesis: He had never felt better. 649 | Premise: The medicine he had taken had worked well. 650 | Answer: 1 651 | 652 | Hypothesis: Oh, what a fool I feel! 653 | Premise: I am beyond proud 654 | Answer: 2 655 | 656 | Hypothesis: {hypothesis} 657 | Premise: {premise} 658 | Answer:""" 659 | 660 | zero_shot_preamble = """For each premise below, indicate whether the hypothesis entails (0), is neutral towards (1) or contradicts (2) the premise. 661 | 662 | Hypothesis: Buffet and a la carte available. 663 | Premise: It has a buffet. 664 | Answer: 0 665 | 666 | Hypothesis: He had never felt better. 667 | Premise: The medicine he had taken had worked well. 668 | Answer: 1 669 | 670 | Hypothesis: Oh, what a fool I feel! 671 | Premise: I am beyond proud 672 | Answer: 2 673 | 674 | Hypothesis: {hypothesis} 675 | Premise: {premise} 676 | Answer:""" 677 | 678 | def clm_collate_fn(self, batch): 679 | # No need to use self.add_space here since we add it to the target tokens 680 | prompts = [ 681 | self.preamble.format(hypothesis=e["hypothesis"], premise=e["premise"]) 682 | for e in batch 683 | ] 684 | # prompts = self.tokenizer(prompts, padding=True, return_tensors="pt") 685 | # prompts = {k: v[:, -self.max_len :] for k, v in prompts.items()} 686 | classes = t.tensor([e["label"] for e in batch]) 687 | targets = t.cat([self.label_idx2target_id[e["label"]] for e in batch]) 688 | return prompts, classes, targets 689 | 690 | def sc_collate_fn(self, batch): 691 | prompts = [e["hypothesis"] + " " + e["premise"] for e in batch] 692 | # prompts = self.tokenizer(prompts, padding=True, return_tensors="pt") 693 | # prompts = {k: v[:, -self.max_len :] for k, v in prompts.items()} 694 | classes = t.tensor([e["label"] for e in batch]) 695 | return prompts, classes, None 696 | 697 | 698 | mnli = MNLIDataset 699 | 700 | 701 | class MRPCDataset(ClassificationDataset): 702 | def __init__( 703 | self, 704 | tokenizer: AutoTokenizer, 705 | add_space: bool = True, 706 | few_shot: bool = False, 707 | max_len: int = 4096, 708 | ): 709 | dset = load_dataset("glue", "mrpc") 710 | prompt = self.few_shot_preamble if few_shot else self.zero_shot_preamble 711 | super().__init__( 712 | dset, 713 | tokenizer, 714 | 2, 715 | prompt, 716 | add_space, 717 | numerical=True, 718 | few_shot=few_shot, 719 | max_len=max_len, 720 | ) 721 | 722 | few_shot_preamble = """For each pair of sentences below, indicate whether the Sentence 1 is equivalent (1) or not equivalent (2) to the Sentence 2. 723 | 724 | Sentence 1: Yucaipa owned Dominick's before selling the chain to Safeway in 1998 for $2.5 billion. 725 | Sentence 2: Yucaipa bought Dominick's in 1995 for $693 million and sold it to Safeway for $1.8 billion in 1998. 726 | Answer: 0 727 | 728 | Sentence 1: Amrozi accused his brother, whom he called "the witness", of deliberately distorting his evidence. 729 | Sentence 2: Referring to him as only "the witness", Amrozi accused his brother of deliberately distorting his evidence. 730 | Answer: 1 731 | 732 | Sentence 1: {sentence_1} 733 | Sentence 2: {sentence_2} 734 | Answer:""" 735 | 736 | zero_shot_preamble = """For each pair of sentences below, indicate whether the Sentence 1 is equivalent (1) or not equivalent (2) to the Sentence 2. 737 | 738 | Sentence 1: {sentence_1} 739 | Sentence 2: {sentence_2} 740 | Answer:""" 741 | 742 | def clm_collate_fn(self, batch): 743 | # No need to use self.add_space here since we add it to the target tokens 744 | prompts = [ 745 | self.preamble.format(sentence_1=e["sentence1"], sentence_2=e["sentence2"]) 746 | for e in batch 747 | ] 748 | # prompts = self.tokenizer(prompts, padding=True, return_tensors="pt") 749 | # prompts = {k: v[:, -self.max_len :] for k, v in prompts.items()} 750 | classes = t.tensor([e["label"] for e in batch]) 751 | targets = t.cat([self.label_idx2target_id[e["label"]] for e in batch]) 752 | return prompts, classes, targets 753 | 754 | def sc_collate_fn(self, batch): 755 | prompts = [e["sentence1"] + " " + e["sentence2"] for e in batch] 756 | # prompts = self.tokenizer(prompts, padding=True, return_tensors="pt") 757 | # prompts = {k: v[:, -self.max_len :] for k, v in prompts.items()} 758 | classes = t.tensor([e["label"] for e in batch]) 759 | return prompts, classes, None 760 | 761 | 762 | mrpc = MRPCDataset 763 | 764 | 765 | class QNLIDataset(ClassificationDataset): 766 | def __init__( 767 | self, 768 | tokenizer: AutoTokenizer, 769 | add_space: bool = True, 770 | few_shot: bool = False, 771 | max_len: int = 4096, 772 | ): 773 | dset = load_dataset("glue", "qnli") 774 | prompt = self.few_shot_preamble if few_shot else self.zero_shot_preamble 775 | super().__init__( 776 | dset, 777 | tokenizer, 778 | 2, 779 | prompt, 780 | add_space, 781 | numerical=True, 782 | few_shot=few_shot, 783 | max_len=max_len, 784 | ) 785 | 786 | few_shot_preamble = """For each sentence below, indicate whether it entails (0) or does not entail (1) the associated question. 787 | 788 | Question: Which collection of minor poems are sometimes attributed to Virgil? 789 | Sentence: A number of minor poems, collected in the Appendix Vergiliana, are sometimes attributed to him. 790 | Answer: 0 791 | 792 | Question: What was the highest order of species n land? 793 | Sentence: The climate was much more humid than the Triassic, and as a result, the world was very tropical. 794 | Answer: 1 795 | 796 | Question: {question} 797 | Sentence: {sentence} 798 | Answer:""" 799 | 800 | zero_shot_preamble = """For each sentence below, indicate whether it entails (0) or does not entail (1) the associated question. 801 | 802 | Question: {question} 803 | Sentence: {sentence} 804 | Answer:""" 805 | 806 | def clm_collate_fn(self, batch): 807 | # No need to use self.add_space here since we add it to the target tokens 808 | prompts = [ 809 | self.preamble.format(question=e["question"], sentence=e["sentence"]) 810 | for e in batch 811 | ] 812 | # prompts = self.tokenizer(prompts, padding=True, return_tensors="pt") 813 | # prompts = {k: v[:, -self.max_len :] for k, v in prompts.items()} 814 | classes = t.tensor([e["label"] for e in batch]) 815 | targets = t.cat([self.label_idx2target_id[e["label"]] for e in batch]) 816 | return prompts, classes, targets 817 | 818 | def sc_collate_fn(self, batch): 819 | prompts = [e["question"] + " " + e["sentence"] for e in batch] 820 | # prompts = self.tokenizer(prompts, padding=True, return_tensors="pt") 821 | # prompts = {k: v[:, -self.max_len :] for k, v in prompts.items()} 822 | classes = t.tensor([e["label"] for e in batch]) 823 | return prompts, classes, None 824 | 825 | 826 | qnli = QNLIDataset 827 | 828 | 829 | class QQPDataset(ClassificationDataset): 830 | def __init__( 831 | self, 832 | tokenizer: AutoTokenizer, 833 | add_space: bool = True, 834 | few_shot: bool = False, 835 | max_len: int = 4096, 836 | ): 837 | dset = load_dataset("glue", "qqp") 838 | prompt = self.few_shot_preamble if few_shot else self.zero_shot_preamble 839 | super().__init__( 840 | dset, 841 | tokenizer, 842 | 2, 843 | prompt, 844 | add_space, 845 | numerical=True, 846 | few_shot=few_shot, 847 | max_len=max_len, 848 | ) 849 | 850 | few_shot_preamble = """For each pair of questions below, indicate whether the first is a duplicate (1) or not a duplicate (0) of the first. 851 | 852 | Question 1: How is air traffic controlled? 853 | Question 2: How do you become an air traffic controller? 854 | Answer: 0 855 | 856 | Question 1: What are the coolest Android hacks and tricks you know? 857 | Question 2: What are some cool hacks for Android phones? 858 | Answer: 1 859 | 860 | Question 1: {question_1} 861 | Question 2: {question_2} 862 | Answer:""" 863 | 864 | zero_shot_preamble = """For each pair of questions below, indicate whether the first is a duplicate (1) or not a duplicate (0) of the first. 865 | 866 | Question 1: {question_1} 867 | Question 2: {question_2} 868 | Answer:""" 869 | 870 | def clm_collate_fn(self, batch): 871 | # No need to use self.add_space here since we add it to the target tokens 872 | prompts = [ 873 | self.preamble.format(question_1=e["question1"], question_2=e["question2"]) 874 | for e in batch 875 | ] 876 | # prompts = self.tokenizer(prompts, padding=True, return_tensors="pt") 877 | # prompts = {k: v[:, -self.max_len :] for k, v in prompts.items()} 878 | classes = t.tensor([e["label"] for e in batch]) 879 | targets = t.cat([self.label_idx2target_id[e["label"]] for e in batch]) 880 | return prompts, classes, targets 881 | 882 | def sc_collate_fn(self, batch): 883 | prompts = [e["question1"] + " " + e["question2"] for e in batch] 884 | # prompts = self.tokenizer(prompts, padding=True, return_tensors="pt") 885 | # prompts = {k: v[:, -self.max_len :] for k, v in prompts.items()} 886 | classes = t.tensor([e["label"] for e in batch]) 887 | return prompts, classes, None 888 | 889 | 890 | qqp = QQPDataset 891 | 892 | 893 | class RTEDataset(ClassificationDataset): 894 | def __init__( 895 | self, 896 | tokenizer: AutoTokenizer, 897 | add_space: bool = True, 898 | few_shot: bool = False, 899 | max_len: int = 4096, 900 | ): 901 | dset = load_dataset("glue", "rte") 902 | prompt = self.few_shot_preamble if few_shot else self.zero_shot_preamble 903 | super().__init__( 904 | dset, 905 | tokenizer, 906 | 2, 907 | prompt, 908 | add_space, 909 | numerical=True, 910 | few_shot=few_shot, 911 | max_len=max_len, 912 | ) 913 | 914 | few_shot_preamble = """For each pair of sentences below, indicate whether the second entails (0) or does not entail (1) the first. 915 | 916 | Sentence 1: Edward VIII became King in January of 1936 and abdicated in December. 917 | Sentence 2: King Edward VIII abdicated in December 1936. 918 | Answer: 0 919 | 920 | Sentence 1: No Weapons of Mass Destruction Found in Iraq Yet. 921 | Sentence 2: Weapons of Mass Destruction Found in Iraq. 922 | Answer: 1 923 | 924 | Sentence 1: {sentence_1} 925 | Sentence 2: {sentence_2} 926 | Answer:""" 927 | 928 | zero_shot_preamble = """For each pair of sentences below, indicate whether the second entails (0) or does not entail (1) the first. 929 | 930 | Sentence 1: {sentence_1} 931 | Sentence 2: {sentence_2} 932 | Answer:""" 933 | 934 | def clm_collate_fn(self, batch): 935 | # No need to use self.add_space here since we add it to the target tokens 936 | prompts = [ 937 | self.preamble.format(sentence_1=e["sentence1"], sentence_2=e["sentence2"]) 938 | for e in batch 939 | ] 940 | # prompts = self.tokenizer(prompts, padding=True, return_tensors="pt") 941 | # prompts = {k: v[:, -self.max_len :] for k, v in prompts.items()} 942 | classes = t.tensor([e["label"] for e in batch]) 943 | targets = t.cat([self.label_idx2target_id[e["label"]] for e in batch]) 944 | return prompts, classes, targets 945 | 946 | def sc_collate_fn(self, batch): 947 | prompts = [e["sentence1"] + " " + e["sentence2"] for e in batch] 948 | # prompts = self.tokenizer(prompts, padding=True, return_tensors="pt") 949 | # prompts = {k: v[:, -self.max_len :] for k, v in prompts.items()} 950 | classes = t.tensor([e["label"] for e in batch]) 951 | return prompts, classes, None 952 | 953 | 954 | rte = RTEDataset 955 | 956 | 957 | class SST2Dataset(ClassificationDataset): 958 | def __init__( 959 | self, 960 | tokenizer: AutoTokenizer, 961 | add_space: bool = True, 962 | few_shot: bool = False, 963 | max_len: int = 4096, 964 | ): 965 | dset = load_dataset("glue", "sst2") 966 | prompt = self.few_shot_preamble if few_shot else self.zero_shot_preamble 967 | super().__init__( 968 | dset, 969 | tokenizer, 970 | 2, 971 | prompt, 972 | add_space, 973 | numerical=True, 974 | few_shot=few_shot, 975 | max_len=max_len, 976 | ) 977 | 978 | few_shot_preamble = """For each sentence below, indicate whether the sentiment is negative (0) or positive (1). 979 | 980 | Sentence: a depressed fifteen-year-old 's suicidal poetry 981 | Answer: 0 982 | 983 | Sentence: the greatest musicians 984 | Answer: 1 985 | 986 | Sentence: {sentence} 987 | Answer:""" 988 | 989 | zero_shot_preamble = """For each sentence below, indicate whether the sentiment is negative (0) or positive (1). 990 | 991 | Sentence: {sentence} 992 | Answer:""" 993 | 994 | def clm_collate_fn(self, batch): 995 | # No need to use self.add_space here since we add it to the target tokens 996 | prompts = [self.preamble.format(sentence=e["sentence"]) for e in batch] 997 | # prompts = self.tokenizer(prompts, padding=True, return_tensors="pt") 998 | # prompts = {k: v[:, -self.max_len :] for k, v in prompts.items()} 999 | classes = t.tensor([e["label"] for e in batch]) 1000 | targets = t.cat([self.label_idx2target_id[e["label"]] for e in batch]) 1001 | return prompts, classes, targets 1002 | 1003 | def sc_collate_fn(self, batch): 1004 | prompts = [e["sentence"] for e in batch] 1005 | # prompts = self.tokenizer(prompts, padding=True, return_tensors="pt") 1006 | # prompts = {k: v[:, -self.max_len :] for k, v in prompts.items()} 1007 | classes = t.tensor([e["label"] for e in batch]) 1008 | return prompts, classes, None 1009 | 1010 | 1011 | sst2 = SST2Dataset 1012 | 1013 | 1014 | class WNLIDataset(ClassificationDataset): 1015 | def __init__( 1016 | self, 1017 | tokenizer: AutoTokenizer, 1018 | add_space: bool = True, 1019 | few_shot: bool = False, 1020 | max_len: int = 4096, 1021 | ): 1022 | dset = load_dataset("glue", "wnli") 1023 | prompt = self.few_shot_preamble if few_shot else self.zero_shot_preamble 1024 | super().__init__( 1025 | dset, 1026 | tokenizer, 1027 | 2, 1028 | prompt, 1029 | add_space, 1030 | numerical=False, 1031 | few_shot=few_shot, 1032 | max_len=max_len, 1033 | ) 1034 | 1035 | few_shot_preamble = """For each pair of sentences below, indicate whether the second entails (1) or does not entail (0) the first. 1036 | 1037 | Sentence 1: Steve follows Fred's example in everything. He influences him hugely. 1038 | Sentence 2: Steve influences him hugely. 1039 | Answer: 0 1040 | 1041 | Sentence 1: The police arrested all of the gang members. They were trying to stop the drug trade in the neighborhood. 1042 | Sentence 2: The police were trying to stop the drug trade in the neighborhood. 1043 | Answer: 1 1044 | 1045 | Sentence 1: {sentence_1} 1046 | Sentence 2: {sentence_2} 1047 | Answer:""" 1048 | 1049 | zero_shot_preamble = """For each pair of sentences below, indicate whether the second entails (1) or does not entail (0) the first. 1050 | 1051 | Sentence 1: {sentence_1} 1052 | Sentence 2: {sentence_2} 1053 | Answer:""" 1054 | 1055 | def clm_collate_fn(self, batch): 1056 | # No need to use self.add_space here since we add it to the target tokens 1057 | prompts = [ 1058 | self.preamble.format(sentence_1=e["sentence1"], sentence_2=e["sentence2"]) 1059 | for e in batch 1060 | ] 1061 | # prompts = self.tokenizer(prompts, padding=True, return_tensors="pt") 1062 | # prompts = {k: v[:, -self.max_len :] for k, v in prompts.items()} 1063 | classes = t.tensor([e["label"] for e in batch]) 1064 | targets = t.cat([self.label_idx2target_id[e["label"]] for e in batch]) 1065 | return prompts, classes, targets 1066 | 1067 | def sc_collate_fn(self, batch): 1068 | prompts = [e["sentence1"] + " " + e["sentence2"] for e in batch] 1069 | # prompts = self.tokenizer(prompts, padding=True, return_tensors="pt") 1070 | # prompts = {k: v[:, -self.max_len :] for k, v in prompts.items()} 1071 | classes = t.tensor([e["label"] for e in batch]) 1072 | return prompts, classes, None 1073 | 1074 | 1075 | wnli = WNLIDataset 1076 | --------------------------------------------------------------------------------