├── llm ├── utils │ ├── __init__.py │ └── generate_utils.py ├── eval │ ├── third_party │ │ ├── __init__.py │ │ └── calibration.py │ ├── __init__.py │ ├── common.py │ └── classifier.py ├── __init__.py ├── distributed │ ├── __init__.py │ └── accelerate.py ├── models │ ├── mpnet.py │ ├── peft │ │ ├── __init__.py │ │ ├── prompt_tuning.py │ │ ├── classifier_head.py │ │ ├── utils.py │ │ ├── lora.py │ │ └── temperature_scaling.py │ ├── __init__.py │ ├── mlp.py │ ├── llm_model_utils.py │ ├── llama3.py │ ├── registry.py │ ├── openai.py │ ├── qwen.py │ ├── mistral.py │ └── llama2.py ├── datasets │ ├── offline │ │ ├── __init__.py │ │ ├── mmlu_pro_offline.py │ │ ├── offline_logits.py │ │ ├── modiste.py │ │ ├── mmlu_offline.py │ │ └── offline.py │ ├── hf │ │ ├── __init__.py │ │ ├── anli.py │ │ ├── gsm8k.py │ │ ├── mmlu_pro.py │ │ ├── wsc.py │ │ ├── truthful_qa.py │ │ ├── obqa.py │ │ ├── commonsense_qa.py │ │ ├── boolq.py │ │ ├── copa.py │ │ ├── hellaswag.py │ │ ├── cosmos_qa.py │ │ ├── story_cloze.py │ │ ├── trec.py │ │ ├── math_qa.py │ │ ├── arc.py │ │ ├── cb.py │ │ ├── snli.py │ │ ├── multirc.py │ │ └── piqa.py │ ├── __init__.py │ ├── utils.py │ ├── registry.py │ └── llm_data_utils.py ├── trainer │ ├── __init__.py │ ├── utils.py │ └── fine_tune.py └── random.py ├── notebooks ├── .gitignore ├── results │ ├── .gitignore │ ├── eval_all_20k_uniform-13b_chat.csv │ └── oe_sampling.csv ├── viz_features.ipynb ├── cleanup_offline.ipynb ├── user-study.ipynb └── mmlu_pro.ipynb ├── requirements-base.txt ├── requirements-dev.txt ├── assets └── explainer_figure.png ├── requirements.txt ├── Dockerfile ├── pyproject.toml ├── experiments ├── publish.py ├── evaluate_logits.py ├── temperature_scale.py ├── fine_tune.py ├── classifier_tune.py ├── embedding_tune.py ├── calibration_tune.py └── train_embedding_only.py └── .gitignore /llm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llm/eval/third_party/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /notebooks/.gitignore: -------------------------------------------------------------------------------- 1 | *_tmp*.ipynb -------------------------------------------------------------------------------- /notebooks/results/.gitignore: -------------------------------------------------------------------------------- 1 | *raw.csv -------------------------------------------------------------------------------- /llm/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.0" 2 | -------------------------------------------------------------------------------- /requirements-base.txt: -------------------------------------------------------------------------------- 1 | fire 2 | numpy 3 | pandas 4 | scikit-learn 5 | scipy 6 | tqdm 7 | wandb 8 | -------------------------------------------------------------------------------- /llm/eval/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import evaluate_dataset 2 | 3 | __all__ = ["evaluate_dataset"] 4 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | black[jupyter] 2 | ipywidgets 3 | jupyterlab 4 | nvitop 5 | seaborn 6 | palettable 7 | -------------------------------------------------------------------------------- /assets/explainer_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/activatedgeek/calibration-tuning/HEAD/assets/explainer_figure.png -------------------------------------------------------------------------------- /llm/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | from .accelerate import Accelerator, AcceleratorState 2 | 3 | __all__ = [ 4 | "Accelerator", 5 | "AcceleratorState", 6 | ] 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | bitsandbytes 3 | datasets 4 | openai 5 | peft 6 | sentencepiece 7 | sentence-transformers 8 | tiktoken 9 | torch 10 | torchvision 11 | transformers<=4.42 12 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | ARG MLDEV_VERSION=cu124-py311 2 | 3 | FROM nyumerics/ml:${MLDEV_VERSION} 4 | 5 | ADD . /tmp/code 6 | RUN --mount=type=cache,target=/root/.cache/pip \ 7 | pushd /tmp/code && \ 8 | micromamba-run pip install --no-cache-dir . && \ 9 | micromamba-run pip uninstall -y llm-calibration && \ 10 | rm -r /tmp/code && \ 11 | popd 12 | -------------------------------------------------------------------------------- /llm/models/mpnet.py: -------------------------------------------------------------------------------- 1 | from sentence_transformers import SentenceTransformer 2 | 3 | from .registry import register_model 4 | 5 | 6 | def get_mpnet(**_): 7 | return SentenceTransformer("sentence-transformers/multi-qa-mpnet-base-dot-v1") 8 | 9 | 10 | @register_model 11 | def mpnet_mqa(*args, **kwargs): 12 | return get_mpnet(*args, **kwargs) 13 | -------------------------------------------------------------------------------- /llm/datasets/offline/__init__.py: -------------------------------------------------------------------------------- 1 | def __setup(): 2 | from importlib import import_module 3 | 4 | for n in [ 5 | "combined", 6 | "mmlu_offline", 7 | "mmlu_pro_offline", 8 | "modiste", 9 | "offline_logits", 10 | "offline", 11 | ]: 12 | import_module(f".{n}", __name__) 13 | 14 | 15 | __setup() 16 | -------------------------------------------------------------------------------- /llm/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import WandbConfigUpdateCallback 2 | from .fine_tune import FineTuner 3 | from .calibration_tune import CalibrationTuner 4 | from .classification_tune import ClassificationTuner 5 | from .embedding_tune import EmbeddingTuner 6 | 7 | __all__ = [ 8 | "WandbConfigUpdateCallback", 9 | "FineTuner", 10 | "CalibrationTuner", 11 | "ClassificationTuner", 12 | "EmbeddingTuner", 13 | ] 14 | -------------------------------------------------------------------------------- /llm/trainer/utils.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | from transformers.trainer import TrainerCallback 3 | 4 | 5 | class WandbConfigUpdateCallback(TrainerCallback): 6 | def __init__(self, **config): 7 | self._config = config 8 | 9 | def on_train_begin(self, _args, state, _control, **_): 10 | if state.is_world_process_zero: 11 | wandb.config.update(self._config, allow_val_change=True) 12 | 13 | del self._config 14 | -------------------------------------------------------------------------------- /llm/models/peft/__init__.py: -------------------------------------------------------------------------------- 1 | from .lora import get_lora_model, use_adapter 2 | from .prompt_tuning import get_prompt_tuning_model 3 | from .temperature_scaling import ( 4 | get_temperature_scale_model, 5 | get_temperature_head, 6 | ) 7 | from .classifier_head import get_classifier_head 8 | 9 | __all__ = [ 10 | "get_prompt_tuning_model", 11 | "get_lora_model", 12 | "use_adapter", 13 | "get_temperature_scale_model", 14 | "get_temperature_head", 15 | "get_classifier_head", 16 | ] 17 | -------------------------------------------------------------------------------- /llm/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .registry import register_model, get_model, get_model_attrs, list_models 2 | 3 | 4 | __all__ = [ 5 | "register_model", 6 | "get_model", 7 | "get_model_attrs", 8 | "list_models", 9 | ] 10 | 11 | 12 | def __setup(): 13 | from importlib import import_module 14 | 15 | for n in [ 16 | "llama2", 17 | "llama3", 18 | "mlp", 19 | "mistral", 20 | "mpnet", 21 | "openai", 22 | "qwen", 23 | ]: 24 | import_module(f".{n}", __name__) 25 | 26 | 27 | __setup() 28 | -------------------------------------------------------------------------------- /llm/models/peft/prompt_tuning.py: -------------------------------------------------------------------------------- 1 | from peft import TaskType, PromptTuningConfig, PromptTuningInit, get_peft_model 2 | 3 | from .utils import get_peft_model_from_checkpoint 4 | 5 | 6 | def get_prompt_tuning_model(model, peft_dir=None, num_tokens=8): 7 | if peft_dir is not None: 8 | return get_peft_model_from_checkpoint(model, peft_dir) 9 | 10 | peft_config = PromptTuningConfig( 11 | task_type=TaskType.CAUSAL_LM, 12 | prompt_tuning_init=PromptTuningInit.RANDOM, 13 | num_virtual_tokens=num_tokens, 14 | ) 15 | model = get_peft_model(model, peft_config) 16 | 17 | return model 18 | -------------------------------------------------------------------------------- /llm/datasets/hf/__init__.py: -------------------------------------------------------------------------------- 1 | def __setup(): 2 | from importlib import import_module 3 | 4 | for n in [ 5 | "anli", 6 | "arc", 7 | "boolq", 8 | "cb", 9 | "commonsense_qa", 10 | "copa", 11 | "cosmos_qa", 12 | "gsm8k", 13 | "hellaswag", 14 | "math_qa", 15 | "mmlu", 16 | "mmlu_pro", 17 | "multirc", 18 | "obqa", 19 | "piqa", 20 | "sciq", 21 | "siqa", 22 | "snli", 23 | "story_cloze", 24 | "trec", 25 | "truthful_qa", 26 | "winogrande", 27 | "wsc", 28 | ]: 29 | import_module(f".{n}", __name__) 30 | 31 | 32 | __setup() 33 | -------------------------------------------------------------------------------- /llm/models/mlp.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .registry import register_model 4 | 5 | 6 | def get_classifier(input_size=None, output_size=None, bias=False, **_): 7 | model = nn.Sequential( 8 | nn.Linear(input_size, 256), 9 | nn.ReLU(), 10 | nn.Linear(256, 128), 11 | nn.ReLU(), 12 | nn.Linear(128, 64), 13 | nn.ReLU(), 14 | nn.Linear(64, output_size, bias=bias), 15 | # nn.Linear(input_size, output_size, bias=bias), 16 | ) 17 | 18 | return model 19 | 20 | 21 | @register_model 22 | def mlp_binary(**kwargs): 23 | kwargs.pop("output_size", None) 24 | kwargs.pop("bias", None) 25 | return get_classifier(**kwargs, output_size=2, bias=True) 26 | -------------------------------------------------------------------------------- /notebooks/results/eval_all_20k_uniform-13b_chat.csv: -------------------------------------------------------------------------------- 1 | N,acc,unc_acc,unc_auroc,unc_ece,split,seed,model_name,query_peft_dir,prompt_style,mode,log_dir,int8,dataset,ts 2 | 2000,0.7435000538825989,0.8365000486373901,0.782752011913517,0.0421182975769042,validation,137,llama2_13b_chat,/workspace/models/llm-calibration/Llama2-13b_chat-oe/Llama2-13b_chat-all_20k_offline-ct-r02l7ki1,oe,oe_fuzzy_gpt-3.5-turbo-1106,/workspace/logs/deeplearn/llm-calibration/2024-02-01T20-18-10,False,all_20k_uniform,3918.880558860488 3 | 2000,0.7460000514984131,0.6565000414848328,0.6761896782841822,0.0932806276679039,validation,137,llama2_13b_chat,,oe,oe_fuzzy_gpt-3.5-turbo-1106,/workspace/logs/deeplearn/llm-calibration/2024-02-01T20-17-41,False,all_20k_uniform,3356.9550915956497 4 | -------------------------------------------------------------------------------- /llm/models/llm_model_utils.py: -------------------------------------------------------------------------------- 1 | DEFAULT_PAD_TOKEN = "[PAD]" 2 | 3 | 4 | def resize_token_embeddings(tokenizer, model): 5 | extra_token_count = len(tokenizer) - model.get_input_embeddings().weight.data.size( 6 | 0 7 | ) 8 | if extra_token_count > 0: 9 | model.resize_token_embeddings(len(tokenizer)) 10 | 11 | input_embeddings = model.get_input_embeddings().weight.data 12 | 13 | input_embeddings[-extra_token_count:] = input_embeddings[ 14 | :-extra_token_count 15 | ].mean(dim=0, keepdim=True) 16 | 17 | output_embeddings = model.get_output_embeddings().weight.data 18 | 19 | output_embeddings[-extra_token_count:] = output_embeddings[ 20 | :-extra_token_count 21 | ].mean(dim=0, keepdim=True) 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools", 4 | "wheel" 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | 8 | [project] 9 | name = "llm-calibration" 10 | description = "LLM Calibration." 11 | readme = "README.md" 12 | license = {file = "LICENSE"} 13 | 14 | dynamic = [ 15 | "version", 16 | "dependencies", 17 | "optional-dependencies" 18 | ] 19 | 20 | [tool.setuptools] 21 | include-package-data = false 22 | 23 | [tool.setuptools.dynamic] 24 | version = { attr = "llm.__init__.__version__" } 25 | dependencies = {file = ["requirements-base.txt", "requirements.txt"]} 26 | optional-dependencies.dev = {file = ["requirements-dev.txt"]} 27 | 28 | [tool.setuptools.packages.find] 29 | exclude = [ 30 | "*experiments.*", 31 | "*experiments", 32 | "*notebooks.*", 33 | "*notebooks", 34 | "*scripts.*", 35 | "*scripts" 36 | ] 37 | -------------------------------------------------------------------------------- /llm/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .registry import ( 2 | register_dataset, 3 | get_data_dir, 4 | get_dataset, 5 | get_dataset_attrs, 6 | list_datasets, 7 | ) 8 | from .utils import get_loader, get_num_workers 9 | 10 | from .llm_data_utils import ( 11 | IGNORE_LABEL, 12 | get_token_vec, 13 | LMText, 14 | LabeledStringDataCollator, 15 | ) 16 | from .llm_utils_oe import prepare_uncertainty_query 17 | 18 | 19 | __all__ = [ 20 | "register_dataset", 21 | "get_data_dir", 22 | "get_dataset", 23 | "get_dataset_attrs", 24 | "list_datasets", 25 | "get_loader", 26 | "get_num_workers", 27 | "IGNORE_LABEL", 28 | "LabeledStringDataCollator", 29 | "get_token_vec", 30 | "LMText", 31 | "prepare_uncertainty_query", 32 | ] 33 | 34 | 35 | def __setup(): 36 | from importlib import import_module 37 | 38 | for n in [ 39 | "hf", 40 | "offline", 41 | ]: 42 | import_module(f".{n}", __name__) 43 | 44 | 45 | __setup() 46 | -------------------------------------------------------------------------------- /llm/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader, random_split 3 | 4 | 5 | def train_test_split(dataset, test_size=0.2, seed=None): 6 | N = len(dataset) 7 | N_test = int(test_size * N) 8 | N -= N_test 9 | 10 | if seed is not None: 11 | train, test = random_split( 12 | dataset, [N, N_test], generator=torch.Generator().manual_seed(seed) 13 | ) 14 | else: 15 | train, test = random_split(dataset, [N, N_test]) 16 | 17 | return train, test 18 | 19 | 20 | def get_num_workers(num_workers=4): 21 | num_gpus_per_host = torch.cuda.device_count() 22 | if num_gpus_per_host == 0: 23 | return num_workers 24 | return (num_workers + num_gpus_per_host - 1) // num_gpus_per_host 25 | 26 | 27 | def get_loader(dataset, batch_size=128, num_workers=4, accelerator=None, **kwargs): 28 | num_workers = get_num_workers(num_workers=num_workers) 29 | loader = DataLoader( 30 | dataset, batch_size=batch_size, num_workers=num_workers, **kwargs 31 | ) 32 | if accelerator is not None: 33 | loader = accelerator.prepare(loader) 34 | 35 | return loader 36 | -------------------------------------------------------------------------------- /llm/random.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | 5 | 6 | class FixedSeed: 7 | def __init__(self, seed): 8 | self.seed = seed 9 | self.np_rng_state = None 10 | self.rand_rng_state = None 11 | self.pt_rng_state = None 12 | self.cuda_rng_state = None 13 | 14 | def __enter__(self): 15 | self.rand_rng_state = random.getstate() 16 | self.np_rng_state = np.random.get_state() 17 | self.pt_rng_state = torch.random.get_rng_state() 18 | self.cuda_rng_state = torch.cuda.get_rng_state_all() 19 | 20 | self.seed_all(seed=self.seed) 21 | 22 | def __exit__(self, *_): 23 | random.setstate(self.rand_rng_state) 24 | np.random.set_state(self.np_rng_state) 25 | torch.random.set_rng_state(self.pt_rng_state) 26 | torch.cuda.set_rng_state_all(self.cuda_rng_state) 27 | 28 | @staticmethod 29 | def seed_all(seed=None): 30 | if isinstance(seed, int) and seed >= 0: 31 | random.seed(seed) 32 | np.random.seed(seed) 33 | torch.manual_seed(seed) 34 | torch.cuda.manual_seed_all(seed) 35 | -------------------------------------------------------------------------------- /llm/models/peft/classifier_head.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | from .. import get_model 8 | from .utils import get_last_checkpoint_path 9 | 10 | 11 | def get_classifier_head( 12 | input_size=None, 13 | classifier_model_name="mlp_binary", 14 | checkpoint_dir=None, 15 | is_trainable=False, 16 | weights_name="classifier_model.bin", 17 | ): 18 | classifier_model = get_model( 19 | classifier_model_name, input_size=input_size, output_size=2 20 | ) 21 | 22 | if checkpoint_dir is not None: 23 | checkpoint_dir = get_last_checkpoint_path(checkpoint_dir) 24 | 25 | if os.path.isfile(f"{checkpoint_dir}/{weights_name}"): 26 | classifier_model.load_state_dict( 27 | torch.load(f"{checkpoint_dir}/{weights_name}") 28 | ) 29 | 30 | logging.info(f"Loaded classifier model checkpoint from '{checkpoint_dir}'.") 31 | else: 32 | for module in classifier_model.modules(): 33 | if isinstance(module, nn.Linear): 34 | nn.init.xavier_normal_(module.weight) 35 | 36 | if is_trainable: 37 | classifier_model = classifier_model.train().requires_grad_(True) 38 | else: 39 | classifier_model = classifier_model.eval().requires_grad_(False) 40 | 41 | return classifier_model 42 | -------------------------------------------------------------------------------- /llm/models/peft/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from peft import PeftModel 4 | from transformers.trainer import ( 5 | PREFIX_CHECKPOINT_DIR, 6 | get_last_checkpoint as __get_last_checkpoint, 7 | ) 8 | 9 | 10 | def get_last_checkpoint_path(path): 11 | if PREFIX_CHECKPOINT_DIR not in path: 12 | path = __get_last_checkpoint(path) 13 | 14 | assert path is not None, f"No checkpoint found in '{path}'." 15 | 16 | return path 17 | 18 | 19 | def get_peft_model_from_checkpoint( 20 | model, 21 | peft_id_or_dir, 22 | is_trainable=True, 23 | adapter_name="default", 24 | **config_args, 25 | ): 26 | if os.path.isdir(peft_id_or_dir): 27 | peft_id_or_dir = get_last_checkpoint_path(peft_id_or_dir) 28 | 29 | if isinstance(model, PeftModel): 30 | model.load_adapter( 31 | peft_id_or_dir, 32 | is_trainable=is_trainable, 33 | adapter_name=adapter_name, 34 | **config_args, 35 | ) 36 | else: 37 | model = PeftModel.from_pretrained( 38 | model, 39 | peft_id_or_dir, 40 | is_trainable=is_trainable, 41 | adapter_name=adapter_name, 42 | **config_args, 43 | ) 44 | 45 | logging.info( 46 | f"Loaded PEFT adapter '{adapter_name}' checkpoint from '{peft_id_or_dir}'" 47 | ) 48 | 49 | return model 50 | -------------------------------------------------------------------------------- /llm/models/peft/lora.py: -------------------------------------------------------------------------------- 1 | from peft import TaskType, LoraConfig, get_peft_model 2 | 3 | from .utils import get_peft_model_from_checkpoint 4 | 5 | 6 | class use_adapter: 7 | def __init__(self, model, adapter_name): 8 | self.model = model 9 | self.adapter_name = adapter_name 10 | self.active_adapter = model.active_adapter 11 | 12 | def __enter__(self): 13 | self.model.set_adapter(self.adapter_name) 14 | 15 | def __exit__(self, *_): 16 | self.model.set_adapter(self.active_adapter) 17 | 18 | 19 | def get_lora_model( 20 | model, 21 | peft_id_or_dir=None, 22 | lora_rank=8, 23 | lora_alpha=32, 24 | lora_dropout=0.1, 25 | is_trainable=False, 26 | adapter_name="default", 27 | **config_args, 28 | ): 29 | if peft_id_or_dir is not None: 30 | return get_peft_model_from_checkpoint( 31 | model, 32 | peft_id_or_dir, 33 | is_trainable=is_trainable, 34 | adapter_name=adapter_name, 35 | **config_args, 36 | ) 37 | 38 | peft_config = LoraConfig( 39 | task_type=TaskType.CAUSAL_LM, 40 | bias="none", 41 | r=lora_rank, 42 | lora_alpha=lora_alpha, 43 | lora_dropout=lora_dropout, 44 | inference_mode=not is_trainable, 45 | **config_args, 46 | ) 47 | model = get_peft_model(model, peft_config, adapter_name=adapter_name) 48 | 49 | return model 50 | -------------------------------------------------------------------------------- /llm/datasets/offline/mmlu_pro_offline.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from ..registry import register_dataset, DatasetTag 4 | from ..hf.mmlu_pro import __TASKS 5 | from .offline import get_offline 6 | 7 | 8 | @register_dataset(attrs=dict(tasks=__TASKS, tags=[DatasetTag.EVAL_ONLY])) 9 | def mmlu_pro_offline( 10 | root=None, dataset_str=None, prompt_style=None, eval_kshot=5, **kwargs 11 | ): 12 | try: 13 | _, name, task = dataset_str.split(":") 14 | 15 | assert task in __TASKS 16 | except ValueError: 17 | logging.exception( 18 | f'Dataset string should be formatted as "mmlu_pro_offline::" (Got {dataset_str})', 19 | ) 20 | raise 21 | except AssertionError: 22 | logging.exception( 23 | f'Task not found. Dataset string should be formatted as "mmlu_pro_offline::" (Got {dataset_str})', 24 | ) 25 | raise 26 | 27 | root = f"{root}/mmlu_pro_offline/{prompt_style}/{name}/{task}" 28 | 29 | return get_offline(root=root, eval_kshot=eval_kshot, **kwargs) 30 | 31 | 32 | @register_dataset(attrs=dict(unlisted=True, collection=True)) 33 | def mmlu_pro_offline_all(dataset_str=None, **_): 34 | try: 35 | _, name = dataset_str.split(":") 36 | except ValueError: 37 | logging.exception( 38 | f'Dataset string should be formatted as "mmlu_pro_offline_all:" (Got {dataset_str})', 39 | ) 40 | raise 41 | 42 | return [f"mmlu_pro_offline:{name}:{task}" for task in __TASKS] 43 | -------------------------------------------------------------------------------- /llm/models/llama3.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from .registry import register_model 4 | from .llama2 import create_tokenizer, create_tokenizer_and_model, create_embed_model 5 | 6 | TOKENIZER_ARGS = dict(model_max_length=8192) 7 | 8 | 9 | __HF_MODEL_MAP = { 10 | "8b": "Meta-Llama-3-8B", 11 | "8b-instruct": "Meta-Llama-3-8B-Instruct", 12 | "70b": "Meta-Llama-3-70B", 13 | "70b-instruct": "Meta-Llama-3-70B-Instruct", 14 | } 15 | 16 | 17 | def __get_model_hf_id(model_str): 18 | try: 19 | _, kind = model_str.split(":") 20 | 21 | assert kind in __HF_MODEL_MAP.keys() 22 | except ValueError: 23 | logging.exception( 24 | f'Model string should be formatted as "llama3:" (Got {model_str})', 25 | ) 26 | raise 27 | except AssertionError: 28 | logging.exception( 29 | f'Model not found. Model string should be formatted as "llama3:" (Got {model_str})', 30 | ) 31 | raise 32 | 33 | return __HF_MODEL_MAP[kind] 34 | 35 | 36 | @register_model(**TOKENIZER_ARGS) 37 | def llama3_tokenizer(*, model_str=None, **kwargs): 38 | return create_tokenizer(__get_model_hf_id(model_str), **kwargs) 39 | 40 | 41 | @register_model(tokenizer_args=TOKENIZER_ARGS) 42 | def llama3(*, model_str=None, **kwargs): 43 | return create_tokenizer_and_model(__get_model_hf_id(model_str), **kwargs) 44 | 45 | 46 | @register_model(tokenizer_args=TOKENIZER_ARGS) 47 | def llama3_embed(*, model_str=None, **kwargs): 48 | return create_embed_model(__get_model_hf_id(model_str), **kwargs) 49 | -------------------------------------------------------------------------------- /experiments/publish.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datasets import DatasetDict 3 | 4 | from llm.datasets import get_dataset 5 | from llm.models import get_model 6 | from llm.models.peft import get_lora_model 7 | 8 | 9 | def main( 10 | dataset=None, 11 | model_name=None, 12 | query_peft_dir=None, 13 | hf_hub_id=None, 14 | hf_token=None, 15 | ): 16 | hf_token = hf_token or os.environ.get("HUGGING_FACE_HUB_TOKEN") 17 | if hf_token is None: 18 | raise ValueError( 19 | f"Missing hf_username (env: HUGGING_FACE_USERNAME) and/or hf_token (env: HUGGING_FACE_HUB_TOKEN)" 20 | ) 21 | 22 | if model_name is not None: 23 | _, model = get_model( 24 | model_name, 25 | device_map="cpu", 26 | ) 27 | 28 | model = get_lora_model( 29 | model, 30 | peft_id_or_dir=query_peft_dir, 31 | is_trainable=False, 32 | adapter_name="default", 33 | ) 34 | 35 | print(f'Pushing model "{hf_hub_id}" to HuggingFace Hub.') 36 | 37 | model.push_to_hub(hf_hub_id, private=True, token=hf_token) 38 | elif dataset is not None: 39 | train_data, val_data, _ = get_dataset( 40 | dataset, 41 | num_workers=8, 42 | use_cache=True, 43 | ) 44 | 45 | dataset = DatasetDict({"train": train_data, "validation": val_data}) 46 | 47 | print(f'Pushing dataset "{hf_hub_id}" to HuggingFace Hub.') 48 | 49 | dataset.push_to_hub(hf_hub_id, private=True, token=hf_token) 50 | else: 51 | raise ValueError('Missing "model_name" or "dataset"') 52 | 53 | 54 | if __name__ == "__main__": 55 | import fire 56 | 57 | fire.Fire(main) 58 | -------------------------------------------------------------------------------- /llm/models/registry.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from functools import wraps 3 | 4 | 5 | __func_map = dict() 6 | __attr_map = dict() 7 | 8 | 9 | def register_model(function=None, attrs=None, **d_kwargs): 10 | def _decorator(f): 11 | @wraps(f) 12 | def _wrapper(*args, **kwargs): 13 | all_kwargs = {**d_kwargs, **kwargs} 14 | return f(*args, **all_kwargs) 15 | 16 | assert ( 17 | _wrapper.__name__ not in __func_map 18 | ), f'Duplicate registration for "{_wrapper.__name__}"' 19 | 20 | __func_map[_wrapper.__name__] = _wrapper 21 | __attr_map[_wrapper.__name__] = attrs or dict() 22 | return _wrapper 23 | 24 | if function: 25 | return _decorator(function) 26 | return _decorator 27 | 28 | 29 | model_key = lambda m: m.split(":")[0] 30 | 31 | 32 | def get_model_attrs(name): 33 | key = model_key(name) 34 | if key not in __attr_map: 35 | raise ValueError(f'Model "{key}" not found.') 36 | 37 | return __attr_map[key] 38 | 39 | 40 | def get_model_fn(name): 41 | key = model_key(name) 42 | if key not in __func_map: 43 | raise ValueError(f'Model "{key}" not found.') 44 | 45 | return __func_map[key] 46 | 47 | 48 | def get_model(model_name, **kwargs): 49 | kwargs = {k: v for k, v in kwargs.items() if v is not None} 50 | 51 | model_fn = get_model_fn(model_name) 52 | 53 | model = model_fn(model_str=model_name, **kwargs) 54 | 55 | logging.info(f'Loaded "{model_name}".') 56 | 57 | return model 58 | 59 | 60 | def list_models(): 61 | return [ 62 | model_name 63 | for model_name in __func_map.keys() 64 | if not get_model_attrs(model_name).get("unlisted", False) 65 | ] 66 | -------------------------------------------------------------------------------- /llm/models/peft/temperature_scaling.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .utils import get_last_checkpoint_path 7 | 8 | 9 | class TemperatureScale(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | self.log_temperature = nn.Parameter(torch.tensor(0.0)) 14 | 15 | def forward(self, inputs): 16 | return inputs / self.log_temperature.exp() 17 | 18 | 19 | def get_temperature_head( 20 | checkpoint_dir=None, is_trainable=False, weights_name="temperature_head.bin" 21 | ): 22 | 23 | temperature_model = TemperatureScale() 24 | 25 | if checkpoint_dir is not None: 26 | checkpoint_dir = get_last_checkpoint_path(checkpoint_dir) 27 | 28 | if os.path.isfile(f"{checkpoint_dir}/{weights_name}"): 29 | temperature_model.load_state_dict( 30 | torch.load(f"{checkpoint_dir}/{weights_name}") 31 | ) 32 | 33 | logging.info( 34 | f"Loaded temperature model checkpoint from '{checkpoint_dir}'." 35 | ) 36 | 37 | if is_trainable: 38 | temperature_model = temperature_model.train().requires_grad_(True) 39 | else: 40 | temperature_model = temperature_model.eval().requires_grad_(False) 41 | 42 | return temperature_model 43 | 44 | 45 | def get_temperature_scale_model(model, target_module_name="lm_head", **kwargs): 46 | 47 | for key, mod in model.named_modules(): 48 | if key.endswith(target_module_name): 49 | device = [p.device for _, p in mod.named_parameters()][0] 50 | 51 | temperature_model = get_temperature_head(**kwargs).to(device) 52 | 53 | new_module = nn.Sequential( 54 | mod, 55 | temperature_model, 56 | ) 57 | 58 | parent = model.get_submodule(".".join(key.split(".")[:-1])) 59 | target_name = key.split(".")[-1] 60 | setattr(parent, target_name, new_module) 61 | 62 | logging.info(f"Added temperature scaling module to {key}.") 63 | 64 | return model 65 | -------------------------------------------------------------------------------- /experiments/evaluate_logits.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from tqdm.auto import tqdm 3 | import wandb 4 | import pandas as pd 5 | 6 | from llm.datasets import get_dataset_attrs, get_dataset 7 | from llm.eval import evaluate_dataset 8 | from llm.logging import entrypoint 9 | 10 | 11 | @entrypoint(with_accelerator=True) 12 | def main( 13 | accelerator=None, 14 | seed=137, 15 | log_dir=None, 16 | dataset=None, 17 | data_dir=None, 18 | prompt_style=None, 19 | eval_kshot=None, 20 | use_dataset_cache=True, 21 | model_name=None, 22 | scale_temp=True, 23 | mode=None, 24 | batch_size=1, 25 | ): 26 | config = dict( 27 | seed=seed, 28 | log_dir=log_dir, 29 | dataset=dataset, 30 | prompt_style=prompt_style, 31 | eval_kshot=eval_kshot, 32 | use_dataset_cache=use_dataset_cache, 33 | model_name=model_name, 34 | scale_temp=scale_temp, 35 | mode=mode, 36 | batch_size=batch_size, 37 | ) 38 | if accelerator.is_main_process: 39 | wandb.config.update(config, allow_val_change=True) 40 | 41 | if get_dataset_attrs(dataset).get("collection", False): 42 | all_datasets = get_dataset(dataset) 43 | else: 44 | assert dataset is not None, "Missing dataset." 45 | all_datasets = [dataset] 46 | 47 | all_metrics = [] 48 | for dataset in tqdm(all_datasets): 49 | metrics = evaluate_dataset( 50 | accelerator, 51 | None, 52 | None, 53 | dataset, 54 | data_dir=data_dir, 55 | prompt_style=prompt_style, 56 | eval_kshot=eval_kshot, 57 | use_cache=use_dataset_cache, 58 | train_data=False, 59 | seed=seed, 60 | batch_size=batch_size, 61 | log_dir=log_dir, 62 | evaluate_fn=mode, 63 | ) 64 | 65 | all_metrics += metrics 66 | logging.info( 67 | {"metrics": wandb.Table(dataframe=pd.DataFrame(all_metrics))}, 68 | extra=dict(metrics=True), 69 | ) 70 | 71 | accelerator.free_memory() 72 | 73 | 74 | if __name__ == "__main__": 75 | import fire 76 | 77 | fire.Fire(main) 78 | -------------------------------------------------------------------------------- /llm/datasets/offline/offline_logits.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | import torch 4 | from torch.utils.data import Dataset 5 | 6 | from ..registry import register_dataset 7 | 8 | 9 | class LogitsDataset(Dataset): 10 | def __init__(self, logits, labels): 11 | self.logits = logits 12 | self.labels = labels 13 | 14 | def __len__(self): 15 | return len(self.labels) 16 | 17 | def __getitem__(self, index): 18 | return self.logits[index], self.labels[index] 19 | 20 | 21 | def get_offline_logits(root=None, **_): 22 | data_splits = dict() 23 | for split in ["train", "validation", "test"]: 24 | data_path = Path(root) / split 25 | 26 | if not data_path.is_dir(): 27 | continue 28 | 29 | data = [ 30 | torch.load(path, map_location="cpu")["fuzzy_gpt-3.5-turbo-1106"] 31 | for path in data_path.glob("*.pt") 32 | ] 33 | if data: 34 | data = {k: torch.cat([v[k] for v in data], dim=0) for k in data[0].keys()} 35 | with open(data_path / "logits.bin", "wb") as f: 36 | torch.save(data, f) 37 | 38 | [path.unlink() for path in data_path.glob("*.pt")] 39 | 40 | try: 41 | data_path = next(data_path.glob("*.bin")) 42 | except StopIteration: 43 | logging.exception(f".bin file not found at {data_path}") 44 | raise 45 | 46 | data = torch.load(data_path, map_location="cpu") 47 | 48 | logits = data.pop("q_logits") 49 | labels = data.pop("q_labels").long() 50 | 51 | data_splits[split] = LogitsDataset(logits, labels) 52 | 53 | train_data = data_splits.pop("train", None) 54 | val_data = data_splits.pop("validation", None) 55 | test_data = data_splits.pop("test", None) 56 | 57 | return train_data, val_data, test_data 58 | 59 | 60 | @register_dataset(attrs=dict(unlisted=True)) 61 | def offline_logits(*args, root=None, dataset_str=None, **kwargs): 62 | try: 63 | _, kind = dataset_str.split(":") 64 | except ValueError: 65 | logging.exception(f"Dataset format should be offline_logits:.") 66 | raise 67 | 68 | root = Path(root) / "offline_logits" / kind 69 | 70 | return get_offline_logits(*args, root=root, **kwargs) 71 | -------------------------------------------------------------------------------- /llm/datasets/hf/anli.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset, DatasetDict 2 | 3 | from ..registry import register_dataset 4 | from ..llm_data_utils import PromptFormat 5 | from .snli import format_sample, format_sample_prompt 6 | 7 | 8 | def get_anli( 9 | round=None, 10 | prompt_style=None, 11 | with_query_label=False, 12 | train_kshot=0, 13 | eval_kshot=0, 14 | num_workers=8, 15 | seed=None, 16 | use_cache=True, 17 | **_, 18 | ): 19 | format = PromptFormat(prompt_style) 20 | 21 | dataset = load_dataset("anli") 22 | if not use_cache: 23 | dataset.cleanup_cache_files() 24 | 25 | dataset = DatasetDict( 26 | {k.split("_")[0]: v for k, v in dataset.items() if k.endswith(f"_r{round}")} 27 | ) 28 | 29 | dataset = dataset.filter( 30 | lambda x: x["label"] in [0, 1, 2], num_proc=num_workers 31 | ).map( 32 | lambda sample, idx: format_sample( 33 | sample, format, with_query_label=with_query_label, seed=seed + idx 34 | ).to_pydict(), 35 | with_indices=True, 36 | num_proc=num_workers, 37 | remove_columns=dataset.column_names["test"], 38 | ) 39 | 40 | prompt_data = dataset.pop("dev") 41 | prompt_kshot = { 42 | "train": train_kshot, 43 | "validation": eval_kshot, 44 | "test": eval_kshot, 45 | } 46 | 47 | data_splits = { 48 | split: ds.map( 49 | lambda _, idx: { 50 | "prompt": format_sample_prompt( 51 | prompt_data, format, kshot=prompt_kshot[split], seed=seed + idx 52 | ) 53 | }, 54 | with_indices=True, 55 | num_proc=num_workers, 56 | ) 57 | for split, ds in dataset.items() 58 | } 59 | 60 | train_data = data_splits.pop("train", None) 61 | val_data = data_splits.pop("validation", None) 62 | test_data = data_splits.pop("test", None) 63 | 64 | return train_data, val_data, test_data 65 | 66 | 67 | @register_dataset 68 | def anli_r1(*args, **kwargs): 69 | return get_anli(*args, **kwargs, round=1) 70 | 71 | 72 | @register_dataset 73 | def anli_r2(*args, **kwargs): 74 | return get_anli(*args, **kwargs, round=2) 75 | 76 | 77 | @register_dataset 78 | def anli_r3(*args, **kwargs): 79 | return get_anli(*args, **kwargs, round=3) 80 | -------------------------------------------------------------------------------- /llm/models/openai.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | from functools import partial 4 | from typing import Any 5 | import numpy as np 6 | import tiktoken 7 | from openai import OpenAI, APIError 8 | import torch 9 | 10 | from .registry import register_model 11 | 12 | 13 | class OpenAITokenizer: 14 | def __init__(self, model): 15 | self.encoding = tiktoken.encoding_for_model(model) 16 | 17 | def __call__(self, texts, return_tensors=None, return_length=True, **_) -> Any: 18 | input_ids = self.encoding.encode_batch(texts) 19 | 20 | out = dict(input_ids=input_ids) 21 | if return_length: 22 | out["length"] = [len(ids) for ids in input_ids] 23 | if return_tensors == "pt": 24 | out["length"] = torch.tensor(out["length"]).long() 25 | return out 26 | 27 | 28 | def get_openai_tokenizer(model, **_): 29 | return OpenAITokenizer(model) 30 | 31 | 32 | @register_model 33 | def oai_gpt35t_tokenizer(*_, **kwargs): 34 | return get_openai_tokenizer("gpt-3.5-turbo", **kwargs) 35 | 36 | 37 | @register_model 38 | def oai_gpt4_tokenizer(*_, **kwargs): 39 | return get_openai_tokenizer("gpt-4", **kwargs) 40 | 41 | 42 | class OpenAIEmbeddingModel: 43 | def __init__(self, model, dimension=1536, retries=10): 44 | client = OpenAI() 45 | self.retries = retries 46 | self.d = dimension 47 | self.model = partial( 48 | client.embeddings.create, 49 | model=model, 50 | encoding_format="float", 51 | dimensions=self.d, 52 | ) 53 | 54 | def get_sentence_embedding_dimension(self): 55 | return self.d 56 | 57 | def encode(self, texts, **_): 58 | response = None 59 | 60 | __retries_left = int(self.retries) 61 | while response is None and __retries_left: 62 | try: 63 | response = self.model(input=texts) 64 | except APIError: 65 | logging.exception("OpenAI API Error.", exc_info=True) 66 | time.sleep(1) 67 | 68 | __retries_left -= 1 69 | 70 | embeddings = np.array([d.embedding for d in response.data]) 71 | return embeddings 72 | 73 | def __call__(self, *args, **kwargs): 74 | return self.encode(*args, **kwargs) 75 | 76 | 77 | def get_openai_embedding_model(model, dimension=1536, retries=10, **_): 78 | return OpenAIEmbeddingModel(model, dimension=dimension, retries=retries) 79 | 80 | 81 | @register_model 82 | def oai_small(*_, **kwargs): 83 | return get_openai_embedding_model("text-embedding-3-small", **kwargs) 84 | -------------------------------------------------------------------------------- /llm/datasets/offline/modiste.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import json 3 | from pathlib import Path 4 | from datasets import Dataset, Value 5 | 6 | from ..registry import register_dataset, DatasetTag 7 | from ..llm_data_utils import LMText 8 | 9 | 10 | __TASKS = [ 11 | "elementary_mathematics", 12 | "high_school_biology", 13 | "us_foreign_policy", 14 | "high_school_computer_science", 15 | ] 16 | 17 | 18 | def format_sample(sample): 19 | context = sample["prompt"] 20 | target = sample["label"] 21 | output = sample["llm_answer"] 22 | 23 | return LMText( 24 | context=context, 25 | target=target, 26 | output=output, 27 | ) 28 | 29 | 30 | def get_modiste(root=None, task=None, num_workers=8, use_cache=True, **_): 31 | with open(Path(root) / "mmlu_responses_w_conf.json") as f: 32 | data = json.load(f) 33 | 34 | dataset = Dataset.from_list(data[task]) 35 | if not use_cache: 36 | dataset.cleanup_cache_files() 37 | 38 | dataset = dataset.map( 39 | lambda sample: { 40 | **format_sample(sample).to_pydict(), 41 | ## Keep IDs for mapping later. 42 | "example_idx": sample["example_idx"], 43 | "orig_example_idx": sample["orig_example_idx"], 44 | }, 45 | num_proc=num_workers, 46 | remove_columns=list( 47 | set(dataset.column_names) - set(["example_idx", "orig_example_idx"]) 48 | ), 49 | ) 50 | 51 | types = dataset.features.copy() 52 | types["example_idx"] = Value("int64") 53 | types["orig_example_idx"] = Value("int64") 54 | dataset = dataset.cast(types, num_proc=num_workers) 55 | 56 | return None, None, dataset 57 | 58 | 59 | @register_dataset(attrs=dict(tasks=__TASKS, tags=[DatasetTag.EVAL_ONLY])) 60 | def modiste_mmlu(*args, root=None, dataset_str=None, **kwargs): 61 | root = Path(root) / "modiste" 62 | 63 | try: 64 | _, task = dataset_str.split(":") 65 | 66 | assert task in __TASKS 67 | except ValueError: 68 | logging.exception( 69 | f'Dataset string should be formatted as "modiste_mmlu:" (Got {dataset_str})', 70 | ) 71 | raise 72 | except AssertionError: 73 | logging.exception( 74 | f'Task not found. Dataset string should be formatted as "modiste_mmlu:" (Got {dataset_str})', 75 | ) 76 | raise 77 | 78 | return get_modiste(*args, root=root, task=task, **kwargs) 79 | 80 | 81 | @register_dataset(attrs=dict(unlisted=True, collection=True)) 82 | def modiste_mmlu_all(*args, **kwargs): 83 | return [f"modiste_mmlu:{task}" for task in __TASKS] 84 | -------------------------------------------------------------------------------- /llm/distributed/accelerate.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import timedelta 3 | import torch 4 | import torch.distributed as torchdist 5 | from torch.distributed.fsdp import ( 6 | ShardingStrategy, 7 | BackwardPrefetch, 8 | StateDictType, 9 | CPUOffload, 10 | ) 11 | from accelerate import ( 12 | Accelerator as __HFAccelerator, 13 | PartialState as __HFAcceleratorState, 14 | DeepSpeedPlugin, 15 | FullyShardedDataParallelPlugin, 16 | InitProcessGroupKwargs, 17 | ) 18 | from accelerate.utils import PrecisionType 19 | 20 | 21 | class Accelerator(__HFAccelerator): 22 | def __init__(self, *args, **kwargs): 23 | deepspeed_plugin = None 24 | if os.getenv("ACCELERATE_USE_DEEPSPEED", "false") == "true": 25 | deepspeed_plugin = DeepSpeedPlugin(zero3_init_flag=True) 26 | 27 | fsdp_plugin = None 28 | if os.getenv("ACCELERATE_USE_FSDP", "false") == "true": 29 | os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = "true" 30 | 31 | fsdp_plugin = FullyShardedDataParallelPlugin( 32 | sharding_strategy=ShardingStrategy.FULL_SHARD, 33 | backward_prefetch=BackwardPrefetch.BACKWARD_PRE, 34 | sync_module_states=True, 35 | use_orig_params=True, 36 | forward_prefetch=False, 37 | auto_wrap_policy="TRANSFORMER_BASED_WRAP", 38 | cpu_offload=CPUOffload(offload_params=False), 39 | state_dict_type=StateDictType.SHARDED_STATE_DICT, 40 | ) 41 | 42 | super().__init__( 43 | *args, 44 | **kwargs, 45 | mixed_precision=( 46 | PrecisionType.BF16 47 | if torch.cuda.is_bf16_supported() 48 | else PrecisionType.FP16 49 | ), 50 | fsdp_plugin=fsdp_plugin, 51 | deepspeed_plugin=deepspeed_plugin, 52 | kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=600))] 53 | ) 54 | 55 | def sync_object(self, obj): 56 | if self.num_processes == 1: 57 | return obj 58 | 59 | __sync_obj = [None for _ in range(self.num_processes)] 60 | torchdist.all_gather_object(__sync_obj, obj) 61 | obj = __sync_obj[0] 62 | return obj 63 | 64 | 65 | class AcceleratorState(__HFAcceleratorState): 66 | def sync_object(self, obj): 67 | if self.num_processes == 1: 68 | return obj 69 | 70 | __sync_obj = [None for _ in range(self.num_processes)] 71 | torchdist.all_gather_object(__sync_obj, obj) 72 | obj = __sync_obj[0] 73 | return obj 74 | -------------------------------------------------------------------------------- /experiments/temperature_scale.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | import torch 4 | from tqdm.auto import tqdm 5 | 6 | from llm.logging import entrypoint 7 | from llm.datasets import get_dataset, get_loader 8 | from llm.models.peft.temperature_scaling import TemperatureScale 9 | 10 | 11 | @entrypoint(with_accelerator=True) 12 | def main( 13 | accelerator=None, 14 | seed=137, 15 | log_dir=None, 16 | dataset=None, 17 | data_dir=None, 18 | num_workers=4, 19 | batch_size=64, 20 | lr=1e-3, 21 | weight_decay=1e-2, 22 | max_steps=2000, 23 | ): 24 | _, val_data, test_data = get_dataset(dataset, seed=seed, root=data_dir) 25 | if val_data is None: 26 | val_data = test_data 27 | 28 | model = TemperatureScale() 29 | 30 | optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) 31 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_steps) 32 | 33 | loader = get_loader( 34 | val_data, 35 | batch_size=batch_size, 36 | num_workers=num_workers, 37 | pin_memory=True, 38 | accelerator=accelerator, 39 | shuffle=True, 40 | ) 41 | 42 | model, optimizer, scheduler = accelerator.prepare(model, optimizer, scheduler) 43 | 44 | criterion = torch.nn.CrossEntropyLoss() 45 | 46 | logging_steps = max(1, max_steps // 200) 47 | save_steps = max_steps // 10 48 | 49 | iter_loader = iter(loader) 50 | for step in tqdm(range(max_steps)): 51 | optimizer.zero_grad() 52 | 53 | try: 54 | batch = next(iter_loader) 55 | except StopIteration: 56 | iter_loader = iter(loader) 57 | batch = next(iter_loader) 58 | 59 | logits, labels = batch 60 | logits = model(logits) 61 | 62 | loss = criterion(logits, labels) 63 | 64 | accelerator.backward(loss) 65 | 66 | optimizer.step() 67 | scheduler.step() 68 | 69 | log_metrics = { 70 | "loss": loss.detach().item(), 71 | "log_temperature": model.log_temperature.data.item(), 72 | "temperature": model.log_temperature.data.exp().item(), 73 | } 74 | 75 | if accelerator.is_main_process and (step + 1) % logging_steps == 0: 76 | logging.info(log_metrics) 77 | logging.info(log_metrics, extra=dict(metrics=True)) 78 | 79 | if accelerator.is_main_process and (step + 1) % save_steps == 0: 80 | checkpoint_path = ( 81 | Path(log_dir) / f"checkpoint-{step + 1}" / "temperature_head.bin" 82 | ) 83 | checkpoint_path.parent.mkdir() 84 | 85 | torch.save(accelerator.unwrap_model(model).state_dict(), checkpoint_path) 86 | 87 | 88 | if __name__ == "__main__": 89 | import fire 90 | 91 | fire.Fire(main) 92 | -------------------------------------------------------------------------------- /notebooks/viz_features.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "import seaborn as sns\n", 12 | "import pandas as pd\n", 13 | "from sklearn.preprocessing import StandardScaler\n", 14 | "import torch" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "features = torch.load(\"\", map_location=\"cpu\")\n", 24 | "\n", 25 | "query_features = features[\"features\"].float().numpy()\n", 26 | "query_features = StandardScaler().fit_transform(query_features)\n", 27 | "query_labels = features[\"labels\"].long().numpy()\n", 28 | "\n", 29 | "del features\n", 30 | "\n", 31 | "query_features.shape, query_labels.shape" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "np.mean(query_labels)" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "import umap\n", 50 | "\n", 51 | "projector = umap.UMAP()\n", 52 | "projected_features = projector.fit_transform(query_features)\n", 53 | "viz_df = pd.DataFrame({\n", 54 | " \"x\": projected_features[:, 0],\n", 55 | " \"y\": projected_features[:, 1],\n", 56 | " \"labels\": query_labels,\n", 57 | "})\n", 58 | "\n", 59 | "viz_df" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "fig, ax = plt.subplots(figsize=(5,5))\n", 69 | "\n", 70 | "sns.scatterplot(ax=ax, data=viz_df, \n", 71 | " x=\"x\", y=\"y\", hue=\"labels\",\n", 72 | " palette=sns.color_palette(\"Set2\", 2))\n", 73 | "\n", 74 | "fig.tight_layout()\n", 75 | "fig.show()" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [] 84 | } 85 | ], 86 | "metadata": { 87 | "kernelspec": { 88 | "display_name": "Python 3", 89 | "language": "python", 90 | "name": "python3" 91 | }, 92 | "language_info": { 93 | "codemirror_mode": { 94 | "name": "ipython", 95 | "version": 3 96 | }, 97 | "file_extension": ".py", 98 | "mimetype": "text/x-python", 99 | "name": "python", 100 | "nbconvert_exporter": "python", 101 | "pygments_lexer": "ipython3", 102 | "version": "3.10.13" 103 | } 104 | }, 105 | "nbformat": 4, 106 | "nbformat_minor": 2 107 | } 108 | -------------------------------------------------------------------------------- /notebooks/cleanup_offline.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import glob\n", 10 | "import pandas as pd\n", 11 | "import os\n", 12 | "\n", 13 | "def cleanup(src, suffix, target_suffix):\n", 14 | " df = pd.concat([\n", 15 | " pd.read_csv(p) \n", 16 | " for p in glob.glob(f\"{src}/{suffix}/*.csv\")], \n", 17 | " ignore_index=True)\n", 18 | "\n", 19 | " df[\"output\"] = df[\"output\"].apply(lambda x: str(x).split(\"\\n\")[0])\n", 20 | " df[\"output\"] = df[\"output\"].apply(lambda x: str(x).strip(\"\\n\").strip())\n", 21 | " df = df[df[\"output\"] != \"\"]\n", 22 | " df = df[df[\"output\"] != \"nan\"]\n", 23 | "\n", 24 | " filter_out = [\n", 25 | " \"\",\n", 26 | " \"\",\n", 27 | " \"\",\n", 28 | " \"\",\n", 29 | " \"\",\n", 30 | " \"\",\n", 31 | " \"\",\n", 32 | " \"\",\n", 33 | " \"_______________\",\n", 34 | " \"Note: I'll give you a hint\",\n", 35 | " \"Note: I'll provide the next question after you answer this one\"\n", 36 | " \"Please provide your answer\",\n", 37 | " \"Which of the following is true?\",\n", 38 | " ]\n", 39 | "\n", 40 | " #filter out any outputs that contain any of the above strings\n", 41 | " for fo in filter_out:\n", 42 | " df = df[~df[\"output\"].str.contains(fo)]\n", 43 | "\n", 44 | " os.makedirs(f\"{src}/{target_suffix}\")\n", 45 | "\n", 46 | " df.to_csv(f\"{src}/{target_suffix}/rows_0.csv\", index=False)" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "import os\n", 56 | "\n", 57 | "csv_folder = f\"{os.environ.get('DATADIR')}/llm-calibration/generated/____/processed\"\n", 58 | "\n", 59 | "# cleanup(csv_folder, \"raw/train\", \"processed/train\")\n", 60 | "# cleanup(csv_folder, \"raw/validation\", \"processed/validation\")" 61 | ] 62 | } 63 | ], 64 | "metadata": { 65 | "kernelspec": { 66 | "display_name": "Python 3", 67 | "language": "python", 68 | "name": "python3" 69 | }, 70 | "language_info": { 71 | "codemirror_mode": { 72 | "name": "ipython", 73 | "version": 3 74 | }, 75 | "file_extension": ".py", 76 | "mimetype": "text/x-python", 77 | "name": "python", 78 | "nbconvert_exporter": "python", 79 | "pygments_lexer": "ipython3", 80 | "version": "3.10.13" 81 | } 82 | }, 83 | "nbformat": 4, 84 | "nbformat_minor": 2 85 | } 86 | -------------------------------------------------------------------------------- /llm/datasets/registry.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | import os 3 | import logging 4 | from functools import wraps 5 | from pathlib import Path 6 | 7 | 8 | __func_map = dict() 9 | __attr_map = dict() 10 | 11 | 12 | class DatasetTag(str, Enum): 13 | TRAIN_ONLY = "train_only" 14 | EVAL_ONLY = "eval_only" 15 | 16 | 17 | def register_dataset(function=None, attrs=None, **d_kwargs): 18 | def _decorator(f): 19 | @wraps(f) 20 | def _wrapper(*args, **kwargs): 21 | all_kwargs = {**d_kwargs, **kwargs} 22 | return f(*args, **all_kwargs) 23 | 24 | assert ( 25 | _wrapper.__name__ not in __func_map 26 | ), f'Duplicate registration for "{_wrapper.__name__}"' 27 | 28 | __func_map[_wrapper.__name__] = _wrapper 29 | __attr_map[_wrapper.__name__] = attrs or dict() 30 | return _wrapper 31 | 32 | if function: 33 | return _decorator(function) 34 | return _decorator 35 | 36 | 37 | dataset_key = lambda d: d.split(":")[0] 38 | 39 | 40 | def get_dataset_attrs(name): 41 | key = dataset_key(name) 42 | if key not in __attr_map: 43 | raise ValueError(f'Dataset "{key}" not found.') 44 | 45 | return __attr_map[key] 46 | 47 | 48 | def get_dataset_fn(name): 49 | key = dataset_key(name) 50 | if key not in __func_map: 51 | raise ValueError(f'Dataset "{key}" not found.') 52 | 53 | return __func_map[key] 54 | 55 | 56 | def get_data_dir(data_dir=None): 57 | if data_dir is None: 58 | data_dir = ( 59 | Path(os.environ.get("PROJECT_HOME", Path.home())) 60 | / Path.cwd().name 61 | / "datasets" 62 | ) 63 | else: 64 | data_dir = Path(data_dir) 65 | 66 | data_dir.mkdir(parents=True, exist_ok=True) 67 | 68 | return str(data_dir.resolve()) 69 | 70 | 71 | def get_dataset(dataset, root=None, seed=42, **kwargs): 72 | kwargs = {k: v for k, v in kwargs.items() if v is not None} 73 | 74 | dataset_fn = get_dataset_fn(dataset) 75 | 76 | root = get_data_dir(data_dir=root) 77 | 78 | if get_dataset_attrs(dataset).get("collection", False): 79 | return dataset_fn(root=root, dataset_str=dataset) 80 | 81 | train_data, val_data, test_data = dataset_fn( 82 | root=root, 83 | seed=seed, 84 | dataset_str=dataset, 85 | **kwargs, 86 | ) 87 | 88 | info_str = " / ".join( 89 | [ 90 | f"{s} (N = {len(ds)})" 91 | for ds, s in zip( 92 | (train_data, val_data, test_data), ("train", "validation", "test") 93 | ) 94 | if ds is not None 95 | ] 96 | ) 97 | logging.info(f'Loaded "{dataset}"; {info_str}') 98 | 99 | return train_data, val_data, test_data 100 | 101 | 102 | def list_datasets(): 103 | return [ 104 | dname 105 | for dname in __func_map.keys() 106 | if not get_dataset_attrs(dname).get("unlisted", False) 107 | ] 108 | -------------------------------------------------------------------------------- /llm/datasets/hf/gsm8k.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from datasets import load_dataset 3 | 4 | from ..registry import register_dataset, DatasetTag 5 | from ..llm_data_utils import LMText, PromptFormat 6 | 7 | 8 | def format_sample(sample, format): 9 | target_prompt = "\nAnswer:" 10 | 11 | question = sample["question"] 12 | 13 | if format == PromptFormat.OE: 14 | context = "\n".join([f"Question: {question}"]) 15 | 16 | target = sample["answer"] 17 | else: 18 | raise NotImplementedError(f"Unsupported prompt format {format}.") 19 | 20 | return LMText(context=context, target_prompt=target_prompt, target=target) 21 | 22 | 23 | def format_sample_prompt(prompt_dataset, format, kshot=8, seed=None): 24 | if not kshot: 25 | return "" 26 | 27 | samples_idx = ( 28 | np.random.default_rng(seed=seed) 29 | .permutation(len(prompt_dataset))[:kshot] 30 | .tolist() 31 | ) 32 | 33 | fewshot_samples_prompt = [ 34 | str(LMText.from_(prompt_dataset[idx])) + "\n" for idx in samples_idx 35 | ] 36 | 37 | if format == PromptFormat.OE: 38 | prompt = [ 39 | "The following are math questions (with answers).\n", 40 | *fewshot_samples_prompt, 41 | "Now, answer the next question.\n\n", 42 | ] 43 | else: 44 | raise NotImplementedError(f"Unsupported prompt format {format}.") 45 | 46 | return "\n".join(prompt) 47 | 48 | 49 | def get_gsm8k( 50 | prompt_style=None, 51 | train_kshot=0, 52 | eval_kshot=8, 53 | tokenizer=None, 54 | num_workers=8, 55 | seed=None, 56 | use_cache=True, 57 | **_, 58 | ): 59 | format = PromptFormat(prompt_style) 60 | 61 | dataset = load_dataset("gsm8k", "main") 62 | if not use_cache: 63 | dataset.cleanup_cache_files() 64 | 65 | dataset = dataset.map( 66 | lambda sample: format_sample(sample, format).to_pydict(), 67 | num_proc=num_workers, 68 | remove_columns=dataset.column_names["test"], 69 | ) 70 | 71 | prompt_data = dataset.get("train") 72 | prompt_kshot = { 73 | "train": train_kshot, 74 | "test": eval_kshot, 75 | } 76 | 77 | data_splits = { 78 | split: ds.map( 79 | lambda _, idx: { 80 | "prompt": format_sample_prompt( 81 | prompt_data, 82 | format, 83 | kshot=prompt_kshot[split], 84 | seed=seed + idx, 85 | ) 86 | }, 87 | with_indices=True, 88 | num_proc=num_workers, 89 | ) 90 | for split, ds in dataset.items() 91 | } 92 | 93 | train_data = data_splits.pop("train", None) 94 | val_data = data_splits.pop("validation", None) 95 | test_data = data_splits.pop("test", None) 96 | 97 | return train_data, val_data, test_data 98 | 99 | 100 | @register_dataset(attrs=dict(tags=[DatasetTag.EVAL_ONLY])) 101 | def gsm8k(*args, **kwargs): 102 | return get_gsm8k(*args, **kwargs) 103 | -------------------------------------------------------------------------------- /llm/models/qwen.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from peft import prepare_model_for_kbit_training 3 | import torch 4 | from transformers import AutoTokenizer, AutoModelForCausalLM 5 | 6 | 7 | from .registry import register_model 8 | 9 | 10 | __QWEN_HF_MODEL_MAP = { 11 | "7b-instruct": "Qwen/Qwen2-7B-Instruct", 12 | } 13 | 14 | 15 | def __get_model_hf_id(model_str, model_map): 16 | try: 17 | model_name, kind = model_str.split(":") 18 | 19 | assert kind in model_map.keys() 20 | except ValueError: 21 | logging.exception( 22 | f'Model string should be formatted as "{model_name}:" (Got {model_str})', 23 | ) 24 | raise 25 | except AssertionError: 26 | logging.exception( 27 | f'Model not found. Model string should be formatted as "{model_name}:" (Got {model_str})', 28 | ) 29 | raise 30 | 31 | return model_map[kind] 32 | 33 | 34 | def create_tokenizer( 35 | kind, 36 | model_dir=None, 37 | padding_side="left", 38 | model_max_length=131_072, 39 | **kwargs, 40 | ): 41 | tokenizer = AutoTokenizer.from_pretrained( 42 | model_dir or kind, 43 | padding_side=padding_side, 44 | model_max_length=model_max_length, 45 | use_fast=True, 46 | legacy=False, 47 | **kwargs, 48 | ) 49 | 50 | return tokenizer 51 | 52 | 53 | def create_model( 54 | kind, 55 | torch_dtype=None, 56 | model_dir=None, 57 | use_cache=False, 58 | tokenizer=None, 59 | use_int8=False, 60 | use_int4=False, 61 | **kwargs, 62 | ): 63 | quantization_config = None 64 | if use_int4 or use_int8: 65 | from transformers import BitsAndBytesConfig 66 | 67 | quantization_config = BitsAndBytesConfig( 68 | load_in_4bit=use_int4, 69 | load_in_8bit=use_int8, 70 | ) 71 | 72 | model = AutoModelForCausalLM.from_pretrained( 73 | model_dir or kind, 74 | torch_dtype=torch_dtype 75 | or (torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16), 76 | quantization_config=quantization_config, 77 | use_cache=use_cache, 78 | **kwargs, 79 | ) 80 | 81 | model.config.pad_token_id = tokenizer.pad_token_id 82 | 83 | if use_int4 or use_int8: 84 | model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False) 85 | 86 | return model 87 | 88 | 89 | def create_tokenizer_and_model(kind, tokenizer_args=None, **kwargs): 90 | tokenizer = create_tokenizer(kind, **(tokenizer_args or dict())) 91 | model = create_model(kind, tokenizer=tokenizer, **kwargs) 92 | return tokenizer, model 93 | 94 | 95 | @register_model 96 | def qwen2_tokenizer(*, model_str=None, **kwargs): 97 | return create_tokenizer(__get_model_hf_id(model_str, __QWEN_HF_MODEL_MAP), **kwargs) 98 | 99 | 100 | @register_model 101 | def qwen2(*, model_str=None, **kwargs): 102 | return create_tokenizer_and_model( 103 | __get_model_hf_id(model_str, __QWEN_HF_MODEL_MAP), **kwargs 104 | ) 105 | -------------------------------------------------------------------------------- /notebooks/user-study.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "import json\n", 11 | "import seaborn as sns\n", 12 | "import matplotlib.pyplot as plt\n", 13 | "\n", 14 | "sns.set_style(\"whitegrid\")\n", 15 | "sns.set_context(\"notebook\", font_scale=1.25)" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "with open(\"mmlu_user_data_NEW.json\", \"r\") as f:\n", 25 | " saved_data = json.load(f)\n", 26 | "\n", 27 | "df = []\n", 28 | "for user_id, user_data in saved_data.items():\n", 29 | " user_df = pd.read_json(user_data[\"user_df\"])\n", 30 | " # user_df[\"comments\"] = user_data[\"comments\"]\n", 31 | " df.append(user_df)\n", 32 | "df = pd.concat(df, ignore_index=True)\n", 33 | "df[\"didRely\"] = df.apply(\n", 34 | " lambda row: \"Agree\" if row[\"llm_answer\"] == row[\"response\"] else \"Disagree\", axis=1\n", 35 | ")" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "for mode in [\"query\", \"ct\", \"rand\"]:\n", 45 | " plt_df = df[df.variant == f\"justMConf_mistral_{mode}\"]\n", 46 | "\n", 47 | " fig, ax = plt.subplots(figsize=(5, 3))\n", 48 | "\n", 49 | " g = sns.histplot(\n", 50 | " plt_df,\n", 51 | " x=\"model_confidence\",\n", 52 | " hue=\"didRely\",\n", 53 | " stat=\"probability\",\n", 54 | " bins=20,\n", 55 | " kde=True,\n", 56 | " ax=ax,\n", 57 | " palette=[sns.color_palette(\"Paired\")[7], sns.color_palette(\"Paired\")[3]],\n", 58 | " hue_order=[\"Disagree\", \"Agree\"],\n", 59 | " )\n", 60 | "\n", 61 | " if mode == \"query\":\n", 62 | " ax.get_legend().set(title=\"\", loc=\"upper left\")\n", 63 | " else:\n", 64 | " ax.get_legend().remove()\n", 65 | "\n", 66 | " g.set(xlabel=r\"Model Confidence ($\\%$)\", ylabel=r\"Proportion ($\\%$)\")\n", 67 | " g.figure.show()\n", 68 | " # g.figure.savefig(f\"user_conf_{mode}.pdf\", bbox_inches=\"tight\")" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [] 77 | } 78 | ], 79 | "metadata": { 80 | "kernelspec": { 81 | "display_name": "Python 3", 82 | "language": "python", 83 | "name": "python3" 84 | }, 85 | "language_info": { 86 | "codemirror_mode": { 87 | "name": "ipython", 88 | "version": 3 89 | }, 90 | "file_extension": ".py", 91 | "mimetype": "text/x-python", 92 | "name": "python", 93 | "nbconvert_exporter": "python", 94 | "pygments_lexer": "ipython3", 95 | "version": "3.11.8" 96 | } 97 | }, 98 | "nbformat": 4, 99 | "nbformat_minor": 2 100 | } 101 | -------------------------------------------------------------------------------- /notebooks/mmlu_pro.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "df = pd.read_csv(\"results/raw.csv\")\n", 19 | "df.scale_temp = df.scale_temp.apply(lambda x: \"base\" if pd.isna(x) else x)\n", 20 | "df = df.sort_values(\"dataset\")\n", 21 | "df.scale_temp.unique()" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "# def agg(df):\n", 31 | "# df = df.reset_index()\n", 32 | "\n", 33 | "# total_N = df.N.sum()\n", 34 | "\n", 35 | "# return pd.DataFrame(\n", 36 | "# [\n", 37 | "# {\n", 38 | "# \"N\": total_N,\n", 39 | "# \"unc_ece\": (df.N * df.unc_ece).sum() / total_N,\n", 40 | "# \"unc_auroc\": (df.N * df.unc_auroc).sum() / total_N,\n", 41 | "# }\n", 42 | "# ]\n", 43 | "# )\n", 44 | "\n", 45 | "\n", 46 | "# grouped_data = df.groupby([\"model_name\", \"scale_temp\", \"split\"])\n", 47 | "# grouped_data[[\"N\", \"unc_auroc\", \"unc_ece\"]].apply(\n", 48 | "# agg\n", 49 | "# ) # .reset_index().drop(columns=[\"level_3\"])" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "df.model_name.unique()" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "import numpy as np\n", 68 | "np.mean(df[df.scale_temp == \"base\"].unc_auroc.values <= df[df.scale_temp == \"query\"].unc_auroc.values) #[[\"dataset\", \"unc_ece\"]]" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "import numpy as np\n", 78 | "np.mean(df[df.scale_temp == \"base\"].unc_ece.values >= df[df.scale_temp == \"query\"].unc_ece.values)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "df[df.scale_temp == \"query\"][[\"dataset\", \"unc_ece\"]]" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [] 96 | } 97 | ], 98 | "metadata": { 99 | "kernelspec": { 100 | "display_name": "Python 3", 101 | "language": "python", 102 | "name": "python3" 103 | }, 104 | "language_info": { 105 | "codemirror_mode": { 106 | "name": "ipython", 107 | "version": 3 108 | }, 109 | "file_extension": ".py", 110 | "mimetype": "text/x-python", 111 | "name": "python", 112 | "nbconvert_exporter": "python", 113 | "pygments_lexer": "ipython3", 114 | "version": "3.11.9" 115 | } 116 | }, 117 | "nbformat": 4, 118 | "nbformat_minor": 2 119 | } 120 | -------------------------------------------------------------------------------- /llm/datasets/offline/mmlu_offline.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from ..registry import register_dataset, DatasetTag 4 | from ..hf.mmlu import __TASK_CATEGORIES 5 | from .offline import get_offline 6 | from .offline_logits import get_offline_logits 7 | 8 | 9 | @register_dataset( 10 | attrs=dict(task_categories=__TASK_CATEGORIES, tags=[DatasetTag.EVAL_ONLY]) 11 | ) 12 | def mmlu_offline( 13 | root=None, dataset_str=None, prompt_style=None, eval_kshot=5, **kwargs 14 | ): 15 | try: 16 | _, name, task = dataset_str.split(":") 17 | 18 | assert task in __TASK_CATEGORIES.keys() 19 | except ValueError: 20 | logging.exception( 21 | f'Dataset string should be formatted as "mmlu_offline::" (Got {dataset_str})', 22 | ) 23 | raise 24 | except AssertionError: 25 | logging.exception( 26 | f'Task not found. Dataset string should be formatted as "mmlu_offline::" (Got {dataset_str})', 27 | ) 28 | raise 29 | 30 | root = f"{root}/mmlu_offline/{prompt_style}/{name}/{task}" 31 | 32 | return get_offline(root=root, eval_kshot=eval_kshot, **kwargs) 33 | 34 | 35 | @register_dataset(attrs=dict(unlisted=True, collection=True)) 36 | def mmlu_offline_all(dataset_str=None, **_): 37 | try: 38 | _, name = dataset_str.split(":") 39 | except ValueError: 40 | logging.exception( 41 | f'Dataset string should be formatted as "mmlu_offline_all:" (Got {dataset_str})', 42 | ) 43 | raise 44 | 45 | return [f"mmlu_offline:{name}:{task}" for task in __TASK_CATEGORIES.keys()] 46 | 47 | 48 | @register_dataset( 49 | attrs=dict(task_categories=__TASK_CATEGORIES, tags=[DatasetTag.EVAL_ONLY]) 50 | ) 51 | def mmlu_offline_query_logits(root=None, dataset_str=None, **kwargs): 52 | name, kind, dataset = dataset_str.split(":") 53 | 54 | root = f"{root}/{name}/{kind}/{dataset}" 55 | 56 | return get_offline_logits(root=root, **kwargs) 57 | 58 | 59 | @register_dataset(attrs=dict(unlisted=True, collection=True)) 60 | def mmlu_offline_query_logits_all(dataset_str=None, **_): 61 | try: 62 | _, name = dataset_str.split(":") 63 | except ValueError: 64 | logging.exception( 65 | f'Dataset string should be formatted as "mmlu_offline_query_logits_all:" (Got {dataset_str})', 66 | ) 67 | raise 68 | 69 | return [ 70 | f"mmlu_offline_query_logits:{name}:{task}" for task in __TASK_CATEGORIES.keys() 71 | ] 72 | 73 | 74 | @register_dataset( 75 | attrs=dict(task_categories=__TASK_CATEGORIES, tags=[DatasetTag.EVAL_ONLY]) 76 | ) 77 | def mmlu_offline_ve_logits(root=None, dataset_str=None, **kwargs): 78 | name, kind, dataset = dataset_str.split(":") 79 | 80 | root = f"{root}/{name}/{kind}/{dataset}" 81 | 82 | return get_offline_logits(root=root, **kwargs) 83 | 84 | 85 | @register_dataset(attrs=dict(unlisted=True, collection=True)) 86 | def mmlu_offline_ve_logits_all(dataset_str=None, **_): 87 | try: 88 | _, name = dataset_str.split(":") 89 | except ValueError: 90 | logging.exception( 91 | f'Dataset string should be formatted as "mmlu_offline_ve_logits_all:" (Got {dataset_str})', 92 | ) 93 | raise 94 | 95 | return [ 96 | f"mmlu_offline_ve_logits:{name}:{task}" for task in __TASK_CATEGORIES.keys() 97 | ] 98 | -------------------------------------------------------------------------------- /experiments/fine_tune.py: -------------------------------------------------------------------------------- 1 | from llm.datasets import get_dataset 2 | from llm.distributed import AcceleratorState 3 | from llm.logging import entrypoint 4 | from llm.models import get_model 5 | from llm.models.peft import get_lora_model, get_temperature_scale_model 6 | from llm.trainer import WandbConfigUpdateCallback, FineTuner 7 | 8 | 9 | @entrypoint 10 | def main( 11 | seed=137, 12 | log_dir=None, 13 | dataset=None, 14 | data_dir=None, 15 | prompt_style=None, 16 | max_token_length=None, 17 | num_workers=4, 18 | use_dataset_cache=True, 19 | model_name=None, 20 | int8=True, 21 | lora_rank=8, 22 | lora_alpha=32, 23 | lora_dropout=0.1, 24 | peft_dir=None, 25 | scale_temp=False, 26 | batch_size=1, 27 | lr=1e-4, 28 | warmup_ratio=0.0, 29 | max_steps=1, 30 | ): 31 | accelerator = AcceleratorState() 32 | 33 | trainer_args = FineTuner.Args( 34 | seed=seed, 35 | output_dir=log_dir, 36 | max_steps=max_steps, 37 | eval_steps=max_steps // 10, 38 | save_steps=max_steps // 10, 39 | logging_steps=max(1, max_steps // 200), 40 | dataloader_num_workers=num_workers, 41 | per_device_train_batch_size=batch_size, 42 | per_device_eval_batch_size=batch_size, 43 | learning_rate=lr, 44 | warmup_ratio=warmup_ratio, 45 | scale_temp=scale_temp, 46 | ) 47 | 48 | with accelerator.main_process_first(): 49 | train_data, val_data, test_data = get_dataset( 50 | dataset, 51 | root=data_dir, 52 | seed=seed, 53 | prompt_style=prompt_style, 54 | max_token_length=max_token_length, 55 | num_workers=num_workers, 56 | use_cache=use_dataset_cache, 57 | ) 58 | if scale_temp: 59 | train_data, val_data = val_data, test_data or val_data 60 | 61 | tokenizer, model = get_model( 62 | model_name, 63 | device_map={"": accelerator.local_process_index}, 64 | use_int8=int8, 65 | ) 66 | 67 | model = get_lora_model( 68 | model, 69 | peft_id_or_dir=peft_dir, 70 | lora_rank=lora_rank, 71 | lora_alpha=lora_alpha, 72 | lora_dropout=lora_dropout, 73 | is_trainable=not scale_temp, 74 | adapter_name="default", 75 | ) 76 | 77 | if scale_temp: 78 | model = get_temperature_scale_model( 79 | model, 80 | checkpoint_dir=peft_dir, 81 | is_trainable=True, 82 | weights_name=FineTuner.TEMPERATURE_WEIGHTS_NAME, 83 | ) 84 | 85 | trainer = FineTuner( 86 | model=model, 87 | args=trainer_args, 88 | train_dataset=train_data, 89 | eval_dataset=val_data, 90 | tokenizer=tokenizer, 91 | callbacks=[ 92 | WandbConfigUpdateCallback( 93 | dataset=dataset, 94 | prompt_style=prompt_style, 95 | max_token_length=max_token_length, 96 | model_name=model_name, 97 | lora_rank=lora_rank, 98 | lora_alpha=lora_alpha, 99 | lora_dropout=lora_dropout, 100 | peft_dir=peft_dir, 101 | ), 102 | ], 103 | ) 104 | trainer.train() 105 | 106 | 107 | if __name__ == "__main__": 108 | import fire 109 | 110 | fire.Fire(main) 111 | -------------------------------------------------------------------------------- /llm/eval/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from sklearn.metrics import roc_auc_score 4 | import torch 5 | import torch.nn.functional as F 6 | from transformers import GenerationConfig 7 | from peft import PeftModel 8 | 9 | from ..datasets import LabeledStringDataCollator 10 | from ..datasets.llm_utils_oe import sanitize_generations 11 | from .third_party.calibration import calibration 12 | 13 | 14 | DATA_FILE_NAME = "data.bin" 15 | 16 | 17 | def save_metrics_data(data, log_dir=None, filename=DATA_FILE_NAME): 18 | if log_dir is None: 19 | return 20 | 21 | os.makedirs(log_dir, exist_ok=True) 22 | 23 | torch.save(data, f"{log_dir}/{filename}") 24 | 25 | logging.info(f'Metrics data saved to "{log_dir}/{filename}".') 26 | 27 | 28 | def compute_auroc(labels, probs, multi_class="ovr", **kwargs): 29 | one_hot_labels = ( 30 | F.one_hot(labels, num_classes=probs.size(-1)) if labels.ndim == 1 else labels 31 | ) 32 | 33 | try: 34 | auroc = roc_auc_score(one_hot_labels, probs.float(), multi_class=multi_class, **kwargs) 35 | except ValueError: 36 | auroc = float("nan") 37 | logging.exception("AUROC calculation failed.", exc_info=True) 38 | 39 | return auroc 40 | 41 | 42 | def compute_uncertainty_metrics(labels, logits, prefix=""): 43 | """ 44 | Arguments: 45 | labels: Shape (N,) 46 | logits: Shape (N, 2) 47 | """ 48 | p = logits.softmax(dim=-1) 49 | 50 | pred = p.argmax(dim=-1) 51 | acc = (pred == labels).float().mean(dim=0) 52 | 53 | ece, _ = calibration( 54 | labels, 55 | pred, 56 | p[torch.arange(p.size(0)), pred].float(), 57 | ) 58 | 59 | auroc = compute_auroc(labels, p) 60 | 61 | return { 62 | "N": labels.size(0), 63 | f"{prefix}acc": acc.item(), 64 | f"{prefix}auroc": auroc, 65 | f"{prefix}ece": ece, 66 | } 67 | 68 | 69 | def get_model_generations( 70 | accelerator, 71 | model, 72 | tokenizer, 73 | lmtext_inputs, 74 | max_new_tokens=None, 75 | adapter_name="default", 76 | ): 77 | config = GenerationConfig( 78 | pad_token_id=tokenizer.pad_token_id, 79 | bos_token_id=tokenizer.bos_token_id, 80 | eos_token_id=tokenizer.eos_token_id, 81 | max_new_tokens=max_new_tokens, 82 | do_sample=False, 83 | return_dict_in_generate=True, 84 | output_logits=True, 85 | output_hidden_states=True, 86 | ) 87 | 88 | if max_new_tokens is None: 89 | logging.warning(f"max_new_tokens is None.") 90 | 91 | collate_fn = LabeledStringDataCollator(tokenizer) 92 | 93 | inputs = collate_fn(lmtext_inputs) 94 | inputs = {k: v.to(accelerator.device) for k, v in inputs.items()} 95 | 96 | if isinstance(model, PeftModel): 97 | active_adapter = model.active_adapter 98 | model.set_adapter(adapter_name) 99 | 100 | outputs = model.generate(**inputs, generation_config=config) 101 | 102 | if isinstance(model, PeftModel): 103 | model.set_adapter(active_adapter) 104 | 105 | str_outputs = tokenizer.batch_decode( 106 | outputs.sequences[:, inputs.get("input_ids").size(-1) :], 107 | skip_special_tokens=True, 108 | clean_up_tokenization_spaces=False, 109 | ) 110 | str_outputs = sanitize_generations(str_outputs) 111 | 112 | return str_outputs, outputs 113 | -------------------------------------------------------------------------------- /llm/eval/third_party/calibration.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Uncertainty Baselines Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | # Source: https://github.com/google/uncertainty-baselines/blob/main/baselines/mnist/utils.py 17 | # 18 | 19 | import numpy as np 20 | import torch 21 | import logging 22 | 23 | 24 | def calibration(y, class_pred, conf, num_bins=10): 25 | """Compute the calibration. 26 | 27 | References: 28 | https://arxiv.org/abs/1706.04599 29 | https://arxiv.org/abs/1807.00263 30 | 31 | Args: 32 | y: one-hot encoding of the true classes, size (?, num_classes) 33 | p_mean: numpy array, size (?, num_classes) 34 | containing the mean output predicted probabilities 35 | num_bins: number of bins 36 | 37 | Returns: 38 | ece: Expected Calibration Error 39 | mce: Maximum Calibration Error 40 | """ 41 | if isinstance(y, torch.Tensor): 42 | y = y.cpu().numpy() 43 | class_pred = class_pred.cpu().numpy() 44 | conf = conf.cpu().numpy() 45 | # Compute for every test sample x, the predicted class. 46 | # class_pred = np.argmax(p_mean, axis=1) 47 | # and the confidence (probability) associated with it. 48 | # conf = np.max(p_mean, axis=1) 49 | # Convert y from one-hot encoding to the number of the class 50 | # y = np.argmax(y, axis=1) 51 | # Storage 52 | acc_tab = np.zeros(num_bins) # empirical (true) confidence 53 | mean_conf = np.zeros(num_bins) # predicted confidence 54 | nb_items_bin = np.zeros(num_bins) # number of items in the bins 55 | tau_tab = np.linspace(0, 1, num_bins + 1) # confidence bins 56 | for i in np.arange(num_bins): # iterate over the bins 57 | # select the items where the predicted max probability falls in the bin 58 | # [tau_tab[i], tau_tab[i + 1)] 59 | sec = (tau_tab[i + 1] > conf) & (conf >= tau_tab[i]) 60 | nb_items_bin[i] = np.sum(sec) # Number of items in the bin 61 | # select the predicted classes, and the true classes 62 | class_pred_sec, y_sec = class_pred[sec], y[sec] 63 | # average of the predicted max probabilities 64 | mean_conf[i] = np.mean(conf[sec]) if nb_items_bin[i] > 0 else np.nan 65 | # compute the empirical confidence 66 | acc_tab[i] = np.mean(class_pred_sec == y_sec) if nb_items_bin[i] > 0 else np.nan 67 | 68 | # Cleaning 69 | mean_conf = mean_conf[nb_items_bin > 0] 70 | acc_tab = acc_tab[nb_items_bin > 0] 71 | nb_items_bin = nb_items_bin[nb_items_bin > 0] 72 | 73 | if len(nb_items_bin) == 0: 74 | logging.warning("ECE computation failed.") 75 | return float("nan"), float("nan") 76 | 77 | # Expected Calibration Error 78 | ece = np.average( 79 | np.absolute(mean_conf - acc_tab), 80 | weights=nb_items_bin.astype(float) / np.sum(nb_items_bin), 81 | ) 82 | # Maximum Calibration Error 83 | mce = np.max(np.absolute(mean_conf - acc_tab)) 84 | return ece, mce 85 | -------------------------------------------------------------------------------- /llm/datasets/hf/mmlu_pro.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from datasets import load_dataset 3 | 4 | from ..registry import register_dataset, DatasetTag 5 | from ..llm_data_utils import LMText, PromptFormat 6 | from .mmlu import format_sample, format_sample_prompt 7 | 8 | 9 | __TASKS = [ 10 | "biology", 11 | "business", 12 | "chemistry", 13 | "computer_science", 14 | "economics", 15 | "engineering", 16 | "health", 17 | "history", 18 | "law", 19 | "math", 20 | "other", 21 | "philosophy", 22 | "physics", 23 | "psychology", 24 | ] 25 | 26 | 27 | def get_mmlu_pro( 28 | task=None, 29 | prompt_style=None, 30 | eval_kshot=5, 31 | num_workers=8, 32 | seed=None, 33 | use_cache=True, 34 | **_, 35 | ): 36 | format = PromptFormat(prompt_style) 37 | task = " ".join(task.split("_")) 38 | 39 | dataset = load_dataset("TIGER-Lab/MMLU-Pro") 40 | if not use_cache: 41 | dataset.cleanup_cache_files() 42 | 43 | dataset = ( 44 | dataset.filter(lambda s: s["category"] == task) 45 | .rename_columns( 46 | { 47 | "options": "choices", 48 | "answer": "answer_choice", 49 | "answer_index": "answer", 50 | } 51 | ) 52 | .remove_columns( 53 | ["question_id", "answer_choice", "cot_content", "category", "src"] 54 | ) 55 | ) 56 | 57 | dataset = dataset.map( 58 | lambda sample: format_sample(sample, format).to_pydict(), 59 | num_proc=num_workers, 60 | remove_columns=dataset.column_names["test"], 61 | ) 62 | 63 | prompt_label = (" ".join(task.split("_"))).capitalize() 64 | prompt_data = dataset.pop("validation") 65 | prompt_kshot = { 66 | "validation": eval_kshot, 67 | "test": eval_kshot, 68 | } 69 | 70 | data_splits = { 71 | split: ds.map( 72 | lambda _, idx: { 73 | "prompt": format_sample_prompt( 74 | prompt_data, 75 | prompt_label, 76 | format, 77 | kshot=prompt_kshot[split], 78 | seed=seed + idx, 79 | ) 80 | }, 81 | with_indices=True, 82 | num_proc=num_workers, 83 | ) 84 | for split, ds in dataset.items() 85 | } 86 | 87 | train_data = data_splits.pop("train", None) 88 | val_data = data_splits.pop("validation", None) 89 | test_data = data_splits.pop("test", None) 90 | 91 | return train_data, val_data, test_data 92 | 93 | 94 | @register_dataset(attrs=dict(task_categories=__TASKS, tags=[DatasetTag.EVAL_ONLY])) 95 | def mmlu_pro(*args, dataset_str=None, **kwargs): 96 | try: 97 | _, task = dataset_str.split(":") 98 | 99 | assert task in __TASKS 100 | except ValueError: 101 | logging.exception( 102 | f'Dataset string should be formatted as "mmlu_pro:" (Got {dataset_str})', 103 | ) 104 | raise 105 | except AssertionError: 106 | logging.exception( 107 | f'Task not found. Dataset string should be formatted as "mmlu_pro:" (Got {dataset_str})', 108 | ) 109 | raise 110 | 111 | return get_mmlu_pro(*args, **kwargs, task=task) 112 | 113 | 114 | @register_dataset(attrs=dict(unlisted=True, collection=True)) 115 | def mmlu_pro_all(*args, **kwargs): 116 | return [f"mmlu_pro:{task}" for task in __TASKS] 117 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | ./logs 163 | ./datasets 164 | -------------------------------------------------------------------------------- /experiments/classifier_tune.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from llm.datasets import get_dataset 4 | from llm.distributed import AcceleratorState 5 | from llm.logging import entrypoint 6 | from llm.models import get_model 7 | from llm.models.peft import get_lora_model, get_classifier_head, get_temperature_head 8 | from llm.trainer import WandbConfigUpdateCallback, ClassificationTuner 9 | 10 | 11 | @entrypoint 12 | def main( 13 | seed=137, 14 | log_dir=None, 15 | dataset=None, 16 | data_dir=None, 17 | prompt_style=None, 18 | max_token_length=None, 19 | num_workers=4, 20 | use_dataset_cache=True, 21 | model_name=None, 22 | int8=True, 23 | lora_rank=8, 24 | lora_alpha=32, 25 | lora_dropout=0.1, 26 | peft_dir=None, 27 | with_lora=False, 28 | scale_temp=False, 29 | batch_size=1, 30 | warmup_ratio=0.0, 31 | lr=1e-4, 32 | max_steps=1, 33 | ): 34 | accelerator = AcceleratorState() 35 | 36 | trainer_args = ClassificationTuner.Args( 37 | seed=seed, 38 | output_dir=log_dir, 39 | max_steps=max_steps, 40 | eval_steps=max_steps // 10, 41 | save_steps=max_steps // 10, 42 | logging_steps=max(1, max_steps // 200), 43 | dataloader_num_workers=num_workers, 44 | per_device_train_batch_size=batch_size, 45 | per_device_eval_batch_size=batch_size, 46 | learning_rate=lr, 47 | warmup_ratio=warmup_ratio, 48 | with_lora=with_lora, 49 | ) 50 | 51 | with accelerator.main_process_first(): 52 | train_data, val_data, test_data = get_dataset( 53 | dataset, 54 | root=data_dir, 55 | seed=seed, 56 | prompt_style=prompt_style, 57 | max_token_length=max_token_length, 58 | num_workers=num_workers, 59 | use_cache=use_dataset_cache, 60 | ) 61 | if scale_temp: 62 | train_data, val_data = val_data, test_data or val_data 63 | 64 | tokenizer, model = get_model( 65 | model_name, 66 | device_map={"": accelerator.local_process_index}, 67 | use_int8=int8, 68 | ) 69 | 70 | model = get_lora_model( 71 | model, 72 | peft_id_or_dir=peft_dir, 73 | lora_rank=lora_rank, 74 | lora_alpha=lora_alpha, 75 | lora_dropout=lora_dropout, 76 | is_trainable=with_lora and not scale_temp, 77 | adapter_name="default", 78 | ) 79 | 80 | classifier_model = get_classifier_head( 81 | input_size=model.config.hidden_size, 82 | checkpoint_dir=peft_dir, 83 | is_trainable=not scale_temp, 84 | weights_name=ClassificationTuner.WEIGHTS_NAME, 85 | ) 86 | 87 | if scale_temp: 88 | temperature_model = get_temperature_head( 89 | checkpoint_dir=peft_dir, 90 | is_trainable=True, 91 | ) 92 | 93 | classifier_model = torch.nn.Sequential( 94 | classifier_model, 95 | temperature_model, 96 | ) 97 | 98 | model.classifier_model = classifier_model.to(model.dtype) 99 | 100 | trainer = ClassificationTuner( 101 | model=model, 102 | classifier_model=classifier_model, 103 | args=trainer_args, 104 | train_dataset=train_data, 105 | eval_dataset=val_data, 106 | tokenizer=tokenizer, 107 | callbacks=[ 108 | WandbConfigUpdateCallback( 109 | dataset=dataset, 110 | prompt_style=prompt_style, 111 | max_token_length=max_token_length, 112 | model_name=model_name, 113 | lora_rank=lora_rank, 114 | lora_alpha=lora_alpha, 115 | lora_dropout=lora_dropout, 116 | peft_dir=peft_dir, 117 | ), 118 | ], 119 | ) 120 | trainer.train() 121 | 122 | 123 | if __name__ == "__main__": 124 | import fire 125 | 126 | fire.Fire(main) 127 | -------------------------------------------------------------------------------- /llm/models/mistral.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from peft import prepare_model_for_kbit_training 3 | import torch 4 | from transformers import AutoTokenizer, MistralForCausalLM 5 | 6 | 7 | from .registry import register_model 8 | from .llm_model_utils import DEFAULT_PAD_TOKEN, resize_token_embeddings 9 | 10 | 11 | __MISTRAL_HF_MODEL_MAP = { 12 | "7b": "Mistral-7B-v0.1", 13 | "7b-instruct": "Mistral-7B-Instruct-v0.2", 14 | } 15 | 16 | __MIXTRAL_HF_MODEL_MAP = { 17 | "8x22b": "Mixtral-8x22B-v0.1", 18 | "8x22b-instruct": "Mixtral-8x22B-Instruct-v0.1", 19 | } 20 | 21 | 22 | def __get_model_hf_id(model_str, model_map): 23 | try: 24 | model_name, kind = model_str.split(":") 25 | 26 | assert kind in model_map.keys() 27 | except ValueError: 28 | logging.exception( 29 | f'Model string should be formatted as "{model_name}:" (Got {model_str})', 30 | ) 31 | raise 32 | except AssertionError: 33 | logging.exception( 34 | f'Model not found. Model string should be formatted as "{model_name}:" (Got {model_str})', 35 | ) 36 | raise 37 | 38 | return model_map[kind] 39 | 40 | 41 | def create_tokenizer( 42 | kind, 43 | model_dir=None, 44 | padding_side="left", 45 | model_max_length=8192, 46 | **kwargs, 47 | ): 48 | tokenizer = AutoTokenizer.from_pretrained( 49 | model_dir or f"mistralai/{kind}", 50 | padding_side=padding_side, 51 | model_max_length=model_max_length, 52 | use_fast=True, 53 | legacy=False, 54 | **kwargs, 55 | ) 56 | 57 | tokenizer.add_special_tokens({"pad_token": DEFAULT_PAD_TOKEN}) 58 | 59 | return tokenizer 60 | 61 | 62 | def create_model( 63 | kind, 64 | torch_dtype=None, 65 | model_dir=None, 66 | use_cache=False, 67 | tokenizer=None, 68 | use_int8=False, 69 | use_int4=False, 70 | **kwargs, 71 | ): 72 | quantization_config = None 73 | if use_int4 or use_int8: 74 | from transformers import BitsAndBytesConfig 75 | 76 | quantization_config = BitsAndBytesConfig( 77 | load_in_4bit=use_int4, 78 | load_in_8bit=use_int8, 79 | ) 80 | 81 | model = MistralForCausalLM.from_pretrained( 82 | model_dir or f"mistralai/{kind}", 83 | torch_dtype=torch_dtype 84 | or (torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16), 85 | quantization_config=quantization_config, 86 | use_cache=use_cache, 87 | **kwargs, 88 | ) 89 | 90 | model.config.pad_token_id = tokenizer.pad_token_id 91 | 92 | resize_token_embeddings(tokenizer, model) 93 | 94 | if use_int4 or use_int8: 95 | model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False) 96 | 97 | return model 98 | 99 | 100 | def create_tokenizer_and_model(kind, tokenizer_args=None, **kwargs): 101 | tokenizer = create_tokenizer(kind, **(tokenizer_args or dict())) 102 | model = create_model(kind, tokenizer=tokenizer, **kwargs) 103 | return tokenizer, model 104 | 105 | 106 | @register_model 107 | def mistral_tokenizer(*, model_str=None, **kwargs): 108 | return create_tokenizer( 109 | __get_model_hf_id(model_str, __MISTRAL_HF_MODEL_MAP), **kwargs 110 | ) 111 | 112 | 113 | @register_model 114 | def mistral(*, model_str=None, **kwargs): 115 | return create_tokenizer_and_model( 116 | __get_model_hf_id(model_str, __MISTRAL_HF_MODEL_MAP), **kwargs 117 | ) 118 | 119 | 120 | @register_model 121 | def mixtral_tokenizer(*, model_str=None, **kwargs): 122 | return create_tokenizer( 123 | __get_model_hf_id(model_str, __MIXTRAL_HF_MODEL_MAP), **kwargs 124 | ) 125 | 126 | 127 | @register_model 128 | def mixtral(*, model_str=None, **kwargs): 129 | return create_tokenizer_and_model( 130 | __get_model_hf_id(model_str, __MIXTRAL_HF_MODEL_MAP), **kwargs 131 | ) 132 | -------------------------------------------------------------------------------- /notebooks/results/oe_sampling.csv: -------------------------------------------------------------------------------- 1 | ,N,substring_acc,substring_ece_counting,substring_ece_likelihood,substring_ece_likelihood_normalized,fuzzy_gpt-3.5-turbo-1106_acc,fuzzy_gpt-3.5-turbo-1106_ece_counting,fuzzy_gpt-3.5-turbo-1106_ece_likelihood,fuzzy_gpt-3.5-turbo-1106_ece_likelihood_normalized,split,seed,model_name,model_dir,peft_dir,query_peft_dir,eval_kshot,prompt_style,mode,output_row_path,dataset,ts 2 | 0,11,0.0,0.2,0.0,0.1873283402639814,0.27272728085517883,0.2545454545454545,0.2,0.2250199733542388,validation,137,llama2_7b,,,,,oe,us_oe,/home/manley/testers/,mmlu:medical_genetics,11037.338716903934 3 | 1,100,0.029999999329447746,0.2030927835051547,0.0,0.2116412126868303,0.3199999928474426,0.09175257731958765,0.14285714285714285,0.17423638393309368,test,137,llama2_7b,,,,,oe,us_oe,/home/manley/testers/,mmlu:medical_genetics,11037.338716903934 4 | 2,29,0.0,0.07857142857142857,0.0,0.06813938962480555,0.1034482792019844,0.1,0.0625,0.08397354189489278,validation,137,llama2_7b,,,,,oe,us_oe,/home/manley/testers/,mmlu:clinical_knowledge,30631.830435613985 5 | 3,265,0.01886792480945587,0.1520912547528517,0.00819672131147541,0.13109914736112063,0.22264151275157928,0.07832699619771864,0.11475409836065574,0.1125430521074833,test,137,llama2_7b,,,,,oe,us_oe,/home/manley/testers/,mmlu:clinical_knowledge,30631.830435613985 6 | 4,14,0.1428571492433548,0.30714285714285716,0.25,0.2434660787259879,0.3571428656578064,0.20714285714285713,0.25,0.19482945050234537,validation,137,llama2_7b,,,,,oe,us_oe,/home/manley/testers/,mmlu:anatomy,13893.812293368974 7 | 5,135,0.10370370000600815,0.24436090225563914,0.06060606060606061,0.23456608558558878,0.29629629850387573,0.12406015037593986,0.12121212121212122,0.12350300356173066,test,137,llama2_7b,,,,,oe,us_oe,/home/manley/testers/,mmlu:anatomy,13893.812293368974 8 | 6,18,0.0555555559694767,0.18888888888888886,0.0,0.16642036999125373,0.3333333432674408,0.24444444444444444,0.4,0.24434712991924706,validation,137,llama2_7b,,,,,oe,us_oe,/home/manley/testers/,mmlu:logical_fallacies,18227.03303959302 9 | 7,163,0.01840490847826004,0.1404907975460123,0.01282051282051282,0.12809718216709298,0.25153374671936035,0.12331288343558285,0.15384615384615385,0.14026401146323153,test,137,llama2_7b,,,,,oe,us_oe,/home/manley/testers/,mmlu:logical_fallacies,18227.03303959302 10 | 8,23,0.08695652335882187,0.16521739130434782,0.0,0.1495042846446019,0.17391304671764374,0.19130434782608696,0.14285714285714285,0.22144873966050033,validation,137,llama2_7b,,,,,oe,us_oe,/home/manley/testers/,mmlu:human_aging,27336.537950017955 11 | 9,223,0.06726457178592682,0.13472222222222227,0.0,0.13210236516838342,0.25112107396125793,0.12361111111111113,0.15306122448979592,0.17871185758902644,test,137,llama2_7b,,,,,oe,us_oe,/home/manley/testers/,mmlu:human_aging,27336.537950017955 12 | 10,14,0.0,0.2642857142857143,0.0,0.34098556452770823,0.4285714328289032,0.19285714285714292,0.5,0.32436169544234417,validation,137,llama2_7b,,,,,oe,us_oe,/home/manley/testers/,mmlu:formal_logic,15292.068300654064 13 | 11,126,0.0555555559694767,0.176984126984127,0.04081632653061224,0.16382318423895145,0.3650793731212616,0.16587301587301584,0.16326530612244897,0.20013771696480473,test,137,llama2_7b,,,,,oe,us_oe,/home/manley/testers/,mmlu:formal_logic,15292.068300654064 14 | 12,32,0.0,0.22812500000000002,0.0,0.18552079471694874,0.21875,0.121875,0.1,0.21065462283057845,validation,137,llama2_7b,,,,,oe,us_oe,/home/manley/testers/,mmlu:high_school_biology,34040.427589047 15 | 13,310,0.03870967775583267,0.1693811074918567,0.0,0.16153449086924884,0.3354838788509369,0.14397394136807817,0.14018691588785046,0.1563416429677811,test,137,llama2_7b,,,,,oe,us_oe,/home/manley/testers/,mmlu:high_school_biology,34040.427589047 16 | 14,13,0.0,0.05384615384615385,0.0,0.05047538815489245,0.0,0.05384615384615385,0.0,0.05047538815489245,validation,137,llama2_7b,,,,,oe,us_oe,/home/manley/testers/,mmlu:international_law,15147.262139331084 17 | 15,121,0.0,0.17250000000000004,0.0,0.164681163411256,0.23140496015548706,0.12083333333333333,0.08333333333333333,0.13235762642366,test,137,llama2_7b,,,,,oe,us_oe,/home/manley/testers/,mmlu:international_law,15147.262139331084 -------------------------------------------------------------------------------- /experiments/embedding_tune.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from llm.datasets import get_dataset 4 | from llm.distributed import AcceleratorState 5 | from llm.logging import entrypoint 6 | from llm.models import get_model 7 | from llm.models.peft import get_lora_model, get_classifier_head, get_temperature_head 8 | from llm.trainer import WandbConfigUpdateCallback, EmbeddingTuner 9 | 10 | 11 | @entrypoint 12 | def main( 13 | seed=137, 14 | log_dir=None, 15 | dataset=None, 16 | data_dir=None, 17 | prompt_style=None, 18 | max_token_length=None, 19 | num_workers=4, 20 | use_dataset_cache=True, 21 | embedding_model_name=None, 22 | model_name=None, 23 | int8=True, 24 | lora_rank=8, 25 | lora_alpha=32, 26 | lora_dropout=0.1, 27 | peft_dir=None, 28 | scale_temp=False, 29 | batch_size=1, 30 | warmup_ratio=0.1, 31 | lr=1e-2, 32 | max_steps=1, 33 | ): 34 | accelerator = AcceleratorState() 35 | 36 | trainer_args = EmbeddingTuner.Args( 37 | seed=seed, 38 | output_dir=log_dir, 39 | max_steps=max_steps, 40 | eval_steps=max_steps // 10, 41 | save_steps=max_steps // 10, 42 | logging_steps=max(1, max_steps // 200), 43 | dataloader_num_workers=num_workers, 44 | per_device_train_batch_size=batch_size, 45 | per_device_eval_batch_size=batch_size, 46 | learning_rate=lr, 47 | warmup_ratio=warmup_ratio, 48 | ) 49 | 50 | with accelerator.main_process_first(): 51 | train_data, val_data, test_data = get_dataset( 52 | dataset, 53 | root=data_dir, 54 | seed=seed, 55 | prompt_style=prompt_style, 56 | max_token_length=max_token_length, 57 | num_workers=num_workers, 58 | use_cache=use_dataset_cache, 59 | ) 60 | if scale_temp: 61 | train_data, val_data = val_data, test_data or val_data 62 | 63 | tokenizer, model = get_model( 64 | model_name, 65 | device_map={"": accelerator.local_process_index}, 66 | use_int8=int8, 67 | ) 68 | 69 | model = get_lora_model( 70 | model, 71 | peft_id_or_dir=peft_dir, 72 | lora_rank=lora_rank, 73 | lora_alpha=lora_alpha, 74 | lora_dropout=lora_dropout, 75 | is_trainable=False, 76 | adapter_name="default", 77 | ) 78 | 79 | embedding_model = get_model(embedding_model_name) 80 | 81 | classifier_model = get_classifier_head( 82 | input_size=embedding_model.get_sentence_embedding_dimension(), 83 | checkpoint_dir=peft_dir, 84 | is_trainable=not scale_temp, 85 | weights_name=EmbeddingTuner.WEIGHTS_NAME, 86 | ) 87 | 88 | if scale_temp: 89 | temperature_model = get_temperature_head( 90 | checkpoint_dir=peft_dir, 91 | is_trainable=True, 92 | ) 93 | 94 | classifier_model = torch.nn.Sequential( 95 | classifier_model, 96 | temperature_model, 97 | ) 98 | 99 | model.classifier_model = classifier_model.to(model.dtype) 100 | 101 | trainer = EmbeddingTuner( 102 | model=model, 103 | embedding_model=embedding_model, 104 | classifier_model=classifier_model, 105 | args=trainer_args, 106 | train_dataset=train_data, 107 | eval_dataset=val_data, 108 | tokenizer=tokenizer, 109 | callbacks=[ 110 | WandbConfigUpdateCallback( 111 | dataset=dataset, 112 | prompt_style=prompt_style, 113 | max_token_length=max_token_length, 114 | model_name=model_name, 115 | lora_rank=lora_rank, 116 | lora_alpha=lora_alpha, 117 | lora_dropout=lora_dropout, 118 | peft_dir=peft_dir, 119 | ), 120 | ], 121 | ) 122 | trainer.train() 123 | 124 | 125 | if __name__ == "__main__": 126 | import fire 127 | 128 | fire.Fire(main) 129 | -------------------------------------------------------------------------------- /llm/utils/generate_utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | import torch.nn.functional as F 4 | from tqdm.auto import tqdm 5 | from peft import PeftModel 6 | 7 | from llm.datasets import LabeledStringDataCollator 8 | 9 | 10 | def wrapped_generate_output(model, tokenizer, generation_inputs, generation_config): 11 | while True: 12 | try: 13 | terminators = [tokenizer.eos_token_id] + ( 14 | [tokenizer.convert_tokens_to_ids("<|eot_id|>")] 15 | if "<|eot_id|>" in tokenizer.vocab 16 | else [] 17 | ) 18 | generation_outputs = model.generate( 19 | **generation_inputs, 20 | eos_token_id=terminators, 21 | generation_config=generation_config, 22 | ) 23 | return generation_outputs 24 | except Exception as e: 25 | generation_outputs = [] 26 | new_bs = max(1, generation_inputs["input_ids"].size(0) // 2) 27 | for i in range(0, generation_inputs["input_ids"].size(0), new_bs): 28 | inputs = {k: v[i : i + new_bs] for k, v in generation_inputs.items()} 29 | _outputs = wrapped_generate_output(model, inputs, generation_config) 30 | generation_outputs.append(_outputs) 31 | return torch.cat(generation_outputs, dim=0) 32 | 33 | 34 | def generate_output( 35 | accelerator, 36 | model, 37 | tokenizer, 38 | loader, 39 | generation_config=None, 40 | generation_config_sampling=None, 41 | n_samples=0, 42 | log_dir=None, 43 | ): 44 | collate_fn = LabeledStringDataCollator(tokenizer) 45 | 46 | all_outputs = [] 47 | 48 | for inputs in tqdm(loader): 49 | inputs = [dict(zip(inputs.keys(), vals)) for vals in zip(*inputs.values())] 50 | targets = [inp.pop("target") for inp in inputs] 51 | 52 | generation_inputs = { 53 | k: v.to(accelerator.device) for k, v in collate_fn(inputs).items() 54 | } 55 | 56 | if isinstance(model, PeftModel): 57 | model.set_adapter("default") 58 | 59 | generation_outputs = wrapped_generate_output( 60 | model, tokenizer, generation_inputs, generation_config 61 | ) 62 | 63 | generations = tokenizer.batch_decode( 64 | generation_outputs[:, generation_inputs.get("input_ids").size(-1) :], 65 | skip_special_tokens=True, 66 | clean_up_tokenization_spaces=False, 67 | ) 68 | 69 | outputs = [ 70 | {**inp, "target": tgt, "output": gen} 71 | for inp, tgt, gen in zip(inputs, targets, generations) 72 | ] 73 | 74 | if n_samples: 75 | assert generation_config_sampling is not None 76 | 77 | # https://github.com/huggingface/transformers/issues/14498#issuecomment-977909651 78 | sampled_outputs = [ 79 | model.generate( 80 | **generation_inputs, 81 | generation_config=generation_config_sampling, 82 | return_dict_in_generate=True, 83 | output_scores=True, 84 | ) 85 | for _ in range(n_samples) 86 | ] 87 | 88 | outputs = [ 89 | { 90 | **o, 91 | "sampled_outputs": tokenizer.batch_decode( 92 | so["sequences"][ 93 | :, generation_inputs.get("input_ids").size(-1) : 94 | ] 95 | ), 96 | "sampled_log_probs": F.log_softmax( 97 | torch.cat(so["scores"], dim=0), dim=-1 98 | ), 99 | } 100 | for o, so in zip(outputs, sampled_outputs) 101 | ] 102 | 103 | all_outputs.extend(outputs) 104 | 105 | if log_dir is not None: 106 | df = pd.DataFrame(all_outputs) 107 | ## NOTE: Avoid spec errors when loading for labeling. 108 | df["query_label"] = -1 109 | 110 | df.to_csv(f"{log_dir}/rows_{accelerator.process_index}.csv", index=False) 111 | -------------------------------------------------------------------------------- /llm/models/llama2.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from peft import prepare_model_for_kbit_training 3 | import torch 4 | from transformers import AutoTokenizer, LlamaForCausalLM 5 | 6 | from ..datasets import LabeledStringDataCollator 7 | from .registry import register_model 8 | from .llm_model_utils import DEFAULT_PAD_TOKEN, resize_token_embeddings 9 | 10 | 11 | __HF_MODEL_MAP = { 12 | "7b": "Llama-2-7b-hf", 13 | "7b-chat": "Llama-2-7b-chat-hf", 14 | "13b": "Llama-2-13b-hf", 15 | "13b-chat": "Llama-2-13b-chat-hf", 16 | "70b": "Llama-2-70b-hf", 17 | "70b-chat": "Llama-2-70b-chat-hf", 18 | } 19 | 20 | 21 | def __get_model_hf_id(model_str): 22 | try: 23 | _, kind = model_str.split(":") 24 | 25 | assert kind in __HF_MODEL_MAP.keys() 26 | except ValueError: 27 | logging.exception( 28 | f'Model string should be formatted as "llama2:" (Got {model_str})', 29 | ) 30 | raise 31 | except AssertionError: 32 | logging.exception( 33 | f'Model not found. Model string should be formatted as "llama2:" (Got {model_str})', 34 | ) 35 | raise 36 | 37 | return __HF_MODEL_MAP[kind] 38 | 39 | 40 | def create_tokenizer( 41 | kind, 42 | model_dir=None, 43 | padding_side="left", 44 | model_max_length=4096, 45 | **kwargs, 46 | ): 47 | tokenizer = AutoTokenizer.from_pretrained( 48 | model_dir or f"meta-llama/{kind}", 49 | padding_side=padding_side, 50 | model_max_length=model_max_length, 51 | use_fast=True, 52 | legacy=False, 53 | **kwargs, 54 | ) 55 | 56 | tokenizer.add_special_tokens({"pad_token": DEFAULT_PAD_TOKEN}) 57 | 58 | return tokenizer 59 | 60 | 61 | def create_model( 62 | kind, 63 | torch_dtype=None, 64 | model_dir=None, 65 | use_cache=False, 66 | tokenizer=None, 67 | use_int8=False, 68 | use_int4=False, 69 | **kwargs, 70 | ): 71 | quantization_config = None 72 | if use_int8 or use_int4: 73 | from transformers import BitsAndBytesConfig 74 | 75 | quantization_config = BitsAndBytesConfig( 76 | load_in_4bit=use_int4, 77 | load_in_8bit=use_int8, 78 | ) 79 | 80 | model = LlamaForCausalLM.from_pretrained( 81 | model_dir or f"meta-llama/{kind}", 82 | torch_dtype=torch_dtype 83 | or (torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16), 84 | quantization_config=quantization_config, 85 | use_cache=use_cache, 86 | **kwargs, 87 | ) 88 | 89 | model.config.pad_token_id = tokenizer.pad_token_id 90 | 91 | resize_token_embeddings(tokenizer, model) 92 | 93 | if use_int8: 94 | model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False) 95 | 96 | return model 97 | 98 | 99 | def create_tokenizer_and_model(kind, tokenizer_args=None, **kwargs): 100 | tokenizer = create_tokenizer(kind, **(tokenizer_args or dict())) 101 | model = create_model(kind, tokenizer=tokenizer, **kwargs) 102 | return tokenizer, model 103 | 104 | 105 | class LMEmbedModel: 106 | def __init__(self, t, m): 107 | self.tokenizer = t 108 | self.model = m 109 | self.tokenizer_args = LabeledStringDataCollator.get_tokenizer_args( 110 | self.tokenizer 111 | ) 112 | 113 | @torch.inference_mode 114 | def __call__(self, texts): 115 | inputs = self.tokenizer(texts, **self.tokenizer_args) 116 | inputs.pop("length", None) 117 | 118 | inputs = {k: v.to(self.model.device) for k, v in inputs.items()} 119 | 120 | outputs = self.model(**inputs, output_hidden_states=True) 121 | embeddings = outputs.hidden_states[-1][..., -1, :] 122 | 123 | return embeddings.clone() 124 | 125 | 126 | def create_embed_model(kind, **kwargs): 127 | return LMEmbedModel(*create_tokenizer_and_model(kind, **kwargs)) 128 | 129 | 130 | @register_model 131 | def llama2_tokenizer(*, model_str=None, **kwargs): 132 | return create_tokenizer(__get_model_hf_id(model_str), **kwargs) 133 | 134 | 135 | @register_model 136 | def llama2(*, model_str=None, **kwargs): 137 | return create_tokenizer_and_model(__get_model_hf_id(model_str), **kwargs) 138 | 139 | 140 | @register_model 141 | def llama2_embed(*, model_str=None, **kwargs): 142 | return create_embed_model(__get_model_hf_id(model_str), **kwargs) 143 | -------------------------------------------------------------------------------- /experiments/calibration_tune.py: -------------------------------------------------------------------------------- 1 | from llm.datasets import get_dataset, LMText, LabeledStringDataCollator 2 | from llm.distributed import AcceleratorState 3 | from llm.logging import entrypoint 4 | from llm.models import get_model 5 | from llm.models.peft import get_lora_model, get_temperature_head 6 | from llm.trainer import WandbConfigUpdateCallback, CalibrationTuner 7 | 8 | 9 | @entrypoint 10 | def main( 11 | seed=137, 12 | log_dir=None, 13 | dataset=None, 14 | data_dir=None, 15 | prompt_style=None, 16 | max_token_length=None, 17 | num_workers=4, 18 | use_dataset_cache=True, 19 | model_name=None, 20 | int8=True, 21 | lora_rank=8, 22 | lora_alpha=32, 23 | lora_dropout=0.1, 24 | peft_dir=None, 25 | ref_peft_dir=None, 26 | scale_temp=False, 27 | batch_size=1, 28 | lr=1e-4, 29 | warmup_ratio=0.0, 30 | kl_decay=0.0, 31 | max_steps=1, 32 | **_, 33 | ): 34 | accelerator = AcceleratorState() 35 | 36 | trainer_args = CalibrationTuner.Args( 37 | seed=seed, 38 | output_dir=log_dir, 39 | max_steps=max_steps, 40 | eval_steps=max_steps // 10, 41 | save_steps=max_steps // 10, 42 | logging_steps=max(1, max_steps // 200), 43 | dataloader_num_workers=num_workers, 44 | per_device_train_batch_size=batch_size, 45 | per_device_eval_batch_size=batch_size, 46 | learning_rate=lr, 47 | warmup_ratio=warmup_ratio, 48 | scale_temp=scale_temp, 49 | kl_decay=kl_decay, 50 | ) 51 | 52 | with accelerator.main_process_first(): 53 | train_data, val_data, test_data = get_dataset( 54 | dataset, 55 | root=data_dir, 56 | seed=seed, 57 | prompt_style=prompt_style, 58 | num_workers=num_workers, 59 | use_cache=use_dataset_cache, 60 | ) 61 | if scale_temp: 62 | train_data, val_data = val_data, test_data or val_data 63 | 64 | tokenizer, model = get_model( 65 | model_name, 66 | device_map={"": accelerator.local_process_index}, 67 | use_int8=int8, 68 | ) 69 | 70 | if max_token_length is not None: 71 | tokenizer_args = LabeledStringDataCollator.get_tokenizer_args(tokenizer) 72 | 73 | def token_length_filter(instance): 74 | f_instance = {k: v for k, v in instance.items() if "embedding" not in k} 75 | inputs = tokenizer( 76 | [str(LMText.from_(f_instance))], 77 | **tokenizer_args, 78 | ) 79 | return inputs.get("input_ids").size(-1) <= max_token_length 80 | 81 | train_data = train_data.filter(token_length_filter, num_proc=num_workers) 82 | val_data = val_data.filter(token_length_filter, num_proc=num_workers) 83 | 84 | model = get_lora_model( 85 | model, 86 | peft_id_or_dir=ref_peft_dir or peft_dir, 87 | lora_rank=lora_rank, 88 | lora_alpha=lora_alpha, 89 | lora_dropout=lora_dropout, 90 | is_trainable=False, 91 | adapter_name="_ref", 92 | ) 93 | 94 | model = get_lora_model( 95 | model, 96 | peft_id_or_dir=peft_dir, 97 | lora_rank=lora_rank, 98 | lora_alpha=lora_alpha, 99 | lora_dropout=lora_dropout, 100 | is_trainable=not scale_temp, 101 | adapter_name="default", 102 | ) 103 | 104 | if scale_temp: 105 | model.requires_grad_(False) 106 | 107 | temperature_model = get_temperature_head( 108 | checkpoint_dir=peft_dir, 109 | is_trainable=True, 110 | weights_name=CalibrationTuner.TEMPERATURE_WEIGHTS_NAME, 111 | ).to(accelerator.local_process_index) 112 | 113 | ## HOTFIX: To allow registry with Trainer optimizer. 114 | model.temperature_model = temperature_model 115 | else: 116 | temperature_model = None 117 | 118 | trainer = CalibrationTuner( 119 | model=model, 120 | query_temperature_model=temperature_model, 121 | args=trainer_args, 122 | train_dataset=train_data, 123 | eval_dataset=val_data, 124 | tokenizer=tokenizer, 125 | callbacks=[ 126 | WandbConfigUpdateCallback( 127 | dataset=dataset, 128 | prompt_style=prompt_style, 129 | max_token_length=max_token_length, 130 | model_name=model_name, 131 | lora_rank=lora_rank, 132 | lora_alpha=lora_alpha, 133 | lora_dropout=lora_dropout, 134 | peft_dir=peft_dir, 135 | ), 136 | ], 137 | ) 138 | trainer.train() 139 | 140 | 141 | if __name__ == "__main__": 142 | import fire 143 | 144 | fire.Fire(main) 145 | -------------------------------------------------------------------------------- /llm/datasets/offline/offline.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import glob 4 | from enum import Enum 5 | import numpy as np 6 | from datasets import load_dataset, Features, Value, DatasetDict 7 | 8 | from ..registry import register_dataset 9 | from ..llm_utils_oe import sanitize_generations 10 | 11 | 12 | CSV_DATASET_FEATURES = Features( 13 | { 14 | "context": Value("string"), 15 | "target": Value("string"), 16 | "target_prompt": Value("string"), 17 | "prompt": Value("string"), 18 | "output": Value("string"), 19 | "query_label": Value("int32"), 20 | } 21 | ) 22 | 23 | 24 | class DatasetSizeRatio(float, Enum): 25 | XXS = 0.01 26 | XS = 0.1 27 | SM = 0.25 28 | MD = 0.5 29 | 30 | 31 | def get_offline( 32 | seed=None, 33 | root=None, 34 | num_workers=8, 35 | use_cache=True, 36 | data_ratio=None, 37 | train_kshot=0, 38 | eval_kshot=0, 39 | load_embeddings=True, 40 | **_, 41 | ): 42 | data_files = {} 43 | embeddings = {} 44 | for split_name in ["train", "validation", "test"]: 45 | if os.path.isdir(f"{root}/{split_name}"): 46 | data_files[split_name] = glob.glob(f"{root}/{split_name}/*.csv") 47 | 48 | if os.path.isfile(f"{root}/{split_name}/embedding.npy"): 49 | embeddings[split_name] = np.load(f"{root}/{split_name}/embedding.npy") 50 | 51 | dataset = load_dataset("csv", data_files=data_files, features=CSV_DATASET_FEATURES) 52 | if not use_cache: 53 | dataset.cleanup_cache_files() 54 | 55 | dataset = dataset.map( 56 | lambda x: {k: "" if v is None else v for k, v in x.items()}, 57 | num_proc=num_workers, 58 | ) 59 | 60 | if load_embeddings and len(set(dataset.keys()) - set(embeddings.keys())) == 0: 61 | dataset = DatasetDict( 62 | { 63 | split: ds.map( 64 | lambda _, idx: {"embedding": embeddings[split][idx]}, 65 | with_indices=True, 66 | num_proc=num_workers, 67 | ) 68 | for split, ds in dataset.items() 69 | } 70 | ) 71 | dataset = dataset.with_format( 72 | "np", columns=["embedding"], output_all_columns=True 73 | ) 74 | 75 | if data_ratio is not None: 76 | data_ratio = DatasetSizeRatio(data_ratio) 77 | 78 | def sub_sample(data): 79 | N = len(data) 80 | idxs = np.random.default_rng(seed=seed).choice( 81 | range(N), int(data_ratio * N) 82 | ) 83 | return data.select(idxs) 84 | 85 | dataset = DatasetDict({split: sub_sample(ds) for split, ds in dataset.items()}) 86 | 87 | prompt_kshot = { 88 | "train": train_kshot, 89 | "validation": eval_kshot, 90 | "test": eval_kshot, 91 | } 92 | data_splits = DatasetDict( 93 | { 94 | split: ( 95 | ds.remove_columns([c for c in ["prompt"] if c in ds.column_names]) 96 | if prompt_kshot[split] == 0 97 | else ds 98 | ) 99 | for split, ds in dataset.items() 100 | } 101 | ).map(lambda s: {"output": sanitize_generations([s["output"].strip()])[0]}) 102 | 103 | train_data = data_splits.pop("train", None) 104 | val_data = data_splits.pop("validation", None) 105 | test_data = data_splits.pop("test", None) 106 | 107 | return train_data, val_data, test_data 108 | 109 | 110 | @register_dataset(attrs=dict(unlisted=True)) 111 | def offline(*args, root=None, dataset_str=None, prompt_style=None, **kwargs): 112 | try: 113 | _, name = dataset_str.split(":") 114 | except ValueError: 115 | logging.exception( 116 | f'Dataset string should be formatted as "offline:" (Got {dataset_str})', 117 | ) 118 | raise 119 | 120 | root = f"{root}/offline/{name}-{prompt_style}" 121 | 122 | return get_offline(*args, root=root, **kwargs) 123 | 124 | 125 | @register_dataset(attrs=dict(unlisted=True)) 126 | def offline_xxs(*args, **kwargs): 127 | kwargs.pop("data_ratio", None) 128 | return offline(*args, data_ratio=DatasetSizeRatio.XXS, **kwargs) 129 | 130 | 131 | @register_dataset(attrs=dict(unlisted=True)) 132 | def offline_xs(*args, **kwargs): 133 | kwargs.pop("data_ratio", None) 134 | return offline(*args, data_ratio=DatasetSizeRatio.XS, **kwargs) 135 | 136 | 137 | @register_dataset(attrs=dict(unlisted=True)) 138 | def offline_sm(*args, **kwargs): 139 | kwargs.pop("data_ratio", None) 140 | return offline(*args, data_ratio=DatasetSizeRatio.SM, **kwargs) 141 | 142 | 143 | @register_dataset(attrs=dict(unlisted=True)) 144 | def offline_md(*args, **kwargs): 145 | kwargs.pop("data_ratio", None) 146 | return offline(*args, data_ratio=DatasetSizeRatio.MD, **kwargs) 147 | -------------------------------------------------------------------------------- /llm/datasets/hf/wsc.py: -------------------------------------------------------------------------------- 1 | import string 2 | import numpy as np 3 | from datasets import load_dataset 4 | 5 | from ..registry import register_dataset 6 | from ..llm_data_utils import LMText, PromptFormat 7 | 8 | 9 | def format_sample(sample, format, with_query_label=False, seed=None): 10 | target_prompt = "\nAnswer:" 11 | 12 | text = sample["text"] 13 | answer_map = sample["options"] 14 | target_idx = sample["label"] 15 | 16 | output = None 17 | query_label = ( 18 | np.random.default_rng(seed=seed).binomial(1, 0.5) if with_query_label else None 19 | ) 20 | output_idx = ( 21 | target_idx 22 | if query_label == 1 23 | else ( 24 | np.random.default_rng(seed=seed).choice( 25 | list(set(range(len(answer_map))) - set([target_idx])) 26 | ) 27 | if query_label == 0 28 | else None 29 | ) 30 | ) 31 | 32 | if format == PromptFormat.CHOICE: 33 | context = "\n".join( 34 | [ 35 | "Sentence:", 36 | text, 37 | "\nChoices:", 38 | *[ 39 | f" ({n}): {c}" 40 | for n, c in zip( 41 | string.ascii_lowercase[: len(answer_map)], answer_map 42 | ) 43 | ], 44 | ] 45 | ) 46 | 47 | target = string.ascii_lowercase[target_idx] 48 | if output_idx is not None: 49 | output = string.ascii_lowercase[output_idx] 50 | elif format == PromptFormat.OE: 51 | context = "\n".join( 52 | [ 53 | "Read the following sentence, and resolve the ambiguity.", 54 | text, 55 | ] 56 | ) 57 | 58 | target = answer_map[target_idx] 59 | if output_idx is not None: 60 | output = answer_map[output_idx] 61 | else: 62 | raise NotImplementedError(f"Unsupported prompt format {format}.") 63 | 64 | return LMText( 65 | context=context, 66 | target_prompt=target_prompt, 67 | target=target, 68 | output=output, 69 | query_label=query_label, 70 | ) 71 | 72 | 73 | def format_sample_prompt(prompt_dataset, format, kshot=1, seed=None): 74 | if not kshot: 75 | return "" 76 | 77 | samples_idx = ( 78 | np.random.default_rng(seed=seed) 79 | .permutation(len(prompt_dataset))[:kshot] 80 | .tolist() 81 | ) 82 | 83 | fewshot_samples_prompt = [ 84 | str(LMText.from_(prompt_dataset[idx])) + "\n" for idx in samples_idx 85 | ] 86 | 87 | if format == PromptFormat.CHOICE: 88 | prompt = [ 89 | "The following are ambiguous sentences (with answer choices).\n", 90 | *fewshot_samples_prompt, 91 | "Now, resolve the next ambiguity.\n\n", 92 | ] 93 | elif format == PromptFormat.OE: 94 | prompt = [ 95 | "The following are ambiguous sentences (with answers).\n", 96 | *fewshot_samples_prompt, 97 | "Now, resolve the next ambiguity.\n\n", 98 | ] 99 | else: 100 | raise NotImplementedError(f"Unsupported prompt format {format}.") 101 | 102 | return "\n".join(prompt) 103 | 104 | 105 | def get_wsc( 106 | prompt_style=None, 107 | with_query_label=False, 108 | eval_kshot=0, 109 | num_workers=8, 110 | seed=None, 111 | use_cache=True, 112 | **_, 113 | ): 114 | format = PromptFormat(prompt_style) 115 | 116 | dataset = load_dataset("winograd_wsc", "wsc285", trust_remote_code=True) 117 | if not use_cache: 118 | dataset.cleanup_cache_files() 119 | 120 | dataset = dataset.map( 121 | lambda sample, idx: format_sample( 122 | sample, format, with_query_label=with_query_label, seed=seed + idx 123 | ).to_pydict(), 124 | with_indices=True, 125 | num_proc=num_workers, 126 | remove_columns=dataset.column_names["test"], 127 | ) 128 | 129 | prompt_data = dataset.get("test") 130 | prompt_kshot = { 131 | "validation": eval_kshot, 132 | "test": eval_kshot, 133 | } 134 | 135 | data_splits = { 136 | split: ds.map( 137 | lambda _, idx: { 138 | "prompt": format_sample_prompt( 139 | prompt_data, format, kshot=prompt_kshot[split], seed=seed + idx 140 | ) 141 | }, 142 | with_indices=True, 143 | num_proc=num_workers, 144 | ) 145 | for split, ds in dataset.items() 146 | } 147 | 148 | train_data = data_splits.pop("train", None) 149 | val_data = data_splits.pop("validation", None) 150 | test_data = data_splits.pop("test", None) 151 | 152 | return train_data, val_data, test_data 153 | 154 | 155 | @register_dataset 156 | def wsc(*args, **kwargs): 157 | return get_wsc(*args, **kwargs) 158 | -------------------------------------------------------------------------------- /llm/datasets/hf/truthful_qa.py: -------------------------------------------------------------------------------- 1 | import string 2 | import numpy as np 3 | from datasets import load_dataset 4 | 5 | from ..registry import register_dataset 6 | from ..llm_data_utils import LMText, PromptFormat 7 | 8 | 9 | def format_sample(sample, format, with_query_label=False, seed=None): 10 | target_prompt = "\nAnswer:" 11 | 12 | question = sample["question"] 13 | answer_map = sample["mc1_targets"]["choices"] 14 | target_idx = np.array(sample["mc1_targets"]["labels"]).argmax().item() 15 | 16 | output = None 17 | query_label = ( 18 | np.random.default_rng(seed=seed).binomial(1, 0.5) if with_query_label else None 19 | ) 20 | output_idx = ( 21 | target_idx 22 | if query_label == 1 23 | else ( 24 | np.random.default_rng(seed=seed).choice( 25 | list(set(range(len(answer_map))) - set([target_idx])) 26 | ) 27 | if query_label == 0 28 | else None 29 | ) 30 | ) 31 | 32 | if format == PromptFormat.CHOICE: 33 | context = "\n".join( 34 | [ 35 | "Question:", 36 | question, 37 | "\nChoices:", 38 | *[ 39 | f" ({n}): {c}" 40 | for n, c in zip( 41 | string.ascii_lowercase[: len(answer_map)], answer_map 42 | ) 43 | ], 44 | ] 45 | ) 46 | 47 | target = string.ascii_lowercase[target_idx] 48 | if output_idx is not None: 49 | output = string.ascii_lowercase[output_idx] 50 | elif format == PromptFormat.OE: 51 | context = "\n".join( 52 | [ 53 | f"Question: {question}", 54 | ] 55 | ) 56 | 57 | target = answer_map[target_idx] 58 | if output_idx is not None: 59 | output = answer_map[output_idx] 60 | else: 61 | raise NotImplementedError(f"Unsupported prompt format {format}.") 62 | 63 | return LMText( 64 | context=context, 65 | target_prompt=target_prompt, 66 | target=target, 67 | output=output, 68 | query_label=query_label, 69 | ) 70 | 71 | 72 | def format_sample_prompt(prompt_dataset, format, kshot=1, seed=None): 73 | if not kshot: 74 | return "" 75 | 76 | samples_idx = ( 77 | np.random.default_rng(seed=seed) 78 | .permutation(len(prompt_dataset))[:kshot] 79 | .tolist() 80 | ) 81 | 82 | fewshot_samples_prompt = [ 83 | str(LMText.from_(prompt_dataset[idx])) + "\n" for idx in samples_idx 84 | ] 85 | 86 | if format == PromptFormat.CHOICE: 87 | prompt = [ 88 | "The following are questions with multiple-choice answers.\n", 89 | *fewshot_samples_prompt, 90 | "Now, answer the next question.\n\n", 91 | ] 92 | elif format == PromptFormat.OE: 93 | prompt = [ 94 | "The following are questions with answers.\n", 95 | *fewshot_samples_prompt, 96 | "Now, answer the next question.\n\n", 97 | ] 98 | else: 99 | raise NotImplementedError(f"Unsupported prompt format {format}.") 100 | 101 | return "\n".join(prompt) 102 | 103 | 104 | def get_truthful_qa( 105 | prompt_style=None, 106 | with_query_label=False, 107 | eval_kshot=0, 108 | num_workers=8, 109 | seed=None, 110 | use_cache=True, 111 | **_, 112 | ): 113 | format = PromptFormat(prompt_style) 114 | 115 | dataset = load_dataset("truthful_qa", "multiple_choice") 116 | if not use_cache: 117 | dataset.cleanup_cache_files() 118 | 119 | dataset = dataset.map( 120 | lambda sample, idx: format_sample( 121 | sample, format, with_query_label=with_query_label, seed=seed + idx 122 | ).to_pydict(), 123 | with_indices=True, 124 | num_proc=num_workers, 125 | remove_columns=dataset.column_names["validation"], 126 | ) 127 | 128 | prompt_data = dataset.get("validation") 129 | prompt_kshot = { 130 | "validation": eval_kshot, 131 | "test": eval_kshot, 132 | } 133 | 134 | data_splits = { 135 | split: ds.map( 136 | lambda _, idx: { 137 | "prompt": format_sample_prompt( 138 | prompt_data, format, kshot=prompt_kshot[split], seed=seed + idx 139 | ) 140 | }, 141 | with_indices=True, 142 | num_proc=num_workers, 143 | ) 144 | for split, ds in dataset.items() 145 | } 146 | 147 | train_data = data_splits.pop("train", None) 148 | val_data = data_splits.pop("validation", None) 149 | test_data = data_splits.pop("test", None) 150 | 151 | return train_data, val_data, test_data 152 | 153 | 154 | @register_dataset 155 | def truthful_qa(*args, **kwargs): 156 | return get_truthful_qa(*args, **kwargs) 157 | -------------------------------------------------------------------------------- /llm/datasets/hf/obqa.py: -------------------------------------------------------------------------------- 1 | import string 2 | import numpy as np 3 | from datasets import load_dataset 4 | 5 | from ..registry import register_dataset 6 | from ..llm_data_utils import LMText, PromptFormat 7 | 8 | 9 | def format_sample(sample, format, with_query_label=False, seed=None): 10 | target_prompt = "\nAnswer:" 11 | 12 | question = sample["question_stem"] 13 | answer_map = sample["choices"]["text"] 14 | target_idx = string.ascii_lowercase.index(sample["answerKey"].lower()) 15 | 16 | output = None 17 | query_label = ( 18 | np.random.default_rng(seed=seed).binomial(1, 0.5) if with_query_label else None 19 | ) 20 | output_idx = ( 21 | target_idx 22 | if query_label == 1 23 | else ( 24 | np.random.default_rng(seed=seed).choice( 25 | list(set(range(len(answer_map))) - set([target_idx])) 26 | ) 27 | if query_label == 0 28 | else None 29 | ) 30 | ) 31 | 32 | if format == PromptFormat.CHOICE: 33 | context = "\n".join( 34 | [ 35 | "Question:", 36 | question, 37 | "\nChoices:", 38 | *[ 39 | f" ({n}): {c}" 40 | for n, c in zip( 41 | string.ascii_lowercase[: len(answer_map)], answer_map 42 | ) 43 | ], 44 | ] 45 | ) 46 | 47 | target = string.ascii_lowercase[target_idx] 48 | if output_idx is not None: 49 | output = string.ascii_lowercase[output_idx] 50 | elif format == PromptFormat.OE: 51 | context = "\n".join( 52 | [ 53 | "Provide your best answer for the following question.", 54 | f"Question: {question}", 55 | ] 56 | ) 57 | 58 | target = answer_map[target_idx] 59 | if output_idx is not None: 60 | output = answer_map[output_idx] 61 | else: 62 | raise NotImplementedError(f"Unsupported prompt format {format}.") 63 | 64 | return LMText( 65 | context=context, 66 | target_prompt=target_prompt, 67 | target=target, 68 | output=output, 69 | query_label=query_label, 70 | ) 71 | 72 | 73 | def format_sample_prompt(prompt_dataset, format, kshot=1, seed=None): 74 | if not kshot: 75 | return "" 76 | 77 | samples_idx = ( 78 | np.random.default_rng(seed=seed) 79 | .permutation(len(prompt_dataset))[:kshot] 80 | .tolist() 81 | ) 82 | 83 | fewshot_samples_prompt = [ 84 | str(LMText.from_(prompt_dataset[idx])) + "\n" for idx in samples_idx 85 | ] 86 | 87 | if format == PromptFormat.CHOICE: 88 | prompt = [ 89 | "The following are questions with multiple-choice answers.\n", 90 | *fewshot_samples_prompt, 91 | "Now, answer the next question.\n\n", 92 | ] 93 | elif format == PromptFormat.OE: 94 | prompt = [ 95 | "The following are questions with answers.\n", 96 | *fewshot_samples_prompt, 97 | "Now, answer the next question.\n\n", 98 | ] 99 | else: 100 | raise NotImplementedError(f"Unsupported prompt format {format}.") 101 | 102 | return "\n".join(prompt) 103 | 104 | 105 | def get_openbookqa( 106 | prompt_style=None, 107 | with_query_label=False, 108 | train_kshot=0, 109 | eval_kshot=0, 110 | num_workers=8, 111 | seed=None, 112 | use_cache=True, 113 | **_, 114 | ): 115 | format = PromptFormat(prompt_style) 116 | 117 | dataset = load_dataset("openbookqa") 118 | if not use_cache: 119 | dataset.cleanup_cache_files() 120 | 121 | dataset = dataset.map( 122 | lambda sample, idx: format_sample( 123 | sample, format, with_query_label=with_query_label, seed=seed + idx 124 | ).to_pydict(), 125 | with_indices=True, 126 | num_proc=num_workers, 127 | remove_columns=dataset.column_names["test"], 128 | ) 129 | 130 | prompt_data = dataset.get("train") 131 | prompt_kshot = { 132 | "train": train_kshot, 133 | "validation": eval_kshot, 134 | "test": eval_kshot, 135 | } 136 | 137 | data_splits = { 138 | split: ds.map( 139 | lambda _, idx: { 140 | "prompt": format_sample_prompt( 141 | prompt_data, format, kshot=prompt_kshot[split], seed=seed + idx 142 | ) 143 | }, 144 | with_indices=True, 145 | num_proc=num_workers, 146 | ) 147 | for split, ds in dataset.items() 148 | } 149 | 150 | train_data = data_splits.pop("train", None) 151 | val_data = data_splits.pop("validation", None) 152 | test_data = data_splits.pop("test", None) 153 | 154 | return train_data, val_data, test_data 155 | 156 | 157 | @register_dataset 158 | def obqa(*args, **kwargs): 159 | return get_openbookqa(*args, **kwargs) 160 | -------------------------------------------------------------------------------- /llm/datasets/hf/commonsense_qa.py: -------------------------------------------------------------------------------- 1 | import string 2 | import numpy as np 3 | from datasets import load_dataset 4 | 5 | from ..registry import register_dataset 6 | from ..llm_data_utils import LMText, PromptFormat 7 | 8 | 9 | def format_sample(sample, format, with_query_label=False, seed=None): 10 | target_prompt = "\nAnswer:" 11 | 12 | question = sample["question"] 13 | answer_map = sample["choices"]["text"] 14 | target_idx = string.ascii_lowercase.index(sample["answerKey"].lower()) 15 | 16 | output = None 17 | query_label = ( 18 | np.random.default_rng(seed=seed).binomial(1, 0.5) if with_query_label else None 19 | ) 20 | output_idx = ( 21 | target_idx 22 | if query_label == 1 23 | else ( 24 | np.random.default_rng(seed=seed).choice( 25 | list(set(range(len(answer_map))) - set([target_idx])) 26 | ) 27 | if query_label == 0 28 | else None 29 | ) 30 | ) 31 | 32 | if format == PromptFormat.CHOICE: 33 | context = "\n".join( 34 | [ 35 | "Question:", 36 | question, 37 | "\nChoices:", 38 | *[ 39 | f" ({n}): {c}" 40 | for n, c in zip( 41 | string.ascii_lowercase[: len(answer_map)], answer_map 42 | ) 43 | ], 44 | ] 45 | ) 46 | 47 | target = string.ascii_lowercase[target_idx] 48 | if output_idx is not None: 49 | output = string.ascii_lowercase[output_idx] 50 | elif format == PromptFormat.OE: 51 | context = "\n".join( 52 | [ 53 | f"Question: {question}", 54 | ] 55 | ) 56 | 57 | target = answer_map[target_idx] 58 | if output_idx is not None: 59 | output = answer_map[output_idx] 60 | else: 61 | raise NotImplementedError(f"Unsupported prompt format {format}.") 62 | 63 | return LMText( 64 | context=context, 65 | target_prompt=target_prompt, 66 | target=target, 67 | output=output, 68 | query_label=query_label, 69 | ) 70 | 71 | 72 | def format_sample_prompt(prompt_dataset, format, kshot=1, seed=None): 73 | if not kshot: 74 | return "" 75 | 76 | samples_idx = ( 77 | np.random.default_rng(seed=seed) 78 | .permutation(len(prompt_dataset))[:kshot] 79 | .tolist() 80 | ) 81 | 82 | fewshot_samples_prompt = [ 83 | str(LMText.from_(prompt_dataset[idx])) + "\n" for idx in samples_idx 84 | ] 85 | 86 | if format == PromptFormat.CHOICE: 87 | prompt = [ 88 | "The following are questions (with multiple-choice answers).\n", 89 | *fewshot_samples_prompt, 90 | "Now, answer the following.\n\n", 91 | ] 92 | elif format == PromptFormat.OE: 93 | prompt = [ 94 | "The following are questions (with answers).\n", 95 | *fewshot_samples_prompt, 96 | "Now, answer the following.\n\n", 97 | ] 98 | else: 99 | raise NotImplementedError(f"Unsupported prompt format {format}.") 100 | 101 | return "\n".join(prompt) 102 | 103 | 104 | def get_commonsense_qa( 105 | prompt_style=None, 106 | with_query_label=False, 107 | train_kshot=0, 108 | eval_kshot=0, 109 | num_workers=8, 110 | seed=None, 111 | use_cache=True, 112 | **_, 113 | ): 114 | format = PromptFormat(prompt_style) 115 | 116 | dataset = load_dataset("commonsense_qa") 117 | if not use_cache: 118 | dataset.cleanup_cache_files() 119 | 120 | dataset = dataset.filter( 121 | lambda x: len(x["choices"]["text"]) == len(np.unique(x["choices"]["text"])), 122 | num_proc=num_workers, 123 | ).map( 124 | lambda sample, idx: format_sample( 125 | sample, format, with_query_label=with_query_label, seed=seed + idx 126 | ).to_pydict(), 127 | with_indices=True, 128 | num_proc=num_workers, 129 | remove_columns=dataset.column_names["validation"], 130 | ) 131 | 132 | prompt_data = dataset.get("train") 133 | prompt_kshot = { 134 | "train": train_kshot, 135 | "validation": eval_kshot, 136 | "test": eval_kshot, 137 | } 138 | 139 | data_splits = { 140 | split: ds.map( 141 | lambda _, idx: { 142 | "prompt": format_sample_prompt( 143 | prompt_data, format, kshot=prompt_kshot[split], seed=seed + idx 144 | ) 145 | }, 146 | with_indices=True, 147 | num_proc=num_workers, 148 | ) 149 | for split, ds in dataset.items() 150 | } 151 | 152 | train_data = data_splits.pop("train", None) 153 | val_data = data_splits.pop("validation", None) 154 | test_data = data_splits.pop("test", None) 155 | 156 | return train_data, val_data, test_data 157 | 158 | 159 | @register_dataset 160 | def commonsense_qa(*args, **kwargs): 161 | return get_commonsense_qa(*args, **kwargs) 162 | -------------------------------------------------------------------------------- /llm/datasets/llm_data_utils.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | import dataclasses 3 | from dataclasses import dataclass, asdict as dataclassasdict 4 | import torch 5 | import transformers 6 | from datasets.formatting.formatting import LazyRow 7 | 8 | 9 | ## NOTE: HF Convention. See https://huggingface.co/docs/transformers/en/tasks/token_classification#preprocess. 10 | IGNORE_LABEL = -100 11 | 12 | 13 | class PromptFormat(str, Enum): 14 | CHOICE = "choice" 15 | OE = "oe" 16 | 17 | 18 | @dataclass 19 | class LMText: 20 | context: str 21 | prompt: str = "" 22 | target_prompt: str = "" 23 | target: str = "" 24 | 25 | ## Misc. 26 | output: str = None 27 | query_label: int = None 28 | 29 | def __str__(self): 30 | return ( 31 | self.prompt + self.context + self.target_prompt + " " + self.target 32 | ).strip() 33 | 34 | def to_pydict(self): 35 | return {k: v for k, v in dataclassasdict(self).items() if v is not None} 36 | 37 | @staticmethod 38 | def field_names(): 39 | return [f.name for f in dataclasses.fields(LMText)] 40 | 41 | @staticmethod 42 | def from_(instance): 43 | if isinstance(instance, LMText): 44 | return instance 45 | 46 | if isinstance(instance, LazyRow): 47 | instance = {k: v for k, v in zip(instance.keys(), instance.values())} 48 | 49 | assert isinstance( 50 | instance, dict 51 | ), f"Could not convert instance to dict. Found {type(instance)}" 52 | 53 | instance = { 54 | k: v 55 | for k, v in instance.items() 56 | if k in set(f.name for f in dataclasses.fields(LMText)) 57 | } 58 | return LMText(**instance) 59 | 60 | 61 | def get_token_vec(tokenizer, format="roman_choice"): 62 | vocab = tokenizer.get_vocab() 63 | 64 | def _create_vec(raw_list): 65 | for t in raw_list: 66 | assert t in vocab, f"Cannot handle {t} as a single token." 67 | 68 | return torch.tensor([tokenizer(t).input_ids[-1] for t in raw_list]) 69 | 70 | if format == "bool": 71 | raw_strings = ["no", "yes"] 72 | elif format == "alpha_choice": 73 | raw_strings = ["a", "b"] 74 | elif format == "choice": 75 | raw_strings = ["a", "b", "c", "d"] 76 | elif format == "roman_choice": 77 | raw_strings = ["i", "ii"] 78 | else: 79 | raise NotImplementedError 80 | 81 | return _create_vec(raw_strings) 82 | 83 | 84 | LLAMA_3_SYS_PROMPT = "You are an expert who responds with concise, correct answers. Directly state the answer without phrases like 'the correct answer is'" 85 | 86 | 87 | @dataclass 88 | class LabeledStringDataCollator: 89 | tokenizer: transformers.PreTrainedTokenizer 90 | target_name: str = "target" 91 | 92 | @staticmethod 93 | def get_tokenizer_args(tokenizer): 94 | return dict( 95 | padding=True, 96 | truncation=True, 97 | max_length=( 98 | tokenizer.model_max_length 99 | if hasattr(tokenizer, "model_max_length") 100 | else None 101 | ), 102 | return_tensors="pt", 103 | return_length=True, 104 | ) 105 | 106 | def __call__(self, instances): 107 | tokenizer_args = self.get_tokenizer_args(self.tokenizer) 108 | 109 | prompts = [str(LMText.from_(instance)) for instance in instances] 110 | 111 | if ( 112 | self.tokenizer.name_or_path 113 | and ("Llama-3" in self.tokenizer.name_or_path) 114 | and ("Instruct" in self.tokenizer.name_or_path) 115 | ): 116 | msgs = [ 117 | [ 118 | {"role": "system", "content": LLAMA_3_SYS_PROMPT}, 119 | {"role": "user", "content": p}, 120 | ] 121 | for p in prompts 122 | ] 123 | 124 | prompts = [ 125 | self.tokenizer.apply_chat_template( 126 | m, tokenize=False, add_generation_prompt=True 127 | ) 128 | for m in msgs 129 | ] 130 | 131 | inputs = self.tokenizer(prompts, **tokenizer_args) 132 | input_lengths = inputs.pop("length") 133 | 134 | if self.target_name in instances[0]: 135 | ## inputs without targets for labeling lengths. 136 | un_inputs = self.tokenizer( 137 | [ 138 | str( 139 | LMText.from_( 140 | {k: v for k, v in instance.items() if k != self.target_name} 141 | ) 142 | ) 143 | for instance in instances 144 | ], 145 | **tokenizer_args, 146 | ) 147 | un_input_lengths = un_inputs.pop("length") 148 | 149 | labels = inputs.get("input_ids").clone() 150 | for i, l in enumerate(input_lengths - un_input_lengths): 151 | labels[i, :-l] = IGNORE_LABEL 152 | inputs["labels"] = labels 153 | 154 | return inputs 155 | -------------------------------------------------------------------------------- /llm/datasets/hf/boolq.py: -------------------------------------------------------------------------------- 1 | import string 2 | import numpy as np 3 | from datasets import load_dataset 4 | 5 | from ..registry import register_dataset 6 | from ..llm_data_utils import LMText, PromptFormat 7 | 8 | 9 | def format_sample(sample, format, with_query_label=False, seed=None): 10 | target_prompt = "\nAnswer:" 11 | 12 | passage = sample["passage"] 13 | question = sample["question"] 14 | answer_map = ["False", "True"] 15 | target_idx = int(bool(sample["answer"])) 16 | 17 | output = None 18 | query_label = ( 19 | np.random.default_rng(seed=seed).binomial(1, 0.5) if with_query_label else None 20 | ) 21 | output_idx = ( 22 | target_idx 23 | if query_label == 1 24 | else ( 25 | np.random.default_rng(seed=seed).choice( 26 | list(set(range(len(answer_map))) - set([target_idx])) 27 | ) 28 | if query_label == 0 29 | else None 30 | ) 31 | ) 32 | 33 | if format == PromptFormat.CHOICE: 34 | context = "\n".join( 35 | [ 36 | "Passage:", 37 | passage, 38 | "\nQuestion:", 39 | question, 40 | "\nChoices:", 41 | *[ 42 | f" ({n}): {c}" 43 | for n, c in zip( 44 | string.ascii_lowercase[: len(answer_map)], answer_map 45 | ) 46 | ], 47 | ] 48 | ) 49 | 50 | target = string.ascii_lowercase[target_idx] 51 | if output_idx is not None: 52 | output = string.ascii_lowercase[output_idx] 53 | elif format == PromptFormat.OE: 54 | context = "\n".join( 55 | [ 56 | 'Read the following passage and answer the question. Respond with only "True" or "False" and no additional text.', 57 | f"Passage: {passage}", 58 | f"Question: {question}?", 59 | ] 60 | ) 61 | 62 | target = answer_map[target_idx] 63 | if output_idx is not None: 64 | output = answer_map[output_idx] 65 | else: 66 | raise NotImplementedError(f"Unsupported prompt format {format}.") 67 | 68 | return LMText( 69 | context=context, 70 | target_prompt=target_prompt, 71 | target=target, 72 | output=output, 73 | query_label=query_label, 74 | ) 75 | 76 | 77 | def format_sample_prompt(prompt_dataset, format, kshot=1, seed=None): 78 | if not kshot: 79 | return "" 80 | 81 | samples_idx = ( 82 | np.random.default_rng(seed=seed) 83 | .permutation(len(prompt_dataset))[:kshot] 84 | .tolist() 85 | ) 86 | 87 | fewshot_samples_prompt = [ 88 | str(LMText.from_(prompt_dataset[idx])) + "\n" for idx in samples_idx 89 | ] 90 | 91 | if format == PromptFormat.CHOICE: 92 | prompt = [ 93 | "The following are comprehension passages (with multiple-choice answers).\n", 94 | *fewshot_samples_prompt, 95 | "Now, answer the following.\n\n", 96 | ] 97 | elif format == PromptFormat.OE: 98 | prompt = [ 99 | "The following are comprehension passages (with answers).\n", 100 | *fewshot_samples_prompt, 101 | "Now, answer the following.\n\n", 102 | ] 103 | else: 104 | raise NotImplementedError(f"Unsupported prompt format {format}.") 105 | 106 | return "\n".join(prompt) 107 | 108 | 109 | def get_boolq( 110 | prompt_style=None, 111 | with_query_label=False, 112 | train_kshot=0, 113 | eval_kshot=0, 114 | num_workers=8, 115 | seed=None, 116 | use_cache=True, 117 | **_, 118 | ): 119 | format = PromptFormat(prompt_style) 120 | 121 | dataset = load_dataset("boolq") 122 | if not use_cache: 123 | dataset.cleanup_cache_files() 124 | 125 | dataset = dataset.map( 126 | lambda sample, idx: format_sample( 127 | sample, format, with_query_label=with_query_label, seed=seed + idx 128 | ).to_pydict(), 129 | with_indices=True, 130 | num_proc=num_workers, 131 | remove_columns=dataset.column_names["validation"], 132 | ) 133 | 134 | prompt_data = dataset.get("train") 135 | prompt_kshot = { 136 | "train": train_kshot, 137 | "validation": eval_kshot, 138 | "test": eval_kshot, 139 | } 140 | 141 | data_splits = { 142 | split: ds.map( 143 | lambda _, idx: { 144 | "prompt": format_sample_prompt( 145 | prompt_data, format, kshot=prompt_kshot[split], seed=seed + idx 146 | ) 147 | }, 148 | with_indices=True, 149 | num_proc=num_workers, 150 | ) 151 | for split, ds in dataset.items() 152 | } 153 | 154 | train_data = data_splits.pop("train", None) 155 | val_data = data_splits.pop("validation", None) 156 | test_data = data_splits.pop("test", None) 157 | 158 | return train_data, val_data, test_data 159 | 160 | 161 | @register_dataset 162 | def boolq(*args, **kwargs): 163 | return get_boolq(*args, **kwargs) 164 | -------------------------------------------------------------------------------- /llm/datasets/hf/copa.py: -------------------------------------------------------------------------------- 1 | import string 2 | import numpy as np 3 | from datasets import load_dataset 4 | 5 | from ..registry import register_dataset 6 | from ..llm_data_utils import LMText, PromptFormat 7 | 8 | 9 | def format_sample(sample, format, with_query_label=False, seed=None): 10 | target_prompt = "\nAnswer:" 11 | 12 | premise = sample["premise"] 13 | answer_map = [sample["choice1"], sample["choice2"]] 14 | target_idx = sample["label"] 15 | 16 | output = None 17 | query_label = ( 18 | np.random.default_rng(seed=seed).binomial(1, 0.5) if with_query_label else None 19 | ) 20 | output_idx = ( 21 | target_idx 22 | if query_label == 1 23 | else ( 24 | np.random.default_rng(seed=seed).choice( 25 | list(set(range(len(answer_map))) - set([target_idx])) 26 | ) 27 | if query_label == 0 28 | else None 29 | ) 30 | ) 31 | 32 | if format == PromptFormat.CHOICE: 33 | context = "\n".join( 34 | [ 35 | "Premise:", 36 | premise, 37 | "\nChoices:", 38 | *[ 39 | f" ({n}): {c}" 40 | for n, c in zip( 41 | string.ascii_lowercase[: len(answer_map)], answer_map 42 | ) 43 | ], 44 | ] 45 | ) 46 | 47 | target = string.ascii_lowercase[target_idx] 48 | if output_idx is not None: 49 | output = string.ascii_lowercase[output_idx] 50 | elif format == PromptFormat.OE: 51 | context = "\n".join( 52 | [ 53 | 'Read the following premise and pick the correct choice. Respond only with "1" or "2" and no additional text.', 54 | f"Premise: {premise}", 55 | f" Choice 1: {answer_map[0]}", 56 | f" Choice 2: {answer_map[1]}", 57 | ] 58 | ) 59 | 60 | target = str(target_idx + 1) 61 | if output_idx is not None: 62 | output = str(output_idx + 1) 63 | else: 64 | raise NotImplementedError(f"Unsupported prompt format {format}.") 65 | 66 | return LMText( 67 | context=context, 68 | target_prompt=target_prompt, 69 | target=target, 70 | output=output, 71 | query_label=query_label, 72 | ) 73 | 74 | 75 | def format_sample_prompt(prompt_dataset, format, kshot=1, seed=None): 76 | if not kshot: 77 | return "" 78 | 79 | samples_idx = ( 80 | np.random.default_rng(seed=seed) 81 | .permutation(len(prompt_dataset))[:kshot] 82 | .tolist() 83 | ) 84 | 85 | fewshot_samples_prompt = [ 86 | str(LMText.from_(prompt_dataset[idx])) + "\n" for idx in samples_idx 87 | ] 88 | 89 | if format == PromptFormat.CHOICE: 90 | prompt = [ 91 | "The following are questions (with multiple-choice answers).\n", 92 | *fewshot_samples_prompt, 93 | "Now, answer the following.\n\n", 94 | ] 95 | elif format == PromptFormat.OE: 96 | prompt = [ 97 | "The following are questions (with answers).\n", 98 | *fewshot_samples_prompt, 99 | "Now, answer the following.\n\n", 100 | ] 101 | else: 102 | raise NotImplementedError(f"Unsupported prompt format {format}.") 103 | 104 | return "\n".join(prompt) 105 | 106 | 107 | def get_copa( 108 | prompt_style=None, 109 | with_query_label=False, 110 | train_kshot=0, 111 | eval_kshot=0, 112 | num_workers=8, 113 | seed=None, 114 | use_cache=True, 115 | **_, 116 | ): 117 | format = PromptFormat(prompt_style) 118 | 119 | dataset = load_dataset("super_glue", "copa", trust_remote_code=True) 120 | if not use_cache: 121 | dataset.cleanup_cache_files() 122 | dataset.pop("test", None) ## NOTE: Test has no labels. 123 | 124 | dataset = dataset.map( 125 | lambda sample, idx: format_sample( 126 | sample, format, with_query_label=with_query_label, seed=seed + idx 127 | ).to_pydict(), 128 | with_indices=True, 129 | num_proc=num_workers, 130 | remove_columns=dataset.column_names["validation"], 131 | ) 132 | 133 | prompt_data = dataset.get("train") 134 | prompt_kshot = { 135 | "train": train_kshot, 136 | "validation": eval_kshot, 137 | "test": eval_kshot, 138 | } 139 | 140 | data_splits = { 141 | split: ds.map( 142 | lambda _, idx: { 143 | "prompt": format_sample_prompt( 144 | prompt_data, format, kshot=prompt_kshot[split], seed=seed + idx 145 | ) 146 | }, 147 | with_indices=True, 148 | num_proc=num_workers, 149 | ) 150 | for split, ds in dataset.items() 151 | } 152 | 153 | train_data = data_splits.pop("train", None) 154 | val_data = data_splits.pop("validation", None) 155 | test_data = data_splits.pop("test", None) 156 | 157 | return train_data, val_data, test_data 158 | 159 | 160 | @register_dataset 161 | def copa(*args, **kwargs): 162 | return get_copa(*args, **kwargs) 163 | -------------------------------------------------------------------------------- /llm/datasets/hf/hellaswag.py: -------------------------------------------------------------------------------- 1 | import string 2 | import numpy as np 3 | from datasets import load_dataset 4 | 5 | from ..registry import register_dataset 6 | from ..llm_data_utils import LMText, PromptFormat 7 | 8 | 9 | def format_sample(sample, format, with_query_label=False, seed=None): 10 | context = " ".join([sample["ctx"], sample["ctx_a"], sample["ctx_b"]]) 11 | answer_map = sample["endings"] 12 | target_idx = int(sample["label"]) 13 | 14 | output = None 15 | query_label = ( 16 | np.random.default_rng(seed=seed).binomial(1, 0.5) if with_query_label else None 17 | ) 18 | output_idx = ( 19 | target_idx 20 | if query_label == 1 21 | else ( 22 | np.random.default_rng(seed=seed).choice( 23 | list(set(range(len(answer_map))) - set([target_idx])) 24 | ) 25 | if query_label == 0 26 | else None 27 | ) 28 | ) 29 | 30 | if format == PromptFormat.CHOICE: 31 | context = "\n".join( 32 | [ 33 | "Context:", 34 | context, 35 | "\nChoices:", 36 | *[ 37 | f" ({n}): {c}" 38 | for n, c in zip( 39 | string.ascii_lowercase[: len(answer_map)], answer_map 40 | ) 41 | ], 42 | ] 43 | ) 44 | 45 | target_prompt = "\nAnswer:" 46 | target = string.ascii_lowercase[target_idx] 47 | if output_idx is not None: 48 | output = string.ascii_lowercase[output_idx] 49 | elif format == PromptFormat.OE: 50 | context = "\n".join( 51 | [ 52 | "Complete the ending for the following paragraph.", 53 | context, 54 | ] 55 | ) 56 | 57 | target_prompt = "\nEnding:" 58 | target = answer_map[target_idx] 59 | if output_idx is not None: 60 | output = answer_map[output_idx] 61 | else: 62 | raise NotImplementedError(f"Unsupported prompt format {format}.") 63 | 64 | if with_query_label: 65 | assert int(target == output) == query_label 66 | 67 | return LMText( 68 | context=context, 69 | target_prompt=target_prompt, 70 | target=target, 71 | output=output, 72 | query_label=query_label, 73 | ) 74 | 75 | 76 | def format_sample_prompt(prompt_dataset, format, kshot=1, seed=None): 77 | if not kshot: 78 | return "" 79 | 80 | samples_idx = ( 81 | np.random.default_rng(seed=seed) 82 | .permutation(len(prompt_dataset))[:kshot] 83 | .tolist() 84 | ) 85 | 86 | fewshot_samples_prompt = [ 87 | str(LMText.from_(prompt_dataset[idx])) + "\n" for idx in samples_idx 88 | ] 89 | 90 | if format == PromptFormat.CHOICE: 91 | prompt = [ 92 | "The following are some context completions (with multiple-choice answers).\n", 93 | *fewshot_samples_prompt, 94 | "Now, answer the following.\n\n", 95 | ] 96 | elif format == PromptFormat.OE: 97 | prompt = [ 98 | "The following are some context completions (with answers).\n", 99 | *fewshot_samples_prompt, 100 | "Now, answer the following.\n\n", 101 | ] 102 | else: 103 | raise NotImplementedError(f"Unsupported prompt format {format}.") 104 | 105 | return "\n".join(prompt) 106 | 107 | 108 | def get_hellaswag( 109 | prompt_style=None, 110 | with_query_label=False, 111 | train_kshot=0, 112 | eval_kshot=0, 113 | num_workers=8, 114 | seed=None, 115 | use_cache=True, 116 | **_, 117 | ): 118 | format = PromptFormat(prompt_style) 119 | 120 | dataset = load_dataset("hellaswag", trust_remote_code=True) 121 | if not use_cache: 122 | dataset.cleanup_cache_files() 123 | dataset.pop("test", None) ## NOTE: Test has no labels. 124 | 125 | dataset = dataset.map( 126 | lambda sample, idx: format_sample( 127 | sample, format, with_query_label=with_query_label, seed=seed + idx 128 | ).to_pydict(), 129 | with_indices=True, 130 | num_proc=num_workers, 131 | remove_columns=dataset.column_names["validation"], 132 | ) 133 | 134 | prompt_data = dataset.get("train") 135 | prompt_kshot = { 136 | "train": train_kshot, 137 | "validation": eval_kshot, 138 | "test": eval_kshot, 139 | } 140 | 141 | data_splits = { 142 | split: ds.map( 143 | lambda _, idx: { 144 | "prompt": format_sample_prompt( 145 | prompt_data, format, kshot=prompt_kshot[split], seed=seed + idx 146 | ) 147 | }, 148 | with_indices=True, 149 | num_proc=num_workers, 150 | ) 151 | for split, ds in dataset.items() 152 | } 153 | 154 | train_data = data_splits.pop("train", None) 155 | val_data = data_splits.pop("validation", None) 156 | test_data = data_splits.pop("test", None) 157 | 158 | return train_data, val_data, test_data 159 | 160 | 161 | @register_dataset 162 | def hellaswag(*args, **kwargs): 163 | return get_hellaswag(*args, **kwargs) 164 | -------------------------------------------------------------------------------- /llm/datasets/hf/cosmos_qa.py: -------------------------------------------------------------------------------- 1 | import string 2 | import numpy as np 3 | from datasets import load_dataset 4 | 5 | from ..registry import register_dataset 6 | from ..llm_data_utils import LMText, PromptFormat 7 | 8 | 9 | def format_sample(sample, format, with_query_label=False, seed=None): 10 | target_prompt = "\nAnswer:" 11 | 12 | context = sample["context"] 13 | question = sample["question"] 14 | answer_map = [sample[f"answer{i}"] for i in range(4)] 15 | target_idx = int(sample["label"]) 16 | 17 | output = None 18 | query_label = ( 19 | np.random.default_rng(seed=seed).binomial(1, 0.5) if with_query_label else None 20 | ) 21 | output_idx = ( 22 | target_idx 23 | if query_label == 1 24 | else ( 25 | np.random.default_rng(seed=seed).choice( 26 | list(set(range(len(answer_map))) - set([target_idx])) 27 | ) 28 | if query_label == 0 29 | else None 30 | ) 31 | ) 32 | 33 | if format == PromptFormat.CHOICE: 34 | context = "\n".join( 35 | [ 36 | "Context:", 37 | context, 38 | "\nQuestion:", 39 | question, 40 | "\nChoices:", 41 | *[ 42 | f" ({n}): {c}" 43 | for n, c in zip( 44 | string.ascii_lowercase[: len(answer_map)], answer_map 45 | ) 46 | ], 47 | ] 48 | ) 49 | 50 | target = string.ascii_lowercase[target_idx] 51 | if output_idx is not None: 52 | output = string.ascii_lowercase[output_idx] 53 | elif format == PromptFormat.OE: 54 | context = "\n".join( 55 | [ 56 | "Read the following paragraph and answer the question.", 57 | f"Paragraph: {context}", 58 | f"Question: {question}", 59 | ] 60 | ) 61 | 62 | target = answer_map[target_idx] 63 | if output_idx is not None: 64 | output = answer_map[output_idx] 65 | else: 66 | raise NotImplementedError(f"Unsupported prompt format {format}.") 67 | 68 | return LMText( 69 | context=context, 70 | target_prompt=target_prompt, 71 | target=target, 72 | output=output, 73 | query_label=query_label, 74 | ) 75 | 76 | 77 | def format_sample_prompt(prompt_dataset, format, kshot=1, seed=None): 78 | if not kshot: 79 | return "" 80 | 81 | samples_idx = ( 82 | np.random.default_rng(seed=seed) 83 | .permutation(len(prompt_dataset))[:kshot] 84 | .tolist() 85 | ) 86 | 87 | fewshot_samples_prompt = [ 88 | str(LMText.from_(prompt_dataset[idx])) + "\n" for idx in samples_idx 89 | ] 90 | 91 | if format == PromptFormat.CHOICE: 92 | prompt = [ 93 | "The following are some contexts and questions (with multiple-choice answers).\n", 94 | *fewshot_samples_prompt, 95 | "Now, answer the following.\n\n", 96 | ] 97 | elif format == PromptFormat.OE: 98 | prompt = [ 99 | "The following are some contexts and questions (with answers).\n", 100 | *fewshot_samples_prompt, 101 | "Now, answer the following.\n\n", 102 | ] 103 | else: 104 | raise NotImplementedError(f"Unsupported prompt format {format}.") 105 | 106 | return "\n".join(prompt) 107 | 108 | 109 | def get_cosmos_qa( 110 | prompt_style=None, 111 | with_query_label=False, 112 | train_kshot=0, 113 | eval_kshot=0, 114 | num_workers=8, 115 | seed=None, 116 | use_cache=True, 117 | **_, 118 | ): 119 | format = PromptFormat(prompt_style) 120 | 121 | dataset = load_dataset("cosmos_qa", trust_remote_code=True) 122 | if not use_cache: 123 | dataset.cleanup_cache_files() 124 | dataset.pop("test", None) ## NOTE: Test has no labels. 125 | 126 | dataset = dataset.map( 127 | lambda sample, idx: format_sample( 128 | sample, format, with_query_label=with_query_label, seed=seed + idx 129 | ).to_pydict(), 130 | with_indices=True, 131 | num_proc=num_workers, 132 | remove_columns=dataset.column_names["validation"], 133 | ) 134 | 135 | prompt_data = dataset.get("train") 136 | prompt_kshot = { 137 | "train": train_kshot, 138 | "validation": eval_kshot, 139 | "test": eval_kshot, 140 | } 141 | 142 | data_splits = { 143 | split: ds.map( 144 | lambda _, idx: { 145 | "prompt": format_sample_prompt( 146 | prompt_data, format, kshot=prompt_kshot[split], seed=seed + idx 147 | ) 148 | }, 149 | with_indices=True, 150 | num_proc=num_workers, 151 | ) 152 | for split, ds in dataset.items() 153 | } 154 | 155 | train_data = data_splits.pop("train", None) 156 | val_data = data_splits.pop("validation", None) 157 | test_data = data_splits.pop("test", None) 158 | 159 | return train_data, val_data, test_data 160 | 161 | 162 | @register_dataset 163 | def cosmos_qa(*args, **kwargs): 164 | return get_cosmos_qa(*args, **kwargs) 165 | -------------------------------------------------------------------------------- /llm/datasets/hf/story_cloze.py: -------------------------------------------------------------------------------- 1 | import os 2 | import string 3 | from pathlib import Path 4 | import numpy as np 5 | from datasets import load_dataset 6 | 7 | from ..registry import register_dataset 8 | from ..llm_data_utils import LMText, PromptFormat 9 | 10 | 11 | def format_sample(sample, format, with_query_label=False, seed=None): 12 | story = " ".join([sample[f"input_sentence_{i}"] for i in range(1, 5)]) 13 | answer_map = [sample["sentence_quiz1"], sample["sentence_quiz2"]] 14 | target_idx = sample["answer_right_ending"] - 1 15 | 16 | output = None 17 | query_label = ( 18 | np.random.default_rng(seed=seed).binomial(1, 0.5) if with_query_label else None 19 | ) 20 | output_idx = ( 21 | target_idx 22 | if query_label == 1 23 | else ( 24 | np.random.default_rng(seed=seed).choice( 25 | list(set(range(len(answer_map))) - set([target_idx])) 26 | ) 27 | if query_label == 0 28 | else None 29 | ) 30 | ) 31 | 32 | if format == PromptFormat.CHOICE: 33 | context = "\n".join( 34 | [ 35 | "Story:", 36 | story, 37 | "\nChoices:", 38 | *[ 39 | f" ({n}): {c}" 40 | for n, c in zip( 41 | string.ascii_lowercase[: len(answer_map)], answer_map 42 | ) 43 | ], 44 | ] 45 | ) 46 | 47 | target_prompt = "\nAnswer:" 48 | target = string.ascii_lowercase[target_idx] 49 | if output_idx is not None: 50 | output = string.ascii_lowercase[output_idx] 51 | elif format == PromptFormat.OE: 52 | context = "\n".join( 53 | [ 54 | "Complete the ending of the following story.", 55 | story, 56 | ] 57 | ) 58 | 59 | target_prompt = "\nEnding:" 60 | target = answer_map[target_idx] 61 | if output_idx is not None: 62 | output = answer_map[output_idx] 63 | else: 64 | raise NotImplementedError(f"Unsupported prompt format {format}.") 65 | 66 | return LMText( 67 | context=context, 68 | target_prompt=target_prompt, 69 | target=target, 70 | output=output, 71 | query_label=query_label, 72 | ) 73 | 74 | 75 | def format_sample_prompt(prompt_dataset, format, kshot=1, seed=None): 76 | if not kshot: 77 | return "" 78 | 79 | samples_idx = ( 80 | np.random.default_rng(seed=seed) 81 | .permutation(len(prompt_dataset))[:kshot] 82 | .tolist() 83 | ) 84 | 85 | fewshot_samples_prompt = [ 86 | str(LMText.from_(prompt_dataset[idx])) + "\n" for idx in samples_idx 87 | ] 88 | 89 | if format == PromptFormat.CHOICE: 90 | prompt = [ 91 | "The following are stories with endings (multiple-choice).\n", 92 | *fewshot_samples_prompt, 93 | "Now, complete the next story.\n\n", 94 | ] 95 | elif format == PromptFormat.OE: 96 | prompt = [ 97 | "The following are stories (with endings).\n", 98 | *fewshot_samples_prompt, 99 | "Now, complete the next story.\n\n", 100 | ] 101 | else: 102 | raise NotImplementedError(f"Unsupported prompt format {format}.") 103 | 104 | return "\n".join(prompt) 105 | 106 | 107 | def get_story_cloze( 108 | prompt_style=None, 109 | with_query_label=False, 110 | eval_kshot=0, 111 | num_workers=8, 112 | seed=None, 113 | use_cache=True, 114 | **_, 115 | ): 116 | format = PromptFormat(prompt_style) 117 | 118 | dataset = load_dataset( 119 | "story_cloze", 120 | "2018", 121 | ## NOTE: manually place the CSV in data_dir. 122 | data_dir=f"{os.environ.get('HF_HOME', Path.home() / 'huggingface')}/datasets/story_cloze/2018", 123 | trust_remote_code=True, 124 | ) 125 | if not use_cache: 126 | dataset.cleanup_cache_files() 127 | 128 | dataset = dataset.map( 129 | lambda sample, idx: format_sample( 130 | sample, format, with_query_label=with_query_label, seed=seed + idx 131 | ).to_pydict(), 132 | with_indices=True, 133 | num_proc=num_workers, 134 | remove_columns=dataset.column_names["validation"], 135 | ) 136 | prompt_data = dataset.get("validation") 137 | prompt_kshot = { 138 | "validation": eval_kshot, 139 | } 140 | 141 | data_splits = { 142 | split: ds.map( 143 | lambda _, idx: { 144 | "prompt": format_sample_prompt( 145 | prompt_data, format, kshot=prompt_kshot[split], seed=seed + idx 146 | ) 147 | }, 148 | with_indices=True, 149 | num_proc=num_workers, 150 | ) 151 | for split, ds in dataset.items() 152 | } 153 | 154 | train_data = data_splits.pop("train", None) 155 | val_data = data_splits.pop("validation", None) 156 | test_data = data_splits.pop("test", None) 157 | 158 | return train_data, val_data, test_data 159 | 160 | 161 | @register_dataset 162 | def story_cloze(*args, **kwargs): 163 | return get_story_cloze(*args, **kwargs) 164 | -------------------------------------------------------------------------------- /llm/datasets/hf/trec.py: -------------------------------------------------------------------------------- 1 | import string 2 | import numpy as np 3 | from datasets import load_dataset 4 | 5 | from ..registry import register_dataset 6 | from ..llm_data_utils import LMText, PromptFormat 7 | 8 | 9 | def format_sample(sample, format, with_query_label=False, seed=None): 10 | target_prompt = "\nAnswer:" 11 | 12 | question = sample["text"] 13 | answer_map = [ 14 | "Abbreviation", 15 | "Entity", 16 | "Description and abstract concept", 17 | "Human being", 18 | "Location", 19 | "Numeric value", 20 | ] 21 | target_idx = sample["coarse_label"] 22 | 23 | output = None 24 | query_label = ( 25 | np.random.default_rng(seed=seed).binomial(1, 0.5) if with_query_label else None 26 | ) 27 | output_idx = ( 28 | target_idx 29 | if query_label == 1 30 | else ( 31 | np.random.default_rng(seed=seed).choice( 32 | list(set(range(len(answer_map))) - set([target_idx])) 33 | ) 34 | if query_label == 0 35 | else None 36 | ) 37 | ) 38 | 39 | if format == PromptFormat.CHOICE: 40 | context = "\n".join( 41 | [ 42 | "Question:", 43 | question, 44 | "\nChoices:", 45 | *[ 46 | f" ({n}): {c}" 47 | for n, c in zip( 48 | string.ascii_lowercase[: len(answer_map)], answer_map 49 | ) 50 | ], 51 | ] 52 | ) 53 | 54 | target = string.ascii_lowercase[target_idx] 55 | if output_idx is not None: 56 | output = string.ascii_lowercase[output_idx] 57 | elif format == PromptFormat.OE: 58 | context = "\n".join( 59 | [ 60 | "Read the following question and then pick a category that describes the question.", 61 | f"Question: {question}", 62 | f"What category is the question in? Choose one from {', '.join(answer_map)}. Respond with only the category and no additional text.", 63 | ] 64 | ) 65 | 66 | target = answer_map[target_idx] 67 | if output_idx is not None: 68 | output = answer_map[output_idx] 69 | else: 70 | raise NotImplementedError(f"Unsupported prompt format {format}.") 71 | 72 | return LMText( 73 | context=context, 74 | target_prompt=target_prompt, 75 | target=target, 76 | output=output, 77 | query_label=query_label, 78 | ) 79 | 80 | 81 | def format_sample_prompt(prompt_dataset, format, kshot=1, seed=None): 82 | if not kshot: 83 | return "" 84 | 85 | samples_idx = ( 86 | np.random.default_rng(seed=seed) 87 | .permutation(len(prompt_dataset))[:kshot] 88 | .tolist() 89 | ) 90 | 91 | fewshot_samples_prompt = [ 92 | str(LMText.from_(prompt_dataset[idx])) + "\n" for idx in samples_idx 93 | ] 94 | 95 | if format == PromptFormat.CHOICE: 96 | prompt = [ 97 | "The following are questions with multiple-choice answers.\n", 98 | *fewshot_samples_prompt, 99 | "Now, answer the next question.\n\n", 100 | ] 101 | elif format == PromptFormat.OE: 102 | prompt = [ 103 | "The following are questions with answers.\n", 104 | *fewshot_samples_prompt, 105 | "Now, answer the next question.", 106 | ] 107 | else: 108 | raise NotImplementedError(f"Unsupported prompt format {format}.") 109 | 110 | return "\n".join(prompt) 111 | 112 | 113 | def get_trec( 114 | prompt_style=None, 115 | with_query_label=False, 116 | train_kshot=0, 117 | eval_kshot=0, 118 | num_workers=8, 119 | seed=None, 120 | use_cache=True, 121 | **_, 122 | ): 123 | format = PromptFormat(prompt_style) 124 | 125 | dataset = load_dataset("trec", trust_remote_code=True) 126 | if not use_cache: 127 | dataset.cleanup_cache_files() 128 | 129 | dataset = dataset.map( 130 | lambda sample, idx: format_sample( 131 | sample, format, with_query_label=with_query_label, seed=seed + idx 132 | ).to_pydict(), 133 | with_indices=True, 134 | num_proc=num_workers, 135 | remove_columns=dataset.column_names["test"], 136 | ) 137 | 138 | prompt_data = dataset.get("train") 139 | prompt_kshot = { 140 | "train": train_kshot, 141 | "validation": eval_kshot, 142 | "test": eval_kshot, 143 | } 144 | 145 | data_splits = { 146 | split: ds.map( 147 | lambda _, idx: { 148 | "prompt": format_sample_prompt( 149 | prompt_data, format, kshot=prompt_kshot[split], seed=seed + idx 150 | ) 151 | }, 152 | with_indices=True, 153 | num_proc=num_workers, 154 | ) 155 | for split, ds in dataset.items() 156 | } 157 | 158 | train_data = data_splits.pop("train", None) 159 | val_data = data_splits.pop("validation", None) 160 | test_data = data_splits.pop("test", None) 161 | 162 | return train_data, val_data, test_data 163 | 164 | 165 | @register_dataset 166 | def trec(*args, **kwargs): 167 | return get_trec(*args, **kwargs) 168 | -------------------------------------------------------------------------------- /llm/datasets/hf/math_qa.py: -------------------------------------------------------------------------------- 1 | import string 2 | import numpy as np 3 | from datasets import load_dataset 4 | 5 | from ..registry import register_dataset 6 | from ..llm_data_utils import LMText, PromptFormat 7 | 8 | 9 | def format_sample(sample, style, with_query_label=False, seed=None): 10 | target_prompt = "\nAnswer:" 11 | 12 | problem = sample["Problem"] 13 | answer_map = [opt.split(")")[-1].strip() for opt in sample["options"].split(",")] 14 | target_idx = string.ascii_lowercase.index(sample["correct"]) 15 | 16 | output = None 17 | query_label = ( 18 | np.random.default_rng(seed=seed).binomial(1, 0.5) if with_query_label else None 19 | ) 20 | output_idx = ( 21 | target_idx 22 | if query_label == 1 23 | else ( 24 | np.random.default_rng(seed=seed).choice( 25 | list(set(range(len(answer_map))) - set([target_idx])) 26 | ) 27 | if query_label == 0 28 | else None 29 | ) 30 | ) 31 | 32 | if style == PromptFormat.CHOICE: 33 | context = "\n".join( 34 | [ 35 | "Problem:", 36 | problem, 37 | "\nChoices:", 38 | *[ 39 | f" ({n}): {c}" 40 | for n, c in zip( 41 | string.ascii_lowercase[: len(answer_map)], answer_map 42 | ) 43 | ], 44 | ] 45 | ) 46 | 47 | target = string.ascii_lowercase[target_idx] 48 | if output_idx is not None: 49 | output = string.ascii_lowercase[output_idx] 50 | elif style == "oe": 51 | context = "\n".join( 52 | [ 53 | "Provide your best answer for the following math problem.", 54 | f"Problem: {problem}", 55 | ] 56 | ) 57 | 58 | target = answer_map[target_idx] 59 | if output_idx is not None: 60 | output = answer_map[output_idx] 61 | else: 62 | raise NotImplementedError(f"Unsupported prompt format {format}.") 63 | 64 | return LMText( 65 | context=context, 66 | target_prompt=target_prompt, 67 | target=target, 68 | output=output, 69 | query_label=query_label, 70 | ) 71 | 72 | 73 | def format_sample_prompt(prompt_dataset, format, kshot=1, seed=None): 74 | if not kshot: 75 | return "" 76 | 77 | samples_idx = ( 78 | np.random.default_rng(seed=seed) 79 | .permutation(len(prompt_dataset))[:kshot] 80 | .tolist() 81 | ) 82 | 83 | fewshot_samples_prompt = [ 84 | str(LMText.from_(prompt_dataset[idx])) + "\n" for idx in samples_idx 85 | ] 86 | 87 | if format == PromptFormat.CHOICE: 88 | prompt = [ 89 | "The following are math questions (with multiple-choice answers).\n", 90 | *fewshot_samples_prompt, 91 | "Now, answer the next question.\n\n", 92 | ] 93 | elif format == PromptFormat.OE: 94 | prompt = [ 95 | "The following are math questions (with answers).\n", 96 | *fewshot_samples_prompt, 97 | "Now, answer the next question.\n\n", 98 | ] 99 | else: 100 | raise NotImplementedError(f"Unsupported prompt format {format}.") 101 | 102 | return "\n".join(prompt) 103 | 104 | 105 | def get_math_qa( 106 | prompt_style=None, 107 | with_query_label=False, 108 | train_kshot=0, 109 | eval_kshot=0, 110 | num_workers=8, 111 | seed=None, 112 | use_cache=True, 113 | **_, 114 | ): 115 | format = PromptFormat(prompt_style) 116 | 117 | dataset = load_dataset("math_qa", trust_remote_code=True) 118 | if not use_cache: 119 | dataset.cleanup_cache_files() 120 | 121 | dataset = dataset.filter( 122 | lambda x: len([opt.split(")")[-1].strip() for opt in x["options"].split(",")]) 123 | == len( 124 | np.unique([opt.split(")")[-1].strip() for opt in x["options"].split(",")]) 125 | ), 126 | num_proc=num_workers, 127 | ).map( 128 | lambda sample, idx: format_sample( 129 | sample, format, with_query_label=with_query_label, seed=seed + idx 130 | ).to_pydict(), 131 | with_indices=True, 132 | num_proc=num_workers, 133 | remove_columns=dataset.column_names["test"], 134 | ) 135 | 136 | prompt_data = dataset.get("train") 137 | prompt_kshot = { 138 | "train": train_kshot, 139 | "validation": eval_kshot, 140 | "test": eval_kshot, 141 | } 142 | 143 | data_splits = { 144 | split: ds.map( 145 | lambda _, idx: { 146 | "prompt": format_sample_prompt( 147 | prompt_data, format, kshot=prompt_kshot[split], seed=seed + idx 148 | ) 149 | }, 150 | with_indices=True, 151 | num_proc=num_workers, 152 | ) 153 | for split, ds in dataset.items() 154 | } 155 | 156 | train_data = data_splits.pop("train", None) 157 | val_data = data_splits.pop("validation", None) 158 | test_data = data_splits.pop("test", None) 159 | 160 | return train_data, val_data, test_data 161 | 162 | 163 | @register_dataset 164 | def math_qa(*args, **kwargs): 165 | return get_math_qa(*args, **kwargs) 166 | -------------------------------------------------------------------------------- /llm/datasets/hf/arc.py: -------------------------------------------------------------------------------- 1 | import string 2 | import numpy as np 3 | from datasets import load_dataset 4 | 5 | from ..registry import register_dataset 6 | from ..llm_data_utils import LMText, PromptFormat 7 | 8 | 9 | def format_sample(sample, format, with_query_label=False, seed=None): 10 | sample["answerKey"] = sample["answerKey"].lower() 11 | sample["answerKey"] = ( 12 | string.ascii_lowercase.index(sample["answerKey"]) 13 | if sample["answerKey"] in string.ascii_lowercase 14 | else int(sample["answerKey"]) - 1 15 | ) 16 | 17 | target_prompt = "\nAnswer:" 18 | 19 | question = sample["question"] 20 | answer_map = sample["choices"]["text"] 21 | target_idx = sample["answerKey"] 22 | 23 | output = None 24 | query_label = ( 25 | np.random.default_rng(seed=seed).binomial(1, 0.5) if with_query_label else None 26 | ) 27 | output_idx = ( 28 | target_idx 29 | if query_label == 1 30 | else ( 31 | np.random.default_rng(seed=seed).choice( 32 | list(set(range(len(answer_map))) - set([target_idx])) 33 | ) 34 | if query_label == 0 35 | else None 36 | ) 37 | ) 38 | 39 | if format == PromptFormat.CHOICE: 40 | context = "\n".join( 41 | [ 42 | "Question:", 43 | question, 44 | "\nChoices:", 45 | *[ 46 | f" ({n}): {c}" 47 | for n, c in zip( 48 | string.ascii_lowercase[: len(answer_map)], answer_map 49 | ) 50 | ], 51 | ] 52 | ) 53 | 54 | target = string.ascii_lowercase[target_idx] 55 | if output_idx is not None: 56 | output = string.ascii_lowercase[output_idx] 57 | elif format == PromptFormat.OE: 58 | context = "\n".join( 59 | [ 60 | f"The question is: {question}", 61 | ] 62 | ) 63 | 64 | target = answer_map[target_idx] 65 | if output_idx is not None: 66 | output = answer_map[output_idx] 67 | else: 68 | raise NotImplementedError(f"Unsupported prompt format {format}.") 69 | 70 | return LMText( 71 | context=context, 72 | target_prompt=target_prompt, 73 | target=target, 74 | output=output, 75 | query_label=query_label, 76 | ) 77 | 78 | 79 | def format_sample_prompt(prompt_dataset, format, kshot=1, seed=None): 80 | if not kshot: 81 | return "" 82 | 83 | samples_idx = ( 84 | np.random.default_rng(seed=seed) 85 | .permutation(len(prompt_dataset))[:kshot] 86 | .tolist() 87 | ) 88 | 89 | fewshot_samples_prompt = [ 90 | str(LMText.from_(prompt_dataset[idx])) + "\n" for idx in samples_idx 91 | ] 92 | 93 | if format == PromptFormat.CHOICE: 94 | prompt = [ 95 | "The following are questions (with multiple-choice answers).\n", 96 | *fewshot_samples_prompt, 97 | "Now, answer the following.\n\n", 98 | ] 99 | elif format == PromptFormat.OE: 100 | prompt = [ 101 | "The following are questions (with answers).\n", 102 | *fewshot_samples_prompt, 103 | "Now, answer the following.\n\n", 104 | ] 105 | else: 106 | raise NotImplementedError(f"Unsupported prompt format {format}.") 107 | 108 | return "\n".join(prompt) 109 | 110 | 111 | def get_arc( 112 | subset=None, 113 | prompt_style=None, 114 | with_query_label=False, 115 | train_kshot=0, 116 | eval_kshot=0, 117 | num_workers=8, 118 | seed=None, 119 | use_cache=True, 120 | **_, 121 | ): 122 | format = PromptFormat(prompt_style) 123 | 124 | dataset = load_dataset("ai2_arc", subset) 125 | if not use_cache: 126 | dataset.cleanup_cache_files() 127 | 128 | dataset = dataset.map( 129 | lambda sample, idx: format_sample( 130 | sample, format, with_query_label=with_query_label, seed=seed + idx 131 | ).to_pydict(), 132 | with_indices=True, 133 | num_proc=num_workers, 134 | remove_columns=dataset.column_names["test"], 135 | ) 136 | 137 | prompt_data = dataset.get("train") 138 | prompt_kshot = { 139 | "train": train_kshot, 140 | "validation": eval_kshot, 141 | "test": eval_kshot, 142 | } 143 | 144 | data_splits = { 145 | split: ds.map( 146 | lambda _, idx: { 147 | "prompt": format_sample_prompt( 148 | prompt_data, format, kshot=prompt_kshot[split], seed=seed + idx 149 | ) 150 | }, 151 | with_indices=True, 152 | num_proc=num_workers, 153 | ) 154 | for split, ds in dataset.items() 155 | } 156 | 157 | train_data = data_splits.pop("train", None) 158 | val_data = data_splits.pop("validation", None) 159 | test_data = data_splits.pop("test", None) 160 | 161 | return train_data, val_data, test_data 162 | 163 | 164 | @register_dataset 165 | def arc(*args, **kwargs): 166 | return get_arc(*args, **kwargs, subset="ARC-Easy") 167 | 168 | 169 | @register_dataset 170 | def arc_challenge(*args, **kwargs): 171 | return get_arc(*args, **kwargs, subset="ARC-Challenge") 172 | -------------------------------------------------------------------------------- /llm/datasets/hf/cb.py: -------------------------------------------------------------------------------- 1 | import string 2 | import numpy as np 3 | from datasets import load_dataset 4 | 5 | from ..registry import register_dataset 6 | from ..llm_data_utils import LMText, PromptFormat 7 | 8 | 9 | def format_sample(sample, format, with_query_label=False, seed=None): 10 | target_prompt = "\nAnswer:" 11 | 12 | premise = sample["premise"] 13 | hypothesis = sample["hypothesis"] 14 | answer_map = ["Yes", "No", "It's impossible to say"] 15 | target_idx = sample["label"] 16 | 17 | output = None 18 | query_label = ( 19 | np.random.default_rng(seed=seed).binomial(1, 0.5) if with_query_label else None 20 | ) 21 | output_idx = ( 22 | target_idx 23 | if query_label == 1 24 | else ( 25 | np.random.default_rng(seed=seed).choice( 26 | list(set(range(len(answer_map))) - set([target_idx])) 27 | ) 28 | if query_label == 0 29 | else None 30 | ) 31 | ) 32 | 33 | if format == PromptFormat.CHOICE: 34 | context = "\n".join( 35 | [ 36 | "Premise:", 37 | premise, 38 | "\nHypothesis:", 39 | hypothesis, 40 | "\nChoices:", 41 | *[ 42 | f" ({n}): {c}" 43 | for n, c in zip( 44 | string.ascii_lowercase[: len(answer_map)], answer_map 45 | ) 46 | ], 47 | ] 48 | ) 49 | 50 | target = string.ascii_lowercase[target_idx] 51 | if output_idx is not None: 52 | output = string.ascii_lowercase[output_idx] 53 | elif format == PromptFormat.OE: 54 | context = "\n".join( 55 | [ 56 | 'Read the following premise and answer if the hypothesis is true. Respond only with "Yes", "No", or "It\'s impossible to say" and no additional text.', 57 | premise, 58 | f"Hypothesis: {hypothesis}.", 59 | ] 60 | ) 61 | 62 | target = answer_map[target_idx] 63 | if output_idx is not None: 64 | output = answer_map[output_idx] 65 | else: 66 | raise NotImplementedError(f"Unsupported prompt format {format}.") 67 | 68 | return LMText( 69 | context=context, 70 | target_prompt=target_prompt, 71 | target=target, 72 | output=output, 73 | query_label=query_label, 74 | ) 75 | 76 | 77 | def format_sample_prompt(prompt_dataset, format, kshot=1, seed=None): 78 | if not kshot: 79 | return "" 80 | 81 | samples_idx = ( 82 | np.random.default_rng(seed=seed) 83 | .permutation(len(prompt_dataset))[:kshot] 84 | .tolist() 85 | ) 86 | 87 | fewshot_samples_prompt = [ 88 | str(LMText.from_(prompt_dataset[idx])) + "\n" for idx in samples_idx 89 | ] 90 | 91 | if format == PromptFormat.CHOICE: 92 | prompt = [ 93 | "The following are questions (with premise, hypothesis, and answers) about entailment.\n", 94 | *fewshot_samples_prompt, 95 | "Now, answer the following.\n\n", 96 | ] 97 | elif format == PromptFormat.OE: 98 | prompt = [ 99 | "The following are questions (with premise, hypothesis, and answers) about entailment.\n", 100 | *fewshot_samples_prompt, 101 | "Now, answer the following.\n\n", 102 | ] 103 | else: 104 | raise NotImplementedError(f"Unsupported prompt format {format}.") 105 | 106 | return "\n".join(prompt) 107 | 108 | 109 | def get_cb( 110 | prompt_style=None, 111 | with_query_label=False, 112 | train_kshot=0, 113 | eval_kshot=0, 114 | num_workers=8, 115 | seed=None, 116 | use_cache=True, 117 | **_, 118 | ): 119 | format = PromptFormat(prompt_style) 120 | 121 | dataset = load_dataset("super_glue", "cb", trust_remote_code=True) 122 | if not use_cache: 123 | dataset.cleanup_cache_files() 124 | dataset.pop("test", None) ## NOTE: Test does not have labels. 125 | 126 | dataset = dataset.map( 127 | lambda sample, idx: format_sample( 128 | sample, format, with_query_label=with_query_label, seed=seed + idx 129 | ).to_pydict(), 130 | with_indices=True, 131 | num_proc=num_workers, 132 | remove_columns=dataset.column_names["validation"], 133 | ) 134 | 135 | prompt_data = dataset.get("train") 136 | prompt_kshot = { 137 | "train": train_kshot, 138 | "validation": eval_kshot, 139 | "test": eval_kshot, 140 | } 141 | 142 | data_splits = { 143 | split: ds.map( 144 | lambda _, idx: { 145 | "prompt": format_sample_prompt( 146 | prompt_data, format, kshot=prompt_kshot[split], seed=seed + idx 147 | ) 148 | }, 149 | with_indices=True, 150 | num_proc=num_workers, 151 | ) 152 | for split, ds in dataset.items() 153 | } 154 | 155 | train_data = data_splits.pop("train", None) 156 | val_data = data_splits.pop("validation", None) 157 | test_data = data_splits.pop("test", None) 158 | 159 | return train_data, val_data, test_data 160 | 161 | 162 | @register_dataset 163 | def cb(*args, **kwargs): 164 | return get_cb(*args, **kwargs) 165 | -------------------------------------------------------------------------------- /llm/datasets/hf/snli.py: -------------------------------------------------------------------------------- 1 | import string 2 | import numpy as np 3 | from datasets import load_dataset 4 | 5 | from ..registry import register_dataset 6 | from ..llm_data_utils import LMText, PromptFormat 7 | 8 | 9 | def format_sample(sample, format, with_query_label=False, seed=None): 10 | target_prompt = "\nAnswer:" 11 | 12 | premise = sample["premise"] 13 | hypothesis = sample["hypothesis"] 14 | answer_map = ["Yes", "It's impossible to say", "No"] 15 | target_idx = sample["label"] 16 | 17 | output = None 18 | query_label = ( 19 | np.random.default_rng(seed=seed).binomial(1, 0.5) if with_query_label else None 20 | ) 21 | output_idx = ( 22 | target_idx 23 | if query_label == 1 24 | else ( 25 | np.random.default_rng(seed=seed).choice( 26 | list(set(range(len(answer_map))) - set([target_idx])) 27 | ) 28 | if query_label == 0 29 | else None 30 | ) 31 | ) 32 | 33 | if format == PromptFormat.CHOICE: 34 | context = "\n".join( 35 | [ 36 | "Premise:", 37 | premise, 38 | "\nHypothesis:", 39 | hypothesis, 40 | "\nChoices:", 41 | *[ 42 | f" ({n}): {c}" 43 | for n, c in zip( 44 | string.ascii_lowercase[: len(answer_map)], answer_map 45 | ) 46 | ], 47 | ] 48 | ) 49 | 50 | target = string.ascii_lowercase[target_idx] 51 | if output_idx is not None: 52 | output = string.ascii_lowercase[output_idx] 53 | elif format == PromptFormat.OE: 54 | context = "\n".join( 55 | [ 56 | 'Read the following premise and answer if the hypothesis is true. Respond with only "Yes", "No", or "It\'s impossible to say" answer and no additional text.', 57 | premise, 58 | f"Hypothesis: {hypothesis}", 59 | ] 60 | ) 61 | 62 | target = answer_map[target_idx] 63 | if output_idx is not None: 64 | output = answer_map[output_idx] 65 | else: 66 | raise NotImplementedError(f"Unsupported prompt format {format}.") 67 | 68 | return LMText( 69 | context=context, 70 | target_prompt=target_prompt, 71 | target=target, 72 | output=output, 73 | query_label=query_label, 74 | ) 75 | 76 | 77 | def format_sample_prompt(prompt_dataset, format, kshot=1, seed=None): 78 | if not kshot: 79 | return "" 80 | 81 | samples_idx = ( 82 | np.random.default_rng(seed=seed) 83 | .permutation(len(prompt_dataset))[:kshot] 84 | .tolist() 85 | ) 86 | 87 | fewshot_samples_prompt = [ 88 | str(LMText.from_(prompt_dataset[idx])) + "\n" for idx in samples_idx 89 | ] 90 | 91 | if format == PromptFormat.CHOICE: 92 | prompt = [ 93 | "The following are questions (with multiple-choice answers) about entailment.\n", 94 | *fewshot_samples_prompt, 95 | "Now, answer the following.\n\n", 96 | ] 97 | elif format == PromptFormat.OE: 98 | prompt = [ 99 | "The following are questions (with premise, hypothesis, and answers) about entailment.\n", 100 | *fewshot_samples_prompt, 101 | "Now, answer the following.\n\n", 102 | ] 103 | else: 104 | raise NotImplementedError(f"Unsupported prompt format {format}.") 105 | 106 | return "\n".join(prompt) 107 | 108 | 109 | def get_snli( 110 | prompt_style=None, 111 | with_query_label=False, 112 | train_kshot=0, 113 | eval_kshot=0, 114 | num_workers=8, 115 | seed=None, 116 | use_cache=True, 117 | **_, 118 | ): 119 | format = PromptFormat(prompt_style) 120 | 121 | dataset = load_dataset("snli") 122 | if not use_cache: 123 | dataset.cleanup_cache_files() 124 | 125 | dataset = dataset.filter( 126 | lambda x: x["label"] in [0, 1, 2], num_proc=num_workers 127 | ).map( 128 | lambda sample, idx: format_sample( 129 | sample, 130 | format, 131 | with_query_label=with_query_label, 132 | seed=seed + idx, 133 | ).to_pydict(), 134 | with_indices=True, 135 | num_proc=num_workers, 136 | remove_columns=dataset.column_names["test"], 137 | ) 138 | 139 | prompt_data = dataset.get("train") 140 | prompt_kshot = { 141 | "train": train_kshot, 142 | "validation": eval_kshot, 143 | "test": eval_kshot, 144 | } 145 | 146 | data_splits = { 147 | split: ds.map( 148 | lambda _, idx: { 149 | "prompt": format_sample_prompt( 150 | prompt_data, format, kshot=prompt_kshot[split], seed=seed + idx 151 | ) 152 | }, 153 | with_indices=True, 154 | num_proc=num_workers, 155 | ) 156 | for split, ds in dataset.items() 157 | } 158 | 159 | train_data = data_splits.pop("train", None) 160 | val_data = data_splits.pop("validation", None) 161 | test_data = data_splits.pop("test", None) 162 | 163 | return train_data, val_data, test_data 164 | 165 | 166 | @register_dataset 167 | def snli(*args, **kwargs): 168 | return get_snli(*args, **kwargs) 169 | -------------------------------------------------------------------------------- /llm/datasets/hf/multirc.py: -------------------------------------------------------------------------------- 1 | import string 2 | import numpy as np 3 | from datasets import load_dataset 4 | 5 | from ..registry import register_dataset 6 | from ..llm_data_utils import LMText, PromptFormat 7 | 8 | 9 | def format_sample(sample, format, with_query_label=False, seed=None): 10 | paragraph = sample["paragraph"] 11 | question = sample["question"] 12 | answer = sample["answer"] 13 | answer_map = ["No", "Yes"] 14 | target_idx = sample["label"] 15 | 16 | output = None 17 | query_label = ( 18 | np.random.default_rng(seed=seed).binomial(1, 0.5) if with_query_label else None 19 | ) 20 | output_idx = ( 21 | target_idx 22 | if query_label == 1 23 | else ( 24 | np.random.default_rng(seed=seed).choice( 25 | list(set(range(len(answer_map))) - set([target_idx])) 26 | ) 27 | if query_label == 0 28 | else None 29 | ) 30 | ) 31 | 32 | if format == PromptFormat.CHOICE: 33 | context = "\n".join( 34 | [ 35 | "Paragraph:", 36 | paragraph, 37 | f"\nQ: {question}", 38 | f"\nA: {answer}" "\n\nChoices:", 39 | *[ 40 | f" ({n}): {c}" 41 | for n, c in zip( 42 | string.ascii_lowercase[: len(answer_map)], answer_map 43 | ) 44 | ], 45 | ] 46 | ) 47 | 48 | target_prompt = "\nAnswer:" 49 | target = string.ascii_lowercase[target_idx] 50 | if output_idx is not None: 51 | output = string.ascii_lowercase[output_idx] 52 | elif format == PromptFormat.OE: 53 | context = "\n".join( 54 | [ 55 | 'Read the following paragraph along with the question and answer. Then, respond with whether the answer is correct. Respond with only "Yes" or "No" and no additional text.', 56 | f"Passage: {paragraph}", 57 | f"Question: {question}", 58 | f"Answer: {answer}. Is the answer correct?", 59 | ] 60 | ) 61 | 62 | target_prompt = "\nResponse:" 63 | target = answer_map[target_idx] 64 | if output_idx is not None: 65 | output = answer_map[output_idx] 66 | else: 67 | raise NotImplementedError(f"Unsupported prompt format {format}.") 68 | 69 | return LMText( 70 | context=context, 71 | target_prompt=target_prompt, 72 | target=target, 73 | output=output, 74 | query_label=query_label, 75 | ) 76 | 77 | 78 | def format_sample_prompt(prompt_dataset, format, kshot=1, seed=None): 79 | if not kshot: 80 | return "" 81 | 82 | samples_idx = ( 83 | np.random.default_rng(seed=seed) 84 | .permutation(len(prompt_dataset))[:kshot] 85 | .tolist() 86 | ) 87 | 88 | fewshot_samples_prompt = [ 89 | str(LMText.from_(prompt_dataset[idx])) + "\n" for idx in samples_idx 90 | ] 91 | 92 | if format == PromptFormat.CHOICE: 93 | prompt = [ 94 | "The following are reading comprehensions (with multiple-choice answers).\n", 95 | *fewshot_samples_prompt, 96 | "Now, answer the following.\n\n", 97 | ] 98 | elif format == PromptFormat.OE: 99 | prompt = [ 100 | "The following are reading comprehensions (with answers).\n", 101 | *fewshot_samples_prompt, 102 | "Now, answer the following.\n\n", 103 | ] 104 | else: 105 | raise NotImplementedError(f"Unsupported prompt format {format}.") 106 | 107 | return "\n".join(prompt) 108 | 109 | 110 | def get_multirc( 111 | prompt_style=None, 112 | with_query_label=False, 113 | train_kshot=0, 114 | eval_kshot=0, 115 | num_workers=8, 116 | seed=None, 117 | use_cache=True, 118 | **_, 119 | ): 120 | format = PromptFormat(prompt_style) 121 | 122 | dataset = load_dataset("super_glue", "multirc", trust_remote_code=True) 123 | if not use_cache: 124 | dataset.cleanup_cache_files() 125 | dataset.pop("test", None) ## NOTE: Test has no labels. 126 | 127 | dataset = dataset.map( 128 | lambda sample, idx: format_sample( 129 | sample, format, with_query_label=with_query_label, seed=seed + idx 130 | ).to_pydict(), 131 | with_indices=True, 132 | num_proc=num_workers, 133 | remove_columns=dataset.column_names["validation"], 134 | ) 135 | 136 | prompt_data = dataset.get("train") 137 | prompt_kshot = { 138 | "train": train_kshot, 139 | "validation": eval_kshot, 140 | "test": eval_kshot, 141 | } 142 | 143 | data_splits = { 144 | split: ds.map( 145 | lambda _, idx: { 146 | "prompt": format_sample_prompt( 147 | prompt_data, format, kshot=prompt_kshot[split], seed=seed + idx 148 | ) 149 | }, 150 | with_indices=True, 151 | num_proc=num_workers, 152 | ) 153 | for split, ds in dataset.items() 154 | } 155 | 156 | train_data = data_splits.pop("train", None) 157 | val_data = data_splits.pop("validation", None) 158 | test_data = data_splits.pop("test", None) 159 | 160 | return train_data, val_data, test_data 161 | 162 | 163 | @register_dataset 164 | def multirc(*args, **kwargs): 165 | return get_multirc(*args, **kwargs) 166 | -------------------------------------------------------------------------------- /llm/datasets/hf/piqa.py: -------------------------------------------------------------------------------- 1 | import string 2 | import numpy as np 3 | from datasets import load_dataset 4 | 5 | from ..registry import register_dataset 6 | from ..llm_data_utils import LMText, PromptFormat 7 | 8 | 9 | def format_sample(sample, format, with_query_label=False, seed=None): 10 | target_prompt = "\nAnswer:" 11 | 12 | goal = sample["goal"] 13 | answer_map = [sample["sol1"], sample["sol2"]] 14 | target_idx = sample["label"] 15 | 16 | output = None 17 | query_label = ( 18 | np.random.default_rng(seed=seed).binomial(1, 0.5) if with_query_label else None 19 | ) 20 | output_idx = ( 21 | target_idx 22 | if query_label == 1 23 | else ( 24 | np.random.default_rng(seed=seed).choice( 25 | list(set(range(len(answer_map))) - set([target_idx])) 26 | ) 27 | if query_label == 0 28 | else None 29 | ) 30 | ) 31 | 32 | if format == PromptFormat.CHOICE: 33 | context = "\n".join( 34 | [ 35 | "Goal:", 36 | goal, 37 | "\nChoices:", 38 | *[ 39 | f" ({n}): {c}" 40 | for n, c in zip( 41 | string.ascii_lowercase[: len(answer_map)], answer_map 42 | ) 43 | ], 44 | ] 45 | ) 46 | 47 | target = string.ascii_lowercase[target_idx] 48 | if output_idx is not None: 49 | output = string.ascii_lowercase[output_idx] 50 | elif format == PromptFormat.OE: 51 | context = "\n".join( 52 | [ 53 | "Provide advice on how to accomplish the following goal.", 54 | f"Goal: {goal}", 55 | ] 56 | ) 57 | 58 | target = answer_map[target_idx] 59 | if output_idx is not None: 60 | output = answer_map[output_idx] 61 | else: 62 | raise NotImplementedError(f"Unsupported prompt format {format}.") 63 | 64 | return LMText( 65 | context=context, 66 | target_prompt=target_prompt, 67 | target=target, 68 | output=output, 69 | query_label=query_label, 70 | ) 71 | 72 | 73 | def format_sample_prompt(prompt_dataset, format, kshot=1, seed=None): 74 | if not kshot: 75 | if format == PromptFormat.OE: 76 | return "\n".join( 77 | [ 78 | "Give ONLY the advice, no other words or explanation.", 79 | "For example:", 80 | "Answer: .", 81 | ] 82 | ) 83 | 84 | return "" 85 | 86 | samples_idx = ( 87 | np.random.default_rng(seed=seed) 88 | .permutation(len(prompt_dataset))[:kshot] 89 | .tolist() 90 | ) 91 | 92 | fewshot_samples_prompt = [ 93 | str(LMText.from_(prompt_dataset[idx])) + "\n" for idx in samples_idx 94 | ] 95 | 96 | if format == PromptFormat.CHOICE: 97 | prompt = [ 98 | "The following are questions with multiple-choice answers.\n", 99 | *fewshot_samples_prompt, 100 | "Now, answer the next question.\n\n", 101 | ] 102 | elif format == PromptFormat.OE: 103 | prompt = [ 104 | "The following are questions with answers.\n", 105 | *fewshot_samples_prompt, 106 | "Now, answer the next question.\n\n", 107 | ] 108 | else: 109 | raise NotImplementedError(f"Unsupported prompt format {format}.") 110 | 111 | return "\n".join(prompt) 112 | 113 | 114 | def get_piqa( 115 | prompt_style=None, 116 | with_query_label=False, 117 | train_kshot=0, 118 | eval_kshot=0, 119 | num_workers=8, 120 | seed=None, 121 | use_cache=True, 122 | **_, 123 | ): 124 | format = PromptFormat(prompt_style) 125 | 126 | dataset = load_dataset("piqa", trust_remote_code=True) 127 | if not use_cache: 128 | dataset.cleanup_cache_files() 129 | dataset.pop("test", None) ## NOTE: "test" split has no labels. 130 | 131 | dataset = dataset.filter( 132 | lambda x: x["label"] in [0, 1, 2], num_proc=num_workers 133 | ).map( 134 | lambda sample, idx: format_sample( 135 | sample, format, with_query_label=with_query_label, seed=seed + idx 136 | ).to_pydict(), 137 | with_indices=True, 138 | num_proc=num_workers, 139 | remove_columns=dataset.column_names["validation"], 140 | ) 141 | 142 | prompt_data = dataset.get("train") 143 | prompt_kshot = { 144 | "train": train_kshot, 145 | "validation": eval_kshot, 146 | "test": eval_kshot, 147 | } 148 | 149 | data_splits = { 150 | split: ds.map( 151 | lambda _, idx: { 152 | "prompt": format_sample_prompt( 153 | prompt_data, format, kshot=prompt_kshot[split], seed=seed + idx 154 | ) 155 | }, 156 | with_indices=True, 157 | num_proc=num_workers, 158 | ) 159 | for split, ds in dataset.items() 160 | } 161 | 162 | train_data = data_splits.pop("train", None) 163 | val_data = data_splits.pop("validation", None) 164 | test_data = data_splits.pop("test", None) 165 | 166 | return train_data, val_data, test_data 167 | 168 | 169 | @register_dataset 170 | def piqa(*args, **kwargs): 171 | return get_piqa(*args, **kwargs) 172 | -------------------------------------------------------------------------------- /llm/eval/classifier.py: -------------------------------------------------------------------------------- 1 | from tqdm.auto import tqdm 2 | from collections import OrderedDict 3 | import torch 4 | from peft import PeftModel 5 | 6 | from ..datasets import LMText, LabeledStringDataCollator, prepare_uncertainty_query 7 | from .common import ( 8 | get_model_generations, 9 | save_metrics_data, 10 | compute_uncertainty_metrics, 11 | ) 12 | 13 | 14 | def get_classifier_inputs_via_embedding_model(model, lmtext_inputs, outputs): 15 | class_inputs = [ 16 | str(LMText.from_({**inp, "target": t})) 17 | for inp, t in zip(lmtext_inputs, outputs) 18 | ] 19 | class_inputs = model.embedding_model.encode( 20 | class_inputs, convert_to_tensor=True, show_progress_bar=False 21 | ) 22 | 23 | return class_inputs 24 | 25 | 26 | def get_classifier_inputs( 27 | accelerator, model, tokenizer, lmtext_inputs, outputs, adapter_name="query" 28 | ): 29 | collate_fn = LabeledStringDataCollator(tokenizer) 30 | 31 | inputs = [{**inp, "target": t} for inp, t in zip(lmtext_inputs, outputs)] 32 | inputs = {k: v.to(accelerator.device) for k, v in collate_fn(inputs).items()} 33 | 34 | if isinstance(model, PeftModel): 35 | active_adapter = model.active_adapter 36 | if adapter_name in model.peft_config: 37 | model.set_adapter(adapter_name) 38 | 39 | class_inputs = model(**inputs, output_hidden_states=True) 40 | 41 | if isinstance(model, PeftModel): 42 | model.set_adapter(active_adapter) 43 | 44 | target_layer = getattr(model.classifier_model, "target_layer", -1) 45 | class_inputs = class_inputs.hidden_states[target_layer][..., -1, :] 46 | 47 | return class_inputs 48 | 49 | 50 | @torch.inference_mode() 51 | def evaluate_classifier( 52 | accelerator, 53 | model, 54 | tokenizer, 55 | loader, 56 | log_dir=None, 57 | max_new_tokens=None, 58 | grade_strategy=None, 59 | **_, 60 | ): 61 | eval_data = OrderedDict([("logits", []), ("labels", [])]) 62 | 63 | for inputs in tqdm(loader): 64 | extra_inputs = { 65 | k: v for k, v in inputs.items() if k not in LMText.field_names() 66 | } 67 | inputs = {k: v for k, v in inputs.items() if k in LMText.field_names()} 68 | 69 | class_inputs = extra_inputs.pop("embedding", None) 70 | class_labels = inputs.pop("query_label", None) 71 | outputs = inputs.pop("output", None) 72 | targets = inputs.pop("target", None) 73 | 74 | inputs = [dict(zip(inputs.keys(), vals)) for vals in zip(*inputs.values())] 75 | 76 | if outputs is None: 77 | outputs, _ = get_model_generations( 78 | accelerator, model, tokenizer, inputs, max_new_tokens=max_new_tokens 79 | ) 80 | 81 | _, class_labels, _ = prepare_uncertainty_query( 82 | tokenizer, 83 | inputs, 84 | targets, 85 | outputs, 86 | strategy=grade_strategy, 87 | query_labels=class_labels, 88 | ) 89 | class_labels = class_labels.to(accelerator.device) 90 | 91 | if hasattr(model, "embedding_model"): 92 | if class_inputs is None: 93 | class_inputs = get_classifier_inputs_via_embedding_model( 94 | model, inputs, outputs 95 | ) 96 | else: 97 | class_inputs = get_classifier_inputs( 98 | accelerator, model, tokenizer, inputs, outputs 99 | ) 100 | class_inputs = class_inputs.to(model.dtype) 101 | 102 | class_logits = model.classifier_model(class_inputs) 103 | 104 | [ 105 | eval_data[k].append(v.cpu()) 106 | for k, v in zip( 107 | eval_data.keys(), 108 | accelerator.gather_for_metrics((class_logits, class_labels)), 109 | ) 110 | ] 111 | 112 | eval_data = OrderedDict({k: torch.cat(v, dim=0) for k, v in eval_data.items()}) 113 | 114 | all_metrics = compute_uncertainty_metrics( 115 | eval_data.get("labels"), 116 | eval_data.get("logits"), 117 | prefix="unc_", 118 | ) 119 | all_metrics["acc"] = eval_data.get("labels").float().mean(dim=0).item() 120 | save_metrics_data(eval_data, log_dir=log_dir, filename="classifier_data.bin") 121 | 122 | return all_metrics 123 | 124 | 125 | @torch.inference_mode() 126 | def evaluate_classifier_logits( 127 | accelerator, 128 | model, 129 | tokenizer, 130 | loader, 131 | log_dir=None, 132 | **_, 133 | ): 134 | eval_data = OrderedDict([("logits", []), ("labels", [])]) 135 | 136 | for inputs in tqdm(loader): 137 | class_inputs = inputs.pop("embedding", None) 138 | class_labels = inputs.pop("query_label", None) 139 | 140 | class_logits = model(class_inputs) 141 | 142 | [ 143 | eval_data[k].append(v.cpu()) 144 | for k, v in zip( 145 | eval_data.keys(), 146 | accelerator.gather_for_metrics((class_logits, class_labels)), 147 | ) 148 | ] 149 | 150 | eval_data = OrderedDict({k: torch.cat(v, dim=0) for k, v in eval_data.items()}) 151 | 152 | all_metrics = compute_uncertainty_metrics( 153 | eval_data.get("labels"), 154 | eval_data.get("logits"), 155 | prefix="unc_", 156 | ) 157 | all_metrics["acc"] = eval_data.get("labels").float().mean(dim=0).item() 158 | save_metrics_data(eval_data, log_dir=log_dir, filename="classifier_data.bin") 159 | 160 | return all_metrics 161 | -------------------------------------------------------------------------------- /llm/trainer/fine_tune.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | from tqdm.auto import tqdm 4 | import torch 5 | from torch.utils.data import default_collate 6 | from transformers.trainer import ( 7 | Trainer, 8 | logger, 9 | TRAINING_ARGS_NAME, 10 | TrainingArguments, 11 | ) 12 | 13 | from ..datasets import LabeledStringDataCollator 14 | 15 | 16 | class FineTuner(Trainer): 17 | TEMPERATURE_WEIGHTS_NAME = "temperature_head.bin" 18 | 19 | @dataclass 20 | class Args(TrainingArguments): 21 | accelerator_config: dict = field( 22 | default_factory=lambda: dict(use_configured_state=True) 23 | ) 24 | fp16: bool = field(default=not torch.cuda.is_bf16_supported()) 25 | bf16: bool = field(default=torch.cuda.is_bf16_supported()) 26 | ddp_find_unused_parameters: bool = field(default=False) 27 | log_on_each_node: bool = field(default=False) 28 | eval_strategy: str = field(default="steps") 29 | dataloader_num_workers: int = field(default=4) 30 | optim: str = field(default="adamw_torch") 31 | lr: float = field(default=1e-4) 32 | lr_scheduler_type: str = field(default="cosine") 33 | weight_decay: float = field(default=0.0) 34 | warmup_ratio: float = field(default=0.0) 35 | gradient_accumulation_steps: int = field(default=1) 36 | report_to: str = field(default="wandb") 37 | ## Custom Args. 38 | scale_temp: bool = field(default=False) 39 | 40 | def __init__(self, args=None, train_dataset=None, tokenizer=None, **kwargs): 41 | args.label_names = train_dataset.column_names 42 | 43 | self._collate_fn = LabeledStringDataCollator(tokenizer) 44 | 45 | super().__init__( 46 | **kwargs, 47 | args=args, 48 | tokenizer=tokenizer, 49 | train_dataset=train_dataset, 50 | data_collator=default_collate, 51 | ) 52 | 53 | def compute_loss(self, model, inputs, **kwargs): 54 | inputs = [dict(zip(inputs.keys(), vals)) for vals in zip(*inputs.values())] 55 | targets = [inp.pop("target") for inp in inputs] 56 | 57 | loss_inputs = { 58 | k: v.to(self.accelerator.device) 59 | for k, v in self._collate_fn( 60 | [{**inp, "target": t} for inp, t in zip(inputs, targets)] 61 | ).items() 62 | } 63 | 64 | return super().compute_loss(model, loss_inputs, **kwargs) 65 | 66 | def evaluate(self, eval_dataset=None, metric_key_prefix="eval", **_): 67 | eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset 68 | 69 | eval_dataloader = self.get_eval_dataloader(eval_dataset) 70 | 71 | all_metrics = {"loss": []} 72 | 73 | for inputs in tqdm(eval_dataloader, leave=False): 74 | inputs = [dict(zip(inputs.keys(), vals)) for vals in zip(*inputs.values())] 75 | targets = [inp.pop("target") for inp in inputs] 76 | B = len(inputs) 77 | 78 | loss_inputs = { 79 | k: v.to(self.accelerator.device) 80 | for k, v in self._collate_fn( 81 | [{**inp, "target": t} for inp, t in zip(inputs, targets)] 82 | ).items() 83 | } 84 | 85 | with torch.inference_mode(): 86 | loss = super().compute_loss( 87 | self.model, loss_inputs, return_outputs=False 88 | ) 89 | 90 | ## De-mean for distributed only. 91 | loss = ( 92 | torch.zeros(B) 93 | .index_fill_(0, torch.tensor([0]).long(), loss * B) 94 | .to(loss.device) 95 | ) 96 | [ 97 | all_metrics[l].append(v) 98 | for l, v in zip( 99 | ("loss",), 100 | self.accelerator.gather_for_metrics((loss,)), 101 | ) 102 | ] 103 | 104 | all_metrics = {k: torch.cat(v, dim=0) for k, v in all_metrics.items()} 105 | N = all_metrics["loss"].size(0) 106 | 107 | all_metrics = { 108 | f"{metric_key_prefix}_{k}": (v[v.nonzero().squeeze(-1)].sum() / N).item() 109 | for k, v in all_metrics.items() 110 | } 111 | 112 | self.log(all_metrics) 113 | 114 | self.control = self.callback_handler.on_evaluate( 115 | self.args, self.state, self.control, all_metrics 116 | ) 117 | 118 | return all_metrics 119 | 120 | def _save(self, output_dir=None, state_dict=None): 121 | output_dir = output_dir if output_dir is not None else self.args.output_dir 122 | os.makedirs(output_dir, exist_ok=True) 123 | logger.info(f"Saving model checkpoint to {output_dir}") 124 | 125 | self.model.save_pretrained( 126 | output_dir, 127 | state_dict=state_dict, 128 | safe_serialization=self.args.save_safetensors, 129 | selected_adapters=["default"], 130 | save_embedding_layers=False, 131 | ) 132 | 133 | if self.tokenizer is not None: 134 | self.tokenizer.save_pretrained(output_dir) 135 | 136 | torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) 137 | 138 | if self.args.scale_temp: 139 | torch.save( 140 | self.accelerator.unwrap_model(self.model).lm_head[-1].state_dict(), 141 | os.path.join(output_dir, self.TEMPERATURE_WEIGHTS_NAME), 142 | ) 143 | -------------------------------------------------------------------------------- /experiments/train_embedding_only.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import OrderedDict 3 | from pathlib import Path 4 | import torch 5 | from tqdm.auto import tqdm 6 | import wandb 7 | 8 | from llm.logging import entrypoint 9 | from llm.datasets import get_dataset, get_loader 10 | from llm.models.peft import get_classifier_head, get_temperature_head 11 | from llm.trainer import ClassificationTuner 12 | 13 | 14 | @torch.inference_mode 15 | def compute_metrics(accelerator, data, model, batch_size=64, num_workers=8, prefix=""): 16 | model.eval() 17 | 18 | loader = get_loader( 19 | data, 20 | batch_size=batch_size, 21 | num_workers=num_workers, 22 | pin_memory=True, 23 | accelerator=accelerator, 24 | ) 25 | 26 | criterion = torch.nn.CrossEntropyLoss(reduction="none") 27 | 28 | all_data = OrderedDict([("loss", []), ("acc", [])]) 29 | 30 | for inputs in tqdm(loader, leave=False): 31 | embeddings = inputs.get("embedding") 32 | labels = inputs.get("query_label") 33 | 34 | logits = model(embeddings) 35 | 36 | loss = criterion(logits, labels) 37 | 38 | preds = labels == logits.argmax(dim=-1) 39 | 40 | [ 41 | all_data[k].append(v.cpu()) 42 | for k, v in zip( 43 | all_data.keys(), accelerator.gather_for_metrics((loss, preds)) 44 | ) 45 | ] 46 | 47 | all_data = { 48 | f"{prefix}{k}": torch.cat(v, dim=0).float().mean().item() 49 | for k, v in all_data.items() 50 | } 51 | 52 | return all_data 53 | 54 | 55 | @entrypoint(with_accelerator=True) 56 | def main( 57 | accelerator=None, 58 | seed=137, 59 | log_dir=None, 60 | dataset=None, 61 | prompt_style=None, 62 | data_dir=None, 63 | num_workers=4, 64 | model_dir=None, 65 | scale_temp=False, 66 | batch_size=64, 67 | lr=1e-3, 68 | weight_decay=1e-2, 69 | max_steps=2000, 70 | ): 71 | config = dict( 72 | seed=seed, 73 | log_dir=log_dir, 74 | dataset=dataset, 75 | prompt_style=prompt_style, 76 | model_dir=model_dir, 77 | scale_temp=scale_temp, 78 | batch_size=batch_size, 79 | lr=lr, 80 | weight_decay=weight_decay, 81 | max_steps=max_steps, 82 | ) 83 | if accelerator.is_main_process: 84 | wandb.config.update(config, allow_val_change=True) 85 | 86 | train_data, val_data, test_data = get_dataset( 87 | dataset, 88 | root=data_dir, 89 | seed=seed, 90 | prompt_style=prompt_style, 91 | num_workers=num_workers, 92 | ) 93 | if scale_temp: 94 | train_data, val_data = val_data, test_data 95 | 96 | model = get_classifier_head( 97 | input_size=train_data[0]["embedding"].shape[0], 98 | checkpoint_dir=model_dir, 99 | is_trainable=not scale_temp, 100 | weights_name=ClassificationTuner.WEIGHTS_NAME, 101 | ) 102 | 103 | if scale_temp: 104 | temperature_model = get_temperature_head(is_trainable=True) 105 | 106 | model = torch.nn.Sequential( 107 | model, 108 | temperature_model, 109 | ) 110 | 111 | optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) 112 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_steps) 113 | 114 | loader = get_loader( 115 | train_data, 116 | batch_size=batch_size, 117 | num_workers=num_workers, 118 | pin_memory=True, 119 | accelerator=accelerator, 120 | shuffle=True, 121 | ) 122 | 123 | model, optimizer, scheduler = accelerator.prepare(model, optimizer, scheduler) 124 | 125 | criterion = torch.nn.CrossEntropyLoss() 126 | 127 | logging_steps = max(1, max_steps // 200) 128 | save_steps = max_steps // 10 129 | 130 | iter_loader = iter(loader) 131 | for step in tqdm(range(max_steps)): 132 | model.train() 133 | 134 | optimizer.zero_grad() 135 | 136 | try: 137 | batch = next(iter_loader) 138 | except StopIteration: 139 | iter_loader = iter(loader) 140 | batch = next(iter_loader) 141 | 142 | embeddings = batch.get("embedding") 143 | labels = batch.get("query_label") 144 | 145 | logits = model(embeddings) 146 | 147 | loss = criterion(logits, labels) 148 | 149 | accelerator.backward(loss) 150 | 151 | optimizer.step() 152 | scheduler.step() 153 | 154 | train_metrics = { 155 | "train/loss": loss.detach().item(), 156 | } 157 | 158 | if (step + 1) % logging_steps == 0: 159 | if val_data is not None: 160 | val_metrics = compute_metrics( 161 | accelerator, 162 | val_data, 163 | model, 164 | batch_size=batch_size, 165 | num_workers=num_workers, 166 | prefix="eval/", 167 | ) 168 | logging.info(val_metrics, extra=dict(metrics=True)) 169 | logging.debug(val_metrics) 170 | 171 | logging.info(train_metrics, extra=dict(metrics=True)) 172 | logging.debug(train_metrics) 173 | 174 | if accelerator.is_main_process and (step + 1) % save_steps == 0: 175 | checkpoint_path = ( 176 | Path(log_dir) 177 | / f"checkpoint-{step + 1}" 178 | / ClassificationTuner.WEIGHTS_NAME 179 | ) 180 | checkpoint_path.parent.mkdir() 181 | 182 | torch.save(accelerator.unwrap_model(model).state_dict(), checkpoint_path) 183 | 184 | 185 | if __name__ == "__main__": 186 | import fire 187 | 188 | fire.Fire(main) 189 | --------------------------------------------------------------------------------