├── .gitignore ├── LICENSE ├── README.md ├── assets └── mom_fig1.png ├── benchmarks ├── benchmark_generation.py ├── benchmark_training_throughput.py ├── modules │ ├── benchmark_cross_entropy.py │ └── benchmark_layernorm.py └── ops │ ├── benchmark.py │ ├── benchmark_abc.py │ ├── benchmark_based.py │ ├── benchmark_delta_rule.py │ ├── benchmark_fla.py │ ├── benchmark_gla.py │ ├── benchmark_gsa.py │ ├── benchmark_hgrn.py │ ├── benchmark_retention.py │ ├── benchmark_rwkv6.py │ └── benchmark_simple_gla_vs_mamba2.py ├── evals ├── harness.py └── ppl.py ├── lm-eval-harness ├── CITATION.bib ├── CODEOWNERS ├── LICENSE.md ├── README.md ├── docs │ ├── CONTRIBUTING.md │ ├── README.md │ ├── decontamination.md │ ├── img │ │ └── fewshot_example_gpt3.png │ ├── interface.md │ ├── model_guide.md │ ├── new_task_guide.md │ └── task_guide.md ├── ignore.txt ├── launch.py ├── launch_hf.py ├── launch_jrt.py ├── launch_local.py ├── lm_eval │ ├── __init__.py │ ├── __main__.py │ ├── api │ │ ├── __init__.py │ │ ├── filter.py │ │ ├── instance.py │ │ ├── metrics.py │ │ ├── model.py │ │ ├── registry.py │ │ ├── samplers.py │ │ └── task.py │ ├── decontamination │ │ ├── __init__.py │ │ ├── archiver.py │ │ ├── decontaminate.py │ │ └── janitor.py │ ├── evaluator.py │ ├── filters │ │ ├── __init__.py │ │ ├── decontamination.py │ │ ├── extraction.py │ │ ├── selection.py │ │ └── transformation.py │ ├── models │ │ ├── __init__.py │ │ ├── anthropic_llms.py │ │ ├── based_lm.py │ │ ├── dummy.py │ │ ├── gguf.py │ │ ├── huggingface.py │ │ ├── jrt_lm.py │ │ ├── local_lm.py │ │ ├── local_utils │ │ │ ├── jrt_utils.py │ │ │ └── loading.py │ │ ├── mamba_lm.py │ │ ├── neuron_optimum.py │ │ ├── openai_completions.py │ │ ├── optimum_lm.py │ │ ├── textsynth.py │ │ ├── utils.py │ │ └── vllm_causallms.py │ ├── prompts │ │ └── __init__.py │ ├── tasks │ │ ├── README.md │ │ ├── __init__.py │ │ ├── based_drop │ │ │ └── task.py │ │ ├── based_fda │ │ │ ├── README.md │ │ │ └── task.py │ │ ├── based_nq │ │ │ └── task.py │ │ ├── based_squadv2 │ │ │ ├── README.md │ │ │ └── task.py │ │ ├── based_swde │ │ │ ├── README.md │ │ │ └── task.py │ │ ├── based_triviaqa │ │ │ └── task.py │ │ ├── scrolls │ │ │ ├── README.md │ │ │ └── task.py │ │ └── super_glue │ │ │ ├── README.md │ │ │ ├── cb │ │ │ ├── aggregate.py │ │ │ └── t5_utils.py │ │ │ ├── copa │ │ │ └── utils.py │ │ │ ├── multirc │ │ │ └── t5_utils.py │ │ │ ├── record │ │ │ ├── t5_utils.py │ │ │ └── util.py │ │ │ └── wsc │ │ │ ├── preprocess_wsc.py │ │ │ └── t5_utils.py │ └── utils.py ├── prompt_scripts │ ├── collect_results.py │ ├── run_jrt_prompt_hazy.sh │ └── run_jrt_prompt_hf.sh ├── pyproject.toml ├── requirements.txt ├── setup.py └── templates │ └── new_yaml_task │ └── README.md ├── mom ├── __init__.py ├── layers │ ├── __init__.py │ ├── mom_gated_deltanet.py │ ├── mom_gla.py │ ├── mom_gsa.py │ └── mom_linear_attn.py └── models │ ├── __init__.py │ ├── mom_gated_deltanet │ ├── __init__.py │ ├── configuration_mom_gated_deltanet.py │ └── modeling_mom_gated_deltanet.py │ ├── mom_gla │ ├── __init__.py │ ├── configuration_mom_gla.py │ └── modeling_mom_gla.py │ ├── mom_gsa │ ├── __init__.py │ ├── configuration_mom_gsa.py │ └── modeling_mom_gsa.py │ └── mom_linear_attn │ ├── __init__.py │ ├── configuration_mom_linear_attn.py │ └── modeling_mom_linear_attn.py ├── setup.py └── training ├── README.md ├── configs ├── mom_1.3B.json └── mom_340M.json ├── flame ├── __init__.py ├── data.py ├── logging.py └── parser.py ├── preprocess.py ├── run.py └── train.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # test file 2 | test.py 3 | mingpt 4 | 5 | # data files 6 | data 7 | 8 | # # bash scripts 9 | # *.sh 10 | 11 | # docs 12 | docs/_build 13 | 14 | # intermediate files 15 | build 16 | dist 17 | *.egg-info 18 | 19 | # experimental results 20 | exp 21 | fineweb 22 | SlimPajama 23 | slimpajama 24 | results 25 | wandb 26 | *.csv 27 | *.html 28 | eval_results 29 | 30 | # log and config files 31 | log.* 32 | *.log 33 | *.cfg 34 | *.ini 35 | *.yml 36 | *.yaml 37 | 38 | # pycache 39 | __pycache__ 40 | 41 | # saved model 42 | *.pkl 43 | *.pt 44 | 45 | # hidden files 46 | .* 47 | 48 | # vscode 49 | .vscode 50 | 51 | # macOS 52 | .DS_Store 53 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # MoM: Mixture-of-Memories 4 | [![arXiv](https://img.shields.io/badge/Arxiv-2502.13685-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2502.13685) 5 | [![huggingface weights](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Weights-ffc107?color=ffc107&logoColor=white)](https://huggingface.co/linear-moe-hub) 6 | [![zhihu](https://img.shields.io/badge/Zhihu-Intro-blue?logo=zhihu)](https://zhuanlan.zhihu.com/p/25066090353) 7 | [![stars](https://img.shields.io/github/stars/OpenSparseLLMs/MoM)](https://github.com/OpenSparseLLMs/MoM/stargazers) 8 | 9 |
10 | 11 | Welcome to MoM! This repository provides the implementation of [MoM: Linear Sequence Modeling with Mixture-of-Memories](https://arxiv.org/abs/2502.13685), on huggingface eco-system. MoM is compatible with all kinds of linear sequence modeling methods like: linear attention, SSM, linear RNN, etc. **Here is an introductory artical about MoM (in Chinese) on [Zhihu](https://zhuanlan.zhihu.com/p/25066090353)**. 12 | 13 |

14 | 15 |

16 |
17 | MoM Architecture 18 |
19 | 20 | ## Installation 21 | 22 | The following requirements should be satisfied: 23 | - [PyTorch](https://pytorch.org/) >= 2.5 24 | - [Triton](https://github.com/openai/triton) >=3.0 25 | - [einops](https://einops.rocks/) 26 | - [transformers](https://github.com/huggingface/transformers) >=4.45.0 27 | - [datasets](https://github.com/huggingface/datasets) >=3.3.0 28 | - [causal-conv1d](https://github.com/Dao-AILab/causal-conv1d) >=1.4.0 29 | 30 | Install the package from source: 31 | ```bash 32 | pip install -e . 33 | ``` 34 | 35 | ## Getting Started 36 | 37 | ### Data Preparation 38 | Before training, make sure to preprocess your data by following the steps outlined in [training/README.md](training/README.md). 39 | 40 | ### Training From Scratch 41 | 42 | To start training with default setup, simply run: 43 | ```bash 44 | cd training 45 | 46 | bash train.sh \ 47 | nodes=4 \ 48 | gpus=8 \ 49 | type=mom \ 50 | lr=3e-4 \ 51 | steps=30720 \ 52 | batch=8 \ 53 | update=1 \ 54 | warmup=1024 \ 55 | context=2048 \ 56 | path=SlimPajama/mom-15B \ 57 | project=SlimPajama \ 58 | model=configs/mom_340M.json \ 59 | tokenizer=fla-hub/gla-1.3B-100B \ 60 | data=SlimPajama-627B \ 61 | cache=data/chunk1/train 62 | ``` 63 | 64 | You can also 65 | - Modify the script to adjust the modeling and training settings. 66 | - e.g., modify [examples/configs/mom_340M.json](examples/configs/mom_340M.json) to adjust the MoM model structure. 67 | 68 | ### Evaluation 69 | 70 | To evaluate model checkpoints on **commonsense reasoning benchmarks**, we recommend you to run: 71 | ```bash 72 | MODEL_PATH=training/SlimPajama/mom-15B/checkpoint-30720 73 | 74 | accelerate launch --multi_gpu evals/harness.py --model hf \ 75 | --model_args pretrained=$MODEL_PATH,dtype=bfloat16 \ 76 | --tasks arc_easy,arc_challenge,hellaswag,lambada_standard,piqa,winogrande,wikitext \ 77 | --output_path eval_results \ 78 | --batch_size 32 \ 79 | --device cuda 80 | ``` 81 | 82 | To evaluate model checkpoints on **recall-intensive tasks**, we recommend you to run: 83 | 1. Install lm_eval 84 | ```bash 85 | cd lm-eval-harness 86 | pip install -e . 87 | ``` 88 | 2. Run the script: 89 | ```bash 90 | MODEL_PATH=../training/SlimPajama/mom-15B/checkpoint-30720 91 | 92 | CUDA_VISIBLE_DEVICES=0,1,2,3,4 python launch_local.py \ 93 | --batch-size 32 \ 94 | -t based_squad \ 95 | -t based_swde \ 96 | -t based_fda \ 97 | -t based_drop \ 98 | -t based_triviaqa \ 99 | -t based_nq_2048 \ 100 | -m $MODEL_PATH \ 101 | --context_length 2048 \ 102 | --answer_length 48 \ 103 | --cutting_context \ 104 | --limit -1 \ 105 | -p 106 | ``` 107 | 108 | ## Acknowledgement 109 | This repo builds upon the open-source [flash-linear-attention](https://github.com/fla-org/flash-linear-attention) and the evaluation code is based on [prefix-linear-attention](https://github.com/HazyResearch/prefix-linear-attention). Happy experimenting! 🔥🚀🔥 110 | 111 | ## Citation 112 | If you find this repo useful, please consider citing our paper: 113 | ```bib 114 | @article{du2025mom, 115 | title={MoM: Linear Sequence Modeling with Mixture-of-Memories}, 116 | author={Du, Jusen and Sun, Weigao and Lan, Disen and Hu, Jiaxi and Cheng, Yu}, 117 | journal={arXiv preprint arXiv:2502.13685}, 118 | year={2025} 119 | } 120 | ``` -------------------------------------------------------------------------------- /assets/mom_fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenSparseLLMs/MoM/2c3ae59f5c0b749f916189433b8dad5a3415dc08/assets/mom_fig1.png -------------------------------------------------------------------------------- /benchmarks/benchmark_generation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2023-2024, Songlin Yang, Yu Zhang. 3 | 4 | import argparse 5 | import time 6 | 7 | import torch 8 | from datasets import load_dataset 9 | from transformers import AutoModelForCausalLM, AutoTokenizer 10 | 11 | import fla # noqa 12 | import mom # noqa 13 | 14 | 15 | def sizeof_fmt(num, suffix='B'): 16 | for unit in ('', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi'): 17 | if abs(num) < 1024.0: 18 | return f'{num:3.1f}{unit}{suffix}' 19 | num /= 1024.0 20 | return f'{num:.1f}Yi{suffix}' 21 | 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser(description="Generation benchmarking") 25 | parser.add_argument("--path", type=str, default="fla-hub/transformer-1.3B-100B") 26 | parser.add_argument("--data", type=str, default="fla-hub/pg19") 27 | parser.add_argument("--length", type=int, default=128) 28 | parser.add_argument("--maxlen", type=int, default=128) 29 | parser.add_argument("--no-cache", action='store_true') 30 | parser.add_argument("--temperature", type=float, default=0.5) 31 | parser.add_argument("--topp", type=float, default=0.2) 32 | parser.add_argument("--repetition_penalty", type=float, default=1.1) 33 | args = parser.parse_args() 34 | 35 | device = "cuda" 36 | dtype = torch.bfloat16 37 | torch.manual_seed(0) 38 | 39 | print(f"Loading {args.path}") 40 | tokenizer = AutoTokenizer.from_pretrained( 41 | args.path, 42 | trust_remote_code=True, 43 | add_eos_token=False 44 | ) 45 | tokenizer.pad_token_id = tokenizer.eos_token_id 46 | print(f"{tokenizer}") 47 | 48 | model = AutoModelForCausalLM.from_pretrained( 49 | args.path, 50 | device_map={"": device}, 51 | torch_dtype=dtype, 52 | use_cache=not args.no_cache 53 | ) 54 | model.eval() 55 | print(f"{model.config}\n{model}\nNumber of parameters: {model.num_parameters()} ({sizeof_fmt(model.num_parameters())})\n") 56 | 57 | print(f"Loading {args.data}") 58 | dataset = load_dataset(args.data, split='train', trust_remote_code=True) 59 | print(f"{dataset}") 60 | 61 | prompt = dataset[0]['text'] 62 | tokens = tokenizer(prompt, return_tensors="pt") 63 | input_ids = tokens.input_ids.to(device=device)[:, :args.length].contiguous() 64 | max_length = input_ids.shape[1] + args.maxlen 65 | 66 | torch.cuda.synchronize() 67 | start = time.time() 68 | with torch.inference_mode(): 69 | text = model.generate( 70 | input_ids=input_ids, 71 | use_cache=not args.no_cache, 72 | max_length=max_length, 73 | pad_token_id=tokenizer.eos_token_id, 74 | eos_token_id=tokenizer.bos_token_id, 75 | do_sample=True, 76 | temperature=args.temperature, 77 | top_p=args.topp, 78 | repetition_penalty=args.repetition_penalty 79 | ) 80 | torch.cuda.synchronize() 81 | elapsed = time.time() - start 82 | print(f"Prompt:\n{tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0].strip()}\n") 83 | print(f"Generated:\n{tokenizer.batch_decode(text, skip_special_tokens=True)[0].strip()}\n") 84 | print(f"Prompt length: {len(input_ids[0])}, generation length: {len(text[0]) - len(input_ids[0])}") 85 | print(f"Total prompt processing + decoding time: {elapsed * 1000:.0f}ms") 86 | print(f"Max memory used: {sizeof_fmt(torch.cuda.max_memory_allocated())}") 87 | -------------------------------------------------------------------------------- /benchmarks/benchmark_training_throughput.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import argparse 4 | import time 5 | from typing import Optional, Tuple 6 | 7 | import torch 8 | from accelerate import Accelerator 9 | from torch.cuda import max_memory_allocated, memory_allocated 10 | from torch.optim import AdamW 11 | from tqdm import trange 12 | from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig 13 | from transformers.optimization import get_cosine_schedule_with_warmup 14 | 15 | import fla 16 | import mom 17 | 18 | classes1 = [getattr(mom.models, i) for i in mom.models.__all__] 19 | classes2 = [getattr(fla.models, i) for i in fla.models.__all__] 20 | classes = classes1 + classes2 21 | configs = {i.model_type: i() for i in classes if issubclass(i, PretrainedConfig)} 22 | 23 | 24 | def sizeof_fmt(num, suffix='B'): 25 | for unit in ('', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi'): 26 | if abs(num) < 1024.0: 27 | return f'{num:.2f}{unit}{suffix}' 28 | num /= 1024.0 29 | return f'{num:.2f}Yi{suffix}' 30 | 31 | 32 | def prepare_inputs( 33 | batch_size: int, 34 | seq_len: int, 35 | varlen: bool, 36 | vocab_size: int, 37 | device: torch.device 38 | ): 39 | if varlen: 40 | tokens = torch.randint(high=vocab_size, size=(1, batch_size * seq_len), device=device) 41 | offsets = torch.cat([ 42 | torch.tensor([0], dtype=torch.long, device=device), 43 | torch.randperm(batch_size * seq_len - 16, device=device)[:batch_size-1] + 16, 44 | torch.tensor([batch_size * seq_len], dtype=torch.long, device=device) 45 | ], 0).sort()[0] 46 | else: 47 | tokens = torch.randint(high=vocab_size, size=(batch_size, seq_len), device=device) 48 | offsets = None 49 | return tokens, offsets 50 | 51 | 52 | def profile( 53 | name: str, 54 | batch_size: int = 8, 55 | seq_len: int = 2048, 56 | varlen: bool = False, 57 | warmup_steps: int = 16, 58 | steps: int = 32, 59 | total_steps: int = 1024, 60 | lr: float = 3e-4, 61 | betas: Tuple[float] = (0.9, 0.95), 62 | weight_decay: float = 0.1, 63 | dtype: Optional[torch.dtype] = torch.bfloat16, 64 | mixed_precision: str = 'bf16' 65 | ): 66 | device = torch.device('cuda') 67 | config = configs[name] if name in configs else AutoConfig.from_pretrained(name) 68 | model = AutoModelForCausalLM.from_config(config).cuda().to(dtype) 69 | num_parameters = model.num_parameters() 70 | print(f"Initializing {name} model from the config:\n{config}\n{model}") 71 | print(f"Number of parameters in total: {num_parameters} ({sizeof_fmt(num_parameters)})") 72 | print(f"Allocated memory after initialization: {sizeof_fmt(memory_allocated(device))}") 73 | 74 | accelerator = Accelerator(mixed_precision=mixed_precision) 75 | optimizer = AdamW( 76 | model.parameters(), 77 | lr=lr, 78 | betas=betas, 79 | weight_decay=weight_decay, 80 | fused=True 81 | ) 82 | scheduler = get_cosine_schedule_with_warmup(optimizer, 0, total_steps) 83 | 84 | bar = trange(warmup_steps) 85 | 86 | model, optimizer, scheduler = accelerator.prepare(model, optimizer, scheduler) 87 | torch.cuda.synchronize(device) 88 | for _ in bar: 89 | # forward pass 90 | tokens, offsets = prepare_inputs( 91 | batch_size=batch_size, 92 | seq_len=seq_len, 93 | varlen=varlen, 94 | vocab_size=config.vocab_size, 95 | device=device 96 | ) 97 | outputs = model(tokens, labels=tokens, offsets=offsets) 98 | # backward pass 99 | accelerator.backward(outputs.loss) 100 | optimizer.step() 101 | scheduler.step() 102 | optimizer.zero_grad() 103 | bar.set_description_str(f"Max memory allocated: {sizeof_fmt(max_memory_allocated(device))}") 104 | 105 | start, total_tokens = time.time(), 0 106 | bar = trange(steps) 107 | torch.cuda.synchronize(device) 108 | for _ in bar: 109 | # forward pass 110 | tokens, offsets = prepare_inputs( 111 | batch_size=batch_size, 112 | seq_len=seq_len, 113 | varlen=varlen, 114 | vocab_size=config.vocab_size, 115 | device=device 116 | ) 117 | outputs = model(tokens, labels=tokens, offsets=offsets) 118 | # backward pass 119 | accelerator.backward(outputs.loss) 120 | optimizer.step() 121 | optimizer.zero_grad() 122 | 123 | total_tokens += batch_size * seq_len 124 | torch.cuda.synchronize(device) 125 | duration = time.time() - start 126 | bar.set_description_str(f"Thoughput: {total_tokens / duration:10.2f} tokens/s") 127 | 128 | 129 | if __name__ == "__main__": 130 | parser = argparse.ArgumentParser() 131 | parser.add_argument("--name", default='retnet') 132 | parser.add_argument("--batch_size", default=8, type=int) 133 | parser.add_argument("--seq_len", default=2048, type=int) 134 | parser.add_argument("--varlen", action='store_true') 135 | parser.add_argument("--warmup_steps", default=16, type=int) 136 | parser.add_argument("--steps", default=32, type=int) 137 | args = parser.parse_args() 138 | profile( 139 | name=args.name, 140 | batch_size=args.batch_size, 141 | seq_len=args.seq_len, 142 | varlen=args.varlen, 143 | warmup_steps=args.warmup_steps, 144 | steps=args.steps 145 | ) 146 | -------------------------------------------------------------------------------- /benchmarks/modules/benchmark_cross_entropy.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import triton 7 | 8 | from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss 9 | 10 | 11 | @triton.testing.perf_report( 12 | triton.testing.Benchmark( 13 | # argument names to use as an x-axis for the plot 14 | x_names=['T'], 15 | # different possible values for `x_name` 16 | x_vals=[128 * 2 ** i for i in range(0, 8)], 17 | # argument name whose value corresponds to a different line in the plot 18 | line_arg='provider', 19 | # possible values for `line_arg`` 20 | line_vals=['naive', 'fused', 'fused_linear', 'naive_bwd', 'fused_bwd', 'fused_linear_bwd'], 21 | # label name for the lines 22 | line_names=['naive', 'fused', 'fused_linear', 'naive_bwd', 'fused_bwd', 'fused_linear_bwd'], 23 | # line styles 24 | styles=[('green', '-'), ('blue', '--'), ('red', '-.'), 25 | ('cyan', ':'), ('yellow', 'dotted'), ('cyan', '--'), ('cyan', '-'), ('black', ':')], 26 | ylabel="Execution Time (ms)", # label name for the y-axis 27 | # name for the plot. Used also as a file name for saving the plot. 28 | plot_name="Performance", 29 | args={}, 30 | ) 31 | ) 32 | def benchmark(T, provider): 33 | device = 'cuda' 34 | dtype = torch.bfloat16 35 | requires_grad = True 36 | B, H, V = 4, 4096, 120000 37 | 38 | x = torch.randn(B * T, H, device=device, requires_grad=requires_grad, dtype=dtype) 39 | target = torch.randint(0, V, (B * T,), device=device, dtype=torch.int64) 40 | w = torch.randn(V, H, device=device, requires_grad=requires_grad, dtype=dtype) 41 | b = torch.randn(V, device=device, requires_grad=requires_grad, dtype=dtype) 42 | 43 | quantiles = [0.5, 0.2, 0.8] 44 | results = 0, 0, 0 45 | if provider == 'naive': 46 | criterion = nn.CrossEntropyLoss() 47 | results = triton.testing.do_bench(lambda: criterion(F.linear(x, w, b), target), quantiles=quantiles) 48 | elif provider == 'naive_bwd': 49 | criterion = nn.CrossEntropyLoss() 50 | results = triton.testing.do_bench(lambda: criterion(F.linear(x, w, b), target).backward(), quantiles=quantiles) 51 | elif provider == 'fused': 52 | criterion = FusedCrossEntropyLoss() 53 | results = triton.testing.do_bench(lambda: criterion(F.linear(x, w, b), target), quantiles=quantiles) 54 | elif provider == 'fused_bwd': 55 | criterion = FusedCrossEntropyLoss() 56 | results = triton.testing.do_bench(lambda: criterion(F.linear(x, w, b), target).backward(), quantiles=quantiles) 57 | elif provider == 'fused_linear': 58 | criterion = FusedLinearCrossEntropyLoss() 59 | results = triton.testing.do_bench(lambda: criterion(x, target, w, b), quantiles=quantiles) 60 | elif provider == 'fused_linear_bwd': 61 | criterion = FusedLinearCrossEntropyLoss() 62 | results = triton.testing.do_bench(lambda: criterion(x, target, w, b).backward(), quantiles=quantiles) 63 | return results 64 | 65 | 66 | if __name__ == '__main__': 67 | benchmark.run(print_data=True) 68 | -------------------------------------------------------------------------------- /benchmarks/modules/benchmark_layernorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | import triton 6 | 7 | from fla.modules import GroupNorm, LayerNorm 8 | 9 | 10 | @triton.testing.perf_report( 11 | triton.testing.Benchmark( 12 | # argument names to use as an x-axis for the plot 13 | x_names=['T'], 14 | # different possible values for `x_name` 15 | x_vals=[128 * 2 ** i for i in range(0, 8)], 16 | # argument name whose value corresponds to a different line in the plot 17 | line_arg='provider', 18 | # possible values for `line_arg`` 19 | line_vals=['naive_ln', 'fused_ln', 'naive_gn', 'fused_gn', 20 | 'naive_ln_bwd', 'fused_ln_bwd', 'naive_gn_bwd', 'fused_gn_bwd'], 21 | # label name for the lines 22 | line_names=['naive_ln', 'fused_ln', 'naive_gn', 'fused_gn', 23 | 'naive_ln_bwd', 'fused_ln_bwd', 'naive_gn_bwd', 'fused_gn_bwd'], 24 | # line styles 25 | styles=[('green', '-'), ('blue', '--'), ('red', '-.'), 26 | ('cyan', ':'), ('yellow', 'dotted'), ('cyan', '--'), ('cyan', '-'), ('black', ':')], 27 | ylabel="Execution Time (ms)", # label name for the y-axis 28 | # name for the plot. Used also as a file name for saving the plot. 29 | plot_name="Performance", 30 | args={}, 31 | ) 32 | ) 33 | def benchmark(T, provider): 34 | device = 'cuda' 35 | dtype = torch.bfloat16 36 | requires_grad = True 37 | B, D = 16, 1024 38 | 39 | x = torch.randn(B * T, D, device=device, requires_grad=requires_grad, dtype=dtype) 40 | 41 | quantiles = [0.5, 0.2, 0.8] 42 | results = 0, 0, 0 43 | if provider.startswith('naive_ln'): 44 | norm = nn.LayerNorm(D, elementwise_affine=True, bias=True).to(device=device, dtype=dtype) 45 | results = triton.testing.do_bench(lambda: norm(x), quantiles=quantiles) 46 | if provider.startswith('fused_ln'): 47 | norm = LayerNorm(D, elementwise_affine=True, bias=True).to(device=device, dtype=dtype) 48 | results = triton.testing.do_bench(lambda: norm(x), quantiles=quantiles) 49 | if provider.startswith('naive_gn'): 50 | norm = nn.GroupNorm(4, D).to(device=device, dtype=dtype) 51 | results = triton.testing.do_bench(lambda: norm(x), quantiles=quantiles) 52 | if provider.startswith('fused_gn'): 53 | norm = GroupNorm(4, D, elementwise_affine=True, bias=True).to(device=device, dtype=dtype) 54 | results = triton.testing.do_bench(lambda: norm(x), quantiles=quantiles) 55 | if provider.startswith('naive_ln_bwd'): 56 | norm = nn.LayerNorm(D, elementwise_affine=True, bias=True).to(device=device, dtype=dtype) 57 | results = triton.testing.do_bench(lambda: norm(x).backward(x), quantiles=quantiles) 58 | if provider.startswith('fused_ln_bwd'): 59 | norm = LayerNorm(D, elementwise_affine=True, bias=True).to(device=device, dtype=dtype) 60 | results = triton.testing.do_bench(lambda: norm(x).backward(x), quantiles=quantiles) 61 | if provider.startswith('naive_gn_bwd'): 62 | norm = nn.GroupNorm(4, D).to(device=device, dtype=dtype) 63 | results = triton.testing.do_bench(lambda: norm(x).backward(x), quantiles=quantiles) 64 | if provider.startswith('fused_gn_bwd'): 65 | norm = GroupNorm(4, D, elementwise_affine=True, bias=True).to(device=device, dtype=dtype) 66 | results = triton.testing.do_bench(lambda: norm(x).backward(x), quantiles=quantiles) 67 | return results 68 | 69 | 70 | if __name__ == '__main__': 71 | benchmark.run(print_data=True) 72 | -------------------------------------------------------------------------------- /benchmarks/ops/benchmark_abc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import triton 5 | from torch.nn import functional as F 6 | 7 | from fla.ops.abc import chunk_abc 8 | from fla.ops.gla import chunk_gla 9 | from fla.ops.retention import chunk_retention 10 | 11 | try: 12 | from flash_attn import flash_attn_func 13 | HAS_FLASH = True 14 | except BaseException: 15 | HAS_FLASH = False 16 | 17 | 18 | @triton.testing.perf_report( 19 | triton.testing.Benchmark( 20 | # argument names to use as an x-axis for the plot 21 | x_names=['T'], 22 | # different possible values for `x_name` 23 | x_vals=[128 * 2 ** i for i in range(0, 8)], 24 | # argument name whose value corresponds to a different line in the plot 25 | line_arg='provider', 26 | # possible values for `line_arg`` 27 | line_vals=['abc', 'gla', 'abc_bwd', 'gla_bwd', 'retention_bwd', 'flash_bwd'], 28 | # label name for the lines 29 | line_names=['abc', 'gla', 'abc_bwd', 'gla_bwd', 'retention_bwd', 'flash_bwd'], 30 | # line styles 31 | styles=[('green', '-'), ('blue', '--'), ('red', '-.'), 32 | ('cyan', ':'), ('yellow', 'dotted'), ('black', ':')], 33 | ylabel="Execution Time (ms)", # label name for the y-axis 34 | # name for the plot. Used also as a file name for saving the plot. 35 | plot_name="Performance", 36 | args={}, 37 | ) 38 | ) 39 | def benchmark(T, provider): 40 | device = 'cuda' 41 | dtype = torch.bfloat16 42 | requires_grad = True 43 | B, H, D, M = 16, 4, 128, 64 44 | 45 | q = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype) 46 | k = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype) 47 | v = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype) 48 | if provider.startswith('flash'): 49 | q = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype) 50 | k = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype) 51 | v = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype) 52 | if provider.startswith('gla'): 53 | g = F.logsigmoid(torch.randn(B, H, T, D, device=device, dtype=dtype)) 54 | g = g.clamp_min(-5).requires_grad_(requires_grad) 55 | if provider.startswith('abc'): 56 | s = torch.randn(B, H, T, M, device=device, requires_grad=requires_grad, dtype=dtype) 57 | 58 | do = torch.ones_like(v, dtype=dtype) 59 | 60 | quantiles = [0.5, 0.2, 0.8] 61 | if provider == 'abc': 62 | results = triton.testing.do_bench(lambda: chunk_abc(q, k, v, s), quantiles=quantiles) 63 | elif provider == 'gla': 64 | results = triton.testing.do_bench(lambda: chunk_gla(q, k, v, g), quantiles=quantiles) 65 | elif provider == 'abc_bwd': 66 | results = triton.testing.do_bench(lambda: chunk_abc(q, k, v, s)[0].backward(do), quantiles=quantiles) 67 | elif provider == 'gla_bwd': 68 | results = triton.testing.do_bench(lambda: chunk_gla(q, k, v, g)[0].backward(do), quantiles=quantiles) 69 | elif provider == 'retention_bwd': 70 | results = triton.testing.do_bench(lambda: chunk_retention(q, k, v)[0].backward(do), quantiles=quantiles) 71 | elif provider == 'flash_bwd': 72 | results = triton.testing.do_bench(lambda: flash_attn_func(q, k, v, causal=True).backward(do), quantiles=quantiles) 73 | return results 74 | 75 | 76 | if __name__ == '__main__': 77 | benchmark.run(print_data=True) 78 | -------------------------------------------------------------------------------- /benchmarks/ops/benchmark_based.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import triton 5 | 6 | from fla.ops.based import fused_chunk_based, parallel_based 7 | from fla.ops.based.naive import naive_chunk_based, naive_parallel_based 8 | 9 | try: 10 | from flash_attn import flash_attn_func 11 | HAS_FLASH = True 12 | except Exception: 13 | HAS_FLASH = False 14 | 15 | 16 | @triton.testing.perf_report( 17 | triton.testing.Benchmark( 18 | # argument names to use as an x-axis for the plot 19 | x_names=['T'], 20 | # different possible values for `x_name` 21 | x_vals=[128 * 2 ** i for i in range(3, 8)], 22 | # argument name whose value corresponds to a different line in the plot 23 | line_arg='provider', 24 | line_vals=['fused_chunk', 'torch', 'parallel', 'parallel_chunk', 'fused_chunk_bwd', 'torch_bwd', 25 | 'parallel_bwd', 'parallel_chunk_bwd'] + (['flash', 'flash_bwd'] if HAS_FLASH else []), 26 | # label name for the lines 27 | line_names=['fused_chunk_fwd', 'torch_fwd', 'parallel_fwd', 'parallel_chunk_fwd', 28 | 'fused_chunk_fwdbwd', 'torch_fwdbwd', 'parallel_fwdbwd', 29 | 'parallel_chunk_fwdbwd'] + (['flash_fwd', 'flash_fwdbwd'] if HAS_FLASH else []), 30 | 31 | # line styles 32 | styles=[('green', '-'), ('blue', '-'), ('red', '-'), ('green', 'dotted'), ('blue', 'dotted'), 33 | ('red', 'dotted'), ('red', '--'), ('red', ':')] + ([('cyan', '-'), ('cyan', 'dotted')] if HAS_FLASH else []), 34 | ylabel="Execution Time (ms)", # label name for the y-axis 35 | # name for the plot. Used also as a file name for saving the plot. 36 | plot_name="Performance", 37 | args={}, 38 | ) 39 | ) 40 | def benchmark(T, provider): 41 | device = 'cuda' 42 | dtype = torch.bfloat16 43 | requires_grad = True 44 | B, H, D = 8, 16, 128 45 | 46 | if provider == 'flash' or provider == 'flash_bwd': 47 | q = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype) 48 | k = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype) 49 | v = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype) 50 | else: 51 | q = torch.randn(B, H, T, 16, device=device, requires_grad=requires_grad, dtype=dtype) 52 | k = torch.randn(B, H, T, 16, device=device, requires_grad=requires_grad, dtype=dtype) 53 | v = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype) 54 | do = torch.ones_like(v, dtype=dtype) 55 | 56 | quantiles = [0.5, 0.2, 0.8] 57 | results = 0, 0, 0 58 | if provider == 'torch': 59 | if T > 1024: 60 | return results 61 | results = triton.testing.do_bench(lambda: naive_parallel_based(q, k, v), quantiles=quantiles) 62 | elif provider == 'fused_chunk': 63 | results = triton.testing.do_bench(lambda: fused_chunk_based(q, k, v), quantiles=quantiles) 64 | elif provider == 'parallel': 65 | results = triton.testing.do_bench(lambda: parallel_based(q, k, v), quantiles=quantiles) 66 | elif provider == 'parallel_chunk': 67 | results = triton.testing.do_bench(lambda: naive_chunk_based(q, k, v), quantiles=quantiles) 68 | elif provider == 'torch_bwd': 69 | if T > 1024: 70 | return results 71 | results = triton.testing.do_bench(lambda: naive_parallel_based(q, k, v).backward(do), quantiles=quantiles) 72 | elif provider == 'fused_chunk_bwd': 73 | results = triton.testing.do_bench(lambda: fused_chunk_based(q, k, v).backward(do), quantiles=quantiles) 74 | elif provider == 'parallel_bwd': 75 | results = triton.testing.do_bench(lambda: parallel_based(q, k, v).backward(do), quantiles=quantiles) 76 | elif provider == 'flash': 77 | results = triton.testing.do_bench(lambda: flash_attn_func(q, k, v, causal=True), quantiles=quantiles) 78 | elif provider == 'flash_bwd': 79 | results = triton.testing.do_bench(lambda: flash_attn_func(q, k, v, causal=True).backward(do), quantiles=quantiles) 80 | elif provider == 'parallel_chunk_bwd': 81 | results = triton.testing.do_bench(lambda: naive_chunk_based(q, k, v).backward(do), quantiles=quantiles) 82 | return results 83 | 84 | 85 | if __name__ == '__main__': 86 | benchmark.run(print_data=True, show_plots=True) 87 | -------------------------------------------------------------------------------- /benchmarks/ops/benchmark_delta_rule.py: -------------------------------------------------------------------------------- 1 | # Install the newest triton version with 2 | # pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python" 3 | 4 | import torch 5 | from benchmark import benchmark_combined, benchmark_forward, benchmark_backward 6 | 7 | from fla.ops.delta_rule import (chunk_delta_rule, 8 | fused_recurrent_delta_rule, fused_chunk_delta_rule) 9 | from fla.ops.retention import fused_chunk_retention 10 | # from flash_attn import flash_attn_func 11 | 12 | 13 | def time_fwd(func, *args, **kwargs): 14 | time_fb = benchmark_forward(func, *args, **kwargs) 15 | return time_fb[1].mean 16 | 17 | 18 | def time_fwd_bwd(func, *args, **kwargs): 19 | time_fb = benchmark_combined(func, *args, **kwargs) 20 | return time_fb[1].mean 21 | 22 | def time_bwd(func, *args, **kwargs): 23 | time_fb = benchmark_backward(func, *args, **kwargs) 24 | return time_fb[1].mean 25 | 26 | 27 | repeats = 256 28 | device = 'cuda' 29 | dtype = torch.bfloat16 30 | 31 | 32 | bs_seqlen_vals = [(8, 2048), (4, 4096), (2, 8192)] 33 | causal_vals = [True] 34 | headdim_vals = [64, 128, 256] 35 | dim = 2048 36 | dropout_p = 0.0 37 | 38 | 39 | methods = (["chunk_delta_rule", "fused_chunk_delta_rule"]) 40 | time_f = {} 41 | time_b = {} 42 | time_f_b = {} 43 | speed_f = {} 44 | speed_b = {} 45 | speed_f_b = {} 46 | for causal in causal_vals: 47 | for headdim in headdim_vals: 48 | for B, seqlen in bs_seqlen_vals: 49 | config = (causal, headdim, B, seqlen) 50 | H = dim // headdim 51 | q = torch.randn(B, H, seqlen, headdim, device=device, requires_grad=True, dtype=dtype) 52 | k = torch.nn.functional.normalize(torch.randn(B, H, seqlen, headdim, device=device, dtype=dtype), p=2, dim=-1).requires_grad_(True) 53 | v = torch.randn(B, H, seqlen, headdim, device=device, requires_grad=True, dtype=dtype) 54 | beta = torch.rand(B, H, seqlen, device=device, dtype=dtype).sigmoid().requires_grad_(True) 55 | o1, _ = chunk_delta_rule(q, k, v, beta) 56 | o1.sum().backward(retain_graph=True) 57 | f_b = time_fwd_bwd( 58 | chunk_delta_rule, q, k, v, beta, verbose=False 59 | ) 60 | time_f_b[config, "chunk_delta_rule"] = f_b 61 | 62 | # q = torch.randn(B, seqlen, H, headdim, device=device, requires_grad=True, dtype=dtype) 63 | # k = torch.randn(B, seqlen, H, headdim, device=device, requires_grad=True, dtype=dtype) 64 | # v = torch.randn(B, seqlen, H, headdim, device=device, requires_grad=True, dtype=dtype) 65 | f_b = time_fwd_bwd( 66 | fused_chunk_delta_rule, q, k, v, beta, verbose=False 67 | ) 68 | time_f_b[config, "fused_chunk_delta_rule"] = f_b 69 | 70 | 71 | print(f"### causal={causal}, headdim={headdim}, B={B}, seqlen={seqlen} ###") 72 | for method in methods: 73 | # time_f_b[config, method] = time_f[config, method] + time_b[config, method] 74 | print(f"{method:>50} fwd + bwd:\t {time_f_b[config, method]*1000:>6.4f} ms ") 75 | 76 | # speed_f[config, method] = efficiency( 77 | # flops(B, seqlen, headdim, H, causal, mode="fwd"), 78 | # time_f[config, method] 79 | # ) 80 | # speed_b[config, method] = efficiency( 81 | # flops(B, seqlen, headdim, H, causal, mode="bwd"), 82 | # time_b[config, method] 83 | # ) 84 | # speed_f_b[config, method] = efficiency( 85 | # flops(B, seqlen, headdim, H, causal, mode="fwd_bwd"), 86 | # time_f_b[config, method] 87 | # ) 88 | # print( 89 | # f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, " 90 | # f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, " 91 | # f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s" 92 | # ) 93 | 94 | 95 | # with open('flash2_attn_time.plk', 'wb') as fp: 96 | # pickle.dump((speed_f, speed_b, speed_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL) 97 | -------------------------------------------------------------------------------- /benchmarks/ops/benchmark_fla.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import triton 5 | 6 | from fla.ops.based import parallel_based 7 | from fla.ops.gla import fused_chunk_gla 8 | from fla.ops.retention import fused_chunk_retention, parallel_retention 9 | 10 | try: 11 | from flash_attn import flash_attn_func 12 | HAS_FLASH = True 13 | except ImportError: 14 | HAS_FLASH = False 15 | 16 | 17 | @triton.testing.perf_report( 18 | triton.testing.Benchmark( 19 | # argument names to use as an x-axis for the plot 20 | x_names=['T'], 21 | # different possible values for `x_name` 22 | x_vals=[128 * 2 ** i for i in range(0, 8)], 23 | # argument name whose value corresponds to a different line in the plot 24 | line_arg='provider', 25 | # possible values for `line_arg`` 26 | line_vals=['retention_parallel', 'retention_fused_chunk', 27 | 'gla_fused_chunk', 'based_parallel'] + (['flash'] if HAS_FLASH else []), 28 | # label name for the lines 29 | line_names=['retention_parallel_fwdbwd', 'retention_fused_chunk_fwdbwd', 30 | 'gla_fused_chunk_fwdbwd', 'based_parallel_fwdbwd'] + (['flash_fwdbwd'] if HAS_FLASH else []), 31 | # line styles 32 | styles=[('green', '-'), ('blue', '--'), ('red', '-.'), ('cyan', ':')] + \ 33 | ([('yellow', 'dotted')] if HAS_FLASH else []), 34 | ylabel="Execution Time (ms)", # label name for the y-axis 35 | # name for the plot. Used also as a file name for saving the plot. 36 | plot_name="Performance", 37 | args={}, 38 | ) 39 | ) 40 | def benchmark(T, provider): 41 | device = 'cuda' 42 | dtype = torch.bfloat16 43 | requires_grad = True 44 | B, H, D = 16, 8, 128 45 | 46 | if provider == 'flash': 47 | q = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype) 48 | k = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype) 49 | v = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype) 50 | elif "based" in provider: 51 | q = torch.randn(B, H, T, 16, device=device, requires_grad=requires_grad, dtype=dtype) 52 | k = torch.randn(B, H, T, 16, device=device, requires_grad=requires_grad, dtype=dtype) 53 | v = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype) 54 | elif "gla" in provider: 55 | q = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype) 56 | k = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype) 57 | v = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype) 58 | g = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype) 59 | else: 60 | q = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype) 61 | k = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype) 62 | v = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype) 63 | 64 | do = torch.rand_like(v, dtype=dtype) 65 | 66 | quantiles = [0.5, 0.2, 0.8] 67 | results = 0, 0, 0 68 | if provider == 'flash': 69 | results = triton.testing.do_bench(lambda: flash_attn_func(q, k, v).backward(do), quantiles=quantiles) 70 | elif provider == 'retention_parallel': 71 | results = triton.testing.do_bench(lambda: parallel_retention(q, k, v).backward(do), quantiles=quantiles) 72 | elif provider == 'retention_fused_chunk': 73 | results = triton.testing.do_bench(lambda: fused_chunk_retention(q, k, v).backward(do), quantiles=quantiles) 74 | elif provider == 'based_parallel': 75 | results = triton.testing.do_bench(lambda: parallel_based(q, k, v).backward(do), quantiles=quantiles) 76 | elif provider == 'gla_fused_chunk': 77 | results = triton.testing.do_bench(lambda: fused_chunk_gla(q, k, v, g).backward(do), quantiles=quantiles) 78 | 79 | return results 80 | 81 | 82 | if __name__ == '__main__': 83 | benchmark.run(print_data=True, show_plots=True) 84 | -------------------------------------------------------------------------------- /benchmarks/ops/benchmark_gla.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import triton 5 | from torch.nn import functional as F 6 | 7 | from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla 8 | from fla.ops.retention import chunk_retention, parallel_retention 9 | from fla.ops.retention.naive import naive_retention 10 | 11 | 12 | @triton.testing.perf_report( 13 | triton.testing.Benchmark( 14 | # argument names to use as an x-axis for the plot 15 | x_names=['T'], 16 | # different possible values for `x_name` 17 | x_vals=[128 * 2 ** i for i in range(0, 8)], 18 | # argument name whose value corresponds to a different line in the plot 19 | line_arg='provider', 20 | # possible values for `line_arg`` 21 | line_vals=['fused_chunk_gla', 'recurrent_gla', 'chunk_gla', 'chunk_retention', 22 | 'fused_chunk_gla_bwd', 'recurrent_gla_bwd', 'chunk_gla_bwd', 'chunk_retention_bwd'], 23 | # label name for the lines 24 | line_names=['fused_chunk_gla', 'recurrent_gla', 'chunk_gla', 'chunk_retention', 25 | 'fused_chunk_gla_bwd', 'recurrent_gla_bwd', 'chunk_gla_bwd', 'chunk_retention_bwd'], 26 | # line styles 27 | styles=[('green', '-'), ('blue', '--'), ('red', '-.'), 28 | ('cyan', ':'), ('yellow', 'dotted'), ('cyan', '--'), ('cyan', '-'), ('black', ':')], 29 | ylabel="Execution Time (ms)", # label name for the y-axis 30 | # name for the plot. Used also as a file name for saving the plot. 31 | plot_name="Performance", 32 | args={}, 33 | ) 34 | ) 35 | def benchmark(T, provider): 36 | device = 'cuda' 37 | dtype = torch.bfloat16 38 | requires_grad = True 39 | B, H, D = 16, 8, 128 40 | 41 | q = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype) 42 | k = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype) 43 | g = F.logsigmoid(torch.randn(B, H, T, D, device=device, dtype=dtype)).clamp_min(-5).requires_grad_(requires_grad) 44 | v = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype) 45 | 46 | do = torch.ones_like(q, dtype=dtype) 47 | 48 | quantiles = [0.5, 0.2, 0.8] 49 | results = 0, 0, 0 50 | if provider == 'torch': 51 | if T > 2048: 52 | return results 53 | results = triton.testing.do_bench(lambda: naive_retention(q, k, v), quantiles=quantiles) 54 | elif provider == 'recurrent_gla': 55 | results = triton.testing.do_bench(lambda: fused_recurrent_gla(q, k, v, g), quantiles=quantiles) 56 | elif provider == 'fused_chunk_gla': 57 | results = triton.testing.do_bench(lambda: fused_chunk_gla(q, k, v, g), quantiles=quantiles) 58 | elif provider == 'chunk_retention': 59 | results = triton.testing.do_bench(lambda: chunk_retention(q, k, v), quantiles=quantiles) 60 | elif provider == 'chunk_gla': 61 | results = triton.testing.do_bench(lambda: chunk_gla(q, k, v, g), quantiles=quantiles) 62 | elif provider == 'parallel': 63 | results = triton.testing.do_bench(lambda: parallel_retention(q, k, v), quantiles=quantiles) 64 | elif provider == 'torch_bwd': 65 | if T > 2048: 66 | return results 67 | elif provider == 'chunk_retention_bwd': 68 | results = triton.testing.do_bench(lambda: chunk_retention(q, k, v)[0].backward(do), quantiles=quantiles) 69 | elif provider == 'recurrent_gla_bwd': 70 | results = triton.testing.do_bench(lambda: fused_recurrent_gla(q, k, v, g)[0].backward(do), quantiles=quantiles) 71 | elif provider == 'fused_chunk_gla_bwd': 72 | results = triton.testing.do_bench(lambda: fused_chunk_gla(q, k, v, g)[0].backward(do), quantiles=quantiles) 73 | elif provider == 'chunk_gla_bwd': 74 | results = triton.testing.do_bench(lambda: chunk_gla(q, k, v, g)[0].backward(do), quantiles=quantiles) 75 | elif provider == 'parallel_bwd': 76 | results = triton.testing.do_bench(lambda: parallel_retention(q, k, v)[0].backward(do), quantiles=quantiles) 77 | return results 78 | 79 | 80 | if __name__ == '__main__': 81 | benchmark.run(print_data=True) 82 | -------------------------------------------------------------------------------- /benchmarks/ops/benchmark_gsa.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import triton 5 | from torch.nn import functional as F 6 | 7 | from fla.ops.gla import chunk_gla 8 | from fla.ops.gsa import chunk_gsa, fused_recurrent_gsa 9 | from fla.ops.retention import chunk_retention 10 | 11 | try: 12 | from flash_attn import flash_attn_func 13 | HAS_FLASH = True 14 | except BaseException: 15 | HAS_FLASH = False 16 | 17 | 18 | @triton.testing.perf_report( 19 | triton.testing.Benchmark( 20 | # argument names to use as an x-axis for the plot 21 | x_names=['T'], 22 | # different possible values for `x_name` 23 | x_vals=[128 * 2 ** i for i in range(0, 8)], 24 | # argument name whose value corresponds to a different line in the plot 25 | line_arg='provider', 26 | # possible values for `line_arg`` 27 | line_vals=['gsa_recurrent', 'gsa_chunk', 'gla', 28 | 'gsa_recurrent_bwd', 'gsa_chunk_bwd', 'gla_bwd', 'retention_bwd', 'flash_bwd'], 29 | # label name for the lines 30 | line_names=['gsa_recurrent', 'gsa_chunk', 'gla', 31 | 'gsa_recurrent_bwd', 'gsa_chunk_bwd', 'gla_bwd', 'retention_bwd', 'flash_bwd'], 32 | # line styles 33 | styles=[('green', '-'), ('blue', '--'), ('red', '-.'), 34 | ('cyan', ':'), ('yellow', 'dotted'), ('black', ':'), ('green', ':'), ('green', 'dotted'), ('green', ':')], 35 | ylabel="Execution Time (ms)", # label name for the y-axis 36 | # name for the plot. Used also as a file name for saving the plot. 37 | plot_name="Performance", 38 | args={}, 39 | ) 40 | ) 41 | def benchmark(T, provider): 42 | device = 'cuda' 43 | dtype = torch.bfloat16 44 | requires_grad = True 45 | B, H, D, M = 16, 4, 128, 64 46 | 47 | q = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype) 48 | k = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype) 49 | v = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype) 50 | if provider.startswith('gsa'): 51 | f = F.logsigmoid(torch.randn(B, H, T, M, device=device, dtype=dtype)) 52 | s = (1 - f.exp()).to(f.dtype) 53 | if provider.startswith('gla'): 54 | g = F.logsigmoid(torch.randn(B, H, T, D, device=device, dtype=dtype)) 55 | g = g.clamp_min(-5).requires_grad_(requires_grad) 56 | if provider.startswith('flash'): 57 | q = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype) 58 | k = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype) 59 | v = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype) 60 | do = torch.ones_like(v, dtype=dtype) 61 | 62 | quantiles = [0.5, 0.2, 0.8] 63 | if provider == 'gsa_recurrent': 64 | return triton.testing.do_bench(lambda: fused_recurrent_gsa(q, k, v, s, f), quantiles=quantiles) 65 | if provider == 'gsa_chunk': 66 | return triton.testing.do_bench(lambda: chunk_gsa(q, k, v, s, f), quantiles=quantiles) 67 | elif provider == 'gla': 68 | return triton.testing.do_bench(lambda: chunk_gla(q, k, v, g), quantiles=quantiles) 69 | elif provider == 'gsa_recurrent_bwd': 70 | return triton.testing.do_bench(lambda: fused_recurrent_gsa(q, k, v, s, f)[0].backward(do), quantiles=quantiles) 71 | elif provider == 'gsa_chunk_bwd': 72 | return triton.testing.do_bench(lambda: chunk_gsa(q, k, v, s, f)[0].backward(do), quantiles=quantiles) 73 | elif provider == 'gla_bwd': 74 | return triton.testing.do_bench(lambda: chunk_gla(q, k, v, g)[0].backward(do), quantiles=quantiles) 75 | elif provider == 'retention_bwd': 76 | return triton.testing.do_bench(lambda: chunk_retention(q, k, v)[0].backward(do), quantiles=quantiles) 77 | elif provider == 'flash_bwd': 78 | return triton.testing.do_bench(lambda: flash_attn_func(q, k, v, causal=True).backward(do), quantiles=quantiles) 79 | 80 | 81 | if __name__ == '__main__': 82 | benchmark.run(print_data=True) 83 | -------------------------------------------------------------------------------- /benchmarks/ops/benchmark_hgrn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import triton 5 | 6 | from fla.ops.hgrn import chunk_hgrn, fused_recurrent_hgrn 7 | 8 | 9 | @triton.testing.perf_report( 10 | triton.testing.Benchmark( 11 | # argument names to use as an x-axis for the plot 12 | x_names=['T'], 13 | # different possible values for `x_name` 14 | x_vals=[128 * 2 ** i for i in range(0, 8)], 15 | # argument name whose value corresponds to a different line in the plot 16 | line_arg='provider', 17 | # possible values for `line_arg`` 18 | line_vals=['chunk', 'recurrent', 'chunk_bwd', 'recurrent_bwd'], 19 | # label name for the lines 20 | line_names=['chunk', 'recurrent', 'chunk_bwd', 'recurrent_bwd'], 21 | # line styles 22 | styles=[('green', '-'), ('blue', '--'), ('red', '-.'), ('cyan', ':'), ('yellow', 'dotted'), ('black', 'dashed')], 23 | ylabel="Execution Time (ms)", # label name for the y-axis 24 | # name for the plot. Used also as a file name for saving the plot. 25 | plot_name="Performance", 26 | args={}, 27 | ) 28 | ) 29 | def benchmark(T, provider): 30 | dtype = torch.bfloat16 31 | B, D = 16, 512 32 | 33 | x = torch.randn((B, T, D), dtype=dtype, device='cuda') 34 | g = torch.randn((B, T, D), dtype=dtype, device='cuda').sigmoid() 35 | x = (1 - g) * x 36 | x, g = (i.detach().clone().to(dtype).requires_grad_() for i in (x, g)) 37 | do = torch.randn_like(x, dtype=dtype) 38 | quantiles = [0.5, 0.2, 0.8] 39 | results = 0, 0, 0 40 | if provider == 'chunk': 41 | results = triton.testing.do_bench(lambda: chunk_hgrn(x, g), quantiles=quantiles) 42 | if provider == 'recurrent': 43 | results = triton.testing.do_bench(lambda: fused_recurrent_hgrn(x, g), quantiles=quantiles) 44 | if provider == 'chunk_bwd': 45 | results = triton.testing.do_bench(lambda: chunk_hgrn(x, g)[0].backward(do), quantiles=quantiles) 46 | if provider == 'recurrent_bwd': 47 | results = triton.testing.do_bench(lambda: fused_recurrent_hgrn(x, g)[0].backward(do), quantiles=quantiles) 48 | return results 49 | 50 | 51 | if __name__ == '__main__': 52 | benchmark.run(print_data=True) 53 | -------------------------------------------------------------------------------- /benchmarks/ops/benchmark_retention.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | import torch 6 | import triton 7 | 8 | from fla.ops.retention import (chunk_retention, fused_recurrent_retention, 9 | parallel_retention) 10 | from fla.ops.retention.naive import naive_retention 11 | 12 | try: 13 | from flash_attn import flash_attn_func 14 | HAS_FLASH = True 15 | except BaseException: 16 | HAS_FLASH = False 17 | 18 | 19 | @triton.testing.perf_report( 20 | triton.testing.Benchmark( 21 | # argument names to use as an x-axis for the plot 22 | x_names=['T'], 23 | # different possible values for `x_name` 24 | x_vals=[128 * 2 ** i for i in range(0, 8)], 25 | # argument name whose value corresponds to a different line in the plot 26 | line_arg='provider', 27 | # possible values for `line_arg`` 28 | line_vals=['fused_chunk', 'chunk', 'parallel', 29 | 'chunk_bwd', 'parallel_bwd'] + (['flash', 'flash_bwd'] if HAS_FLASH else []), 30 | # label name for the lines 31 | line_names=['fused_chunk_fwd', 'chunk_fwd', 'parallel_fwd', 32 | 'chunk_fwdbwd', 'parallel_fwdbwd'] + (['flash_fwd', 'flash_fwdbwd'] if HAS_FLASH else []), 33 | # line styles 34 | styles=[('green', '-'), ('blue', '-'), ('red', '-'), ('green', 'dotted'), ('blue', 'dotted'), 35 | ('red', 'dotted')] + ([('cyan', '-'), ('cyan', 'dotted')] if HAS_FLASH else []), 36 | ylabel="Execution Time (ms)", # label name for the y-axis 37 | # name for the plot. Used also as a file name for saving the plot. 38 | plot_name="Performance", 39 | args={}, 40 | ) 41 | ) 42 | def benchmark(T, provider): 43 | device = 'cuda' 44 | dtype = torch.bfloat16 45 | requires_grad = True 46 | B, H, D = 4, 8, 256 47 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 48 | 49 | if provider == 'flash' or provider == 'flash_bwd': 50 | q = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype) 51 | k = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype) 52 | v = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype) 53 | else: 54 | q = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype) 55 | k = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype) 56 | v = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype) 57 | do = torch.ones_like(q, dtype=dtype) 58 | 59 | quantiles = [0.5, 0.2, 0.8] 60 | results = 0, 0, 0 61 | if provider == 'torch': 62 | if T > 2048: 63 | return results 64 | results = triton.testing.do_bench(lambda: naive_retention(q, k, v), quantiles=quantiles) 65 | elif provider == 'recurrent': 66 | results = triton.testing.do_bench(lambda: fused_recurrent_retention(q, k, v), quantiles=quantiles) 67 | elif provider == 'chunk': 68 | results = triton.testing.do_bench(lambda: chunk_retention(q, k, v), quantiles=quantiles) 69 | elif provider == 'parallel': 70 | results = triton.testing.do_bench(lambda: parallel_retention(q, k, v), quantiles=quantiles) 71 | elif provider == 'torch_bwd': 72 | if T > 2048: 73 | return results 74 | results = triton.testing.do_bench(lambda: naive_retention(q, k, v).backward(do), quantiles=quantiles) 75 | elif provider == 'recurrent_bwd': 76 | results = triton.testing.do_bench(lambda: fused_recurrent_retention(q, k, v)[0].backward(do), quantiles=quantiles) 77 | elif provider == 'chunk_bwd': 78 | results = triton.testing.do_bench(lambda: chunk_retention(q, k, v)[0].backward(do), quantiles=quantiles) 79 | elif provider == 'parallel_bwd': 80 | results = triton.testing.do_bench(lambda: parallel_retention(q, k, v)[0].backward(do), quantiles=quantiles) 81 | elif provider == 'flash': 82 | results = triton.testing.do_bench(lambda: flash_attn_func(q, k, v, causal=True), quantiles=quantiles) 83 | elif provider == 'flash_bwd': 84 | results = triton.testing.do_bench(lambda: flash_attn_func(q, k, v, causal=True).backward(do), quantiles=quantiles) 85 | return results 86 | 87 | 88 | if __name__ == '__main__': 89 | benchmark.run(print_data=True) 90 | -------------------------------------------------------------------------------- /benchmarks/ops/benchmark_rwkv6.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import triton 5 | from torch.nn import functional as F 6 | 7 | from fla.ops.gla import chunk_gla 8 | from fla.ops.retention import chunk_retention 9 | from fla.ops.rwkv6 import chunk_rwkv6 10 | 11 | try: 12 | from flash_attn import flash_attn_func 13 | HAS_FLASH = True 14 | except BaseException: 15 | HAS_FLASH = False 16 | 17 | 18 | @triton.testing.perf_report( 19 | triton.testing.Benchmark( 20 | # argument names to use as an x-axis for the plot 21 | x_names=['T'], 22 | # different possible values for `x_name` 23 | x_vals=[128 * 2 ** i for i in range(0, 8)], 24 | # argument name whose value corresponds to a different line in the plot 25 | line_arg='provider', 26 | # possible values for `line_arg`` 27 | line_vals=['rwkv6', 'gla', 'rwkv6_bwd', 'gla_bwd', 'retention_bwd', 'flash_bwd'], 28 | # label name for the lines 29 | line_names=['rwkv6', 'gla', 'rwkv6_bwd', 'gla_bwd', 'retention_bwd', 'flash_bwd'], 30 | # line styles 31 | styles=[('green', '-'), ('blue', '--'), ('red', '-.'), 32 | ('cyan', ':'), ('yellow', 'dotted'), ('black', ':')], 33 | ylabel="Execution Time (ms)", # label name for the y-axis 34 | # name for the plot. Used also as a file name for saving the plot. 35 | plot_name="Performance", 36 | args={}, 37 | ) 38 | ) 39 | def benchmark(T, provider): 40 | device = 'cuda' 41 | dtype = torch.bfloat16 42 | requires_grad = True 43 | B, H, D = 16, 8, 128 44 | 45 | q = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype) 46 | k = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype) 47 | v = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype) 48 | if provider.startswith('flash'): 49 | q = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype) 50 | k = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype) 51 | v = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype) 52 | if provider.startswith('gla'): 53 | g = F.logsigmoid(torch.randn(B, H, T, D, device=device, dtype=dtype)) 54 | g = g.clamp_min(-5).requires_grad_(requires_grad) 55 | if provider.startswith('rwkv6'): 56 | w = F.logsigmoid(torch.randn(B, H, T, D, device=device, dtype=dtype)).requires_grad_(True) 57 | u = torch.randn(H, D, device=device, dtype=dtype).requires_grad_(True) 58 | 59 | do = torch.ones_like(v, dtype=dtype) 60 | 61 | quantiles = [0.5, 0.2, 0.8] 62 | if provider == 'rwkv6': 63 | results = triton.testing.do_bench(lambda: chunk_rwkv6(q, k, v, w, u), quantiles=quantiles) 64 | elif provider == 'gla': 65 | results = triton.testing.do_bench(lambda: chunk_gla(q, k, v, g), quantiles=quantiles) 66 | elif provider == 'rwkv6_bwd': 67 | results = triton.testing.do_bench(lambda: chunk_rwkv6(q, k, v, w, u)[0].backward(do), quantiles=quantiles) 68 | elif provider == 'gla_bwd': 69 | results = triton.testing.do_bench(lambda: chunk_gla(q, k, v, g)[0].backward(do), quantiles=quantiles) 70 | elif provider == 'retention_bwd': 71 | results = triton.testing.do_bench(lambda: chunk_retention(q, k, v)[0].backward(do), quantiles=quantiles) 72 | elif provider == 'flash_bwd': 73 | results = triton.testing.do_bench(lambda: flash_attn_func(q, k, v, causal=True).backward(do), quantiles=quantiles) 74 | return results 75 | 76 | 77 | if __name__ == '__main__': 78 | benchmark.run(print_data=True) 79 | -------------------------------------------------------------------------------- /benchmarks/ops/benchmark_simple_gla_vs_mamba2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dependencies: 3 | $ pip install mamba-ssm==2.2.2 triton==2.3.1 4 | 5 | For correctness check, see: 6 | https://github.com/sustcsonglin/flash-linear-attention/pull/49 7 | """ 8 | 9 | import torch 10 | import triton 11 | 12 | from fla.ops.simple_gla import chunk_simple_gla 13 | 14 | from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined 15 | 16 | 17 | @triton.testing.perf_report( 18 | triton.testing.Benchmark( 19 | # argument names to use as an x-axis for the plot 20 | x_names=['T'], 21 | # different possible values for `x_name` 22 | x_vals=[64] + [128 * 2 ** i for i in range(0, 8)], 23 | # argument name whose value corresponds to a different line in the plot 24 | line_arg='provider', 25 | # possible values for `line_arg`` 26 | line_vals=["chunk_simple_gla", "mamba2_ssd"], 27 | # label name for the lines 28 | line_names=["chunk_simple_gla", "mamba2_ssd"], 29 | # line styles 30 | styles=[('blue', '-'), ('red', '-')], 31 | ylabel="Execution Time (ms)", # label name for the y-axis 32 | # name for the plot. Used also as a file name for saving the plot. 33 | plot_name="Performance", 34 | args={}, 35 | ) 36 | ) 37 | def benchmark(T, provider): 38 | # TODO: also add bwd pass benchmark 39 | device = 'cuda' 40 | dtype = torch.bfloat16 41 | B, H, D = 16, 8, 128 42 | # TODO: test more shapes 43 | # TODO: different values for D_V and D_QK 44 | # TODO: different values for H_Q and H_KV 45 | final_state = False # does not impact performance 46 | 47 | # initialize Mamba2-format inputs 48 | X_mamba = 0.1 * torch.randn(B, T, H, D, dtype=dtype, device=device) 49 | dt_mamba = torch.ones(B, T, H, dtype=dtype, device=device) 50 | A_mamba = -0.1 * torch.rand(H, dtype=dtype, device=device) 51 | B_mamba = 0.1 * torch.randn(B, T, H, D, dtype=dtype, device=device) 52 | C_mamba = 0.1 * torch.randn(B, T, H, D, dtype=dtype, device=device) 53 | 54 | quantiles = [0.5, 0.2, 0.8] 55 | if provider == 'chunk_simple_gla': 56 | # mapping inputs Mamba2 -> FLA 57 | # C, B, X: [B, T, H, D] -> [B, H, T, D] 58 | # g: [B, T, H] -> [B, H, T] 59 | q = C_mamba.transpose(1, 2).contiguous() 60 | k = B_mamba.transpose(1, 2).contiguous() 61 | v = X_mamba.transpose(1, 2).contiguous() 62 | g = (A_mamba * dt_mamba).transpose(1, 2).contiguous() 63 | # NOTE: whether to include the memory-copy cost of `contiguous()`? 64 | # this depends on the memory layout used by surrounding non-SSM layers 65 | 66 | results = triton.testing.do_bench( 67 | lambda: chunk_simple_gla( 68 | q, k, v, g, scale=1.0, output_final_state=final_state 69 | ), quantiles=quantiles 70 | ) 71 | 72 | elif provider == 'mamba2_ssd': 73 | # NOTE: `chunk_size` is configurable in mamba2 kernel 74 | # here sets to the same hard-coded `BT = 64` as in simple_gla kernel 75 | # TODO: benchmark different chunk sizes 76 | results = triton.testing.do_bench( 77 | lambda: mamba_chunk_scan_combined( 78 | X_mamba, dt_mamba, A_mamba, B_mamba, C_mamba, 79 | chunk_size=64, D=None, return_final_states=final_state 80 | ), 81 | quantiles=quantiles 82 | ) 83 | return results 84 | 85 | if __name__ == '__main__': 86 | benchmark.run(print_data=True, save_path='.') 87 | -------------------------------------------------------------------------------- /evals/harness.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import annotations 4 | 5 | import sys 6 | import os 7 | import mom # noqa 8 | import fla 9 | from lm_eval.__main__ import cli_evaluate 10 | from lm_eval.api.registry import register_model 11 | from lm_eval.models.huggingface import HFLM 12 | 13 | 14 | @register_model('fla') 15 | class FlashLinearAttentionLMWrapper(HFLM): 16 | def __init__(self, **kwargs) -> FlashLinearAttentionLMWrapper: 17 | 18 | # TODO: provide options for doing inference with different kernels 19 | 20 | super().__init__(**kwargs) 21 | 22 | 23 | if __name__ == "__main__": 24 | cli_evaluate() 25 | -------------------------------------------------------------------------------- /lm-eval-harness/CITATION.bib: -------------------------------------------------------------------------------- 1 | @misc{eval-harness, 2 | author = {Gao, Leo and Tow, Jonathan and Abbasi, Baber and Biderman, Stella and Black, Sid and DiPofi, Anthony and Foster, Charles and Golding, Laurence and Hsu, Jeffrey and Le Noac'h, Alain and Li, Haonan and McDonell, Kyle and Muennighoff, Niklas and Ociepa, Chris and Phang, Jason and Reynolds, Laria and Schoelkopf, Hailey and Skowron, Aviya and Sutawika, Lintang and Tang, Eric and Thite, Anish and Wang, Ben and Wang, Kevin and Zou, Andy}, 3 | title = {A framework for few-shot language model evaluation}, 4 | month = 12, 5 | year = 2023, 6 | publisher = {Zenodo}, 7 | version = {v0.4.0}, 8 | doi = {10.5281/zenodo.10256836}, 9 | url = {https://zenodo.org/records/10256836} 10 | } 11 | -------------------------------------------------------------------------------- /lm-eval-harness/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @haileyschoelkopf @lintangsutawika 2 | -------------------------------------------------------------------------------- /lm-eval-harness/LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 EleutherAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /lm-eval-harness/docs/README.md: -------------------------------------------------------------------------------- 1 | # Eval Harness Documentation 2 | 3 | Welcome to the docs for the LM Evaluation Harness! 4 | 5 | ## Table of Contents 6 | 7 | * To learn about the public interface of the library, as well as how to evaluate via the commandline or as integrated into an external library, see the [Interface](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/docs/interface.md) 8 | * To learn how to add a new library, API, or model type to the library, as well as a quick explainer on the types of ways to evaluate an LM, see the [Model Guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/docs/model_guide.md). 9 | * For a crash course on adding new tasks to the library, see our [New Task Guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/docs/new_task_guide.md). 10 | * To learn more about pushing the limits of task configuration that the Eval Harness supports, see the [Task Configuration Guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/docs/task_guide.md). 11 | -------------------------------------------------------------------------------- /lm-eval-harness/docs/decontamination.md: -------------------------------------------------------------------------------- 1 | # Decontamination 2 | 3 | ## Usage 4 | 5 | Simply add a "--decontamination_ngrams_path" when running \__main\__.py. The provided directory should contain 6 | the ngram files and info.json produced in "Pile Ngram Generation" further down. 7 | 8 | ```bash 9 | python -m lm_eval \ 10 | --model gpt2 \ 11 | --device 0 \ 12 | --tasks sciq \ 13 | --decontamination_ngrams_path path/containing/training/set/ngrams 14 | ``` 15 | 16 | ## Background 17 | Downstream evaluations test model generalization, and are less useful when test set data also exists in the training set, referred to as leakage or contamination. 18 | 19 | Filtering your training set against the test set is a good first step, however this isn't always possible, as in the case of a new benchmark or one that wasn't considered prior to model training. When training set filtering isn't possible, it is useful to measure the impact of test set leakage by detecting the contaminated test examples and producing a clean version of the benchmark. 20 | 21 | The basis for our decontamination procedure can be found in Appendix C of "Language Models are Few-Shot Learners". OpenAI defined a test document as contaminated if any N-gram overlap existed with any training document. They used a range of N values between 8 and 13 depending on dataset, while we just used 13 for simplicity. 22 | 23 | ## Implementation 24 | Contamination detection can be found in `lm_eval/decontaminate.py` with supporting code in `lm_eval/decontamination/`. 25 | 26 | decontaminate.py does the following: 27 | 1. Build dictionaries of all ngrams and their corresponding evaluation/document ids. 28 | 2. Scan through sorted files containing training set n-grams. 29 | 3. If a match is found, the corresponding evaluation/document combinations are marked as contaminated. 30 | 31 | `lm_eval/evaluator.py` can then produce a clean version of the benchmark by excluding the results of contaminated documents. For each metric, a clean version will be shown in the results with a "decontaminate" suffix. 32 | 33 | This is disabled by default for new tasks, to support decontamination on a task override the "should_decontaminate" and "doc_to_decontamination_query" methods. For more details see the [task guide](task_guide.md). 34 | 35 | ## Pile Ngram Generation 36 | The relevant scripts can be found in `scripts/clean_training_data`, which also import from 37 | `lm_eval/decontamination/` 38 | 39 | 1. git clone https://github.com/EleutherAI/lm-evaluation-harness.git 40 | 2. pip install -r requirements.txt 41 | 3. Download The Pile from [The Eye](https://the-eye.eu/public/AI/pile/train/) 42 | 4. Place pile files in "pile" directory under "lm-evaluation-harness" (or create a symlink) 43 | 5. Run generate_13_grams. 44 | 45 | ```bash 46 | export PYTHONHASHSEED=0 47 | python -m scripts/clean_training_data/generate_13_grams \ 48 | -dir path/to/working/directory \ 49 | -n 13 \ 50 | -buckets 500 51 | ``` 52 | 53 | Took approximately 4 days for us. We had the time to wait, but this could be scaled out by doing partial pile scans on multiple instances of this script and merging the relevant buckets. We fixed PYTHONHASHSEED to ensure reproducibility of bucket hashing in case you need to stop and start. 54 | 55 | 6. Sort the generated 13-grams. 56 | ```bash 57 | python -m scripts/clean_training_data/sort_13_gram_buckets \ 58 | -dir path/to/working/directory/output 59 | ``` 60 | 61 | Took approximately 5 days for us. You could speed this up by spreading the files around to different machines and running the sort script before gathering them together. 62 | 63 | 7. Compress the sorted 13 grams files and place them together with info.json. 64 | 65 | This step only takes a few hours. 66 | 67 | ```bash 68 | python -m scripts/clean_training_data/compress_and_package \ 69 | -dir path/to/working/directory \ 70 | -output path/to/final/directory \ 71 | -procs 8 72 | ``` 73 | 74 | Congratulations, the final directory can now be passed to lm-evaulation-harness with the "--decontamination_ngrams_path" argument. 75 | -------------------------------------------------------------------------------- /lm-eval-harness/docs/img/fewshot_example_gpt3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenSparseLLMs/MoM/2c3ae59f5c0b749f916189433b8dad5a3415dc08/lm-eval-harness/docs/img/fewshot_example_gpt3.png -------------------------------------------------------------------------------- /lm-eval-harness/ignore.txt: -------------------------------------------------------------------------------- 1 | ROUGE 2 | rouge 3 | nin 4 | maka 5 | mor 6 | te 7 | ond 8 | extraversion 9 | -------------------------------------------------------------------------------- /lm-eval-harness/launch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import List, Optional 4 | 5 | from lm_eval.__main__ import cli_evaluate 6 | 7 | 8 | from datetime import datetime 9 | import os 10 | import importlib.util 11 | 12 | import click 13 | from tqdm import tqdm 14 | 15 | 16 | MAX_WORKERS_PER_GPU = 1 17 | 18 | 19 | def execute_config( 20 | model: str, 21 | task: str, 22 | batch_size: int, 23 | limit: int, 24 | output_dir: str, 25 | num_fewshot: int, 26 | context_length: int = 1000, 27 | answer_length: int = 50, 28 | cutting_context: bool = False, 29 | decode_mode: str = "default", 30 | ): 31 | # Save the original standard output 32 | import subprocess 33 | 34 | output_dir = os.path.join(output_dir, model, task) 35 | 36 | args = [ 37 | "lm_eval", 38 | "--model", "based_lm", 39 | "--model_args", f"checkpoint_name={model}", 40 | "--tasks", task, 41 | "--device", "cuda:0", 42 | "--batch_size", str(batch_size), 43 | "--log_samples", 44 | "--output_path", output_dir, 45 | "--decode_mode", decode_mode, 46 | "--num_fewshot", str(num_fewshot), 47 | # , 48 | 49 | ] 50 | 51 | if cutting_context: 52 | args.extend(["--cutting_context"]) 53 | args.extend(["--context_length", str(context_length)]) 54 | args.extend(["--answer_length", str(answer_length)]) 55 | args.extend(["--context_key", "text"]) 56 | 57 | if 'squad' not in task: 58 | args.extend(["--answer_key", "key", "value"]) 59 | else: 60 | args.extend(["--answer_key", "value"]) 61 | 62 | if limit is not None: 63 | args.extend(["--limit", str(limit)]) 64 | 65 | subprocess.run(args) 66 | 67 | print(f"Decoded with mode: {decode_mode}") 68 | 69 | 70 | @click.command() 71 | @click.option("-m", "--model", type=str, multiple=True) 72 | @click.option("-t", "--task", type=str, multiple=True) 73 | @click.option("-p", "--parallelize", is_flag=True) 74 | @click.option("--gpus", default=None, type=str) 75 | @click.option("--batch-size", default=8, type=int) 76 | @click.option("--limit", default=None, type=int) 77 | @click.option("--num_fewshot", default=0, type=int) 78 | @click.option("--context_length", default=1000, type=int) 79 | @click.option("--answer_length", default=50, type=int) 80 | @click.option("--output_dir", default="output", type=str) 81 | @click.option("--cutting_context", is_flag=True) 82 | @click.option("--decode_mode", default="default", type=str) 83 | def main( 84 | model: List[str], 85 | task: List[str], 86 | batch_size: int, 87 | limit: Optional[int], 88 | parallelize: bool, 89 | gpus: str, 90 | num_fewshot: int = 0, 91 | output_dir: str = "output", 92 | context_length: int = 1000, 93 | answer_length: int = 50, 94 | cutting_context: bool = False, 95 | decode_mode: str = 'default' 96 | ): 97 | 98 | if limit < 0: limit = None 99 | 100 | if gpus is not None: 101 | os.environ["CUDA_VISIBLE_DEVICES"] = gpus 102 | 103 | # Load the given Python file as a module 104 | configs = [ 105 | {"model": m, "task": t} for m in model for t in task 106 | ] 107 | 108 | use_ray = parallelize and len(configs) > 0 109 | if use_ray: 110 | import ray 111 | # ray was killing workers due to OOM, but it didn't seem to be necessary 112 | os.environ["RAY_memory_monitor_refresh_ms"] = "0" 113 | ray.init(ignore_reinit_error=True, log_to_driver=True) 114 | 115 | print(f"Running sweep with {len(configs)} configs") 116 | 117 | output_dir = f"{output_dir}/{datetime.now().strftime('%y-%m-%d_%H-%M')}" 118 | 119 | # Run each script in parallel using Ray 120 | if not use_ray: 121 | for config in configs: 122 | execute_config( 123 | **config, 124 | batch_size=batch_size, 125 | limit=limit, 126 | output_dir=output_dir, 127 | num_fewshot=num_fewshot, 128 | context_length=context_length, 129 | answer_length=answer_length, 130 | cutting_context=cutting_context, 131 | decode_mode=decode_mode, 132 | ) 133 | else: 134 | completed = 0 135 | total = len(configs) 136 | print(f"Completed: {completed} ({completed / total:0.1%}) | Total: {total}") 137 | 138 | remote = ray.remote(num_gpus=(1 // MAX_WORKERS_PER_GPU))(execute_config) 139 | futures = [remote.remote( 140 | **config, batch_size=batch_size, limit=limit, output_dir=output_dir, num_fewshot=num_fewshot, 141 | context_length=context_length, answer_length=answer_length, cutting_context=cutting_context, 142 | decode_mode=decode_mode, 143 | ) for config in configs] 144 | 145 | while futures: 146 | complete, futures = ray.wait(futures) 147 | completed += len(complete) 148 | print(f"Completed: {completed} ({completed / total:0.1%}) | Total: {total}") 149 | 150 | ray.shutdown() 151 | 152 | if __name__ == "__main__": 153 | main() 154 | -------------------------------------------------------------------------------- /lm-eval-harness/launch_hf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import List, Optional 4 | 5 | from lm_eval.__main__ import cli_evaluate 6 | 7 | 8 | from datetime import datetime 9 | import os 10 | import importlib.util 11 | 12 | import click 13 | from tqdm import tqdm 14 | 15 | 16 | MAX_WORKERS_PER_GPU = 1 17 | 18 | 19 | def execute_config( 20 | model: str, 21 | task: str, 22 | batch_size: int, 23 | limit: int, 24 | output_dir: str, 25 | context_length: int = 1000, 26 | cutting_context: bool = False, 27 | answer_length: int=50, 28 | ): 29 | # Save the original standard output 30 | import subprocess 31 | 32 | output_dir = os.path.join(output_dir, model, task) 33 | 34 | if 'mamba' in model.lower() and 'rw' not in model.lower(): model_name = "mamba_ssm" 35 | else: model_name = 'hf-auto' 36 | 37 | args = [ 38 | "lm_eval", 39 | "--model", f"{model_name}", 40 | "--model_args", f"checkpoint_name={model}", 41 | "--tasks", task, 42 | "--device", "cuda:0", 43 | "--batch_size", str(batch_size), 44 | "--log_samples", 45 | "--output_path", output_dir 46 | ] 47 | 48 | if cutting_context: 49 | args.extend(["--cutting_context"]) 50 | args.extend(["--context_length", str(context_length)]) 51 | args.extend(["--answer_length", str(answer_length)]) 52 | args.extend(["--context_key", "text"]) 53 | 54 | if limit is not None: 55 | args.extend(["--limit", str(limit)]) 56 | 57 | subprocess.run(args) 58 | 59 | 60 | 61 | @click.command() 62 | @click.option("-m", "--model", type=str, multiple=True) 63 | @click.option("-t", "--task", type=str, multiple=True) 64 | @click.option("-p", "--parallelize", is_flag=True) 65 | @click.option("--gpus", default=None, type=str) 66 | @click.option("--batch-size", default=8, type=int) 67 | @click.option("--limit", default=None, type=int) 68 | @click.option("--context_length", default=1000, type=int) 69 | @click.option("--answer_length", default=50, type=int) 70 | @click.option("--cutting_context", is_flag=True) 71 | @click.option("--output_dir", default="output", type=str) 72 | def main( 73 | model: List[str], 74 | task: List[str], 75 | batch_size: int, 76 | limit: Optional[int], 77 | parallelize: bool, 78 | gpus: str, 79 | context_length: int, 80 | cutting_context: bool, 81 | answer_length: int, 82 | output_dir: str, 83 | ): 84 | if limit < 0: limit = None 85 | 86 | if gpus is not None: 87 | os.environ["CUDA_VISIBLE_DEVICES"] = gpus 88 | 89 | # Load the given Python file as a module 90 | configs = [ 91 | {"model": m, "task": t} for m in model for t in task 92 | ] 93 | 94 | use_ray = parallelize and len(configs) > 0 95 | if use_ray: 96 | import ray 97 | # ray was killing workers due to OOM, but it didn't seem to be necessary 98 | os.environ["RAY_memory_monitor_refresh_ms"] = "0" 99 | ray.init(ignore_reinit_error=True, log_to_driver=True) 100 | 101 | print(f"Running sweep with {len(configs)} configs") 102 | 103 | output_dir = f"{output_dir}/{datetime.now().strftime('%y-%m-%d_%H-%M')}" 104 | 105 | # Run each script in parallel using Ray 106 | if not use_ray: 107 | for config in configs: 108 | execute_config( 109 | **config, 110 | batch_size=batch_size, 111 | limit=limit, 112 | output_dir=output_dir, 113 | context_length=context_length, 114 | answer_length=answer_length, 115 | cutting_context=cutting_context 116 | ) 117 | else: 118 | completed = 0 119 | total = len(configs) 120 | print(f"Completed: {completed} ({completed / total:0.1%}) | Total: {total}") 121 | 122 | remote = ray.remote(num_gpus=(1 // MAX_WORKERS_PER_GPU))(execute_config) 123 | futures = [remote.remote( 124 | **config, batch_size=batch_size, limit=limit, output_dir=output_dir, 125 | answer_length=answer_length, cutting_context=cutting_context, context_length=context_length, 126 | ) for config in configs] 127 | 128 | while futures: 129 | complete, futures = ray.wait(futures) 130 | completed += len(complete) 131 | print(f"Completed: {completed} ({completed / total:0.1%}) | Total: {total}") 132 | 133 | ray.shutdown() 134 | 135 | 136 | 137 | if __name__ == "__main__": 138 | main() 139 | 140 | -------------------------------------------------------------------------------- /lm-eval-harness/launch_jrt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import List, Optional 4 | 5 | from lm_eval.__main__ import cli_evaluate 6 | 7 | 8 | from datetime import datetime 9 | import os 10 | import importlib.util 11 | 12 | import click 13 | from tqdm import tqdm 14 | 15 | 16 | MAX_WORKERS_PER_GPU = 1 17 | 18 | 19 | def execute_config( 20 | model: str, 21 | task: str, 22 | batch_size: int, 23 | limit: int, 24 | output_dir: str, 25 | num_fewshot: int, 26 | context_length: int = 1000, 27 | answer_length: int = 50, 28 | cutting_context: bool = False, 29 | decode_mode: str = "default", 30 | ): 31 | # Save the original standard output 32 | import subprocess 33 | 34 | output_dir = os.path.join(output_dir, model, task) 35 | 36 | args = [ 37 | "lm_eval", 38 | "--model", "jrt_lm", 39 | "--model_args", f"checkpoint_name={model}", 40 | "--tasks", task, 41 | "--device", "cuda:0", 42 | "--batch_size", str(batch_size), 43 | "--log_samples", 44 | "--output_path", output_dir, 45 | "--decode_mode", decode_mode, 46 | "--num_fewshot", str(num_fewshot), 47 | ] 48 | 49 | if cutting_context: 50 | args.extend(["--cutting_context"]) 51 | args.extend(["--context_length", str(context_length)]) 52 | args.extend(["--answer_length", str(answer_length)]) 53 | args.extend(["--context_key", "text"]) 54 | 55 | if 'squad' not in task: 56 | args.extend(["--answer_key", "key", "value"]) 57 | else: 58 | args.extend(["--answer_key", "value"]) 59 | 60 | if limit is not None: 61 | args.extend(["--limit", str(limit)]) 62 | 63 | subprocess.run(args) 64 | 65 | print(f"Decoded with mode: {decode_mode}") 66 | 67 | 68 | @click.command() 69 | @click.option("-m", "--model", type=str, multiple=True) 70 | @click.option("-t", "--task", type=str, multiple=True) 71 | @click.option("-p", "--parallelize", is_flag=True) 72 | @click.option("--gpus", default=None, type=str) 73 | @click.option("--batch-size", default=8, type=int) 74 | @click.option("--limit", default=None, type=int) 75 | @click.option("--num_fewshot", default=0, type=int) 76 | @click.option("--context_length", default=1000, type=int) 77 | @click.option("--answer_length", default=50, type=int) 78 | @click.option("--output_dir", default="output", type=str) 79 | @click.option("--cutting_context", is_flag=True) 80 | @click.option("--decode_mode", default="default", type=str) 81 | def main( 82 | model: List[str], 83 | task: List[str], 84 | batch_size: int, 85 | limit: Optional[int], 86 | parallelize: bool, 87 | gpus: str, 88 | num_fewshot: int = 0, 89 | output_dir: str = "output", 90 | context_length: int = 1000, 91 | answer_length: int = 50, 92 | cutting_context: bool = False, 93 | decode_mode: str = 'default' 94 | ): 95 | 96 | if limit < 0: limit = None 97 | 98 | if gpus is not None: 99 | os.environ["CUDA_VISIBLE_DEVICES"] = gpus 100 | 101 | # Load the given Python file as a module 102 | configs = [ 103 | {"model": m, "task": t} for m in model for t in task 104 | ] 105 | 106 | use_ray = parallelize and len(configs) > 0 107 | if use_ray: 108 | import ray 109 | # ray was killing workers due to OOM, but it didn't seem to be necessary 110 | os.environ["RAY_memory_monitor_refresh_ms"] = "0" 111 | ray.init(ignore_reinit_error=True, log_to_driver=True) 112 | 113 | print(f"Running sweep with {len(configs)} configs") 114 | 115 | output_dir = f"{output_dir}/{datetime.now().strftime('%y-%m-%d_%H-%M')}" 116 | 117 | # Run each script in parallel using Ray 118 | if not use_ray: 119 | for config in configs: 120 | execute_config( 121 | **config, 122 | batch_size=batch_size, 123 | limit=limit, 124 | output_dir=output_dir, 125 | num_fewshot=num_fewshot, 126 | context_length=context_length, 127 | answer_length=answer_length, 128 | cutting_context=cutting_context, 129 | decode_mode=decode_mode, 130 | ) 131 | else: 132 | completed = 0 133 | total = len(configs) 134 | print(f"Completed: {completed} ({completed / total:0.1%}) | Total: {total}") 135 | 136 | remote = ray.remote(num_gpus=(1 // MAX_WORKERS_PER_GPU))(execute_config) 137 | futures = [remote.remote( 138 | **config, batch_size=batch_size, limit=limit, output_dir=output_dir, num_fewshot=num_fewshot, 139 | context_length=context_length, answer_length=answer_length, cutting_context=cutting_context, 140 | decode_mode=decode_mode, 141 | ) for config in configs] 142 | 143 | while futures: 144 | complete, futures = ray.wait(futures) 145 | completed += len(complete) 146 | print(f"Completed: {completed} ({completed / total:0.1%}) | Total: {total}") 147 | 148 | ray.shutdown() 149 | 150 | if __name__ == "__main__": 151 | main() 152 | -------------------------------------------------------------------------------- /lm-eval-harness/launch_local.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | import click 5 | 6 | from typing import List, Optional 7 | 8 | from lm_eval.__main__ import cli_evaluate 9 | from datetime import datetime 10 | import os 11 | import importlib.util 12 | 13 | from tqdm import tqdm 14 | 15 | MAX_WORKERS_PER_GPU = 1 16 | DEVICE = "cuda" 17 | torch.random.manual_seed(0) 18 | 19 | 20 | def execute_config( 21 | model: str, 22 | task: str, 23 | batch_size: int, 24 | limit: int, 25 | output_dir: str, 26 | num_fewshot: int, 27 | 28 | context_length: int = 1000, 29 | answer_length: int = 50, 30 | cutting_context: bool = False, 31 | decode_mode: str = 'default', 32 | ): 33 | # Save the original standard output 34 | import subprocess 35 | 36 | output_dir = os.path.join(output_dir, model, task) 37 | 38 | # pass flags to cli_evaluate() to override the defaults in argparse 39 | args = [ 40 | "lm_eval", 41 | "--model", "lm_eval_model", 42 | "--model_args", f" checkpoint_name={model}", 43 | "--task", task, 44 | "--device", "cuda:0", 45 | "--batch_size", str(batch_size), 46 | "--output_path", output_dir, 47 | "--num_fewshot", str(num_fewshot), 48 | "--decode_mode", decode_mode, 49 | 50 | "--log_samples", 51 | "--write_out", 52 | # "--output", f"{run_name}/", 53 | ] 54 | 55 | if cutting_context: 56 | args.extend(["--cutting_context"]) 57 | args.extend(["--context_length", str(context_length)]) 58 | args.extend(["--answer_length", str(answer_length)]) 59 | args.extend(["--context_key", "text"]) 60 | 61 | if limit is not None: 62 | args.extend(["--limit", str(limit)]) 63 | 64 | subprocess.run(args) 65 | 66 | print(f"Decoded with mode: {decode_mode}") 67 | 68 | 69 | @click.command() 70 | @click.option("-m", "--model", type=str, multiple=True) 71 | @click.option("-t", "--task", type=str, multiple=True) 72 | @click.option("-p", "--parallelize", is_flag=True) 73 | @click.option("--gpus", default=None, type=str) 74 | @click.option("--batch-size", default=8, type=int) 75 | @click.option("--limit", default=None, type=int) 76 | @click.option("--num_fewshot", default=0, type=int) 77 | @click.option("--context_length", default=1000, type=int) 78 | @click.option("--answer_length", default=50, type=int) 79 | @click.option("--output_dir", default="output", type=str) 80 | @click.option("--cutting_context", is_flag=True) 81 | @click.option("--decode_mode", default="default", type=str) 82 | def main( 83 | model: List[str], 84 | task: List[str], 85 | batch_size: int, 86 | limit: Optional[int], 87 | parallelize: bool, 88 | gpus: str, 89 | num_fewshot: int = 0, 90 | context_length: int = 1000, 91 | answer_length: int = 50, 92 | cutting_context: bool = False, 93 | output_dir: str = "output", 94 | decode_mode: str = 'default' 95 | ): 96 | 97 | if limit is not None and limit < 0: limit = None 98 | 99 | if gpus is not None: 100 | os.environ["CUDA_VISIBLE_DEVICES"] = gpus 101 | 102 | # Load the given Python file as a module 103 | configs = [ 104 | {"model": m, "task": t} for m in model for t in task 105 | ] 106 | 107 | use_ray = parallelize and len(configs) > 0 108 | if use_ray: 109 | import ray 110 | # ray was killing workers due to OOM, but it didn't seem to be necessary 111 | os.environ["RAY_memory_monitor_refresh_ms"] = "0" 112 | ray.init(ignore_reinit_error=True, log_to_driver=True) 113 | 114 | print(f"Running sweep with {len(configs)} configs") 115 | 116 | output_dir = f"{output_dir}/{datetime.now().strftime('%y-%m-%d_%H-%M')}" 117 | 118 | # Run each script in parallel using Ray 119 | if not use_ray: 120 | for config in configs: 121 | execute_config( 122 | **config, 123 | batch_size=batch_size, 124 | limit=limit, 125 | output_dir=output_dir, 126 | num_fewshot=num_fewshot, 127 | context_length=context_length, 128 | answer_length=answer_length, 129 | cutting_context=cutting_context, 130 | decode_mode=decode_mode 131 | ) 132 | else: 133 | completed = 0 134 | total = len(configs) 135 | print(f"Completed: {completed} ({completed / total:0.1%}) | Total: {total}") 136 | 137 | remote = ray.remote(num_gpus=(1 // MAX_WORKERS_PER_GPU))(execute_config) 138 | futures = [remote.remote( 139 | **config, batch_size=batch_size, limit=limit, output_dir=output_dir, 140 | num_fewshot=num_fewshot, context_length=context_length, answer_length=answer_length, cutting_context=cutting_context, 141 | decode_mode=decode_mode 142 | ) for config in configs] 143 | 144 | while futures: 145 | complete, futures = ray.wait(futures) 146 | completed += len(complete) 147 | print(f"Completed: {completed} ({completed / total:0.1%}) | Total: {total}") 148 | 149 | ray.shutdown() 150 | 151 | 152 | if __name__ == "__main__": 153 | main() 154 | 155 | -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluator import evaluate, simple_evaluate 2 | -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenSparseLLMs/MoM/2c3ae59f5c0b749f916189433b8dad5a3415dc08/lm-eval-harness/lm_eval/api/__init__.py -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/api/filter.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass 3 | from typing import Callable, Iterable, List, Union 4 | 5 | from lm_eval.api.instance import Instance 6 | 7 | 8 | class Filter(ABC): 9 | """ 10 | Filter classes operate on a per-task level. 11 | They take all model outputs (`instance.resps` for all `task.instances`) 12 | across all instances of a task, and perform operations. 13 | In a single run, one can configure any number of separate filters or lists of filters. 14 | 15 | """ 16 | 17 | def __init__(self, **kwargs) -> None: 18 | """ 19 | Can define custom behavior here, if an individual instantiation of a Filter class should have state. 20 | """ 21 | 22 | @abstractmethod 23 | def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable: 24 | """ 25 | Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects. 26 | Should return the list of (filtered) response lists *in the same order as they were input*, e.g. 27 | if pass in [, ] should return 28 | [, ] 29 | """ 30 | return resps 31 | 32 | 33 | @dataclass 34 | class FilterEnsemble: 35 | """ 36 | FilterEnsemble creates a pipeline applying multiple filters. 37 | Its intended usage is to stack multiple post-processing steps in order. 38 | `task.apply_filters` should use a list of FilterEnsemble classes that it stores, to apply each 39 | pipeline separately. 40 | """ 41 | 42 | name: str 43 | filters: List[Callable[[], Filter]] 44 | 45 | def apply(self, instances: List[Instance]) -> None: 46 | resps, docs = zip(*((inst.resps, inst.doc) for inst in instances)) 47 | resps, docs = list(resps), list(docs) 48 | 49 | for f in self.filters: 50 | # apply filters in sequence 51 | resps = f().apply(resps, docs) 52 | 53 | # add the end results after filtering to filtered_requests of their respective source instances. 54 | # has key `self.name`: each FilterEnsemble applied in a given run should use a different name. 55 | for inst, resp in zip(instances, resps): 56 | inst.filtered_resps[self.name] = resp 57 | -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/api/instance.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Literal, Tuple 3 | 4 | 5 | @dataclass 6 | class Instance: 7 | request_type: Literal[ 8 | "loglikelihood", 9 | "loglikelihood_rolling", 10 | "generate_until", 11 | "multiple_choice", 12 | ] 13 | doc: dict 14 | arguments: tuple 15 | idx: int 16 | metadata: Tuple[str, int, int] = field( 17 | default_factory=lambda: (None, None, None) 18 | ) # TODO: better typehints here 19 | resps: list = field(default_factory=list) 20 | filtered_resps: dict = field(default_factory=dict) 21 | 22 | # initialized after init 23 | task_name: str = None 24 | doc_id: str = None 25 | repeats: str = None 26 | 27 | def __post_init__(self) -> None: 28 | # unpack metadata field 29 | self.task_name, self.doc_id, self.repeats = self.metadata 30 | 31 | @property 32 | def args(self): 33 | """ 34 | Returns (string,) where `string` is the string to calculate loglikelihood over 35 | """ 36 | return ( 37 | self.arguments if isinstance(self.arguments, tuple) else (self.arguments,) 38 | ) 39 | -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/api/samplers.py: -------------------------------------------------------------------------------- 1 | class ContextSampler: 2 | def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None: 3 | self.rnd = rnd 4 | assert self.rnd, "must pass rnd to FewShotSampler!" 5 | 6 | self.task = task 7 | self.config = task._config 8 | 9 | self.target_delimiter = self.config.target_delimiter 10 | self.fewshot_delimiter = self.config.fewshot_delimiter 11 | 12 | self.doc_to_text = self.task.doc_to_text 13 | self.doc_to_target = self.task.doc_to_target 14 | self.doc_to_choice = self.task.doc_to_choice 15 | 16 | self.docs = docs # HF dataset split, provided by task._fewshot_docs() 17 | if fewshot_indices: # subset few-shot docs from 18 | self.docs = self.docs.select(fewshot_indices) 19 | 20 | def get_context(self, doc, num_fewshot): 21 | # draw an extra fewshot sample if using same split as evaluating on 22 | n_samples = ( 23 | num_fewshot + 1 24 | if self.config.fewshot_split == self.config.test_split 25 | else num_fewshot 26 | ) 27 | 28 | # draw `n_samples` docs from fewshot_docs 29 | fewshotex = self.sample(n_samples) 30 | 31 | # get rid of the doc that's the one we're evaluating, if it's in the fewshot 32 | # TODO: should we just stop people from using fewshot from same split as evaluating? 33 | selected_docs = [x for x in fewshotex if x != doc][:num_fewshot] 34 | 35 | labeled_examples = ( 36 | self.fewshot_delimiter.join( 37 | [ 38 | # TODO: is separating doc_to_text and doc_to_target by one space always desired? 39 | ( 40 | self.doc_to_text(doc) 41 | if ( 42 | self.config.doc_to_choice is None 43 | or isinstance(self.doc_to_text(doc), str) 44 | ) 45 | else self.doc_to_choice(doc)[self.doc_to_text(doc)] 46 | ) 47 | + self.target_delimiter 48 | + ( 49 | str(self.doc_to_target(doc)[0]) 50 | if isinstance(self.doc_to_target(doc), list) 51 | else self.doc_to_target(doc) 52 | if ( 53 | self.config.doc_to_choice is None 54 | or isinstance(self.doc_to_target(doc), str) 55 | ) 56 | else str(self.doc_to_choice(doc)[self.doc_to_target(doc)]) 57 | ) 58 | for doc in selected_docs 59 | ] 60 | ) 61 | + self.fewshot_delimiter 62 | ) 63 | 64 | return labeled_examples 65 | 66 | def sample(self, n): 67 | """ 68 | Draw `n` samples from our fewshot docs. This method should be overridden by subclasses. 69 | """ 70 | 71 | return self.rnd.sample(self.docs, n) 72 | 73 | 74 | class FirstNSampler(ContextSampler): 75 | def sample(self, n) -> None: 76 | """ 77 | Draw the first `n` samples in order from the specified split. 78 | Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU. 79 | """ 80 | assert ( 81 | n <= len(self.docs) 82 | ), f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available." 83 | return self.docs[:n] 84 | 85 | 86 | class BalancedSampler(ContextSampler): 87 | def sample(self, n) -> None: 88 | """ 89 | TODO: this should return approximately class-balanced samples from our fewshot examples. 90 | TODO: what order should they be in? maybe random? 91 | """ 92 | 93 | pass 94 | 95 | 96 | class ManualSampler(ContextSampler): 97 | def sample(self, n) -> None: 98 | """ """ 99 | pass 100 | 101 | 102 | SAMPLER_REGISTRY = { 103 | "default": ContextSampler, 104 | "first_n": FirstNSampler, 105 | } 106 | 107 | 108 | def get_sampler(name): 109 | try: 110 | return SAMPLER_REGISTRY[name] 111 | except KeyError: 112 | raise ValueError( 113 | f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}" 114 | ) 115 | -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/decontamination/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenSparseLLMs/MoM/2c3ae59f5c0b749f916189433b8dad5a3415dc08/lm-eval-harness/lm_eval/decontamination/__init__.py -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/filters/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | from functools import partial 3 | 4 | from lm_eval.api.filter import FilterEnsemble 5 | from . import selection 6 | from . import extraction 7 | from . import transformation 8 | 9 | 10 | FILTER_REGISTRY = { 11 | "take_first": selection.TakeFirstFilter, 12 | "regex": extraction.RegexFilter, 13 | "majority_vote": selection.MajorityVoteFilter, 14 | "take_first_k": selection.TakeKFilter, 15 | "remove_whitespace": extraction.WhitespaceFilter, 16 | "lowercase": transformation.LowercaseFilter, 17 | "uppercase": transformation.UppercaseFilter, 18 | "map": transformation.MapFilter, 19 | # TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function 20 | # that takes an input and returns a scalar and then should select the max reward, 21 | # or should implement different filters for different ways of handling a reward model's inference. 22 | # "arg_max": selection.ArgMaxFilter, 23 | } 24 | 25 | 26 | def get_filter(filter_name: str) -> Union[type, str]: 27 | if filter_name in FILTER_REGISTRY: 28 | return FILTER_REGISTRY[filter_name] 29 | else: 30 | return filter_name 31 | 32 | 33 | def build_filter_ensemble( 34 | filter_name: str, components: List[List[str]] 35 | ) -> FilterEnsemble: 36 | """ 37 | Create a filtering pipeline. 38 | """ 39 | filters = [] 40 | for function, kwargs in components: 41 | if kwargs is None: 42 | kwargs = {} 43 | # create a filter given its name in the registry 44 | f = partial(get_filter(function), **kwargs) 45 | # add the filter as a pipeline step 46 | filters.append(f) 47 | 48 | return FilterEnsemble(name=filter_name, filters=filters) 49 | -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/filters/decontamination.py: -------------------------------------------------------------------------------- 1 | from lm_eval.api.filter import Filter 2 | 3 | 4 | class DecontaminationFilter(Filter): 5 | 6 | """ 7 | A filter which evaluates 8 | """ 9 | 10 | name = "track_decontamination" 11 | 12 | def __init__(self, path) -> None: 13 | """ 14 | 15 | TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path"). 16 | should further cache result on a given (task_name, doc_id) 17 | """ 18 | self._decontam_results = None 19 | 20 | def apply(self, resps, docs) -> None: 21 | """ 22 | Return {"no_contamination", "only_contamination"} keys for the 2 different subsets 23 | """ 24 | pass 25 | -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/filters/extraction.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from lm_eval.api.filter import Filter 4 | 5 | 6 | class RegexFilter(Filter): 7 | """ """ 8 | 9 | def __init__( 10 | self, regex_pattern: str = r"#### (\-?[0-9\.\,]+)", fallback: str = "[invalid]" 11 | ) -> None: 12 | """ 13 | pass a string `regex` to run `re.compile(r"regex")` on. 14 | `fallback` defines the output returned if no matches for the regex are located. 15 | """ 16 | self.regex_pattern = regex_pattern 17 | self.regex = re.compile(regex_pattern) 18 | self.fallback = fallback 19 | 20 | def apply(self, resps, docs): 21 | # here, we assume we have a list, in which each element is 22 | # a list of model responses for some particular input/target pair. 23 | # so we process each of these (same input/target response sets) 24 | # independently (and keep them a list.) 25 | def filter_set(inst): 26 | filtered = [] 27 | for resp in inst: 28 | match = self.regex.search(resp) 29 | if match: 30 | match = match.group(1).strip() 31 | else: 32 | match = self.fallback 33 | filtered.append(match) 34 | return filtered 35 | 36 | # print(resps) 37 | filtered_resps = list(map(lambda x: filter_set(x), resps)) 38 | # print(filtered_resps) 39 | 40 | return filtered_resps 41 | 42 | 43 | class WhitespaceFilter(Filter): 44 | """ """ 45 | 46 | def __init__(self) -> None: 47 | pass 48 | 49 | def apply(self, resps, docs): 50 | def filter_set(inst): 51 | filtered_resp = [] 52 | for resp in inst: 53 | if resp.startswith(" "): 54 | resp = resp[1:] 55 | 56 | filtered_resp.append(resp) 57 | 58 | return filtered_resp 59 | 60 | filtered_resps = [filter_set(resp) for resp in resps] 61 | 62 | return filtered_resps 63 | -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/filters/selection.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | 3 | from lm_eval.api.filter import Filter 4 | 5 | 6 | class TakeFirstFilter(Filter): 7 | def __init__(self) -> None: 8 | """ 9 | Can define custom behavior here, if an individual instantiation of a Filter class should have state. 10 | """ 11 | 12 | def apply(self, resps, docs): 13 | """ 14 | Assuming each entry of `resps` is a list of model responses, we discard all but the first response. 15 | """ 16 | return map(lambda r: r[0], resps) 17 | 18 | 19 | class TakeKFilter(Filter): 20 | def __init__(self, **kwargs) -> None: 21 | self.k = kwargs.pop("k") 22 | 23 | super().__init__(**kwargs) 24 | 25 | def apply(self, resps, docs): 26 | # need resp to be subscriptable to check below 27 | resps = list(resps) 28 | # check we have at least k responses per doc, else we can't take the first k 29 | assert ( 30 | len(resps[0]) >= self.k 31 | ), f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ." 32 | return map(lambda r: r[: self.k], resps) 33 | 34 | 35 | class MajorityVoteFilter(Filter): 36 | def __init__(self) -> None: 37 | """ 38 | Can define custom behavior here, if an individual instantiation of a Filter class should have state. 39 | """ 40 | 41 | def apply(self, resps, docs): 42 | """ 43 | Each entry of `resps` is a list of model responses. 44 | We select the response that occurs most frequently in each entry of `resps`. 45 | """ 46 | 47 | def select_majority(resp): 48 | counts = Counter(resp) 49 | vote = counts.most_common(1)[0][0] 50 | return vote 51 | 52 | return map(lambda r: [select_majority(r)], resps) 53 | -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/filters/transformation.py: -------------------------------------------------------------------------------- 1 | from lm_eval.api.filter import Filter 2 | 3 | 4 | class LowercaseFilter(Filter): 5 | def __init__(self) -> None: 6 | pass 7 | 8 | def apply(self, resps, docs): 9 | def filter_set(inst): 10 | return [resp.lower() for resp in inst] 11 | 12 | return [filter_set(resp) for resp in resps] 13 | 14 | 15 | class UppercaseFilter(Filter): 16 | def __init__(self) -> None: 17 | pass 18 | 19 | def apply(self, resps, docs): 20 | def filter_set(inst): 21 | return [resp.upper() for resp in inst] 22 | 23 | return [filter_set(resp) for resp in resps] 24 | 25 | 26 | class MapFilter(Filter): 27 | def __init__(self, mapping_dict: dict = None, default_value=None) -> None: 28 | """ 29 | Initializes the MapFilter with a given mapping dictionary and default value. 30 | 31 | Args: 32 | - mapping_dict (dict): A dictionary containing the key-value mappings. 33 | Default is an empty dictionary. 34 | - default_value (Any): The value to be returned when a key is not found in the mapping_dict. 35 | Default is None. 36 | 37 | Example: 38 | mapper = MapFilter({'A': 1, 'B': 2}, default_value=0) 39 | """ 40 | if mapping_dict is None: 41 | mapping_dict = {} 42 | assert isinstance( 43 | mapping_dict, dict 44 | ), "Provided mapping_dict is not a dictionary" 45 | self.mapping_dict = mapping_dict 46 | self.default_value = default_value 47 | 48 | def apply(self, resps, docs): 49 | def filter_set(inst): 50 | return [self.mapping_dict.get(resp, self.default_value) for resp in inst] 51 | 52 | return [filter_set(resp) for resp in resps] 53 | -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import huggingface 2 | from . import openai_completions 3 | from . import textsynth 4 | from . import dummy 5 | from . import anthropic_llms 6 | from . import gguf 7 | from . import vllm_causallms 8 | from . import mamba_lm 9 | from . import based_lm 10 | from . import optimum_lm 11 | from . import neuron_optimum 12 | from . import local_lm 13 | from . import jrt_lm 14 | # TODO: implement __all__ 15 | 16 | 17 | import os 18 | 19 | try: 20 | # enabling faster model download 21 | import hf_transfer 22 | 23 | os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" 24 | except ImportError: 25 | pass 26 | -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/models/based_lm.py: -------------------------------------------------------------------------------- 1 | import re 2 | from transformers import AutoTokenizer 3 | import torch 4 | 5 | # from based.utils.hf import load_config_hf 6 | import json 7 | from transformers.utils import WEIGHTS_NAME, CONFIG_NAME 8 | from transformers.utils.hub import cached_file 9 | 10 | from lm_eval.api.registry import register_model 11 | from lm_eval.models.huggingface import HFLM 12 | 13 | def load_config_hf(model_name): 14 | resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False) 15 | return json.load(open(resolved_archive_file)) 16 | 17 | @register_model("based_lm") 18 | class BasedLMWrapper(HFLM): 19 | def __init__( 20 | self, 21 | checkpoint_name: str='hazyresearch/based-1.3b', 22 | arch: str=None, 23 | device: str = "cuda", 24 | **kwargs 25 | ) -> None: 26 | 27 | if arch is None: 28 | arch = checkpoint_name.split("/")[1].split("-")[0] 29 | 30 | assert arch in ['based', 'mamba', 'attn'], print("`arch` must be one of 'based', 'mamba', or 'attn'") 31 | 32 | if "backend" in kwargs: 33 | # based currently only supports causal models 34 | assert kwargs["backend"] == "causal" 35 | 36 | self.checkpoint_name = checkpoint_name 37 | 38 | if arch == "based": 39 | from based.models.gpt import GPTLMHeadModel 40 | model = GPTLMHeadModel.from_pretrained_hf(pretrained_model_name=self.checkpoint_name, device=device) 41 | elif arch == "mamba": 42 | from based.models.mamba import MambaLMHeadModel 43 | model = MambaLMHeadModel.from_pretrained_hf(pretrained_model_name=self.checkpoint_name, device=device) 44 | elif arch == "attn": 45 | from based.models.transformer.gpt import GPTLMHeadModel, GPT2Config, state_dict_from_pretrained; # TODO: construct a loading function 46 | config_data = load_config_hf(self.checkpoint_name) 47 | config = GPT2Config(**config_data) 48 | try: 49 | model = GPTLMHeadModel(config=config, device=device, dtype=torch.float16, multiple_of=256) 50 | state_dict = state_dict_from_pretrained(self.checkpoint_name, dtype=torch.float16) 51 | # remove the 'model.' prefix from the keys 52 | state_dict = {re.sub("^model\.", "", k): v for k, v in state_dict.items()} 53 | # remove Unexpected key(s) in state_dict: "train_metrics.num-tokens.count", "val_metrics.num-tokens.count", "test_metrics.num-tokens.count". from the state_dict 54 | state_dict = {k: v for k, v in state_dict.items() if "metrics" not in k} 55 | model.load_state_dict(state_dict) 56 | except: 57 | model = GPTLMHeadModel(config=config, device=device, dtype=torch.float16, multiple_of=128) 58 | state_dict = state_dict_from_pretrained(self.checkpoint_name, dtype=torch.float16) 59 | state_dict = {re.sub("^model\.", "", k): v for k, v in state_dict.items()} 60 | state_dict = {k: v for k, v in state_dict.items() if "metrics" not in k} 61 | model.load_state_dict(state_dict) 62 | else: 63 | raise ValueError(f"Unsupported model {arch}") 64 | 65 | tokenizer_name = kwargs.get("tokenizer", "gpt2") 66 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 67 | tokenizer.model_max_length = 2048 68 | 69 | model.device = device 70 | 71 | super().__init__( 72 | pretrained=model, 73 | # set appropriate defaults for tokenizer, max length, etc 74 | backend=kwargs.get("backend", "causal"), 75 | max_length=kwargs.get("max_length", 2048), 76 | tokenizer=tokenizer, 77 | device=device, 78 | **kwargs, 79 | ) -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/models/dummy.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from lm_eval.api.model import LM 4 | from lm_eval.api.registry import register_model 5 | 6 | 7 | @register_model("dummy") 8 | class DummyLM(LM): 9 | def __init__(self) -> None: 10 | super().__init__() 11 | 12 | @classmethod 13 | def create_from_arg_string(cls, arg_string, additional_config=None): 14 | return cls() 15 | 16 | def loglikelihood(self, requests): 17 | res = [] 18 | 19 | for _ in requests: 20 | res.append((-random.random(), False)) 21 | 22 | return res 23 | 24 | def generate_until(self, requests): 25 | res = [] 26 | 27 | for ctx, _ in requests: 28 | res.append("lol") 29 | assert ctx.strip() != "" 30 | 31 | return res 32 | 33 | def loglikelihood_rolling(self, requests): 34 | res = [] 35 | 36 | for _ in requests: 37 | res.append(-random.random()) 38 | 39 | return res 40 | -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/models/gguf.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | 4 | import requests 5 | from requests.exceptions import RequestException 6 | from tqdm import tqdm 7 | 8 | from lm_eval.api.model import LM 9 | from lm_eval.api.registry import register_model 10 | 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def get_result(logprobs, context_length): 16 | is_greedy = True 17 | offsets = logprobs["text_offset"] 18 | tokens = logprobs["tokens"] 19 | tokens_logprobs = logprobs["token_logprobs"] 20 | 21 | idx = 0 22 | while offsets[idx] < context_length: 23 | idx += 1 24 | continuation_logprobs = sum(tokens_logprobs[idx:-1]) 25 | for i in range(idx, len(tokens)): 26 | token = tokens[i] 27 | top_tokens = logprobs["top_logprobs"][i] 28 | top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x]) 29 | if top_token != token: 30 | is_greedy = False 31 | break 32 | 33 | return continuation_logprobs, is_greedy 34 | 35 | 36 | @register_model("gguf", "ggml") 37 | class GGUFLM(LM): 38 | def __init__(self, base_url=None, max_length=2048, **kwargs): 39 | super().__init__() 40 | self.base_url = base_url 41 | assert self.base_url, "must pass `base_url` to use GGUF LM!" 42 | self.logprobs = 10 43 | self.temperature = 0.0 44 | self.max_length = max_length 45 | 46 | def gguf_completion( 47 | self, context, continuation=None, stop=None, retries=3, delay=5, **kwargs 48 | ): 49 | for _ in range(retries): 50 | try: 51 | prompt = context 52 | request = { 53 | "prompt": prompt, 54 | "logprobs": self.logprobs, 55 | "temperature": self.temperature, 56 | } 57 | if continuation: 58 | prompt += continuation 59 | request.update({"prompt": prompt, "max_tokens": 1, "echo": True}) 60 | if stop is not None: 61 | request["stop"] = stop 62 | response = requests.post( 63 | f"{self.base_url}/v1/completions", json=request 64 | ) 65 | response.raise_for_status() 66 | return response.json() 67 | except RequestException as e: 68 | logger.error(f"RequestException: {e}") 69 | time.sleep(delay) # wait before retrying 70 | else: 71 | raise Exception(f"Failed to get a valid response after {retries} retries.") 72 | 73 | def loglikelihood(self, requests): 74 | if not requests: 75 | return [] 76 | res = [] 77 | for context, continuation in tqdm([req.args for req in requests]): 78 | response = self.gguf_completion(context=context, continuation=continuation) 79 | if response and "choices" in response and response["choices"]: 80 | choice = response["choices"][0] 81 | logprobs = choice.get("logprobs") 82 | if ( 83 | logprobs 84 | and "token_logprobs" in logprobs 85 | and logprobs["token_logprobs"] 86 | ): 87 | logprob, is_greedy = get_result(logprobs, len(context)) 88 | res.append((logprob, is_greedy)) 89 | else: 90 | logger.warning( 91 | "Invalid logprobs data. Expected 'logprobs' to contain 'token_logprobs' list." 92 | ) 93 | else: 94 | logger.error( 95 | f"Invalid response for loglikelihood. Response: {response}" 96 | ) 97 | assert False 98 | return res 99 | 100 | def generate_until(self, requests): 101 | if not requests: 102 | return [] 103 | 104 | res = [] 105 | for request in tqdm([req.args for req in requests]): 106 | inp = request[0] 107 | request_args = request[1] 108 | until = request_args.get("until", [""]) 109 | response = self.gguf_completion(context=inp, stop=until) 110 | if response and "choices" in response and response["choices"]: 111 | choice = response["choices"][0] 112 | if "text" in choice: 113 | generated_text = choice["text"].strip() 114 | res.append(generated_text) 115 | else: 116 | logger.error( 117 | f"Invalid response for greedy_until. Response: {response}" 118 | ) 119 | res.append(None) # Add default value in case of error 120 | else: 121 | logger.error(f"Invalid response for greedy_until. Response: {response}") 122 | res.append(None) # Add default value in case of error 123 | return res 124 | 125 | def loglikelihood_rolling(self, requests): 126 | raise NotImplementedError( 127 | "loglikelihood_rolling not yet supported for GGUF models" 128 | ) 129 | -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/models/jrt_lm.py: -------------------------------------------------------------------------------- 1 | import re 2 | from transformers import AutoTokenizer 3 | import torch 4 | 5 | import json 6 | from transformers.utils import WEIGHTS_NAME, CONFIG_NAME 7 | from transformers.utils.hub import cached_file 8 | 9 | from lm_eval.api.registry import register_model 10 | from lm_eval.models.huggingface import HFLM 11 | 12 | def load_config_hf(model_name): 13 | resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False) 14 | return json.load(open(resolved_archive_file)) 15 | 16 | @register_model("jrt_lm") 17 | class JRTLMWrapper(HFLM): 18 | def __init__( 19 | self, 20 | checkpoint_name: str='hazyresearch/based-1.3b', 21 | arch: str=None, 22 | device: str = "cuda", 23 | **kwargs 24 | ) -> None: 25 | 26 | if arch is None: 27 | arch = checkpoint_name.split("/")[1].split("-")[0] 28 | 29 | assert arch in ['JRT', 'based', 'mamba', 'attn'], print("`arch` must be one of 'JRT', 'based', 'mamba', or 'attn'") 30 | 31 | if "backend" in kwargs: 32 | # based currently only supports causal models 33 | assert kwargs["backend"] == "causal" 34 | 35 | self.checkpoint_name = checkpoint_name 36 | 37 | if arch == "based": 38 | from train.src.models.gpt import GPTLMHeadModel 39 | model = GPTLMHeadModel.from_pretrained_hf(pretrained_model_name=self.checkpoint_name, device=device) 40 | elif arch == "JRT": 41 | from train.src.models.gpt import GPTLMHeadModel 42 | model = GPTLMHeadModel.from_pretrained_hf(pretrained_model_name=self.checkpoint_name, device=device) 43 | elif arch == "mamba": 44 | from based.models.mamba import MambaLMHeadModel 45 | model = MambaLMHeadModel.from_pretrained_hf(pretrained_model_name=self.checkpoint_name, device=device) 46 | elif arch == "attn": 47 | from based.models.transformer.gpt import GPTLMHeadModel, GPT2Config, state_dict_from_pretrained; # TODO: construct a loading function 48 | config_data = load_config_hf(self.checkpoint_name) 49 | config = GPT2Config(**config_data) 50 | try: 51 | model = GPTLMHeadModel(config=config, device=device, dtype=torch.float16, multiple_of=256) 52 | state_dict = state_dict_from_pretrained(self.checkpoint_name, dtype=torch.float16) 53 | # remove the 'model.' prefix from the keys 54 | state_dict = {re.sub("^model\.", "", k): v for k, v in state_dict.items()} 55 | # remove Unexpected key(s) in state_dict: "train_metrics.num-tokens.count", "val_metrics.num-tokens.count", "test_metrics.num-tokens.count". from the state_dict 56 | state_dict = {k: v for k, v in state_dict.items() if "metrics" not in k} 57 | model.load_state_dict(state_dict) 58 | except: 59 | model = GPTLMHeadModel(config=config, device=device, dtype=torch.float16, multiple_of=128) 60 | state_dict = state_dict_from_pretrained(self.checkpoint_name, dtype=torch.float16) 61 | state_dict = {re.sub("^model\.", "", k): v for k, v in state_dict.items()} 62 | state_dict = {k: v for k, v in state_dict.items() if "metrics" not in k} 63 | model.load_state_dict(state_dict) 64 | else: 65 | raise ValueError(f"Unsupported model {arch}") 66 | 67 | tokenizer_name = kwargs.get("tokenizer", "gpt2") 68 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 69 | tokenizer.model_max_length = 2048 70 | 71 | model.device = device 72 | 73 | super().__init__( 74 | pretrained=model, 75 | # set appropriate defaults for tokenizer, max length, etc 76 | backend=kwargs.get("backend", "causal"), 77 | max_length=kwargs.get("max_length", 2048), 78 | tokenizer=tokenizer, 79 | device=device, 80 | **kwargs, 81 | ) -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/models/local_lm.py: -------------------------------------------------------------------------------- 1 | from .local_utils.loading import load_model, load_tokenizer, load_hf_model 2 | from lm_eval.api.registry import register_model 3 | from lm_eval.models.huggingface import HFLM 4 | 5 | 6 | @register_model("lm_eval_model") 7 | class LMWrapper(HFLM): 8 | def __init__( 9 | self, 10 | checkpoint_name: str, 11 | max_length: int = 2048, 12 | device: str = "cuda", 13 | **kwargs 14 | ) -> None: 15 | 16 | is_hf=not checkpoint_name.startswith("hazy-research") 17 | tokenizer = load_tokenizer(checkpoint_name, is_hf=is_hf) 18 | if is_hf: 19 | model = load_hf_model(checkpoint_name) 20 | else: 21 | model = load_model(checkpoint_name, device=device) 22 | # model.device = device 23 | # import torch 24 | # model.to(device, dtype=torch.float32) 25 | model.to(device) 26 | 27 | super().__init__( 28 | pretrained=model, 29 | backend="causal", 30 | max_length=max_length, 31 | tokenizer=tokenizer, 32 | device=device, 33 | **kwargs, 34 | ) 35 | 36 | -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/models/local_utils/loading.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from typing import Union 4 | import os 5 | import hydra 6 | import sys 7 | import torch 8 | from transformers import AutoTokenizer, AutoModelForCausalLM 9 | from .jrt_utils import import_object, load_config 10 | 11 | ### This code loads models that we trained ### 12 | def load_model( 13 | run_id: str, 14 | device: Union[int, str] = None, 15 | config: any=None, 16 | ) -> nn.Module: 17 | """ 18 | Load a model from a wandb run ID. 19 | Parameters: 20 | run_id (str): A full wandb run id like "hazy-research/attention/159o6asi" 21 | """ 22 | 23 | # 1: Get configuration from wandb 24 | config = load_config(run_id) 25 | path = config["callbacks"]["model_checkpoint"]["dirpath"] 26 | 27 | if config["model"].get("_instantiate_config_", True): 28 | 29 | # SE (01/29): models were trained on flash_attn==2.3.6 30 | # a newer version sets this parameter to 128 by default, so to make it 31 | # compatible while still allowing for upgrades of flash attention, 32 | # we set it to 256 here 33 | if config["model"]["_target_"] == "flash_attn.models.gpt.GPTLMHeadModel": 34 | config["model"]["config"]["mlp_multiple_of"] = 128 35 | 36 | model_config = hydra.utils.instantiate( 37 | config["model"]["config"], _recursive_=False, _convert_="object" 38 | ) 39 | cls = import_object(config["model"]["_target_"]) 40 | model = cls(model_config).to(device=device) 41 | else: 42 | # SE: need this alternate form for models that accept kwargs, not a config object 43 | model_config = config["model"].pop("config") 44 | model = hydra.utils.instantiate(config["model"], **model_config, _recursive_=False) 45 | 46 | path = path.replace( 47 | "/var/cr05_data/sim_data/checkpoints/", # old machine 48 | '/home/simarora/based-checkpoints/checkpoints/' 49 | ) 50 | 51 | try: 52 | assert os.path.exists(path), print(f"Path {path} does not exist") 53 | ckpt = torch.load(os.path.join(path, "last.ckpt"), map_location=torch.device(device)) 54 | except: 55 | paths = os.listdir(path) 56 | paths = [p for p in paths if ".ckpt" in p] 57 | print(f'Loading model from {paths[0]}') 58 | ckpt = torch.load(os.path.join(path, paths[0]), map_location=torch.device(device)) 59 | 60 | # 3: Load model 61 | # load the state dict, but remove the "model." prefix and all other keys from the 62 | # the PyTorch Lightning module that are not in the actual model 63 | model.load_state_dict({ 64 | k[len("model."):]: v 65 | for k, v in ckpt["state_dict"].items() 66 | if k.startswith("model.") 67 | }) 68 | 69 | model = model.to(device=device) 70 | return model 71 | 72 | 73 | def load_hf_model(model_name: str, device:str = 'cuda') -> nn.Module: 74 | if "mamba" in model_name: 75 | 76 | # SA: can't pass in device here https://github.com/pytorch/pytorch/issues/10622 77 | model = MambaLMHeadModel.from_pretrained(model_name, device=device, dtype=torch.float16) 78 | 79 | else: 80 | if "Mixtral" in model_name: 81 | model = AutoModelForCausalLM.from_pretrained( 82 | model_name, trust_remote_code=True, use_flash_attention_2=True, 83 | # load_in_8bit=True, 84 | torch_dtype=torch.bfloat16, device_map="auto" 85 | ) 86 | else: 87 | model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, token="your token here") 88 | model.to(device) 89 | 90 | 91 | try: model.device = device 92 | except: pass 93 | model.eval() 94 | return model 95 | 96 | 97 | def load_tokenizer(model_name: str, is_hf: bool=False) -> nn.Module: 98 | if not is_hf: 99 | tokenizer = AutoTokenizer.from_pretrained("gpt2") 100 | tokenizer.model_max_length = 2048 101 | else: 102 | if "mamba" in model_name or "mpt" in model_name: 103 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") 104 | else: 105 | tokenizer = AutoTokenizer.from_pretrained(model_name) 106 | tokenizer.pad_token = tokenizer.eos_token 107 | tokenizer.pad_token_id = tokenizer.eos_token_id 108 | return tokenizer 109 | 110 | -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/models/mamba_lm.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import torch 4 | 5 | import lm_eval.models.utils 6 | from lm_eval.api.registry import register_model 7 | from lm_eval.models.huggingface import HFLM 8 | 9 | 10 | @register_model("mamba_ssm") 11 | class MambaLMWrapper(HFLM): 12 | def __init__( 13 | self, 14 | checkpoint_name="state-spaces/mamba-130m", 15 | **kwargs, 16 | ) -> None: 17 | """ 18 | Mamba (via the `mamba_ssm` package) supports the following args: 19 | ``` 20 | d_model: int, 21 | n_layer: int, 22 | vocab_size: int, 23 | initializer_cfg=None, 24 | pad_vocab_size_multiple: int = 1, 25 | ssm_cfg=None, 26 | norm_epsilon: float = 1e-5, 27 | rms_norm: bool = False, 28 | initializer_cfg=None, 29 | fused_add_norm=False, 30 | residual_in_fp32=False, 31 | ``` 32 | 33 | See https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L175 for more info. 34 | The above can all be passed via `--model_args` or to this __init__() directly 35 | but we recommend placing many of these within the config.json file uploaded alongside your 36 | Mamba model to the HF Hub instead. 37 | All other HuggingFace from_pretrained() kwargs 38 | such as those related to 39 | `parallelize=True`, PEFT, autoGPTQ, 40 | or any sub-configurations of these advanced args, 41 | are unsupported by the `mamba_ssm` package. 42 | 43 | The HFLM arguments 44 | 45 | `backend`, `tokenizer`, `truncation`, `max_length`, 46 | `device`, `dtype`, `batch_size`, `max_batch_size`, `trust_remote_code`, `use_fast_tokenizer` 47 | 48 | Are all supported by Mamba where they do not conflict 49 | with Mamba-specific restrictions such as causal LMs only. 50 | """ 51 | 52 | if "backend" in kwargs: 53 | # mamba currently only supports causal models 54 | assert kwargs["backend"] == "causal" 55 | 56 | super().__init__( 57 | pretrained=checkpoint_name, 58 | # set appropriate defaults for tokenizer, max length, etc 59 | backend=kwargs.get("backend", "causal"), 60 | tokenizer=kwargs.get("tokenizer", "EleutherAI/gpt-neox-20b"), 61 | max_length=kwargs.get("max_length", 2048), 62 | **kwargs, 63 | ) 64 | 65 | def _get_config( 66 | self, 67 | checkpoint_name: str, 68 | **kwargs, 69 | ) -> None: 70 | try: 71 | from mamba_ssm.utils.hf import load_config_hf # noqa: F811 72 | except ModuleNotFoundError: 73 | raise Exception( 74 | "attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed. \ 75 | please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`", 76 | ) 77 | 78 | self._config = load_config_hf(checkpoint_name) 79 | 80 | def _create_model( 81 | self, 82 | pretrained: str, 83 | dtype: Optional[Union[str, torch.dtype]] = "float16", 84 | # no `parallelize=True` options 85 | # no PEFT and quantization options 86 | # Mamba does not support arbitrary HF from_pretrained() args 87 | **kwargs, 88 | ) -> None: 89 | try: 90 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel # noqa: F811 91 | except ModuleNotFoundError: 92 | raise Exception( 93 | "attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed. \ 94 | please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`", 95 | ) 96 | 97 | self._model = MambaLMHeadModel.from_pretrained( 98 | pretrained, 99 | device=self._device, 100 | dtype=torch.float16 101 | if dtype == "auto" 102 | else lm_eval.models.utils.get_dtype(dtype), 103 | ) 104 | 105 | def _model_generate(self, context, max_length, stop, **generation_kwargs): 106 | for key in ("do_sample", "attention_mask"): 107 | if key in generation_kwargs: 108 | generation_kwargs.pop(key) 109 | 110 | # mamba's custom GenerationMixin currently does not support 111 | # passing stopping criteria. 112 | # for the time being, we simply generate to max length, 113 | # then truncate (equivalent result) 114 | # -- this should be revisited to speed up generation 115 | # stopping_criteria = stop_sequences_criteria( 116 | # self.tokenizer, stop, 1, context.shape[0] 117 | # ) 118 | 119 | return self.model.generate( 120 | input_ids=context, 121 | max_length=max_length, 122 | # stopping_criteria=stopping_criteria, 123 | # pad_token_id=self.tokenizer.pad_token_id, 124 | # use_cache=True, 125 | **generation_kwargs, 126 | ) 127 | -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/models/optimum_lm.py: -------------------------------------------------------------------------------- 1 | from importlib.util import find_spec 2 | from pathlib import Path 3 | 4 | from lm_eval.api.registry import register_model 5 | from lm_eval.models.huggingface import HFLM 6 | 7 | 8 | @register_model("openvino") 9 | class OptimumLM(HFLM): 10 | """ 11 | Optimum Intel provides a simple interface to optimize Transformer models and convert them to \ 12 | OpenVINO™ Intermediate Representation (IR) format to accelerate end-to-end pipelines on \ 13 | Intel® architectures using OpenVINO™ runtime. 14 | """ 15 | 16 | def __init__( 17 | self, 18 | device="cpu", 19 | **kwargs, 20 | ) -> None: 21 | if "backend" in kwargs: 22 | # optimum currently only supports causal models 23 | assert ( 24 | kwargs["backend"] == "causal" 25 | ), "Currently, only OVModelForCausalLM is supported." 26 | 27 | self.openvino_device = device 28 | 29 | super().__init__( 30 | device=self.openvino_device, 31 | backend=kwargs.get("backend", "causal"), 32 | **kwargs, 33 | ) 34 | 35 | def _create_model( 36 | self, 37 | pretrained: str, 38 | revision="main", 39 | dtype="auto", 40 | trust_remote_code=False, 41 | **kwargs, 42 | ) -> None: 43 | if not find_spec("optimum"): 44 | raise Exception( 45 | "package `optimum` is not installed. Please install it via `pip install optimum[openvino]`" 46 | ) 47 | else: 48 | from optimum.intel.openvino import OVModelForCausalLM 49 | 50 | model_kwargs = kwargs if kwargs else {} 51 | model_file = Path(pretrained) / "openvino_model.xml" 52 | if model_file.exists(): 53 | export = False 54 | else: 55 | export = True 56 | kwargs["ov_config"] = { 57 | "PERFORMANCE_HINT": "LATENCY", 58 | "NUM_STREAMS": "1", 59 | "CACHE_DIR": "", 60 | } 61 | 62 | self._model = OVModelForCausalLM.from_pretrained( 63 | pretrained, 64 | revision=revision, 65 | trust_remote_code=trust_remote_code, 66 | export=export, 67 | device=self.openvino_device.upper(), 68 | **model_kwargs, 69 | ) 70 | -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/prompts/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ast 3 | 4 | from typing import Dict 5 | from lm_eval import utils 6 | from lm_eval.utils import eval_logger 7 | 8 | # Prompt library. 9 | # Stores prompts in a dictionary indexed by 2 levels: 10 | # prompt category name, and prompt name. 11 | # This allows us to access prompts 12 | PROMPT_REGISTRY: Dict[str, Dict[str, str]] = { 13 | "qa-basic": { 14 | "question-newline-answer": "Question: {{question}}\nAnswer:", 15 | "q-newline-a": "Q: {{question}}\nA:", 16 | }, 17 | } 18 | 19 | 20 | def get_prompt(prompt_id: str, dataset_name: str = None, subset_name: str = None): 21 | # unpack prompt name 22 | category_name, prompt_name = prompt_id.split(":") 23 | if subset_name is None: 24 | dataset_full_name = dataset_name 25 | else: 26 | dataset_full_name = f"{dataset_name}-{subset_name}" 27 | eval_logger.info(f"Loading prompt from {category_name} for {dataset_full_name}") 28 | if category_name == "promptsource": 29 | try: 30 | from promptsource.templates import DatasetTemplates 31 | except ModuleNotFoundError: 32 | raise Exception( 33 | "Tried to load a Promptsource template, but promptsource is not installed ", 34 | "please install promptsource via pip install lm-eval[promptsource] or pip install -e .[promptsource]", 35 | ) 36 | try: 37 | if subset_name is None: 38 | prompts = DatasetTemplates(dataset_name=dataset_name) 39 | else: 40 | prompts = DatasetTemplates( 41 | dataset_name=dataset_name, subset_name=subset_name 42 | ) 43 | except Exception: 44 | raise ValueError(f"{dataset_name} and {subset_name} not found") 45 | if prompt_name in prompts.all_template_names: 46 | return prompts[prompt_name] 47 | else: 48 | raise ValueError( 49 | f"{prompt_name} not in prompt list {prompts.all_template_names}" 50 | ) 51 | elif ".yaml" in category_name: 52 | import yaml 53 | 54 | with open(category_name, "rb") as file: 55 | prompt_yaml_file = yaml.full_load(file) 56 | 57 | prompt_string = prompt_yaml_file["prompts"][prompt_name] 58 | return PromptString(prompt_string) 59 | else: 60 | try: 61 | return PROMPT_REGISTRY[category_name][prompt_name] 62 | except Exception: 63 | raise ValueError( 64 | f"expected only a single `:` as separator between \ 65 | prompt category and name, but got `{prompt_id}` instead" 66 | ) 67 | 68 | 69 | def load_prompt_list( 70 | use_prompt: str, dataset_name=None, subset_name=None, yaml_path=None, **kwargs 71 | ): 72 | category_name, prompt_name = use_prompt.split(":") 73 | 74 | if category_name == "promptsource": 75 | from promptsource.templates import DatasetTemplates 76 | 77 | if subset_name is None: 78 | prompts = DatasetTemplates(dataset_name=dataset_name) 79 | else: 80 | prompts = DatasetTemplates( 81 | dataset_name=dataset_name, subset_name=subset_name 82 | ) 83 | 84 | prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names) 85 | 86 | elif ".yaml" in category_name: 87 | import yaml 88 | 89 | if yaml_path is not None: 90 | category_name = os.path.realpath(os.path.join(yaml_path, category_name)) 91 | 92 | with open(category_name, "rb") as file: 93 | prompt_yaml_file = yaml.full_load(file) 94 | 95 | prompt_list = utils.pattern_match( 96 | prompt_name, prompt_yaml_file["prompts"].keys() 97 | ) 98 | 99 | # category_name, *prompt_name = use_prompt.split(":") 100 | # TODO allow to multiple prompt naming 101 | # if len(prompt_name) > 1: 102 | # prompt_list = [] 103 | # for prompt in prompt_name: 104 | # prompt_list.append(utils.pattern_match(prompt_name, prompts.all_template_names)) 105 | # else: 106 | # prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names) 107 | return [":".join([category_name, prompt]) for prompt in prompt_list] 108 | 109 | 110 | class PromptString: 111 | def __init__(self, prompt_string): 112 | self.prompt_string = prompt_string 113 | 114 | def apply(self, doc): 115 | doc_to_text = self.prompt_string["doc_to_text"] 116 | doc_to_target = self.prompt_string["doc_to_target"] 117 | 118 | # TODO need a way to process doc_to_choice 119 | if "doc_to_choice" in self.prompt_string: 120 | raise Exception("Not yet implemented to accept doc_to_choice") 121 | 122 | text_string = utils.apply_template(doc_to_text, doc) 123 | target_string = utils.apply_template(doc_to_target, doc) 124 | 125 | return [text_string, target_string] 126 | -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/tasks/README.md: -------------------------------------------------------------------------------- 1 | # v1.0 Tasks 2 | This list keeps track of which tasks' implementations have been ported to YAML / v2.0 of the Eval Harness. 3 | 4 | Boxes should be checked iff tasks are implemented in the refactor and tested for regression. Tasks should be struck through if checked *against original introducing paper* implementation or popularizing implementation. (WIP) Denotes that there exists a PR or person working on this task already. 5 | 6 | - [x] Glue 7 | - [x] SuperGlue 8 | - [x] CoQA 9 | - [x] DROP 10 | - [x] ~~Lambada~~ 11 | - [x] Lambada (Cloze variants) 12 | - [x] ~~Lambada (Multilingual)~~ 13 | - [x] Wikitext 14 | - [x] PiQA 15 | - [x] PROST 16 | - [x] MCTACO 17 | - [x] Pubmed QA 18 | - [x] SciQ 19 | - [x] QASPER 20 | - [x] QA4MRE 21 | - [x] TriviaQA 22 | - [x] AI2 ARC 23 | - [x] LogiQA 24 | - [x] HellaSwag 25 | - [x] SWAG 26 | - [x] OpenBookQA 27 | - [ ] SQuADv2 (Lintang) 28 | - [x] RACE 29 | - [x] HeadQA 30 | - [x] MathQA 31 | - [x] WebQs 32 | - [x] WSC273 33 | - [x] Winogrande 34 | - [x] ANLI 35 | - [x] Hendrycks Ethics (missing some tasks/metrics, see PR 660: for more info) 36 | - [x] TruthfulQA (mc1) 37 | - [x] TruthfulQA (mc2) 38 | - [x] TruthfulQA (gen) 39 | - [x] MuTual 40 | - [ ] Hendrycks Math (Hailey) 41 | - [x] Asdiv 42 | - [ ] GSM8k 43 | - [x] Arithmetic 44 | - [ ] MMMLU (Hailey) 45 | - [x] Translation (WMT) suite 46 | - [x] Unscramble 47 | - [x] ~~Pile (perplexity)~~ 48 | - [x] BLiMP 49 | - [x] ToxiGen 50 | - [x] StoryCloze 51 | - [ ] NaturalQs (Hailey) 52 | - [x] CrowS-Pairs 53 | - [x] XCopa 54 | - [ ] BIG-Bench (Hailey) 55 | - [x] XStoryCloze 56 | - [x] XWinograd 57 | - [x] PAWS-X 58 | - [x] XNLI 59 | - [x] MGSM 60 | - [ ] SCROLLS 61 | - [x] Babi 62 | - [x] Belebele 63 | 64 | # Novel Tasks 65 | Tasks added in the revamped harness that were not previously available. Again, a strikethrough denotes checking performed *against the original task's implementation or published results introducing the task*. 66 | 67 | # Task Wishlist 68 | 69 | - [ ] TheoremQA 70 | - [ ] Theorem Proving evaluations 71 | - [ ] Chain of Thought 72 | - [ ] Self-consistency ; Least-to-Most prompting, etc. 73 | - [ ] Summarization Tasks 74 | - [ ] Anthropic Model-Written Evals 75 | -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/tasks/based_drop/task.py: -------------------------------------------------------------------------------- 1 | from lm_eval.api.task import ConfigurableTask 2 | from lm_eval.api.instance import Instance 3 | import re 4 | from typing import List 5 | import numpy as np 6 | 7 | 8 | class BasedDrop(ConfigurableTask): 9 | VERSION = "default" 10 | DATASET_PATH = "hazyresearch/based_drop" 11 | DATASET_NAME = None 12 | 13 | def __init__(self): 14 | super().__init__(config={'metadata': {'version': self.VERSION}}) 15 | 16 | def has_training_docs(self): 17 | return False 18 | 19 | def has_validation_docs(self): 20 | return True 21 | 22 | def has_test_docs(self): 23 | return False 24 | 25 | def validation_docs(self): 26 | return self.dataset["validation"] 27 | 28 | def doc_to_text(self, doc): 29 | context = doc["context"].strip() 30 | question = doc["question"].strip() 31 | while(context.lower().endswith(question.lower())): 32 | context = context[:-len(question)] 33 | 34 | out = ( 35 | context.strip().strip(".") + ". " + question 36 | ) 37 | return out 38 | 39 | def should_decontaminate(self): 40 | return True 41 | 42 | def doc_to_decontamination_query(self, doc): 43 | return doc["context"] 44 | 45 | def doc_to_target(self, doc): 46 | answer_list = doc['answers'] 47 | if len(answer_list) > 0: 48 | answer = answer_list[0] 49 | else: 50 | answer = "unanswerable" 51 | return " " + answer 52 | 53 | def construct_requests(self, doc, ctx, **kwargs): 54 | """Uses RequestFactory to construct Requests and returns an iterable of 55 | Requests which will be sent to the LM. 56 | 57 | :param doc: 58 | The document as returned from training_docs, validation_docs, or test_docs. 59 | :param ctx: str 60 | The context string, generated by fewshot_context. This includes the natural 61 | language description, as well as the few shot examples, and the question 62 | part of the document for `doc`. 63 | """ 64 | 65 | return [ 66 | Instance( 67 | request_type="generate_until", 68 | doc=doc, 69 | arguments=(ctx, {"until": ["\n"], "max_gen_toks": 48}), 70 | idx=0, 71 | **kwargs, 72 | ), 73 | ] 74 | 75 | def process_results(self, doc, results): 76 | """Take a single document and the LM results and evaluates, returning a 77 | dict where keys are the names of submetrics and values are the values of 78 | the metric for that one document 79 | 80 | :param doc: 81 | The document as returned from training_docs, validation_docs, or test_docs. 82 | :param results: 83 | The results of the requests created in construct_requests. 84 | """ 85 | 86 | continuation = results[0] 87 | 88 | return { 89 | "contains": contains_score(continuation, doc["answers"]) 90 | } 91 | 92 | 93 | def aggregation(self): 94 | """ 95 | :returns: {str: [float] -> float} 96 | A dictionary where keys are the names of submetrics and values are 97 | functions that aggregate a list of metrics 98 | """ 99 | return { 100 | "contains": np.mean, 101 | } 102 | 103 | def higher_is_better(self): 104 | """ 105 | :returns: {str: bool} 106 | A dictionary where keys are the names of submetrics and values are 107 | whether a higher value of the submetric is better 108 | """ 109 | return { 110 | "contains": True 111 | } 112 | 113 | def contains_score(prediction: str, labels: List[str]): 114 | return max( 115 | int(bool(re.search(re.compile(re.escape(label), re.IGNORECASE), prediction))) 116 | for label in labels 117 | ) 118 | 119 | class BasedDropTwice(BasedDrop): 120 | 121 | def doc_to_text(self, doc): 122 | context = doc["context"].strip() 123 | question = doc["question"].strip() 124 | while(context.lower().endswith(question.lower())): 125 | context = context[:-len(question)] 126 | 127 | out = context.strip().strip(".").strip() + "." 128 | out += "\n" + out + " " + question 129 | 130 | intro_q = doc['orig_question'].strip(":") 131 | out = f"{intro_q} " + out 132 | return out 133 | 134 | 135 | -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/tasks/based_fda/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenSparseLLMs/MoM/2c3ae59f5c0b749f916189433b8dad5a3415dc08/lm-eval-harness/lm_eval/tasks/based_fda/README.md -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/tasks/based_fda/task.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import re 3 | import numpy as np 4 | from lm_eval.api.task import ConfigurableTask 5 | from lm_eval.api.instance import Instance 6 | 7 | 8 | class FDA(ConfigurableTask): 9 | VERSION = 0 10 | DATASET_PATH = "hazyresearch/based-fda" 11 | DATASET_NAME = "default" 12 | 13 | def __init__(self): 14 | super().__init__(config={'metadata': {'version': self.VERSION}}) 15 | 16 | def has_training_docs(self): 17 | return False 18 | 19 | def has_validation_docs(self): 20 | return True 21 | 22 | def has_test_docs(self): 23 | return False 24 | 25 | def validation_docs(self): 26 | return self.dataset["validation"] 27 | 28 | def doc_to_text(self, doc): 29 | question = doc["key"]+":" 30 | while(doc["text"].lower().endswith(question.lower())): 31 | doc["text"] = doc["text"][:-len(question)] 32 | upper_key = doc['key'][0].upper() + doc['key'][1:] 33 | question = upper_key +":" 34 | doc['text'] = doc['text'].strip("\n").strip(".") 35 | out = doc["text"] 36 | if not out.endswith("."): out += "." 37 | out += " " + question 38 | return out 39 | 40 | def doc_to_target(self, doc): 41 | return doc["value"] 42 | 43 | def construct_requests(self, doc, ctx, **kwargs): 44 | """Uses RequestFactory to construct Requests and returns an iterable of 45 | Requests which will be sent to the LM. 46 | 47 | :param doc: 48 | The document as returned from training_docs, validation_docs, or test_docs. 49 | :param ctx: str 50 | The context string, generated by fewshot_context. This includes the natural 51 | language description, as well as the few shot examples, and the question 52 | part of the document for `doc`. 53 | """ 54 | 55 | return [ 56 | Instance( 57 | request_type="generate_until", 58 | doc=doc, 59 | arguments=(ctx, {"until": ["\n"], "max_gen_toks": 48}), 60 | idx=0, 61 | **kwargs, 62 | ), 63 | ] 64 | 65 | def process_results(self, doc, results): 66 | """Take a single document and the LM results and evaluates, returning a 67 | dict where keys are the names of submetrics and values are the values of 68 | the metric for that one document 69 | 70 | :param doc: 71 | The document as returned from training_docs, validation_docs, or test_docs. 72 | :param results: 73 | The results of the requests created in construct_requests. 74 | """ 75 | continuation = results[0] 76 | 77 | return { 78 | "contains": contains_score(continuation, [doc["value"]]) 79 | } 80 | 81 | def aggregation(self): 82 | """ 83 | :returns: {str: [float] -> float} 84 | A dictionary where keys are the names of submetrics and values are 85 | functions that aggregate a list of metrics 86 | """ 87 | return { 88 | "contains": np.mean, # Exact match (the normalized answer exactly match the gold answer) 89 | } 90 | 91 | def higher_is_better(self): 92 | """ 93 | :returns: {str: bool} 94 | A dictionary where keys are the names of submetrics and values are 95 | whether a higher value of the submetric is better 96 | """ 97 | return { 98 | "contains": True, # Exact match (the normalized answer exactly match the gold answer 99 | } 100 | 101 | 102 | def contains_score(prediction: str, labels: List[str]): 103 | return max( 104 | int(bool(re.search(re.compile(re.escape(label), re.IGNORECASE), prediction))) 105 | for label in labels 106 | ) 107 | 108 | class FDATwice(FDA): 109 | 110 | def doc_to_text(self, doc): 111 | question = doc["key"]+":" 112 | while(doc["text"].lower().endswith(question.lower())): 113 | doc["text"] = doc["text"][:-len(question)] 114 | 115 | upper_key = doc['key'][0].upper() + doc['key'][1:] 116 | question = upper_key +":" 117 | doc['text'] = doc['text'].strip("\n").strip(".") 118 | out = doc["text"] 119 | if not out.endswith("."): out += "." 120 | intro_q = question.strip(":") 121 | out = f"Information about {intro_q}. " + out + "\n" + out + " " + question 122 | return out 123 | -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/tasks/based_nq/task.py: -------------------------------------------------------------------------------- 1 | from lm_eval.api.task import ConfigurableTask 2 | from lm_eval.api.instance import Instance 3 | import re 4 | from typing import List 5 | import numpy as np 6 | 7 | 8 | class BasedNQ512(ConfigurableTask): 9 | VERSION = "default" 10 | DATASET_PATH = "hazyresearch/based_nq_512" 11 | DATASET_NAME = None 12 | 13 | def __init__(self): 14 | super().__init__(config={'metadata': {'version': self.VERSION}}) 15 | 16 | def has_training_docs(self): 17 | return False 18 | 19 | def has_validation_docs(self): 20 | return True 21 | 22 | def has_test_docs(self): 23 | return False 24 | 25 | def validation_docs(self): 26 | return self.dataset["validation"] 27 | 28 | def doc_to_text(self, doc): 29 | context = doc["context"].strip() 30 | question = doc["question"].strip() 31 | if context.endswith(question): 32 | context = context[:-len(question)] 33 | 34 | out = context.strip().strip(".") + ". " 35 | out = out + doc["question"].strip() 36 | return out 37 | 38 | def should_decontaminate(self): 39 | return True 40 | 41 | def doc_to_decontamination_query(self, doc): 42 | return doc["context"] 43 | 44 | def doc_to_target(self, doc): 45 | answer_list = doc['answers'] 46 | if len(answer_list) > 0: 47 | answer = answer_list[0] 48 | else: 49 | answer = "unanswerable" 50 | return " " + answer 51 | 52 | def construct_requests(self, doc, ctx, **kwargs): 53 | """Uses RequestFactory to construct Requests and returns an iterable of 54 | Requests which will be sent to the LM. 55 | 56 | :param doc: 57 | The document as returned from training_docs, validation_docs, or test_docs. 58 | :param ctx: str 59 | The context string, generated by fewshot_context. This includes the natural 60 | language description, as well as the few shot examples, and the question 61 | part of the document for `doc`. 62 | """ 63 | 64 | return [ 65 | Instance( 66 | request_type="generate_until", 67 | doc=doc, 68 | arguments=(ctx, {"until": ["\n"], "max_gen_toks": 48}), 69 | idx=0, 70 | **kwargs, 71 | ), 72 | ] 73 | 74 | def process_results(self, doc, results): 75 | """Take a single document and the LM results and evaluates, returning a 76 | dict where keys are the names of submetrics and values are the values of 77 | the metric for that one document 78 | 79 | :param doc: 80 | The document as returned from training_docs, validation_docs, or test_docs. 81 | :param results: 82 | The results of the requests created in construct_requests. 83 | """ 84 | continuation = results[0] 85 | return { 86 | "contains": contains_score(continuation, doc["answers"]) 87 | } 88 | 89 | 90 | def aggregation(self): 91 | """ 92 | :returns: {str: [float] -> float} 93 | A dictionary where keys are the names of submetrics and values are 94 | functions that aggregate a list of metrics 95 | """ 96 | return { 97 | "contains": np.mean, 98 | } 99 | 100 | 101 | def higher_is_better(self): 102 | """ 103 | :returns: {str: bool} 104 | A dictionary where keys are the names of submetrics and values are 105 | whether a higher value of the submetric is better 106 | """ 107 | return { 108 | "contains": True 109 | } 110 | 111 | 112 | def contains_score(prediction: str, labels: List[str]): 113 | return max( 114 | int(bool(re.search(re.compile(re.escape(label), re.IGNORECASE), prediction))) 115 | for label in labels 116 | ) 117 | 118 | 119 | class BasedNQTwice512(BasedNQ512): 120 | def doc_to_text(self, doc): 121 | context = doc["context"].strip() 122 | question = doc["question"].strip() 123 | while(context.lower().endswith(question.lower())): 124 | context = context[:-len(question)] 125 | 126 | out = context.strip().strip(".").strip() + "." 127 | out += "\n" + out + " " + question 128 | 129 | intro_q = doc['orig_question'].strip(":") 130 | out = f"{intro_q} " + out 131 | return out 132 | 133 | 134 | class BasedNQ1024(BasedNQ512): 135 | DATASET_PATH = "hazyresearch/based_nq_1024" 136 | DATASET_NAME = None 137 | 138 | 139 | class BasedNQ2048(BasedNQ512): 140 | DATASET_PATH = "hazyresearch/based_nq_2048" 141 | DATASET_NAME = None 142 | 143 | 144 | class BasedNQTwice1024(BasedNQTwice512): 145 | DATASET_PATH = "hazyresearch/based_nq_1024" 146 | DATASET_NAME = None 147 | 148 | 149 | class BasedNQTwice2048(BasedNQTwice512): 150 | DATASET_PATH = "hazyresearch/based_nq_2048" 151 | DATASET_NAME = None -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/tasks/based_squadv2/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenSparseLLMs/MoM/2c3ae59f5c0b749f916189433b8dad5a3415dc08/lm-eval-harness/lm_eval/tasks/based_squadv2/README.md -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/tasks/based_swde/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenSparseLLMs/MoM/2c3ae59f5c0b749f916189433b8dad5a3415dc08/lm-eval-harness/lm_eval/tasks/based_swde/README.md -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/tasks/scrolls/README.md: -------------------------------------------------------------------------------- 1 | """ 2 | SCROLLS: Standardized CompaRison Over Long Language Sequences 3 | https://arxiv.org/abs/2201.03533 4 | 5 | SCROLLS is a suite of datasets that require synthesizing information over long texts. 6 | The benchmark includes seven natural language tasks across multiple domains, 7 | including summarization, question answering, and natural language inference. 8 | 9 | Homepage: https://www.scrolls-benchmark.com/ 10 | 11 | Since SCROLLS tasks are generally longer than the maximum sequence length of many models, 12 | it is possible to create "subset" tasks that contain only those samples whose tokenized length 13 | is less than some pre-defined limit. For example, to create a subset of "Qasper" that would 14 | be suitable for a model using the GPTNeoX tokenizer and a 4K maximium sequence length: 15 | 16 | ``` 17 | class QasperGPTNeoX4K(Qasper): 18 | PRUNE_TOKENIZERS = ["EleutherAI/pythia-410m-deduped"] 19 | PRUNE_MAX_TOKENS = 4096 20 | PRUNE_NUM_PROC = _num_cpu_cores() # optional, to speed up pruning of large datasets like NarrativeQA 21 | ``` 22 | 23 | `PRUNE_TOKENIZERS` can contain more than one tokenizer; this will include only samples that are 24 | less than `PRUNE_MAX_TOKENS` for ALL of the tokenizers. This can be useful to comparing models 25 | that use different tokenizers but the same maximum sequence length. 26 | 27 | Once the subset task class has been defined in this file, it can be used by adding the class 28 | to `lm_eval/tasks/__init__.py`. 29 | 30 | NOTE: GovReport may need `max_gen_toks` set larger for causal models. 31 | """ 32 | -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/tasks/super_glue/README.md: -------------------------------------------------------------------------------- 1 | # SuperGLUE 2 | 3 | ### Paper 4 | 5 | Title: `SuperGLUE: A Stickier Benchmark for General-Purpose Language Understanding Systems` 6 | Abstract: `https://w4ngatang.github.io/static/papers/superglue.pdf` 7 | 8 | SuperGLUE is a benchmark styled after GLUE with a new set of more difficult language 9 | understanding tasks. 10 | 11 | Homepage: https://super.gluebenchmark.com/ 12 | 13 | ### Citation 14 | 15 | ``` 16 | @inproceedings{NEURIPS2019_4496bf24, 17 | author = {Wang, Alex and Pruksachatkun, Yada and Nangia, Nikita and Singh, Amanpreet and Michael, Julian and Hill, Felix and Levy, Omer and Bowman, Samuel}, 18 | booktitle = {Advances in Neural Information Processing Systems}, 19 | editor = {H. Wallach and H. Larochelle and A. Beygelzimer and F. d\textquotesingle Alch\'{e}-Buc and E. Fox and R. Garnett}, 20 | pages = {}, 21 | publisher = {Curran Associates, Inc.}, 22 | title = {SuperGLUE: A Stickier Benchmark for General-Purpose Language Understanding Systems}, 23 | url = {https://proceedings.neurips.cc/paper/2019/file/4496bf24afe7fab6f046bf4923da8de6-Paper.pdf}, 24 | volume = {32}, 25 | year = {2019} 26 | } 27 | ``` 28 | 29 | ### Groups and Tasks 30 | 31 | #### Groups 32 | 33 | * `super-glue-lm-eval-v1`: SuperGLUE eval adapted from LM Eval V1 34 | * `super-glue-t5-prompt`: SuperGLUE prompt and evaluation that matches the T5 paper (if using accelerate, will error if record is included.) 35 | 36 | #### Tasks 37 | 38 | Comparison between validation split score on T5x and LM-Eval (T5x models converted to HF) 39 | | T5V1.1 Base | SGLUE | BoolQ | CB | Copa | MultiRC | ReCoRD | RTE | WiC | WSC | 40 | | ----------- | ------| ----- | --------- | ---- | ------- | ------ | --- | --- | --- | 41 | | T5x | 69.47 | 78.47(acc) | 83.93(f1) 87.5(acc) | 50(acc) | 73.81(f1) 33.26(em) | 70.09(em) 71.34(f1) | 78.7(acc) | 63.64(acc) | 75(acc) | 42 | | LM-Eval | 71.35 | 79.36(acc) | 83.63(f1) 87.5(acc) | 63(acc) | 73.45(f1) 33.26(em) | 69.85(em) 68.86(f1) | 78.34(acc) | 65.83(acc) | 75.96(acc) | 43 | 44 | 45 | 46 | * `super-glue-lm-eval-v1` 47 | - `boolq` 48 | - `cb` 49 | - `copa` 50 | - `multirc` 51 | - `record` 52 | - `rte` 53 | - `wic` 54 | - `wsc` 55 | 56 | * `super-glue-t5-prompt` 57 | - `super_glue-boolq-t5-prompt` 58 | - `super_glue-cb-t5-prompt` 59 | - `super_glue-copa-t5-prompt` 60 | - `super_glue-multirc-t5-prompt` 61 | - `super_glue-record-t5-prompt` 62 | - `super_glue-rte-t5-prompt` 63 | - `super_glue-wic-t5-prompt` 64 | - `super_glue-wsc-t5-prompt` 65 | 66 | ### Checklist 67 | 68 | For adding novel benchmarks/datasets to the library: 69 | * [ ] Is the task an existing benchmark in the literature? 70 | * [ ] Have you referenced the original paper that introduced the task? 71 | * [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test? 72 | 73 | 74 | If other tasks on this dataset are already supported: 75 | * [ ] Is the "Main" variant of this task clearly denoted? 76 | * [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates? 77 | * [ ] Have you noted which, if any, published evaluation setups are matched by this variant? 78 | -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/tasks/super_glue/cb/aggregate.py: -------------------------------------------------------------------------------- 1 | import sklearn 2 | import numpy as np 3 | 4 | 5 | def cb_multi_fi(items): 6 | preds, golds = zip(*items) 7 | preds = np.array(preds) 8 | golds = np.array(golds) 9 | f11 = sklearn.metrics.f1_score(y_true=golds == 0, y_pred=preds == 0) 10 | f12 = sklearn.metrics.f1_score(y_true=golds == 1, y_pred=preds == 1) 11 | f13 = sklearn.metrics.f1_score(y_true=golds == 2, y_pred=preds == 2) 12 | avg_f1 = np.mean([f11, f12, f13]) 13 | return avg_f1 14 | -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/tasks/super_glue/cb/t5_utils.py: -------------------------------------------------------------------------------- 1 | import sklearn.metrics 2 | 3 | 4 | def mean_3class_f1(predictions, references): # This is a passthrough function 5 | string_label = ["entailment", "contradiction", "neutral"] 6 | predictions = ( 7 | string_label.index(predictions[0]) if predictions[0] in string_label else 0 8 | ) 9 | references = string_label.index(references[0]) 10 | 11 | return (predictions, references) 12 | 13 | 14 | def agg_mean_3class_f1(items): 15 | predictions, references = zip(*items) 16 | 17 | """Computes the unweighted average of the F1 per class.""" 18 | metric_str = "fbeta_score" 19 | metric_fn_kwargs = { 20 | "beta": 1, 21 | "labels": range(3), 22 | "average": "macro", 23 | } 24 | 25 | def _fn(predictions, references): 26 | metric_fn = getattr(sklearn.metrics, metric_str) 27 | metric_val = metric_fn(references, predictions, **metric_fn_kwargs) 28 | return metric_val 29 | 30 | return _fn(predictions, references) 31 | -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/tasks/super_glue/copa/utils.py: -------------------------------------------------------------------------------- 1 | def convert_choice(choice): 2 | return choice[0].lower() + choice[1:] 3 | 4 | 5 | def doc_to_text(doc): 6 | # Drop the period 7 | connector = { 8 | "cause": "because", 9 | "effect": "therefore", 10 | }[doc["question"]] 11 | return doc["premise"].strip()[:-1] + f" {connector}" 12 | 13 | 14 | def doc_to_target(doc): 15 | correct_choice = doc["choice1"] if doc["label"] == 0 else doc["choice2"] 16 | # Connect the sentences 17 | return " " + convert_choice(correct_choice) 18 | 19 | 20 | def doc_to_choice(doc): 21 | return [" " + convert_choice(doc["choice1"]), " " + convert_choice(doc["choice2"])] 22 | -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/tasks/super_glue/multirc/t5_utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import numpy as np 4 | import sklearn.metrics 5 | 6 | 7 | def f1(predictions, references): # This is a passthrough function 8 | _prediction = predictions[0] 9 | _reference = references[0].split("_")[-1] 10 | string_label = ["False", "True"] 11 | reference = string_label.index(_reference) 12 | prediction = ( 13 | string_label.index(_prediction) 14 | if _prediction in string_label 15 | else not bool(reference) 16 | ) 17 | 18 | return (prediction, reference) 19 | 20 | 21 | def agg_f1(items): 22 | predictions, references = zip(*items) 23 | references, predictions = np.asarray(references), np.asarray(predictions) 24 | 25 | return sklearn.metrics.f1_score(references, predictions) 26 | 27 | 28 | def em(predictions, references): # This is a passthrough function 29 | _prediction = predictions[0] 30 | _group, _reference = references[0].split("_") 31 | string_label = ["False", "True"] 32 | reference = string_label.index(_reference) 33 | prediction = ( 34 | string_label.index(_prediction) 35 | if _prediction in string_label 36 | else not bool(reference) 37 | ) 38 | 39 | return (_group, prediction, reference) 40 | 41 | 42 | def agg_em(items): 43 | grouped_values = collections.defaultdict(lambda: ([], [])) 44 | for group, prediction, reference in items: 45 | grouped_values[group][0].append(reference) 46 | grouped_values[group][1].append(prediction) 47 | 48 | group_scores = [] 49 | for group, (targets, predictions) in grouped_values.items(): 50 | score = float(np.array_equal(targets, predictions)) 51 | group_scores.append(score) 52 | 53 | return np.mean(group_scores) 54 | -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/tasks/super_glue/record/t5_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | import collections 4 | import numpy as np 5 | 6 | from datasets import Dataset 7 | 8 | from lm_eval.api.metrics import metric_max_over_ground_truths 9 | 10 | 11 | def doc_to_text(doc): 12 | passage = doc["passage"] 13 | passage = re.sub(r"(\.|\?|\!|\"|\')\n@highlight\n", r"\1 ", passage) 14 | passage = re.sub(r"\n@highlight\n", ". ", passage) 15 | 16 | return " ".join( 17 | [ 18 | "record query:", 19 | doc["query"], 20 | "entities:", 21 | ", ".join(doc["entities"]), 22 | "passage:", 23 | passage, 24 | ] 25 | ) 26 | 27 | 28 | def process_docs(dataset): 29 | def split_answers(doc): 30 | split_doc = { 31 | **{k: [] for k in doc.keys()}, 32 | } 33 | answers = doc.pop("answers") 34 | for idx, answer in enumerate(answers): 35 | for key in split_doc.keys(): 36 | if key in doc: 37 | split_doc[key].append(doc[key]) 38 | 39 | split_doc["answers"].append(answer) 40 | return split_doc 41 | 42 | dataset = dataset.map(split_answers) 43 | new_dataset = {} 44 | for key in dataset.features.keys(): 45 | new_dataset[key] = [x for row in dataset[key] for x in row] 46 | 47 | return Dataset.from_dict(new_dataset) 48 | 49 | 50 | def normalize_squad(answer): 51 | """Normalization used in official SQuAD evaluation script.""" 52 | 53 | def _normalize_answer(text, punc_chars, punc_repl): 54 | """Lower text and remove punctuation, articles and extra whitespace.""" 55 | 56 | def remove_articles(s): 57 | return re.sub(r"\b(a|an|the)\b", " ", s) 58 | 59 | def replace_punctuation(s): 60 | to_replace = set(punc_chars) 61 | return "".join(punc_repl if ch in to_replace else ch for ch in s) 62 | 63 | def white_space_fix(s): 64 | return " ".join(s.split()) 65 | 66 | text = text.lower() 67 | text = replace_punctuation(text) 68 | text = remove_articles(text) 69 | text = white_space_fix(text) 70 | 71 | return text 72 | 73 | return _normalize_answer(answer, punc_chars=string.punctuation, punc_repl="") 74 | 75 | 76 | def em(predictions, references): # This is a passthrough function 77 | return (predictions[0], references[0]) 78 | 79 | 80 | def f1(predictions, references): # This is a passthrough function 81 | return (predictions[0], references[0]) 82 | 83 | 84 | def squad_em_agg(items): 85 | def _exact_match_score(prediction, target): 86 | return target == prediction 87 | 88 | grouped_values = collections.defaultdict(lambda: ([], [])) 89 | for prediction, reference in items: 90 | group, reference = reference.split("_") 91 | # if group not in grouped_values: 92 | grouped_values[group][0].append(normalize_squad(prediction)) 93 | grouped_values[group][1].append(normalize_squad(reference)) 94 | 95 | em = [] 96 | for group in grouped_values.keys(): 97 | predictions, targets = grouped_values[group] 98 | for p in predictions: 99 | em.append(metric_max_over_ground_truths(_exact_match_score, p, targets)) 100 | 101 | return np.mean(em) 102 | 103 | 104 | def squad_f1_agg(items): 105 | def _f1_score(prediction, target): 106 | """Computes token f1 score for a single target and prediction.""" 107 | prediction_tokens = prediction.split() 108 | target_tokens = target.split() 109 | common = collections.Counter(prediction_tokens) & collections.Counter( 110 | target_tokens 111 | ) 112 | num_same = sum(common.values()) 113 | if num_same == 0: 114 | return 0 115 | precision = 1.0 * num_same / len(prediction_tokens) 116 | recall = 1.0 * num_same / len(target_tokens) 117 | f1 = (2 * precision * recall) / (precision + recall) 118 | return f1 119 | 120 | grouped_values = collections.defaultdict(lambda: ([], [])) 121 | for prediction, reference in items: 122 | group, reference = reference.split("_") 123 | if group not in grouped_values: 124 | grouped_values[group][0].append(normalize_squad(prediction)) 125 | grouped_values[group][1].append(normalize_squad(reference)) 126 | 127 | f1 = [] 128 | for group in grouped_values.keys(): 129 | p, t = grouped_values[group] 130 | f1.append(metric_max_over_ground_truths(_f1_score, p[0], t)) 131 | 132 | return np.mean(f1) 133 | -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/tasks/super_glue/record/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import transformers.data.metrics.squad_metrics as squad_metrics 3 | 4 | from lm_eval.api.metrics import metric_max_over_ground_truths 5 | 6 | 7 | def doc_to_text(doc): 8 | initial_text, *highlights = doc["passage"].strip().split("\n@highlight\n") 9 | text = initial_text + "\n\n" 10 | for highlight in highlights: 11 | text += f" - {highlight}.\n" 12 | return text 13 | 14 | 15 | def format_answer(query, entity): 16 | return f" - {query}".replace("@placeholder", entity) 17 | 18 | 19 | def doc_to_target(doc): 20 | # We only output the first correct entity in a doc 21 | return format_answer(query=doc["query"], entity=doc["answers"][0]) 22 | 23 | 24 | def process_results(doc, results): 25 | # ReCoRD's evaluation is actually deceptively simple: 26 | # - Pick the maximum likelihood prediction entity 27 | # - Evaluate the accuracy and token F1 PER EXAMPLE 28 | # - Average over all examples 29 | max_idx = np.argmax(np.array([result[0] for result in results])) 30 | 31 | prediction = doc["entities"][max_idx] 32 | gold_label_set = doc["answers"] 33 | f1 = metric_max_over_ground_truths( 34 | squad_metrics.compute_f1, prediction, gold_label_set 35 | ) 36 | em = metric_max_over_ground_truths( 37 | squad_metrics.compute_exact, prediction, gold_label_set 38 | ) 39 | 40 | return { 41 | "f1": f1, 42 | "em": em, 43 | } 44 | -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/tasks/super_glue/wsc/preprocess_wsc.py: -------------------------------------------------------------------------------- 1 | from lm_eval.utils import general_detokenize 2 | 3 | 4 | def default_doc_to_text(x): 5 | raw_passage = x["text"] 6 | # NOTE: HuggingFace span indices are word-based not character-based. 7 | pre = " ".join(raw_passage.split()[: x["span2_index"]]) 8 | post = raw_passage[len(pre) + len(x["span2_text"]) + 1 :] 9 | passage = general_detokenize(pre + " *{}*".format(x["span2_text"]) + post) 10 | noun = x["span1_text"] 11 | pronoun = x["span2_text"] 12 | text = ( 13 | f"Passage: {passage}\n" 14 | + f'Question: In the passage above, does the pronoun "*{pronoun}*" refer to "*{noun}*"?\n' 15 | + "Answer:" 16 | ) 17 | return text 18 | -------------------------------------------------------------------------------- /lm-eval-harness/lm_eval/tasks/super_glue/wsc/t5_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import List 3 | 4 | def doc_to_text(x): 5 | text = re.sub(r" X ", " *" + x["span2_text"] + "* ", _wsc_inputs(x)) 6 | return "wsc: " + text 7 | 8 | 9 | def _wsc_inputs(x): 10 | words = x["text"].split(" ") 11 | 12 | # We would need some special logic to handle the case where the pronoun is the 13 | # first or last word in the text. None of the examples in WSC seem to have 14 | # this, so we are ignoring these cases. 15 | assert x["span2_index"] > 0 16 | assert x["span2_index"] < len(words) 17 | pronoun_index = x["span2_index"] 18 | 19 | def create_input(): 20 | assert words[pronoun_index] == x["span2_text"] 21 | 22 | return " ".join( 23 | [ 24 | " ".join(words[:pronoun_index]), 25 | "X", 26 | " ".join(words[pronoun_index + 1:]), 27 | ] 28 | ) 29 | 30 | # Handle some special cases. 31 | if ( 32 | x["text"] 33 | == 'The boy continued to whip the pony , and eventually the pony threw him over. John laughed out quite loud. "Good for him," he said. ' 34 | ): 35 | return ( 36 | "The boy continued to whip the pony , and eventually the pony threw " 37 | 'him over. John laughed out quite loud. "Good for X ," he said.' 38 | ) 39 | 40 | # Using the span2_index, we get 'use' instead of 'it'. 41 | if ( 42 | x["text"] 43 | == "When they had eventually calmed down a bit , and had gotten home, Mr. Farley put the magic pebble in an iron safe . Some day they might want to use it , but really for now, what more could they wish for?" 44 | ): 45 | return ( 46 | "When they had eventually calmed down a bit , and had gotten home, " 47 | "Mr. Farley put the magic pebble in an iron safe . Some day they might " 48 | "want to use X , but really for now, what more could they wish for?" 49 | ) 50 | 51 | return create_input() 52 | 53 | 54 | DETERMINERS = { 55 | "a", 56 | "an", 57 | "few", 58 | "her", 59 | "his", 60 | "each", 61 | "every", 62 | "many", 63 | "much", 64 | "my", 65 | "our", 66 | "some", 67 | "that", 68 | "the", 69 | "their", 70 | "these", 71 | "this", 72 | "those", 73 | "which", 74 | "whose", 75 | "your", 76 | } 77 | 78 | 79 | def clean(s: str) -> str: 80 | """Ignore capitalization and determiners.""" 81 | s = s.strip().lower() 82 | return " ".join([w for w in s.split(" ") if w not in DETERMINERS]) 83 | 84 | 85 | def process_results(docs: dict, resps: List): 86 | prediction = clean(resps[0]) 87 | reference = clean(docs["span1_text"]) 88 | 89 | if ("'" in prediction) != ("'" in reference): 90 | # referent is "Bob's hat" as predicting the referent. 91 | predicted_referent = False 92 | else: 93 | prediction_words = set(prediction.split(" ")) 94 | referent_words = set(reference.split(" ")) 95 | 96 | # Handle cases where the prediction is "fuzzy bunny" and the referent is 97 | # "bunny". 98 | predicted_referent = prediction_words.issubset( 99 | referent_words 100 | ) or referent_words.issubset(prediction_words) 101 | 102 | acc = 1.0 if predicted_referent == docs["label"] else 0.0 103 | return {"accuracy": acc} 104 | -------------------------------------------------------------------------------- /lm-eval-harness/prompt_scripts/run_jrt_prompt_hazy.sh: -------------------------------------------------------------------------------- 1 | ### JRT ArXiv Table 1 ### 2 | 3 | output_dir="run_jrt_prompt_hazy" 4 | limit=-1 5 | 6 | 7 | # Default and twice SWDE context length 1000 8 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python launch_jrt.py \ 9 | --batch-size 32 \ 10 | -m hazyresearch/based-1b-50b \ 11 | -m hazyresearch/mamba-1b-50b \ 12 | -m hazyresearch/attn-1b-50bn \ 13 | -t based_swde \ 14 | -t based_swde_twice \ 15 | --output_dir ${output_dir} \ 16 | --context_length 1000 \ 17 | --answer_length 50 \ 18 | --cutting_context \ 19 | --limit ${limit} \ 20 | -p 21 | 22 | 23 | # Default and twice FDA at context length 1000 24 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python launch_jrt.py \ 25 | --batch-size 32 \ 26 | -m hazyresearch/based-1b-50b \ 27 | -m hazyresearch/mamba-1b-50b \ 28 | -m hazyresearch/attn-1b-50bn \ 29 | -t based_fda \ 30 | -t based_fda_twice \ 31 | --output_dir ${output_dir} \ 32 | --context_length 1000 \ 33 | --answer_length 50 \ 34 | --cutting_context \ 35 | --limit ${limit} \ 36 | -p 37 | 38 | 39 | # Default and twice SQUAD completion at context length 1000 40 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python launch_jrt.py \ 41 | --batch-size 32 \ 42 | -m hazyresearch/based-1b-50b \ 43 | -m hazyresearch/mamba-1b-50b \ 44 | -m hazyresearch/attn-1b-50bn \ 45 | -t based_squad_twice \ 46 | -t based_squad \ 47 | --output_dir ${output_dir} \ 48 | --context_length 1000 \ 49 | --answer_length 50 \ 50 | --cutting_context \ 51 | --limit ${limit} \ 52 | -p 53 | 54 | 55 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python launch_jrt.py \ 56 | --batch-size 32 \ 57 | -m hazyresearch/based-1b-50b \ 58 | -m hazyresearch/mamba-1b-50b \ 59 | -m hazyresearch/attn-1b-50bn \ 60 | -t based_drop_twice \ 61 | -t based_drop \ 62 | --output_dir ${output_dir} \ 63 | --context_length 1000 \ 64 | --answer_length 50 \ 65 | --cutting_context \ 66 | --limit ${limit} \ 67 | -p 68 | 69 | 70 | CUDA_VISIBLE_DEVICES=10,1,2,3,4,5,6,7 python launch_jrt.py \ 71 | --batch-size 32 \ 72 | -m hazyresearch/based-1b-50b \ 73 | -m hazyresearch/mamba-1b-50b \ 74 | -m hazyresearch/attn-1b-50bn \ 75 | -t based_nq_1024_twice \ 76 | -t based_nq_1024 \ 77 | --output_dir ${output_dir} \ 78 | --context_length 1000 \ 79 | --answer_length 50 \ 80 | --cutting_context \ 81 | --limit ${limit} \ 82 | -p 83 | 84 | 85 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python launch_jrt.py \ 86 | --batch-size 32 \ 87 | -m hazyresearch/based-1b-50b \ 88 | -m hazyresearch/mamba-1b-50b \ 89 | -m hazyresearch/attn-1b-50bn \ 90 | -t based_triviaqa_twice \ 91 | -t based_triviaqa \ 92 | --output_dir ${output_dir} \ 93 | --context_length 1000 \ 94 | --answer_length 50 \ 95 | --cutting_context \ 96 | --limit ${limit} \ 97 | -p 98 | 99 | -------------------------------------------------------------------------------- /lm-eval-harness/prompt_scripts/run_jrt_prompt_hf.sh: -------------------------------------------------------------------------------- 1 | ### JRT ArXiv Table 1 ### 2 | 3 | output_dir="run_jrt_prompt_hf" 4 | limit=-1 5 | 6 | # Default and twice SWDE context length 1000 7 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python launch_hf.py \ 8 | --batch-size 32 \ 9 | -m "fla-hub/gla-1.3B-100B" \ 10 | -m "fla-hub/gla-2.7B-100B" \ 11 | -m "state-spaces/mamba-130m" \ 12 | -m "state-spaces/mamba-370m" \ 13 | -m "state-spaces/mamba-1.4b" \ 14 | -m "state-spaces/mamba-2.8b" \ 15 | -m "state-spaces/mamba2-130m" \ 16 | -m "state-spaces/mamba2-370m" \ 17 | -m "state-spaces/mamba2-1.3b" \ 18 | -m "state-spaces/mamba2-2.7b" \ 19 | -t based_swde \ 20 | -t based_swde_twice \ 21 | --output_dir ${output_dir} \ 22 | --context_length 1000 \ 23 | --answer_length 50 \ 24 | --cutting_context \ 25 | --limit ${limit} \ 26 | -p 27 | 28 | 29 | # Default and twice FDA at context length 1000 30 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python launch_hf.py \ 31 | --batch-size 32 \ 32 | -m "fla-hub/gla-1.3B-100B" \ 33 | -m "fla-hub/gla-2.7B-100B" \ 34 | -m "state-spaces/mamba-130m" \ 35 | -m "state-spaces/mamba-370m" \ 36 | -m "state-spaces/mamba-1.4b" \ 37 | -m "state-spaces/mamba-2.8b" \ 38 | -m "state-spaces/mamba2-130m" \ 39 | -m "state-spaces/mamba2-370m" \ 40 | -m "state-spaces/mamba2-1.3b" \ 41 | -m "state-spaces/mamba2-2.7b" \ 42 | -t based_fda \ 43 | -t based_fda_twice \ 44 | --output_dir ${output_dir} \ 45 | --context_length 1000 \ 46 | --answer_length 50 \ 47 | --cutting_context \ 48 | --limit ${limit} \ 49 | -p 50 | 51 | 52 | # Default and twice SQUAD completion at context length 1000 53 | CUDA_VISIBLE_DEVICES=1,2,3,5,6,7 python launch_hf.py \ 54 | --batch-size 32 \ 55 | -m "fla-hub/gla-1.3B-100B" \ 56 | -m "fla-hub/gla-2.7B-100B" \ 57 | -m "state-spaces/mamba-130m" \ 58 | -m "state-spaces/mamba-370m" \ 59 | -m "state-spaces/mamba-1.4b" \ 60 | -m "state-spaces/mamba-2.8b" \ 61 | -m "state-spaces/mamba2-130m" \ 62 | -m "state-spaces/mamba2-370m" \ 63 | -m "state-spaces/mamba2-1.3b" \ 64 | -m "state-spaces/mamba2-2.7b" \ 65 | -t based_squad_twice \ 66 | -t based_squad \ 67 | --output_dir ${output_dir} \ 68 | --context_length 1000 \ 69 | --answer_length 50 \ 70 | --cutting_context \ 71 | --limit ${limit} \ 72 | -p 73 | 74 | 75 | CUDA_VISIBLE_DEVICES=1,2,3,5,6,7 python launch_hf.py \ 76 | --batch-size 32 \ 77 | -m "fla-hub/gla-1.3B-100B" \ 78 | -m "fla-hub/gla-2.7B-100B" \ 79 | -m "state-spaces/mamba-130m" \ 80 | -m "state-spaces/mamba-370m" \ 81 | -m "state-spaces/mamba-1.4b" \ 82 | -m "state-spaces/mamba-2.8b" \ 83 | -m "state-spaces/mamba2-130m" \ 84 | -m "state-spaces/mamba2-370m" \ 85 | -m "state-spaces/mamba2-1.3b" \ 86 | -m "state-spaces/mamba2-2.7b" \ 87 | -t based_drop_twice \ 88 | -t based_drop \ 89 | --output_dir ${output_dir} \ 90 | --context_length 1000 \ 91 | --answer_length 50 \ 92 | --cutting_context \ 93 | --limit ${limit} \ 94 | -p 95 | 96 | 97 | CUDA_VISIBLE_DEVICES=1,2,3,5,6,7 python launch_hf.py \ 98 | --batch-size 32 \ 99 | -m "fla-hub/gla-1.3B-100B" \ 100 | -m "fla-hub/gla-2.7B-100B" \ 101 | -m "state-spaces/mamba-130m" \ 102 | -m "state-spaces/mamba-370m" \ 103 | -m "state-spaces/mamba-1.4b" \ 104 | -m "state-spaces/mamba-2.8b" \ 105 | -m "state-spaces/mamba2-130m" \ 106 | -m "state-spaces/mamba2-370m" \ 107 | -m "state-spaces/mamba2-1.3b" \ 108 | -m "state-spaces/mamba2-2.7b" \ 109 | -t based_nq_1024_twice \ 110 | -t based_nq_1024 \ 111 | --output_dir ${output_dir} \ 112 | --context_length 1000 \ 113 | --answer_length 50 \ 114 | --cutting_context \ 115 | --limit ${limit} \ 116 | -p 117 | 118 | 119 | CUDA_VISIBLE_DEVICES=1,2,3,5,6,7 python launch_hf.py \ 120 | --batch-size 32 \ 121 | -m "fla-hub/gla-1.3B-100B" \ 122 | -m "fla-hub/gla-2.7B-100B" \ 123 | -m "state-spaces/mamba-130m" \ 124 | -m "state-spaces/mamba-370m" \ 125 | -m "state-spaces/mamba-1.4b" \ 126 | -m "state-spaces/mamba-2.8b" \ 127 | -m "state-spaces/mamba2-130m" \ 128 | -m "state-spaces/mamba2-370m" \ 129 | -m "state-spaces/mamba2-1.3b" \ 130 | -m "state-spaces/mamba2-2.7b" \ 131 | -t based_triviaqa_twice \ 132 | -t based_triviaqa \ 133 | --output_dir ${output_dir} \ 134 | --context_length 1000 \ 135 | --answer_length 50 \ 136 | --cutting_context \ 137 | --limit ${limit} \ 138 | -p 139 | 140 | -------------------------------------------------------------------------------- /lm-eval-harness/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=40.8.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "lm_eval" 7 | version = "0.4.1" 8 | authors = [ 9 | {name="EleutherAI", email="contact@eleuther.ai"} 10 | ] 11 | description = "A framework for evaluating language models" 12 | readme = "README.md" 13 | classifiers = [ 14 | "Development Status :: 3 - Alpha", 15 | "Programming Language :: Python :: 3", 16 | "License :: OSI Approved :: MIT License", 17 | "Operating System :: OS Independent", 18 | ] 19 | requires-python = ">=3.8" 20 | license = { "text" = "MIT" } 21 | dependencies = [ 22 | "accelerate>=0.21.0", 23 | "evaluate", 24 | "datasets>=2.14.0", 25 | "evaluate>=0.4.0", 26 | "jsonlines", 27 | "numexpr", 28 | "peft>=0.2.0", 29 | "pybind11>=2.6.2", 30 | "pytablewriter", 31 | "rouge-score>=0.0.4", 32 | "sacrebleu>=1.5.0", 33 | "scikit-learn>=0.24.1", 34 | "sqlitedict", 35 | "torch>=1.8", 36 | "tqdm-multiprocess", 37 | "transformers>=4.1", 38 | "zstandard", 39 | ] 40 | 41 | [tool.setuptools.packages.find] 42 | include = ["lm_eval*"] 43 | 44 | # required to include yaml files in pip installation 45 | [tool.setuptools.package-data] 46 | lm_eval = ["**/*.yaml", "tasks/**/*"] 47 | 48 | [project.scripts] 49 | lm-eval = "lm_eval.__main__:cli_evaluate" 50 | lm_eval = "lm_eval.__main__:cli_evaluate" 51 | 52 | [project.urls] 53 | Homepage = "https://github.com/EleutherAI/lm-evaluation-harness" 54 | Repository = "https://github.com/EleutherAI/lm-evaluation-harness" 55 | 56 | [project.optional-dependencies] 57 | anthropic = ["anthropic"] 58 | dev = ["pytest", "pytest-cov", "pytest-xdist", "pre-commit", "mypy"] 59 | gptq = ["auto-gptq[triton]>=0.6.0"] 60 | hf_transfer = ["hf_transfer"] 61 | ifeval = ["langdetect", "immutabledict"] 62 | neuronx = ["optimum[neuronx]"] 63 | mamba = ["mamba_ssm", "causal-conv1d==1.0.2"] 64 | math = ["sympy>=1.12", "antlr4-python3-runtime==4.11"] 65 | multilingual = ["nagisa>=0.2.7", "jieba>=0.42.1", "pycountry"] 66 | openai = ["openai==1.3.9", "tiktoken"] 67 | optimum = ["optimum[openvino]"] 68 | promptsource = ["promptsource>=0.2.3"] 69 | sentencepiece = ["sentencepiece>=0.1.98", "protobuf>=4.22.1"] 70 | testing = ["pytest", "pytest-cov", "pytest-xdist"] 71 | vllm = ["vllm<=0.2.5"] 72 | zeno = ["pandas", "zeno-client"] 73 | all = [ 74 | "lm_eval[anthropic]", 75 | "lm_eval[dev]", 76 | "lm_eval[gptq]", 77 | "lm_eval[hf_transfer]", 78 | "lm_eval[ifeval]", 79 | "lm_eval[mamba]", 80 | "lm_eval[math]", 81 | "lm_eval[multilingual]", 82 | "lm_eval[openai]", 83 | "lm_eval[promptsource]", 84 | "lm_eval[sentencepiece]", 85 | "lm_eval[testing]", 86 | "lm_eval[vllm]", 87 | "lm_eval[zeno]", 88 | ] 89 | 90 | [tool.ruff] 91 | extend-exclude = ["lm_eval/tasks/*.py"] 92 | 93 | [tool.ruff.lint] 94 | extend-select = ["I"] 95 | 96 | [tool.ruff.isort] 97 | lines-after-imports = 2 98 | known-first-party = ["lm_eval"] 99 | 100 | [tool.ruff.extend-per-file-ignores] 101 | "__init__.py" = ["F401","F402","F403","I"] 102 | "lm_eval/tasks/*"= ["E721"] 103 | -------------------------------------------------------------------------------- /lm-eval-harness/requirements.txt: -------------------------------------------------------------------------------- 1 | -e . 2 | -------------------------------------------------------------------------------- /lm-eval-harness/setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | 4 | # This is to make sure that the package supports editable installs 5 | setuptools.setup() 6 | -------------------------------------------------------------------------------- /lm-eval-harness/templates/new_yaml_task/README.md: -------------------------------------------------------------------------------- 1 | # Task-name 2 | 3 | ### Paper 4 | 5 | Title: `paper titles goes here` 6 | 7 | Abstract: `link to paper PDF or arXiv abstract goes here` 8 | 9 | `Short description of paper / benchmark goes here:` 10 | 11 | Homepage: `homepage to the benchmark's website goes here, if applicable` 12 | 13 | 14 | ### Citation 15 | 16 | ``` 17 | BibTeX-formatted citation goes here 18 | ``` 19 | 20 | ### Groups and Tasks 21 | 22 | #### Groups 23 | 24 | * `group_name`: `Short description` 25 | 26 | #### Tasks 27 | 28 | * `task_name`: `1-sentence description of what this particular task does` 29 | * `task_name2`: ... 30 | 31 | ### Checklist 32 | 33 | For adding novel benchmarks/datasets to the library: 34 | * [ ] Is the task an existing benchmark in the literature? 35 | * [ ] Have you referenced the original paper that introduced the task? 36 | * [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test? 37 | 38 | 39 | If other tasks on this dataset are already supported: 40 | * [ ] Is the "Main" variant of this task clearly denoted? 41 | * [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates? 42 | * [ ] Have you noted which, if any, published evaluation setups are matched by this variant? 43 | -------------------------------------------------------------------------------- /mom/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from mom.layers import MomGatedDeltaNet, MomLinearAttention, MomGatedSlotAttention, MomGatedLinearAttention 4 | from mom.models import (MomGatedDeltaNetForCausalLM, MomGatedDeltaNetModel, 5 | MomLinearAttentionForCausalLM, MomLinearAttentionModel, 6 | MomGLAForCausalLM, MomGLAModel, 7 | MomGSAForCausalLM, MomGSAModel) 8 | 9 | __all__ = [ 10 | 'MomGatedDeltaNet', 11 | 'MomGatedLinearAttention', 12 | 'MomGatedSlotAttention', 13 | 'MomLinearAttention', 14 | 'MomGatedDeltaNetForCausalLM', 15 | 'MomGatedDeltaNetModel', 16 | 'MomGLAForCausalLM', 17 | 'MomGLAModel', 18 | 'MomGSAForCausalLM', 19 | 'MomGSAModel', 20 | 'MomLinearAttentionForCausalLM', 21 | 'MomLinearAttentionModel', 22 | ] 23 | 24 | __version__ = '0.1' 25 | -------------------------------------------------------------------------------- /mom/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .mom_gated_deltanet import MomGatedDeltaNet 4 | from .mom_gla import MomGatedLinearAttention 5 | from .mom_gsa import MomGatedSlotAttention 6 | from .mom_linear_attn import MomLinearAttention 7 | 8 | __all__ = [ 9 | 'MomGatedDeltaNet', 10 | 'MomGatedLinearAttention', 11 | 'MomGatedSlotAttention', 12 | 'MomLinearAttention', 13 | ] 14 | -------------------------------------------------------------------------------- /mom/models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from mom.models.mom_gla import MomGLAConfig, MomGLAForCausalLM, MomGLAModel 4 | from mom.models.mom_gsa import MomGSAConfig, MomGSAForCausalLM, MomGSAModel 5 | from mom.models.mom_linear_attn import (MomLinearAttentionConfig, 6 | MomLinearAttentionForCausalLM, 7 | MomLinearAttentionModel) 8 | 9 | from mom.models.mom_gated_deltanet import MomGatedDeltaNetConfig, MomGatedDeltaNetForCausalLM, MomGatedDeltaNetModel 10 | 11 | __all__ = [ 12 | 'MomGLAConfig', 'MomGLAForCausalLM', 'MomGLAModel', 13 | 'MomGSAConfig', 'MomGSAForCausalLM', 'MomGSAModel', 14 | 'MomLinearAttentionConfig', 'MomLinearAttentionForCausalLM', 'MomLinearAttentionModel', 15 | 'MomGatedDeltaNetConfig', 'MomGatedDeltaNetForCausalLM', 'MomGatedDeltaNetModel' 16 | ] 17 | -------------------------------------------------------------------------------- /mom/models/mom_gated_deltanet/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from mom.models.mom_gated_deltanet.configuration_mom_gated_deltanet import \ 6 | MomGatedDeltaNetConfig 7 | from mom.models.mom_gated_deltanet.modeling_mom_gated_deltanet import ( 8 | MomGatedDeltaNetForCausalLM, MomGatedDeltaNetModel) 9 | 10 | AutoConfig.register(MomGatedDeltaNetConfig.model_type, MomGatedDeltaNetConfig) 11 | AutoModel.register(MomGatedDeltaNetConfig, MomGatedDeltaNetModel) 12 | AutoModelForCausalLM.register(MomGatedDeltaNetConfig, MomGatedDeltaNetForCausalLM) 13 | 14 | __all__ = ['MomGatedDeltaNetConfig', 'MomGatedDeltaNetForCausalLM', 'MomGatedDeltaNetModel'] 15 | -------------------------------------------------------------------------------- /mom/models/mom_gated_deltanet/configuration_mom_gated_deltanet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Dict, Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class MomGatedDeltaNetConfig(PretrainedConfig): 9 | model_type = 'mom_gated_deltanet' 10 | keys_to_ignore_at_inference = ['past_key_values'] 11 | def __init__( 12 | self, 13 | attn_mode: str = "chunk", 14 | hidden_size: int = 2048, 15 | expand_v: int = 2, 16 | use_gate: bool = True, 17 | use_short_conv: bool = True, 18 | conv_size: int = 4, 19 | head_dim: int = 256, 20 | num_heads: int = 6, 21 | max_position_embeddings: int = 2048, 22 | hidden_ratio: Optional[int] = 4, 23 | intermediate_size: Optional[int] = None, 24 | hidden_act: str = "swish", 25 | num_hidden_layers: int = 21, 26 | norm_first: bool = False, 27 | norm_eps: float = 1e-6, 28 | attn: Optional[Dict] = None, 29 | use_cache: bool = True, 30 | pad_token_id: int = None, 31 | bos_token_id: int = 1, 32 | eos_token_id: int = 2, 33 | tie_word_embeddings: bool = False, 34 | initializer_range: float = 0.02, 35 | fuse_cross_entropy: bool = True, 36 | vocab_size: int = 32000, 37 | num_memories: int = 8, 38 | topk: int = 2, 39 | capacity: float = 1.0, 40 | use_layer_wise_balance: bool=True, 41 | aux_loss_scale: float=0.01, 42 | shared_mem: bool = False, 43 | single_kv_proj: bool = False, 44 | **kwargs 45 | ): 46 | self.attn_mode = attn_mode 47 | self.hidden_size = hidden_size 48 | self.expand_v = expand_v 49 | self.use_gate = use_gate 50 | self.use_short_conv = use_short_conv 51 | self.conv_size = conv_size 52 | self.head_dim = head_dim 53 | self.num_heads = num_heads 54 | self.max_position_embeddings = max_position_embeddings 55 | 56 | self.hidden_ratio = hidden_ratio 57 | self.intermediate_size = intermediate_size 58 | self.hidden_act = hidden_act 59 | self.num_hidden_layers = num_hidden_layers 60 | self.norm_first = norm_first 61 | self.norm_eps = norm_eps 62 | self.attn = attn 63 | self.use_cache = use_cache 64 | self.initializer_range = initializer_range 65 | self.fuse_cross_entropy = fuse_cross_entropy 66 | self.vocab_size = vocab_size 67 | self.num_memories = num_memories 68 | self.topk = topk 69 | self.capacity = capacity 70 | self.use_layer_wise_balance = use_layer_wise_balance 71 | self.aux_loss_scale = aux_loss_scale 72 | self.shared_mem = shared_mem 73 | self.single_kv_proj = single_kv_proj 74 | 75 | if attn is not None: 76 | if not isinstance(attn, Dict): 77 | raise ValueError("attn must be a dictionary") 78 | if 'layers' not in attn: 79 | raise ValueError("Layer indices must be provided to initialize hybrid attention layers") 80 | if 'num_heads' not in attn: 81 | raise ValueError("Number of heads must be provided to initialize hybrid attention layers") 82 | attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) 83 | attn['window_size'] = attn.get('window_size', None) 84 | 85 | super().__init__( 86 | pad_token_id=pad_token_id, 87 | bos_token_id=bos_token_id, 88 | eos_token_id=eos_token_id, 89 | tie_word_embeddings=tie_word_embeddings, 90 | **kwargs, 91 | ) -------------------------------------------------------------------------------- /mom/models/mom_gla/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from mom.models.mom_gla.configuration_mom_gla import MomGLAConfig 6 | from mom.models.mom_gla.modeling_mom_gla import MomGLAForCausalLM, MomGLAModel 7 | 8 | AutoConfig.register(MomGLAConfig.model_type, MomGLAConfig) 9 | AutoModel.register(MomGLAConfig, MomGLAModel) 10 | AutoModelForCausalLM.register(MomGLAConfig, MomGLAForCausalLM) 11 | 12 | 13 | __all__ = ['MomGLAConfig', 'MomGLAForCausalLM', 'MomGLAModel'] 14 | -------------------------------------------------------------------------------- /mom/models/mom_gla/configuration_mom_gla.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Dict, Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class MomGLAConfig(PretrainedConfig): 9 | 10 | model_type = 'mom_gla' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | hidden_size: int = 2048, 16 | expand_k: int = 0.5, 17 | expand_v: int = 1, 18 | hidden_ratio: Optional[int] = 4, 19 | intermediate_size: Optional[int] = None, 20 | num_hidden_layers: int = 24, 21 | num_heads: int = 4, 22 | num_kv_heads: Optional[int] = None, 23 | feature_map: Optional[str] = None, 24 | attn_mode: str = "chunk", 25 | use_short_conv: bool = False, 26 | conv_size: int = 4, 27 | use_output_gate: bool = True, 28 | clamp_min: Optional[float] = None, 29 | hidden_act: str = "swish", 30 | max_position_embeddings: int = 2048, 31 | elementwise_affine: Optional[bool] = True, 32 | norm_eps: float = 1e-6, 33 | use_gk: bool = True, 34 | use_gv: bool = False, 35 | attn: Optional[Dict] = None, 36 | use_cache: bool = True, 37 | pad_token_id: int = None, 38 | bos_token_id: int = 1, 39 | eos_token_id: int = 2, 40 | tie_word_embeddings: bool = False, 41 | initializer_range: float = 0.02, 42 | fuse_norm: bool = True, 43 | fuse_cross_entropy: bool = True, 44 | vocab_size: int = 32000, 45 | num_memories: int = 8, 46 | topk: int = 2, 47 | capacity: float = 1.0, 48 | use_layer_wise_balance: bool=True, 49 | aux_loss_scale: float=0.01, 50 | shared_mem: bool = False, 51 | single_kv_proj: bool = False, 52 | **kwargs 53 | ): 54 | self.hidden_size = hidden_size 55 | self.expand_k = expand_k 56 | self.expand_v = expand_v 57 | self.hidden_ratio = hidden_ratio 58 | self.intermediate_size = intermediate_size 59 | self.num_hidden_layers = num_hidden_layers 60 | self.num_heads = num_heads 61 | self.num_kv_heads = num_kv_heads 62 | self.feature_map = feature_map 63 | self.attn_mode = attn_mode 64 | self.use_short_conv = use_short_conv 65 | self.conv_size = conv_size 66 | self.use_output_gate = use_output_gate 67 | self.clamp_min = clamp_min 68 | self.hidden_act = hidden_act 69 | self.max_position_embeddings = max_position_embeddings 70 | self.elementwise_affine = elementwise_affine 71 | self.norm_eps = norm_eps 72 | self.use_gk = use_gk 73 | self.use_gv = use_gv 74 | self.attn = attn 75 | self.use_cache = use_cache 76 | self.initializer_range = initializer_range 77 | self.fuse_norm = fuse_norm 78 | self.fuse_cross_entropy = fuse_cross_entropy 79 | self.vocab_size = vocab_size 80 | self.num_memories = num_memories 81 | self.topk = topk 82 | self.capacity = capacity 83 | self.use_layer_wise_balance = use_layer_wise_balance 84 | self.aux_loss_scale = aux_loss_scale 85 | self.shared_mem = shared_mem 86 | self.single_kv_proj = single_kv_proj 87 | 88 | if attn is not None: 89 | if not isinstance(attn, Dict): 90 | raise ValueError("attn must be a dictionary") 91 | if 'layers' not in attn: 92 | raise ValueError("Layer indices must be provided to initialize hybrid attention layers") 93 | if 'num_heads' not in attn: 94 | raise ValueError("Number of heads must be provided to initialize hybrid attention layers") 95 | attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) 96 | attn['window_size'] = attn.get('window_size', None) 97 | 98 | super().__init__( 99 | pad_token_id=pad_token_id, 100 | bos_token_id=bos_token_id, 101 | eos_token_id=eos_token_id, 102 | tie_word_embeddings=tie_word_embeddings, 103 | **kwargs, 104 | ) 105 | -------------------------------------------------------------------------------- /mom/models/mom_gsa/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from mom.models.mom_gsa.configuration_mom_gsa import MomGSAConfig 6 | from mom.models.mom_gsa.modeling_mom_gsa import MomGSAForCausalLM, MomGSAModel 7 | 8 | AutoConfig.register(MomGSAConfig.model_type, MomGSAConfig) 9 | AutoModel.register(MomGSAConfig, MomGSAModel) 10 | AutoModelForCausalLM.register(MomGSAConfig, MomGSAForCausalLM) 11 | 12 | 13 | __all__ = ['MomGSAConfig', 'MomGSAForCausalLM', 'MomGSAModel'] 14 | -------------------------------------------------------------------------------- /mom/models/mom_gsa/configuration_mom_gsa.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Dict, Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class MomGSAConfig(PretrainedConfig): 9 | 10 | model_type = 'mom_gsa' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | hidden_size: int = 2048, 16 | gate_logit_normalizer: Optional[int] = 8, 17 | clamp_min: Optional[float] = None, 18 | clamp_max: Optional[float] = None, 19 | hidden_ratio: Optional[int] = 4, 20 | intermediate_size: Optional[int] = None, 21 | num_hidden_layers: int = 24, 22 | num_heads: int = 4, 23 | num_kv_heads: Optional[int] = None, 24 | num_slots: Optional[int] = 64, 25 | use_short_conv: bool = False, 26 | conv_size: int = 4, 27 | exapnd_k: float = 1, 28 | exapnd_v: float = 1, 29 | feature_map: str = 'swish', 30 | use_output_gate: bool = False, 31 | use_norm: bool = True, 32 | max_position_embeddings: int = 2048, 33 | hidden_act: str = "swish", 34 | elementwise_affine: Optional[bool] = True, 35 | norm_first: bool = True, 36 | norm_eps: float = 1e-6, 37 | attn: Optional[Dict] = None, 38 | use_cache: bool = True, 39 | pad_token_id: int = None, 40 | bos_token_id: int = 1, 41 | eos_token_id: int = 2, 42 | initializer_range: float = 0.02, 43 | tie_word_embeddings: bool = False, 44 | fuse_norm: bool = True, 45 | fuse_cross_entropy: bool = True, 46 | vocab_size: int = 32000, 47 | num_experts: int = 8, 48 | topk: int = 2, 49 | capacity: float = 1.0, 50 | use_layer_wise_balance: bool=True, 51 | aux_loss_scale: float=0.01, 52 | shared_mem: bool = False, 53 | **kwargs 54 | ): 55 | self.hidden_size = hidden_size 56 | self.gate_logit_normalizer = gate_logit_normalizer 57 | self.clamp_min = clamp_min 58 | self.clamp_max = clamp_max 59 | self.hidden_ratio = hidden_ratio 60 | self.intermediate_size = intermediate_size 61 | self.num_hidden_layers = num_hidden_layers 62 | self.num_heads = num_heads 63 | self.num_kv_heads = num_kv_heads 64 | self.num_slots = num_slots 65 | self.use_short_conv = use_short_conv 66 | self.conv_size = conv_size 67 | self.expand_k = exapnd_k 68 | self.expand_v = exapnd_v 69 | self.feature_map = feature_map 70 | self.use_output_gate = use_output_gate 71 | self.use_norm = use_norm 72 | self.max_position_embeddings = max_position_embeddings 73 | self.hidden_act = hidden_act 74 | self.elementwise_affine = elementwise_affine 75 | self.norm_first = norm_first 76 | self.norm_eps = norm_eps 77 | self.attn = attn 78 | self.use_cache = use_cache 79 | self.initializer_range = initializer_range 80 | self.fuse_cross_entropy = fuse_cross_entropy 81 | self.fuse_norm = fuse_norm 82 | self.vocab_size = vocab_size 83 | self.num_experts = num_experts 84 | self.topk = topk 85 | self.capacity = capacity 86 | self.use_layer_wise_balance = use_layer_wise_balance 87 | self.aux_loss_scale = aux_loss_scale 88 | self.shared_mem = shared_mem 89 | 90 | if attn is not None: 91 | if not isinstance(attn, Dict): 92 | raise ValueError("attn must be a dictionary") 93 | if 'layers' not in attn: 94 | raise ValueError("Layer indices must be provided to initialize hybrid attention layers") 95 | if 'num_heads' not in attn: 96 | raise ValueError("Number of heads must be provided to initialize hybrid attention layers") 97 | attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) 98 | attn['window_size'] = attn.get('window_size', None) 99 | 100 | super().__init__( 101 | pad_token_id=pad_token_id, 102 | bos_token_id=bos_token_id, 103 | eos_token_id=eos_token_id, 104 | tie_word_embeddings=tie_word_embeddings, 105 | **kwargs, 106 | ) 107 | -------------------------------------------------------------------------------- /mom/models/mom_linear_attn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from mom.models.mom_linear_attn.configuration_mom_linear_attn import \ 6 | MomLinearAttentionConfig 7 | from mom.models.mom_linear_attn.modeling_mom_linear_attn import ( 8 | MomLinearAttentionForCausalLM, MomLinearAttentionModel) 9 | 10 | AutoConfig.register(MomLinearAttentionConfig.model_type, MomLinearAttentionConfig) 11 | AutoModel.register(MomLinearAttentionConfig, MomLinearAttentionModel) 12 | AutoModelForCausalLM.register(MomLinearAttentionConfig, MomLinearAttentionForCausalLM) 13 | 14 | __all__ = ['MomLinearAttentionConfig', 'MomLinearAttentionForCausalLM', 'MomLinearAttentionModel'] 15 | -------------------------------------------------------------------------------- /mom/models/mom_linear_attn/configuration_mom_linear_attn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Dict, Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class MomLinearAttentionConfig(PretrainedConfig): 9 | 10 | model_type = 'mom_linear_attn' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | attn_mode: str = "fused_chunk", 16 | hidden_size: int = 2048, 17 | expand_k: int = 1, 18 | expand_v: int = 1, 19 | hidden_ratio: Optional[int] = 4, 20 | intermediate_size: Optional[int] = None, 21 | num_hidden_layers: int = 24, 22 | num_heads: int = 4, 23 | num_kv_heads: Optional[int] = None, 24 | feature_map: str = "elementwise_product", 25 | tie_feature_map_qk: bool = False, 26 | norm_q: bool = False, 27 | norm_k: bool = False, 28 | norm_feature_map: bool = False, 29 | hidden_act: str = "swish", 30 | max_position_embeddings: int = 2048, 31 | elementwise_affine: Optional[bool] = True, 32 | norm_eps: float = 1e-6, 33 | attn: Optional[Dict] = None, 34 | use_cache: bool = True, 35 | pad_token_id: int = None, 36 | bos_token_id: int = 1, 37 | eos_token_id: int = 2, 38 | tie_word_embeddings: bool = False, 39 | initializer_range: float = 0.02, 40 | fuse_cross_entropy: bool = True, 41 | vocab_size: int = 32000, 42 | num_memories: int = 8, 43 | topk: int = 2, 44 | capacity: float = 1.0, 45 | use_layer_wise_balance: bool=True, 46 | aux_loss_scale: float=0.01, 47 | shared_mem: bool = False, 48 | **kwargs 49 | ): 50 | self.attn_mode = attn_mode 51 | self.hidden_size = hidden_size 52 | self.expand_k = expand_k 53 | self.expand_v = expand_v 54 | self.hidden_ratio = hidden_ratio 55 | self.intermediate_size = intermediate_size 56 | self.num_hidden_layers = num_hidden_layers 57 | self.num_heads = num_heads 58 | self.num_kv_heads = num_kv_heads 59 | self.feature_map = feature_map 60 | self.tie_feature_map_qk = tie_feature_map_qk 61 | self.norm_q = norm_q 62 | self.norm_k = norm_k 63 | self.norm_feature_map = norm_feature_map 64 | self.hidden_act = hidden_act 65 | self.max_position_embeddings = max_position_embeddings 66 | self.elementwise_affine = elementwise_affine 67 | self.norm_eps = norm_eps 68 | self.attn = attn 69 | self.use_cache = use_cache 70 | self.initializer_range = initializer_range 71 | self.fuse_cross_entropy = fuse_cross_entropy 72 | self.vocab_size = vocab_size 73 | self.num_memories = num_memories 74 | self.topk = topk 75 | self.capacity = capacity 76 | self.use_layer_wise_balance = use_layer_wise_balance 77 | self.aux_loss_scale = aux_loss_scale 78 | self.shared_mem = shared_mem 79 | 80 | if attn is not None: 81 | if not isinstance(attn, Dict): 82 | raise ValueError("attn must be a dictionary") 83 | if 'layers' not in attn: 84 | raise ValueError("Layer indices must be provided to initialize hybrid attention layers") 85 | if 'num_heads' not in attn: 86 | raise ValueError("Number of heads must be provided to initialize hybrid attention layers") 87 | attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) 88 | attn['window_size'] = attn.get('window_size', None) 89 | 90 | super().__init__( 91 | pad_token_id=pad_token_id, 92 | bos_token_id=bos_token_id, 93 | eos_token_id=eos_token_id, 94 | tie_word_embeddings=tie_word_embeddings, 95 | **kwargs, 96 | ) 97 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import os 3 | import re 4 | from pathlib import Path 5 | 6 | from setuptools import find_packages, setup 7 | 8 | setup( 9 | name='mom', 10 | version='0.1.0', 11 | description='MoM: Mixture of Memories', 12 | author='Jusen Du', 13 | author_email='dujusen@gmail.com', 14 | url='https://github.com/OpenSparseLLMs/MoM', 15 | long_description=open('README.md').read(), 16 | long_description_content_type='text/markdown', 17 | packages=find_packages(), 18 | classifiers=[ 19 | 'Programming Language :: Python :: 3', 20 | 'License :: OSI Approved :: Apache License', 21 | 'Operating System :: OS Independent', 22 | 'Topic :: Scientific/Engineering :: Artificial Intelligence' 23 | ], 24 | python_requires='>=3.7', 25 | install_requires=[ 26 | 'torch>=2.3', 27 | 'transformers>=4.45.0', 28 | 'triton>=3.0', 29 | 'datasets>=3.1.0', 30 | 'einops', 31 | 'ninja', 32 | 'fla @ git+https://github.com/fla-org/flash-linear-attention' 33 | ], 34 | ) 35 | -------------------------------------------------------------------------------- /training/configs/mom_1.3B.json: -------------------------------------------------------------------------------- 1 | { 2 | "attn_mode": "chunk", 3 | "bos_token_id": 1, 4 | "expand_v": 1, 5 | "fuse_cross_entropy": true, 6 | "hidden_act": "swish", 7 | "hidden_ratio": 4, 8 | "hidden_size": 2048, 9 | "initializer_range": 0.02, 10 | "intermediate_size": null, 11 | "model_type": "mom_gated_deltanet", 12 | "num_heads": 4, 13 | "head_dim": 256, 14 | "num_hidden_layers": 24, 15 | "norm_eps": 1e-06, 16 | "tie_word_embeddings": true, 17 | "use_cache": true, 18 | "vocab_size": 32000, 19 | "use_short_conv": true, 20 | "num_memories": 4, 21 | "topk": 2, 22 | "capacity": 1.0, 23 | "use_layer_wise_balance": true, 24 | "aux_loss_scale": 0.01, 25 | "shared_mem": true, 26 | "single_kv_proj": false 27 | } -------------------------------------------------------------------------------- /training/configs/mom_340M.json: -------------------------------------------------------------------------------- 1 | { 2 | "attn_mode": "chunk", 3 | "bos_token_id": 1, 4 | "expand_v": 1, 5 | "fuse_cross_entropy": true, 6 | "hidden_act": "swish", 7 | "hidden_ratio": 3, 8 | "hidden_size": 1024, 9 | "initializer_range": 0.02, 10 | "intermediate_size": null, 11 | "model_type": "mom_gated_deltanet", 12 | "num_heads": 4, 13 | "head_dim": 256, 14 | "num_hidden_layers": 24, 15 | "norm_eps": 1e-06, 16 | "tie_word_embeddings": true, 17 | "use_cache": true, 18 | "vocab_size": 32000, 19 | "use_short_conv": true, 20 | "num_memories": 4, 21 | "topk": 2, 22 | "capacity": 1.0, 23 | "use_layer_wise_balance": true, 24 | "aux_loss_scale": 0.01, 25 | "shared_mem": true, 26 | "single_kv_proj": false 27 | } -------------------------------------------------------------------------------- /training/flame/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenSparseLLMs/MoM/2c3ae59f5c0b749f916189433b8dad5a3415dc08/training/flame/__init__.py -------------------------------------------------------------------------------- /training/flame/logging.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import json 4 | import logging 5 | import os 6 | import sys 7 | import time 8 | 9 | from transformers.trainer_callback import (ExportableState, TrainerCallback, 10 | TrainerControl, TrainerState) 11 | from transformers.training_args import TrainingArguments 12 | 13 | 14 | def get_logger(name: str = None) -> logging.Logger: 15 | formatter = logging.Formatter( 16 | fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S" 17 | ) 18 | handler = logging.StreamHandler(sys.stdout) 19 | handler.setFormatter(formatter) 20 | 21 | logger = logging.getLogger(name) 22 | if 'RANK' in os.environ and int(os.environ['RANK']) == 0: 23 | logger.setLevel(logging.INFO) 24 | logger.addHandler(handler) 25 | 26 | return logger 27 | 28 | 29 | logger = get_logger(__name__) 30 | 31 | LOG_FILE_NAME = "trainer_log.jsonl" 32 | 33 | 34 | class LogCallback(TrainerCallback, ExportableState): 35 | def __init__(self, start_time: float = None, elapsed_time: float = None): 36 | 37 | self.start_time = time.time() if start_time is None else start_time 38 | self.elapsed_time = 0 if elapsed_time is None else elapsed_time 39 | self.last_time = self.start_time 40 | 41 | def on_train_begin( 42 | self, 43 | args: TrainingArguments, 44 | state: TrainerState, 45 | control: TrainerControl, 46 | **kwargs 47 | ): 48 | r""" 49 | Event called at the beginning of training. 50 | """ 51 | if state.is_local_process_zero: 52 | if not args.resume_from_checkpoint: 53 | self.start_time = time.time() 54 | self.elapsed_time = 0 55 | else: 56 | self.start_time = state.stateful_callbacks['LogCallback']['start_time'] 57 | self.elapsed_time = state.stateful_callbacks['LogCallback']['elapsed_time'] 58 | 59 | if args.save_on_each_node: 60 | if not state.is_local_process_zero: 61 | return 62 | else: 63 | if not state.is_world_process_zero: 64 | return 65 | 66 | self.last_time = time.time() 67 | if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir: 68 | logger.warning("Previous log file in this folder will be deleted.") 69 | os.remove(os.path.join(args.output_dir, LOG_FILE_NAME)) 70 | 71 | def on_log( 72 | self, 73 | args: TrainingArguments, 74 | state: TrainerState, 75 | control: TrainerControl, 76 | logs, 77 | **kwargs 78 | ): 79 | if args.save_on_each_node: 80 | if not state.is_local_process_zero: 81 | return 82 | else: 83 | if not state.is_world_process_zero: 84 | return 85 | 86 | self.elapsed_time += time.time() - self.last_time 87 | self.last_time = time.time() 88 | if 'num_input_tokens_seen' in logs: 89 | logs['num_tokens'] = logs.pop('num_input_tokens_seen') 90 | state.log_history[-1].pop('num_input_tokens_seen') 91 | throughput = logs['num_tokens'] / args.world_size / self.elapsed_time 92 | state.log_history[-1]['throughput'] = logs['throughput'] = throughput 93 | state.stateful_callbacks["LogCallback"] = self.state() 94 | 95 | logs = dict( 96 | current_steps=state.global_step, 97 | total_steps=state.max_steps, 98 | loss=state.log_history[-1].get("loss", None), 99 | eval_loss=state.log_history[-1].get("eval_loss", None), 100 | predict_loss=state.log_history[-1].get("predict_loss", None), 101 | learning_rate=state.log_history[-1].get("learning_rate", None), 102 | epoch=state.log_history[-1].get("epoch", None), 103 | percentage=round(state.global_step / state.max_steps * 100, 2) if state.max_steps != 0 else 100, 104 | ) 105 | 106 | os.makedirs(args.output_dir, exist_ok=True) 107 | with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f: 108 | f.write(json.dumps(logs) + "\n") 109 | 110 | def state(self) -> dict: 111 | return { 112 | 'start_time': self.start_time, 113 | 'elapsed_time': self.elapsed_time 114 | } 115 | 116 | @classmethod 117 | def from_state(cls, state): 118 | return cls(state['start_time'], state['elapsed_time']) 119 | -------------------------------------------------------------------------------- /training/flame/parser.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import annotations 4 | 5 | from dataclasses import dataclass, field 6 | from typing import Optional 7 | 8 | import transformers 9 | from transformers import HfArgumentParser, TrainingArguments 10 | 11 | from flame.logging import get_logger 12 | 13 | logger = get_logger(__name__) 14 | 15 | 16 | @dataclass 17 | class TrainingArguments(TrainingArguments): 18 | 19 | model_name_or_path: str = field( 20 | default=None, 21 | metadata={ 22 | "help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models." 23 | }, 24 | ) 25 | tokenizer: str = field( 26 | default="fla-hub/gla-1.3B-100B", 27 | metadata={"help": "Name of the tokenizer to use."} 28 | ) 29 | use_fast_tokenizer: bool = field( 30 | default=False, 31 | metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."}, 32 | ) 33 | from_config: bool = field( 34 | default=True, 35 | metadata={"help": "Whether to initialize models from scratch."}, 36 | ) 37 | dataset: Optional[str] = field( 38 | default=None, 39 | metadata={"help": "The dataset(s) to use. Use commas to separate multiple datasets."}, 40 | ) 41 | dataset_name: Optional[str] = field( 42 | default=None, 43 | metadata={"help": "The name of provided dataset(s) to use."}, 44 | ) 45 | cache_dir: str = field( 46 | default=None, 47 | metadata={"help": "Path to the cached tokenized dataset."}, 48 | ) 49 | split: str = field( 50 | default="train", 51 | metadata={"help": "Which dataset split to use for training and evaluation."}, 52 | ) 53 | streaming: bool = field( 54 | default=False, 55 | metadata={"help": "Enable dataset streaming."}, 56 | ) 57 | hf_hub_token: Optional[str] = field( 58 | default=None, 59 | metadata={"help": "Auth token to log in with Hugging Face Hub."}, 60 | ) 61 | preprocessing_num_workers: Optional[int] = field( 62 | default=None, 63 | metadata={"help": "The number of processes to use for the pre-processing."}, 64 | ) 65 | buffer_size: int = field( 66 | default=2048, 67 | metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}, 68 | ) 69 | context_length: int = field( 70 | default=2048, 71 | metadata={"help": "The context length of the tokenized inputs in the dataset."}, 72 | ) 73 | varlen: bool = field( 74 | default=False, 75 | metadata={"help": "Enable training with variable length inputs."}, 76 | ) 77 | 78 | 79 | def get_train_args(): 80 | parser = HfArgumentParser(TrainingArguments) 81 | args, unknown_args = parser.parse_args_into_dataclasses(return_remaining_strings=True) 82 | 83 | if unknown_args: 84 | print(parser.format_help()) 85 | print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args)) 86 | raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args)) 87 | 88 | if args.should_log: 89 | transformers.utils.logging.set_verbosity(args.get_process_log_level()) 90 | transformers.utils.logging.enable_default_handler() 91 | transformers.utils.logging.enable_explicit_format() 92 | # set seeds manually 93 | transformers.set_seed(args.seed) 94 | return args 95 | -------------------------------------------------------------------------------- /training/run.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from datasets import load_from_disk 4 | from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, 5 | Trainer, set_seed) 6 | 7 | import sys 8 | import os 9 | import torch 10 | from torch import nn 11 | import mom 12 | import fla 13 | from flame.data import DataCollatorForLanguageModeling 14 | from flame.logging import LogCallback, get_logger 15 | from flame.parser import get_train_args 16 | import wandb 17 | from torchinfo import summary 18 | 19 | logger = get_logger(__name__) 20 | 21 | 22 | def main(): 23 | # torch.autograd.set_detect_anomaly(True) 24 | args = get_train_args() 25 | logger.info(args) 26 | 27 | tokenizer = AutoTokenizer.from_pretrained( 28 | args.tokenizer, 29 | use_fast=args.use_fast_tokenizer, 30 | trust_remote_code=True, 31 | add_bos_token=True, 32 | add_eos_token=False 33 | ) 34 | if tokenizer.pad_token_id is None: 35 | tokenizer.pad_token = tokenizer.eos_token 36 | logger.info("Add pad token: {}".format(tokenizer.pad_token)) 37 | # args.from_config = False 38 | if args.from_config: 39 | logger.info("All model params are randomly initialized for from-scratch training.") 40 | model = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(args.model_name_or_path)) 41 | else: 42 | logger.info(f"Loading pretrained checkpoint {args.model_name_or_path}") 43 | model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path) 44 | for name, param in model.named_parameters(): 45 | if 'gate' in name: 46 | if 'weight' in name: 47 | nn.init.xavier_normal_(param) 48 | model.train() 49 | 50 | # summary(model, depth=6) 51 | # exit(0) 52 | 53 | trainable_params, all_param = model.num_parameters(only_trainable=True), model.num_parameters() 54 | logger.info(f"% of trainable params: {trainable_params:d} / {all_param:d} = {trainable_params / all_param:.2%}") 55 | logger.info(f"{tokenizer}\n{model}\n{model.config}") 56 | 57 | logger.info(f"Loading the `{args.split}` split directly from the cache {args.cache_dir}...") 58 | dataset = load_from_disk(args.cache_dir) 59 | logger.info(f"{dataset}") 60 | logger.info(f"Shuffling the dataset with seed {args.seed}") 61 | dataset = dataset.shuffle(seed=args.seed) 62 | logger.info("Creating the data collator") 63 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, varlen=args.varlen) 64 | logger.info(f"{data_collator}") 65 | 66 | if args.lr_scheduler_type == 'cosine_with_min_lr': 67 | args.lr_scheduler_kwargs = {'min_lr_rate': 0.1} 68 | if args.lr_scheduler_type == 'warmup_stable_decay': 69 | args.lr_scheduler_kwargs = { 70 | 'num_stable_steps': args.max_steps * 0.9 - args.warmup_steps, 71 | 'num_decay_steps': args.max_steps * 0.1 72 | } 73 | 74 | args.logging_steps = 16 75 | trainer = Trainer( 76 | model=model, 77 | args=args, 78 | tokenizer=tokenizer, 79 | data_collator=data_collator, 80 | callbacks=[LogCallback()], 81 | train_dataset=dataset 82 | ) 83 | 84 | def detect_nan_hook(grad, name): 85 | if torch.isnan(grad).any(): 86 | print(f"NaN detected in gradients of {name}!") 87 | print(f"Gradient values: {grad}") 88 | exit() 89 | 90 | # 注册钩子到每个参数 91 | for name, param in model.named_parameters(): 92 | param.register_hook(lambda grad, name=name: detect_nan_hook(grad, name)) 93 | 94 | results = trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) 95 | trainer.save_model() 96 | tokenizer.save_pretrained(trainer.args.output_dir) 97 | 98 | trainer.log_metrics("train", results.metrics) 99 | trainer.save_metrics("train", results.metrics) 100 | trainer.save_state() 101 | 102 | 103 | if __name__ == "__main__": 104 | main() 105 | --------------------------------------------------------------------------------