├── .gitignore
├── LICENSE
├── README.md
├── assets
└── mom_fig1.png
├── benchmarks
├── benchmark_generation.py
├── benchmark_training_throughput.py
├── modules
│ ├── benchmark_cross_entropy.py
│ └── benchmark_layernorm.py
└── ops
│ ├── benchmark.py
│ ├── benchmark_abc.py
│ ├── benchmark_based.py
│ ├── benchmark_delta_rule.py
│ ├── benchmark_fla.py
│ ├── benchmark_gla.py
│ ├── benchmark_gsa.py
│ ├── benchmark_hgrn.py
│ ├── benchmark_retention.py
│ ├── benchmark_rwkv6.py
│ └── benchmark_simple_gla_vs_mamba2.py
├── evals
├── harness.py
└── ppl.py
├── lm-eval-harness
├── CITATION.bib
├── CODEOWNERS
├── LICENSE.md
├── README.md
├── docs
│ ├── CONTRIBUTING.md
│ ├── README.md
│ ├── decontamination.md
│ ├── img
│ │ └── fewshot_example_gpt3.png
│ ├── interface.md
│ ├── model_guide.md
│ ├── new_task_guide.md
│ └── task_guide.md
├── ignore.txt
├── launch.py
├── launch_hf.py
├── launch_jrt.py
├── launch_local.py
├── lm_eval
│ ├── __init__.py
│ ├── __main__.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── filter.py
│ │ ├── instance.py
│ │ ├── metrics.py
│ │ ├── model.py
│ │ ├── registry.py
│ │ ├── samplers.py
│ │ └── task.py
│ ├── decontamination
│ │ ├── __init__.py
│ │ ├── archiver.py
│ │ ├── decontaminate.py
│ │ └── janitor.py
│ ├── evaluator.py
│ ├── filters
│ │ ├── __init__.py
│ │ ├── decontamination.py
│ │ ├── extraction.py
│ │ ├── selection.py
│ │ └── transformation.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── anthropic_llms.py
│ │ ├── based_lm.py
│ │ ├── dummy.py
│ │ ├── gguf.py
│ │ ├── huggingface.py
│ │ ├── jrt_lm.py
│ │ ├── local_lm.py
│ │ ├── local_utils
│ │ │ ├── jrt_utils.py
│ │ │ └── loading.py
│ │ ├── mamba_lm.py
│ │ ├── neuron_optimum.py
│ │ ├── openai_completions.py
│ │ ├── optimum_lm.py
│ │ ├── textsynth.py
│ │ ├── utils.py
│ │ └── vllm_causallms.py
│ ├── prompts
│ │ └── __init__.py
│ ├── tasks
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── based_drop
│ │ │ └── task.py
│ │ ├── based_fda
│ │ │ ├── README.md
│ │ │ └── task.py
│ │ ├── based_nq
│ │ │ └── task.py
│ │ ├── based_squadv2
│ │ │ ├── README.md
│ │ │ └── task.py
│ │ ├── based_swde
│ │ │ ├── README.md
│ │ │ └── task.py
│ │ ├── based_triviaqa
│ │ │ └── task.py
│ │ ├── scrolls
│ │ │ ├── README.md
│ │ │ └── task.py
│ │ └── super_glue
│ │ │ ├── README.md
│ │ │ ├── cb
│ │ │ ├── aggregate.py
│ │ │ └── t5_utils.py
│ │ │ ├── copa
│ │ │ └── utils.py
│ │ │ ├── multirc
│ │ │ └── t5_utils.py
│ │ │ ├── record
│ │ │ ├── t5_utils.py
│ │ │ └── util.py
│ │ │ └── wsc
│ │ │ ├── preprocess_wsc.py
│ │ │ └── t5_utils.py
│ └── utils.py
├── prompt_scripts
│ ├── collect_results.py
│ ├── run_jrt_prompt_hazy.sh
│ └── run_jrt_prompt_hf.sh
├── pyproject.toml
├── requirements.txt
├── setup.py
└── templates
│ └── new_yaml_task
│ └── README.md
├── mom
├── __init__.py
├── layers
│ ├── __init__.py
│ ├── mom_gated_deltanet.py
│ ├── mom_gla.py
│ ├── mom_gsa.py
│ └── mom_linear_attn.py
└── models
│ ├── __init__.py
│ ├── mom_gated_deltanet
│ ├── __init__.py
│ ├── configuration_mom_gated_deltanet.py
│ └── modeling_mom_gated_deltanet.py
│ ├── mom_gla
│ ├── __init__.py
│ ├── configuration_mom_gla.py
│ └── modeling_mom_gla.py
│ ├── mom_gsa
│ ├── __init__.py
│ ├── configuration_mom_gsa.py
│ └── modeling_mom_gsa.py
│ └── mom_linear_attn
│ ├── __init__.py
│ ├── configuration_mom_linear_attn.py
│ └── modeling_mom_linear_attn.py
├── setup.py
└── training
├── README.md
├── configs
├── mom_1.3B.json
└── mom_340M.json
├── flame
├── __init__.py
├── data.py
├── logging.py
└── parser.py
├── preprocess.py
├── run.py
└── train.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | # test file
2 | test.py
3 | mingpt
4 |
5 | # data files
6 | data
7 |
8 | # # bash scripts
9 | # *.sh
10 |
11 | # docs
12 | docs/_build
13 |
14 | # intermediate files
15 | build
16 | dist
17 | *.egg-info
18 |
19 | # experimental results
20 | exp
21 | fineweb
22 | SlimPajama
23 | slimpajama
24 | results
25 | wandb
26 | *.csv
27 | *.html
28 | eval_results
29 |
30 | # log and config files
31 | log.*
32 | *.log
33 | *.cfg
34 | *.ini
35 | *.yml
36 | *.yaml
37 |
38 | # pycache
39 | __pycache__
40 |
41 | # saved model
42 | *.pkl
43 | *.pt
44 |
45 | # hidden files
46 | .*
47 |
48 | # vscode
49 | .vscode
50 |
51 | # macOS
52 | .DS_Store
53 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # MoM: Mixture-of-Memories
4 | [](https://arxiv.org/abs/2502.13685)
5 | [](https://huggingface.co/linear-moe-hub)
6 | [](https://zhuanlan.zhihu.com/p/25066090353)
7 | [](https://github.com/OpenSparseLLMs/MoM/stargazers)
8 |
9 |
10 |
11 | Welcome to MoM! This repository provides the implementation of [MoM: Linear Sequence Modeling with Mixture-of-Memories](https://arxiv.org/abs/2502.13685), on huggingface eco-system. MoM is compatible with all kinds of linear sequence modeling methods like: linear attention, SSM, linear RNN, etc. **Here is an introductory artical about MoM (in Chinese) on [Zhihu](https://zhuanlan.zhihu.com/p/25066090353)**.
12 |
13 |
14 |
15 |
16 |
17 | MoM Architecture
18 |
19 |
20 | ## Installation
21 |
22 | The following requirements should be satisfied:
23 | - [PyTorch](https://pytorch.org/) >= 2.5
24 | - [Triton](https://github.com/openai/triton) >=3.0
25 | - [einops](https://einops.rocks/)
26 | - [transformers](https://github.com/huggingface/transformers) >=4.45.0
27 | - [datasets](https://github.com/huggingface/datasets) >=3.3.0
28 | - [causal-conv1d](https://github.com/Dao-AILab/causal-conv1d) >=1.4.0
29 |
30 | Install the package from source:
31 | ```bash
32 | pip install -e .
33 | ```
34 |
35 | ## Getting Started
36 |
37 | ### Data Preparation
38 | Before training, make sure to preprocess your data by following the steps outlined in [training/README.md](training/README.md).
39 |
40 | ### Training From Scratch
41 |
42 | To start training with default setup, simply run:
43 | ```bash
44 | cd training
45 |
46 | bash train.sh \
47 | nodes=4 \
48 | gpus=8 \
49 | type=mom \
50 | lr=3e-4 \
51 | steps=30720 \
52 | batch=8 \
53 | update=1 \
54 | warmup=1024 \
55 | context=2048 \
56 | path=SlimPajama/mom-15B \
57 | project=SlimPajama \
58 | model=configs/mom_340M.json \
59 | tokenizer=fla-hub/gla-1.3B-100B \
60 | data=SlimPajama-627B \
61 | cache=data/chunk1/train
62 | ```
63 |
64 | You can also
65 | - Modify the script to adjust the modeling and training settings.
66 | - e.g., modify [examples/configs/mom_340M.json](examples/configs/mom_340M.json) to adjust the MoM model structure.
67 |
68 | ### Evaluation
69 |
70 | To evaluate model checkpoints on **commonsense reasoning benchmarks**, we recommend you to run:
71 | ```bash
72 | MODEL_PATH=training/SlimPajama/mom-15B/checkpoint-30720
73 |
74 | accelerate launch --multi_gpu evals/harness.py --model hf \
75 | --model_args pretrained=$MODEL_PATH,dtype=bfloat16 \
76 | --tasks arc_easy,arc_challenge,hellaswag,lambada_standard,piqa,winogrande,wikitext \
77 | --output_path eval_results \
78 | --batch_size 32 \
79 | --device cuda
80 | ```
81 |
82 | To evaluate model checkpoints on **recall-intensive tasks**, we recommend you to run:
83 | 1. Install lm_eval
84 | ```bash
85 | cd lm-eval-harness
86 | pip install -e .
87 | ```
88 | 2. Run the script:
89 | ```bash
90 | MODEL_PATH=../training/SlimPajama/mom-15B/checkpoint-30720
91 |
92 | CUDA_VISIBLE_DEVICES=0,1,2,3,4 python launch_local.py \
93 | --batch-size 32 \
94 | -t based_squad \
95 | -t based_swde \
96 | -t based_fda \
97 | -t based_drop \
98 | -t based_triviaqa \
99 | -t based_nq_2048 \
100 | -m $MODEL_PATH \
101 | --context_length 2048 \
102 | --answer_length 48 \
103 | --cutting_context \
104 | --limit -1 \
105 | -p
106 | ```
107 |
108 | ## Acknowledgement
109 | This repo builds upon the open-source [flash-linear-attention](https://github.com/fla-org/flash-linear-attention) and the evaluation code is based on [prefix-linear-attention](https://github.com/HazyResearch/prefix-linear-attention). Happy experimenting! 🔥🚀🔥
110 |
111 | ## Citation
112 | If you find this repo useful, please consider citing our paper:
113 | ```bib
114 | @article{du2025mom,
115 | title={MoM: Linear Sequence Modeling with Mixture-of-Memories},
116 | author={Du, Jusen and Sun, Weigao and Lan, Disen and Hu, Jiaxi and Cheng, Yu},
117 | journal={arXiv preprint arXiv:2502.13685},
118 | year={2025}
119 | }
120 | ```
--------------------------------------------------------------------------------
/assets/mom_fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenSparseLLMs/MoM/2c3ae59f5c0b749f916189433b8dad5a3415dc08/assets/mom_fig1.png
--------------------------------------------------------------------------------
/benchmarks/benchmark_generation.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) 2023-2024, Songlin Yang, Yu Zhang.
3 |
4 | import argparse
5 | import time
6 |
7 | import torch
8 | from datasets import load_dataset
9 | from transformers import AutoModelForCausalLM, AutoTokenizer
10 |
11 | import fla # noqa
12 | import mom # noqa
13 |
14 |
15 | def sizeof_fmt(num, suffix='B'):
16 | for unit in ('', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi'):
17 | if abs(num) < 1024.0:
18 | return f'{num:3.1f}{unit}{suffix}'
19 | num /= 1024.0
20 | return f'{num:.1f}Yi{suffix}'
21 |
22 |
23 | if __name__ == "__main__":
24 | parser = argparse.ArgumentParser(description="Generation benchmarking")
25 | parser.add_argument("--path", type=str, default="fla-hub/transformer-1.3B-100B")
26 | parser.add_argument("--data", type=str, default="fla-hub/pg19")
27 | parser.add_argument("--length", type=int, default=128)
28 | parser.add_argument("--maxlen", type=int, default=128)
29 | parser.add_argument("--no-cache", action='store_true')
30 | parser.add_argument("--temperature", type=float, default=0.5)
31 | parser.add_argument("--topp", type=float, default=0.2)
32 | parser.add_argument("--repetition_penalty", type=float, default=1.1)
33 | args = parser.parse_args()
34 |
35 | device = "cuda"
36 | dtype = torch.bfloat16
37 | torch.manual_seed(0)
38 |
39 | print(f"Loading {args.path}")
40 | tokenizer = AutoTokenizer.from_pretrained(
41 | args.path,
42 | trust_remote_code=True,
43 | add_eos_token=False
44 | )
45 | tokenizer.pad_token_id = tokenizer.eos_token_id
46 | print(f"{tokenizer}")
47 |
48 | model = AutoModelForCausalLM.from_pretrained(
49 | args.path,
50 | device_map={"": device},
51 | torch_dtype=dtype,
52 | use_cache=not args.no_cache
53 | )
54 | model.eval()
55 | print(f"{model.config}\n{model}\nNumber of parameters: {model.num_parameters()} ({sizeof_fmt(model.num_parameters())})\n")
56 |
57 | print(f"Loading {args.data}")
58 | dataset = load_dataset(args.data, split='train', trust_remote_code=True)
59 | print(f"{dataset}")
60 |
61 | prompt = dataset[0]['text']
62 | tokens = tokenizer(prompt, return_tensors="pt")
63 | input_ids = tokens.input_ids.to(device=device)[:, :args.length].contiguous()
64 | max_length = input_ids.shape[1] + args.maxlen
65 |
66 | torch.cuda.synchronize()
67 | start = time.time()
68 | with torch.inference_mode():
69 | text = model.generate(
70 | input_ids=input_ids,
71 | use_cache=not args.no_cache,
72 | max_length=max_length,
73 | pad_token_id=tokenizer.eos_token_id,
74 | eos_token_id=tokenizer.bos_token_id,
75 | do_sample=True,
76 | temperature=args.temperature,
77 | top_p=args.topp,
78 | repetition_penalty=args.repetition_penalty
79 | )
80 | torch.cuda.synchronize()
81 | elapsed = time.time() - start
82 | print(f"Prompt:\n{tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0].strip()}\n")
83 | print(f"Generated:\n{tokenizer.batch_decode(text, skip_special_tokens=True)[0].strip()}\n")
84 | print(f"Prompt length: {len(input_ids[0])}, generation length: {len(text[0]) - len(input_ids[0])}")
85 | print(f"Total prompt processing + decoding time: {elapsed * 1000:.0f}ms")
86 | print(f"Max memory used: {sizeof_fmt(torch.cuda.max_memory_allocated())}")
87 |
--------------------------------------------------------------------------------
/benchmarks/benchmark_training_throughput.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import argparse
4 | import time
5 | from typing import Optional, Tuple
6 |
7 | import torch
8 | from accelerate import Accelerator
9 | from torch.cuda import max_memory_allocated, memory_allocated
10 | from torch.optim import AdamW
11 | from tqdm import trange
12 | from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig
13 | from transformers.optimization import get_cosine_schedule_with_warmup
14 |
15 | import fla
16 | import mom
17 |
18 | classes1 = [getattr(mom.models, i) for i in mom.models.__all__]
19 | classes2 = [getattr(fla.models, i) for i in fla.models.__all__]
20 | classes = classes1 + classes2
21 | configs = {i.model_type: i() for i in classes if issubclass(i, PretrainedConfig)}
22 |
23 |
24 | def sizeof_fmt(num, suffix='B'):
25 | for unit in ('', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi'):
26 | if abs(num) < 1024.0:
27 | return f'{num:.2f}{unit}{suffix}'
28 | num /= 1024.0
29 | return f'{num:.2f}Yi{suffix}'
30 |
31 |
32 | def prepare_inputs(
33 | batch_size: int,
34 | seq_len: int,
35 | varlen: bool,
36 | vocab_size: int,
37 | device: torch.device
38 | ):
39 | if varlen:
40 | tokens = torch.randint(high=vocab_size, size=(1, batch_size * seq_len), device=device)
41 | offsets = torch.cat([
42 | torch.tensor([0], dtype=torch.long, device=device),
43 | torch.randperm(batch_size * seq_len - 16, device=device)[:batch_size-1] + 16,
44 | torch.tensor([batch_size * seq_len], dtype=torch.long, device=device)
45 | ], 0).sort()[0]
46 | else:
47 | tokens = torch.randint(high=vocab_size, size=(batch_size, seq_len), device=device)
48 | offsets = None
49 | return tokens, offsets
50 |
51 |
52 | def profile(
53 | name: str,
54 | batch_size: int = 8,
55 | seq_len: int = 2048,
56 | varlen: bool = False,
57 | warmup_steps: int = 16,
58 | steps: int = 32,
59 | total_steps: int = 1024,
60 | lr: float = 3e-4,
61 | betas: Tuple[float] = (0.9, 0.95),
62 | weight_decay: float = 0.1,
63 | dtype: Optional[torch.dtype] = torch.bfloat16,
64 | mixed_precision: str = 'bf16'
65 | ):
66 | device = torch.device('cuda')
67 | config = configs[name] if name in configs else AutoConfig.from_pretrained(name)
68 | model = AutoModelForCausalLM.from_config(config).cuda().to(dtype)
69 | num_parameters = model.num_parameters()
70 | print(f"Initializing {name} model from the config:\n{config}\n{model}")
71 | print(f"Number of parameters in total: {num_parameters} ({sizeof_fmt(num_parameters)})")
72 | print(f"Allocated memory after initialization: {sizeof_fmt(memory_allocated(device))}")
73 |
74 | accelerator = Accelerator(mixed_precision=mixed_precision)
75 | optimizer = AdamW(
76 | model.parameters(),
77 | lr=lr,
78 | betas=betas,
79 | weight_decay=weight_decay,
80 | fused=True
81 | )
82 | scheduler = get_cosine_schedule_with_warmup(optimizer, 0, total_steps)
83 |
84 | bar = trange(warmup_steps)
85 |
86 | model, optimizer, scheduler = accelerator.prepare(model, optimizer, scheduler)
87 | torch.cuda.synchronize(device)
88 | for _ in bar:
89 | # forward pass
90 | tokens, offsets = prepare_inputs(
91 | batch_size=batch_size,
92 | seq_len=seq_len,
93 | varlen=varlen,
94 | vocab_size=config.vocab_size,
95 | device=device
96 | )
97 | outputs = model(tokens, labels=tokens, offsets=offsets)
98 | # backward pass
99 | accelerator.backward(outputs.loss)
100 | optimizer.step()
101 | scheduler.step()
102 | optimizer.zero_grad()
103 | bar.set_description_str(f"Max memory allocated: {sizeof_fmt(max_memory_allocated(device))}")
104 |
105 | start, total_tokens = time.time(), 0
106 | bar = trange(steps)
107 | torch.cuda.synchronize(device)
108 | for _ in bar:
109 | # forward pass
110 | tokens, offsets = prepare_inputs(
111 | batch_size=batch_size,
112 | seq_len=seq_len,
113 | varlen=varlen,
114 | vocab_size=config.vocab_size,
115 | device=device
116 | )
117 | outputs = model(tokens, labels=tokens, offsets=offsets)
118 | # backward pass
119 | accelerator.backward(outputs.loss)
120 | optimizer.step()
121 | optimizer.zero_grad()
122 |
123 | total_tokens += batch_size * seq_len
124 | torch.cuda.synchronize(device)
125 | duration = time.time() - start
126 | bar.set_description_str(f"Thoughput: {total_tokens / duration:10.2f} tokens/s")
127 |
128 |
129 | if __name__ == "__main__":
130 | parser = argparse.ArgumentParser()
131 | parser.add_argument("--name", default='retnet')
132 | parser.add_argument("--batch_size", default=8, type=int)
133 | parser.add_argument("--seq_len", default=2048, type=int)
134 | parser.add_argument("--varlen", action='store_true')
135 | parser.add_argument("--warmup_steps", default=16, type=int)
136 | parser.add_argument("--steps", default=32, type=int)
137 | args = parser.parse_args()
138 | profile(
139 | name=args.name,
140 | batch_size=args.batch_size,
141 | seq_len=args.seq_len,
142 | varlen=args.varlen,
143 | warmup_steps=args.warmup_steps,
144 | steps=args.steps
145 | )
146 |
--------------------------------------------------------------------------------
/benchmarks/modules/benchmark_cross_entropy.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import triton
7 |
8 | from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
9 |
10 |
11 | @triton.testing.perf_report(
12 | triton.testing.Benchmark(
13 | # argument names to use as an x-axis for the plot
14 | x_names=['T'],
15 | # different possible values for `x_name`
16 | x_vals=[128 * 2 ** i for i in range(0, 8)],
17 | # argument name whose value corresponds to a different line in the plot
18 | line_arg='provider',
19 | # possible values for `line_arg``
20 | line_vals=['naive', 'fused', 'fused_linear', 'naive_bwd', 'fused_bwd', 'fused_linear_bwd'],
21 | # label name for the lines
22 | line_names=['naive', 'fused', 'fused_linear', 'naive_bwd', 'fused_bwd', 'fused_linear_bwd'],
23 | # line styles
24 | styles=[('green', '-'), ('blue', '--'), ('red', '-.'),
25 | ('cyan', ':'), ('yellow', 'dotted'), ('cyan', '--'), ('cyan', '-'), ('black', ':')],
26 | ylabel="Execution Time (ms)", # label name for the y-axis
27 | # name for the plot. Used also as a file name for saving the plot.
28 | plot_name="Performance",
29 | args={},
30 | )
31 | )
32 | def benchmark(T, provider):
33 | device = 'cuda'
34 | dtype = torch.bfloat16
35 | requires_grad = True
36 | B, H, V = 4, 4096, 120000
37 |
38 | x = torch.randn(B * T, H, device=device, requires_grad=requires_grad, dtype=dtype)
39 | target = torch.randint(0, V, (B * T,), device=device, dtype=torch.int64)
40 | w = torch.randn(V, H, device=device, requires_grad=requires_grad, dtype=dtype)
41 | b = torch.randn(V, device=device, requires_grad=requires_grad, dtype=dtype)
42 |
43 | quantiles = [0.5, 0.2, 0.8]
44 | results = 0, 0, 0
45 | if provider == 'naive':
46 | criterion = nn.CrossEntropyLoss()
47 | results = triton.testing.do_bench(lambda: criterion(F.linear(x, w, b), target), quantiles=quantiles)
48 | elif provider == 'naive_bwd':
49 | criterion = nn.CrossEntropyLoss()
50 | results = triton.testing.do_bench(lambda: criterion(F.linear(x, w, b), target).backward(), quantiles=quantiles)
51 | elif provider == 'fused':
52 | criterion = FusedCrossEntropyLoss()
53 | results = triton.testing.do_bench(lambda: criterion(F.linear(x, w, b), target), quantiles=quantiles)
54 | elif provider == 'fused_bwd':
55 | criterion = FusedCrossEntropyLoss()
56 | results = triton.testing.do_bench(lambda: criterion(F.linear(x, w, b), target).backward(), quantiles=quantiles)
57 | elif provider == 'fused_linear':
58 | criterion = FusedLinearCrossEntropyLoss()
59 | results = triton.testing.do_bench(lambda: criterion(x, target, w, b), quantiles=quantiles)
60 | elif provider == 'fused_linear_bwd':
61 | criterion = FusedLinearCrossEntropyLoss()
62 | results = triton.testing.do_bench(lambda: criterion(x, target, w, b).backward(), quantiles=quantiles)
63 | return results
64 |
65 |
66 | if __name__ == '__main__':
67 | benchmark.run(print_data=True)
68 |
--------------------------------------------------------------------------------
/benchmarks/modules/benchmark_layernorm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import torch.nn as nn
5 | import triton
6 |
7 | from fla.modules import GroupNorm, LayerNorm
8 |
9 |
10 | @triton.testing.perf_report(
11 | triton.testing.Benchmark(
12 | # argument names to use as an x-axis for the plot
13 | x_names=['T'],
14 | # different possible values for `x_name`
15 | x_vals=[128 * 2 ** i for i in range(0, 8)],
16 | # argument name whose value corresponds to a different line in the plot
17 | line_arg='provider',
18 | # possible values for `line_arg``
19 | line_vals=['naive_ln', 'fused_ln', 'naive_gn', 'fused_gn',
20 | 'naive_ln_bwd', 'fused_ln_bwd', 'naive_gn_bwd', 'fused_gn_bwd'],
21 | # label name for the lines
22 | line_names=['naive_ln', 'fused_ln', 'naive_gn', 'fused_gn',
23 | 'naive_ln_bwd', 'fused_ln_bwd', 'naive_gn_bwd', 'fused_gn_bwd'],
24 | # line styles
25 | styles=[('green', '-'), ('blue', '--'), ('red', '-.'),
26 | ('cyan', ':'), ('yellow', 'dotted'), ('cyan', '--'), ('cyan', '-'), ('black', ':')],
27 | ylabel="Execution Time (ms)", # label name for the y-axis
28 | # name for the plot. Used also as a file name for saving the plot.
29 | plot_name="Performance",
30 | args={},
31 | )
32 | )
33 | def benchmark(T, provider):
34 | device = 'cuda'
35 | dtype = torch.bfloat16
36 | requires_grad = True
37 | B, D = 16, 1024
38 |
39 | x = torch.randn(B * T, D, device=device, requires_grad=requires_grad, dtype=dtype)
40 |
41 | quantiles = [0.5, 0.2, 0.8]
42 | results = 0, 0, 0
43 | if provider.startswith('naive_ln'):
44 | norm = nn.LayerNorm(D, elementwise_affine=True, bias=True).to(device=device, dtype=dtype)
45 | results = triton.testing.do_bench(lambda: norm(x), quantiles=quantiles)
46 | if provider.startswith('fused_ln'):
47 | norm = LayerNorm(D, elementwise_affine=True, bias=True).to(device=device, dtype=dtype)
48 | results = triton.testing.do_bench(lambda: norm(x), quantiles=quantiles)
49 | if provider.startswith('naive_gn'):
50 | norm = nn.GroupNorm(4, D).to(device=device, dtype=dtype)
51 | results = triton.testing.do_bench(lambda: norm(x), quantiles=quantiles)
52 | if provider.startswith('fused_gn'):
53 | norm = GroupNorm(4, D, elementwise_affine=True, bias=True).to(device=device, dtype=dtype)
54 | results = triton.testing.do_bench(lambda: norm(x), quantiles=quantiles)
55 | if provider.startswith('naive_ln_bwd'):
56 | norm = nn.LayerNorm(D, elementwise_affine=True, bias=True).to(device=device, dtype=dtype)
57 | results = triton.testing.do_bench(lambda: norm(x).backward(x), quantiles=quantiles)
58 | if provider.startswith('fused_ln_bwd'):
59 | norm = LayerNorm(D, elementwise_affine=True, bias=True).to(device=device, dtype=dtype)
60 | results = triton.testing.do_bench(lambda: norm(x).backward(x), quantiles=quantiles)
61 | if provider.startswith('naive_gn_bwd'):
62 | norm = nn.GroupNorm(4, D).to(device=device, dtype=dtype)
63 | results = triton.testing.do_bench(lambda: norm(x).backward(x), quantiles=quantiles)
64 | if provider.startswith('fused_gn_bwd'):
65 | norm = GroupNorm(4, D, elementwise_affine=True, bias=True).to(device=device, dtype=dtype)
66 | results = triton.testing.do_bench(lambda: norm(x).backward(x), quantiles=quantiles)
67 | return results
68 |
69 |
70 | if __name__ == '__main__':
71 | benchmark.run(print_data=True)
72 |
--------------------------------------------------------------------------------
/benchmarks/ops/benchmark_abc.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import triton
5 | from torch.nn import functional as F
6 |
7 | from fla.ops.abc import chunk_abc
8 | from fla.ops.gla import chunk_gla
9 | from fla.ops.retention import chunk_retention
10 |
11 | try:
12 | from flash_attn import flash_attn_func
13 | HAS_FLASH = True
14 | except BaseException:
15 | HAS_FLASH = False
16 |
17 |
18 | @triton.testing.perf_report(
19 | triton.testing.Benchmark(
20 | # argument names to use as an x-axis for the plot
21 | x_names=['T'],
22 | # different possible values for `x_name`
23 | x_vals=[128 * 2 ** i for i in range(0, 8)],
24 | # argument name whose value corresponds to a different line in the plot
25 | line_arg='provider',
26 | # possible values for `line_arg``
27 | line_vals=['abc', 'gla', 'abc_bwd', 'gla_bwd', 'retention_bwd', 'flash_bwd'],
28 | # label name for the lines
29 | line_names=['abc', 'gla', 'abc_bwd', 'gla_bwd', 'retention_bwd', 'flash_bwd'],
30 | # line styles
31 | styles=[('green', '-'), ('blue', '--'), ('red', '-.'),
32 | ('cyan', ':'), ('yellow', 'dotted'), ('black', ':')],
33 | ylabel="Execution Time (ms)", # label name for the y-axis
34 | # name for the plot. Used also as a file name for saving the plot.
35 | plot_name="Performance",
36 | args={},
37 | )
38 | )
39 | def benchmark(T, provider):
40 | device = 'cuda'
41 | dtype = torch.bfloat16
42 | requires_grad = True
43 | B, H, D, M = 16, 4, 128, 64
44 |
45 | q = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
46 | k = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
47 | v = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
48 | if provider.startswith('flash'):
49 | q = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
50 | k = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
51 | v = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
52 | if provider.startswith('gla'):
53 | g = F.logsigmoid(torch.randn(B, H, T, D, device=device, dtype=dtype))
54 | g = g.clamp_min(-5).requires_grad_(requires_grad)
55 | if provider.startswith('abc'):
56 | s = torch.randn(B, H, T, M, device=device, requires_grad=requires_grad, dtype=dtype)
57 |
58 | do = torch.ones_like(v, dtype=dtype)
59 |
60 | quantiles = [0.5, 0.2, 0.8]
61 | if provider == 'abc':
62 | results = triton.testing.do_bench(lambda: chunk_abc(q, k, v, s), quantiles=quantiles)
63 | elif provider == 'gla':
64 | results = triton.testing.do_bench(lambda: chunk_gla(q, k, v, g), quantiles=quantiles)
65 | elif provider == 'abc_bwd':
66 | results = triton.testing.do_bench(lambda: chunk_abc(q, k, v, s)[0].backward(do), quantiles=quantiles)
67 | elif provider == 'gla_bwd':
68 | results = triton.testing.do_bench(lambda: chunk_gla(q, k, v, g)[0].backward(do), quantiles=quantiles)
69 | elif provider == 'retention_bwd':
70 | results = triton.testing.do_bench(lambda: chunk_retention(q, k, v)[0].backward(do), quantiles=quantiles)
71 | elif provider == 'flash_bwd':
72 | results = triton.testing.do_bench(lambda: flash_attn_func(q, k, v, causal=True).backward(do), quantiles=quantiles)
73 | return results
74 |
75 |
76 | if __name__ == '__main__':
77 | benchmark.run(print_data=True)
78 |
--------------------------------------------------------------------------------
/benchmarks/ops/benchmark_based.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import triton
5 |
6 | from fla.ops.based import fused_chunk_based, parallel_based
7 | from fla.ops.based.naive import naive_chunk_based, naive_parallel_based
8 |
9 | try:
10 | from flash_attn import flash_attn_func
11 | HAS_FLASH = True
12 | except Exception:
13 | HAS_FLASH = False
14 |
15 |
16 | @triton.testing.perf_report(
17 | triton.testing.Benchmark(
18 | # argument names to use as an x-axis for the plot
19 | x_names=['T'],
20 | # different possible values for `x_name`
21 | x_vals=[128 * 2 ** i for i in range(3, 8)],
22 | # argument name whose value corresponds to a different line in the plot
23 | line_arg='provider',
24 | line_vals=['fused_chunk', 'torch', 'parallel', 'parallel_chunk', 'fused_chunk_bwd', 'torch_bwd',
25 | 'parallel_bwd', 'parallel_chunk_bwd'] + (['flash', 'flash_bwd'] if HAS_FLASH else []),
26 | # label name for the lines
27 | line_names=['fused_chunk_fwd', 'torch_fwd', 'parallel_fwd', 'parallel_chunk_fwd',
28 | 'fused_chunk_fwdbwd', 'torch_fwdbwd', 'parallel_fwdbwd',
29 | 'parallel_chunk_fwdbwd'] + (['flash_fwd', 'flash_fwdbwd'] if HAS_FLASH else []),
30 |
31 | # line styles
32 | styles=[('green', '-'), ('blue', '-'), ('red', '-'), ('green', 'dotted'), ('blue', 'dotted'),
33 | ('red', 'dotted'), ('red', '--'), ('red', ':')] + ([('cyan', '-'), ('cyan', 'dotted')] if HAS_FLASH else []),
34 | ylabel="Execution Time (ms)", # label name for the y-axis
35 | # name for the plot. Used also as a file name for saving the plot.
36 | plot_name="Performance",
37 | args={},
38 | )
39 | )
40 | def benchmark(T, provider):
41 | device = 'cuda'
42 | dtype = torch.bfloat16
43 | requires_grad = True
44 | B, H, D = 8, 16, 128
45 |
46 | if provider == 'flash' or provider == 'flash_bwd':
47 | q = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
48 | k = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
49 | v = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
50 | else:
51 | q = torch.randn(B, H, T, 16, device=device, requires_grad=requires_grad, dtype=dtype)
52 | k = torch.randn(B, H, T, 16, device=device, requires_grad=requires_grad, dtype=dtype)
53 | v = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
54 | do = torch.ones_like(v, dtype=dtype)
55 |
56 | quantiles = [0.5, 0.2, 0.8]
57 | results = 0, 0, 0
58 | if provider == 'torch':
59 | if T > 1024:
60 | return results
61 | results = triton.testing.do_bench(lambda: naive_parallel_based(q, k, v), quantiles=quantiles)
62 | elif provider == 'fused_chunk':
63 | results = triton.testing.do_bench(lambda: fused_chunk_based(q, k, v), quantiles=quantiles)
64 | elif provider == 'parallel':
65 | results = triton.testing.do_bench(lambda: parallel_based(q, k, v), quantiles=quantiles)
66 | elif provider == 'parallel_chunk':
67 | results = triton.testing.do_bench(lambda: naive_chunk_based(q, k, v), quantiles=quantiles)
68 | elif provider == 'torch_bwd':
69 | if T > 1024:
70 | return results
71 | results = triton.testing.do_bench(lambda: naive_parallel_based(q, k, v).backward(do), quantiles=quantiles)
72 | elif provider == 'fused_chunk_bwd':
73 | results = triton.testing.do_bench(lambda: fused_chunk_based(q, k, v).backward(do), quantiles=quantiles)
74 | elif provider == 'parallel_bwd':
75 | results = triton.testing.do_bench(lambda: parallel_based(q, k, v).backward(do), quantiles=quantiles)
76 | elif provider == 'flash':
77 | results = triton.testing.do_bench(lambda: flash_attn_func(q, k, v, causal=True), quantiles=quantiles)
78 | elif provider == 'flash_bwd':
79 | results = triton.testing.do_bench(lambda: flash_attn_func(q, k, v, causal=True).backward(do), quantiles=quantiles)
80 | elif provider == 'parallel_chunk_bwd':
81 | results = triton.testing.do_bench(lambda: naive_chunk_based(q, k, v).backward(do), quantiles=quantiles)
82 | return results
83 |
84 |
85 | if __name__ == '__main__':
86 | benchmark.run(print_data=True, show_plots=True)
87 |
--------------------------------------------------------------------------------
/benchmarks/ops/benchmark_delta_rule.py:
--------------------------------------------------------------------------------
1 | # Install the newest triton version with
2 | # pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python"
3 |
4 | import torch
5 | from benchmark import benchmark_combined, benchmark_forward, benchmark_backward
6 |
7 | from fla.ops.delta_rule import (chunk_delta_rule,
8 | fused_recurrent_delta_rule, fused_chunk_delta_rule)
9 | from fla.ops.retention import fused_chunk_retention
10 | # from flash_attn import flash_attn_func
11 |
12 |
13 | def time_fwd(func, *args, **kwargs):
14 | time_fb = benchmark_forward(func, *args, **kwargs)
15 | return time_fb[1].mean
16 |
17 |
18 | def time_fwd_bwd(func, *args, **kwargs):
19 | time_fb = benchmark_combined(func, *args, **kwargs)
20 | return time_fb[1].mean
21 |
22 | def time_bwd(func, *args, **kwargs):
23 | time_fb = benchmark_backward(func, *args, **kwargs)
24 | return time_fb[1].mean
25 |
26 |
27 | repeats = 256
28 | device = 'cuda'
29 | dtype = torch.bfloat16
30 |
31 |
32 | bs_seqlen_vals = [(8, 2048), (4, 4096), (2, 8192)]
33 | causal_vals = [True]
34 | headdim_vals = [64, 128, 256]
35 | dim = 2048
36 | dropout_p = 0.0
37 |
38 |
39 | methods = (["chunk_delta_rule", "fused_chunk_delta_rule"])
40 | time_f = {}
41 | time_b = {}
42 | time_f_b = {}
43 | speed_f = {}
44 | speed_b = {}
45 | speed_f_b = {}
46 | for causal in causal_vals:
47 | for headdim in headdim_vals:
48 | for B, seqlen in bs_seqlen_vals:
49 | config = (causal, headdim, B, seqlen)
50 | H = dim // headdim
51 | q = torch.randn(B, H, seqlen, headdim, device=device, requires_grad=True, dtype=dtype)
52 | k = torch.nn.functional.normalize(torch.randn(B, H, seqlen, headdim, device=device, dtype=dtype), p=2, dim=-1).requires_grad_(True)
53 | v = torch.randn(B, H, seqlen, headdim, device=device, requires_grad=True, dtype=dtype)
54 | beta = torch.rand(B, H, seqlen, device=device, dtype=dtype).sigmoid().requires_grad_(True)
55 | o1, _ = chunk_delta_rule(q, k, v, beta)
56 | o1.sum().backward(retain_graph=True)
57 | f_b = time_fwd_bwd(
58 | chunk_delta_rule, q, k, v, beta, verbose=False
59 | )
60 | time_f_b[config, "chunk_delta_rule"] = f_b
61 |
62 | # q = torch.randn(B, seqlen, H, headdim, device=device, requires_grad=True, dtype=dtype)
63 | # k = torch.randn(B, seqlen, H, headdim, device=device, requires_grad=True, dtype=dtype)
64 | # v = torch.randn(B, seqlen, H, headdim, device=device, requires_grad=True, dtype=dtype)
65 | f_b = time_fwd_bwd(
66 | fused_chunk_delta_rule, q, k, v, beta, verbose=False
67 | )
68 | time_f_b[config, "fused_chunk_delta_rule"] = f_b
69 |
70 |
71 | print(f"### causal={causal}, headdim={headdim}, B={B}, seqlen={seqlen} ###")
72 | for method in methods:
73 | # time_f_b[config, method] = time_f[config, method] + time_b[config, method]
74 | print(f"{method:>50} fwd + bwd:\t {time_f_b[config, method]*1000:>6.4f} ms ")
75 |
76 | # speed_f[config, method] = efficiency(
77 | # flops(B, seqlen, headdim, H, causal, mode="fwd"),
78 | # time_f[config, method]
79 | # )
80 | # speed_b[config, method] = efficiency(
81 | # flops(B, seqlen, headdim, H, causal, mode="bwd"),
82 | # time_b[config, method]
83 | # )
84 | # speed_f_b[config, method] = efficiency(
85 | # flops(B, seqlen, headdim, H, causal, mode="fwd_bwd"),
86 | # time_f_b[config, method]
87 | # )
88 | # print(
89 | # f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, "
90 | # f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, "
91 | # f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s"
92 | # )
93 |
94 |
95 | # with open('flash2_attn_time.plk', 'wb') as fp:
96 | # pickle.dump((speed_f, speed_b, speed_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL)
97 |
--------------------------------------------------------------------------------
/benchmarks/ops/benchmark_fla.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import triton
5 |
6 | from fla.ops.based import parallel_based
7 | from fla.ops.gla import fused_chunk_gla
8 | from fla.ops.retention import fused_chunk_retention, parallel_retention
9 |
10 | try:
11 | from flash_attn import flash_attn_func
12 | HAS_FLASH = True
13 | except ImportError:
14 | HAS_FLASH = False
15 |
16 |
17 | @triton.testing.perf_report(
18 | triton.testing.Benchmark(
19 | # argument names to use as an x-axis for the plot
20 | x_names=['T'],
21 | # different possible values for `x_name`
22 | x_vals=[128 * 2 ** i for i in range(0, 8)],
23 | # argument name whose value corresponds to a different line in the plot
24 | line_arg='provider',
25 | # possible values for `line_arg``
26 | line_vals=['retention_parallel', 'retention_fused_chunk',
27 | 'gla_fused_chunk', 'based_parallel'] + (['flash'] if HAS_FLASH else []),
28 | # label name for the lines
29 | line_names=['retention_parallel_fwdbwd', 'retention_fused_chunk_fwdbwd',
30 | 'gla_fused_chunk_fwdbwd', 'based_parallel_fwdbwd'] + (['flash_fwdbwd'] if HAS_FLASH else []),
31 | # line styles
32 | styles=[('green', '-'), ('blue', '--'), ('red', '-.'), ('cyan', ':')] + \
33 | ([('yellow', 'dotted')] if HAS_FLASH else []),
34 | ylabel="Execution Time (ms)", # label name for the y-axis
35 | # name for the plot. Used also as a file name for saving the plot.
36 | plot_name="Performance",
37 | args={},
38 | )
39 | )
40 | def benchmark(T, provider):
41 | device = 'cuda'
42 | dtype = torch.bfloat16
43 | requires_grad = True
44 | B, H, D = 16, 8, 128
45 |
46 | if provider == 'flash':
47 | q = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
48 | k = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
49 | v = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
50 | elif "based" in provider:
51 | q = torch.randn(B, H, T, 16, device=device, requires_grad=requires_grad, dtype=dtype)
52 | k = torch.randn(B, H, T, 16, device=device, requires_grad=requires_grad, dtype=dtype)
53 | v = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
54 | elif "gla" in provider:
55 | q = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
56 | k = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
57 | v = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
58 | g = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
59 | else:
60 | q = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
61 | k = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
62 | v = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
63 |
64 | do = torch.rand_like(v, dtype=dtype)
65 |
66 | quantiles = [0.5, 0.2, 0.8]
67 | results = 0, 0, 0
68 | if provider == 'flash':
69 | results = triton.testing.do_bench(lambda: flash_attn_func(q, k, v).backward(do), quantiles=quantiles)
70 | elif provider == 'retention_parallel':
71 | results = triton.testing.do_bench(lambda: parallel_retention(q, k, v).backward(do), quantiles=quantiles)
72 | elif provider == 'retention_fused_chunk':
73 | results = triton.testing.do_bench(lambda: fused_chunk_retention(q, k, v).backward(do), quantiles=quantiles)
74 | elif provider == 'based_parallel':
75 | results = triton.testing.do_bench(lambda: parallel_based(q, k, v).backward(do), quantiles=quantiles)
76 | elif provider == 'gla_fused_chunk':
77 | results = triton.testing.do_bench(lambda: fused_chunk_gla(q, k, v, g).backward(do), quantiles=quantiles)
78 |
79 | return results
80 |
81 |
82 | if __name__ == '__main__':
83 | benchmark.run(print_data=True, show_plots=True)
84 |
--------------------------------------------------------------------------------
/benchmarks/ops/benchmark_gla.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import triton
5 | from torch.nn import functional as F
6 |
7 | from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
8 | from fla.ops.retention import chunk_retention, parallel_retention
9 | from fla.ops.retention.naive import naive_retention
10 |
11 |
12 | @triton.testing.perf_report(
13 | triton.testing.Benchmark(
14 | # argument names to use as an x-axis for the plot
15 | x_names=['T'],
16 | # different possible values for `x_name`
17 | x_vals=[128 * 2 ** i for i in range(0, 8)],
18 | # argument name whose value corresponds to a different line in the plot
19 | line_arg='provider',
20 | # possible values for `line_arg``
21 | line_vals=['fused_chunk_gla', 'recurrent_gla', 'chunk_gla', 'chunk_retention',
22 | 'fused_chunk_gla_bwd', 'recurrent_gla_bwd', 'chunk_gla_bwd', 'chunk_retention_bwd'],
23 | # label name for the lines
24 | line_names=['fused_chunk_gla', 'recurrent_gla', 'chunk_gla', 'chunk_retention',
25 | 'fused_chunk_gla_bwd', 'recurrent_gla_bwd', 'chunk_gla_bwd', 'chunk_retention_bwd'],
26 | # line styles
27 | styles=[('green', '-'), ('blue', '--'), ('red', '-.'),
28 | ('cyan', ':'), ('yellow', 'dotted'), ('cyan', '--'), ('cyan', '-'), ('black', ':')],
29 | ylabel="Execution Time (ms)", # label name for the y-axis
30 | # name for the plot. Used also as a file name for saving the plot.
31 | plot_name="Performance",
32 | args={},
33 | )
34 | )
35 | def benchmark(T, provider):
36 | device = 'cuda'
37 | dtype = torch.bfloat16
38 | requires_grad = True
39 | B, H, D = 16, 8, 128
40 |
41 | q = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
42 | k = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
43 | g = F.logsigmoid(torch.randn(B, H, T, D, device=device, dtype=dtype)).clamp_min(-5).requires_grad_(requires_grad)
44 | v = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
45 |
46 | do = torch.ones_like(q, dtype=dtype)
47 |
48 | quantiles = [0.5, 0.2, 0.8]
49 | results = 0, 0, 0
50 | if provider == 'torch':
51 | if T > 2048:
52 | return results
53 | results = triton.testing.do_bench(lambda: naive_retention(q, k, v), quantiles=quantiles)
54 | elif provider == 'recurrent_gla':
55 | results = triton.testing.do_bench(lambda: fused_recurrent_gla(q, k, v, g), quantiles=quantiles)
56 | elif provider == 'fused_chunk_gla':
57 | results = triton.testing.do_bench(lambda: fused_chunk_gla(q, k, v, g), quantiles=quantiles)
58 | elif provider == 'chunk_retention':
59 | results = triton.testing.do_bench(lambda: chunk_retention(q, k, v), quantiles=quantiles)
60 | elif provider == 'chunk_gla':
61 | results = triton.testing.do_bench(lambda: chunk_gla(q, k, v, g), quantiles=quantiles)
62 | elif provider == 'parallel':
63 | results = triton.testing.do_bench(lambda: parallel_retention(q, k, v), quantiles=quantiles)
64 | elif provider == 'torch_bwd':
65 | if T > 2048:
66 | return results
67 | elif provider == 'chunk_retention_bwd':
68 | results = triton.testing.do_bench(lambda: chunk_retention(q, k, v)[0].backward(do), quantiles=quantiles)
69 | elif provider == 'recurrent_gla_bwd':
70 | results = triton.testing.do_bench(lambda: fused_recurrent_gla(q, k, v, g)[0].backward(do), quantiles=quantiles)
71 | elif provider == 'fused_chunk_gla_bwd':
72 | results = triton.testing.do_bench(lambda: fused_chunk_gla(q, k, v, g)[0].backward(do), quantiles=quantiles)
73 | elif provider == 'chunk_gla_bwd':
74 | results = triton.testing.do_bench(lambda: chunk_gla(q, k, v, g)[0].backward(do), quantiles=quantiles)
75 | elif provider == 'parallel_bwd':
76 | results = triton.testing.do_bench(lambda: parallel_retention(q, k, v)[0].backward(do), quantiles=quantiles)
77 | return results
78 |
79 |
80 | if __name__ == '__main__':
81 | benchmark.run(print_data=True)
82 |
--------------------------------------------------------------------------------
/benchmarks/ops/benchmark_gsa.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import triton
5 | from torch.nn import functional as F
6 |
7 | from fla.ops.gla import chunk_gla
8 | from fla.ops.gsa import chunk_gsa, fused_recurrent_gsa
9 | from fla.ops.retention import chunk_retention
10 |
11 | try:
12 | from flash_attn import flash_attn_func
13 | HAS_FLASH = True
14 | except BaseException:
15 | HAS_FLASH = False
16 |
17 |
18 | @triton.testing.perf_report(
19 | triton.testing.Benchmark(
20 | # argument names to use as an x-axis for the plot
21 | x_names=['T'],
22 | # different possible values for `x_name`
23 | x_vals=[128 * 2 ** i for i in range(0, 8)],
24 | # argument name whose value corresponds to a different line in the plot
25 | line_arg='provider',
26 | # possible values for `line_arg``
27 | line_vals=['gsa_recurrent', 'gsa_chunk', 'gla',
28 | 'gsa_recurrent_bwd', 'gsa_chunk_bwd', 'gla_bwd', 'retention_bwd', 'flash_bwd'],
29 | # label name for the lines
30 | line_names=['gsa_recurrent', 'gsa_chunk', 'gla',
31 | 'gsa_recurrent_bwd', 'gsa_chunk_bwd', 'gla_bwd', 'retention_bwd', 'flash_bwd'],
32 | # line styles
33 | styles=[('green', '-'), ('blue', '--'), ('red', '-.'),
34 | ('cyan', ':'), ('yellow', 'dotted'), ('black', ':'), ('green', ':'), ('green', 'dotted'), ('green', ':')],
35 | ylabel="Execution Time (ms)", # label name for the y-axis
36 | # name for the plot. Used also as a file name for saving the plot.
37 | plot_name="Performance",
38 | args={},
39 | )
40 | )
41 | def benchmark(T, provider):
42 | device = 'cuda'
43 | dtype = torch.bfloat16
44 | requires_grad = True
45 | B, H, D, M = 16, 4, 128, 64
46 |
47 | q = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
48 | k = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
49 | v = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
50 | if provider.startswith('gsa'):
51 | f = F.logsigmoid(torch.randn(B, H, T, M, device=device, dtype=dtype))
52 | s = (1 - f.exp()).to(f.dtype)
53 | if provider.startswith('gla'):
54 | g = F.logsigmoid(torch.randn(B, H, T, D, device=device, dtype=dtype))
55 | g = g.clamp_min(-5).requires_grad_(requires_grad)
56 | if provider.startswith('flash'):
57 | q = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
58 | k = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
59 | v = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
60 | do = torch.ones_like(v, dtype=dtype)
61 |
62 | quantiles = [0.5, 0.2, 0.8]
63 | if provider == 'gsa_recurrent':
64 | return triton.testing.do_bench(lambda: fused_recurrent_gsa(q, k, v, s, f), quantiles=quantiles)
65 | if provider == 'gsa_chunk':
66 | return triton.testing.do_bench(lambda: chunk_gsa(q, k, v, s, f), quantiles=quantiles)
67 | elif provider == 'gla':
68 | return triton.testing.do_bench(lambda: chunk_gla(q, k, v, g), quantiles=quantiles)
69 | elif provider == 'gsa_recurrent_bwd':
70 | return triton.testing.do_bench(lambda: fused_recurrent_gsa(q, k, v, s, f)[0].backward(do), quantiles=quantiles)
71 | elif provider == 'gsa_chunk_bwd':
72 | return triton.testing.do_bench(lambda: chunk_gsa(q, k, v, s, f)[0].backward(do), quantiles=quantiles)
73 | elif provider == 'gla_bwd':
74 | return triton.testing.do_bench(lambda: chunk_gla(q, k, v, g)[0].backward(do), quantiles=quantiles)
75 | elif provider == 'retention_bwd':
76 | return triton.testing.do_bench(lambda: chunk_retention(q, k, v)[0].backward(do), quantiles=quantiles)
77 | elif provider == 'flash_bwd':
78 | return triton.testing.do_bench(lambda: flash_attn_func(q, k, v, causal=True).backward(do), quantiles=quantiles)
79 |
80 |
81 | if __name__ == '__main__':
82 | benchmark.run(print_data=True)
83 |
--------------------------------------------------------------------------------
/benchmarks/ops/benchmark_hgrn.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import triton
5 |
6 | from fla.ops.hgrn import chunk_hgrn, fused_recurrent_hgrn
7 |
8 |
9 | @triton.testing.perf_report(
10 | triton.testing.Benchmark(
11 | # argument names to use as an x-axis for the plot
12 | x_names=['T'],
13 | # different possible values for `x_name`
14 | x_vals=[128 * 2 ** i for i in range(0, 8)],
15 | # argument name whose value corresponds to a different line in the plot
16 | line_arg='provider',
17 | # possible values for `line_arg``
18 | line_vals=['chunk', 'recurrent', 'chunk_bwd', 'recurrent_bwd'],
19 | # label name for the lines
20 | line_names=['chunk', 'recurrent', 'chunk_bwd', 'recurrent_bwd'],
21 | # line styles
22 | styles=[('green', '-'), ('blue', '--'), ('red', '-.'), ('cyan', ':'), ('yellow', 'dotted'), ('black', 'dashed')],
23 | ylabel="Execution Time (ms)", # label name for the y-axis
24 | # name for the plot. Used also as a file name for saving the plot.
25 | plot_name="Performance",
26 | args={},
27 | )
28 | )
29 | def benchmark(T, provider):
30 | dtype = torch.bfloat16
31 | B, D = 16, 512
32 |
33 | x = torch.randn((B, T, D), dtype=dtype, device='cuda')
34 | g = torch.randn((B, T, D), dtype=dtype, device='cuda').sigmoid()
35 | x = (1 - g) * x
36 | x, g = (i.detach().clone().to(dtype).requires_grad_() for i in (x, g))
37 | do = torch.randn_like(x, dtype=dtype)
38 | quantiles = [0.5, 0.2, 0.8]
39 | results = 0, 0, 0
40 | if provider == 'chunk':
41 | results = triton.testing.do_bench(lambda: chunk_hgrn(x, g), quantiles=quantiles)
42 | if provider == 'recurrent':
43 | results = triton.testing.do_bench(lambda: fused_recurrent_hgrn(x, g), quantiles=quantiles)
44 | if provider == 'chunk_bwd':
45 | results = triton.testing.do_bench(lambda: chunk_hgrn(x, g)[0].backward(do), quantiles=quantiles)
46 | if provider == 'recurrent_bwd':
47 | results = triton.testing.do_bench(lambda: fused_recurrent_hgrn(x, g)[0].backward(do), quantiles=quantiles)
48 | return results
49 |
50 |
51 | if __name__ == '__main__':
52 | benchmark.run(print_data=True)
53 |
--------------------------------------------------------------------------------
/benchmarks/ops/benchmark_retention.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import os
4 |
5 | import torch
6 | import triton
7 |
8 | from fla.ops.retention import (chunk_retention, fused_recurrent_retention,
9 | parallel_retention)
10 | from fla.ops.retention.naive import naive_retention
11 |
12 | try:
13 | from flash_attn import flash_attn_func
14 | HAS_FLASH = True
15 | except BaseException:
16 | HAS_FLASH = False
17 |
18 |
19 | @triton.testing.perf_report(
20 | triton.testing.Benchmark(
21 | # argument names to use as an x-axis for the plot
22 | x_names=['T'],
23 | # different possible values for `x_name`
24 | x_vals=[128 * 2 ** i for i in range(0, 8)],
25 | # argument name whose value corresponds to a different line in the plot
26 | line_arg='provider',
27 | # possible values for `line_arg``
28 | line_vals=['fused_chunk', 'chunk', 'parallel',
29 | 'chunk_bwd', 'parallel_bwd'] + (['flash', 'flash_bwd'] if HAS_FLASH else []),
30 | # label name for the lines
31 | line_names=['fused_chunk_fwd', 'chunk_fwd', 'parallel_fwd',
32 | 'chunk_fwdbwd', 'parallel_fwdbwd'] + (['flash_fwd', 'flash_fwdbwd'] if HAS_FLASH else []),
33 | # line styles
34 | styles=[('green', '-'), ('blue', '-'), ('red', '-'), ('green', 'dotted'), ('blue', 'dotted'),
35 | ('red', 'dotted')] + ([('cyan', '-'), ('cyan', 'dotted')] if HAS_FLASH else []),
36 | ylabel="Execution Time (ms)", # label name for the y-axis
37 | # name for the plot. Used also as a file name for saving the plot.
38 | plot_name="Performance",
39 | args={},
40 | )
41 | )
42 | def benchmark(T, provider):
43 | device = 'cuda'
44 | dtype = torch.bfloat16
45 | requires_grad = True
46 | B, H, D = 4, 8, 256
47 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
48 |
49 | if provider == 'flash' or provider == 'flash_bwd':
50 | q = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
51 | k = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
52 | v = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
53 | else:
54 | q = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
55 | k = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
56 | v = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
57 | do = torch.ones_like(q, dtype=dtype)
58 |
59 | quantiles = [0.5, 0.2, 0.8]
60 | results = 0, 0, 0
61 | if provider == 'torch':
62 | if T > 2048:
63 | return results
64 | results = triton.testing.do_bench(lambda: naive_retention(q, k, v), quantiles=quantiles)
65 | elif provider == 'recurrent':
66 | results = triton.testing.do_bench(lambda: fused_recurrent_retention(q, k, v), quantiles=quantiles)
67 | elif provider == 'chunk':
68 | results = triton.testing.do_bench(lambda: chunk_retention(q, k, v), quantiles=quantiles)
69 | elif provider == 'parallel':
70 | results = triton.testing.do_bench(lambda: parallel_retention(q, k, v), quantiles=quantiles)
71 | elif provider == 'torch_bwd':
72 | if T > 2048:
73 | return results
74 | results = triton.testing.do_bench(lambda: naive_retention(q, k, v).backward(do), quantiles=quantiles)
75 | elif provider == 'recurrent_bwd':
76 | results = triton.testing.do_bench(lambda: fused_recurrent_retention(q, k, v)[0].backward(do), quantiles=quantiles)
77 | elif provider == 'chunk_bwd':
78 | results = triton.testing.do_bench(lambda: chunk_retention(q, k, v)[0].backward(do), quantiles=quantiles)
79 | elif provider == 'parallel_bwd':
80 | results = triton.testing.do_bench(lambda: parallel_retention(q, k, v)[0].backward(do), quantiles=quantiles)
81 | elif provider == 'flash':
82 | results = triton.testing.do_bench(lambda: flash_attn_func(q, k, v, causal=True), quantiles=quantiles)
83 | elif provider == 'flash_bwd':
84 | results = triton.testing.do_bench(lambda: flash_attn_func(q, k, v, causal=True).backward(do), quantiles=quantiles)
85 | return results
86 |
87 |
88 | if __name__ == '__main__':
89 | benchmark.run(print_data=True)
90 |
--------------------------------------------------------------------------------
/benchmarks/ops/benchmark_rwkv6.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import triton
5 | from torch.nn import functional as F
6 |
7 | from fla.ops.gla import chunk_gla
8 | from fla.ops.retention import chunk_retention
9 | from fla.ops.rwkv6 import chunk_rwkv6
10 |
11 | try:
12 | from flash_attn import flash_attn_func
13 | HAS_FLASH = True
14 | except BaseException:
15 | HAS_FLASH = False
16 |
17 |
18 | @triton.testing.perf_report(
19 | triton.testing.Benchmark(
20 | # argument names to use as an x-axis for the plot
21 | x_names=['T'],
22 | # different possible values for `x_name`
23 | x_vals=[128 * 2 ** i for i in range(0, 8)],
24 | # argument name whose value corresponds to a different line in the plot
25 | line_arg='provider',
26 | # possible values for `line_arg``
27 | line_vals=['rwkv6', 'gla', 'rwkv6_bwd', 'gla_bwd', 'retention_bwd', 'flash_bwd'],
28 | # label name for the lines
29 | line_names=['rwkv6', 'gla', 'rwkv6_bwd', 'gla_bwd', 'retention_bwd', 'flash_bwd'],
30 | # line styles
31 | styles=[('green', '-'), ('blue', '--'), ('red', '-.'),
32 | ('cyan', ':'), ('yellow', 'dotted'), ('black', ':')],
33 | ylabel="Execution Time (ms)", # label name for the y-axis
34 | # name for the plot. Used also as a file name for saving the plot.
35 | plot_name="Performance",
36 | args={},
37 | )
38 | )
39 | def benchmark(T, provider):
40 | device = 'cuda'
41 | dtype = torch.bfloat16
42 | requires_grad = True
43 | B, H, D = 16, 8, 128
44 |
45 | q = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
46 | k = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
47 | v = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
48 | if provider.startswith('flash'):
49 | q = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
50 | k = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
51 | v = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
52 | if provider.startswith('gla'):
53 | g = F.logsigmoid(torch.randn(B, H, T, D, device=device, dtype=dtype))
54 | g = g.clamp_min(-5).requires_grad_(requires_grad)
55 | if provider.startswith('rwkv6'):
56 | w = F.logsigmoid(torch.randn(B, H, T, D, device=device, dtype=dtype)).requires_grad_(True)
57 | u = torch.randn(H, D, device=device, dtype=dtype).requires_grad_(True)
58 |
59 | do = torch.ones_like(v, dtype=dtype)
60 |
61 | quantiles = [0.5, 0.2, 0.8]
62 | if provider == 'rwkv6':
63 | results = triton.testing.do_bench(lambda: chunk_rwkv6(q, k, v, w, u), quantiles=quantiles)
64 | elif provider == 'gla':
65 | results = triton.testing.do_bench(lambda: chunk_gla(q, k, v, g), quantiles=quantiles)
66 | elif provider == 'rwkv6_bwd':
67 | results = triton.testing.do_bench(lambda: chunk_rwkv6(q, k, v, w, u)[0].backward(do), quantiles=quantiles)
68 | elif provider == 'gla_bwd':
69 | results = triton.testing.do_bench(lambda: chunk_gla(q, k, v, g)[0].backward(do), quantiles=quantiles)
70 | elif provider == 'retention_bwd':
71 | results = triton.testing.do_bench(lambda: chunk_retention(q, k, v)[0].backward(do), quantiles=quantiles)
72 | elif provider == 'flash_bwd':
73 | results = triton.testing.do_bench(lambda: flash_attn_func(q, k, v, causal=True).backward(do), quantiles=quantiles)
74 | return results
75 |
76 |
77 | if __name__ == '__main__':
78 | benchmark.run(print_data=True)
79 |
--------------------------------------------------------------------------------
/benchmarks/ops/benchmark_simple_gla_vs_mamba2.py:
--------------------------------------------------------------------------------
1 | """
2 | Dependencies:
3 | $ pip install mamba-ssm==2.2.2 triton==2.3.1
4 |
5 | For correctness check, see:
6 | https://github.com/sustcsonglin/flash-linear-attention/pull/49
7 | """
8 |
9 | import torch
10 | import triton
11 |
12 | from fla.ops.simple_gla import chunk_simple_gla
13 |
14 | from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
15 |
16 |
17 | @triton.testing.perf_report(
18 | triton.testing.Benchmark(
19 | # argument names to use as an x-axis for the plot
20 | x_names=['T'],
21 | # different possible values for `x_name`
22 | x_vals=[64] + [128 * 2 ** i for i in range(0, 8)],
23 | # argument name whose value corresponds to a different line in the plot
24 | line_arg='provider',
25 | # possible values for `line_arg``
26 | line_vals=["chunk_simple_gla", "mamba2_ssd"],
27 | # label name for the lines
28 | line_names=["chunk_simple_gla", "mamba2_ssd"],
29 | # line styles
30 | styles=[('blue', '-'), ('red', '-')],
31 | ylabel="Execution Time (ms)", # label name for the y-axis
32 | # name for the plot. Used also as a file name for saving the plot.
33 | plot_name="Performance",
34 | args={},
35 | )
36 | )
37 | def benchmark(T, provider):
38 | # TODO: also add bwd pass benchmark
39 | device = 'cuda'
40 | dtype = torch.bfloat16
41 | B, H, D = 16, 8, 128
42 | # TODO: test more shapes
43 | # TODO: different values for D_V and D_QK
44 | # TODO: different values for H_Q and H_KV
45 | final_state = False # does not impact performance
46 |
47 | # initialize Mamba2-format inputs
48 | X_mamba = 0.1 * torch.randn(B, T, H, D, dtype=dtype, device=device)
49 | dt_mamba = torch.ones(B, T, H, dtype=dtype, device=device)
50 | A_mamba = -0.1 * torch.rand(H, dtype=dtype, device=device)
51 | B_mamba = 0.1 * torch.randn(B, T, H, D, dtype=dtype, device=device)
52 | C_mamba = 0.1 * torch.randn(B, T, H, D, dtype=dtype, device=device)
53 |
54 | quantiles = [0.5, 0.2, 0.8]
55 | if provider == 'chunk_simple_gla':
56 | # mapping inputs Mamba2 -> FLA
57 | # C, B, X: [B, T, H, D] -> [B, H, T, D]
58 | # g: [B, T, H] -> [B, H, T]
59 | q = C_mamba.transpose(1, 2).contiguous()
60 | k = B_mamba.transpose(1, 2).contiguous()
61 | v = X_mamba.transpose(1, 2).contiguous()
62 | g = (A_mamba * dt_mamba).transpose(1, 2).contiguous()
63 | # NOTE: whether to include the memory-copy cost of `contiguous()`?
64 | # this depends on the memory layout used by surrounding non-SSM layers
65 |
66 | results = triton.testing.do_bench(
67 | lambda: chunk_simple_gla(
68 | q, k, v, g, scale=1.0, output_final_state=final_state
69 | ), quantiles=quantiles
70 | )
71 |
72 | elif provider == 'mamba2_ssd':
73 | # NOTE: `chunk_size` is configurable in mamba2 kernel
74 | # here sets to the same hard-coded `BT = 64` as in simple_gla kernel
75 | # TODO: benchmark different chunk sizes
76 | results = triton.testing.do_bench(
77 | lambda: mamba_chunk_scan_combined(
78 | X_mamba, dt_mamba, A_mamba, B_mamba, C_mamba,
79 | chunk_size=64, D=None, return_final_states=final_state
80 | ),
81 | quantiles=quantiles
82 | )
83 | return results
84 |
85 | if __name__ == '__main__':
86 | benchmark.run(print_data=True, save_path='.')
87 |
--------------------------------------------------------------------------------
/evals/harness.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from __future__ import annotations
4 |
5 | import sys
6 | import os
7 | import mom # noqa
8 | import fla
9 | from lm_eval.__main__ import cli_evaluate
10 | from lm_eval.api.registry import register_model
11 | from lm_eval.models.huggingface import HFLM
12 |
13 |
14 | @register_model('fla')
15 | class FlashLinearAttentionLMWrapper(HFLM):
16 | def __init__(self, **kwargs) -> FlashLinearAttentionLMWrapper:
17 |
18 | # TODO: provide options for doing inference with different kernels
19 |
20 | super().__init__(**kwargs)
21 |
22 |
23 | if __name__ == "__main__":
24 | cli_evaluate()
25 |
--------------------------------------------------------------------------------
/lm-eval-harness/CITATION.bib:
--------------------------------------------------------------------------------
1 | @misc{eval-harness,
2 | author = {Gao, Leo and Tow, Jonathan and Abbasi, Baber and Biderman, Stella and Black, Sid and DiPofi, Anthony and Foster, Charles and Golding, Laurence and Hsu, Jeffrey and Le Noac'h, Alain and Li, Haonan and McDonell, Kyle and Muennighoff, Niklas and Ociepa, Chris and Phang, Jason and Reynolds, Laria and Schoelkopf, Hailey and Skowron, Aviya and Sutawika, Lintang and Tang, Eric and Thite, Anish and Wang, Ben and Wang, Kevin and Zou, Andy},
3 | title = {A framework for few-shot language model evaluation},
4 | month = 12,
5 | year = 2023,
6 | publisher = {Zenodo},
7 | version = {v0.4.0},
8 | doi = {10.5281/zenodo.10256836},
9 | url = {https://zenodo.org/records/10256836}
10 | }
11 |
--------------------------------------------------------------------------------
/lm-eval-harness/CODEOWNERS:
--------------------------------------------------------------------------------
1 | * @haileyschoelkopf @lintangsutawika
2 |
--------------------------------------------------------------------------------
/lm-eval-harness/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 EleutherAI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/lm-eval-harness/docs/README.md:
--------------------------------------------------------------------------------
1 | # Eval Harness Documentation
2 |
3 | Welcome to the docs for the LM Evaluation Harness!
4 |
5 | ## Table of Contents
6 |
7 | * To learn about the public interface of the library, as well as how to evaluate via the commandline or as integrated into an external library, see the [Interface](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/docs/interface.md)
8 | * To learn how to add a new library, API, or model type to the library, as well as a quick explainer on the types of ways to evaluate an LM, see the [Model Guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/docs/model_guide.md).
9 | * For a crash course on adding new tasks to the library, see our [New Task Guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/docs/new_task_guide.md).
10 | * To learn more about pushing the limits of task configuration that the Eval Harness supports, see the [Task Configuration Guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/docs/task_guide.md).
11 |
--------------------------------------------------------------------------------
/lm-eval-harness/docs/decontamination.md:
--------------------------------------------------------------------------------
1 | # Decontamination
2 |
3 | ## Usage
4 |
5 | Simply add a "--decontamination_ngrams_path" when running \__main\__.py. The provided directory should contain
6 | the ngram files and info.json produced in "Pile Ngram Generation" further down.
7 |
8 | ```bash
9 | python -m lm_eval \
10 | --model gpt2 \
11 | --device 0 \
12 | --tasks sciq \
13 | --decontamination_ngrams_path path/containing/training/set/ngrams
14 | ```
15 |
16 | ## Background
17 | Downstream evaluations test model generalization, and are less useful when test set data also exists in the training set, referred to as leakage or contamination.
18 |
19 | Filtering your training set against the test set is a good first step, however this isn't always possible, as in the case of a new benchmark or one that wasn't considered prior to model training. When training set filtering isn't possible, it is useful to measure the impact of test set leakage by detecting the contaminated test examples and producing a clean version of the benchmark.
20 |
21 | The basis for our decontamination procedure can be found in Appendix C of "Language Models are Few-Shot Learners". OpenAI defined a test document as contaminated if any N-gram overlap existed with any training document. They used a range of N values between 8 and 13 depending on dataset, while we just used 13 for simplicity.
22 |
23 | ## Implementation
24 | Contamination detection can be found in `lm_eval/decontaminate.py` with supporting code in `lm_eval/decontamination/`.
25 |
26 | decontaminate.py does the following:
27 | 1. Build dictionaries of all ngrams and their corresponding evaluation/document ids.
28 | 2. Scan through sorted files containing training set n-grams.
29 | 3. If a match is found, the corresponding evaluation/document combinations are marked as contaminated.
30 |
31 | `lm_eval/evaluator.py` can then produce a clean version of the benchmark by excluding the results of contaminated documents. For each metric, a clean version will be shown in the results with a "decontaminate" suffix.
32 |
33 | This is disabled by default for new tasks, to support decontamination on a task override the "should_decontaminate" and "doc_to_decontamination_query" methods. For more details see the [task guide](task_guide.md).
34 |
35 | ## Pile Ngram Generation
36 | The relevant scripts can be found in `scripts/clean_training_data`, which also import from
37 | `lm_eval/decontamination/`
38 |
39 | 1. git clone https://github.com/EleutherAI/lm-evaluation-harness.git
40 | 2. pip install -r requirements.txt
41 | 3. Download The Pile from [The Eye](https://the-eye.eu/public/AI/pile/train/)
42 | 4. Place pile files in "pile" directory under "lm-evaluation-harness" (or create a symlink)
43 | 5. Run generate_13_grams.
44 |
45 | ```bash
46 | export PYTHONHASHSEED=0
47 | python -m scripts/clean_training_data/generate_13_grams \
48 | -dir path/to/working/directory \
49 | -n 13 \
50 | -buckets 500
51 | ```
52 |
53 | Took approximately 4 days for us. We had the time to wait, but this could be scaled out by doing partial pile scans on multiple instances of this script and merging the relevant buckets. We fixed PYTHONHASHSEED to ensure reproducibility of bucket hashing in case you need to stop and start.
54 |
55 | 6. Sort the generated 13-grams.
56 | ```bash
57 | python -m scripts/clean_training_data/sort_13_gram_buckets \
58 | -dir path/to/working/directory/output
59 | ```
60 |
61 | Took approximately 5 days for us. You could speed this up by spreading the files around to different machines and running the sort script before gathering them together.
62 |
63 | 7. Compress the sorted 13 grams files and place them together with info.json.
64 |
65 | This step only takes a few hours.
66 |
67 | ```bash
68 | python -m scripts/clean_training_data/compress_and_package \
69 | -dir path/to/working/directory \
70 | -output path/to/final/directory \
71 | -procs 8
72 | ```
73 |
74 | Congratulations, the final directory can now be passed to lm-evaulation-harness with the "--decontamination_ngrams_path" argument.
75 |
--------------------------------------------------------------------------------
/lm-eval-harness/docs/img/fewshot_example_gpt3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenSparseLLMs/MoM/2c3ae59f5c0b749f916189433b8dad5a3415dc08/lm-eval-harness/docs/img/fewshot_example_gpt3.png
--------------------------------------------------------------------------------
/lm-eval-harness/ignore.txt:
--------------------------------------------------------------------------------
1 | ROUGE
2 | rouge
3 | nin
4 | maka
5 | mor
6 | te
7 | ond
8 | extraversion
9 |
--------------------------------------------------------------------------------
/lm-eval-harness/launch.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from typing import List, Optional
4 |
5 | from lm_eval.__main__ import cli_evaluate
6 |
7 |
8 | from datetime import datetime
9 | import os
10 | import importlib.util
11 |
12 | import click
13 | from tqdm import tqdm
14 |
15 |
16 | MAX_WORKERS_PER_GPU = 1
17 |
18 |
19 | def execute_config(
20 | model: str,
21 | task: str,
22 | batch_size: int,
23 | limit: int,
24 | output_dir: str,
25 | num_fewshot: int,
26 | context_length: int = 1000,
27 | answer_length: int = 50,
28 | cutting_context: bool = False,
29 | decode_mode: str = "default",
30 | ):
31 | # Save the original standard output
32 | import subprocess
33 |
34 | output_dir = os.path.join(output_dir, model, task)
35 |
36 | args = [
37 | "lm_eval",
38 | "--model", "based_lm",
39 | "--model_args", f"checkpoint_name={model}",
40 | "--tasks", task,
41 | "--device", "cuda:0",
42 | "--batch_size", str(batch_size),
43 | "--log_samples",
44 | "--output_path", output_dir,
45 | "--decode_mode", decode_mode,
46 | "--num_fewshot", str(num_fewshot),
47 | # ,
48 |
49 | ]
50 |
51 | if cutting_context:
52 | args.extend(["--cutting_context"])
53 | args.extend(["--context_length", str(context_length)])
54 | args.extend(["--answer_length", str(answer_length)])
55 | args.extend(["--context_key", "text"])
56 |
57 | if 'squad' not in task:
58 | args.extend(["--answer_key", "key", "value"])
59 | else:
60 | args.extend(["--answer_key", "value"])
61 |
62 | if limit is not None:
63 | args.extend(["--limit", str(limit)])
64 |
65 | subprocess.run(args)
66 |
67 | print(f"Decoded with mode: {decode_mode}")
68 |
69 |
70 | @click.command()
71 | @click.option("-m", "--model", type=str, multiple=True)
72 | @click.option("-t", "--task", type=str, multiple=True)
73 | @click.option("-p", "--parallelize", is_flag=True)
74 | @click.option("--gpus", default=None, type=str)
75 | @click.option("--batch-size", default=8, type=int)
76 | @click.option("--limit", default=None, type=int)
77 | @click.option("--num_fewshot", default=0, type=int)
78 | @click.option("--context_length", default=1000, type=int)
79 | @click.option("--answer_length", default=50, type=int)
80 | @click.option("--output_dir", default="output", type=str)
81 | @click.option("--cutting_context", is_flag=True)
82 | @click.option("--decode_mode", default="default", type=str)
83 | def main(
84 | model: List[str],
85 | task: List[str],
86 | batch_size: int,
87 | limit: Optional[int],
88 | parallelize: bool,
89 | gpus: str,
90 | num_fewshot: int = 0,
91 | output_dir: str = "output",
92 | context_length: int = 1000,
93 | answer_length: int = 50,
94 | cutting_context: bool = False,
95 | decode_mode: str = 'default'
96 | ):
97 |
98 | if limit < 0: limit = None
99 |
100 | if gpus is not None:
101 | os.environ["CUDA_VISIBLE_DEVICES"] = gpus
102 |
103 | # Load the given Python file as a module
104 | configs = [
105 | {"model": m, "task": t} for m in model for t in task
106 | ]
107 |
108 | use_ray = parallelize and len(configs) > 0
109 | if use_ray:
110 | import ray
111 | # ray was killing workers due to OOM, but it didn't seem to be necessary
112 | os.environ["RAY_memory_monitor_refresh_ms"] = "0"
113 | ray.init(ignore_reinit_error=True, log_to_driver=True)
114 |
115 | print(f"Running sweep with {len(configs)} configs")
116 |
117 | output_dir = f"{output_dir}/{datetime.now().strftime('%y-%m-%d_%H-%M')}"
118 |
119 | # Run each script in parallel using Ray
120 | if not use_ray:
121 | for config in configs:
122 | execute_config(
123 | **config,
124 | batch_size=batch_size,
125 | limit=limit,
126 | output_dir=output_dir,
127 | num_fewshot=num_fewshot,
128 | context_length=context_length,
129 | answer_length=answer_length,
130 | cutting_context=cutting_context,
131 | decode_mode=decode_mode,
132 | )
133 | else:
134 | completed = 0
135 | total = len(configs)
136 | print(f"Completed: {completed} ({completed / total:0.1%}) | Total: {total}")
137 |
138 | remote = ray.remote(num_gpus=(1 // MAX_WORKERS_PER_GPU))(execute_config)
139 | futures = [remote.remote(
140 | **config, batch_size=batch_size, limit=limit, output_dir=output_dir, num_fewshot=num_fewshot,
141 | context_length=context_length, answer_length=answer_length, cutting_context=cutting_context,
142 | decode_mode=decode_mode,
143 | ) for config in configs]
144 |
145 | while futures:
146 | complete, futures = ray.wait(futures)
147 | completed += len(complete)
148 | print(f"Completed: {completed} ({completed / total:0.1%}) | Total: {total}")
149 |
150 | ray.shutdown()
151 |
152 | if __name__ == "__main__":
153 | main()
154 |
--------------------------------------------------------------------------------
/lm-eval-harness/launch_hf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from typing import List, Optional
4 |
5 | from lm_eval.__main__ import cli_evaluate
6 |
7 |
8 | from datetime import datetime
9 | import os
10 | import importlib.util
11 |
12 | import click
13 | from tqdm import tqdm
14 |
15 |
16 | MAX_WORKERS_PER_GPU = 1
17 |
18 |
19 | def execute_config(
20 | model: str,
21 | task: str,
22 | batch_size: int,
23 | limit: int,
24 | output_dir: str,
25 | context_length: int = 1000,
26 | cutting_context: bool = False,
27 | answer_length: int=50,
28 | ):
29 | # Save the original standard output
30 | import subprocess
31 |
32 | output_dir = os.path.join(output_dir, model, task)
33 |
34 | if 'mamba' in model.lower() and 'rw' not in model.lower(): model_name = "mamba_ssm"
35 | else: model_name = 'hf-auto'
36 |
37 | args = [
38 | "lm_eval",
39 | "--model", f"{model_name}",
40 | "--model_args", f"checkpoint_name={model}",
41 | "--tasks", task,
42 | "--device", "cuda:0",
43 | "--batch_size", str(batch_size),
44 | "--log_samples",
45 | "--output_path", output_dir
46 | ]
47 |
48 | if cutting_context:
49 | args.extend(["--cutting_context"])
50 | args.extend(["--context_length", str(context_length)])
51 | args.extend(["--answer_length", str(answer_length)])
52 | args.extend(["--context_key", "text"])
53 |
54 | if limit is not None:
55 | args.extend(["--limit", str(limit)])
56 |
57 | subprocess.run(args)
58 |
59 |
60 |
61 | @click.command()
62 | @click.option("-m", "--model", type=str, multiple=True)
63 | @click.option("-t", "--task", type=str, multiple=True)
64 | @click.option("-p", "--parallelize", is_flag=True)
65 | @click.option("--gpus", default=None, type=str)
66 | @click.option("--batch-size", default=8, type=int)
67 | @click.option("--limit", default=None, type=int)
68 | @click.option("--context_length", default=1000, type=int)
69 | @click.option("--answer_length", default=50, type=int)
70 | @click.option("--cutting_context", is_flag=True)
71 | @click.option("--output_dir", default="output", type=str)
72 | def main(
73 | model: List[str],
74 | task: List[str],
75 | batch_size: int,
76 | limit: Optional[int],
77 | parallelize: bool,
78 | gpus: str,
79 | context_length: int,
80 | cutting_context: bool,
81 | answer_length: int,
82 | output_dir: str,
83 | ):
84 | if limit < 0: limit = None
85 |
86 | if gpus is not None:
87 | os.environ["CUDA_VISIBLE_DEVICES"] = gpus
88 |
89 | # Load the given Python file as a module
90 | configs = [
91 | {"model": m, "task": t} for m in model for t in task
92 | ]
93 |
94 | use_ray = parallelize and len(configs) > 0
95 | if use_ray:
96 | import ray
97 | # ray was killing workers due to OOM, but it didn't seem to be necessary
98 | os.environ["RAY_memory_monitor_refresh_ms"] = "0"
99 | ray.init(ignore_reinit_error=True, log_to_driver=True)
100 |
101 | print(f"Running sweep with {len(configs)} configs")
102 |
103 | output_dir = f"{output_dir}/{datetime.now().strftime('%y-%m-%d_%H-%M')}"
104 |
105 | # Run each script in parallel using Ray
106 | if not use_ray:
107 | for config in configs:
108 | execute_config(
109 | **config,
110 | batch_size=batch_size,
111 | limit=limit,
112 | output_dir=output_dir,
113 | context_length=context_length,
114 | answer_length=answer_length,
115 | cutting_context=cutting_context
116 | )
117 | else:
118 | completed = 0
119 | total = len(configs)
120 | print(f"Completed: {completed} ({completed / total:0.1%}) | Total: {total}")
121 |
122 | remote = ray.remote(num_gpus=(1 // MAX_WORKERS_PER_GPU))(execute_config)
123 | futures = [remote.remote(
124 | **config, batch_size=batch_size, limit=limit, output_dir=output_dir,
125 | answer_length=answer_length, cutting_context=cutting_context, context_length=context_length,
126 | ) for config in configs]
127 |
128 | while futures:
129 | complete, futures = ray.wait(futures)
130 | completed += len(complete)
131 | print(f"Completed: {completed} ({completed / total:0.1%}) | Total: {total}")
132 |
133 | ray.shutdown()
134 |
135 |
136 |
137 | if __name__ == "__main__":
138 | main()
139 |
140 |
--------------------------------------------------------------------------------
/lm-eval-harness/launch_jrt.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from typing import List, Optional
4 |
5 | from lm_eval.__main__ import cli_evaluate
6 |
7 |
8 | from datetime import datetime
9 | import os
10 | import importlib.util
11 |
12 | import click
13 | from tqdm import tqdm
14 |
15 |
16 | MAX_WORKERS_PER_GPU = 1
17 |
18 |
19 | def execute_config(
20 | model: str,
21 | task: str,
22 | batch_size: int,
23 | limit: int,
24 | output_dir: str,
25 | num_fewshot: int,
26 | context_length: int = 1000,
27 | answer_length: int = 50,
28 | cutting_context: bool = False,
29 | decode_mode: str = "default",
30 | ):
31 | # Save the original standard output
32 | import subprocess
33 |
34 | output_dir = os.path.join(output_dir, model, task)
35 |
36 | args = [
37 | "lm_eval",
38 | "--model", "jrt_lm",
39 | "--model_args", f"checkpoint_name={model}",
40 | "--tasks", task,
41 | "--device", "cuda:0",
42 | "--batch_size", str(batch_size),
43 | "--log_samples",
44 | "--output_path", output_dir,
45 | "--decode_mode", decode_mode,
46 | "--num_fewshot", str(num_fewshot),
47 | ]
48 |
49 | if cutting_context:
50 | args.extend(["--cutting_context"])
51 | args.extend(["--context_length", str(context_length)])
52 | args.extend(["--answer_length", str(answer_length)])
53 | args.extend(["--context_key", "text"])
54 |
55 | if 'squad' not in task:
56 | args.extend(["--answer_key", "key", "value"])
57 | else:
58 | args.extend(["--answer_key", "value"])
59 |
60 | if limit is not None:
61 | args.extend(["--limit", str(limit)])
62 |
63 | subprocess.run(args)
64 |
65 | print(f"Decoded with mode: {decode_mode}")
66 |
67 |
68 | @click.command()
69 | @click.option("-m", "--model", type=str, multiple=True)
70 | @click.option("-t", "--task", type=str, multiple=True)
71 | @click.option("-p", "--parallelize", is_flag=True)
72 | @click.option("--gpus", default=None, type=str)
73 | @click.option("--batch-size", default=8, type=int)
74 | @click.option("--limit", default=None, type=int)
75 | @click.option("--num_fewshot", default=0, type=int)
76 | @click.option("--context_length", default=1000, type=int)
77 | @click.option("--answer_length", default=50, type=int)
78 | @click.option("--output_dir", default="output", type=str)
79 | @click.option("--cutting_context", is_flag=True)
80 | @click.option("--decode_mode", default="default", type=str)
81 | def main(
82 | model: List[str],
83 | task: List[str],
84 | batch_size: int,
85 | limit: Optional[int],
86 | parallelize: bool,
87 | gpus: str,
88 | num_fewshot: int = 0,
89 | output_dir: str = "output",
90 | context_length: int = 1000,
91 | answer_length: int = 50,
92 | cutting_context: bool = False,
93 | decode_mode: str = 'default'
94 | ):
95 |
96 | if limit < 0: limit = None
97 |
98 | if gpus is not None:
99 | os.environ["CUDA_VISIBLE_DEVICES"] = gpus
100 |
101 | # Load the given Python file as a module
102 | configs = [
103 | {"model": m, "task": t} for m in model for t in task
104 | ]
105 |
106 | use_ray = parallelize and len(configs) > 0
107 | if use_ray:
108 | import ray
109 | # ray was killing workers due to OOM, but it didn't seem to be necessary
110 | os.environ["RAY_memory_monitor_refresh_ms"] = "0"
111 | ray.init(ignore_reinit_error=True, log_to_driver=True)
112 |
113 | print(f"Running sweep with {len(configs)} configs")
114 |
115 | output_dir = f"{output_dir}/{datetime.now().strftime('%y-%m-%d_%H-%M')}"
116 |
117 | # Run each script in parallel using Ray
118 | if not use_ray:
119 | for config in configs:
120 | execute_config(
121 | **config,
122 | batch_size=batch_size,
123 | limit=limit,
124 | output_dir=output_dir,
125 | num_fewshot=num_fewshot,
126 | context_length=context_length,
127 | answer_length=answer_length,
128 | cutting_context=cutting_context,
129 | decode_mode=decode_mode,
130 | )
131 | else:
132 | completed = 0
133 | total = len(configs)
134 | print(f"Completed: {completed} ({completed / total:0.1%}) | Total: {total}")
135 |
136 | remote = ray.remote(num_gpus=(1 // MAX_WORKERS_PER_GPU))(execute_config)
137 | futures = [remote.remote(
138 | **config, batch_size=batch_size, limit=limit, output_dir=output_dir, num_fewshot=num_fewshot,
139 | context_length=context_length, answer_length=answer_length, cutting_context=cutting_context,
140 | decode_mode=decode_mode,
141 | ) for config in configs]
142 |
143 | while futures:
144 | complete, futures = ray.wait(futures)
145 | completed += len(complete)
146 | print(f"Completed: {completed} ({completed / total:0.1%}) | Total: {total}")
147 |
148 | ray.shutdown()
149 |
150 | if __name__ == "__main__":
151 | main()
152 |
--------------------------------------------------------------------------------
/lm-eval-harness/launch_local.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import torch
4 | import click
5 |
6 | from typing import List, Optional
7 |
8 | from lm_eval.__main__ import cli_evaluate
9 | from datetime import datetime
10 | import os
11 | import importlib.util
12 |
13 | from tqdm import tqdm
14 |
15 | MAX_WORKERS_PER_GPU = 1
16 | DEVICE = "cuda"
17 | torch.random.manual_seed(0)
18 |
19 |
20 | def execute_config(
21 | model: str,
22 | task: str,
23 | batch_size: int,
24 | limit: int,
25 | output_dir: str,
26 | num_fewshot: int,
27 |
28 | context_length: int = 1000,
29 | answer_length: int = 50,
30 | cutting_context: bool = False,
31 | decode_mode: str = 'default',
32 | ):
33 | # Save the original standard output
34 | import subprocess
35 |
36 | output_dir = os.path.join(output_dir, model, task)
37 |
38 | # pass flags to cli_evaluate() to override the defaults in argparse
39 | args = [
40 | "lm_eval",
41 | "--model", "lm_eval_model",
42 | "--model_args", f" checkpoint_name={model}",
43 | "--task", task,
44 | "--device", "cuda:0",
45 | "--batch_size", str(batch_size),
46 | "--output_path", output_dir,
47 | "--num_fewshot", str(num_fewshot),
48 | "--decode_mode", decode_mode,
49 |
50 | "--log_samples",
51 | "--write_out",
52 | # "--output", f"{run_name}/",
53 | ]
54 |
55 | if cutting_context:
56 | args.extend(["--cutting_context"])
57 | args.extend(["--context_length", str(context_length)])
58 | args.extend(["--answer_length", str(answer_length)])
59 | args.extend(["--context_key", "text"])
60 |
61 | if limit is not None:
62 | args.extend(["--limit", str(limit)])
63 |
64 | subprocess.run(args)
65 |
66 | print(f"Decoded with mode: {decode_mode}")
67 |
68 |
69 | @click.command()
70 | @click.option("-m", "--model", type=str, multiple=True)
71 | @click.option("-t", "--task", type=str, multiple=True)
72 | @click.option("-p", "--parallelize", is_flag=True)
73 | @click.option("--gpus", default=None, type=str)
74 | @click.option("--batch-size", default=8, type=int)
75 | @click.option("--limit", default=None, type=int)
76 | @click.option("--num_fewshot", default=0, type=int)
77 | @click.option("--context_length", default=1000, type=int)
78 | @click.option("--answer_length", default=50, type=int)
79 | @click.option("--output_dir", default="output", type=str)
80 | @click.option("--cutting_context", is_flag=True)
81 | @click.option("--decode_mode", default="default", type=str)
82 | def main(
83 | model: List[str],
84 | task: List[str],
85 | batch_size: int,
86 | limit: Optional[int],
87 | parallelize: bool,
88 | gpus: str,
89 | num_fewshot: int = 0,
90 | context_length: int = 1000,
91 | answer_length: int = 50,
92 | cutting_context: bool = False,
93 | output_dir: str = "output",
94 | decode_mode: str = 'default'
95 | ):
96 |
97 | if limit is not None and limit < 0: limit = None
98 |
99 | if gpus is not None:
100 | os.environ["CUDA_VISIBLE_DEVICES"] = gpus
101 |
102 | # Load the given Python file as a module
103 | configs = [
104 | {"model": m, "task": t} for m in model for t in task
105 | ]
106 |
107 | use_ray = parallelize and len(configs) > 0
108 | if use_ray:
109 | import ray
110 | # ray was killing workers due to OOM, but it didn't seem to be necessary
111 | os.environ["RAY_memory_monitor_refresh_ms"] = "0"
112 | ray.init(ignore_reinit_error=True, log_to_driver=True)
113 |
114 | print(f"Running sweep with {len(configs)} configs")
115 |
116 | output_dir = f"{output_dir}/{datetime.now().strftime('%y-%m-%d_%H-%M')}"
117 |
118 | # Run each script in parallel using Ray
119 | if not use_ray:
120 | for config in configs:
121 | execute_config(
122 | **config,
123 | batch_size=batch_size,
124 | limit=limit,
125 | output_dir=output_dir,
126 | num_fewshot=num_fewshot,
127 | context_length=context_length,
128 | answer_length=answer_length,
129 | cutting_context=cutting_context,
130 | decode_mode=decode_mode
131 | )
132 | else:
133 | completed = 0
134 | total = len(configs)
135 | print(f"Completed: {completed} ({completed / total:0.1%}) | Total: {total}")
136 |
137 | remote = ray.remote(num_gpus=(1 // MAX_WORKERS_PER_GPU))(execute_config)
138 | futures = [remote.remote(
139 | **config, batch_size=batch_size, limit=limit, output_dir=output_dir,
140 | num_fewshot=num_fewshot, context_length=context_length, answer_length=answer_length, cutting_context=cutting_context,
141 | decode_mode=decode_mode
142 | ) for config in configs]
143 |
144 | while futures:
145 | complete, futures = ray.wait(futures)
146 | completed += len(complete)
147 | print(f"Completed: {completed} ({completed / total:0.1%}) | Total: {total}")
148 |
149 | ray.shutdown()
150 |
151 |
152 | if __name__ == "__main__":
153 | main()
154 |
155 |
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/__init__.py:
--------------------------------------------------------------------------------
1 | from .evaluator import evaluate, simple_evaluate
2 |
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/api/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenSparseLLMs/MoM/2c3ae59f5c0b749f916189433b8dad5a3415dc08/lm-eval-harness/lm_eval/api/__init__.py
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/api/filter.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from dataclasses import dataclass
3 | from typing import Callable, Iterable, List, Union
4 |
5 | from lm_eval.api.instance import Instance
6 |
7 |
8 | class Filter(ABC):
9 | """
10 | Filter classes operate on a per-task level.
11 | They take all model outputs (`instance.resps` for all `task.instances`)
12 | across all instances of a task, and perform operations.
13 | In a single run, one can configure any number of separate filters or lists of filters.
14 |
15 | """
16 |
17 | def __init__(self, **kwargs) -> None:
18 | """
19 | Can define custom behavior here, if an individual instantiation of a Filter class should have state.
20 | """
21 |
22 | @abstractmethod
23 | def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable:
24 | """
25 | Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
26 | Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
27 | if pass in [, ] should return
28 | [, ]
29 | """
30 | return resps
31 |
32 |
33 | @dataclass
34 | class FilterEnsemble:
35 | """
36 | FilterEnsemble creates a pipeline applying multiple filters.
37 | Its intended usage is to stack multiple post-processing steps in order.
38 | `task.apply_filters` should use a list of FilterEnsemble classes that it stores, to apply each
39 | pipeline separately.
40 | """
41 |
42 | name: str
43 | filters: List[Callable[[], Filter]]
44 |
45 | def apply(self, instances: List[Instance]) -> None:
46 | resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
47 | resps, docs = list(resps), list(docs)
48 |
49 | for f in self.filters:
50 | # apply filters in sequence
51 | resps = f().apply(resps, docs)
52 |
53 | # add the end results after filtering to filtered_requests of their respective source instances.
54 | # has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
55 | for inst, resp in zip(instances, resps):
56 | inst.filtered_resps[self.name] = resp
57 |
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/api/instance.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Literal, Tuple
3 |
4 |
5 | @dataclass
6 | class Instance:
7 | request_type: Literal[
8 | "loglikelihood",
9 | "loglikelihood_rolling",
10 | "generate_until",
11 | "multiple_choice",
12 | ]
13 | doc: dict
14 | arguments: tuple
15 | idx: int
16 | metadata: Tuple[str, int, int] = field(
17 | default_factory=lambda: (None, None, None)
18 | ) # TODO: better typehints here
19 | resps: list = field(default_factory=list)
20 | filtered_resps: dict = field(default_factory=dict)
21 |
22 | # initialized after init
23 | task_name: str = None
24 | doc_id: str = None
25 | repeats: str = None
26 |
27 | def __post_init__(self) -> None:
28 | # unpack metadata field
29 | self.task_name, self.doc_id, self.repeats = self.metadata
30 |
31 | @property
32 | def args(self):
33 | """
34 | Returns (string,) where `string` is the string to calculate loglikelihood over
35 | """
36 | return (
37 | self.arguments if isinstance(self.arguments, tuple) else (self.arguments,)
38 | )
39 |
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/api/samplers.py:
--------------------------------------------------------------------------------
1 | class ContextSampler:
2 | def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None:
3 | self.rnd = rnd
4 | assert self.rnd, "must pass rnd to FewShotSampler!"
5 |
6 | self.task = task
7 | self.config = task._config
8 |
9 | self.target_delimiter = self.config.target_delimiter
10 | self.fewshot_delimiter = self.config.fewshot_delimiter
11 |
12 | self.doc_to_text = self.task.doc_to_text
13 | self.doc_to_target = self.task.doc_to_target
14 | self.doc_to_choice = self.task.doc_to_choice
15 |
16 | self.docs = docs # HF dataset split, provided by task._fewshot_docs()
17 | if fewshot_indices: # subset few-shot docs from
18 | self.docs = self.docs.select(fewshot_indices)
19 |
20 | def get_context(self, doc, num_fewshot):
21 | # draw an extra fewshot sample if using same split as evaluating on
22 | n_samples = (
23 | num_fewshot + 1
24 | if self.config.fewshot_split == self.config.test_split
25 | else num_fewshot
26 | )
27 |
28 | # draw `n_samples` docs from fewshot_docs
29 | fewshotex = self.sample(n_samples)
30 |
31 | # get rid of the doc that's the one we're evaluating, if it's in the fewshot
32 | # TODO: should we just stop people from using fewshot from same split as evaluating?
33 | selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
34 |
35 | labeled_examples = (
36 | self.fewshot_delimiter.join(
37 | [
38 | # TODO: is separating doc_to_text and doc_to_target by one space always desired?
39 | (
40 | self.doc_to_text(doc)
41 | if (
42 | self.config.doc_to_choice is None
43 | or isinstance(self.doc_to_text(doc), str)
44 | )
45 | else self.doc_to_choice(doc)[self.doc_to_text(doc)]
46 | )
47 | + self.target_delimiter
48 | + (
49 | str(self.doc_to_target(doc)[0])
50 | if isinstance(self.doc_to_target(doc), list)
51 | else self.doc_to_target(doc)
52 | if (
53 | self.config.doc_to_choice is None
54 | or isinstance(self.doc_to_target(doc), str)
55 | )
56 | else str(self.doc_to_choice(doc)[self.doc_to_target(doc)])
57 | )
58 | for doc in selected_docs
59 | ]
60 | )
61 | + self.fewshot_delimiter
62 | )
63 |
64 | return labeled_examples
65 |
66 | def sample(self, n):
67 | """
68 | Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
69 | """
70 |
71 | return self.rnd.sample(self.docs, n)
72 |
73 |
74 | class FirstNSampler(ContextSampler):
75 | def sample(self, n) -> None:
76 | """
77 | Draw the first `n` samples in order from the specified split.
78 | Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
79 | """
80 | assert (
81 | n <= len(self.docs)
82 | ), f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available."
83 | return self.docs[:n]
84 |
85 |
86 | class BalancedSampler(ContextSampler):
87 | def sample(self, n) -> None:
88 | """
89 | TODO: this should return approximately class-balanced samples from our fewshot examples.
90 | TODO: what order should they be in? maybe random?
91 | """
92 |
93 | pass
94 |
95 |
96 | class ManualSampler(ContextSampler):
97 | def sample(self, n) -> None:
98 | """ """
99 | pass
100 |
101 |
102 | SAMPLER_REGISTRY = {
103 | "default": ContextSampler,
104 | "first_n": FirstNSampler,
105 | }
106 |
107 |
108 | def get_sampler(name):
109 | try:
110 | return SAMPLER_REGISTRY[name]
111 | except KeyError:
112 | raise ValueError(
113 | f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}"
114 | )
115 |
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/decontamination/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenSparseLLMs/MoM/2c3ae59f5c0b749f916189433b8dad5a3415dc08/lm-eval-harness/lm_eval/decontamination/__init__.py
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/filters/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union
2 | from functools import partial
3 |
4 | from lm_eval.api.filter import FilterEnsemble
5 | from . import selection
6 | from . import extraction
7 | from . import transformation
8 |
9 |
10 | FILTER_REGISTRY = {
11 | "take_first": selection.TakeFirstFilter,
12 | "regex": extraction.RegexFilter,
13 | "majority_vote": selection.MajorityVoteFilter,
14 | "take_first_k": selection.TakeKFilter,
15 | "remove_whitespace": extraction.WhitespaceFilter,
16 | "lowercase": transformation.LowercaseFilter,
17 | "uppercase": transformation.UppercaseFilter,
18 | "map": transformation.MapFilter,
19 | # TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function
20 | # that takes an input and returns a scalar and then should select the max reward,
21 | # or should implement different filters for different ways of handling a reward model's inference.
22 | # "arg_max": selection.ArgMaxFilter,
23 | }
24 |
25 |
26 | def get_filter(filter_name: str) -> Union[type, str]:
27 | if filter_name in FILTER_REGISTRY:
28 | return FILTER_REGISTRY[filter_name]
29 | else:
30 | return filter_name
31 |
32 |
33 | def build_filter_ensemble(
34 | filter_name: str, components: List[List[str]]
35 | ) -> FilterEnsemble:
36 | """
37 | Create a filtering pipeline.
38 | """
39 | filters = []
40 | for function, kwargs in components:
41 | if kwargs is None:
42 | kwargs = {}
43 | # create a filter given its name in the registry
44 | f = partial(get_filter(function), **kwargs)
45 | # add the filter as a pipeline step
46 | filters.append(f)
47 |
48 | return FilterEnsemble(name=filter_name, filters=filters)
49 |
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/filters/decontamination.py:
--------------------------------------------------------------------------------
1 | from lm_eval.api.filter import Filter
2 |
3 |
4 | class DecontaminationFilter(Filter):
5 |
6 | """
7 | A filter which evaluates
8 | """
9 |
10 | name = "track_decontamination"
11 |
12 | def __init__(self, path) -> None:
13 | """
14 |
15 | TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path").
16 | should further cache result on a given (task_name, doc_id)
17 | """
18 | self._decontam_results = None
19 |
20 | def apply(self, resps, docs) -> None:
21 | """
22 | Return {"no_contamination", "only_contamination"} keys for the 2 different subsets
23 | """
24 | pass
25 |
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/filters/extraction.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | from lm_eval.api.filter import Filter
4 |
5 |
6 | class RegexFilter(Filter):
7 | """ """
8 |
9 | def __init__(
10 | self, regex_pattern: str = r"#### (\-?[0-9\.\,]+)", fallback: str = "[invalid]"
11 | ) -> None:
12 | """
13 | pass a string `regex` to run `re.compile(r"regex")` on.
14 | `fallback` defines the output returned if no matches for the regex are located.
15 | """
16 | self.regex_pattern = regex_pattern
17 | self.regex = re.compile(regex_pattern)
18 | self.fallback = fallback
19 |
20 | def apply(self, resps, docs):
21 | # here, we assume we have a list, in which each element is
22 | # a list of model responses for some particular input/target pair.
23 | # so we process each of these (same input/target response sets)
24 | # independently (and keep them a list.)
25 | def filter_set(inst):
26 | filtered = []
27 | for resp in inst:
28 | match = self.regex.search(resp)
29 | if match:
30 | match = match.group(1).strip()
31 | else:
32 | match = self.fallback
33 | filtered.append(match)
34 | return filtered
35 |
36 | # print(resps)
37 | filtered_resps = list(map(lambda x: filter_set(x), resps))
38 | # print(filtered_resps)
39 |
40 | return filtered_resps
41 |
42 |
43 | class WhitespaceFilter(Filter):
44 | """ """
45 |
46 | def __init__(self) -> None:
47 | pass
48 |
49 | def apply(self, resps, docs):
50 | def filter_set(inst):
51 | filtered_resp = []
52 | for resp in inst:
53 | if resp.startswith(" "):
54 | resp = resp[1:]
55 |
56 | filtered_resp.append(resp)
57 |
58 | return filtered_resp
59 |
60 | filtered_resps = [filter_set(resp) for resp in resps]
61 |
62 | return filtered_resps
63 |
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/filters/selection.py:
--------------------------------------------------------------------------------
1 | from collections import Counter
2 |
3 | from lm_eval.api.filter import Filter
4 |
5 |
6 | class TakeFirstFilter(Filter):
7 | def __init__(self) -> None:
8 | """
9 | Can define custom behavior here, if an individual instantiation of a Filter class should have state.
10 | """
11 |
12 | def apply(self, resps, docs):
13 | """
14 | Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
15 | """
16 | return map(lambda r: r[0], resps)
17 |
18 |
19 | class TakeKFilter(Filter):
20 | def __init__(self, **kwargs) -> None:
21 | self.k = kwargs.pop("k")
22 |
23 | super().__init__(**kwargs)
24 |
25 | def apply(self, resps, docs):
26 | # need resp to be subscriptable to check below
27 | resps = list(resps)
28 | # check we have at least k responses per doc, else we can't take the first k
29 | assert (
30 | len(resps[0]) >= self.k
31 | ), f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ."
32 | return map(lambda r: r[: self.k], resps)
33 |
34 |
35 | class MajorityVoteFilter(Filter):
36 | def __init__(self) -> None:
37 | """
38 | Can define custom behavior here, if an individual instantiation of a Filter class should have state.
39 | """
40 |
41 | def apply(self, resps, docs):
42 | """
43 | Each entry of `resps` is a list of model responses.
44 | We select the response that occurs most frequently in each entry of `resps`.
45 | """
46 |
47 | def select_majority(resp):
48 | counts = Counter(resp)
49 | vote = counts.most_common(1)[0][0]
50 | return vote
51 |
52 | return map(lambda r: [select_majority(r)], resps)
53 |
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/filters/transformation.py:
--------------------------------------------------------------------------------
1 | from lm_eval.api.filter import Filter
2 |
3 |
4 | class LowercaseFilter(Filter):
5 | def __init__(self) -> None:
6 | pass
7 |
8 | def apply(self, resps, docs):
9 | def filter_set(inst):
10 | return [resp.lower() for resp in inst]
11 |
12 | return [filter_set(resp) for resp in resps]
13 |
14 |
15 | class UppercaseFilter(Filter):
16 | def __init__(self) -> None:
17 | pass
18 |
19 | def apply(self, resps, docs):
20 | def filter_set(inst):
21 | return [resp.upper() for resp in inst]
22 |
23 | return [filter_set(resp) for resp in resps]
24 |
25 |
26 | class MapFilter(Filter):
27 | def __init__(self, mapping_dict: dict = None, default_value=None) -> None:
28 | """
29 | Initializes the MapFilter with a given mapping dictionary and default value.
30 |
31 | Args:
32 | - mapping_dict (dict): A dictionary containing the key-value mappings.
33 | Default is an empty dictionary.
34 | - default_value (Any): The value to be returned when a key is not found in the mapping_dict.
35 | Default is None.
36 |
37 | Example:
38 | mapper = MapFilter({'A': 1, 'B': 2}, default_value=0)
39 | """
40 | if mapping_dict is None:
41 | mapping_dict = {}
42 | assert isinstance(
43 | mapping_dict, dict
44 | ), "Provided mapping_dict is not a dictionary"
45 | self.mapping_dict = mapping_dict
46 | self.default_value = default_value
47 |
48 | def apply(self, resps, docs):
49 | def filter_set(inst):
50 | return [self.mapping_dict.get(resp, self.default_value) for resp in inst]
51 |
52 | return [filter_set(resp) for resp in resps]
53 |
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/models/__init__.py:
--------------------------------------------------------------------------------
1 | from . import huggingface
2 | from . import openai_completions
3 | from . import textsynth
4 | from . import dummy
5 | from . import anthropic_llms
6 | from . import gguf
7 | from . import vllm_causallms
8 | from . import mamba_lm
9 | from . import based_lm
10 | from . import optimum_lm
11 | from . import neuron_optimum
12 | from . import local_lm
13 | from . import jrt_lm
14 | # TODO: implement __all__
15 |
16 |
17 | import os
18 |
19 | try:
20 | # enabling faster model download
21 | import hf_transfer
22 |
23 | os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
24 | except ImportError:
25 | pass
26 |
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/models/based_lm.py:
--------------------------------------------------------------------------------
1 | import re
2 | from transformers import AutoTokenizer
3 | import torch
4 |
5 | # from based.utils.hf import load_config_hf
6 | import json
7 | from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
8 | from transformers.utils.hub import cached_file
9 |
10 | from lm_eval.api.registry import register_model
11 | from lm_eval.models.huggingface import HFLM
12 |
13 | def load_config_hf(model_name):
14 | resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
15 | return json.load(open(resolved_archive_file))
16 |
17 | @register_model("based_lm")
18 | class BasedLMWrapper(HFLM):
19 | def __init__(
20 | self,
21 | checkpoint_name: str='hazyresearch/based-1.3b',
22 | arch: str=None,
23 | device: str = "cuda",
24 | **kwargs
25 | ) -> None:
26 |
27 | if arch is None:
28 | arch = checkpoint_name.split("/")[1].split("-")[0]
29 |
30 | assert arch in ['based', 'mamba', 'attn'], print("`arch` must be one of 'based', 'mamba', or 'attn'")
31 |
32 | if "backend" in kwargs:
33 | # based currently only supports causal models
34 | assert kwargs["backend"] == "causal"
35 |
36 | self.checkpoint_name = checkpoint_name
37 |
38 | if arch == "based":
39 | from based.models.gpt import GPTLMHeadModel
40 | model = GPTLMHeadModel.from_pretrained_hf(pretrained_model_name=self.checkpoint_name, device=device)
41 | elif arch == "mamba":
42 | from based.models.mamba import MambaLMHeadModel
43 | model = MambaLMHeadModel.from_pretrained_hf(pretrained_model_name=self.checkpoint_name, device=device)
44 | elif arch == "attn":
45 | from based.models.transformer.gpt import GPTLMHeadModel, GPT2Config, state_dict_from_pretrained; # TODO: construct a loading function
46 | config_data = load_config_hf(self.checkpoint_name)
47 | config = GPT2Config(**config_data)
48 | try:
49 | model = GPTLMHeadModel(config=config, device=device, dtype=torch.float16, multiple_of=256)
50 | state_dict = state_dict_from_pretrained(self.checkpoint_name, dtype=torch.float16)
51 | # remove the 'model.' prefix from the keys
52 | state_dict = {re.sub("^model\.", "", k): v for k, v in state_dict.items()}
53 | # remove Unexpected key(s) in state_dict: "train_metrics.num-tokens.count", "val_metrics.num-tokens.count", "test_metrics.num-tokens.count". from the state_dict
54 | state_dict = {k: v for k, v in state_dict.items() if "metrics" not in k}
55 | model.load_state_dict(state_dict)
56 | except:
57 | model = GPTLMHeadModel(config=config, device=device, dtype=torch.float16, multiple_of=128)
58 | state_dict = state_dict_from_pretrained(self.checkpoint_name, dtype=torch.float16)
59 | state_dict = {re.sub("^model\.", "", k): v for k, v in state_dict.items()}
60 | state_dict = {k: v for k, v in state_dict.items() if "metrics" not in k}
61 | model.load_state_dict(state_dict)
62 | else:
63 | raise ValueError(f"Unsupported model {arch}")
64 |
65 | tokenizer_name = kwargs.get("tokenizer", "gpt2")
66 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
67 | tokenizer.model_max_length = 2048
68 |
69 | model.device = device
70 |
71 | super().__init__(
72 | pretrained=model,
73 | # set appropriate defaults for tokenizer, max length, etc
74 | backend=kwargs.get("backend", "causal"),
75 | max_length=kwargs.get("max_length", 2048),
76 | tokenizer=tokenizer,
77 | device=device,
78 | **kwargs,
79 | )
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/models/dummy.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | from lm_eval.api.model import LM
4 | from lm_eval.api.registry import register_model
5 |
6 |
7 | @register_model("dummy")
8 | class DummyLM(LM):
9 | def __init__(self) -> None:
10 | super().__init__()
11 |
12 | @classmethod
13 | def create_from_arg_string(cls, arg_string, additional_config=None):
14 | return cls()
15 |
16 | def loglikelihood(self, requests):
17 | res = []
18 |
19 | for _ in requests:
20 | res.append((-random.random(), False))
21 |
22 | return res
23 |
24 | def generate_until(self, requests):
25 | res = []
26 |
27 | for ctx, _ in requests:
28 | res.append("lol")
29 | assert ctx.strip() != ""
30 |
31 | return res
32 |
33 | def loglikelihood_rolling(self, requests):
34 | res = []
35 |
36 | for _ in requests:
37 | res.append(-random.random())
38 |
39 | return res
40 |
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/models/gguf.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import time
3 |
4 | import requests
5 | from requests.exceptions import RequestException
6 | from tqdm import tqdm
7 |
8 | from lm_eval.api.model import LM
9 | from lm_eval.api.registry import register_model
10 |
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | def get_result(logprobs, context_length):
16 | is_greedy = True
17 | offsets = logprobs["text_offset"]
18 | tokens = logprobs["tokens"]
19 | tokens_logprobs = logprobs["token_logprobs"]
20 |
21 | idx = 0
22 | while offsets[idx] < context_length:
23 | idx += 1
24 | continuation_logprobs = sum(tokens_logprobs[idx:-1])
25 | for i in range(idx, len(tokens)):
26 | token = tokens[i]
27 | top_tokens = logprobs["top_logprobs"][i]
28 | top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
29 | if top_token != token:
30 | is_greedy = False
31 | break
32 |
33 | return continuation_logprobs, is_greedy
34 |
35 |
36 | @register_model("gguf", "ggml")
37 | class GGUFLM(LM):
38 | def __init__(self, base_url=None, max_length=2048, **kwargs):
39 | super().__init__()
40 | self.base_url = base_url
41 | assert self.base_url, "must pass `base_url` to use GGUF LM!"
42 | self.logprobs = 10
43 | self.temperature = 0.0
44 | self.max_length = max_length
45 |
46 | def gguf_completion(
47 | self, context, continuation=None, stop=None, retries=3, delay=5, **kwargs
48 | ):
49 | for _ in range(retries):
50 | try:
51 | prompt = context
52 | request = {
53 | "prompt": prompt,
54 | "logprobs": self.logprobs,
55 | "temperature": self.temperature,
56 | }
57 | if continuation:
58 | prompt += continuation
59 | request.update({"prompt": prompt, "max_tokens": 1, "echo": True})
60 | if stop is not None:
61 | request["stop"] = stop
62 | response = requests.post(
63 | f"{self.base_url}/v1/completions", json=request
64 | )
65 | response.raise_for_status()
66 | return response.json()
67 | except RequestException as e:
68 | logger.error(f"RequestException: {e}")
69 | time.sleep(delay) # wait before retrying
70 | else:
71 | raise Exception(f"Failed to get a valid response after {retries} retries.")
72 |
73 | def loglikelihood(self, requests):
74 | if not requests:
75 | return []
76 | res = []
77 | for context, continuation in tqdm([req.args for req in requests]):
78 | response = self.gguf_completion(context=context, continuation=continuation)
79 | if response and "choices" in response and response["choices"]:
80 | choice = response["choices"][0]
81 | logprobs = choice.get("logprobs")
82 | if (
83 | logprobs
84 | and "token_logprobs" in logprobs
85 | and logprobs["token_logprobs"]
86 | ):
87 | logprob, is_greedy = get_result(logprobs, len(context))
88 | res.append((logprob, is_greedy))
89 | else:
90 | logger.warning(
91 | "Invalid logprobs data. Expected 'logprobs' to contain 'token_logprobs' list."
92 | )
93 | else:
94 | logger.error(
95 | f"Invalid response for loglikelihood. Response: {response}"
96 | )
97 | assert False
98 | return res
99 |
100 | def generate_until(self, requests):
101 | if not requests:
102 | return []
103 |
104 | res = []
105 | for request in tqdm([req.args for req in requests]):
106 | inp = request[0]
107 | request_args = request[1]
108 | until = request_args.get("until", [""])
109 | response = self.gguf_completion(context=inp, stop=until)
110 | if response and "choices" in response and response["choices"]:
111 | choice = response["choices"][0]
112 | if "text" in choice:
113 | generated_text = choice["text"].strip()
114 | res.append(generated_text)
115 | else:
116 | logger.error(
117 | f"Invalid response for greedy_until. Response: {response}"
118 | )
119 | res.append(None) # Add default value in case of error
120 | else:
121 | logger.error(f"Invalid response for greedy_until. Response: {response}")
122 | res.append(None) # Add default value in case of error
123 | return res
124 |
125 | def loglikelihood_rolling(self, requests):
126 | raise NotImplementedError(
127 | "loglikelihood_rolling not yet supported for GGUF models"
128 | )
129 |
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/models/jrt_lm.py:
--------------------------------------------------------------------------------
1 | import re
2 | from transformers import AutoTokenizer
3 | import torch
4 |
5 | import json
6 | from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
7 | from transformers.utils.hub import cached_file
8 |
9 | from lm_eval.api.registry import register_model
10 | from lm_eval.models.huggingface import HFLM
11 |
12 | def load_config_hf(model_name):
13 | resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
14 | return json.load(open(resolved_archive_file))
15 |
16 | @register_model("jrt_lm")
17 | class JRTLMWrapper(HFLM):
18 | def __init__(
19 | self,
20 | checkpoint_name: str='hazyresearch/based-1.3b',
21 | arch: str=None,
22 | device: str = "cuda",
23 | **kwargs
24 | ) -> None:
25 |
26 | if arch is None:
27 | arch = checkpoint_name.split("/")[1].split("-")[0]
28 |
29 | assert arch in ['JRT', 'based', 'mamba', 'attn'], print("`arch` must be one of 'JRT', 'based', 'mamba', or 'attn'")
30 |
31 | if "backend" in kwargs:
32 | # based currently only supports causal models
33 | assert kwargs["backend"] == "causal"
34 |
35 | self.checkpoint_name = checkpoint_name
36 |
37 | if arch == "based":
38 | from train.src.models.gpt import GPTLMHeadModel
39 | model = GPTLMHeadModel.from_pretrained_hf(pretrained_model_name=self.checkpoint_name, device=device)
40 | elif arch == "JRT":
41 | from train.src.models.gpt import GPTLMHeadModel
42 | model = GPTLMHeadModel.from_pretrained_hf(pretrained_model_name=self.checkpoint_name, device=device)
43 | elif arch == "mamba":
44 | from based.models.mamba import MambaLMHeadModel
45 | model = MambaLMHeadModel.from_pretrained_hf(pretrained_model_name=self.checkpoint_name, device=device)
46 | elif arch == "attn":
47 | from based.models.transformer.gpt import GPTLMHeadModel, GPT2Config, state_dict_from_pretrained; # TODO: construct a loading function
48 | config_data = load_config_hf(self.checkpoint_name)
49 | config = GPT2Config(**config_data)
50 | try:
51 | model = GPTLMHeadModel(config=config, device=device, dtype=torch.float16, multiple_of=256)
52 | state_dict = state_dict_from_pretrained(self.checkpoint_name, dtype=torch.float16)
53 | # remove the 'model.' prefix from the keys
54 | state_dict = {re.sub("^model\.", "", k): v for k, v in state_dict.items()}
55 | # remove Unexpected key(s) in state_dict: "train_metrics.num-tokens.count", "val_metrics.num-tokens.count", "test_metrics.num-tokens.count". from the state_dict
56 | state_dict = {k: v for k, v in state_dict.items() if "metrics" not in k}
57 | model.load_state_dict(state_dict)
58 | except:
59 | model = GPTLMHeadModel(config=config, device=device, dtype=torch.float16, multiple_of=128)
60 | state_dict = state_dict_from_pretrained(self.checkpoint_name, dtype=torch.float16)
61 | state_dict = {re.sub("^model\.", "", k): v for k, v in state_dict.items()}
62 | state_dict = {k: v for k, v in state_dict.items() if "metrics" not in k}
63 | model.load_state_dict(state_dict)
64 | else:
65 | raise ValueError(f"Unsupported model {arch}")
66 |
67 | tokenizer_name = kwargs.get("tokenizer", "gpt2")
68 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
69 | tokenizer.model_max_length = 2048
70 |
71 | model.device = device
72 |
73 | super().__init__(
74 | pretrained=model,
75 | # set appropriate defaults for tokenizer, max length, etc
76 | backend=kwargs.get("backend", "causal"),
77 | max_length=kwargs.get("max_length", 2048),
78 | tokenizer=tokenizer,
79 | device=device,
80 | **kwargs,
81 | )
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/models/local_lm.py:
--------------------------------------------------------------------------------
1 | from .local_utils.loading import load_model, load_tokenizer, load_hf_model
2 | from lm_eval.api.registry import register_model
3 | from lm_eval.models.huggingface import HFLM
4 |
5 |
6 | @register_model("lm_eval_model")
7 | class LMWrapper(HFLM):
8 | def __init__(
9 | self,
10 | checkpoint_name: str,
11 | max_length: int = 2048,
12 | device: str = "cuda",
13 | **kwargs
14 | ) -> None:
15 |
16 | is_hf=not checkpoint_name.startswith("hazy-research")
17 | tokenizer = load_tokenizer(checkpoint_name, is_hf=is_hf)
18 | if is_hf:
19 | model = load_hf_model(checkpoint_name)
20 | else:
21 | model = load_model(checkpoint_name, device=device)
22 | # model.device = device
23 | # import torch
24 | # model.to(device, dtype=torch.float32)
25 | model.to(device)
26 |
27 | super().__init__(
28 | pretrained=model,
29 | backend="causal",
30 | max_length=max_length,
31 | tokenizer=tokenizer,
32 | device=device,
33 | **kwargs,
34 | )
35 |
36 |
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/models/local_utils/loading.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from typing import Union
4 | import os
5 | import hydra
6 | import sys
7 | import torch
8 | from transformers import AutoTokenizer, AutoModelForCausalLM
9 | from .jrt_utils import import_object, load_config
10 |
11 | ### This code loads models that we trained ###
12 | def load_model(
13 | run_id: str,
14 | device: Union[int, str] = None,
15 | config: any=None,
16 | ) -> nn.Module:
17 | """
18 | Load a model from a wandb run ID.
19 | Parameters:
20 | run_id (str): A full wandb run id like "hazy-research/attention/159o6asi"
21 | """
22 |
23 | # 1: Get configuration from wandb
24 | config = load_config(run_id)
25 | path = config["callbacks"]["model_checkpoint"]["dirpath"]
26 |
27 | if config["model"].get("_instantiate_config_", True):
28 |
29 | # SE (01/29): models were trained on flash_attn==2.3.6
30 | # a newer version sets this parameter to 128 by default, so to make it
31 | # compatible while still allowing for upgrades of flash attention,
32 | # we set it to 256 here
33 | if config["model"]["_target_"] == "flash_attn.models.gpt.GPTLMHeadModel":
34 | config["model"]["config"]["mlp_multiple_of"] = 128
35 |
36 | model_config = hydra.utils.instantiate(
37 | config["model"]["config"], _recursive_=False, _convert_="object"
38 | )
39 | cls = import_object(config["model"]["_target_"])
40 | model = cls(model_config).to(device=device)
41 | else:
42 | # SE: need this alternate form for models that accept kwargs, not a config object
43 | model_config = config["model"].pop("config")
44 | model = hydra.utils.instantiate(config["model"], **model_config, _recursive_=False)
45 |
46 | path = path.replace(
47 | "/var/cr05_data/sim_data/checkpoints/", # old machine
48 | '/home/simarora/based-checkpoints/checkpoints/'
49 | )
50 |
51 | try:
52 | assert os.path.exists(path), print(f"Path {path} does not exist")
53 | ckpt = torch.load(os.path.join(path, "last.ckpt"), map_location=torch.device(device))
54 | except:
55 | paths = os.listdir(path)
56 | paths = [p for p in paths if ".ckpt" in p]
57 | print(f'Loading model from {paths[0]}')
58 | ckpt = torch.load(os.path.join(path, paths[0]), map_location=torch.device(device))
59 |
60 | # 3: Load model
61 | # load the state dict, but remove the "model." prefix and all other keys from the
62 | # the PyTorch Lightning module that are not in the actual model
63 | model.load_state_dict({
64 | k[len("model."):]: v
65 | for k, v in ckpt["state_dict"].items()
66 | if k.startswith("model.")
67 | })
68 |
69 | model = model.to(device=device)
70 | return model
71 |
72 |
73 | def load_hf_model(model_name: str, device:str = 'cuda') -> nn.Module:
74 | if "mamba" in model_name:
75 |
76 | # SA: can't pass in device here https://github.com/pytorch/pytorch/issues/10622
77 | model = MambaLMHeadModel.from_pretrained(model_name, device=device, dtype=torch.float16)
78 |
79 | else:
80 | if "Mixtral" in model_name:
81 | model = AutoModelForCausalLM.from_pretrained(
82 | model_name, trust_remote_code=True, use_flash_attention_2=True,
83 | # load_in_8bit=True,
84 | torch_dtype=torch.bfloat16, device_map="auto"
85 | )
86 | else:
87 | model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, token="your token here")
88 | model.to(device)
89 |
90 |
91 | try: model.device = device
92 | except: pass
93 | model.eval()
94 | return model
95 |
96 |
97 | def load_tokenizer(model_name: str, is_hf: bool=False) -> nn.Module:
98 | if not is_hf:
99 | tokenizer = AutoTokenizer.from_pretrained("gpt2")
100 | tokenizer.model_max_length = 2048
101 | else:
102 | if "mamba" in model_name or "mpt" in model_name:
103 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
104 | else:
105 | tokenizer = AutoTokenizer.from_pretrained(model_name)
106 | tokenizer.pad_token = tokenizer.eos_token
107 | tokenizer.pad_token_id = tokenizer.eos_token_id
108 | return tokenizer
109 |
110 |
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/models/mamba_lm.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union
2 |
3 | import torch
4 |
5 | import lm_eval.models.utils
6 | from lm_eval.api.registry import register_model
7 | from lm_eval.models.huggingface import HFLM
8 |
9 |
10 | @register_model("mamba_ssm")
11 | class MambaLMWrapper(HFLM):
12 | def __init__(
13 | self,
14 | checkpoint_name="state-spaces/mamba-130m",
15 | **kwargs,
16 | ) -> None:
17 | """
18 | Mamba (via the `mamba_ssm` package) supports the following args:
19 | ```
20 | d_model: int,
21 | n_layer: int,
22 | vocab_size: int,
23 | initializer_cfg=None,
24 | pad_vocab_size_multiple: int = 1,
25 | ssm_cfg=None,
26 | norm_epsilon: float = 1e-5,
27 | rms_norm: bool = False,
28 | initializer_cfg=None,
29 | fused_add_norm=False,
30 | residual_in_fp32=False,
31 | ```
32 |
33 | See https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L175 for more info.
34 | The above can all be passed via `--model_args` or to this __init__() directly
35 | but we recommend placing many of these within the config.json file uploaded alongside your
36 | Mamba model to the HF Hub instead.
37 | All other HuggingFace from_pretrained() kwargs
38 | such as those related to
39 | `parallelize=True`, PEFT, autoGPTQ,
40 | or any sub-configurations of these advanced args,
41 | are unsupported by the `mamba_ssm` package.
42 |
43 | The HFLM arguments
44 |
45 | `backend`, `tokenizer`, `truncation`, `max_length`,
46 | `device`, `dtype`, `batch_size`, `max_batch_size`, `trust_remote_code`, `use_fast_tokenizer`
47 |
48 | Are all supported by Mamba where they do not conflict
49 | with Mamba-specific restrictions such as causal LMs only.
50 | """
51 |
52 | if "backend" in kwargs:
53 | # mamba currently only supports causal models
54 | assert kwargs["backend"] == "causal"
55 |
56 | super().__init__(
57 | pretrained=checkpoint_name,
58 | # set appropriate defaults for tokenizer, max length, etc
59 | backend=kwargs.get("backend", "causal"),
60 | tokenizer=kwargs.get("tokenizer", "EleutherAI/gpt-neox-20b"),
61 | max_length=kwargs.get("max_length", 2048),
62 | **kwargs,
63 | )
64 |
65 | def _get_config(
66 | self,
67 | checkpoint_name: str,
68 | **kwargs,
69 | ) -> None:
70 | try:
71 | from mamba_ssm.utils.hf import load_config_hf # noqa: F811
72 | except ModuleNotFoundError:
73 | raise Exception(
74 | "attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed. \
75 | please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`",
76 | )
77 |
78 | self._config = load_config_hf(checkpoint_name)
79 |
80 | def _create_model(
81 | self,
82 | pretrained: str,
83 | dtype: Optional[Union[str, torch.dtype]] = "float16",
84 | # no `parallelize=True` options
85 | # no PEFT and quantization options
86 | # Mamba does not support arbitrary HF from_pretrained() args
87 | **kwargs,
88 | ) -> None:
89 | try:
90 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel # noqa: F811
91 | except ModuleNotFoundError:
92 | raise Exception(
93 | "attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed. \
94 | please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`",
95 | )
96 |
97 | self._model = MambaLMHeadModel.from_pretrained(
98 | pretrained,
99 | device=self._device,
100 | dtype=torch.float16
101 | if dtype == "auto"
102 | else lm_eval.models.utils.get_dtype(dtype),
103 | )
104 |
105 | def _model_generate(self, context, max_length, stop, **generation_kwargs):
106 | for key in ("do_sample", "attention_mask"):
107 | if key in generation_kwargs:
108 | generation_kwargs.pop(key)
109 |
110 | # mamba's custom GenerationMixin currently does not support
111 | # passing stopping criteria.
112 | # for the time being, we simply generate to max length,
113 | # then truncate (equivalent result)
114 | # -- this should be revisited to speed up generation
115 | # stopping_criteria = stop_sequences_criteria(
116 | # self.tokenizer, stop, 1, context.shape[0]
117 | # )
118 |
119 | return self.model.generate(
120 | input_ids=context,
121 | max_length=max_length,
122 | # stopping_criteria=stopping_criteria,
123 | # pad_token_id=self.tokenizer.pad_token_id,
124 | # use_cache=True,
125 | **generation_kwargs,
126 | )
127 |
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/models/optimum_lm.py:
--------------------------------------------------------------------------------
1 | from importlib.util import find_spec
2 | from pathlib import Path
3 |
4 | from lm_eval.api.registry import register_model
5 | from lm_eval.models.huggingface import HFLM
6 |
7 |
8 | @register_model("openvino")
9 | class OptimumLM(HFLM):
10 | """
11 | Optimum Intel provides a simple interface to optimize Transformer models and convert them to \
12 | OpenVINO™ Intermediate Representation (IR) format to accelerate end-to-end pipelines on \
13 | Intel® architectures using OpenVINO™ runtime.
14 | """
15 |
16 | def __init__(
17 | self,
18 | device="cpu",
19 | **kwargs,
20 | ) -> None:
21 | if "backend" in kwargs:
22 | # optimum currently only supports causal models
23 | assert (
24 | kwargs["backend"] == "causal"
25 | ), "Currently, only OVModelForCausalLM is supported."
26 |
27 | self.openvino_device = device
28 |
29 | super().__init__(
30 | device=self.openvino_device,
31 | backend=kwargs.get("backend", "causal"),
32 | **kwargs,
33 | )
34 |
35 | def _create_model(
36 | self,
37 | pretrained: str,
38 | revision="main",
39 | dtype="auto",
40 | trust_remote_code=False,
41 | **kwargs,
42 | ) -> None:
43 | if not find_spec("optimum"):
44 | raise Exception(
45 | "package `optimum` is not installed. Please install it via `pip install optimum[openvino]`"
46 | )
47 | else:
48 | from optimum.intel.openvino import OVModelForCausalLM
49 |
50 | model_kwargs = kwargs if kwargs else {}
51 | model_file = Path(pretrained) / "openvino_model.xml"
52 | if model_file.exists():
53 | export = False
54 | else:
55 | export = True
56 | kwargs["ov_config"] = {
57 | "PERFORMANCE_HINT": "LATENCY",
58 | "NUM_STREAMS": "1",
59 | "CACHE_DIR": "",
60 | }
61 |
62 | self._model = OVModelForCausalLM.from_pretrained(
63 | pretrained,
64 | revision=revision,
65 | trust_remote_code=trust_remote_code,
66 | export=export,
67 | device=self.openvino_device.upper(),
68 | **model_kwargs,
69 | )
70 |
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/prompts/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import ast
3 |
4 | from typing import Dict
5 | from lm_eval import utils
6 | from lm_eval.utils import eval_logger
7 |
8 | # Prompt library.
9 | # Stores prompts in a dictionary indexed by 2 levels:
10 | # prompt category name, and prompt name.
11 | # This allows us to access prompts
12 | PROMPT_REGISTRY: Dict[str, Dict[str, str]] = {
13 | "qa-basic": {
14 | "question-newline-answer": "Question: {{question}}\nAnswer:",
15 | "q-newline-a": "Q: {{question}}\nA:",
16 | },
17 | }
18 |
19 |
20 | def get_prompt(prompt_id: str, dataset_name: str = None, subset_name: str = None):
21 | # unpack prompt name
22 | category_name, prompt_name = prompt_id.split(":")
23 | if subset_name is None:
24 | dataset_full_name = dataset_name
25 | else:
26 | dataset_full_name = f"{dataset_name}-{subset_name}"
27 | eval_logger.info(f"Loading prompt from {category_name} for {dataset_full_name}")
28 | if category_name == "promptsource":
29 | try:
30 | from promptsource.templates import DatasetTemplates
31 | except ModuleNotFoundError:
32 | raise Exception(
33 | "Tried to load a Promptsource template, but promptsource is not installed ",
34 | "please install promptsource via pip install lm-eval[promptsource] or pip install -e .[promptsource]",
35 | )
36 | try:
37 | if subset_name is None:
38 | prompts = DatasetTemplates(dataset_name=dataset_name)
39 | else:
40 | prompts = DatasetTemplates(
41 | dataset_name=dataset_name, subset_name=subset_name
42 | )
43 | except Exception:
44 | raise ValueError(f"{dataset_name} and {subset_name} not found")
45 | if prompt_name in prompts.all_template_names:
46 | return prompts[prompt_name]
47 | else:
48 | raise ValueError(
49 | f"{prompt_name} not in prompt list {prompts.all_template_names}"
50 | )
51 | elif ".yaml" in category_name:
52 | import yaml
53 |
54 | with open(category_name, "rb") as file:
55 | prompt_yaml_file = yaml.full_load(file)
56 |
57 | prompt_string = prompt_yaml_file["prompts"][prompt_name]
58 | return PromptString(prompt_string)
59 | else:
60 | try:
61 | return PROMPT_REGISTRY[category_name][prompt_name]
62 | except Exception:
63 | raise ValueError(
64 | f"expected only a single `:` as separator between \
65 | prompt category and name, but got `{prompt_id}` instead"
66 | )
67 |
68 |
69 | def load_prompt_list(
70 | use_prompt: str, dataset_name=None, subset_name=None, yaml_path=None, **kwargs
71 | ):
72 | category_name, prompt_name = use_prompt.split(":")
73 |
74 | if category_name == "promptsource":
75 | from promptsource.templates import DatasetTemplates
76 |
77 | if subset_name is None:
78 | prompts = DatasetTemplates(dataset_name=dataset_name)
79 | else:
80 | prompts = DatasetTemplates(
81 | dataset_name=dataset_name, subset_name=subset_name
82 | )
83 |
84 | prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names)
85 |
86 | elif ".yaml" in category_name:
87 | import yaml
88 |
89 | if yaml_path is not None:
90 | category_name = os.path.realpath(os.path.join(yaml_path, category_name))
91 |
92 | with open(category_name, "rb") as file:
93 | prompt_yaml_file = yaml.full_load(file)
94 |
95 | prompt_list = utils.pattern_match(
96 | prompt_name, prompt_yaml_file["prompts"].keys()
97 | )
98 |
99 | # category_name, *prompt_name = use_prompt.split(":")
100 | # TODO allow to multiple prompt naming
101 | # if len(prompt_name) > 1:
102 | # prompt_list = []
103 | # for prompt in prompt_name:
104 | # prompt_list.append(utils.pattern_match(prompt_name, prompts.all_template_names))
105 | # else:
106 | # prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names)
107 | return [":".join([category_name, prompt]) for prompt in prompt_list]
108 |
109 |
110 | class PromptString:
111 | def __init__(self, prompt_string):
112 | self.prompt_string = prompt_string
113 |
114 | def apply(self, doc):
115 | doc_to_text = self.prompt_string["doc_to_text"]
116 | doc_to_target = self.prompt_string["doc_to_target"]
117 |
118 | # TODO need a way to process doc_to_choice
119 | if "doc_to_choice" in self.prompt_string:
120 | raise Exception("Not yet implemented to accept doc_to_choice")
121 |
122 | text_string = utils.apply_template(doc_to_text, doc)
123 | target_string = utils.apply_template(doc_to_target, doc)
124 |
125 | return [text_string, target_string]
126 |
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/tasks/README.md:
--------------------------------------------------------------------------------
1 | # v1.0 Tasks
2 | This list keeps track of which tasks' implementations have been ported to YAML / v2.0 of the Eval Harness.
3 |
4 | Boxes should be checked iff tasks are implemented in the refactor and tested for regression. Tasks should be struck through if checked *against original introducing paper* implementation or popularizing implementation. (WIP) Denotes that there exists a PR or person working on this task already.
5 |
6 | - [x] Glue
7 | - [x] SuperGlue
8 | - [x] CoQA
9 | - [x] DROP
10 | - [x] ~~Lambada~~
11 | - [x] Lambada (Cloze variants)
12 | - [x] ~~Lambada (Multilingual)~~
13 | - [x] Wikitext
14 | - [x] PiQA
15 | - [x] PROST
16 | - [x] MCTACO
17 | - [x] Pubmed QA
18 | - [x] SciQ
19 | - [x] QASPER
20 | - [x] QA4MRE
21 | - [x] TriviaQA
22 | - [x] AI2 ARC
23 | - [x] LogiQA
24 | - [x] HellaSwag
25 | - [x] SWAG
26 | - [x] OpenBookQA
27 | - [ ] SQuADv2 (Lintang)
28 | - [x] RACE
29 | - [x] HeadQA
30 | - [x] MathQA
31 | - [x] WebQs
32 | - [x] WSC273
33 | - [x] Winogrande
34 | - [x] ANLI
35 | - [x] Hendrycks Ethics (missing some tasks/metrics, see PR 660: for more info)
36 | - [x] TruthfulQA (mc1)
37 | - [x] TruthfulQA (mc2)
38 | - [x] TruthfulQA (gen)
39 | - [x] MuTual
40 | - [ ] Hendrycks Math (Hailey)
41 | - [x] Asdiv
42 | - [ ] GSM8k
43 | - [x] Arithmetic
44 | - [ ] MMMLU (Hailey)
45 | - [x] Translation (WMT) suite
46 | - [x] Unscramble
47 | - [x] ~~Pile (perplexity)~~
48 | - [x] BLiMP
49 | - [x] ToxiGen
50 | - [x] StoryCloze
51 | - [ ] NaturalQs (Hailey)
52 | - [x] CrowS-Pairs
53 | - [x] XCopa
54 | - [ ] BIG-Bench (Hailey)
55 | - [x] XStoryCloze
56 | - [x] XWinograd
57 | - [x] PAWS-X
58 | - [x] XNLI
59 | - [x] MGSM
60 | - [ ] SCROLLS
61 | - [x] Babi
62 | - [x] Belebele
63 |
64 | # Novel Tasks
65 | Tasks added in the revamped harness that were not previously available. Again, a strikethrough denotes checking performed *against the original task's implementation or published results introducing the task*.
66 |
67 | # Task Wishlist
68 |
69 | - [ ] TheoremQA
70 | - [ ] Theorem Proving evaluations
71 | - [ ] Chain of Thought
72 | - [ ] Self-consistency ; Least-to-Most prompting, etc.
73 | - [ ] Summarization Tasks
74 | - [ ] Anthropic Model-Written Evals
75 |
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/tasks/based_drop/task.py:
--------------------------------------------------------------------------------
1 | from lm_eval.api.task import ConfigurableTask
2 | from lm_eval.api.instance import Instance
3 | import re
4 | from typing import List
5 | import numpy as np
6 |
7 |
8 | class BasedDrop(ConfigurableTask):
9 | VERSION = "default"
10 | DATASET_PATH = "hazyresearch/based_drop"
11 | DATASET_NAME = None
12 |
13 | def __init__(self):
14 | super().__init__(config={'metadata': {'version': self.VERSION}})
15 |
16 | def has_training_docs(self):
17 | return False
18 |
19 | def has_validation_docs(self):
20 | return True
21 |
22 | def has_test_docs(self):
23 | return False
24 |
25 | def validation_docs(self):
26 | return self.dataset["validation"]
27 |
28 | def doc_to_text(self, doc):
29 | context = doc["context"].strip()
30 | question = doc["question"].strip()
31 | while(context.lower().endswith(question.lower())):
32 | context = context[:-len(question)]
33 |
34 | out = (
35 | context.strip().strip(".") + ". " + question
36 | )
37 | return out
38 |
39 | def should_decontaminate(self):
40 | return True
41 |
42 | def doc_to_decontamination_query(self, doc):
43 | return doc["context"]
44 |
45 | def doc_to_target(self, doc):
46 | answer_list = doc['answers']
47 | if len(answer_list) > 0:
48 | answer = answer_list[0]
49 | else:
50 | answer = "unanswerable"
51 | return " " + answer
52 |
53 | def construct_requests(self, doc, ctx, **kwargs):
54 | """Uses RequestFactory to construct Requests and returns an iterable of
55 | Requests which will be sent to the LM.
56 |
57 | :param doc:
58 | The document as returned from training_docs, validation_docs, or test_docs.
59 | :param ctx: str
60 | The context string, generated by fewshot_context. This includes the natural
61 | language description, as well as the few shot examples, and the question
62 | part of the document for `doc`.
63 | """
64 |
65 | return [
66 | Instance(
67 | request_type="generate_until",
68 | doc=doc,
69 | arguments=(ctx, {"until": ["\n"], "max_gen_toks": 48}),
70 | idx=0,
71 | **kwargs,
72 | ),
73 | ]
74 |
75 | def process_results(self, doc, results):
76 | """Take a single document and the LM results and evaluates, returning a
77 | dict where keys are the names of submetrics and values are the values of
78 | the metric for that one document
79 |
80 | :param doc:
81 | The document as returned from training_docs, validation_docs, or test_docs.
82 | :param results:
83 | The results of the requests created in construct_requests.
84 | """
85 |
86 | continuation = results[0]
87 |
88 | return {
89 | "contains": contains_score(continuation, doc["answers"])
90 | }
91 |
92 |
93 | def aggregation(self):
94 | """
95 | :returns: {str: [float] -> float}
96 | A dictionary where keys are the names of submetrics and values are
97 | functions that aggregate a list of metrics
98 | """
99 | return {
100 | "contains": np.mean,
101 | }
102 |
103 | def higher_is_better(self):
104 | """
105 | :returns: {str: bool}
106 | A dictionary where keys are the names of submetrics and values are
107 | whether a higher value of the submetric is better
108 | """
109 | return {
110 | "contains": True
111 | }
112 |
113 | def contains_score(prediction: str, labels: List[str]):
114 | return max(
115 | int(bool(re.search(re.compile(re.escape(label), re.IGNORECASE), prediction)))
116 | for label in labels
117 | )
118 |
119 | class BasedDropTwice(BasedDrop):
120 |
121 | def doc_to_text(self, doc):
122 | context = doc["context"].strip()
123 | question = doc["question"].strip()
124 | while(context.lower().endswith(question.lower())):
125 | context = context[:-len(question)]
126 |
127 | out = context.strip().strip(".").strip() + "."
128 | out += "\n" + out + " " + question
129 |
130 | intro_q = doc['orig_question'].strip(":")
131 | out = f"{intro_q} " + out
132 | return out
133 |
134 |
135 |
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/tasks/based_fda/README.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenSparseLLMs/MoM/2c3ae59f5c0b749f916189433b8dad5a3415dc08/lm-eval-harness/lm_eval/tasks/based_fda/README.md
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/tasks/based_fda/task.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import re
3 | import numpy as np
4 | from lm_eval.api.task import ConfigurableTask
5 | from lm_eval.api.instance import Instance
6 |
7 |
8 | class FDA(ConfigurableTask):
9 | VERSION = 0
10 | DATASET_PATH = "hazyresearch/based-fda"
11 | DATASET_NAME = "default"
12 |
13 | def __init__(self):
14 | super().__init__(config={'metadata': {'version': self.VERSION}})
15 |
16 | def has_training_docs(self):
17 | return False
18 |
19 | def has_validation_docs(self):
20 | return True
21 |
22 | def has_test_docs(self):
23 | return False
24 |
25 | def validation_docs(self):
26 | return self.dataset["validation"]
27 |
28 | def doc_to_text(self, doc):
29 | question = doc["key"]+":"
30 | while(doc["text"].lower().endswith(question.lower())):
31 | doc["text"] = doc["text"][:-len(question)]
32 | upper_key = doc['key'][0].upper() + doc['key'][1:]
33 | question = upper_key +":"
34 | doc['text'] = doc['text'].strip("\n").strip(".")
35 | out = doc["text"]
36 | if not out.endswith("."): out += "."
37 | out += " " + question
38 | return out
39 |
40 | def doc_to_target(self, doc):
41 | return doc["value"]
42 |
43 | def construct_requests(self, doc, ctx, **kwargs):
44 | """Uses RequestFactory to construct Requests and returns an iterable of
45 | Requests which will be sent to the LM.
46 |
47 | :param doc:
48 | The document as returned from training_docs, validation_docs, or test_docs.
49 | :param ctx: str
50 | The context string, generated by fewshot_context. This includes the natural
51 | language description, as well as the few shot examples, and the question
52 | part of the document for `doc`.
53 | """
54 |
55 | return [
56 | Instance(
57 | request_type="generate_until",
58 | doc=doc,
59 | arguments=(ctx, {"until": ["\n"], "max_gen_toks": 48}),
60 | idx=0,
61 | **kwargs,
62 | ),
63 | ]
64 |
65 | def process_results(self, doc, results):
66 | """Take a single document and the LM results and evaluates, returning a
67 | dict where keys are the names of submetrics and values are the values of
68 | the metric for that one document
69 |
70 | :param doc:
71 | The document as returned from training_docs, validation_docs, or test_docs.
72 | :param results:
73 | The results of the requests created in construct_requests.
74 | """
75 | continuation = results[0]
76 |
77 | return {
78 | "contains": contains_score(continuation, [doc["value"]])
79 | }
80 |
81 | def aggregation(self):
82 | """
83 | :returns: {str: [float] -> float}
84 | A dictionary where keys are the names of submetrics and values are
85 | functions that aggregate a list of metrics
86 | """
87 | return {
88 | "contains": np.mean, # Exact match (the normalized answer exactly match the gold answer)
89 | }
90 |
91 | def higher_is_better(self):
92 | """
93 | :returns: {str: bool}
94 | A dictionary where keys are the names of submetrics and values are
95 | whether a higher value of the submetric is better
96 | """
97 | return {
98 | "contains": True, # Exact match (the normalized answer exactly match the gold answer
99 | }
100 |
101 |
102 | def contains_score(prediction: str, labels: List[str]):
103 | return max(
104 | int(bool(re.search(re.compile(re.escape(label), re.IGNORECASE), prediction)))
105 | for label in labels
106 | )
107 |
108 | class FDATwice(FDA):
109 |
110 | def doc_to_text(self, doc):
111 | question = doc["key"]+":"
112 | while(doc["text"].lower().endswith(question.lower())):
113 | doc["text"] = doc["text"][:-len(question)]
114 |
115 | upper_key = doc['key'][0].upper() + doc['key'][1:]
116 | question = upper_key +":"
117 | doc['text'] = doc['text'].strip("\n").strip(".")
118 | out = doc["text"]
119 | if not out.endswith("."): out += "."
120 | intro_q = question.strip(":")
121 | out = f"Information about {intro_q}. " + out + "\n" + out + " " + question
122 | return out
123 |
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/tasks/based_nq/task.py:
--------------------------------------------------------------------------------
1 | from lm_eval.api.task import ConfigurableTask
2 | from lm_eval.api.instance import Instance
3 | import re
4 | from typing import List
5 | import numpy as np
6 |
7 |
8 | class BasedNQ512(ConfigurableTask):
9 | VERSION = "default"
10 | DATASET_PATH = "hazyresearch/based_nq_512"
11 | DATASET_NAME = None
12 |
13 | def __init__(self):
14 | super().__init__(config={'metadata': {'version': self.VERSION}})
15 |
16 | def has_training_docs(self):
17 | return False
18 |
19 | def has_validation_docs(self):
20 | return True
21 |
22 | def has_test_docs(self):
23 | return False
24 |
25 | def validation_docs(self):
26 | return self.dataset["validation"]
27 |
28 | def doc_to_text(self, doc):
29 | context = doc["context"].strip()
30 | question = doc["question"].strip()
31 | if context.endswith(question):
32 | context = context[:-len(question)]
33 |
34 | out = context.strip().strip(".") + ". "
35 | out = out + doc["question"].strip()
36 | return out
37 |
38 | def should_decontaminate(self):
39 | return True
40 |
41 | def doc_to_decontamination_query(self, doc):
42 | return doc["context"]
43 |
44 | def doc_to_target(self, doc):
45 | answer_list = doc['answers']
46 | if len(answer_list) > 0:
47 | answer = answer_list[0]
48 | else:
49 | answer = "unanswerable"
50 | return " " + answer
51 |
52 | def construct_requests(self, doc, ctx, **kwargs):
53 | """Uses RequestFactory to construct Requests and returns an iterable of
54 | Requests which will be sent to the LM.
55 |
56 | :param doc:
57 | The document as returned from training_docs, validation_docs, or test_docs.
58 | :param ctx: str
59 | The context string, generated by fewshot_context. This includes the natural
60 | language description, as well as the few shot examples, and the question
61 | part of the document for `doc`.
62 | """
63 |
64 | return [
65 | Instance(
66 | request_type="generate_until",
67 | doc=doc,
68 | arguments=(ctx, {"until": ["\n"], "max_gen_toks": 48}),
69 | idx=0,
70 | **kwargs,
71 | ),
72 | ]
73 |
74 | def process_results(self, doc, results):
75 | """Take a single document and the LM results and evaluates, returning a
76 | dict where keys are the names of submetrics and values are the values of
77 | the metric for that one document
78 |
79 | :param doc:
80 | The document as returned from training_docs, validation_docs, or test_docs.
81 | :param results:
82 | The results of the requests created in construct_requests.
83 | """
84 | continuation = results[0]
85 | return {
86 | "contains": contains_score(continuation, doc["answers"])
87 | }
88 |
89 |
90 | def aggregation(self):
91 | """
92 | :returns: {str: [float] -> float}
93 | A dictionary where keys are the names of submetrics and values are
94 | functions that aggregate a list of metrics
95 | """
96 | return {
97 | "contains": np.mean,
98 | }
99 |
100 |
101 | def higher_is_better(self):
102 | """
103 | :returns: {str: bool}
104 | A dictionary where keys are the names of submetrics and values are
105 | whether a higher value of the submetric is better
106 | """
107 | return {
108 | "contains": True
109 | }
110 |
111 |
112 | def contains_score(prediction: str, labels: List[str]):
113 | return max(
114 | int(bool(re.search(re.compile(re.escape(label), re.IGNORECASE), prediction)))
115 | for label in labels
116 | )
117 |
118 |
119 | class BasedNQTwice512(BasedNQ512):
120 | def doc_to_text(self, doc):
121 | context = doc["context"].strip()
122 | question = doc["question"].strip()
123 | while(context.lower().endswith(question.lower())):
124 | context = context[:-len(question)]
125 |
126 | out = context.strip().strip(".").strip() + "."
127 | out += "\n" + out + " " + question
128 |
129 | intro_q = doc['orig_question'].strip(":")
130 | out = f"{intro_q} " + out
131 | return out
132 |
133 |
134 | class BasedNQ1024(BasedNQ512):
135 | DATASET_PATH = "hazyresearch/based_nq_1024"
136 | DATASET_NAME = None
137 |
138 |
139 | class BasedNQ2048(BasedNQ512):
140 | DATASET_PATH = "hazyresearch/based_nq_2048"
141 | DATASET_NAME = None
142 |
143 |
144 | class BasedNQTwice1024(BasedNQTwice512):
145 | DATASET_PATH = "hazyresearch/based_nq_1024"
146 | DATASET_NAME = None
147 |
148 |
149 | class BasedNQTwice2048(BasedNQTwice512):
150 | DATASET_PATH = "hazyresearch/based_nq_2048"
151 | DATASET_NAME = None
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/tasks/based_squadv2/README.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenSparseLLMs/MoM/2c3ae59f5c0b749f916189433b8dad5a3415dc08/lm-eval-harness/lm_eval/tasks/based_squadv2/README.md
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/tasks/based_swde/README.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenSparseLLMs/MoM/2c3ae59f5c0b749f916189433b8dad5a3415dc08/lm-eval-harness/lm_eval/tasks/based_swde/README.md
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/tasks/scrolls/README.md:
--------------------------------------------------------------------------------
1 | """
2 | SCROLLS: Standardized CompaRison Over Long Language Sequences
3 | https://arxiv.org/abs/2201.03533
4 |
5 | SCROLLS is a suite of datasets that require synthesizing information over long texts.
6 | The benchmark includes seven natural language tasks across multiple domains,
7 | including summarization, question answering, and natural language inference.
8 |
9 | Homepage: https://www.scrolls-benchmark.com/
10 |
11 | Since SCROLLS tasks are generally longer than the maximum sequence length of many models,
12 | it is possible to create "subset" tasks that contain only those samples whose tokenized length
13 | is less than some pre-defined limit. For example, to create a subset of "Qasper" that would
14 | be suitable for a model using the GPTNeoX tokenizer and a 4K maximium sequence length:
15 |
16 | ```
17 | class QasperGPTNeoX4K(Qasper):
18 | PRUNE_TOKENIZERS = ["EleutherAI/pythia-410m-deduped"]
19 | PRUNE_MAX_TOKENS = 4096
20 | PRUNE_NUM_PROC = _num_cpu_cores() # optional, to speed up pruning of large datasets like NarrativeQA
21 | ```
22 |
23 | `PRUNE_TOKENIZERS` can contain more than one tokenizer; this will include only samples that are
24 | less than `PRUNE_MAX_TOKENS` for ALL of the tokenizers. This can be useful to comparing models
25 | that use different tokenizers but the same maximum sequence length.
26 |
27 | Once the subset task class has been defined in this file, it can be used by adding the class
28 | to `lm_eval/tasks/__init__.py`.
29 |
30 | NOTE: GovReport may need `max_gen_toks` set larger for causal models.
31 | """
32 |
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/tasks/super_glue/README.md:
--------------------------------------------------------------------------------
1 | # SuperGLUE
2 |
3 | ### Paper
4 |
5 | Title: `SuperGLUE: A Stickier Benchmark for General-Purpose Language Understanding Systems`
6 | Abstract: `https://w4ngatang.github.io/static/papers/superglue.pdf`
7 |
8 | SuperGLUE is a benchmark styled after GLUE with a new set of more difficult language
9 | understanding tasks.
10 |
11 | Homepage: https://super.gluebenchmark.com/
12 |
13 | ### Citation
14 |
15 | ```
16 | @inproceedings{NEURIPS2019_4496bf24,
17 | author = {Wang, Alex and Pruksachatkun, Yada and Nangia, Nikita and Singh, Amanpreet and Michael, Julian and Hill, Felix and Levy, Omer and Bowman, Samuel},
18 | booktitle = {Advances in Neural Information Processing Systems},
19 | editor = {H. Wallach and H. Larochelle and A. Beygelzimer and F. d\textquotesingle Alch\'{e}-Buc and E. Fox and R. Garnett},
20 | pages = {},
21 | publisher = {Curran Associates, Inc.},
22 | title = {SuperGLUE: A Stickier Benchmark for General-Purpose Language Understanding Systems},
23 | url = {https://proceedings.neurips.cc/paper/2019/file/4496bf24afe7fab6f046bf4923da8de6-Paper.pdf},
24 | volume = {32},
25 | year = {2019}
26 | }
27 | ```
28 |
29 | ### Groups and Tasks
30 |
31 | #### Groups
32 |
33 | * `super-glue-lm-eval-v1`: SuperGLUE eval adapted from LM Eval V1
34 | * `super-glue-t5-prompt`: SuperGLUE prompt and evaluation that matches the T5 paper (if using accelerate, will error if record is included.)
35 |
36 | #### Tasks
37 |
38 | Comparison between validation split score on T5x and LM-Eval (T5x models converted to HF)
39 | | T5V1.1 Base | SGLUE | BoolQ | CB | Copa | MultiRC | ReCoRD | RTE | WiC | WSC |
40 | | ----------- | ------| ----- | --------- | ---- | ------- | ------ | --- | --- | --- |
41 | | T5x | 69.47 | 78.47(acc) | 83.93(f1) 87.5(acc) | 50(acc) | 73.81(f1) 33.26(em) | 70.09(em) 71.34(f1) | 78.7(acc) | 63.64(acc) | 75(acc) |
42 | | LM-Eval | 71.35 | 79.36(acc) | 83.63(f1) 87.5(acc) | 63(acc) | 73.45(f1) 33.26(em) | 69.85(em) 68.86(f1) | 78.34(acc) | 65.83(acc) | 75.96(acc) |
43 |
44 |
45 |
46 | * `super-glue-lm-eval-v1`
47 | - `boolq`
48 | - `cb`
49 | - `copa`
50 | - `multirc`
51 | - `record`
52 | - `rte`
53 | - `wic`
54 | - `wsc`
55 |
56 | * `super-glue-t5-prompt`
57 | - `super_glue-boolq-t5-prompt`
58 | - `super_glue-cb-t5-prompt`
59 | - `super_glue-copa-t5-prompt`
60 | - `super_glue-multirc-t5-prompt`
61 | - `super_glue-record-t5-prompt`
62 | - `super_glue-rte-t5-prompt`
63 | - `super_glue-wic-t5-prompt`
64 | - `super_glue-wsc-t5-prompt`
65 |
66 | ### Checklist
67 |
68 | For adding novel benchmarks/datasets to the library:
69 | * [ ] Is the task an existing benchmark in the literature?
70 | * [ ] Have you referenced the original paper that introduced the task?
71 | * [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
72 |
73 |
74 | If other tasks on this dataset are already supported:
75 | * [ ] Is the "Main" variant of this task clearly denoted?
76 | * [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
77 | * [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
78 |
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/tasks/super_glue/cb/aggregate.py:
--------------------------------------------------------------------------------
1 | import sklearn
2 | import numpy as np
3 |
4 |
5 | def cb_multi_fi(items):
6 | preds, golds = zip(*items)
7 | preds = np.array(preds)
8 | golds = np.array(golds)
9 | f11 = sklearn.metrics.f1_score(y_true=golds == 0, y_pred=preds == 0)
10 | f12 = sklearn.metrics.f1_score(y_true=golds == 1, y_pred=preds == 1)
11 | f13 = sklearn.metrics.f1_score(y_true=golds == 2, y_pred=preds == 2)
12 | avg_f1 = np.mean([f11, f12, f13])
13 | return avg_f1
14 |
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/tasks/super_glue/cb/t5_utils.py:
--------------------------------------------------------------------------------
1 | import sklearn.metrics
2 |
3 |
4 | def mean_3class_f1(predictions, references): # This is a passthrough function
5 | string_label = ["entailment", "contradiction", "neutral"]
6 | predictions = (
7 | string_label.index(predictions[0]) if predictions[0] in string_label else 0
8 | )
9 | references = string_label.index(references[0])
10 |
11 | return (predictions, references)
12 |
13 |
14 | def agg_mean_3class_f1(items):
15 | predictions, references = zip(*items)
16 |
17 | """Computes the unweighted average of the F1 per class."""
18 | metric_str = "fbeta_score"
19 | metric_fn_kwargs = {
20 | "beta": 1,
21 | "labels": range(3),
22 | "average": "macro",
23 | }
24 |
25 | def _fn(predictions, references):
26 | metric_fn = getattr(sklearn.metrics, metric_str)
27 | metric_val = metric_fn(references, predictions, **metric_fn_kwargs)
28 | return metric_val
29 |
30 | return _fn(predictions, references)
31 |
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/tasks/super_glue/copa/utils.py:
--------------------------------------------------------------------------------
1 | def convert_choice(choice):
2 | return choice[0].lower() + choice[1:]
3 |
4 |
5 | def doc_to_text(doc):
6 | # Drop the period
7 | connector = {
8 | "cause": "because",
9 | "effect": "therefore",
10 | }[doc["question"]]
11 | return doc["premise"].strip()[:-1] + f" {connector}"
12 |
13 |
14 | def doc_to_target(doc):
15 | correct_choice = doc["choice1"] if doc["label"] == 0 else doc["choice2"]
16 | # Connect the sentences
17 | return " " + convert_choice(correct_choice)
18 |
19 |
20 | def doc_to_choice(doc):
21 | return [" " + convert_choice(doc["choice1"]), " " + convert_choice(doc["choice2"])]
22 |
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/tasks/super_glue/multirc/t5_utils.py:
--------------------------------------------------------------------------------
1 | import collections
2 |
3 | import numpy as np
4 | import sklearn.metrics
5 |
6 |
7 | def f1(predictions, references): # This is a passthrough function
8 | _prediction = predictions[0]
9 | _reference = references[0].split("_")[-1]
10 | string_label = ["False", "True"]
11 | reference = string_label.index(_reference)
12 | prediction = (
13 | string_label.index(_prediction)
14 | if _prediction in string_label
15 | else not bool(reference)
16 | )
17 |
18 | return (prediction, reference)
19 |
20 |
21 | def agg_f1(items):
22 | predictions, references = zip(*items)
23 | references, predictions = np.asarray(references), np.asarray(predictions)
24 |
25 | return sklearn.metrics.f1_score(references, predictions)
26 |
27 |
28 | def em(predictions, references): # This is a passthrough function
29 | _prediction = predictions[0]
30 | _group, _reference = references[0].split("_")
31 | string_label = ["False", "True"]
32 | reference = string_label.index(_reference)
33 | prediction = (
34 | string_label.index(_prediction)
35 | if _prediction in string_label
36 | else not bool(reference)
37 | )
38 |
39 | return (_group, prediction, reference)
40 |
41 |
42 | def agg_em(items):
43 | grouped_values = collections.defaultdict(lambda: ([], []))
44 | for group, prediction, reference in items:
45 | grouped_values[group][0].append(reference)
46 | grouped_values[group][1].append(prediction)
47 |
48 | group_scores = []
49 | for group, (targets, predictions) in grouped_values.items():
50 | score = float(np.array_equal(targets, predictions))
51 | group_scores.append(score)
52 |
53 | return np.mean(group_scores)
54 |
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/tasks/super_glue/record/t5_utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | import string
3 | import collections
4 | import numpy as np
5 |
6 | from datasets import Dataset
7 |
8 | from lm_eval.api.metrics import metric_max_over_ground_truths
9 |
10 |
11 | def doc_to_text(doc):
12 | passage = doc["passage"]
13 | passage = re.sub(r"(\.|\?|\!|\"|\')\n@highlight\n", r"\1 ", passage)
14 | passage = re.sub(r"\n@highlight\n", ". ", passage)
15 |
16 | return " ".join(
17 | [
18 | "record query:",
19 | doc["query"],
20 | "entities:",
21 | ", ".join(doc["entities"]),
22 | "passage:",
23 | passage,
24 | ]
25 | )
26 |
27 |
28 | def process_docs(dataset):
29 | def split_answers(doc):
30 | split_doc = {
31 | **{k: [] for k in doc.keys()},
32 | }
33 | answers = doc.pop("answers")
34 | for idx, answer in enumerate(answers):
35 | for key in split_doc.keys():
36 | if key in doc:
37 | split_doc[key].append(doc[key])
38 |
39 | split_doc["answers"].append(answer)
40 | return split_doc
41 |
42 | dataset = dataset.map(split_answers)
43 | new_dataset = {}
44 | for key in dataset.features.keys():
45 | new_dataset[key] = [x for row in dataset[key] for x in row]
46 |
47 | return Dataset.from_dict(new_dataset)
48 |
49 |
50 | def normalize_squad(answer):
51 | """Normalization used in official SQuAD evaluation script."""
52 |
53 | def _normalize_answer(text, punc_chars, punc_repl):
54 | """Lower text and remove punctuation, articles and extra whitespace."""
55 |
56 | def remove_articles(s):
57 | return re.sub(r"\b(a|an|the)\b", " ", s)
58 |
59 | def replace_punctuation(s):
60 | to_replace = set(punc_chars)
61 | return "".join(punc_repl if ch in to_replace else ch for ch in s)
62 |
63 | def white_space_fix(s):
64 | return " ".join(s.split())
65 |
66 | text = text.lower()
67 | text = replace_punctuation(text)
68 | text = remove_articles(text)
69 | text = white_space_fix(text)
70 |
71 | return text
72 |
73 | return _normalize_answer(answer, punc_chars=string.punctuation, punc_repl="")
74 |
75 |
76 | def em(predictions, references): # This is a passthrough function
77 | return (predictions[0], references[0])
78 |
79 |
80 | def f1(predictions, references): # This is a passthrough function
81 | return (predictions[0], references[0])
82 |
83 |
84 | def squad_em_agg(items):
85 | def _exact_match_score(prediction, target):
86 | return target == prediction
87 |
88 | grouped_values = collections.defaultdict(lambda: ([], []))
89 | for prediction, reference in items:
90 | group, reference = reference.split("_")
91 | # if group not in grouped_values:
92 | grouped_values[group][0].append(normalize_squad(prediction))
93 | grouped_values[group][1].append(normalize_squad(reference))
94 |
95 | em = []
96 | for group in grouped_values.keys():
97 | predictions, targets = grouped_values[group]
98 | for p in predictions:
99 | em.append(metric_max_over_ground_truths(_exact_match_score, p, targets))
100 |
101 | return np.mean(em)
102 |
103 |
104 | def squad_f1_agg(items):
105 | def _f1_score(prediction, target):
106 | """Computes token f1 score for a single target and prediction."""
107 | prediction_tokens = prediction.split()
108 | target_tokens = target.split()
109 | common = collections.Counter(prediction_tokens) & collections.Counter(
110 | target_tokens
111 | )
112 | num_same = sum(common.values())
113 | if num_same == 0:
114 | return 0
115 | precision = 1.0 * num_same / len(prediction_tokens)
116 | recall = 1.0 * num_same / len(target_tokens)
117 | f1 = (2 * precision * recall) / (precision + recall)
118 | return f1
119 |
120 | grouped_values = collections.defaultdict(lambda: ([], []))
121 | for prediction, reference in items:
122 | group, reference = reference.split("_")
123 | if group not in grouped_values:
124 | grouped_values[group][0].append(normalize_squad(prediction))
125 | grouped_values[group][1].append(normalize_squad(reference))
126 |
127 | f1 = []
128 | for group in grouped_values.keys():
129 | p, t = grouped_values[group]
130 | f1.append(metric_max_over_ground_truths(_f1_score, p[0], t))
131 |
132 | return np.mean(f1)
133 |
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/tasks/super_glue/record/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import transformers.data.metrics.squad_metrics as squad_metrics
3 |
4 | from lm_eval.api.metrics import metric_max_over_ground_truths
5 |
6 |
7 | def doc_to_text(doc):
8 | initial_text, *highlights = doc["passage"].strip().split("\n@highlight\n")
9 | text = initial_text + "\n\n"
10 | for highlight in highlights:
11 | text += f" - {highlight}.\n"
12 | return text
13 |
14 |
15 | def format_answer(query, entity):
16 | return f" - {query}".replace("@placeholder", entity)
17 |
18 |
19 | def doc_to_target(doc):
20 | # We only output the first correct entity in a doc
21 | return format_answer(query=doc["query"], entity=doc["answers"][0])
22 |
23 |
24 | def process_results(doc, results):
25 | # ReCoRD's evaluation is actually deceptively simple:
26 | # - Pick the maximum likelihood prediction entity
27 | # - Evaluate the accuracy and token F1 PER EXAMPLE
28 | # - Average over all examples
29 | max_idx = np.argmax(np.array([result[0] for result in results]))
30 |
31 | prediction = doc["entities"][max_idx]
32 | gold_label_set = doc["answers"]
33 | f1 = metric_max_over_ground_truths(
34 | squad_metrics.compute_f1, prediction, gold_label_set
35 | )
36 | em = metric_max_over_ground_truths(
37 | squad_metrics.compute_exact, prediction, gold_label_set
38 | )
39 |
40 | return {
41 | "f1": f1,
42 | "em": em,
43 | }
44 |
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/tasks/super_glue/wsc/preprocess_wsc.py:
--------------------------------------------------------------------------------
1 | from lm_eval.utils import general_detokenize
2 |
3 |
4 | def default_doc_to_text(x):
5 | raw_passage = x["text"]
6 | # NOTE: HuggingFace span indices are word-based not character-based.
7 | pre = " ".join(raw_passage.split()[: x["span2_index"]])
8 | post = raw_passage[len(pre) + len(x["span2_text"]) + 1 :]
9 | passage = general_detokenize(pre + " *{}*".format(x["span2_text"]) + post)
10 | noun = x["span1_text"]
11 | pronoun = x["span2_text"]
12 | text = (
13 | f"Passage: {passage}\n"
14 | + f'Question: In the passage above, does the pronoun "*{pronoun}*" refer to "*{noun}*"?\n'
15 | + "Answer:"
16 | )
17 | return text
18 |
--------------------------------------------------------------------------------
/lm-eval-harness/lm_eval/tasks/super_glue/wsc/t5_utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import List
3 |
4 | def doc_to_text(x):
5 | text = re.sub(r" X ", " *" + x["span2_text"] + "* ", _wsc_inputs(x))
6 | return "wsc: " + text
7 |
8 |
9 | def _wsc_inputs(x):
10 | words = x["text"].split(" ")
11 |
12 | # We would need some special logic to handle the case where the pronoun is the
13 | # first or last word in the text. None of the examples in WSC seem to have
14 | # this, so we are ignoring these cases.
15 | assert x["span2_index"] > 0
16 | assert x["span2_index"] < len(words)
17 | pronoun_index = x["span2_index"]
18 |
19 | def create_input():
20 | assert words[pronoun_index] == x["span2_text"]
21 |
22 | return " ".join(
23 | [
24 | " ".join(words[:pronoun_index]),
25 | "X",
26 | " ".join(words[pronoun_index + 1:]),
27 | ]
28 | )
29 |
30 | # Handle some special cases.
31 | if (
32 | x["text"]
33 | == 'The boy continued to whip the pony , and eventually the pony threw him over. John laughed out quite loud. "Good for him," he said. '
34 | ):
35 | return (
36 | "The boy continued to whip the pony , and eventually the pony threw "
37 | 'him over. John laughed out quite loud. "Good for X ," he said.'
38 | )
39 |
40 | # Using the span2_index, we get 'use' instead of 'it'.
41 | if (
42 | x["text"]
43 | == "When they had eventually calmed down a bit , and had gotten home, Mr. Farley put the magic pebble in an iron safe . Some day they might want to use it , but really for now, what more could they wish for?"
44 | ):
45 | return (
46 | "When they had eventually calmed down a bit , and had gotten home, "
47 | "Mr. Farley put the magic pebble in an iron safe . Some day they might "
48 | "want to use X , but really for now, what more could they wish for?"
49 | )
50 |
51 | return create_input()
52 |
53 |
54 | DETERMINERS = {
55 | "a",
56 | "an",
57 | "few",
58 | "her",
59 | "his",
60 | "each",
61 | "every",
62 | "many",
63 | "much",
64 | "my",
65 | "our",
66 | "some",
67 | "that",
68 | "the",
69 | "their",
70 | "these",
71 | "this",
72 | "those",
73 | "which",
74 | "whose",
75 | "your",
76 | }
77 |
78 |
79 | def clean(s: str) -> str:
80 | """Ignore capitalization and determiners."""
81 | s = s.strip().lower()
82 | return " ".join([w for w in s.split(" ") if w not in DETERMINERS])
83 |
84 |
85 | def process_results(docs: dict, resps: List):
86 | prediction = clean(resps[0])
87 | reference = clean(docs["span1_text"])
88 |
89 | if ("'" in prediction) != ("'" in reference):
90 | # referent is "Bob's hat" as predicting the referent.
91 | predicted_referent = False
92 | else:
93 | prediction_words = set(prediction.split(" "))
94 | referent_words = set(reference.split(" "))
95 |
96 | # Handle cases where the prediction is "fuzzy bunny" and the referent is
97 | # "bunny".
98 | predicted_referent = prediction_words.issubset(
99 | referent_words
100 | ) or referent_words.issubset(prediction_words)
101 |
102 | acc = 1.0 if predicted_referent == docs["label"] else 0.0
103 | return {"accuracy": acc}
104 |
--------------------------------------------------------------------------------
/lm-eval-harness/prompt_scripts/run_jrt_prompt_hazy.sh:
--------------------------------------------------------------------------------
1 | ### JRT ArXiv Table 1 ###
2 |
3 | output_dir="run_jrt_prompt_hazy"
4 | limit=-1
5 |
6 |
7 | # Default and twice SWDE context length 1000
8 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python launch_jrt.py \
9 | --batch-size 32 \
10 | -m hazyresearch/based-1b-50b \
11 | -m hazyresearch/mamba-1b-50b \
12 | -m hazyresearch/attn-1b-50bn \
13 | -t based_swde \
14 | -t based_swde_twice \
15 | --output_dir ${output_dir} \
16 | --context_length 1000 \
17 | --answer_length 50 \
18 | --cutting_context \
19 | --limit ${limit} \
20 | -p
21 |
22 |
23 | # Default and twice FDA at context length 1000
24 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python launch_jrt.py \
25 | --batch-size 32 \
26 | -m hazyresearch/based-1b-50b \
27 | -m hazyresearch/mamba-1b-50b \
28 | -m hazyresearch/attn-1b-50bn \
29 | -t based_fda \
30 | -t based_fda_twice \
31 | --output_dir ${output_dir} \
32 | --context_length 1000 \
33 | --answer_length 50 \
34 | --cutting_context \
35 | --limit ${limit} \
36 | -p
37 |
38 |
39 | # Default and twice SQUAD completion at context length 1000
40 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python launch_jrt.py \
41 | --batch-size 32 \
42 | -m hazyresearch/based-1b-50b \
43 | -m hazyresearch/mamba-1b-50b \
44 | -m hazyresearch/attn-1b-50bn \
45 | -t based_squad_twice \
46 | -t based_squad \
47 | --output_dir ${output_dir} \
48 | --context_length 1000 \
49 | --answer_length 50 \
50 | --cutting_context \
51 | --limit ${limit} \
52 | -p
53 |
54 |
55 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python launch_jrt.py \
56 | --batch-size 32 \
57 | -m hazyresearch/based-1b-50b \
58 | -m hazyresearch/mamba-1b-50b \
59 | -m hazyresearch/attn-1b-50bn \
60 | -t based_drop_twice \
61 | -t based_drop \
62 | --output_dir ${output_dir} \
63 | --context_length 1000 \
64 | --answer_length 50 \
65 | --cutting_context \
66 | --limit ${limit} \
67 | -p
68 |
69 |
70 | CUDA_VISIBLE_DEVICES=10,1,2,3,4,5,6,7 python launch_jrt.py \
71 | --batch-size 32 \
72 | -m hazyresearch/based-1b-50b \
73 | -m hazyresearch/mamba-1b-50b \
74 | -m hazyresearch/attn-1b-50bn \
75 | -t based_nq_1024_twice \
76 | -t based_nq_1024 \
77 | --output_dir ${output_dir} \
78 | --context_length 1000 \
79 | --answer_length 50 \
80 | --cutting_context \
81 | --limit ${limit} \
82 | -p
83 |
84 |
85 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python launch_jrt.py \
86 | --batch-size 32 \
87 | -m hazyresearch/based-1b-50b \
88 | -m hazyresearch/mamba-1b-50b \
89 | -m hazyresearch/attn-1b-50bn \
90 | -t based_triviaqa_twice \
91 | -t based_triviaqa \
92 | --output_dir ${output_dir} \
93 | --context_length 1000 \
94 | --answer_length 50 \
95 | --cutting_context \
96 | --limit ${limit} \
97 | -p
98 |
99 |
--------------------------------------------------------------------------------
/lm-eval-harness/prompt_scripts/run_jrt_prompt_hf.sh:
--------------------------------------------------------------------------------
1 | ### JRT ArXiv Table 1 ###
2 |
3 | output_dir="run_jrt_prompt_hf"
4 | limit=-1
5 |
6 | # Default and twice SWDE context length 1000
7 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python launch_hf.py \
8 | --batch-size 32 \
9 | -m "fla-hub/gla-1.3B-100B" \
10 | -m "fla-hub/gla-2.7B-100B" \
11 | -m "state-spaces/mamba-130m" \
12 | -m "state-spaces/mamba-370m" \
13 | -m "state-spaces/mamba-1.4b" \
14 | -m "state-spaces/mamba-2.8b" \
15 | -m "state-spaces/mamba2-130m" \
16 | -m "state-spaces/mamba2-370m" \
17 | -m "state-spaces/mamba2-1.3b" \
18 | -m "state-spaces/mamba2-2.7b" \
19 | -t based_swde \
20 | -t based_swde_twice \
21 | --output_dir ${output_dir} \
22 | --context_length 1000 \
23 | --answer_length 50 \
24 | --cutting_context \
25 | --limit ${limit} \
26 | -p
27 |
28 |
29 | # Default and twice FDA at context length 1000
30 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python launch_hf.py \
31 | --batch-size 32 \
32 | -m "fla-hub/gla-1.3B-100B" \
33 | -m "fla-hub/gla-2.7B-100B" \
34 | -m "state-spaces/mamba-130m" \
35 | -m "state-spaces/mamba-370m" \
36 | -m "state-spaces/mamba-1.4b" \
37 | -m "state-spaces/mamba-2.8b" \
38 | -m "state-spaces/mamba2-130m" \
39 | -m "state-spaces/mamba2-370m" \
40 | -m "state-spaces/mamba2-1.3b" \
41 | -m "state-spaces/mamba2-2.7b" \
42 | -t based_fda \
43 | -t based_fda_twice \
44 | --output_dir ${output_dir} \
45 | --context_length 1000 \
46 | --answer_length 50 \
47 | --cutting_context \
48 | --limit ${limit} \
49 | -p
50 |
51 |
52 | # Default and twice SQUAD completion at context length 1000
53 | CUDA_VISIBLE_DEVICES=1,2,3,5,6,7 python launch_hf.py \
54 | --batch-size 32 \
55 | -m "fla-hub/gla-1.3B-100B" \
56 | -m "fla-hub/gla-2.7B-100B" \
57 | -m "state-spaces/mamba-130m" \
58 | -m "state-spaces/mamba-370m" \
59 | -m "state-spaces/mamba-1.4b" \
60 | -m "state-spaces/mamba-2.8b" \
61 | -m "state-spaces/mamba2-130m" \
62 | -m "state-spaces/mamba2-370m" \
63 | -m "state-spaces/mamba2-1.3b" \
64 | -m "state-spaces/mamba2-2.7b" \
65 | -t based_squad_twice \
66 | -t based_squad \
67 | --output_dir ${output_dir} \
68 | --context_length 1000 \
69 | --answer_length 50 \
70 | --cutting_context \
71 | --limit ${limit} \
72 | -p
73 |
74 |
75 | CUDA_VISIBLE_DEVICES=1,2,3,5,6,7 python launch_hf.py \
76 | --batch-size 32 \
77 | -m "fla-hub/gla-1.3B-100B" \
78 | -m "fla-hub/gla-2.7B-100B" \
79 | -m "state-spaces/mamba-130m" \
80 | -m "state-spaces/mamba-370m" \
81 | -m "state-spaces/mamba-1.4b" \
82 | -m "state-spaces/mamba-2.8b" \
83 | -m "state-spaces/mamba2-130m" \
84 | -m "state-spaces/mamba2-370m" \
85 | -m "state-spaces/mamba2-1.3b" \
86 | -m "state-spaces/mamba2-2.7b" \
87 | -t based_drop_twice \
88 | -t based_drop \
89 | --output_dir ${output_dir} \
90 | --context_length 1000 \
91 | --answer_length 50 \
92 | --cutting_context \
93 | --limit ${limit} \
94 | -p
95 |
96 |
97 | CUDA_VISIBLE_DEVICES=1,2,3,5,6,7 python launch_hf.py \
98 | --batch-size 32 \
99 | -m "fla-hub/gla-1.3B-100B" \
100 | -m "fla-hub/gla-2.7B-100B" \
101 | -m "state-spaces/mamba-130m" \
102 | -m "state-spaces/mamba-370m" \
103 | -m "state-spaces/mamba-1.4b" \
104 | -m "state-spaces/mamba-2.8b" \
105 | -m "state-spaces/mamba2-130m" \
106 | -m "state-spaces/mamba2-370m" \
107 | -m "state-spaces/mamba2-1.3b" \
108 | -m "state-spaces/mamba2-2.7b" \
109 | -t based_nq_1024_twice \
110 | -t based_nq_1024 \
111 | --output_dir ${output_dir} \
112 | --context_length 1000 \
113 | --answer_length 50 \
114 | --cutting_context \
115 | --limit ${limit} \
116 | -p
117 |
118 |
119 | CUDA_VISIBLE_DEVICES=1,2,3,5,6,7 python launch_hf.py \
120 | --batch-size 32 \
121 | -m "fla-hub/gla-1.3B-100B" \
122 | -m "fla-hub/gla-2.7B-100B" \
123 | -m "state-spaces/mamba-130m" \
124 | -m "state-spaces/mamba-370m" \
125 | -m "state-spaces/mamba-1.4b" \
126 | -m "state-spaces/mamba-2.8b" \
127 | -m "state-spaces/mamba2-130m" \
128 | -m "state-spaces/mamba2-370m" \
129 | -m "state-spaces/mamba2-1.3b" \
130 | -m "state-spaces/mamba2-2.7b" \
131 | -t based_triviaqa_twice \
132 | -t based_triviaqa \
133 | --output_dir ${output_dir} \
134 | --context_length 1000 \
135 | --answer_length 50 \
136 | --cutting_context \
137 | --limit ${limit} \
138 | -p
139 |
140 |
--------------------------------------------------------------------------------
/lm-eval-harness/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=40.8.0", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "lm_eval"
7 | version = "0.4.1"
8 | authors = [
9 | {name="EleutherAI", email="contact@eleuther.ai"}
10 | ]
11 | description = "A framework for evaluating language models"
12 | readme = "README.md"
13 | classifiers = [
14 | "Development Status :: 3 - Alpha",
15 | "Programming Language :: Python :: 3",
16 | "License :: OSI Approved :: MIT License",
17 | "Operating System :: OS Independent",
18 | ]
19 | requires-python = ">=3.8"
20 | license = { "text" = "MIT" }
21 | dependencies = [
22 | "accelerate>=0.21.0",
23 | "evaluate",
24 | "datasets>=2.14.0",
25 | "evaluate>=0.4.0",
26 | "jsonlines",
27 | "numexpr",
28 | "peft>=0.2.0",
29 | "pybind11>=2.6.2",
30 | "pytablewriter",
31 | "rouge-score>=0.0.4",
32 | "sacrebleu>=1.5.0",
33 | "scikit-learn>=0.24.1",
34 | "sqlitedict",
35 | "torch>=1.8",
36 | "tqdm-multiprocess",
37 | "transformers>=4.1",
38 | "zstandard",
39 | ]
40 |
41 | [tool.setuptools.packages.find]
42 | include = ["lm_eval*"]
43 |
44 | # required to include yaml files in pip installation
45 | [tool.setuptools.package-data]
46 | lm_eval = ["**/*.yaml", "tasks/**/*"]
47 |
48 | [project.scripts]
49 | lm-eval = "lm_eval.__main__:cli_evaluate"
50 | lm_eval = "lm_eval.__main__:cli_evaluate"
51 |
52 | [project.urls]
53 | Homepage = "https://github.com/EleutherAI/lm-evaluation-harness"
54 | Repository = "https://github.com/EleutherAI/lm-evaluation-harness"
55 |
56 | [project.optional-dependencies]
57 | anthropic = ["anthropic"]
58 | dev = ["pytest", "pytest-cov", "pytest-xdist", "pre-commit", "mypy"]
59 | gptq = ["auto-gptq[triton]>=0.6.0"]
60 | hf_transfer = ["hf_transfer"]
61 | ifeval = ["langdetect", "immutabledict"]
62 | neuronx = ["optimum[neuronx]"]
63 | mamba = ["mamba_ssm", "causal-conv1d==1.0.2"]
64 | math = ["sympy>=1.12", "antlr4-python3-runtime==4.11"]
65 | multilingual = ["nagisa>=0.2.7", "jieba>=0.42.1", "pycountry"]
66 | openai = ["openai==1.3.9", "tiktoken"]
67 | optimum = ["optimum[openvino]"]
68 | promptsource = ["promptsource>=0.2.3"]
69 | sentencepiece = ["sentencepiece>=0.1.98", "protobuf>=4.22.1"]
70 | testing = ["pytest", "pytest-cov", "pytest-xdist"]
71 | vllm = ["vllm<=0.2.5"]
72 | zeno = ["pandas", "zeno-client"]
73 | all = [
74 | "lm_eval[anthropic]",
75 | "lm_eval[dev]",
76 | "lm_eval[gptq]",
77 | "lm_eval[hf_transfer]",
78 | "lm_eval[ifeval]",
79 | "lm_eval[mamba]",
80 | "lm_eval[math]",
81 | "lm_eval[multilingual]",
82 | "lm_eval[openai]",
83 | "lm_eval[promptsource]",
84 | "lm_eval[sentencepiece]",
85 | "lm_eval[testing]",
86 | "lm_eval[vllm]",
87 | "lm_eval[zeno]",
88 | ]
89 |
90 | [tool.ruff]
91 | extend-exclude = ["lm_eval/tasks/*.py"]
92 |
93 | [tool.ruff.lint]
94 | extend-select = ["I"]
95 |
96 | [tool.ruff.isort]
97 | lines-after-imports = 2
98 | known-first-party = ["lm_eval"]
99 |
100 | [tool.ruff.extend-per-file-ignores]
101 | "__init__.py" = ["F401","F402","F403","I"]
102 | "lm_eval/tasks/*"= ["E721"]
103 |
--------------------------------------------------------------------------------
/lm-eval-harness/requirements.txt:
--------------------------------------------------------------------------------
1 | -e .
2 |
--------------------------------------------------------------------------------
/lm-eval-harness/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 |
4 | # This is to make sure that the package supports editable installs
5 | setuptools.setup()
6 |
--------------------------------------------------------------------------------
/lm-eval-harness/templates/new_yaml_task/README.md:
--------------------------------------------------------------------------------
1 | # Task-name
2 |
3 | ### Paper
4 |
5 | Title: `paper titles goes here`
6 |
7 | Abstract: `link to paper PDF or arXiv abstract goes here`
8 |
9 | `Short description of paper / benchmark goes here:`
10 |
11 | Homepage: `homepage to the benchmark's website goes here, if applicable`
12 |
13 |
14 | ### Citation
15 |
16 | ```
17 | BibTeX-formatted citation goes here
18 | ```
19 |
20 | ### Groups and Tasks
21 |
22 | #### Groups
23 |
24 | * `group_name`: `Short description`
25 |
26 | #### Tasks
27 |
28 | * `task_name`: `1-sentence description of what this particular task does`
29 | * `task_name2`: ...
30 |
31 | ### Checklist
32 |
33 | For adding novel benchmarks/datasets to the library:
34 | * [ ] Is the task an existing benchmark in the literature?
35 | * [ ] Have you referenced the original paper that introduced the task?
36 | * [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
37 |
38 |
39 | If other tasks on this dataset are already supported:
40 | * [ ] Is the "Main" variant of this task clearly denoted?
41 | * [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
42 | * [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
43 |
--------------------------------------------------------------------------------
/mom/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from mom.layers import MomGatedDeltaNet, MomLinearAttention, MomGatedSlotAttention, MomGatedLinearAttention
4 | from mom.models import (MomGatedDeltaNetForCausalLM, MomGatedDeltaNetModel,
5 | MomLinearAttentionForCausalLM, MomLinearAttentionModel,
6 | MomGLAForCausalLM, MomGLAModel,
7 | MomGSAForCausalLM, MomGSAModel)
8 |
9 | __all__ = [
10 | 'MomGatedDeltaNet',
11 | 'MomGatedLinearAttention',
12 | 'MomGatedSlotAttention',
13 | 'MomLinearAttention',
14 | 'MomGatedDeltaNetForCausalLM',
15 | 'MomGatedDeltaNetModel',
16 | 'MomGLAForCausalLM',
17 | 'MomGLAModel',
18 | 'MomGSAForCausalLM',
19 | 'MomGSAModel',
20 | 'MomLinearAttentionForCausalLM',
21 | 'MomLinearAttentionModel',
22 | ]
23 |
24 | __version__ = '0.1'
25 |
--------------------------------------------------------------------------------
/mom/layers/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from .mom_gated_deltanet import MomGatedDeltaNet
4 | from .mom_gla import MomGatedLinearAttention
5 | from .mom_gsa import MomGatedSlotAttention
6 | from .mom_linear_attn import MomLinearAttention
7 |
8 | __all__ = [
9 | 'MomGatedDeltaNet',
10 | 'MomGatedLinearAttention',
11 | 'MomGatedSlotAttention',
12 | 'MomLinearAttention',
13 | ]
14 |
--------------------------------------------------------------------------------
/mom/models/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from mom.models.mom_gla import MomGLAConfig, MomGLAForCausalLM, MomGLAModel
4 | from mom.models.mom_gsa import MomGSAConfig, MomGSAForCausalLM, MomGSAModel
5 | from mom.models.mom_linear_attn import (MomLinearAttentionConfig,
6 | MomLinearAttentionForCausalLM,
7 | MomLinearAttentionModel)
8 |
9 | from mom.models.mom_gated_deltanet import MomGatedDeltaNetConfig, MomGatedDeltaNetForCausalLM, MomGatedDeltaNetModel
10 |
11 | __all__ = [
12 | 'MomGLAConfig', 'MomGLAForCausalLM', 'MomGLAModel',
13 | 'MomGSAConfig', 'MomGSAForCausalLM', 'MomGSAModel',
14 | 'MomLinearAttentionConfig', 'MomLinearAttentionForCausalLM', 'MomLinearAttentionModel',
15 | 'MomGatedDeltaNetConfig', 'MomGatedDeltaNetForCausalLM', 'MomGatedDeltaNetModel'
16 | ]
17 |
--------------------------------------------------------------------------------
/mom/models/mom_gated_deltanet/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4 |
5 | from mom.models.mom_gated_deltanet.configuration_mom_gated_deltanet import \
6 | MomGatedDeltaNetConfig
7 | from mom.models.mom_gated_deltanet.modeling_mom_gated_deltanet import (
8 | MomGatedDeltaNetForCausalLM, MomGatedDeltaNetModel)
9 |
10 | AutoConfig.register(MomGatedDeltaNetConfig.model_type, MomGatedDeltaNetConfig)
11 | AutoModel.register(MomGatedDeltaNetConfig, MomGatedDeltaNetModel)
12 | AutoModelForCausalLM.register(MomGatedDeltaNetConfig, MomGatedDeltaNetForCausalLM)
13 |
14 | __all__ = ['MomGatedDeltaNetConfig', 'MomGatedDeltaNetForCausalLM', 'MomGatedDeltaNetModel']
15 |
--------------------------------------------------------------------------------
/mom/models/mom_gated_deltanet/configuration_mom_gated_deltanet.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from typing import Dict, Optional
4 |
5 | from transformers.configuration_utils import PretrainedConfig
6 |
7 |
8 | class MomGatedDeltaNetConfig(PretrainedConfig):
9 | model_type = 'mom_gated_deltanet'
10 | keys_to_ignore_at_inference = ['past_key_values']
11 | def __init__(
12 | self,
13 | attn_mode: str = "chunk",
14 | hidden_size: int = 2048,
15 | expand_v: int = 2,
16 | use_gate: bool = True,
17 | use_short_conv: bool = True,
18 | conv_size: int = 4,
19 | head_dim: int = 256,
20 | num_heads: int = 6,
21 | max_position_embeddings: int = 2048,
22 | hidden_ratio: Optional[int] = 4,
23 | intermediate_size: Optional[int] = None,
24 | hidden_act: str = "swish",
25 | num_hidden_layers: int = 21,
26 | norm_first: bool = False,
27 | norm_eps: float = 1e-6,
28 | attn: Optional[Dict] = None,
29 | use_cache: bool = True,
30 | pad_token_id: int = None,
31 | bos_token_id: int = 1,
32 | eos_token_id: int = 2,
33 | tie_word_embeddings: bool = False,
34 | initializer_range: float = 0.02,
35 | fuse_cross_entropy: bool = True,
36 | vocab_size: int = 32000,
37 | num_memories: int = 8,
38 | topk: int = 2,
39 | capacity: float = 1.0,
40 | use_layer_wise_balance: bool=True,
41 | aux_loss_scale: float=0.01,
42 | shared_mem: bool = False,
43 | single_kv_proj: bool = False,
44 | **kwargs
45 | ):
46 | self.attn_mode = attn_mode
47 | self.hidden_size = hidden_size
48 | self.expand_v = expand_v
49 | self.use_gate = use_gate
50 | self.use_short_conv = use_short_conv
51 | self.conv_size = conv_size
52 | self.head_dim = head_dim
53 | self.num_heads = num_heads
54 | self.max_position_embeddings = max_position_embeddings
55 |
56 | self.hidden_ratio = hidden_ratio
57 | self.intermediate_size = intermediate_size
58 | self.hidden_act = hidden_act
59 | self.num_hidden_layers = num_hidden_layers
60 | self.norm_first = norm_first
61 | self.norm_eps = norm_eps
62 | self.attn = attn
63 | self.use_cache = use_cache
64 | self.initializer_range = initializer_range
65 | self.fuse_cross_entropy = fuse_cross_entropy
66 | self.vocab_size = vocab_size
67 | self.num_memories = num_memories
68 | self.topk = topk
69 | self.capacity = capacity
70 | self.use_layer_wise_balance = use_layer_wise_balance
71 | self.aux_loss_scale = aux_loss_scale
72 | self.shared_mem = shared_mem
73 | self.single_kv_proj = single_kv_proj
74 |
75 | if attn is not None:
76 | if not isinstance(attn, Dict):
77 | raise ValueError("attn must be a dictionary")
78 | if 'layers' not in attn:
79 | raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
80 | if 'num_heads' not in attn:
81 | raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
82 | attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
83 | attn['window_size'] = attn.get('window_size', None)
84 |
85 | super().__init__(
86 | pad_token_id=pad_token_id,
87 | bos_token_id=bos_token_id,
88 | eos_token_id=eos_token_id,
89 | tie_word_embeddings=tie_word_embeddings,
90 | **kwargs,
91 | )
--------------------------------------------------------------------------------
/mom/models/mom_gla/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4 |
5 | from mom.models.mom_gla.configuration_mom_gla import MomGLAConfig
6 | from mom.models.mom_gla.modeling_mom_gla import MomGLAForCausalLM, MomGLAModel
7 |
8 | AutoConfig.register(MomGLAConfig.model_type, MomGLAConfig)
9 | AutoModel.register(MomGLAConfig, MomGLAModel)
10 | AutoModelForCausalLM.register(MomGLAConfig, MomGLAForCausalLM)
11 |
12 |
13 | __all__ = ['MomGLAConfig', 'MomGLAForCausalLM', 'MomGLAModel']
14 |
--------------------------------------------------------------------------------
/mom/models/mom_gla/configuration_mom_gla.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from typing import Dict, Optional
4 |
5 | from transformers.configuration_utils import PretrainedConfig
6 |
7 |
8 | class MomGLAConfig(PretrainedConfig):
9 |
10 | model_type = 'mom_gla'
11 | keys_to_ignore_at_inference = ['past_key_values']
12 |
13 | def __init__(
14 | self,
15 | hidden_size: int = 2048,
16 | expand_k: int = 0.5,
17 | expand_v: int = 1,
18 | hidden_ratio: Optional[int] = 4,
19 | intermediate_size: Optional[int] = None,
20 | num_hidden_layers: int = 24,
21 | num_heads: int = 4,
22 | num_kv_heads: Optional[int] = None,
23 | feature_map: Optional[str] = None,
24 | attn_mode: str = "chunk",
25 | use_short_conv: bool = False,
26 | conv_size: int = 4,
27 | use_output_gate: bool = True,
28 | clamp_min: Optional[float] = None,
29 | hidden_act: str = "swish",
30 | max_position_embeddings: int = 2048,
31 | elementwise_affine: Optional[bool] = True,
32 | norm_eps: float = 1e-6,
33 | use_gk: bool = True,
34 | use_gv: bool = False,
35 | attn: Optional[Dict] = None,
36 | use_cache: bool = True,
37 | pad_token_id: int = None,
38 | bos_token_id: int = 1,
39 | eos_token_id: int = 2,
40 | tie_word_embeddings: bool = False,
41 | initializer_range: float = 0.02,
42 | fuse_norm: bool = True,
43 | fuse_cross_entropy: bool = True,
44 | vocab_size: int = 32000,
45 | num_memories: int = 8,
46 | topk: int = 2,
47 | capacity: float = 1.0,
48 | use_layer_wise_balance: bool=True,
49 | aux_loss_scale: float=0.01,
50 | shared_mem: bool = False,
51 | single_kv_proj: bool = False,
52 | **kwargs
53 | ):
54 | self.hidden_size = hidden_size
55 | self.expand_k = expand_k
56 | self.expand_v = expand_v
57 | self.hidden_ratio = hidden_ratio
58 | self.intermediate_size = intermediate_size
59 | self.num_hidden_layers = num_hidden_layers
60 | self.num_heads = num_heads
61 | self.num_kv_heads = num_kv_heads
62 | self.feature_map = feature_map
63 | self.attn_mode = attn_mode
64 | self.use_short_conv = use_short_conv
65 | self.conv_size = conv_size
66 | self.use_output_gate = use_output_gate
67 | self.clamp_min = clamp_min
68 | self.hidden_act = hidden_act
69 | self.max_position_embeddings = max_position_embeddings
70 | self.elementwise_affine = elementwise_affine
71 | self.norm_eps = norm_eps
72 | self.use_gk = use_gk
73 | self.use_gv = use_gv
74 | self.attn = attn
75 | self.use_cache = use_cache
76 | self.initializer_range = initializer_range
77 | self.fuse_norm = fuse_norm
78 | self.fuse_cross_entropy = fuse_cross_entropy
79 | self.vocab_size = vocab_size
80 | self.num_memories = num_memories
81 | self.topk = topk
82 | self.capacity = capacity
83 | self.use_layer_wise_balance = use_layer_wise_balance
84 | self.aux_loss_scale = aux_loss_scale
85 | self.shared_mem = shared_mem
86 | self.single_kv_proj = single_kv_proj
87 |
88 | if attn is not None:
89 | if not isinstance(attn, Dict):
90 | raise ValueError("attn must be a dictionary")
91 | if 'layers' not in attn:
92 | raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
93 | if 'num_heads' not in attn:
94 | raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
95 | attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
96 | attn['window_size'] = attn.get('window_size', None)
97 |
98 | super().__init__(
99 | pad_token_id=pad_token_id,
100 | bos_token_id=bos_token_id,
101 | eos_token_id=eos_token_id,
102 | tie_word_embeddings=tie_word_embeddings,
103 | **kwargs,
104 | )
105 |
--------------------------------------------------------------------------------
/mom/models/mom_gsa/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4 |
5 | from mom.models.mom_gsa.configuration_mom_gsa import MomGSAConfig
6 | from mom.models.mom_gsa.modeling_mom_gsa import MomGSAForCausalLM, MomGSAModel
7 |
8 | AutoConfig.register(MomGSAConfig.model_type, MomGSAConfig)
9 | AutoModel.register(MomGSAConfig, MomGSAModel)
10 | AutoModelForCausalLM.register(MomGSAConfig, MomGSAForCausalLM)
11 |
12 |
13 | __all__ = ['MomGSAConfig', 'MomGSAForCausalLM', 'MomGSAModel']
14 |
--------------------------------------------------------------------------------
/mom/models/mom_gsa/configuration_mom_gsa.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from typing import Dict, Optional
4 |
5 | from transformers.configuration_utils import PretrainedConfig
6 |
7 |
8 | class MomGSAConfig(PretrainedConfig):
9 |
10 | model_type = 'mom_gsa'
11 | keys_to_ignore_at_inference = ['past_key_values']
12 |
13 | def __init__(
14 | self,
15 | hidden_size: int = 2048,
16 | gate_logit_normalizer: Optional[int] = 8,
17 | clamp_min: Optional[float] = None,
18 | clamp_max: Optional[float] = None,
19 | hidden_ratio: Optional[int] = 4,
20 | intermediate_size: Optional[int] = None,
21 | num_hidden_layers: int = 24,
22 | num_heads: int = 4,
23 | num_kv_heads: Optional[int] = None,
24 | num_slots: Optional[int] = 64,
25 | use_short_conv: bool = False,
26 | conv_size: int = 4,
27 | exapnd_k: float = 1,
28 | exapnd_v: float = 1,
29 | feature_map: str = 'swish',
30 | use_output_gate: bool = False,
31 | use_norm: bool = True,
32 | max_position_embeddings: int = 2048,
33 | hidden_act: str = "swish",
34 | elementwise_affine: Optional[bool] = True,
35 | norm_first: bool = True,
36 | norm_eps: float = 1e-6,
37 | attn: Optional[Dict] = None,
38 | use_cache: bool = True,
39 | pad_token_id: int = None,
40 | bos_token_id: int = 1,
41 | eos_token_id: int = 2,
42 | initializer_range: float = 0.02,
43 | tie_word_embeddings: bool = False,
44 | fuse_norm: bool = True,
45 | fuse_cross_entropy: bool = True,
46 | vocab_size: int = 32000,
47 | num_experts: int = 8,
48 | topk: int = 2,
49 | capacity: float = 1.0,
50 | use_layer_wise_balance: bool=True,
51 | aux_loss_scale: float=0.01,
52 | shared_mem: bool = False,
53 | **kwargs
54 | ):
55 | self.hidden_size = hidden_size
56 | self.gate_logit_normalizer = gate_logit_normalizer
57 | self.clamp_min = clamp_min
58 | self.clamp_max = clamp_max
59 | self.hidden_ratio = hidden_ratio
60 | self.intermediate_size = intermediate_size
61 | self.num_hidden_layers = num_hidden_layers
62 | self.num_heads = num_heads
63 | self.num_kv_heads = num_kv_heads
64 | self.num_slots = num_slots
65 | self.use_short_conv = use_short_conv
66 | self.conv_size = conv_size
67 | self.expand_k = exapnd_k
68 | self.expand_v = exapnd_v
69 | self.feature_map = feature_map
70 | self.use_output_gate = use_output_gate
71 | self.use_norm = use_norm
72 | self.max_position_embeddings = max_position_embeddings
73 | self.hidden_act = hidden_act
74 | self.elementwise_affine = elementwise_affine
75 | self.norm_first = norm_first
76 | self.norm_eps = norm_eps
77 | self.attn = attn
78 | self.use_cache = use_cache
79 | self.initializer_range = initializer_range
80 | self.fuse_cross_entropy = fuse_cross_entropy
81 | self.fuse_norm = fuse_norm
82 | self.vocab_size = vocab_size
83 | self.num_experts = num_experts
84 | self.topk = topk
85 | self.capacity = capacity
86 | self.use_layer_wise_balance = use_layer_wise_balance
87 | self.aux_loss_scale = aux_loss_scale
88 | self.shared_mem = shared_mem
89 |
90 | if attn is not None:
91 | if not isinstance(attn, Dict):
92 | raise ValueError("attn must be a dictionary")
93 | if 'layers' not in attn:
94 | raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
95 | if 'num_heads' not in attn:
96 | raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
97 | attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
98 | attn['window_size'] = attn.get('window_size', None)
99 |
100 | super().__init__(
101 | pad_token_id=pad_token_id,
102 | bos_token_id=bos_token_id,
103 | eos_token_id=eos_token_id,
104 | tie_word_embeddings=tie_word_embeddings,
105 | **kwargs,
106 | )
107 |
--------------------------------------------------------------------------------
/mom/models/mom_linear_attn/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4 |
5 | from mom.models.mom_linear_attn.configuration_mom_linear_attn import \
6 | MomLinearAttentionConfig
7 | from mom.models.mom_linear_attn.modeling_mom_linear_attn import (
8 | MomLinearAttentionForCausalLM, MomLinearAttentionModel)
9 |
10 | AutoConfig.register(MomLinearAttentionConfig.model_type, MomLinearAttentionConfig)
11 | AutoModel.register(MomLinearAttentionConfig, MomLinearAttentionModel)
12 | AutoModelForCausalLM.register(MomLinearAttentionConfig, MomLinearAttentionForCausalLM)
13 |
14 | __all__ = ['MomLinearAttentionConfig', 'MomLinearAttentionForCausalLM', 'MomLinearAttentionModel']
15 |
--------------------------------------------------------------------------------
/mom/models/mom_linear_attn/configuration_mom_linear_attn.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from typing import Dict, Optional
4 |
5 | from transformers.configuration_utils import PretrainedConfig
6 |
7 |
8 | class MomLinearAttentionConfig(PretrainedConfig):
9 |
10 | model_type = 'mom_linear_attn'
11 | keys_to_ignore_at_inference = ['past_key_values']
12 |
13 | def __init__(
14 | self,
15 | attn_mode: str = "fused_chunk",
16 | hidden_size: int = 2048,
17 | expand_k: int = 1,
18 | expand_v: int = 1,
19 | hidden_ratio: Optional[int] = 4,
20 | intermediate_size: Optional[int] = None,
21 | num_hidden_layers: int = 24,
22 | num_heads: int = 4,
23 | num_kv_heads: Optional[int] = None,
24 | feature_map: str = "elementwise_product",
25 | tie_feature_map_qk: bool = False,
26 | norm_q: bool = False,
27 | norm_k: bool = False,
28 | norm_feature_map: bool = False,
29 | hidden_act: str = "swish",
30 | max_position_embeddings: int = 2048,
31 | elementwise_affine: Optional[bool] = True,
32 | norm_eps: float = 1e-6,
33 | attn: Optional[Dict] = None,
34 | use_cache: bool = True,
35 | pad_token_id: int = None,
36 | bos_token_id: int = 1,
37 | eos_token_id: int = 2,
38 | tie_word_embeddings: bool = False,
39 | initializer_range: float = 0.02,
40 | fuse_cross_entropy: bool = True,
41 | vocab_size: int = 32000,
42 | num_memories: int = 8,
43 | topk: int = 2,
44 | capacity: float = 1.0,
45 | use_layer_wise_balance: bool=True,
46 | aux_loss_scale: float=0.01,
47 | shared_mem: bool = False,
48 | **kwargs
49 | ):
50 | self.attn_mode = attn_mode
51 | self.hidden_size = hidden_size
52 | self.expand_k = expand_k
53 | self.expand_v = expand_v
54 | self.hidden_ratio = hidden_ratio
55 | self.intermediate_size = intermediate_size
56 | self.num_hidden_layers = num_hidden_layers
57 | self.num_heads = num_heads
58 | self.num_kv_heads = num_kv_heads
59 | self.feature_map = feature_map
60 | self.tie_feature_map_qk = tie_feature_map_qk
61 | self.norm_q = norm_q
62 | self.norm_k = norm_k
63 | self.norm_feature_map = norm_feature_map
64 | self.hidden_act = hidden_act
65 | self.max_position_embeddings = max_position_embeddings
66 | self.elementwise_affine = elementwise_affine
67 | self.norm_eps = norm_eps
68 | self.attn = attn
69 | self.use_cache = use_cache
70 | self.initializer_range = initializer_range
71 | self.fuse_cross_entropy = fuse_cross_entropy
72 | self.vocab_size = vocab_size
73 | self.num_memories = num_memories
74 | self.topk = topk
75 | self.capacity = capacity
76 | self.use_layer_wise_balance = use_layer_wise_balance
77 | self.aux_loss_scale = aux_loss_scale
78 | self.shared_mem = shared_mem
79 |
80 | if attn is not None:
81 | if not isinstance(attn, Dict):
82 | raise ValueError("attn must be a dictionary")
83 | if 'layers' not in attn:
84 | raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
85 | if 'num_heads' not in attn:
86 | raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
87 | attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
88 | attn['window_size'] = attn.get('window_size', None)
89 |
90 | super().__init__(
91 | pad_token_id=pad_token_id,
92 | bos_token_id=bos_token_id,
93 | eos_token_id=eos_token_id,
94 | tie_word_embeddings=tie_word_embeddings,
95 | **kwargs,
96 | )
97 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import os
3 | import re
4 | from pathlib import Path
5 |
6 | from setuptools import find_packages, setup
7 |
8 | setup(
9 | name='mom',
10 | version='0.1.0',
11 | description='MoM: Mixture of Memories',
12 | author='Jusen Du',
13 | author_email='dujusen@gmail.com',
14 | url='https://github.com/OpenSparseLLMs/MoM',
15 | long_description=open('README.md').read(),
16 | long_description_content_type='text/markdown',
17 | packages=find_packages(),
18 | classifiers=[
19 | 'Programming Language :: Python :: 3',
20 | 'License :: OSI Approved :: Apache License',
21 | 'Operating System :: OS Independent',
22 | 'Topic :: Scientific/Engineering :: Artificial Intelligence'
23 | ],
24 | python_requires='>=3.7',
25 | install_requires=[
26 | 'torch>=2.3',
27 | 'transformers>=4.45.0',
28 | 'triton>=3.0',
29 | 'datasets>=3.1.0',
30 | 'einops',
31 | 'ninja',
32 | 'fla @ git+https://github.com/fla-org/flash-linear-attention'
33 | ],
34 | )
35 |
--------------------------------------------------------------------------------
/training/configs/mom_1.3B.json:
--------------------------------------------------------------------------------
1 | {
2 | "attn_mode": "chunk",
3 | "bos_token_id": 1,
4 | "expand_v": 1,
5 | "fuse_cross_entropy": true,
6 | "hidden_act": "swish",
7 | "hidden_ratio": 4,
8 | "hidden_size": 2048,
9 | "initializer_range": 0.02,
10 | "intermediate_size": null,
11 | "model_type": "mom_gated_deltanet",
12 | "num_heads": 4,
13 | "head_dim": 256,
14 | "num_hidden_layers": 24,
15 | "norm_eps": 1e-06,
16 | "tie_word_embeddings": true,
17 | "use_cache": true,
18 | "vocab_size": 32000,
19 | "use_short_conv": true,
20 | "num_memories": 4,
21 | "topk": 2,
22 | "capacity": 1.0,
23 | "use_layer_wise_balance": true,
24 | "aux_loss_scale": 0.01,
25 | "shared_mem": true,
26 | "single_kv_proj": false
27 | }
--------------------------------------------------------------------------------
/training/configs/mom_340M.json:
--------------------------------------------------------------------------------
1 | {
2 | "attn_mode": "chunk",
3 | "bos_token_id": 1,
4 | "expand_v": 1,
5 | "fuse_cross_entropy": true,
6 | "hidden_act": "swish",
7 | "hidden_ratio": 3,
8 | "hidden_size": 1024,
9 | "initializer_range": 0.02,
10 | "intermediate_size": null,
11 | "model_type": "mom_gated_deltanet",
12 | "num_heads": 4,
13 | "head_dim": 256,
14 | "num_hidden_layers": 24,
15 | "norm_eps": 1e-06,
16 | "tie_word_embeddings": true,
17 | "use_cache": true,
18 | "vocab_size": 32000,
19 | "use_short_conv": true,
20 | "num_memories": 4,
21 | "topk": 2,
22 | "capacity": 1.0,
23 | "use_layer_wise_balance": true,
24 | "aux_loss_scale": 0.01,
25 | "shared_mem": true,
26 | "single_kv_proj": false
27 | }
--------------------------------------------------------------------------------
/training/flame/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenSparseLLMs/MoM/2c3ae59f5c0b749f916189433b8dad5a3415dc08/training/flame/__init__.py
--------------------------------------------------------------------------------
/training/flame/logging.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import json
4 | import logging
5 | import os
6 | import sys
7 | import time
8 |
9 | from transformers.trainer_callback import (ExportableState, TrainerCallback,
10 | TrainerControl, TrainerState)
11 | from transformers.training_args import TrainingArguments
12 |
13 |
14 | def get_logger(name: str = None) -> logging.Logger:
15 | formatter = logging.Formatter(
16 | fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
17 | )
18 | handler = logging.StreamHandler(sys.stdout)
19 | handler.setFormatter(formatter)
20 |
21 | logger = logging.getLogger(name)
22 | if 'RANK' in os.environ and int(os.environ['RANK']) == 0:
23 | logger.setLevel(logging.INFO)
24 | logger.addHandler(handler)
25 |
26 | return logger
27 |
28 |
29 | logger = get_logger(__name__)
30 |
31 | LOG_FILE_NAME = "trainer_log.jsonl"
32 |
33 |
34 | class LogCallback(TrainerCallback, ExportableState):
35 | def __init__(self, start_time: float = None, elapsed_time: float = None):
36 |
37 | self.start_time = time.time() if start_time is None else start_time
38 | self.elapsed_time = 0 if elapsed_time is None else elapsed_time
39 | self.last_time = self.start_time
40 |
41 | def on_train_begin(
42 | self,
43 | args: TrainingArguments,
44 | state: TrainerState,
45 | control: TrainerControl,
46 | **kwargs
47 | ):
48 | r"""
49 | Event called at the beginning of training.
50 | """
51 | if state.is_local_process_zero:
52 | if not args.resume_from_checkpoint:
53 | self.start_time = time.time()
54 | self.elapsed_time = 0
55 | else:
56 | self.start_time = state.stateful_callbacks['LogCallback']['start_time']
57 | self.elapsed_time = state.stateful_callbacks['LogCallback']['elapsed_time']
58 |
59 | if args.save_on_each_node:
60 | if not state.is_local_process_zero:
61 | return
62 | else:
63 | if not state.is_world_process_zero:
64 | return
65 |
66 | self.last_time = time.time()
67 | if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
68 | logger.warning("Previous log file in this folder will be deleted.")
69 | os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
70 |
71 | def on_log(
72 | self,
73 | args: TrainingArguments,
74 | state: TrainerState,
75 | control: TrainerControl,
76 | logs,
77 | **kwargs
78 | ):
79 | if args.save_on_each_node:
80 | if not state.is_local_process_zero:
81 | return
82 | else:
83 | if not state.is_world_process_zero:
84 | return
85 |
86 | self.elapsed_time += time.time() - self.last_time
87 | self.last_time = time.time()
88 | if 'num_input_tokens_seen' in logs:
89 | logs['num_tokens'] = logs.pop('num_input_tokens_seen')
90 | state.log_history[-1].pop('num_input_tokens_seen')
91 | throughput = logs['num_tokens'] / args.world_size / self.elapsed_time
92 | state.log_history[-1]['throughput'] = logs['throughput'] = throughput
93 | state.stateful_callbacks["LogCallback"] = self.state()
94 |
95 | logs = dict(
96 | current_steps=state.global_step,
97 | total_steps=state.max_steps,
98 | loss=state.log_history[-1].get("loss", None),
99 | eval_loss=state.log_history[-1].get("eval_loss", None),
100 | predict_loss=state.log_history[-1].get("predict_loss", None),
101 | learning_rate=state.log_history[-1].get("learning_rate", None),
102 | epoch=state.log_history[-1].get("epoch", None),
103 | percentage=round(state.global_step / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
104 | )
105 |
106 | os.makedirs(args.output_dir, exist_ok=True)
107 | with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
108 | f.write(json.dumps(logs) + "\n")
109 |
110 | def state(self) -> dict:
111 | return {
112 | 'start_time': self.start_time,
113 | 'elapsed_time': self.elapsed_time
114 | }
115 |
116 | @classmethod
117 | def from_state(cls, state):
118 | return cls(state['start_time'], state['elapsed_time'])
119 |
--------------------------------------------------------------------------------
/training/flame/parser.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from __future__ import annotations
4 |
5 | from dataclasses import dataclass, field
6 | from typing import Optional
7 |
8 | import transformers
9 | from transformers import HfArgumentParser, TrainingArguments
10 |
11 | from flame.logging import get_logger
12 |
13 | logger = get_logger(__name__)
14 |
15 |
16 | @dataclass
17 | class TrainingArguments(TrainingArguments):
18 |
19 | model_name_or_path: str = field(
20 | default=None,
21 | metadata={
22 | "help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
23 | },
24 | )
25 | tokenizer: str = field(
26 | default="fla-hub/gla-1.3B-100B",
27 | metadata={"help": "Name of the tokenizer to use."}
28 | )
29 | use_fast_tokenizer: bool = field(
30 | default=False,
31 | metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
32 | )
33 | from_config: bool = field(
34 | default=True,
35 | metadata={"help": "Whether to initialize models from scratch."},
36 | )
37 | dataset: Optional[str] = field(
38 | default=None,
39 | metadata={"help": "The dataset(s) to use. Use commas to separate multiple datasets."},
40 | )
41 | dataset_name: Optional[str] = field(
42 | default=None,
43 | metadata={"help": "The name of provided dataset(s) to use."},
44 | )
45 | cache_dir: str = field(
46 | default=None,
47 | metadata={"help": "Path to the cached tokenized dataset."},
48 | )
49 | split: str = field(
50 | default="train",
51 | metadata={"help": "Which dataset split to use for training and evaluation."},
52 | )
53 | streaming: bool = field(
54 | default=False,
55 | metadata={"help": "Enable dataset streaming."},
56 | )
57 | hf_hub_token: Optional[str] = field(
58 | default=None,
59 | metadata={"help": "Auth token to log in with Hugging Face Hub."},
60 | )
61 | preprocessing_num_workers: Optional[int] = field(
62 | default=None,
63 | metadata={"help": "The number of processes to use for the pre-processing."},
64 | )
65 | buffer_size: int = field(
66 | default=2048,
67 | metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."},
68 | )
69 | context_length: int = field(
70 | default=2048,
71 | metadata={"help": "The context length of the tokenized inputs in the dataset."},
72 | )
73 | varlen: bool = field(
74 | default=False,
75 | metadata={"help": "Enable training with variable length inputs."},
76 | )
77 |
78 |
79 | def get_train_args():
80 | parser = HfArgumentParser(TrainingArguments)
81 | args, unknown_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)
82 |
83 | if unknown_args:
84 | print(parser.format_help())
85 | print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args))
86 | raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args))
87 |
88 | if args.should_log:
89 | transformers.utils.logging.set_verbosity(args.get_process_log_level())
90 | transformers.utils.logging.enable_default_handler()
91 | transformers.utils.logging.enable_explicit_format()
92 | # set seeds manually
93 | transformers.set_seed(args.seed)
94 | return args
95 |
--------------------------------------------------------------------------------
/training/run.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from datasets import load_from_disk
4 | from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
5 | Trainer, set_seed)
6 |
7 | import sys
8 | import os
9 | import torch
10 | from torch import nn
11 | import mom
12 | import fla
13 | from flame.data import DataCollatorForLanguageModeling
14 | from flame.logging import LogCallback, get_logger
15 | from flame.parser import get_train_args
16 | import wandb
17 | from torchinfo import summary
18 |
19 | logger = get_logger(__name__)
20 |
21 |
22 | def main():
23 | # torch.autograd.set_detect_anomaly(True)
24 | args = get_train_args()
25 | logger.info(args)
26 |
27 | tokenizer = AutoTokenizer.from_pretrained(
28 | args.tokenizer,
29 | use_fast=args.use_fast_tokenizer,
30 | trust_remote_code=True,
31 | add_bos_token=True,
32 | add_eos_token=False
33 | )
34 | if tokenizer.pad_token_id is None:
35 | tokenizer.pad_token = tokenizer.eos_token
36 | logger.info("Add pad token: {}".format(tokenizer.pad_token))
37 | # args.from_config = False
38 | if args.from_config:
39 | logger.info("All model params are randomly initialized for from-scratch training.")
40 | model = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(args.model_name_or_path))
41 | else:
42 | logger.info(f"Loading pretrained checkpoint {args.model_name_or_path}")
43 | model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
44 | for name, param in model.named_parameters():
45 | if 'gate' in name:
46 | if 'weight' in name:
47 | nn.init.xavier_normal_(param)
48 | model.train()
49 |
50 | # summary(model, depth=6)
51 | # exit(0)
52 |
53 | trainable_params, all_param = model.num_parameters(only_trainable=True), model.num_parameters()
54 | logger.info(f"% of trainable params: {trainable_params:d} / {all_param:d} = {trainable_params / all_param:.2%}")
55 | logger.info(f"{tokenizer}\n{model}\n{model.config}")
56 |
57 | logger.info(f"Loading the `{args.split}` split directly from the cache {args.cache_dir}...")
58 | dataset = load_from_disk(args.cache_dir)
59 | logger.info(f"{dataset}")
60 | logger.info(f"Shuffling the dataset with seed {args.seed}")
61 | dataset = dataset.shuffle(seed=args.seed)
62 | logger.info("Creating the data collator")
63 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, varlen=args.varlen)
64 | logger.info(f"{data_collator}")
65 |
66 | if args.lr_scheduler_type == 'cosine_with_min_lr':
67 | args.lr_scheduler_kwargs = {'min_lr_rate': 0.1}
68 | if args.lr_scheduler_type == 'warmup_stable_decay':
69 | args.lr_scheduler_kwargs = {
70 | 'num_stable_steps': args.max_steps * 0.9 - args.warmup_steps,
71 | 'num_decay_steps': args.max_steps * 0.1
72 | }
73 |
74 | args.logging_steps = 16
75 | trainer = Trainer(
76 | model=model,
77 | args=args,
78 | tokenizer=tokenizer,
79 | data_collator=data_collator,
80 | callbacks=[LogCallback()],
81 | train_dataset=dataset
82 | )
83 |
84 | def detect_nan_hook(grad, name):
85 | if torch.isnan(grad).any():
86 | print(f"NaN detected in gradients of {name}!")
87 | print(f"Gradient values: {grad}")
88 | exit()
89 |
90 | # 注册钩子到每个参数
91 | for name, param in model.named_parameters():
92 | param.register_hook(lambda grad, name=name: detect_nan_hook(grad, name))
93 |
94 | results = trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
95 | trainer.save_model()
96 | tokenizer.save_pretrained(trainer.args.output_dir)
97 |
98 | trainer.log_metrics("train", results.metrics)
99 | trainer.save_metrics("train", results.metrics)
100 | trainer.save_state()
101 |
102 |
103 | if __name__ == "__main__":
104 | main()
105 |
--------------------------------------------------------------------------------