├── cfgs
├── run
│ ├── default.yaml
│ ├── base_memory_policy
│ │ ├── none.yaml
│ │ └── deep
│ │ │ └── bam
│ │ │ └── base_bam.yaml
│ ├── namm_bam_i1.yaml
│ ├── namm_bam_i2.yaml
│ ├── namm_bam_i3.yaml
│ ├── namm_bam_eval.yaml
│ └── namm_bam_eval_choubun.yaml
├── auxiliary_loss
│ └── none.yaml
├── typing
│ ├── default.yaml
│ └── half_attn.yaml
├── policy
│ ├── deep_scoring
│ │ ├── bam.yaml
│ │ ├── mlp.yaml
│ │ └── attn.yaml
│ ├── none.yaml
│ ├── deep_selection
│ │ ├── topk.yaml
│ │ ├── dynamic.yaml
│ │ └── binary.yaml
│ ├── deep.yaml
│ └── deep_embedding
│ │ ├── norm_base.yaml
│ │ ├── attn_spec_norm.yaml
│ │ └── attn_spec_base.yaml
├── task
│ ├── lb_2subset.yaml
│ ├── passage_retrieval_en.yaml
│ ├── lb_3subset_incr.yaml
│ ├── choubun_full.yaml
│ ├── base_sampler.yaml
│ └── lb_full.yaml
├── trainer
│ ├── eval.yaml
│ └── default.yaml
├── evolution
│ ├── dummy.yaml
│ └── cma_es.yaml
├── model
│ ├── wrapped_llm
│ │ ├── llama3-8b-rope-x4NTK.yaml
│ │ └── base.yaml
│ └── hf_evaluator.yaml
├── config_run_eval.yaml
└── config.yaml
├── figures
└── logo.png
├── memory_evolution
├── __init__.py
└── base.py
├── memory_llms
├── __init__.py
└── base.py
├── ChouBun
└── config
│ ├── dataset2maxlen.json
│ └── dataset2prompt.json
├── stateless_parallel_modules
├── __init__.py
├── mlp.py
├── base.py
└── attention.py
├── LongBench
└── config
│ ├── model2maxlen.json
│ ├── dataset2maxlen.json
│ ├── model2path.json
│ └── dataset2prompt.json
├── env_minimal.yaml
├── memory_policy
├── __init__.py
├── deep_embedding_shared.py
├── deep_scoring_bam.py
├── deep_embedding.py
├── deep_selection.py
├── shared.py
├── deep_embedding_spectogram.py
├── auxiliary_losses.py
└── base_dynamic.py
├── utils_log.py
├── choubun_metrics.py
├── utils_hydra.py
├── README.md
├── longbench_metrics.py
├── main.py
├── utils_longbench.py
└── env.yaml
/cfgs/run/default.yaml:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/cfgs/auxiliary_loss/none.yaml:
--------------------------------------------------------------------------------
1 | auxiliary_loss:
2 |
3 |
4 |
--------------------------------------------------------------------------------
/figures/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SakanaAI/evo-memory/HEAD/figures/logo.png
--------------------------------------------------------------------------------
/memory_evolution/__init__.py:
--------------------------------------------------------------------------------
1 | from .cma_es import CMA_ES
2 | from .base import MemoryEvolution, DummyEvolution
3 |
--------------------------------------------------------------------------------
/memory_llms/__init__.py:
--------------------------------------------------------------------------------
1 | from .llama import WrappedLlamaForCausalLM
2 | from .base import MemoryModelWrapper
3 |
--------------------------------------------------------------------------------
/cfgs/typing/default.yaml:
--------------------------------------------------------------------------------
1 | # main model dtype
2 | dtype: 'bfloat16'
3 | output_attentions_full_precision: true
4 |
5 | dtype_conf_name: ''
--------------------------------------------------------------------------------
/ChouBun/config/dataset2maxlen.json:
--------------------------------------------------------------------------------
1 | {
2 | "wiki_qa": 64,
3 | "edinet_qa": 64,
4 | "corp_sec_qa": 64,
5 | "corp_sec_sum": 128
6 | }
--------------------------------------------------------------------------------
/cfgs/typing/half_attn.yaml:
--------------------------------------------------------------------------------
1 | # main model dtype
2 | dtype: 'bfloat16'
3 | output_attentions_full_precision: false
4 |
5 | dtype_conf_name: 'halfAttn'
--------------------------------------------------------------------------------
/cfgs/policy/deep_scoring/bam.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - attn
3 | - _self_
4 |
5 | mult: true
6 | mult_nonlinearity:
7 | # --
8 |
9 | scoring_mp_log_name: bam
--------------------------------------------------------------------------------
/cfgs/policy/none.yaml:
--------------------------------------------------------------------------------
1 | memory_policy:
2 | _target_: memory_policy.Recency
3 | cache_size: ${cache_size}
4 |
5 | cache_size: null
6 |
7 |
8 | mp_log_name: dummy
--------------------------------------------------------------------------------
/cfgs/task/lb_2subset.yaml:
--------------------------------------------------------------------------------
1 | # TODO
2 |
3 | defaults:
4 | - base_sampler
5 | - _self_
6 |
7 | tasks: ["lb/passage_retrieval_en", "lb/dureader"]
8 | metrics: 'perf'
9 |
10 | tasks_log_name: lb2subset
--------------------------------------------------------------------------------
/cfgs/task/passage_retrieval_en.yaml:
--------------------------------------------------------------------------------
1 | # TODO
2 |
3 | defaults:
4 | - base_sampler
5 | - _self_
6 |
7 | tasks: ["lb/passage_retrieval_en"]
8 | metrics: 'perf'
9 |
10 | tasks_log_name: passage-retrieval-en
--------------------------------------------------------------------------------
/cfgs/task/lb_3subset_incr.yaml:
--------------------------------------------------------------------------------
1 | # TODO
2 |
3 | defaults:
4 | - base_sampler
5 | - _self_
6 |
7 | tasks: ["lb/passage_retrieval_en", "lb/dureader", "lb/narrativeqa"]
8 | metrics: 'perf'
9 |
10 | tasks_log_name: lb3subset-i
--------------------------------------------------------------------------------
/cfgs/policy/deep_selection/topk.yaml:
--------------------------------------------------------------------------------
1 | selection_criteria: ${topk_selection}
2 |
3 | topk_selection:
4 | _target_: memory_policy.TopKSelection
5 | cache_size: ${cache_size}
6 |
7 | cache_size: 8192
8 |
9 | selection_mp_log_name: topk-${cache_size}cs
--------------------------------------------------------------------------------
/cfgs/task/choubun_full.yaml:
--------------------------------------------------------------------------------
1 | # TODO
2 |
3 | defaults:
4 | - base_sampler
5 | - _self_
6 |
7 | tasks:
8 | - "choubun/wiki_qa"
9 | - "choubun/edinet_qa"
10 | - "choubun/corp_sec_qa"
11 | - "choubun/corp_sec_sum"
12 | metrics: 'perf'
13 |
--------------------------------------------------------------------------------
/stateless_parallel_modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .mlp import StatelessGeneralizedMLP
2 | from .attention import StatelessAttention, StatelessAttentionParams
3 |
4 | from .base import (
5 | StatelessGeneralizedOperation, GeneralizedLinear, get_nonlinearity,
6 | StatelessGeneralizedModule)
7 |
--------------------------------------------------------------------------------
/cfgs/task/base_sampler.yaml:
--------------------------------------------------------------------------------
1 | task_sampler:
2 | _target_: task_sampler.TaskSampler
3 | tasks: ${tasks}
4 | metrics: ${metrics}
5 | training_tasks_subset: ${training_tasks_subset}
6 | test_tasks_subset: ${test_tasks_subset}
7 |
8 |
9 | tasks: ['sciq']
10 | metrics: ['acc']
11 | training_tasks_subset:
12 | test_tasks_subset:
13 |
14 | tasks_log_name:
--------------------------------------------------------------------------------
/cfgs/trainer/eval.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default
3 | - _self_
4 |
5 | max_iters: 0
6 | eval_only: true
7 | always_save_checkpoint: true
8 | scratch: false
9 | record_advanced_eval_stats: true
10 | eval_samples_batch_size:
11 | allow_distributed_eval: true
12 |
13 | score_aggregation: 'mean'
14 | score_normalization_reference:
15 | store_eval_results_locally: true
--------------------------------------------------------------------------------
/cfgs/evolution/dummy.yaml:
--------------------------------------------------------------------------------
1 | evolution_algorithm:
2 | _target_: memory_evolution.base.DummyEvolution
3 | pop_size: ${pop_size}
4 | param_size: ${param_size}
5 | param_clip: ${param_clip}
6 | score_processing: ${score_processing}
7 |
8 | pop_size: 1
9 | param_size:
10 | param_clip:
11 | score_processing:
12 |
13 | evolution_algorithm_name: dummy
14 |
15 | evolution_algorithm_log_name: ${evolution_algorithm_name}-p${pop_size}
--------------------------------------------------------------------------------
/cfgs/policy/deep_selection/dynamic.yaml:
--------------------------------------------------------------------------------
1 | selection_criteria: ${dynamic_selection}
2 |
3 | dynamic_selection:
4 | _target_: memory_policy.DynamicSelection
5 | per_layer: ${per_layer}
6 | per_head: ${per_head}
7 | shared: ${selection_shared}
8 | cache_size: ${cache_size}
9 | dynamic_thresh: ${dynamic_thresh}
10 |
11 | selection_shared: true
12 | cache_size:
13 | # 1/2048
14 | dynamic_thresh: 0.00048828125
15 |
16 | selection_mp_log_name: dynamic-${cache_size}cs
--------------------------------------------------------------------------------
/cfgs/policy/deep_selection/binary.yaml:
--------------------------------------------------------------------------------
1 | selection_criteria: ${dynamic_selection}
2 |
3 | dynamic_selection:
4 | _target_: memory_policy.BinarySelection
5 | per_layer: ${per_layer}
6 | per_head: ${per_head}
7 | shared: ${selection_shared}
8 | cache_size: ${cache_size}
9 | is_probabilistic: ${is_probabilistic}
10 | temp: ${temp}
11 | learned_temp: ${learned_temp}
12 |
13 |
14 | is_probabilistic: false
15 | temp: 1.0
16 | learned_temp: false
17 |
18 | selection_shared: true
19 | cache_size:
20 |
21 | selection_mp_log_name: binary-${cache_size}cs
--------------------------------------------------------------------------------
/LongBench/config/model2maxlen.json:
--------------------------------------------------------------------------------
1 | {
2 | "llama2-7b-chat-4k": 3500,
3 | "longchat-v1.5-7b-32k": 31500,
4 | "xgen-7b-8k": 7500,
5 | "internlm-7b-8k": 7500,
6 | "chatglm2-6b": 31500,
7 | "chatglm2-6b-32k": 31500,
8 | "chatglm3-6b-32k": 31500,
9 | "vicuna-v1.5-7b-16k": 15500,
10 | "Meta-Llama-3-8B": 7500,
11 | "Meta-Llama-3-8B-x4NTK": 31500,
12 | "Meta-Llama-3-8B-x4Linear": 31500,
13 | "Mixtral-8x22B-v0.1": 63500,
14 | "Mistral-7B-Instruct-v0.2": 31500,
15 | "Mistral-NoChatPrompt-7B-Instruct-v0.2": 31500
16 | }
17 |
--------------------------------------------------------------------------------
/ChouBun/config/dataset2prompt.json:
--------------------------------------------------------------------------------
1 | {
2 | "wiki_qa": "以下の文書を注意深く読み、続く質問に答えてください。回答は必ず文書内の連続した文字列(テキストスパン)から抽出してください。\n### 文書 ###\n{context}\n### 質問 ###\n{input}\n### 回答 ###\n",
3 | "edinet_qa": "以下の文書を注意深く読み、続く質問に答えてください。回答は必ず文書内の連続した文字列(テキストスパン)から抽出してください。\n### 文書 ###\n{context}\n### 質問 ###\n{input}\n### 回答 ###\n",
4 | "corp_sec_qa": "以下の文書を注意深く読み、続く質問に答えてください。回答は必ず文書内の連続した文字列(テキストスパン)から抽出してください。### 文書 ###\n{context}\n### 質問 ###\n{input}\n### 回答 ###\n",
5 | "corp_sec_sum": "以下の文書を要約してください。要約は短く簡潔にしてください。また、文書の全体的な考え方、傾向、洞察を反映させてください。\n### 文書 ###\n{context}\n### 要約 ###\n"
6 | }
--------------------------------------------------------------------------------
/LongBench/config/dataset2maxlen.json:
--------------------------------------------------------------------------------
1 | {
2 | "narrativeqa": 128,
3 | "qasper": 128,
4 | "multifieldqa_en": 64,
5 | "multifieldqa_zh": 64,
6 | "hotpotqa": 32,
7 | "2wikimqa": 32,
8 | "musique": 32,
9 | "dureader": 128,
10 | "gov_report": 512,
11 | "qmsum": 512,
12 | "multi_news": 512,
13 | "vcsum": 512,
14 | "trec": 64,
15 | "triviaqa": 32,
16 | "samsum": 128,
17 | "lsht": 64,
18 | "passage_count": 32,
19 | "passage_retrieval_en": 32,
20 | "passage_retrieval_zh": 32,
21 | "lcc": 64,
22 | "repobench-p": 64
23 | }
--------------------------------------------------------------------------------
/cfgs/task/lb_full.yaml:
--------------------------------------------------------------------------------
1 | # TODO
2 |
3 | defaults:
4 | - base_sampler
5 | - _self_
6 |
7 | tasks:
8 | - "lb/narrativeqa"
9 | - "lb/qasper"
10 | - "lb/multifieldqa_en"
11 | - "lb/multifieldqa_zh"
12 | - "lb/hotpotqa"
13 | - "lb/2wikimqa"
14 | - "lb/musique"
15 | - "lb/dureader"
16 | - "lb/gov_report"
17 | - "lb/qmsum"
18 | - "lb/multi_news"
19 | - "lb/vcsum"
20 | - "lb/trec"
21 | - "lb/triviaqa"
22 | - "lb/samsum"
23 | - "lb/lsht"
24 | - "lb/passage_count"
25 | - "lb/passage_retrieval_en"
26 | - "lb/passage_retrieval_zh"
27 | - "lb/lcc"
28 | - "lb/repobench-p"
29 | metrics: 'perf'
30 |
--------------------------------------------------------------------------------
/cfgs/policy/deep.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - deep_embedding@_global_: joint_recency_attn_spec
4 | - deep_scoring@_global_: mlp
5 | - deep_selection@_global_: dynamic
6 |
7 |
8 | memory_policy:
9 | _target_: memory_policy.DeepMP
10 | pop_size: ${pop_size}
11 | per_head: ${per_head}
12 | per_layer: ${per_layer}
13 | token_embedding: ${token_embedding}
14 | scoring_network: ${scoring_network}
15 | selection_criteria: ${selection_criteria}
16 | lazy_param_num: ${lazy_param_num}
17 |
18 | per_head: false
19 | per_layer: true
20 | lazy_param_num: false
21 |
22 | mp_log_name: NAMM/${embedding_mp_log_name}/${scoring_mp_log_name}/${selection_mp_log_name}
--------------------------------------------------------------------------------
/env_minimal.yaml:
--------------------------------------------------------------------------------
1 | name: th2
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - python=3.10
9 | - pytorch
10 | - pytorch-cuda=12.1
11 | - torchvision
12 | - torchaudio
13 | - ipykernel
14 | - matplotlib
15 | - ipywidgets
16 | - pip:
17 | - vllm
18 | - numpy
19 | - transformers==4.41.2
20 | - accelerate
21 | - datasets
22 | - tiktoken
23 | - wandb
24 | - tqdm
25 | - hydra-core
26 | - crfm-helm
27 | - lm-eval==0.4.2
28 | - fugashi==1.3.2
29 | - ftfy
30 | - bitsandbytes
31 | - rouge
32 | - jieba
33 | - fuzzywuzzy
34 | - einops
35 | - scipy==1.13.0
36 | - seaborn
37 |
--------------------------------------------------------------------------------
/cfgs/model/wrapped_llm/llama3-8b-rope-x4NTK.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - base
3 | - _self_
4 |
5 | pretrained_llm_name: meta-llama/Meta-Llama-3-8B
6 |
7 | memory_model:
8 | _target_: memory_llms.llama.WrappedLlamaForCausalLM
9 | model: ${pretrained_llm}
10 | memory_policy: ${memory_policy}
11 | max_new_tokens: ${max_new_tokens}
12 |
13 | pretrained_llm:
14 | _target_: utils_hydra.AutoModelForCausalLM.from_pretrained
15 | pretrained_model_name_or_path: ${pretrained_llm_name}
16 | rope_scaling: ${rope_scaling}
17 |
18 | max_new_tokens:
19 | max_position_id: 32768
20 |
21 | rope_scaling:
22 | type: dynamic
23 | override: 'ntk'
24 | factor: 4
25 | alpha: 2
26 |
27 |
28 | llm_log_name: Meta-Llama-3-8B-x4NTK
--------------------------------------------------------------------------------
/cfgs/policy/deep_embedding/norm_base.yaml:
--------------------------------------------------------------------------------
1 | emb_online_output_normalization: true
2 | emb_update_norm_during_training: false
3 | emb_update_norm_during_eval: true
4 |
5 | synchronized_buffers_freeze_after: 1
6 |
7 | embedding_output_params:
8 | _target_: memory_policy.ComponentOutputParams
9 | requires_recomputation: ${embedding_requires_recomputation}
10 | reduction_mode: ${embedding_reduction_mode}
11 | ema_params: ${embedding_ema_params}
12 | output_past_non_reduced_history: ${embedding_output_past_non_reduced_history}
13 | max_non_reduced_history_len: ${embedding_max_non_reduced_history_len}
14 | online_output_normalization: ${emb_online_output_normalization}
15 | update_norm_during_training: ${emb_update_norm_during_training}
16 | update_norm_during_eval: ${emb_update_norm_during_eval}
17 |
--------------------------------------------------------------------------------
/LongBench/config/model2path.json:
--------------------------------------------------------------------------------
1 | {
2 | "llama2-7b-chat-4k": "meta-llama/Llama-2-7b-chat-hf",
3 | "longchat-v1.5-7b-32k": "lmsys/longchat-7b-v1.5-32k",
4 | "xgen-7b-8k": "Salesforce/xgen-7b-8k-inst",
5 | "internlm-7b-8k": "internlm/internlm-chat-7b-8k",
6 | "chatglm2-6b": "THUDM/chatglm2-6b",
7 | "chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
8 | "chatglm3-6b-32k": "THUDM/chatglm3-6b-32k",
9 | "vicuna-v1.5-7b-16k": "lmsys/vicuna-7b-v1.5-16k",
10 | "Meta-Llama-3-8B": "meta-llama/Meta-Llama-3-8B",
11 | "Meta-Llama-3-8B-x4NTK": "meta-llama/Meta-Llama-3-8B",
12 | "Meta-Llama-3-8B-x4Linear": "meta-llama/Meta-Llama-3-8B",
13 | "Mixtral-8x22B-v0.1": "mistralai/Mixtral-8x22B-v0.1",
14 | "Mistral-7B-Instruct-v0.2": "mistralai/Mistral-7B-Instruct-v0.2",
15 | "Mistral-NoChatPrompt-7B-Instruct-v0.2": "mistralai/Mistral-7B-Instruct-v0.2"
16 | }
17 |
--------------------------------------------------------------------------------
/cfgs/config_run_eval.yaml:
--------------------------------------------------------------------------------
1 |
2 | # TODO
3 | defaults:
4 | - _self_
5 | - trainer@_global_: default
6 | - model@_global_: hf_evaluator
7 | - policy@_global_: deep
8 | - auxiliary_loss@_global_: none
9 | - evolution@_global_: cma_es
10 | - task@_global_: passage_retrieval_en
11 | - run@_global_: default
12 |
13 | params_to_sweep:
14 | memory_policy:
15 | cache_size: [256, 512, 1024]
16 |
17 | out_folder: ./exp_local
18 | out_dir: '${out_folder}/${wandb_project}/${wandb_group_name}/${wandb_run_name}/${seed}'
19 |
20 | # wandb logging
21 | wandb_log: true # enabled by default
22 | wandb_project: memory_evolution_hf
23 | wandb_run_name: default_configs_test
24 | wandb_group_name: tests
25 |
26 |
27 | # system
28 | backend: 'nccl' # 'nccl', 'gloo', etc.
29 | device: 'cuda'
30 | seed: 1337
31 | deterministic_behavior: false
32 |
33 |
34 | hydra:
35 | run:
36 | dir: ${out_dir}/
--------------------------------------------------------------------------------
/cfgs/model/wrapped_llm/base.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 |
4 |
5 | pretrained_llm:
6 | _target_: utils_hydra.AutoModelForCausalLM.from_pretrained
7 | pretrained_model_name_or_path: ${pretrained_llm_name}
8 |
9 | tokenizer:
10 | _target_: utils_hydra.AutoTokenizer.from_pretrained
11 | pretrained_model_name_or_path: ${pretrained_llm_name}
12 |
13 | pretrained_llm_name: meta-llama/Meta-Llama-3-8B
14 |
15 |
16 | memory_model:
17 | _target_: memory_llms.llama.WrappedLlamaForCausalLM
18 | model: ${pretrained_llm}
19 | memory_policy: ${memory_policy}
20 | max_new_tokens: ${max_new_tokens}
21 | memory_policy_fixed_delay: ${memory_policy_fixed_delay}
22 | output_attentions_full_precision: ${output_attentions_full_precision}
23 |
24 | # needs to specify only for non memory_models
25 | max_new_tokens:
26 | max_position_id: 2048
27 | max_conditioning_length: ${max_position_id}
28 | memory_policy_fixed_delay:
--------------------------------------------------------------------------------
/cfgs/run/base_memory_policy/none.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model/wrapped_llm@_global_: llama3-8b-rope-x4NTK
4 | - override /task@_global_: passage_retrieval_en
5 | - override /evolution@_global_: dummy
6 | - override /policy@_global_: none
7 | - override /typing@_global_: half_attn
8 | - override /auxiliary_loss@_global_: none
9 | - _self_
10 |
11 | run_name_suffix: ''
12 |
13 | wandb_project: memory_evolution_hf
14 |
15 | wandb_run_name: ${tasks_log_name}-${evolution_algorithm_log_name}-shared-${pop_size}pop-${samples_batch_size}qs-${memory_policy_fixed_delay}fixDel-${run_name_suffix}
16 | wandb_group_name: ${llm_log_name}/${mp_log_name}
17 |
18 | add_bos_token: true # to match ref. scores setting
19 | score_normalization_reference: 'lb_reference_scores/per_request/llama3-8b-32k-x2NTK2alpha-v2.json'
20 |
21 | batch_size: 1
22 |
23 | max_new_tokens: 128
24 |
25 | cache_size: ${max_position_id} # do not discard any tokens
--------------------------------------------------------------------------------
/cfgs/run/namm_bam_i1.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - base_memory_policy/deep/bam/base_bam
4 | - _self_
5 |
6 | run_name_suffix: 'stage1'
7 |
8 | cache_size:
9 | max_new_tokens: 128
10 |
11 | scoring_reduction_mode:
12 | embedding_reduction_mode: ema
13 |
14 | # Ema params.
15 | embedding_ema_coeff: 0.99
16 | embedding_reduction_learned: false
17 |
18 |
19 | # --
20 | # increasing the hidden dim, increases time cost marginally
21 | scoring_attn_hidden_dim: 32
22 | scoring_attn_output_dim:
23 | scoring_attn_num_heads: 1
24 | scoring_attn_bias: true
25 | scoring_attn_use_rope: false
26 | scoring_attn_rope_theta: 50000
27 | scoring_attn_masking_strategy: backward
28 | #--
29 |
30 | # --
31 | # Stft params
32 | n_fft: 32
33 | hop_length: 16
34 | window_fn:
35 | _target_: utils_hydra.torch.hann_window
36 | window_length: ${n_fft}
37 | periodic: true
38 | window_fn_log_name: hann
39 | pad_mode: constant
40 | output_magnitudes: true
41 | # --
42 |
43 | pop_size: 32
44 |
45 | scoring_initializer: 0
46 |
47 | keep_past_epoch_checkpoints_every: 1
--------------------------------------------------------------------------------
/cfgs/evolution/cma_es.yaml:
--------------------------------------------------------------------------------
1 | evolution_algorithm:
2 | _target_: memory_evolution.cma_es.CMA_ES
3 | pop_size: ${pop_size}
4 | param_size: ${param_size}
5 | param_clip: ${param_clip}
6 | clip_param_min: ${clip_param_min}
7 | clip_param_max: ${clip_param_max}
8 | score_processing: ${score_processing}
9 | elite_ratio: ${elite_ratio}
10 | c_1: ${c_1}
11 | c_mu: ${c_mu}
12 | c_sigma: ${c_sigma}
13 | d_sigma: ${d_sigma}
14 | c_c: ${c_c}
15 | c_m: ${c_m}
16 | init_sigma: ${init_sigma}
17 | init_param_range: ${init_param_range}
18 | prefer_mean_to_best: ${prefer_mean_to_best}
19 |
20 |
21 | pop_size: 16
22 | param_size:
23 | param_clip:
24 | clip_param_min:
25 | clip_param_max:
26 | # cma-es has own custom score processinhg
27 | score_processing:
28 | elite_ratio: 0.5
29 | c_1:
30 | c_mu:
31 | c_sigma:
32 | d_sigma:
33 | c_c:
34 |
35 |
36 | c_m: 1.0
37 | init_sigma: 0.065
38 | init_param_range:
39 |
40 | prefer_mean_to_best: false
41 |
42 | evolution_algorithm_name: cma-es
43 |
44 | evolution_algorithm_log_name: ${evolution_algorithm_name}-p${pop_size}-rMean${prefer_mean_to_best}
--------------------------------------------------------------------------------
/cfgs/config.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - trainer@_global_: default
4 | - model@_global_: hf_evaluator
5 | - typing@_global_: default
6 | - policy@_global_: deep
7 | - auxiliary_loss@_global_: none
8 | - evolution@_global_: cma_es
9 | - task@_global_: passage_retrieval_en
10 | - run@_global_: default
11 |
12 | wandb_config:
13 | _target_: memory_trainer.WandbConfig
14 | wandb_log: ${wandb_log}
15 | wandb_project: ${wandb_project}
16 | wandb_run_name: ${wandb_run_name}
17 | wandb_group_name: ${wandb_group_name}
18 |
19 |
20 | out_folder: ./exp_local
21 | out_dir: '${out_folder}/${wandb_project}/${wandb_group_name}/${wandb_run_name}/${seed}'
22 |
23 | # wandb logging
24 | wandb_log: true # enabled by default
25 | wandb_project: memory_evolution_hf
26 | wandb_run_name: default_configs_test
27 | wandb_group_name: tests
28 |
29 | # system
30 | device: 'cuda'
31 | seed: 1337
32 | deterministic_behavior: false
33 |
34 | # ddp settings
35 | backend: 'nccl' # 'nccl', 'gloo', etc.
36 | ddp_timeout_limit: '0:6:0' # days:hours:minutes
37 |
38 | hydra:
39 | run:
40 | dir: ${out_dir}/
--------------------------------------------------------------------------------
/cfgs/policy/deep_embedding/attn_spec_norm.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - attn_spec_base
3 | - norm_base
4 | - _self_
5 |
6 | token_embedding:
7 | _target_: memory_policy.RecencyEmbeddingWrapper
8 | token_embedding: ${attn_spectrogram}
9 | recency_embedding: ${recency_deep_embedding}
10 | start_recency_from: ${start_recency_from}
11 | wrapper_output_dim: ${wrapper_output_dim}
12 | processing_layers: ${processing_layers}
13 | joining_strategy: ${joining_strategy}
14 | output_params: ${wrapper_output_params}
15 |
16 |
17 | recency_deep_embedding:
18 | _target_: memory_policy.PositionalEmbedding
19 | max_position_id: ${max_position_id}
20 | embed_dim: ${recency_embed_dim}
21 | max_freq: ${recency_max_freq}
22 |
23 | wrapper_output_dim:
24 | processing_layers: 0
25 | joining_strategy: 'append'
26 | wrapper_output_params:
27 |
28 | recency_embedding_name: ${joining_strategy}RecAbsPoc${recency_embed_dim}D
29 |
30 | # must be even
31 | recency_embed_dim: 8
32 | recency_max_freq: 50000
33 | start_recency_from: 1 # NOTE: 1 for legacy compatibility purposes
34 |
35 | embedding_mp_log_name: attn-spec-norm
--------------------------------------------------------------------------------
/cfgs/run/namm_bam_i2.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - base_memory_policy/deep/bam/base_bam
4 | - override /task@_global_: lb_2subset
5 | - _self_
6 |
7 | run_name_suffix: 'stage2'
8 |
9 | cache_size:
10 | max_new_tokens: 128
11 |
12 | scoring_reduction_mode:
13 | embedding_reduction_mode: ema
14 |
15 | # Ema params.
16 | embedding_ema_coeff: 0.99
17 | embedding_reduction_learned: false
18 |
19 | # --
20 | # increasing the hidden dim, increases time cost marginally
21 | scoring_attn_hidden_dim: 32
22 | scoring_attn_output_dim:
23 | scoring_attn_num_heads: 1
24 | scoring_attn_bias: true
25 | scoring_attn_use_rope: false
26 | scoring_attn_rope_theta: 50000
27 | scoring_attn_masking_strategy: backward
28 | #--
29 |
30 | # --
31 | # Stft params
32 | n_fft: 32
33 | hop_length: 16
34 | window_fn:
35 | _target_: utils_hydra.torch.hann_window
36 | window_length: ${n_fft}
37 | periodic: true
38 | window_fn_log_name: hann
39 | pad_mode: constant
40 | output_magnitudes: true
41 | # --
42 |
43 |
44 | init_from: 'path/to/stage1.pt'
45 |
46 |
47 | pop_size: 32
48 |
49 | scoring_initializer: 0
50 | keep_past_epoch_checkpoints_every: 1
51 |
52 | scratch: false
--------------------------------------------------------------------------------
/cfgs/run/namm_bam_i3.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - base_memory_policy/deep/bam/base_bam
4 | - override /task@_global_: lb_3subset_incr
5 | - _self_
6 |
7 | run_name_suffix: 'stage3'
8 |
9 | cache_size:
10 | max_new_tokens: 128
11 |
12 | scoring_reduction_mode:
13 | embedding_reduction_mode: ema
14 |
15 | # Ema params.
16 | embedding_ema_coeff: 0.99
17 | embedding_reduction_learned: false
18 |
19 | # --
20 | # increasing the hidden dim, increases time cost marginally
21 | scoring_attn_hidden_dim: 32
22 | scoring_attn_output_dim:
23 | scoring_attn_num_heads: 1
24 | scoring_attn_bias: true
25 | scoring_attn_use_rope: false
26 | scoring_attn_rope_theta: 50000
27 | scoring_attn_masking_strategy: backward
28 | #--
29 |
30 | # --
31 | # Stft params
32 | n_fft: 32
33 | hop_length: 16
34 | window_fn:
35 | _target_: utils_hydra.torch.hann_window
36 | window_length: ${n_fft}
37 | periodic: true
38 | window_fn_log_name: hann
39 | pad_mode: constant
40 | output_magnitudes: true
41 | # --
42 |
43 |
44 | init_from: 'path/to/stage2.pt'
45 |
46 |
47 | pop_size: 32
48 |
49 | scoring_initializer: 0
50 | keep_past_epoch_checkpoints_every: 1
51 |
52 | scratch: false
--------------------------------------------------------------------------------
/cfgs/model/hf_evaluator.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - wrapped_llm@_global_: pythia160m
4 |
5 | memory_evaluator:
6 | _target_: memory_evaluator.MemoryHFEvaluator
7 | model: ${pretrained_llm}
8 | tokenizer: ${tokenizer}
9 |
10 | evaluation_ctx_steps: ${evaluation_ctx_steps}
11 | add_bos_token: ${add_bos_token}
12 |
13 | eval_max_batch_size: ${eval_max_batch_size}
14 | memory_batch_size: ${memory_batch_size}
15 |
16 | batch_size: ${batch_size}
17 | max_conditioning_length: ${max_conditioning_length}
18 |
19 | max_memory_length: ${max_memory_length}
20 | max_gen_tokens: ${max_gen_tokens}
21 | full_context_gen: ${full_context_gen}
22 |
23 | per_timestep_loglikelihood: ${per_timestep_loglikelihood}
24 | force_clear_cache: ${force_clear_cache}
25 |
26 | device: ${device}
27 |
28 | log_misc: ${log_misc}
29 |
30 | evaluation_ctx_steps: 1 # default, unused
31 | add_bos_token: true
32 |
33 | eval_max_batch_size: 256
34 | memory_batch_size: false
35 |
36 | batch_size: "auto"
37 |
38 | max_memory_length: 2048
39 | max_gen_tokens: 512 # from the HFLM class
40 | full_context_gen: true
41 |
42 | per_timestep_loglikelihood: true
43 | force_clear_cache: true
44 |
45 | log_misc: false
--------------------------------------------------------------------------------
/cfgs/run/namm_bam_eval.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - base_memory_policy/deep/bam/base_bam
4 | - override /task@_global_: lb_full
5 | - _self_
6 |
7 | run_name_suffix: 'stage3'
8 |
9 | cache_size:
10 | max_new_tokens: 128
11 |
12 | scoring_reduction_mode:
13 | embedding_reduction_mode: ema
14 |
15 | # Ema params.
16 | embedding_ema_coeff: 0.99
17 | embedding_reduction_learned: false
18 |
19 | # --
20 | # increasing the hidden dim, increases time cost marginally
21 | scoring_attn_hidden_dim: 32
22 | scoring_attn_output_dim:
23 | scoring_attn_num_heads: 1
24 | scoring_attn_bias: true
25 | scoring_attn_use_rope: false
26 | scoring_attn_rope_theta: 50000
27 | scoring_attn_masking_strategy: backward
28 | #--
29 |
30 | # --
31 | # Stft params
32 | n_fft: 32
33 | hop_length: 16
34 | window_fn:
35 | _target_: utils_hydra.torch.hann_window
36 | window_length: ${n_fft}
37 | periodic: true
38 | window_fn_log_name: hann
39 | pad_mode: constant
40 | output_magnitudes: true
41 | # --
42 |
43 |
44 | init_from: 'path/to/namm/results/ckpt.pt'
45 |
46 | pop_size: 32
47 |
48 | scoring_initializer: 0
49 | keep_past_epoch_checkpoints_every: 1
50 |
51 | scratch: false
52 |
53 | eval_only: true
54 | # avoid timeout when evaluating on all tasks
55 | ddp_timeout_limit: '9:0:0'
--------------------------------------------------------------------------------
/memory_policy/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .base import (
3 | MemoryPolicy, ParamMemoryPolicy, Recency, AttnRequiringRecency,)
4 | from .base_dynamic import (
5 | DynamicMemoryPolicy, DynamicParamMemoryPolicy,
6 | RecencyParams, AttentionParams,
7 | )
8 |
9 | from .auxiliary_losses import MemoryPolicyAuxiliaryLoss, SparsityAuxiliaryLoss, L2NormAuxiliaryLoss
10 |
11 | from memory_policy.deep import DeepMP
12 | from memory_policy.deep_embedding_spectogram import (
13 | STFTParams, AttentionSpectrogram, fft_avg_mask, fft_ema_mask,
14 | )
15 | from memory_policy.deep_embedding import (
16 | RecencyExponents, NormalizedRecencyExponents)
17 | from memory_policy.deep_scoring import (
18 | MLPScoring, GeneralizedScoring, make_scaled_one_hot_init, TCNScoring)
19 | from memory_policy.deep_selection import (
20 | DynamicSelection, TopKSelection, BinarySelection)
21 | from memory_policy.base_deep_components import (
22 | EMAParams, ComponentOutputParams, wrap_torch_initializer,
23 | DeepMemoryPolicyComponent, TokenEmbedding, JointEmbeddings,
24 | ScoringNetwork, SelectionNetwork,
25 | )
26 | from .shared import SynchronizableBufferStorage, RegistrationCompatible
27 |
28 | from .deep_embedding_shared import PositionalEmbedding, Embedding
29 | from .deep_embedding_wrappers import RecencyEmbeddingWrapper
--------------------------------------------------------------------------------
/cfgs/run/namm_bam_eval_choubun.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - base_memory_policy/deep/bam/base_bam
4 | - override /task@_global_: choubun_full
5 | - _self_
6 |
7 | run_name_suffix: 'stage3'
8 |
9 | cache_size:
10 | max_new_tokens: 128
11 |
12 | scoring_reduction_mode:
13 | embedding_reduction_mode: ema
14 |
15 | # Ema params.
16 | embedding_ema_coeff: 0.99
17 | embedding_reduction_learned: false
18 |
19 | # --
20 | # increasing the hidden dim, increases time cost marginally
21 | scoring_attn_hidden_dim: 32
22 | scoring_attn_output_dim:
23 | scoring_attn_num_heads: 1
24 | scoring_attn_bias: true
25 | scoring_attn_use_rope: false
26 | scoring_attn_rope_theta: 50000
27 | scoring_attn_masking_strategy: backward
28 | #--
29 |
30 | # --
31 | # Stft params
32 | n_fft: 32
33 | hop_length: 16
34 | window_fn:
35 | _target_: utils_hydra.torch.hann_window
36 | window_length: ${n_fft}
37 | periodic: true
38 | window_fn_log_name: hann
39 | pad_mode: constant
40 | output_magnitudes: true
41 | # --
42 |
43 |
44 | init_from: 'path/to/namm/results/ckpt.pt'
45 |
46 | pop_size: 32
47 |
48 | scoring_initializer: 0
49 | keep_past_epoch_checkpoints_every: 1
50 |
51 | scratch: false
52 |
53 | eval_only: true
54 | # avoid timeout when evaluating on all tasks
55 | ddp_timeout_limit: '9:0:0'
56 |
57 | score_normalization_reference:
--------------------------------------------------------------------------------
/cfgs/policy/deep_embedding/attn_spec_base.yaml:
--------------------------------------------------------------------------------
1 | attn_spectrogram:
2 | _target_: memory_policy.AttentionSpectrogram
3 | per_layer: ${per_layer}
4 | per_head: ${per_head}
5 | shared: ${embedding_shared}
6 | output_params: ${embedding_output_params}
7 | stft_params: ${stft_params}
8 |
9 | embedding_output_params:
10 | _target_: memory_policy.ComponentOutputParams
11 | requires_recomputation: ${embedding_requires_recomputation}
12 | reduction_mode: ${embedding_reduction_mode}
13 | ema_params: ${embedding_ema_params}
14 | output_past_non_reduced_history: ${embedding_output_past_non_reduced_history}
15 | max_non_reduced_history_len: ${embedding_max_non_reduced_history_len}
16 |
17 | embedding_ema_params:
18 | _target_: memory_policy.EMAParams
19 | coeff: ${embedding_ema_coeff}
20 | learned: ${embedding_reduction_learned}
21 | reduction_stride: ${hop_length}
22 |
23 | stft_params:
24 | _target_: memory_policy.STFTParams
25 | n_fft: ${n_fft}
26 | hop_length: ${hop_length}
27 | window_fn: ${window_fn}
28 | pad_mode: ${pad_mode}
29 | output_magnitudes: ${output_magnitudes}
30 |
31 | # Spectrogram params.
32 | embedding_shared: true
33 | # Output params.
34 | embedding_requires_recomputation: true
35 | embedding_reduction_mode:
36 | # Ema params.
37 | embedding_ema_coeff: 0.975
38 | embedding_reduction_learned: false
39 | # --
40 | embedding_output_past_non_reduced_history: false
41 | embedding_max_non_reduced_history_len:
42 | # --
43 | # Stft params
44 | n_fft: 32
45 | hop_length: 16
46 | window_fn:
47 | _target_: memory_policy.fft_avg_mask
48 | window_length: ${n_fft}
49 | pad_mode: constant
50 | output_magnitudes: true
51 | # --
52 |
53 | window_fn_log_name: avg
54 | embedding_spec_mp_log_name: attn-spec
--------------------------------------------------------------------------------
/cfgs/policy/deep_scoring/mlp.yaml:
--------------------------------------------------------------------------------
1 | scoring_network: ${mlp_scoring}
2 |
3 | mlp_scoring:
4 | _target_: memory_policy.MLPScoring
5 | per_layer: ${per_layer}
6 | per_head: ${per_head}
7 | shared: ${scoring_shared}
8 | output_params: ${scoring_output_params}
9 | hidden_features: ${hidden_features}
10 | hidden_depth: ${hidden_depth}
11 | bias: ${scoring_mlp_bias}
12 | non_linearity: ${non_linearity}
13 | initializer: ${scoring_initializer}
14 | residual: ${residual}
15 | residual_first: ${residual_first}
16 |
17 |
18 | scoring_output_params:
19 | _target_: memory_policy.ComponentOutputParams
20 | requires_recomputation: ${scoring_requires_recomputation}
21 | reduction_mode: ${scoring_reduction_mode}
22 | ema_params: ${scoring_ema_params}
23 | output_past_non_reduced_history: ${scoring_output_past_non_reduced_history}
24 | max_non_reduced_history_len: ${scoring_max_non_reduced_history_len}
25 |
26 | scoring_ema_params:
27 | _target_: memory_policy.EMAParams
28 | coeff: ${scoring_ema_coeff}
29 | learned: ${scoring_reduction_learned}
30 | reduction_stride: ${hop_length}
31 |
32 | # MLP params
33 | scoring_shared: true
34 | # Output params.
35 | scoring_requires_recomputation: true
36 | scoring_reduction_mode: ema
37 | # Ema params.
38 | scoring_ema_coeff: 0.99
39 | scoring_reduction_learned: false
40 | # --
41 | scoring_output_past_non_reduced_history: false
42 | scoring_max_non_reduced_history_len:
43 | # --
44 | hidden_features:
45 | hidden_depth: 1
46 | scoring_mlp_bias: false
47 | non_linearity: relu
48 | scoring_initializer:
49 | _target_: memory_policy.make_scaled_one_hot_init
50 | idxs_to_scale: {}
51 | idxs_to_ones: []
52 | residual: true
53 | residual_first: true
54 | # --
55 |
56 |
57 | scoring_mp_log_name: mlp
--------------------------------------------------------------------------------
/utils_log.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
5 |
6 | from lm_eval.models.utils import Collator
7 |
8 |
9 | def stat_fn(list_of_lists, index=None):
10 | if index is None:
11 | flat_list = np.array(
12 | [v for vs in list_of_lists for v in vs])
13 | else:
14 | mean = np.mean(list_of_lists[i])
15 | mean = np.mean(flat_list)
16 | std = np.std(flat_list)
17 | above_avg = np.mean(flat_list > mean)
18 | max_v = np.max(flat_list)
19 | min_v = np.min(flat_list)
20 | return mean, std, above_avg, max_v, min_v
21 |
22 |
23 | def initialize_stat_objects_for(
24 | self,
25 | score_name,
26 | stats=['mean', 'std', 'above_mean', 'max', 'min'],
27 | ):
28 | for stat in stats:
29 | init_list = [[] for _ in range(self.num_memory_layers)]
30 | setattr(self, f'{score_name}_{stat}', init_list)
31 | raise NotImplementedError
32 |
33 |
34 | class COLOR:
35 | # ANSI color codes and tools
36 | BLACK = "\033[0;30m"
37 | RED = "\033[0;31m"
38 | GREEN = "\033[0;32m"
39 | BROWN = "\033[0;33m"
40 | BLUE = "\033[0;34m"
41 | PURPLE = "\033[0;35m"
42 | CYAN = "\033[0;36m"
43 | LIGHT_GRAY = "\033[0;37m"
44 | DARK_GRAY = "\033[1;30m"
45 | LIGHT_RED = "\033[1;31m"
46 | LIGHT_GREEN = "\033[1;32m"
47 | YELLOW = "\033[1;33m"
48 | LIGHT_BLUE = "\033[1;34m"
49 | LIGHT_PURPLE = "\033[1;35m"
50 | LIGHT_CYAN = "\033[1;36m"
51 | LIGHT_WHITE = "\033[1;37m"
52 | BOLD = "\033[1m"
53 | FAINT = "\033[2m"
54 | ITALIC = "\033[3m"
55 | UNDERLINE = "\033[4m"
56 | BLINK = "\033[5m"
57 | NEGATIVE = "\033[7m"
58 | CROSSED = "\033[9m"
59 | END = "\033[0m"
60 |
--------------------------------------------------------------------------------
/choubun_metrics.py:
--------------------------------------------------------------------------------
1 | import re
2 | import string
3 |
4 | from collections import Counter
5 | from rouge import Rouge
6 |
7 |
8 | def normalize_ja_answer(s):
9 | """Lower text and remove punctuation, extra whitespace."""
10 |
11 | def white_space_fix(text):
12 | return "".join(text.split())
13 |
14 | def remove_punc(text):
15 | cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
16 | all_punctuation = set(string.punctuation + cn_punctuation)
17 | return "".join(ch for ch in text if ch not in all_punctuation)
18 |
19 | def lower(text):
20 | return text.lower()
21 |
22 | return white_space_fix(remove_punc(lower(s)))
23 |
24 |
25 | def rouge_score(prediction, ground_truth, **kwargs):
26 | rouge = Rouge()
27 | try:
28 | scores = rouge.get_scores([prediction], [ground_truth], avg=True)
29 | except:
30 | return 0.0
31 | return scores["rouge-l"]["f"]
32 |
33 |
34 | def rouge_ja_score(prediction, ground_truth, tokenizer, **kwargs):
35 | prediction_tokens = [word.surface for word in tokenizer(prediction)]
36 | ground_truth_tokens = [word.surface for word in tokenizer(ground_truth)]
37 | prediction = " ".join(prediction_tokens)
38 | ground_truth = " ".join(ground_truth_tokens)
39 | score = rouge_score(prediction, ground_truth)
40 | return score
41 |
42 |
43 | def f1_score(prediction, ground_truth, **kwargs):
44 | common = Counter(prediction) & Counter(ground_truth)
45 | num_same = sum(common.values())
46 | if num_same == 0:
47 | return 0
48 | precision = 1.0 * num_same / len(prediction)
49 | recall = 1.0 * num_same / len(ground_truth)
50 | f1 = (2 * precision * recall) / (precision + recall)
51 | return f1
52 |
53 |
54 | def qa_f1_ja_score(prediction, ground_truth, tokenizer, **kwargs):
55 | prediction_tokens = [word.surface for word in tokenizer(prediction)]
56 | ground_truth_tokens = [word.surface for word in tokenizer(ground_truth)]
57 | prediction_tokens = [normalize_ja_answer(token) for token in prediction_tokens]
58 | ground_truth_tokens = [normalize_ja_answer(token) for token in ground_truth_tokens]
59 | prediction_tokens = [token for token in prediction_tokens if len(token) > 0]
60 | ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0]
61 | return f1_score(prediction_tokens, ground_truth_tokens)
--------------------------------------------------------------------------------
/cfgs/trainer/default.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 |
4 | trainer:
5 | _target_: memory_trainer.MemoryTrainer
6 | trainer_config: ${trainer_config}
7 | wandb_config: ${wandb_config}
8 | device: ${device}
9 | scratch: ${scratch}
10 |
11 | trainer_config:
12 | _target_: memory_trainer.TrainerConfig
13 | out_dir: ${out_dir}
14 | max_iters: ${max_iters}
15 | task_batch_size: ${task_batch_size}
16 | samples_batch_size: ${samples_batch_size}
17 | eval_samples_batch_size: ${eval_samples_batch_size}
18 | allow_distributed_eval: ${allow_distributed_eval}
19 |
20 | pop_accumulation_steps: ${pop_accumulation_steps}
21 |
22 | score_aggregation: ${score_aggregation}
23 | score_normalization_reference: ${score_normalization_reference}
24 |
25 | synchronized_buffers_aggregation: ${synchronized_buffers_aggregation}
26 | synchronized_buffers_freeze_after: ${synchronized_buffers_freeze_after}
27 |
28 | prefetch_task_tensors: ${prefetch_task_tensors}
29 | override_prefetched_tensors: ${override_prefetched_tensors}
30 |
31 | eval_interval: ${eval_interval}
32 | early_stop_patience: ${early_stop_patience}
33 | log_interval: ${log_interval}
34 | eval_iters: ${eval_iters}
35 | eval_only: ${eval_only}
36 | eval_candidate_samples: ${eval_candidate_samples}
37 | eval_candidate_temp: ${eval_candidate_temp}
38 | record_advanced_eval_stats: ${record_advanced_eval_stats}
39 | store_eval_results_locally: ${store_eval_results_locally}
40 | record_per_task_eval_stats: ${record_per_task_eval_stats}
41 |
42 | always_save_checkpoint: ${always_save_checkpoint}
43 | keep_past_epoch_checkpoints_every: ${keep_past_epoch_checkpoints_every}
44 | use_amp: ${use_amp}
45 | init_from: ${init_from}
46 | dtype: ${dtype}
47 |
48 | # main trainer cfgs
49 | max_iters: 10000
50 | task_batch_size:
51 | samples_batch_size: 16
52 | eval_samples_batch_size: 128
53 | allow_distributed_eval: false
54 |
55 | pop_accumulation_steps: 1
56 | score_aggregation: 'mean'
57 | score_normalization_reference:
58 |
59 | # use 'mean' or 'best' population member to aggregate sync. buffers
60 | synchronized_buffers_aggregation: 'mean'
61 | synchronized_buffers_freeze_after:
62 |
63 | # prefetching (only implemented for generation tasks)
64 | prefetch_task_tensors: false
65 | override_prefetched_tensors: false
66 |
67 | # logging
68 | eval_interval: 100
69 | early_stop_patience: -1
70 | log_interval: 50
71 | eval_iters: 1
72 | eval_only: false
73 | eval_candidate_samples:
74 | eval_candidate_temp:
75 |
76 | always_save_checkpoint: true
77 | keep_past_epoch_checkpoints_every:
78 | scratch: true
79 | record_advanced_eval_stats: true
80 | store_eval_results_locally: true
81 | record_per_task_eval_stats: false
82 |
83 | # precision
84 | use_amp: false
85 | init_from:
86 | dtype: 'bfloat16'
87 |
88 |
--------------------------------------------------------------------------------
/memory_policy/deep_embedding_shared.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pdb
3 | import copy
4 | import math
5 | import numbers
6 | import abc
7 | import numpy as np
8 | from dataclasses import dataclass
9 | from typing import Optional, Tuple, Union, Dict, Callable, List
10 |
11 | from omegaconf import OmegaConf, DictConfig
12 | import hydra
13 |
14 | import torch
15 | from torch import nn
16 | import torch.utils.checkpoint
17 |
18 |
19 | def convert_to_tensor(
20 | el: Union[List[float], np.ndarray, torch.Tensor],
21 | ) -> torch.Tensor:
22 | if isinstance(el, torch.Tensor):
23 | return el
24 | else:
25 | el = torch.tensor(el)
26 | return el
27 |
28 | def cos_sin_seq_embeddings(
29 | length,
30 | embed_dim,
31 | max_freq=10000,
32 |
33 |
34 | ):
35 | assert embed_dim % 2 == 0, 'Embedding dimension should be even '
36 | positions = np.arange(length)
37 | embed_dim_per_op = embed_dim // 2
38 |
39 |
40 | freq_coeff = np.arange(embed_dim_per_op, dtype=np.float64)/embed_dim_per_op
41 | freq_coeff = 1/(max_freq**freq_coeff)
42 |
43 | freq_values = np.expand_dims(positions, axis=-1)*freq_coeff
44 | embeddings = np.concatenate(
45 | [np.sin(freq_values), np.cos(freq_values)], axis=-1)
46 | return embeddings
47 |
48 | class Embedding(nn.Module, abc.ABC):
49 | def __init__(self, embed_dim):
50 | nn.Module.__init__(self=self,)
51 | self.embed_dim = embed_dim
52 | if embed_dim is not None:
53 | self.set_embed_dim(embed_dim=embed_dim)
54 |
55 | def set_embed_dim(self, embed_dim):
56 | self.embed_dim: int = embed_dim
57 | assert self.embed_dim is not None, 'make sure embed_dim is not None'
58 |
59 | @abc.abstractmethod
60 | def forward(self, x: torch.Tensor) -> torch.Tensor:
61 | raise NotImplementedError
62 |
63 |
64 | class PositionalEmbedding(Embedding):
65 | def __init__(self, max_position_id, embed_dim, max_freq):
66 | self.max_position_id = max_position_id
67 | self.max_freq = max_freq
68 | Embedding.__init__(self=self, embed_dim=embed_dim)
69 |
70 |
71 | def set_embed_dim(self, embed_dim):
72 | Embedding.set_embed_dim(self=self, embed_dim=embed_dim)
73 | embeddings = cos_sin_seq_embeddings(
74 |
75 |
76 | length=self.max_position_id + 1,
77 | embed_dim=embed_dim,
78 | max_freq=self.max_freq,
79 | ).astype('float32')
80 | self.register_buffer('embeddings', torch.tensor(embeddings))
81 |
82 | def forward(self, x: torch.Tensor) -> torch.Tensor:
83 |
84 |
85 |
86 |
87 |
88 | emb = self.embeddings[x]
89 |
90 | return emb
91 |
92 |
--------------------------------------------------------------------------------
/utils_hydra.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
2 |
3 | from torch.nn import functional as F
4 | import torch
5 | import transformers
6 | import numpy as np
7 |
8 |
9 | from hydra import compose, initialize
10 | from main import make_eval_model, make_task_sampler, wandb_init
11 | import omegaconf
12 | import hydra
13 | import time
14 |
15 |
16 | # File containing imports to be used in hydra configs/for instantiating hydra
17 | # objects outside main
18 |
19 | def initialize_cfg(
20 | config_path="cfgs",
21 | hydra_overrides: dict = {},
22 | log_yaml: bool = False,
23 |
24 | ):
25 | with initialize(version_base=None, config_path=config_path,
26 | job_name="test_app"):
27 | cfg = compose(config_name="config",
28 | overrides=hydra_overrides,
29 | )
30 | if log_yaml:
31 | print('Loading the following configurations:')
32 | print(omegaconf.OmegaConf.to_yaml(cfg))
33 | return cfg
34 |
35 |
36 | def load_run_cfgs_trainer(
37 | run_file_path_location: str,
38 | hydra_overrides_from_dict: dict = dict(
39 | wandb_log="false",
40 | wandb_project="scratch",
41 | ),
42 | task_sampler_kwargs: dict = dict(
43 | tasks=["lb/passage_retrieval_en"],
44 | ),
45 | config_path="cfgs",
46 | batch_size: int = 1,
47 | ):
48 | print(f"Loading model specified in {run_file_path_location}")
49 | start_time = time.time()
50 | hydra_overrides = [
51 | f"run@_global_={run_file_path_location}",
52 | ]
53 |
54 | hydra_overrides_from_dict = [f"{k}={v}" for k, v in
55 | hydra_overrides_from_dict.items()]
56 |
57 | hydra_overrides = hydra_overrides + hydra_overrides_from_dict
58 |
59 | cfg = initialize_cfg(
60 | config_path=config_path,
61 | hydra_overrides=hydra_overrides,
62 | )
63 |
64 | (memory_policy, memory_model, memory_evaluator, evolution_algorithm,
65 | auxiliary_loss) = make_eval_model(cfg=cfg)
66 |
67 | task_sampler = make_task_sampler(cfg=cfg, **task_sampler_kwargs)
68 |
69 | trainer = hydra.utils.instantiate(
70 | cfg.trainer,
71 | evaluation_model=memory_evaluator,
72 | task_sampler=task_sampler,
73 | evolution_algorithm=evolution_algorithm,
74 | auxiliary_loss=auxiliary_loss,
75 | )
76 | params, buffers = trainer.sample_and_synchronize_params(best=True)
77 | memory_model = memory_model
78 | memory_model.set_memory_params(params=params)
79 |
80 | if memory_model.memory_policy_has_buffers_to_merge():
81 | memory_model.load_buffers_dict(buffers_dict=buffers)
82 |
83 | batch_idxs = np.zeros([batch_size])
84 | memory_policy.set_params_batch_idxs(batch_idxs)
85 | print("Time taken:", round(time.time() - start_time))
86 | return trainer # contains all other models
87 |
--------------------------------------------------------------------------------
/cfgs/policy/deep_scoring/attn.yaml:
--------------------------------------------------------------------------------
1 | scoring_network: ${generalized_scoring_net}
2 |
3 | generalized_scoring_net:
4 | _target_: memory_policy.GeneralizedScoring
5 | per_layer: ${per_layer}
6 | per_head: ${per_head}
7 | shared: ${scoring_shared}
8 | output_params: ${scoring_output_params}
9 | initializer: ${scoring_initializer}
10 | stateless_modules_list: ${stateless_modules_list}
11 | residual: ${residual}
12 | mult: ${mult}
13 | mult_nonlinearity: ${mult_nonlinearity}
14 |
15 | mlp_layer:
16 | _target_: stateless_parallel_modules.StatelessGeneralizedMLP
17 | input_features: ${scoring_attn_output_dim}
18 | hidden_features: ${scoring_mlp_hidden_features}
19 | hidden_depth: ${scoring_mlp_hidden_depth}
20 | output_features: ${scoring_mlp_output_features}
21 | bias: ${scoring_mlp_bias}
22 | non_linearity: ${scoring_mlp_non_linearity}
23 | residual: ${scoring_mlp_residual}
24 | residual_first: ${scoring_mlp_residual_first}
25 |
26 | scoring_mlp_hidden_features:
27 | scoring_mlp_hidden_depth: 1
28 | scoring_mlp_output_features: 1
29 | scoring_mlp_bias: true
30 | scoring_mlp_non_linearity: relu
31 | scoring_mlp_residual: true
32 | scoring_mlp_residual_first: false
33 |
34 |
35 | attention_layer:
36 | _target_: stateless_parallel_modules.StatelessAttention
37 | attention_params: ${attention_params}
38 |
39 | attention_params:
40 | _target_: stateless_parallel_modules.StatelessAttentionParams
41 | input_dim:
42 | hidden_dim: ${scoring_attn_hidden_dim}
43 | output_dim: ${scoring_attn_output_dim}
44 | num_heads: ${scoring_attn_num_heads}
45 | bias: ${scoring_attn_bias}
46 | max_position_id: ${max_position_id}
47 | use_rope: ${scoring_attn_use_rope}
48 | rope_theta: ${scoring_attn_rope_theta}
49 | masking_strategy: ${scoring_attn_masking_strategy}
50 |
51 | scoring_attn_hidden_dim: 32
52 | scoring_attn_output_dim:
53 | scoring_attn_num_heads: 4
54 | scoring_attn_bias: false
55 | scoring_attn_use_rope: false
56 | scoring_attn_rope_theta: 10000
57 | scoring_attn_masking_strategy:
58 |
59 | scoring_output_params:
60 | _target_: memory_policy.ComponentOutputParams
61 | requires_recomputation: ${scoring_requires_recomputation}
62 | reduction_mode: ${scoring_reduction_mode}
63 | ema_params: ${scoring_ema_params}
64 | output_past_non_reduced_history: ${scoring_output_past_non_reduced_history}
65 | max_non_reduced_history_len: ${scoring_max_non_reduced_history_len}
66 |
67 | scoring_ema_params:
68 | _target_: memory_policy.EMAParams
69 | coeff: ${scoring_ema_coeff}
70 | learned: ${scoring_reduction_learned}
71 | reduction_stride: ${hop_length}
72 |
73 | # MLP params
74 | scoring_shared: true
75 | # Output params.
76 | scoring_requires_recomputation: true
77 | scoring_reduction_mode: ema
78 | # Ema params.
79 | scoring_ema_coeff: 0.975
80 | scoring_reduction_learned: false
81 | # --
82 | scoring_output_past_non_reduced_history: false
83 | scoring_max_non_reduced_history_len:
84 | # --
85 | scoring_initializer: 0
86 | stateless_modules_list:
87 | - ${attention_layer}
88 | - ${mlp_layer}
89 | residual: true
90 | mult: false
91 | mult_nonlinearity:
92 |
93 | # --
94 |
95 |
96 | scoring_mp_log_name: bam-no-mult
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | 
4 | An Evolved Universal Transformer Memory
5 |
6 |
7 |
8 | 📄 [Paper] |
9 | 🤗 [Hugging Face]
10 | 📁 [Dataset]
11 |
12 |
13 | ## Installation
14 |
15 | We provide means to install this repository with [conda](https://docs.conda.io/projects/conda/en/latest/index.html):
16 |
17 | For the full set of dependencies with fixed versions (provided to ensure some level of long-term reproducibility):
18 |
19 | ```bash
20 | conda env create --file=env.yaml
21 | ```
22 |
23 | For a more minimal and less constrained set of dependencies (for future development/extensions):
24 |
25 | ```bash
26 | conda env create --file=env_minimal.yaml
27 | ```
28 |
29 | ## Usage
30 |
31 | ### Training
32 |
33 | Training following the incremental setup described in our work can be replicated via the following [hydra](https://hydra.cc/) commands:
34 |
35 | stage 1 training:
36 | ```bash
37 | torchrun --standalone --nproc_per_node=$NUM_OF_GPUs main.py run@_global_=namm_bam_i1.yaml
38 | ```
39 |
40 | stage 2 training:
41 | ```bash
42 | torchrun --standalone --nproc_per_node=$NUM_OF_GPUs main.py run@_global_=namm_bam_i2.yaml init_from='path/to/stage1/results/ckpt.pt'
43 | ```
44 |
45 | stage 3 training:
46 | ```bash
47 | torchrun --standalone --nproc_per_node=$NUM_OF_GPUs main.py run@_global_=namm_bam_i3.yaml init_from='path/to/stage2/results/ckpt.pt'
48 | ```
49 |
50 | ### Evaluation
51 |
52 | Evaluating trained NAMMs on the full set of LongBench tasks can be replicated for both NAMMs with the following command:
53 |
54 | ```bash
55 | torchrun --standalone --nproc_per_node=$NUM_OF_GPUs main.py run@_global_=namm_bam_eval.yaml init_from='path/to/results/ckpt.pt'
56 | ```
57 |
58 | Evaluating trained NAMMs on the full set of ChouBun tasks can be replicated with the following command:
59 |
60 | ```bash
61 | torchrun --standalone --nproc_per_node=$NUM_OF_GPUs main.py run@_global_=namm_bam_eval_choubun.yaml init_from='path/to/results/ckpt.pt'
62 | ```
63 |
64 | ### Additional notes
65 |
66 | Using [wandb](https://wandb.ai/) to log the results (through the hydra setting wandb_log=true) requires authenticating to the wandb server via the following command:
67 |
68 | ```bash
69 | wandb login
70 | ```
71 |
72 | and using your account's API key (which you should be able to find [here](https://wandb.ai/authorize))
73 |
74 | ### Gated models (e.g., Llama)
75 |
76 | Using gated models requires authenticating to the hugging face hub by running:
77 |
78 | ```bash
79 | huggingface-cli login
80 | ```
81 |
82 | and using your account's access tokens (which you should be able to find [here](https://huggingface.co/settings/tokens))
83 |
84 |
85 | ## Bibtex
86 |
87 | To cite our work, you can use the following:
88 |
89 | ```
90 | @article{sakana2024memory,
91 | title={An Evolved Universal Transformer Memory},
92 | author={Edoardo Cetin and Qi Sun and Tianyu Zhao and Yujin Tang},
93 | year={2024},
94 | eprint={2410.13166},
95 | archivePrefix={arXiv},
96 | primaryClass={cs.LG},
97 | url={https://arxiv.org/abs/2410.13166},
98 | }
99 | ```
100 |
101 |
--------------------------------------------------------------------------------
/cfgs/run/base_memory_policy/deep/bam/base_bam.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model/wrapped_llm@_global_: llama3-8b-rope-x4NTK
4 | - override /task@_global_: passage_retrieval_en
5 | - override /evolution@_global_: cma_es
6 | - override /policy@_global_: deep
7 | - override /policy/deep_embedding@_global_: attn_spec_norm
8 | - override /policy/deep_scoring@_global_: bam
9 | - override /policy/deep_selection@_global_: binary
10 | - override /typing@_global_: half_attn
11 | - override /auxiliary_loss@_global_: none
12 | - _self_
13 |
14 | run_name_suffix: ''
15 |
16 | per_head: false
17 | per_layer: false
18 |
19 | # ATTN params
20 | scoring_shared: true
21 | # Output params.
22 | scoring_requires_recomputation: true
23 | scoring_reduction_mode:
24 | # Ema params.
25 | scoring_ema_coeff: 0.99
26 | scoring_reduction_learned: false
27 | # --
28 | scoring_output_past_non_reduced_history: false
29 | scoring_max_non_reduced_history_len:
30 |
31 |
32 | # --
33 | # increasing the hidden dim, increases time cost marginally
34 | scoring_attn_hidden_dim: 16
35 | scoring_attn_output_dim:
36 | scoring_attn_num_heads: 1
37 | scoring_attn_bias: false
38 | scoring_attn_use_rope: false
39 | scoring_attn_rope_theta: 50000
40 | scoring_attn_masking_strategy: backward
41 | #--
42 |
43 |
44 | # Spectrogram params.
45 | embedding_shared: true
46 | # Output params.
47 | embedding_requires_recomputation: true
48 | embedding_reduction_mode: ema
49 | # Ema params.
50 | embedding_ema_coeff: 0.99
51 | embedding_reduction_learned: false
52 | # --
53 | embedding_output_past_non_reduced_history: false
54 | embedding_max_non_reduced_history_len:
55 | # --
56 | # Stft params
57 | n_fft: 32
58 | hop_length: 16
59 | window_fn:
60 | _target_: utils_hydra.torch.hann_window
61 | window_length: ${n_fft}
62 | periodic: true
63 | window_fn_log_name: hann
64 | pad_mode: constant
65 | output_magnitudes: true
66 | # --
67 | scoring_initializer: 0
68 |
69 | scoring_mlp_bias: true
70 |
71 |
72 | cache_size:
73 |
74 |
75 | always_save_checkpoint: true
76 |
77 | wandb_log: true #false # NOTE
78 |
79 | log_misc: false
80 | wandb_project: memory_evolution_hf
81 |
82 | wandb_run_name: ${tasks_log_name}-${evolution_algorithm_log_name}-shared-${pop_size}pop-${samples_batch_size}qs-${memory_policy_fixed_delay}fixDel-${run_name_suffix}
83 | wandb_group_name: ${llm_log_name}/${mp_log_name}
84 |
85 | add_bos_token: true # to match ref. scores setting
86 | score_normalization_reference: 'lb_reference_scores/per_request/llama3-8b-32k-x2NTK2alpha-v2.json'
87 |
88 | eval_max_batch_size: 32
89 | per_timestep_loglikelihood: false
90 |
91 | batch_size: 1 # NOTE: super low for the moment...
92 |
93 | max_iters: 1000
94 |
95 | # total samples per step = 64 x 3 (tasks) x 16 (pop) = 3072
96 | pop_size: 16
97 |
98 | # CMA-ES params
99 | elite_ratio: 0.5
100 | c_m: 1.0
101 | init_sigma: 0.065
102 |
103 | pop_accumulation_steps: 1
104 | score_aggregation: 'mean'
105 |
106 | # logging
107 | eval_samples_batch_size:
108 |
109 | eval_interval: 10
110 | early_stop_patience: -1
111 | log_interval: 1
112 | eval_iters: 1
113 |
114 | # precision
115 | use_amp: false
116 | dtype: bfloat16
117 |
118 | init_mean: 0.0
119 | init_stdev: 1.0
120 |
121 | allow_distributed_eval: true
122 | memory_policy_fixed_delay: 512
123 | max_new_tokens:
124 |
125 | synchronized_buffers_freeze_after: 1
126 |
127 | prefer_mean_to_best: true
128 | samples_batch_size: 64
129 |
130 | start_recency_from: 0
--------------------------------------------------------------------------------
/memory_llms/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from typing import Optional
3 | from memory_policy import MemoryPolicy
4 |
5 |
6 | class MemoryModelWrapper(abc.ABC):
7 | def __init__(self,
8 | config,
9 | memory_policy: MemoryPolicy,
10 | registration_kwargs: dict,
11 | # if set, recomputes memory policy every fixed number of steps
12 | memory_policy_fixed_delay: Optional[int] = None,
13 | output_attentions_full_precision: bool = True,
14 | max_new_tokens: Optional[int] = None,
15 | ):
16 | self.memory_policy: MemoryPolicy = memory_policy
17 | self.registration_kwargs = registration_kwargs
18 | self.config = config
19 | self.memory_policy.register_new_memory_model(
20 | self.config, self.registration_kwargs)
21 | self.memory_policy.finalize_registration()
22 | self.memory_policy_fixed_delay = memory_policy_fixed_delay
23 |
24 | if not hasattr(self, 'max_new_tokens'):
25 | self.max_new_tokens = max_new_tokens
26 |
27 | if (self.memory_policy_fixed_delay is not None) and (
28 | self.max_new_tokens is not None):
29 | assert self.memory_policy_fixed_delay % self.max_new_tokens == 0
30 | self.output_attentions_full_precision = output_attentions_full_precision
31 | self._past_length = 0
32 |
33 | def load_partial_state_dict(self, state_dict):
34 | current_state = self.state_dict()
35 | for name in self.base_model_param_keys:
36 | target_param = state_dict[name]
37 | current_state[name].copy_(target_param.data)
38 |
39 |
40 | def swap_memory_policy(self, new_memory_policy: MemoryPolicy):
41 | self.memory_policy = new_memory_policy
42 | self.memory_policy.register_new_memory_model(
43 | self.config, self.registration_kwargs)
44 | self.memory_policy.finalize_registration()
45 | self.memory_requires_attn = self.memory_policy.requires_attn_scores
46 | self.memory_requires_queries = self.memory_policy.requires_queries
47 |
48 | @property
49 | def cache_size(self,):
50 | return self.memory_policy.cache_size
51 |
52 | def set_memory_params(self, params) -> None:
53 | self.memory_policy.set_params(params=params)
54 |
55 | def get_memory_params(self,):
56 | return self.memory_policy.get_layer_params()
57 |
58 | def set_memory_params_batch_idxs(self, param_idxs) -> None:
59 | self.memory_policy.set_params_batch_idxs(param_idxs=param_idxs)
60 |
61 | def get_param_size(self,):
62 | return self.memory_policy.param_size
63 |
64 | def get_param_stats(self,):
65 | return self.memory_policy.get_param_stats()
66 |
67 | def get_buffers_list(self,):
68 | return self.memory_policy.get_buffers_list()
69 |
70 | def self_merge(self,):
71 | return self.memory_policy.self_merge()
72 |
73 | def merge_buffers_list(self, buffers_to_merge):
74 | return self.memory_policy.merge_buffers_list(
75 | buffers_to_merge=buffers_to_merge)
76 |
77 | def receive_buffers_list(self, buffers_list):
78 | return self.memory_policy.receive_buffers_list(
79 | buffers_list=buffers_list)
80 |
81 | # functions to save (e.g., normalization) buffers along with the checkpoints
82 | def get_buffers_dict(self,):
83 | return self.memory_policy.get_buffers_dict()
84 |
85 | def load_buffers_dict(self, buffers_dict):
86 | return self.memory_policy.load_buffers_dict(buffers_dict=buffers_dict)
87 |
88 | def memory_policy_has_buffers_to_merge(self,):
89 | return self.memory_policy._has_buffers_to_merge
90 |
91 | # functions to signal the memory policy that ut is being trained/evaluated
92 | def training_mode(self,):
93 | self.memory_policy.training_mode()
94 |
95 | def evaluation_mode(self,):
96 | self.memory_policy.evaluation_mode()
97 |
98 | def freeze_sync_buffers(self, freeze=True):
99 | self.memory_policy.freeze_sync_buffers(freeze=freeze)
100 |
101 | def unfreeze_sync_buffers(self,):
102 | self.memory_policy.unfreeze_sync_buffers()
103 | #
104 | def are_sync_buffers_frozen(self,):
105 | return self.memory_policy.are_sync_buffers_frozen()
106 |
107 | class MemoryDecoderLayer(abc.ABC):
108 |
109 | abc.abstractmethod
110 | def __init__(self,):
111 | pass
112 |
113 | class MemoryAttention(abc.ABC):
114 | abc.abstractmethod
115 | def __init__(self,):
116 | pass
117 |
--------------------------------------------------------------------------------
/LongBench/config/dataset2prompt.json:
--------------------------------------------------------------------------------
1 | {
2 | "narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {context}\n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:",
3 | "qasper": "You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nArticle: {context}\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:",
4 | "multifieldqa_en": "Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
5 | "multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n{context}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:",
6 | "hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
7 | "2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
8 | "musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
9 | "dureader": "请基于给定的文章回答下述问题。\n\n文章:{context}\n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:",
10 | "gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}\n\nNow, write a one-page summary of the report.\n\nSummary:",
11 | "qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n{context}\n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:",
12 | "multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\nSummary:",
13 | "vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n{context}\n\n会议总结:",
14 | "trec": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}\n{input}",
15 | "triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}",
16 | "samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}\n\n{input}",
17 | "lsht": "请判断给定新闻的类别,下面是一些例子。\n\n{context}\n{input}",
18 | "passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ",
19 | "passage_retrieval_en": "Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}\n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like \"Paragraph 1\", \"Paragraph 2\", etc.\n\nThe answer is: ",
20 | "passage_retrieval_zh": "以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}\n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是\"段落1\",\"段落2\"等格式\n\n答案是:",
21 | "lcc": "Please complete the code given below. \n{context}Next line of code:\n",
22 | "repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n"
23 | }
--------------------------------------------------------------------------------
/longbench_metrics.py:
--------------------------------------------------------------------------------
1 | import re
2 | import string
3 |
4 | import jieba
5 | from fuzzywuzzy import fuzz
6 | import difflib
7 |
8 | from typing import List
9 | from collections import Counter
10 | from rouge import Rouge
11 |
12 | # copied from https://github.com/THUDM/LongBench/blob/main/metrics.py
13 |
14 |
15 | def normalize_answer(s):
16 | """Lower text and remove punctuation, articles and extra whitespace."""
17 |
18 | def remove_articles(text):
19 | return re.sub(r"\b(a|an|the)\b", " ", text)
20 |
21 | def white_space_fix(text):
22 | return " ".join(text.split())
23 |
24 | def remove_punc(text):
25 | exclude = set(string.punctuation)
26 | return "".join(ch for ch in text if ch not in exclude)
27 |
28 | def lower(text):
29 | return text.lower()
30 |
31 | return white_space_fix(remove_articles(remove_punc(lower(s))))
32 |
33 |
34 | def normalize_zh_answer(s):
35 | """Lower text and remove punctuation, extra whitespace."""
36 |
37 | def white_space_fix(text):
38 | return "".join(text.split())
39 |
40 | def remove_punc(text):
41 | cn_punctuation = (
42 | "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》"
43 | "「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏.")
44 | all_punctuation = set(string.punctuation + cn_punctuation)
45 | return "".join(ch for ch in text if ch not in all_punctuation)
46 |
47 | def lower(text):
48 | return text.lower()
49 |
50 | return white_space_fix(remove_punc(lower(s)))
51 |
52 |
53 | def count_score(prediction, ground_truth, **kwargs):
54 | numbers = re.findall(r"\d+", prediction)
55 | right_num = 0
56 | for number in numbers:
57 | if str(number) == str(ground_truth):
58 | right_num += 1
59 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
60 | return float(final_score)
61 |
62 |
63 | def retrieval_score(prediction, ground_truth, **kwargs):
64 | pattern = r'Paragraph (\d+)'
65 | matches = re.findall(pattern, ground_truth)
66 | ground_truth_id = matches[0]
67 | numbers = re.findall(r"\d+", prediction)
68 | right_num = 0
69 | for number in numbers:
70 | if str(number) == str(ground_truth_id):
71 | right_num += 1
72 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
73 | return float(final_score)
74 |
75 |
76 | def retrieval_zh_score(prediction, ground_truth, **kwargs):
77 | pattern = r'段落(\d+)'
78 | matches = re.findall(pattern, ground_truth)
79 | ground_truth_id = matches[0]
80 | numbers = re.findall(r"\d+", prediction)
81 | right_num = 0
82 | for number in numbers:
83 | if str(number) == str(ground_truth_id):
84 | right_num += 1
85 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
86 | return float(final_score)
87 |
88 |
89 | def code_sim_score(prediction, ground_truth, **kwargs):
90 | all_lines = prediction.lstrip('\n').split('\n')
91 | prediction = ""
92 | for line in all_lines:
93 | if ('`' not in line) and ('#' not in line) and ('//' not in line):
94 | prediction = line
95 | break
96 | return (fuzz.ratio(prediction, ground_truth) / 100)
97 |
98 |
99 | def classification_score(prediction, ground_truth, **kwargs):
100 | em_match_list = []
101 | all_classes = kwargs["all_classes"]
102 | for class_name in all_classes:
103 | if class_name in prediction:
104 | em_match_list.append(class_name)
105 | for match_term in em_match_list:
106 | if match_term in ground_truth and match_term != ground_truth:
107 | em_match_list.remove(match_term)
108 | if ground_truth in em_match_list:
109 | score = (1.0 / len(em_match_list))
110 | else:
111 | score = 0.0
112 | return score
113 |
114 |
115 | def rouge_score(prediction, ground_truth, **kwargs):
116 | rouge = Rouge()
117 | try:
118 | scores = rouge.get_scores([prediction], [ground_truth], avg=True)
119 | except:
120 | return 0.0
121 | return scores["rouge-l"]["f"]
122 |
123 |
124 | def rouge_zh_score(prediction, ground_truth, **kwargs):
125 | prediction = " ".join(list(jieba.cut(prediction, cut_all=False)))
126 | ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False)))
127 | score = rouge_score(prediction, ground_truth)
128 | return score
129 |
130 |
131 | def f1_score(prediction, ground_truth, **kwargs):
132 | common = Counter(prediction) & Counter(ground_truth)
133 | num_same = sum(common.values())
134 | if num_same == 0:
135 | return 0
136 | precision = 1.0 * num_same / len(prediction)
137 | recall = 1.0 * num_same / len(ground_truth)
138 | f1 = (2 * precision * recall) / (precision + recall)
139 | return f1
140 |
141 |
142 | def qa_f1_score(prediction, ground_truth, **kwargs):
143 | normalized_prediction = normalize_answer(prediction)
144 | normalized_ground_truth = normalize_answer(ground_truth)
145 |
146 | prediction_tokens = normalized_prediction.split()
147 | ground_truth_tokens = normalized_ground_truth.split()
148 | return f1_score(prediction_tokens, ground_truth_tokens)
149 |
150 |
151 | def qa_f1_zh_score(prediction, ground_truth, **kwargs):
152 | prediction_tokens = list(jieba.cut(prediction, cut_all=False))
153 | ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False))
154 | prediction_tokens = [normalize_zh_answer(
155 | token) for token in prediction_tokens]
156 | ground_truth_tokens = [normalize_zh_answer(
157 | token) for token in ground_truth_tokens]
158 | prediction_tokens = [
159 | token for token in prediction_tokens if len(token) > 0]
160 | ground_truth_tokens = [
161 | token for token in ground_truth_tokens if len(token) > 0]
162 | return f1_score(prediction_tokens, ground_truth_tokens)
163 |
--------------------------------------------------------------------------------
/memory_evolution/base.py:
--------------------------------------------------------------------------------
1 | import typing as tp
2 | from abc import ABC, abstractmethod
3 | import torch
4 | import torch.nn as nn
5 | import numpy as np
6 |
7 |
8 | class MemoryEvolution(nn.Module, ABC):
9 | '''Strategy for parameterizing the queries is independent of memory
10 | and strategy for optimizing the queries'''
11 |
12 | def __init__(self,
13 | pop_size, # equal to n_replicas (of dividing)
14 | param_size,
15 | score_processing: tp.Optional[str] = None,
16 | param_clip: tp.Optional[float] = None,
17 | clip_param_min: tp.Optional[float] = None,
18 | clip_param_max: tp.Optional[float] = None,
19 | prefer_mean_to_best: bool = False,
20 |
21 | ):
22 | nn.Module.__init__(self=self,)
23 | self.stored_buffers_to_save = nn.ParameterDict()
24 | self.best_stored_buffers_to_save = nn.ParameterDict()
25 | self.pop_size = pop_size
26 | self.param_size = param_size
27 | self.clip_param_min = clip_param_min
28 | self.clip_param_max = clip_param_max
29 |
30 | if clip_param_min is None and param_clip is not None:
31 | self.clip_param_min = -param_clip
32 | if clip_param_max is None and param_clip is not None:
33 | self.clip_param_max = param_clip
34 |
35 | if isinstance(score_processing, str):
36 | score_processing = score_processing.lower()
37 | assert score_processing in ['none', 'rank']
38 | if score_processing == 'rank':
39 | scores_values = torch.arange(pop_size)/(pop_size - 1) - 0.5
40 | self.register_buffer(
41 | name='scores_values',
42 | tensor=scores_values,
43 | persistent=False,
44 | )
45 |
46 | self.score_processing = score_processing
47 |
48 | self.prefer_mean_to_best = prefer_mean_to_best
49 |
50 | @abstractmethod
51 | def ask(self,) -> torch.Tensor:
52 | """Ask the algorithm for a population of parameters and update internal
53 | state."""
54 | raise NotImplementedError()
55 |
56 | @abstractmethod
57 | def tell(self, fitness) -> None:
58 | """Report the fitness of the population to the algorithm."""
59 | raise NotImplementedError()
60 |
61 | @property
62 | def best_params(self) -> torch.Tensor:
63 | raise NotImplementedError()
64 |
65 | @property
66 | def best_buffer(self) -> torch.Tensor:
67 | raise NotImplementedError()
68 |
69 | @abstractmethod
70 | def sample_candidates(
71 | self,
72 | num_candidates: int,
73 | temperature: float = 1.0,
74 | ) -> torch.Tensor:
75 | """Ask the algorithm for a population of parameters, temperature should
76 | optionally indicate how much the parameters should differ from the
77 | best/mean solution, depending on the specific algorithm."""
78 | raise NotImplementedError()
79 |
80 | def process_scores(self, fitness) -> torch.Tensor:
81 | """Preprocess the scores"""
82 | # fitness should be a pop_size-dim. vector
83 | if self.score_processing is None:
84 | return fitness
85 | elif self.score_processing == 'none':
86 | return fitness
87 | elif self.score_processing == 'rank':
88 | sorted_idxs = fitness.argsort(-1).argsort(-1) # from min to max
89 | # from -0.5 to 0.5
90 | scores = self.scores_values[sorted_idxs]
91 | return scores
92 | else:
93 | raise NotImplementedError
94 |
95 | def store_best_params(self, x, fitness=None):
96 | raise NotImplementedError
97 |
98 | def store_best_buffers(self, buffers):
99 | self.store_buffers(buffers=buffers, best=True)
100 |
101 | def forward(self,
102 | ):
103 | '''Computes queries'''
104 | return self.ask()
105 |
106 | def get_stats(self,
107 | ):
108 | '''Returns statistics for logging'''
109 | return {}
110 |
111 | def load_init(self, init_param):
112 | '''Loads custom initialization values'''
113 | pass
114 |
115 | def store_buffers(
116 | self,
117 | buffers: tp.Dict[str, torch.Tensor] = None,
118 | best: bool = False,
119 | ):
120 | if best:
121 | self.best_stored_buffers_to_save.update(
122 | {n: nn.Parameter(v, requires_grad=False)
123 | for n, v in buffers.items()})
124 | else:
125 | self.stored_buffers_to_save.update(
126 | {n: nn.Parameter(v, requires_grad=False)
127 | for n, v in buffers.items()})
128 | # buffers)
129 |
130 | def get_stored_buffers(self, best=False):
131 | if best and (not self.prefer_mean_to_best):
132 | tensor_dict = {k: v.data.clone() for k, v
133 | in self.best_stored_buffers_to_save.items()}
134 | else:
135 | tensor_dict = {k: v.data.clone() for k, v
136 | in self.stored_buffers_to_save.items()}
137 | return tensor_dict
138 |
139 |
140 | class DummyEvolution(MemoryEvolution):
141 | '''To be returned if no params to optimize, i.e., param_size=0'''
142 |
143 | def __init__(self,
144 | pop_size,
145 | param_size,
146 | param_clip,
147 | score_processing):
148 | nn.Module.__init__(self,)
149 | assert param_size == 0
150 | self.pop_size = pop_size
151 | self.param_size = param_size
152 | self.param_clip = param_clip
153 | self.register_buffer(
154 | name='dummy_params', tensor=torch.zeros(
155 | [self.pop_size, self.param_size],), persistent=False,)
156 |
157 | def ask(self) -> torch.Tensor:
158 | return self.dummy_params
159 |
160 | def tell(self, fitness) -> None:
161 | """Report the fitness of the population to the algorithm."""
162 | pass
163 |
164 | @property
165 | def best_params(self) -> torch.Tensor:
166 | return self.dummy_params[0]
167 |
168 | def forward(self,
169 | ):
170 | '''Computes queries'''
171 | return self.ask()
172 |
173 | def sample_candidates(
174 | self,
175 | num_candidates: int,
176 | temperature: float = 1.0,
177 | ) -> torch.Tensor:
178 | pass
179 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import random
4 | import numpy as np
5 | from torch.distributed import init_process_group, destroy_process_group
6 | from omegaconf import DictConfig, OmegaConf
7 | import hydra
8 | import copy
9 | import datetime
10 |
11 |
12 | def ddp_setup(
13 | backend: str = "nccl",
14 | ddp_timeout_limit: str = '0:6:0', # days:hours:minutes
15 | ):
16 | days, hours, seconds = ddp_timeout_limit.split(':')
17 | timeout_delta = datetime.timedelta(
18 | days=int(days), hours=int(hours), seconds=int(seconds))
19 | init_process_group(backend=backend, timeout=timeout_delta)
20 | torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
21 |
22 |
23 | def get_dist_info():
24 | local_rank = int(os.getenv('OMPI_COMM_WORLD_LOCAL_RANK', -1))
25 | global_rank = int(os.getenv('OMPI_COMM_WORLD_RANK', -1))
26 | world_size = int(os.getenv('OMPI_COMM_WORLD_SIZE', '-1'))
27 |
28 | local_rank = int(os.getenv(
29 | 'LOCAL_RANK', -1)) if local_rank == -1 else local_rank
30 | global_rank = int(os.getenv(
31 | 'RANK', -1)) if global_rank == -1 else global_rank
32 | world_size = int(os.getenv(
33 | 'WORLD_SIZE', -1)) if world_size == -1 else world_size
34 |
35 | if global_rank != -1:
36 | os.environ['RANK'] = str(global_rank) # Needed by torch distributed.
37 | os.environ['LOCAL_RANK'] = str(local_rank)
38 | os.environ['WORLD_SIZE'] = str(world_size)
39 |
40 | return local_rank, global_rank, world_size
41 |
42 |
43 | def wandb_init(cfg):
44 | import wandb
45 | config_dict = OmegaConf.to_container(
46 | # allow missing values for memory experiments
47 | cfg, resolve=True, throw_on_missing=False,
48 | )
49 | # wandb has a 128-size character limit on the group name
50 | wandb.init(
51 | project=cfg.wandb_config.wandb_project,
52 | group=cfg.wandb_config.wandb_group_name[:127],
53 | name=cfg.wandb_config.wandb_run_name[:127],
54 | config=config_dict,
55 | )
56 | return wandb
57 |
58 |
59 | def make_eval_model(cfg, log_prefix='...'):
60 | print(log_prefix + 'Instantialting llm...')
61 | pretrained_llm = hydra.utils.call(cfg.pretrained_llm, _convert_="object")
62 | tokenizer = hydra.utils.call(cfg.tokenizer)
63 |
64 | print(log_prefix + 'Instantialting memory policy...')
65 | memory_policy = hydra.utils.instantiate(cfg.memory_policy,
66 | _convert_="object")
67 |
68 | print(log_prefix + 'Instantialting memory llm...')
69 | memory_model = hydra.utils.instantiate(
70 | cfg.memory_model, model=pretrained_llm, memory_policy=memory_policy,)
71 |
72 | print(log_prefix + 'Instantialting evaluation module...')
73 | memory_evaluator = hydra.utils.instantiate(
74 | cfg.memory_evaluator, model=memory_model, tokenizer=tokenizer)
75 |
76 | print(log_prefix + 'Instantialting evolution module...')
77 | evolution_algorithm = hydra.utils.instantiate(
78 | cfg.evolution_algorithm, param_size=memory_policy.param_size,
79 | _recursive_=False)
80 |
81 | init_param = memory_policy.get_init_param_values()
82 |
83 | evolution_algorithm.load_init(init_param=init_param)
84 |
85 | if cfg.auxiliary_loss is not None:
86 | auxiliary_loss = hydra.utils.instantiate(cfg.auxiliary_loss,
87 | memory_policy=memory_policy)
88 | else:
89 | auxiliary_loss = None
90 |
91 | print(log_prefix + 'Finished instantiations.')
92 | return (memory_policy, memory_model, memory_evaluator, evolution_algorithm,
93 | auxiliary_loss)
94 |
95 |
96 | def make_task_sampler(cfg, log_prefix='', **task_sampler_kwargs):
97 | print(log_prefix + f'Instantiating tasks: {cfg.task_sampler.tasks}; with '
98 | f' corresponding metrics: {cfg.task_sampler.metrics}')
99 | task_sampler = hydra.utils.instantiate(
100 | cfg.task_sampler, _convert_='none',
101 | **task_sampler_kwargs)
102 | return task_sampler
103 |
104 |
105 | def stochasticity_setup(cfg, seed_offset=0, log_prefix=''):
106 | print(log_prefix + f'Global rank used for seed offset {seed_offset}')
107 | np.random.seed(cfg.seed + seed_offset)
108 | torch.manual_seed(cfg.seed + seed_offset)
109 | random.seed(cfg.seed + seed_offset)
110 |
111 | # NOTE: likely can remove offset
112 | if cfg.deterministic_behavior:
113 | print('WARNING: training with deterministic behavior')
114 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
115 | torch.backends.cudnn.benchmark = False
116 | torch.backends.cudnn.deterministic = True
117 | torch.use_deterministic_algorithms(True)
118 |
119 |
120 | @hydra.main(version_base=None, config_path='cfgs', config_name='config')
121 | def main(cfg: DictConfig):
122 | _, global_rank, n_ddp = get_dist_info()
123 | is_ddp = global_rank > -1
124 | if is_ddp:
125 | ddp_setup(backend=cfg.backend, ddp_timeout_limit=cfg.ddp_timeout_limit)
126 | master_process = global_rank == 0
127 | seed_offset = global_rank
128 | else:
129 | master_process = True
130 | seed_offset = 0
131 |
132 | if master_process:
133 | print(f"SHARED Working directory: {os.getcwd()}")
134 | print(f"SHARED Output directory: " +
135 | f"{hydra.core.hydra_config.HydraConfig.get().runtime.output_dir}")
136 |
137 | log_prefix = ''
138 | if is_ddp:
139 | log_prefix = f'RANK {global_rank} ({n_ddp} total): '
140 |
141 | stochasticity_setup(cfg=cfg, seed_offset=seed_offset,
142 | log_prefix=log_prefix)
143 |
144 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
145 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
146 |
147 | with torch.no_grad():
148 | (memory_policy, memory_model, memory_evaluator, evolution_algorithm,
149 | auxiliary_loss) = make_eval_model(cfg=cfg, log_prefix=log_prefix)
150 |
151 | task_sampler = make_task_sampler(cfg=cfg, log_prefix=log_prefix)
152 |
153 | trainer = hydra.utils.instantiate(
154 | cfg.trainer,
155 | evaluation_model=memory_evaluator,
156 | task_sampler=task_sampler,
157 | evolution_algorithm=evolution_algorithm,
158 | auxiliary_loss=auxiliary_loss,
159 | )
160 |
161 | if cfg.wandb_config.wandb_log and master_process:
162 | wandb_init(cfg=cfg)
163 |
164 | with torch.no_grad():
165 | trainer.train()
166 |
167 | if is_ddp:
168 | destroy_process_group()
169 |
170 |
171 | if __name__ == "__main__":
172 | with torch.no_grad():
173 | main()
174 |
--------------------------------------------------------------------------------
/stateless_parallel_modules/mlp.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pdb
3 | import copy
4 | import math
5 | import numpy as np
6 | from dataclasses import dataclass
7 | from typing import Optional, Tuple, Union, Dict, Callable
8 |
9 |
10 | import torch
11 | from torch import nn
12 | import torch.utils.checkpoint
13 |
14 | from .base import (
15 | StatelessGeneralizedOperation, GeneralizedLinear,
16 | StatelessGeneralizedModule)
17 |
18 | from utils import get_nonlinearity
19 |
20 |
21 | class StatelessGeneralizedMLP(StatelessGeneralizedModule):
22 | def __init__(
23 | self,
24 | input_features: Optional[int],
25 | hidden_features: Optional[int],
26 | output_features: Optional[int],
27 | # depth = 1, means a single linear operation
28 | hidden_depth: int,
29 | bias: bool,
30 | # base_parallel_operations: Optional[int] = None,
31 | non_linearity: Union[str, Callable] = 'relu',
32 | # add a residual connection for the hidden layers
33 | residual: bool = True,
34 | # add a residual connection from the first layer
35 | residual_first: bool = False,
36 | ):
37 |
38 | StatelessGeneralizedModule.__init__(
39 | self=self,
40 | input_features=input_features,
41 | output_features=output_features,
42 | init_module=True,
43 | )
44 |
45 | self.hidden_features = hidden_features
46 | self.hidden_depth = hidden_depth
47 | self.bias = bias
48 |
49 | self.non_linearity = get_nonlinearity(
50 | nonlinearity=non_linearity,
51 | )
52 |
53 | self.residual_first = residual_first
54 | self.residual = residual
55 |
56 | self.is_linear = hidden_depth == 1
57 | self.has_intermediate = hidden_depth > 2
58 |
59 | if (self.input_features is not None and
60 | self.output_features is not None and
61 | self.hidden_features is not None):
62 | print('instantiating ops')
63 | self.instantiate_and_setup_ops(
64 | input_features=input_features,
65 | hidden_features=hidden_features,
66 | output_features=output_features,
67 | preceding_module=None,
68 | default_output_features_mult=1,
69 | )
70 |
71 | def instantiate_and_setup_ops(
72 | self,
73 | input_features: Optional[int] = None,
74 | hidden_features: Optional[int] = None,
75 | output_features: Optional[int] = None,
76 | preceding_module=None,
77 | default_output_features_mult: int = 1,
78 | **kwargs,
79 | ):
80 |
81 | if (self.input_features is None or
82 | self.output_features is None or
83 | self.hidden_features is None):
84 |
85 | self.instantiate_model(
86 | input_features=input_features,
87 | output_features=output_features,
88 | preceding_module=preceding_module,
89 | default_output_features_mult=default_output_features_mult,
90 | )
91 | if self.hidden_features is None and hidden_features is not None:
92 | self.hidden_features = hidden_features
93 | elif self.hidden_features is None:
94 | print('Warning: hidden features not specified setting to ' +
95 | f'{self.output_features} (output features)')
96 | self.hidden_features = self.output_features
97 |
98 | if self.residual_first:
99 | # assert hidden_depth > 1
100 | assert self.input_features == self.hidden_features
101 |
102 | self.linear_op = GeneralizedLinear()
103 |
104 | in_dims = self.input_features
105 | self.input_output_features_tuple = []
106 | for _ in range(self.hidden_depth - 1):
107 | out_dims = self.hidden_features
108 | self.input_output_features_tuple.append(
109 | (in_dims, out_dims))
110 | in_dims = self.hidden_features
111 | out_dims = self.output_features
112 | self.input_output_features_tuple.append((in_dims, out_dims))
113 |
114 | operation_list = [
115 | self.linear_op for _ in range(self.hidden_depth)]
116 | operation_kwargs = dict(
117 | bias=self.bias,
118 | )
119 | operation_kwargs_overrides_list = [
120 | dict(in_features=in_dims, out_features=out_dims)
121 | for in_dims, out_dims in self.input_output_features_tuple]
122 |
123 | self.setup_operations(
124 | operations=operation_list,
125 | operation_kwargs=operation_kwargs,
126 | operation_kwargs_overrides_list=operation_kwargs_overrides_list,
127 | save_as_module_list=True,
128 | )
129 |
130 | def forward(
131 | self,
132 | # parallel_operations x batch_size x input_features
133 | inputs: torch.Tensor,
134 | *args,
135 | n_parallel_dimensions: Optional[int] = None,
136 | **kwargs,
137 | ):
138 |
139 | if n_parallel_dimensions is not None:
140 | input_shape = inputs.shape
141 | inputs = inputs.flatten(
142 | start_dim=0, end_dim=n_parallel_dimensions-1)
143 | # this is guaranteed to be 3-dimensional
144 | inputs = inputs.flatten(start_dim=1, end_dim=-2)
145 |
146 | weight, bias = self.parameters_per_layer[0]
147 | # raise NotImplementedError
148 | h = inputs
149 | # TODO
150 | h_out = self.linear_op(
151 | input=h,
152 | weight=weight,
153 | bias=bias,
154 | parallel_operations=self.parallel_operations,
155 | # initial input already flattened above
156 | n_parallel_dimensions=None,
157 | )
158 |
159 | if not self.is_linear:
160 | h_out = self.non_linearity(h_out)
161 | if self.residual_first:
162 | h = h + h_out
163 | else:
164 | h = h_out
165 | if self.has_intermediate:
166 | for weight, bias in self.parameters_per_layer[1:-1]:
167 | h_out = self.linear_op(
168 | input=h,
169 | weight=weight,
170 | bias=bias,
171 | parallel_operations=self.parallel_operations,
172 | # initial input already flattened above
173 | n_parallel_dimensions=None,
174 | )
175 | h_out = self.non_linearity(h)
176 | if self.residual:
177 | h = h + h_out
178 | else:
179 | h = h_out
180 | weight, bias = self.parameters_per_layer[-1]
181 | h_out = self.linear_op(
182 | input=h,
183 | weight=weight,
184 | bias=bias,
185 | parallel_operations=self.parallel_operations,
186 | # initial input already flattened above
187 | n_parallel_dimensions=None,
188 | )
189 | if n_parallel_dimensions is not None:
190 | h_out = h_out.view(*input_shape[:-1], self.output_features)
191 | return h_out
192 |
--------------------------------------------------------------------------------
/utils_longbench.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | import numpy as np
5 |
6 | from choubun_metrics import (
7 | rouge_ja_score,
8 | qa_f1_ja_score
9 | )
10 | from longbench_metrics import (
11 | qa_f1_score,
12 | rouge_zh_score,
13 | qa_f1_zh_score,
14 | rouge_score,
15 | classification_score,
16 | retrieval_score,
17 | retrieval_zh_score,
18 | count_score,
19 | code_sim_score,
20 | )
21 |
22 | dataset2metric = {
23 | # LongBench
24 | "narrativeqa": qa_f1_score,
25 | "qasper": qa_f1_score,
26 | "multifieldqa_en": qa_f1_score,
27 | "multifieldqa_zh": qa_f1_zh_score,
28 | "hotpotqa": qa_f1_score,
29 | "2wikimqa": qa_f1_score,
30 | "musique": qa_f1_score,
31 | "dureader": rouge_zh_score,
32 | "gov_report": rouge_score,
33 | "qmsum": rouge_score,
34 | "multi_news": rouge_score,
35 | "vcsum": rouge_zh_score,
36 | "trec": classification_score,
37 | "triviaqa": qa_f1_score,
38 | "samsum": rouge_score,
39 | "lsht": classification_score,
40 | "passage_retrieval_en": retrieval_score,
41 | "passage_count": count_score,
42 | "passage_retrieval_zh": retrieval_zh_score,
43 | "lcc": code_sim_score,
44 | "repobench-p": code_sim_score,
45 | # ChouBun
46 | "wiki_qa": qa_f1_ja_score,
47 | "edinet_qa": qa_f1_ja_score,
48 | "corp_sec_qa": qa_f1_ja_score,
49 | "corp_sec_sum": rouge_ja_score
50 |
51 | }
52 |
53 | # This is the customized building prompt for chat models
54 |
55 |
56 | def build_chat(lm, prompt):
57 | print('IN PROMPT')
58 | print(prompt)
59 | tokenizer = lm.tokenizer
60 | model_name = lm.model_name
61 | if "chatglm3" in model_name:
62 | tokenized_prompt = tokenizer.build_chat_input(prompt)
63 | elif "chatglm" in model_name:
64 | prompt = tokenizer.build_prompt(prompt)
65 | elif "longchat" in model_name or "vicuna" in model_name:
66 | raise NotImplementedError
67 | elif "llama2" in model_name:
68 | prompt = f"[INST]{prompt}[/INST]"
69 | elif "xgen" in model_name:
70 | header = (
71 | "A chat between a curious human and an artificial intelligence "
72 | "assistant. The assistant gives helpful, detailed, and polite "
73 | "answers to the human's questions.\n\n"
74 | )
75 | prompt = header + f" ### Human: {prompt}\n###"
76 | elif "internlm" in model_name:
77 | prompt = f"<|User|>:{prompt}\n<|Bot|>:"
78 | else:
79 | print
80 | raise NotImplementedError
81 |
82 | if not ("chatglm3" in model_name):
83 | tokenized_prompt = tokenizer(prompt, truncation=False,
84 | return_tensors="pt")
85 | return tokenized_prompt
86 |
87 |
88 | def parse_args(args=None):
89 | parser = argparse.ArgumentParser()
90 | parser.add_argument('--model', type=str, default=None)
91 | parser.add_argument('--e', action='store_true',
92 | help="Evaluate on LongBench-E")
93 | return parser.parse_args(args)
94 |
95 |
96 | def scorer_e(dataset, predictions, answers, lengths, all_classes):
97 | scores = {"0-4k": [], "4-8k": [], "8k+": []}
98 | for (prediction, ground_truths, length) in zip(
99 | predictions, answers, lengths):
100 | score = 0.
101 | if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
102 | prediction = prediction.lstrip('\n').split('\n')[0]
103 | for ground_truth in ground_truths:
104 | score = max(score, dataset2metric[dataset](
105 | prediction, ground_truth, all_classes=all_classes))
106 | if length < 4000:
107 | scores["0-4k"].append(score)
108 | elif length < 8000:
109 | scores["4-8k"].append(score)
110 | else:
111 | scores["8k+"].append(score)
112 | for key in scores.keys():
113 | scores[key] = round(100 * np.mean(scores[key]), 2)
114 | return scores
115 |
116 |
117 | def scorer(dataset, predictions, answers, all_classes):
118 | total_score = 0.
119 | for (prediction, ground_truths) in zip(predictions, answers):
120 | score = 0.
121 | if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
122 | prediction = prediction.lstrip('\n').split('\n')[0]
123 | for ground_truth in ground_truths:
124 | score = max(score, dataset2metric[dataset](
125 | prediction, ground_truth, all_classes=all_classes))
126 | total_score += score
127 | return round(100 * total_score / len(predictions), 2)
128 |
129 |
130 | def get_all_scores(task, predictions, answers, all_classes):
131 | all_scores = []
132 | for (prediction, ground_truths) in zip(predictions, answers):
133 | score = 0.
134 | if task in ["trec", "triviaqa", "samsum", "lsht"]:
135 | prediction = prediction.lstrip('\n').split('\n')[0]
136 | for ground_truth in ground_truths:
137 | score = max(score, dataset2metric[task](prediction, ground_truth,
138 | all_classes=all_classes))
139 | all_scores.append(score)
140 | return all_scores
141 |
142 |
143 | def get_score(task, predictions, answers, all_classes):
144 | total_score = 0.
145 | all_scores = []
146 |
147 | # Instantiate tokenizer for ChouBun tasks
148 | if task in ["wiki_qa", "edinet_qa", "corp_sec_qa", "corp_sec_sum"]:
149 | from fugashi import Tagger
150 | tokenizer = Tagger('-Owakati')
151 | else:
152 | tokenizer = None
153 |
154 | for (prediction, ground_truths) in zip(predictions, answers):
155 | score = 0.
156 | if task in ["trec", "triviaqa", "samsum", "lsht"]:
157 | prediction = prediction.lstrip('\n').split('\n')[0]
158 | for ground_truth in ground_truths:
159 | cur_score = dataset2metric[task](
160 | prediction,
161 | ground_truth,
162 | tokenizer=tokenizer,
163 | all_classes=all_classes
164 | )
165 | score = max(score, cur_score)
166 | all_scores.append(score)
167 | total_score += score
168 | mean_score = 100 * total_score / len(predictions)
169 | return mean_score, all_scores
170 |
171 |
172 | if __name__ == '__main__':
173 | args = parse_args()
174 | scores = dict()
175 | if args.e:
176 | path = f"pred_e/{args.model}/"
177 | else:
178 | path = f"pred/{args.model}/"
179 | all_files = os.listdir(path)
180 | print("Evaluating on:", all_files)
181 | for filename in all_files:
182 | if not filename.endswith("jsonl"):
183 | continue
184 | predictions, answers, lengths = [], [], []
185 | dataset = filename.split('.')[0]
186 | with open(f"{path}{filename}", "r", encoding="utf-8") as f:
187 | for line in f:
188 | data = json.loads(line)
189 | predictions.append(data["pred"])
190 | answers.append(data["answers"])
191 | all_classes = data["all_classes"]
192 | if "length" in data:
193 | lengths.append(data["length"])
194 | if args.e:
195 | score = scorer_e(dataset, predictions,
196 | answers, lengths, all_classes)
197 | else:
198 | score = scorer(dataset, predictions, answers, all_classes)
199 | scores[dataset] = score
200 | if args.e:
201 | out_path = f"pred_e/{args.model}/result.json"
202 | else:
203 | out_path = f"pred/{args.model}/result.json"
204 | with open(out_path, "w") as f:
205 | json.dump(scores, f, ensure_ascii=False, indent=4)
206 |
--------------------------------------------------------------------------------
/memory_policy/deep_scoring_bam.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pdb
3 | import copy
4 | import math
5 | import numbers
6 | import numpy as np
7 | from dataclasses import dataclass
8 | from typing import Optional, Tuple, Union, Dict, Callable, List
9 |
10 |
11 | import torch
12 | from torch import nn
13 | import torch.utils.checkpoint
14 | import torch.nn.functional as F
15 | from torch.cuda.amp import autocast
16 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
17 | from transformers import LlamaPreTrainedModel
18 | from transformers.cache_utils import Cache, DynamicCache, StaticCache
19 | from .base import MemoryPolicy, ParamMemoryPolicy
20 | from .base_dynamic import (
21 | DynamicMemoryPolicy, DynamicParamMemoryPolicy,
22 | RecencyParams, AttentionParams, threshold_score_idxs
23 | )
24 | from .base_deep_components import (
25 | ScoringNetwork, TokenEmbedding, SelectionNetwork, wrap_torch_initializer,
26 | ComponentOutputParams)
27 | from stateless_parallel_modules import StatelessGeneralizedMLP
28 |
29 |
30 | def make_scaled_one_hot_init(
31 | idxs_to_scale: dict,
32 | idxs_to_ones: Union[List[int], np.ndarray, torch.Tensor]):
33 | def _init_fn(shape):
34 | tensor = torch.zeros(shape)
35 | for idx in idxs_to_ones:
36 | tensor[..., idx] = 1.0
37 | for idx, value in idxs_to_scale.items():
38 | tensor[..., int(idx)] = float(value)
39 |
40 | return tensor
41 | return _init_fn
42 |
43 | class MLPScoring(ScoringNetwork):
44 | '''MLP scoring layer, producing score as the NN output combination of the
45 | embeddings.'''
46 | def __init__(
47 | self,
48 | per_layer: bool,
49 | per_head: bool,
50 | shared: bool,
51 | output_params: ComponentOutputParams,
52 | hidden_features: Optional[int],
53 | hidden_depth: int,
54 | bias: bool,
55 | non_linearity: Optional[Union[str, Callable]] = 'relu',
56 | initializer: numbers.Number = 0,
57 | residual: bool = True,
58 | residual_first: bool = False,
59 |
60 | ):
61 | ScoringNetwork.__init__(
62 | self,
63 | per_layer=per_layer,
64 | per_head=per_head,
65 | shared=shared,
66 | output_params=output_params,
67 | buffer_names=['past_scores'],
68 | initializer=initializer,
69 | )
70 | self.hidden_features = hidden_features
71 | self.hidden_depth = hidden_depth
72 | self.bias = bias
73 | self.non_linearity = non_linearity
74 | self.residual = residual
75 | self.residual_first = residual_first
76 |
77 | def register_embedding(self, embedding_module: TokenEmbedding):
78 | ScoringNetwork.register_embedding(
79 | self=self,
80 | embedding_module=embedding_module,
81 | )
82 | if self.hidden_features is None:
83 | self.hidden_features=self.input_embedding_dim
84 | self.mlp = StatelessGeneralizedMLP(
85 | input_features=self.input_embedding_dim,
86 | hidden_features=self.hidden_features,
87 | output_features=1,
88 | hidden_depth=self.hidden_depth,
89 | bias=self.bias,
90 | non_linearity=self.non_linearity,
91 | residual=self.residual,
92 | residual_first=self.residual_first,
93 | )
94 | self.mlp_base_parameters = self.mlp.total_base_parameter_dims
95 |
96 | def get_tokens_score(
97 | self,
98 | layer_id,
99 | parameters,
100 | token_embeddings: torch.Tensor,
101 | new_sequences,
102 | num_new_tokens,
103 | attn_weights=None,
104 | attn_mask=None,
105 | position_ids=None,
106 |
107 | **kwargs,
108 | ) -> torch.Tensor:
109 | '''Produces score for each KV cache token embedding'''
110 |
111 | if not self.requires_recomputation:
112 | token_embeddings = token_embeddings[..., -num_new_tokens:, :]
113 |
114 | if self.is_reduced_input:
115 | batch_size, n_heads, n_out_tokens, emb_dim = token_embeddings.shape
116 | else:
117 | batch_size, n_heads, non_reduced_outputs, n_out_tokens, emb_dim = (
118 | token_embeddings.shape)
119 | token_embeddings = token_embeddings.flatten(start_dim=2, end_dim=3)
120 |
121 |
122 |
123 | n_out_tokens = n_out_tokens
124 | parallel_operations = batch_size
125 |
126 |
127 | if self.shared and self.per_head:
128 |
129 | parallel_operations = parallel_operations*n_heads
130 | token_embeddings = token_embeddings.flatten(start_dim=0, end_dim=1)
131 | else:
132 |
133 | token_embeddings = token_embeddings.flatten(start_dim=1, end_dim=2)
134 |
135 |
136 |
137 | self.mlp.load_parameters(
138 | parameters=parameters,
139 | parallel_operations=parallel_operations,
140 | )
141 | scores = self.mlp(inputs=token_embeddings)
142 |
143 | if self.is_reduced_input:
144 | scores = scores.view(batch_size, n_heads, n_out_tokens)
145 | else:
146 | scores = scores.view(
147 | batch_size, n_heads, non_reduced_outputs, n_out_tokens)
148 |
149 | if not self.requires_recomputation:
150 | if not new_sequences:
151 |
152 | past_scores: torch.Tensor = self.past_scores[layer_id]
153 | scores = torch.concat([past_scores, scores], dim=-1)
154 | self.past_scores[layer_id] = scores
155 |
156 | if (self.reduction_mode is not None) and (not self.is_reduced_input):
157 |
158 | scores = self.process_output(
159 | layer_id=layer_id,
160 | ema_coeff=self.ema_coeff,
161 | num_new_tokens=num_new_tokens,
162 | new_sequences=new_sequences,
163 | component_output=scores,
164 | **kwargs,
165 | )
166 | return scores
167 |
168 | def filter_buffer_values(
169 | self,
170 | layer_id: int,
171 |
172 | retained_idxs: torch.Tensor,
173 | ):
174 | ScoringNetwork.filter_buffer_values(
175 | self=self,
176 | layer_id=layer_id,
177 | retained_idxs=retained_idxs,
178 | )
179 | if not self.requires_recomputation:
180 |
181 |
182 | past_scores: torch.Tensor = self.past_scores[layer_id]
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 | self.past_scores[layer_id] = torch.gather(
193 | input=past_scores, dim=-1, index=retained_idxs)
194 |
195 | def net_param_size(self,) -> int:
196 | return self.mlp_base_parameters
197 |
198 |
199 | class LinearScoring(MLPScoring):
200 | '''Linear scoring layer, producing score as a linear combination of the
201 | embeddings.'''
202 | def __init__(
203 | self,
204 | per_layer: bool,
205 | per_head: bool,
206 | shared: bool,
207 | output_params: ComponentOutputParams,
208 | bias: bool,
209 |
210 | initializer: numbers.Number = 0,
211 | ):
212 | MLPScoring.__init__(
213 | self,
214 | per_layer=per_layer,
215 | per_head=per_head,
216 | shared=shared,
217 | output_params=output_params,
218 | hidden_features=0,
219 | hidden_depth=1,
220 | bias=bias,
221 | non_linearity=None,
222 | initializer=initializer,
223 | )
224 | self.bias = bias
225 |
226 |
--------------------------------------------------------------------------------
/memory_policy/deep_embedding.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pdb
3 | import copy
4 | import math
5 | import numbers
6 | import abc
7 | import numpy as np
8 | from dataclasses import dataclass
9 | from typing import Optional, Tuple, Union, Dict, Callable, List
10 |
11 | from omegaconf import OmegaConf, DictConfig
12 | import hydra
13 |
14 | import torch
15 | from torch import nn
16 | import torch.utils.checkpoint
17 | import torch.nn.functional as F
18 | from torch.cuda.amp import autocast
19 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
20 | from transformers import LlamaPreTrainedModel
21 | from transformers.cache_utils import Cache, DynamicCache, StaticCache
22 | from .base import MemoryPolicy, ParamMemoryPolicy
23 | from .base_dynamic import compute_recency
24 |
25 | from .base_deep_components import (
26 | ScoringNetwork, TokenEmbedding, SelectionNetwork, wrap_torch_initializer,
27 | ComponentOutputParams,)
28 |
29 | from .deep_embedding_shared import Embedding, PositionalEmbedding
30 |
31 | from stateless_parallel_modules import StatelessGeneralizedMLP
32 |
33 | def convert_to_tensor(
34 | el: Union[List[float], np.ndarray, torch.Tensor],
35 | ) -> torch.Tensor:
36 | if isinstance(el, torch.Tensor):
37 | return el
38 | else:
39 | el = torch.tensor(el)
40 | return el
41 |
42 |
43 | class RecencyExponents(TokenEmbedding):
44 | '''Representing each KV, via a polynomial vector of its recency'''
45 | def __init__(
46 | self,
47 | per_layer: bool,
48 | per_head: bool,
49 | shared: bool,
50 | initial_exponents: Union[List[float], np.ndarray, torch.Tensor],
51 | dtype: Optional[Union[str, torch.dtype]] = None
52 | ):
53 | initial_exponents = convert_to_tensor(initial_exponents)
54 | assert len(initial_exponents.shape) == 1
55 | self._num_recency_exponents = initial_exponents.shape[-1]
56 | TokenEmbedding.__init__(
57 | self,
58 | per_layer=per_layer,
59 | per_head=per_head,
60 | shared=shared,
61 | output_params=ComponentOutputParams(
62 | requires_recomputation=False,),
63 | buffer_names=[],
64 | initializer=initial_exponents,
65 | dtype=dtype,
66 | )
67 | self.initial_recency_exponents = initial_exponents
68 |
69 | def get_tokens_embedding(
70 | self,
71 | layer_id,
72 | parameters,
73 | key_cache,
74 | value_cache,
75 | new_sequences,
76 | num_new_tokens,
77 | position_ids,
78 | attn_mask=None,
79 | **kwargs,
80 | ) -> torch.Tensor:
81 | '''Builds a tensor representation for each KV cache token'''
82 |
83 | exponents = parameters
84 |
85 |
86 | unsqueezed_exponents = exponents.unsqueeze(dim=-2)
87 |
88 |
89 | cache_recencies = compute_recency(position_ids=position_ids)
90 | unsqueezed_recencies = cache_recencies.unsqueeze(dim=-1)
91 |
92 | embeddings = torch.pow(
93 | input=unsqueezed_recencies, exponent=unsqueezed_exponents)
94 |
95 | if self._custom_dtype is not None:
96 | embeddings = embeddings.to(dtype=self.ptdtype)
97 |
98 |
99 | embeddings = self.process_output(
100 | layer_id=layer_id,
101 | ema_coeff=self.ema_coeff,
102 | num_new_tokens=num_new_tokens,
103 | new_sequences=new_sequences,
104 | component_output=embeddings,
105 | attn_mask=attn_mask,
106 | **kwargs,
107 | )
108 |
109 | return embeddings
110 |
111 | def get_embedding_dim(self,) -> int:
112 | return self._num_recency_exponents
113 |
114 | def net_param_size(self,) -> int:
115 | return self._num_recency_exponents
116 |
117 | def aux_param_size(self) -> int:
118 | return 0
119 |
120 | def get_net_params_stats(self, parameters: torch.Tensor):
121 | stats = dict()
122 | learned_exps = parameters.split(split_size=1, dim=-1)
123 | for i, learned_exp in enumerate(learned_exps):
124 | stats[f'net_params/rec_exp_{i}'] = learned_exp.mean().item()
125 | return stats
126 |
127 | @property
128 | def requires_position_ids(self,):
129 |
130 | return True
131 |
132 | @property
133 | def requires_recomputation(self,):
134 |
135 | return False
136 |
137 | @property
138 | def reduced_output(self,):
139 | return True
140 |
141 | def net_param_size(self,) -> int:
142 | return self._num_recency_exponents
143 |
144 |
145 | class NormalizedRecencyExponents(TokenEmbedding):
146 | '''Representing each KV, via a polynomial vector of its normalized
147 | recency - a score ranging from 1 to 0 (most recent to oldest possible
148 | i.e., max_position_id)'''
149 | def __init__(
150 | self,
151 | per_layer: bool,
152 | per_head: bool,
153 | shared: bool,
154 | max_position_id: int,
155 | initial_exponents: Union[List[float], np.ndarray, torch.Tensor],
156 | only_positive_exponents: bool = True,
157 | dtype: Optional[Union[str, torch.dtype]] = None
158 | ):
159 | initial_exponents = convert_to_tensor(initial_exponents)
160 | assert len(initial_exponents.shape) == 1
161 | self._num_recency_exponents = initial_exponents.shape[-1]
162 | self.max_position_id = max_position_id
163 | self.only_positive_exponents = only_positive_exponents
164 |
165 | if only_positive_exponents:
166 | assert torch.all(initial_exponents > 0)
167 | log_initial_exponents = torch.log(
168 | initial_exponents)
169 | initializer = log_initial_exponents
170 | else:
171 | initializer = initial_exponents
172 |
173 | TokenEmbedding.__init__(
174 | self,
175 | per_layer=per_layer,
176 | per_head=per_head,
177 | shared=shared,
178 | output_params=ComponentOutputParams(
179 | requires_recomputation=False,),
180 | buffer_names=[],
181 | initializer=initializer,
182 | dtype=dtype,
183 | )
184 |
185 | def get_tokens_embedding(
186 | self,
187 | layer_id,
188 | parameters,
189 | key_cache,
190 | value_cache,
191 | new_sequences,
192 | num_new_tokens,
193 | position_ids,
194 | attn_mask=None,
195 | **kwargs,
196 | ) -> torch.Tensor:
197 | '''Builds a tensor representation for each KV cache token'''
198 |
199 | if self.only_positive_exponents:
200 | exponents = torch.exp(parameters)
201 | else:
202 | exponents = parameters
203 |
204 |
205 | unsqueezed_exponents = exponents.unsqueeze(dim=-2)
206 |
207 |
208 |
209 | cache_recencies = compute_recency(
210 | position_ids=position_ids)/self.max_position_id
211 |
212 |
213 |
214 | cache_recencies = 1 - cache_recencies
215 |
216 | unsqueezed_recencies = cache_recencies.unsqueeze(dim=-1)
217 |
218 | embeddings = torch.pow(
219 | input=unsqueezed_recencies, exponent=unsqueezed_exponents)
220 |
221 | if self._custom_dtype is not None:
222 | embeddings = embeddings.to(dtype=self.ptdtype)
223 |
224 | embeddings = self.process_output(
225 | layer_id=layer_id,
226 | ema_coeff=self.ema_coeff,
227 | num_new_tokens=num_new_tokens,
228 | new_sequences=new_sequences,
229 | component_output=embeddings,
230 | attn_mask=attn_mask,
231 | **kwargs,
232 | )
233 | return embeddings
234 |
235 | def get_embedding_dim(self,) -> int:
236 | return self._num_recency_exponents
237 |
238 | def net_param_size(self,) -> int:
239 | return self._num_recency_exponents
240 |
241 | def aux_param_size(self) -> int:
242 | return 0
243 |
244 | def get_net_params_stats(self, parameters: torch.Tensor):
245 | stats = dict()
246 | learned_exps = parameters.split(split_size=1, dim=-1)
247 | for i, learned_exp in enumerate(learned_exps):
248 | stats[f'net_params/rec_exp_{i}'] = learned_exp.mean().item()
249 | return stats
250 |
251 | @property
252 | def requires_position_ids(self,):
253 |
254 | return True
255 |
256 | @property
257 | def requires_recomputation(self,):
258 |
259 | return False
260 |
261 | @property
262 | def reduced_output(self,):
263 | return True
264 |
265 | def net_param_size(self,) -> int:
266 | return self._num_recency_exponents
267 |
268 |
269 |
--------------------------------------------------------------------------------
/memory_policy/deep_selection.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pdb
3 | import copy
4 | import math
5 | import numpy as np
6 | from dataclasses import dataclass
7 | from typing import Optional, Tuple, Union, Dict
8 |
9 |
10 | import torch
11 | from torch import nn
12 | import torch.utils.checkpoint
13 | import torch.nn.functional as F
14 | from torch.cuda.amp import autocast
15 |
16 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
17 | from transformers import LlamaPreTrainedModel
18 | from transformers.cache_utils import Cache, DynamicCache, StaticCache
19 | from .base import MemoryPolicy, ParamMemoryPolicy
20 | from .base_dynamic import (
21 | DynamicMemoryPolicy, DynamicParamMemoryPolicy,
22 | RecencyParams, AttentionParams, threshold_score_idxs
23 | )
24 | from .base_deep_components import SelectionNetwork, ComponentOutputParams
25 |
26 |
27 | class DynamicSelection(SelectionNetwork):
28 | '''Replicates the default selection criteria of dynamic memory policies'''
29 | def __init__(self,
30 | per_layer: bool,
31 | per_head: bool,
32 | shared: bool,
33 | cache_size: Optional[int],
34 | dynamic_thresh: float,
35 | ):
36 | SelectionNetwork.__init__(
37 | self,
38 | per_layer=per_layer,
39 | per_head=per_head,
40 | shared=shared,
41 | output_params=ComponentOutputParams(
42 | requires_recomputation=False,
43 | reduction_mode=None,
44 | ema_params=None,
45 | output_past_non_reduced_history=False,
46 | max_non_reduced_history_len=None,
47 | ),
48 | buffer_names=[],
49 | initializer=dynamic_thresh,
50 | )
51 | self.cache_size = cache_size
52 |
53 | def select_new_tokens(
54 | self,
55 | parameters,
56 | token_scores,
57 | new_sequences,
58 | num_new_tokens,
59 | attn_weights=None,
60 | attn_mask=None,
61 | position_ids=None,
62 | threshold_shift: float = 0.0,
63 | **kwargs,
64 | ) -> torch.Tensor:
65 | '''Produces indexes for the selected KV cache tokens and a selection
66 | mask.'''
67 | dynamic_thresh = parameters
68 | min_value = torch.finfo(token_scores.dtype).min
69 | max_value = torch.finfo(dynamic_thresh.dtype).max
70 |
71 |
72 | masked_full_scores = torch.where(
73 | attn_mask.bool(), token_scores, min_value)
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 | masked_full_scores[..., -1] = max_value
84 | retained_idxs, new_mask = threshold_score_idxs(
85 | masked_full_scores=masked_full_scores,
86 | dynamic_thresh=dynamic_thresh + threshold_shift,
87 | preserve_order=True,
88 | cache_size=self.cache_size,
89 | )
90 |
91 |
92 | return retained_idxs, new_mask
93 |
94 | def get_cache_size(self,) -> Optional[int]:
95 | return self.cache_size
96 |
97 | def net_param_size(self,) -> int:
98 | return 1
99 |
100 | def get_param_scaling(self,) -> Optional[Union[str, Tuple[float, float]]]:
101 |
102 |
103 | return 'exp'
104 |
105 | class BinarySelection(SelectionNetwork):
106 | '''Mantains tokens when scores > 0 - can be probabilistic, based on a
107 | logistic distribution instead'''
108 |
109 | def __init__(self,
110 | per_layer: bool,
111 | per_head: bool,
112 | shared: bool,
113 | cache_size: Optional[int],
114 | is_probabilistic: bool = False,
115 | temp: float = 1.0,
116 | learned_temp: bool = False,
117 | ):
118 |
119 | if learned_temp:
120 | assert is_probabilistic, (
121 | 'If is not probabilistic, the temperature will not be used')
122 | self._needs_learned_temp = True
123 | else:
124 | self._needs_learned_temp = False
125 | self.is_probabilistic = is_probabilistic
126 | self.initial_temp = temp
127 | SelectionNetwork.__init__(
128 | self,
129 | per_layer=per_layer,
130 | per_head=per_head,
131 | shared=shared,
132 | output_params=ComponentOutputParams(
133 | requires_recomputation=False,
134 | reduction_mode=None,
135 | ema_params=None,
136 | output_past_non_reduced_history=False,
137 | max_non_reduced_history_len=None,
138 | ),
139 | buffer_names=[],
140 | initializer=temp,
141 | )
142 | self.cache_size = cache_size
143 |
144 |
145 | def select_new_tokens(
146 | self,
147 | parameters,
148 | token_scores,
149 | new_sequences,
150 | num_new_tokens,
151 | attn_weights=None,
152 | attn_mask=None,
153 | position_ids=None,
154 | threshold_shift: float = 0.0,
155 | **kwargs,
156 | ) -> torch.Tensor:
157 | '''Produces indexes for the selected KV cache tokens and a selection
158 | mask.'''
159 |
160 |
161 |
162 |
163 |
164 | min_value = torch.finfo(token_scores.dtype).min
165 | max_value = torch.finfo(token_scores.dtype).max
166 |
167 | if self.is_probabilistic:
168 | if self._needs_learned_temp:
169 | temp = parameters
170 | else:
171 | temp = self.initial_temp
172 |
173 | probabilities = F.sigmoid(masked_full_scores/temp)
174 | random_samples = torch.rand_like(probabilities)
175 |
176 |
177 | token_scores = (probabilities >= random_samples).to(
178 | probabilities.dtype) - 0.5
179 |
180 | masked_full_scores = torch.where(
181 | attn_mask.bool(), token_scores, min_value)
182 | masked_full_scores[..., -1] = max_value
183 | retained_idxs, new_mask = threshold_score_idxs(
184 | masked_full_scores=masked_full_scores,
185 | dynamic_thresh=threshold_shift,
186 | preserve_order=True,
187 | cache_size=self.cache_size,
188 | )
189 |
190 |
191 | return retained_idxs, new_mask
192 |
193 | def get_cache_size(self,) -> Optional[int]:
194 | return self.cache_size
195 |
196 | def net_param_size(self,) -> int:
197 | if self._needs_learned_temp:
198 | return 1
199 | else:
200 | return 0
201 |
202 | def get_param_scaling(self,) -> Optional[Union[str, Tuple[float, float]]]:
203 |
204 |
205 | return 'exp'
206 |
207 |
208 | class TopKSelection(SelectionNetwork):
209 | '''Simply collects the top K scores with no thesholding'''
210 | def __init__(self,
211 |
212 |
213 |
214 | cache_size: Optional[int],
215 |
216 | ):
217 | SelectionNetwork.__init__(
218 | self,
219 | per_layer=False,
220 | per_head=False,
221 | shared=True,
222 | output_params=ComponentOutputParams(
223 | requires_recomputation=False,
224 | reduction_mode=None,
225 | ema_params=None,
226 | output_past_non_reduced_history=False,
227 | max_non_reduced_history_len=None,
228 | ),
229 | buffer_names=[],
230 | )
231 | self.cache_size = cache_size
232 |
233 | def select_new_tokens(
234 | self,
235 | parameters,
236 | token_scores,
237 | new_sequences,
238 | num_new_tokens,
239 | attn_weights=None,
240 | attn_mask=None,
241 | position_ids=None,
242 | **kwargs,
243 | ) -> torch.Tensor:
244 | '''Produces indexes for the selected KV cache tokens and a selection
245 | mask.'''
246 |
247 | num_samples = token_scores.shape[-1]
248 | if self.cache_size is not None and num_samples > self.cache_size:
249 | min_value = torch.finfo(token_scores.dtype).min
250 | masked_full_scores = torch.where(
251 | attn_mask.bool(), token_scores, min_value)
252 | _, retained_idxs = torch.topk(
253 | masked_full_scores, k=self.cache_size, sorted=False, dim=-1)
254 | retained_idxs, _ = retained_idxs.sort(descending=False, dim=-1,)
255 | else:
256 | retained_idxs = torch.arange(
257 | num_samples, device=token_scores.device,).view(
258 | 1, 1, num_samples).expand_as(token_scores)
259 | if self.cache_size is not None:
260 | attn_mask = attn_mask[..., -self.cache_size:]
261 | new_mask = torch.ones_like(retained_idxs, dtype=torch.bool)*attn_mask
262 | return retained_idxs, new_mask
263 |
264 | def get_cache_size(self,) -> Optional[int]:
265 | return self.cache_size
266 |
267 | def net_param_size(self,) -> int:
268 | return 0
269 |
270 |
271 |
272 |
273 |
274 |
--------------------------------------------------------------------------------
/memory_policy/shared.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple, Union, Dict, List
2 |
3 | from collections import OrderedDict
4 | import abc
5 | import copy
6 | import torch
7 | from torch import nn
8 |
9 |
10 | class RegistrationCompatible(abc.ABC):
11 | def register_auxiliary_loss_callback(self, auxiliary_loss):
12 | self.auxiliary_loss = auxiliary_loss
13 | self.auxiliary_loss_callback = True
14 | print(
15 | 'ERROR: Using auxiliary loss with component with no parameters')
16 | raise NotImplementedError
17 |
18 | def register_new_memory_layer(self, config, registration_kwargs):
19 |
20 |
21 |
22 |
23 | curr_layer_id = self.num_memory_layers
24 | self.num_memory_layers = self.num_memory_layers + 1
25 | self.config = config
26 | self.registration_kwargs = registration_kwargs
27 | return curr_layer_id
28 |
29 | def register_new_memory_model(self, config, registration_kwargs):
30 | assert 'num_memory_layers' in registration_kwargs
31 | self.num_memory_layers = registration_kwargs['num_memory_layers']
32 | self.num_heads = registration_kwargs['num_heads']
33 | self.num_key_value_groups = registration_kwargs.get(
34 | 'num_key_value_groups', None)
35 | self.config = config
36 | self.registration_kwargs = registration_kwargs
37 | self.hidden_size = registration_kwargs.get(
38 | 'hidden_size', None)
39 | if self.hidden_size is None:
40 | if hasattr(config, 'hidden_size'):
41 | self.hidden_size = config.hidden_size
42 | else:
43 | raise NotImplementedError
44 |
45 |
46 | class SynchronizableBufferStorage(abc.ABC):
47 | def __init__(
48 | self,
49 | buffers_to_merge: Union[List[str], dict] = [],
50 | sub_buffer_storages: list = [],
51 | ):
52 | self.initialize_buffer_dicts_to_merge(
53 | buffers_to_merge=buffers_to_merge,
54 | sub_buffer_storages=sub_buffer_storages)
55 | self.training_mode()
56 | self.unfreeze_sync_buffers()
57 |
58 | def are_sync_buffers_frozen(self,):
59 | return self._frozen_sync_buffers
60 |
61 | def freeze_sync_buffers(self, freeze=True):
62 | self._frozen_sync_buffers = freeze
63 | if self._has_sub_buffers_to_merge:
64 | for sub_buffer_n, sub_buffer in (
65 | self._sub_buffer_ordered_references.items()):
66 | sub_buffer.freeze_sync_buffers(freeze=freeze)
67 |
68 | def unfreeze_sync_buffers(self,):
69 | self.freeze_sync_buffers(freeze=False)
70 |
71 | def training_mode(self,):
72 | self._is_in_training_mode = True
73 | if self._has_sub_buffers_to_merge:
74 | for sub_buffer_n, sub_buffer in (
75 | self._sub_buffer_ordered_references.items()):
76 | sub_buffer.training_mode()
77 |
78 | def evaluation_mode(self,):
79 | self._is_in_training_mode = False
80 | if self._has_sub_buffers_to_merge:
81 | for sub_buffer_n, sub_buffer in (
82 | self._sub_buffer_ordered_references.items()):
83 | sub_buffer.evaluation_mode()
84 |
85 | def get_buffers_to_merge_keys(self,):
86 | return self._buffers_to_merge_keys
87 |
88 |
89 | def get_buffers_list(self,):
90 | buffers_dict = self.get_buffers_dict()
91 | assert len(buffers_dict) == len(self._buffers_to_merge_keys)
92 | return [buffers_dict[k] for k in self._buffers_to_merge_keys]
93 |
94 | def merge_buffers_list(
95 | self,
96 | buffers_to_merge: List[List[torch.Tensor]],
97 | ) -> List[torch.Tensor]:
98 | merged_buffers = []
99 | buffer_group_idx = 0
100 | if self._has_owned_buffers_to_merge:
101 | merged_buffers += self._merge_own_buffers(
102 | buffers_to_merge=buffers_to_merge[
103 | :self._num_total_owned_buffers_to_merge])
104 | buffer_group_idx += self._num_total_buffers_to_merge
105 | if self._has_sub_buffers_to_merge:
106 | for n_buffers, (sub_buffer_n, sub_buffer) in zip(
107 | self.num_buffers_per_sub_buffer,
108 | self._sub_buffer_ordered_references.items()):
109 |
110 | end_buffer_group_idx = buffer_group_idx + n_buffers
111 | rel_buffers_to_merge = buffers_to_merge[
112 | buffer_group_idx:end_buffer_group_idx]
113 | merged_buffers += sub_buffer.merge_buffers_list(
114 | buffers_to_merge=rel_buffers_to_merge)
115 | buffer_group_idx = end_buffer_group_idx
116 | assert len(merged_buffers) == len(buffers_to_merge)
117 | return merged_buffers
118 |
119 | def _merge_own_buffers(
120 | self,
121 | buffers_to_merge: List[List[torch.Tensor]],
122 | ) -> List[torch.Tensor]:
123 | raise NotImplementedError
124 |
125 | def receive_buffers_list(self, buffers_list):
126 | assert len(buffers_list) == len(self._buffers_to_merge_keys)
127 | buffers_dict = {k: v for k, v in zip(
128 | self._buffers_to_merge_keys, buffers_list)}
129 | self.load_buffers_dict(buffers_dict=buffers_dict)
130 |
131 | def get_buffers_dict(self,):
132 | buffers_dict = {}
133 | if self._has_owned_buffers_to_merge:
134 | buffers_dict.update(self.buffers_to_merge_dict)
135 | if self._has_sub_buffers_to_merge:
136 | buffers_dict.update(self.get_dict_from_sub_buffers())
137 | return buffers_dict
138 |
139 | def load_buffers_dict(self, buffers_dict):
140 | if len(buffers_dict) == 0:
141 | return
142 | else:
143 | assert set(buffers_dict.keys()) == set(
144 | self.get_buffers_to_merge_keys())
145 |
146 | for k in self.buffers_to_merge_dict:
147 | self.buffers_to_merge_dict[k] = buffers_dict[k]
148 | if self._has_sub_buffers_to_merge:
149 | self.load_dict_to_sub_buffers(buffers_dict=buffers_dict)
150 |
151 | def _self_merge_own_buffers(self,) -> List[torch.Tensor]:
152 | raise NotImplementedError
153 |
154 | def self_merge(self,) -> List[torch.Tensor]:
155 | merged_buffers = []
156 | if self._has_owned_buffers_to_merge:
157 | merged_buffers += self._self_merge_own_buffers()
158 | if self._has_sub_buffers_to_merge:
159 | for k, sub_buffer in self._sub_buffer_ordered_references.items():
160 | merged_buffers += sub_buffer.self_merge()
161 | return merged_buffers
162 |
163 | def initialize_buffer_dicts_to_merge(
164 | self, buffers_to_merge: Union[List[str], dict],
165 | sub_buffer_storages: list,
166 | reset: bool = True,
167 | ):
168 |
169 | assert reset
170 | if isinstance(buffers_to_merge, dict):
171 | self.buffers_to_merge_dict = OrderedDict(buffers_to_merge)
172 | else:
173 | self.buffers_to_merge_dict = OrderedDict(
174 | [(k, 0) for k in buffers_to_merge])
175 |
176 | self._buffers_to_merge_keys = list(self.buffers_to_merge_dict.keys())
177 | self._owned_buffers_to_merge_keys = copy.copy(
178 | self._buffers_to_merge_keys)
179 | self._num_total_owned_buffers_to_merge = len(
180 | self.buffers_to_merge_dict)
181 | self._num_total_buffers_to_merge = (
182 | self._num_total_owned_buffers_to_merge)
183 | self._has_buffers_to_merge = self._has_owned_buffers_to_merge = (
184 | self._num_total_buffers_to_merge > 0)
185 | self.register_sub_buffers_to_merge(
186 | sub_buffer_storages=sub_buffer_storages,)
187 |
188 | def register_sub_buffers_to_merge(self, sub_buffer_storages: list):
189 | self.num_buffers_per_sub_buffer = []
190 | self._buffers_to_merge_keys_from_sub_buffers = []
191 | self._sub_buffer_ordered_references: Dict[
192 | str, SynchronizableBufferStorage] = OrderedDict()
193 | for i, buffer in enumerate(sub_buffer_storages):
194 | if isinstance(buffer, str):
195 | assert hasattr(self, buffer)
196 | buffer_obj: SynchronizableBufferStorage = getattr(self, buffer)
197 | buffer_name = buffer
198 | else:
199 | buffer_obj: SynchronizableBufferStorage = buffer
200 | buffer_name = f'{i}'
201 |
202 | if buffer_obj._has_buffers_to_merge:
203 | sub_buffer_keys = buffer_obj._buffers_to_merge_keys
204 | self._buffers_to_merge_keys_from_sub_buffers += [
205 | buffer_name + '_' + k for k in sub_buffer_keys]
206 | self.num_buffers_per_sub_buffer.append(len(sub_buffer_keys))
207 | self._sub_buffer_ordered_references[buffer_name] = buffer_obj
208 |
209 | self._buffers_to_merge_keys += (
210 | self._buffers_to_merge_keys_from_sub_buffers)
211 |
212 | self._num_total_buffers_to_merge_from_sub_buffers = sum(
213 | self.num_buffers_per_sub_buffer)
214 | self._has_sub_buffers_to_merge = (
215 | self._num_total_buffers_to_merge_from_sub_buffers > 0)
216 | self._has_buffers_to_merge = (
217 | self._has_buffers_to_merge or self._has_sub_buffers_to_merge)
218 | self._num_total_buffers_to_merge += (
219 | self._num_total_buffers_to_merge_from_sub_buffers)
220 |
221 | def get_dict_from_sub_buffers(self,) -> Dict[str, torch.Tensor]:
222 | buffers_dict = OrderedDict()
223 | for buffer_name, buffer_obj in (
224 | self._sub_buffer_ordered_references.items()):
225 | for k, v in buffer_obj.get_buffers_dict().items():
226 | buffers_dict[buffer_name + '_' + k] = v
227 | return buffers_dict
228 |
229 | def load_dict_to_sub_buffers(self, buffers_dict) -> Dict[str, torch.Tensor]:
230 | for buffer_name, buffer_obj in (
231 | self._sub_buffer_ordered_references.items()):
232 |
233 | buffer_sub_dict = {}
234 | for k, v in buffers_dict.items():
235 | target_prefix = f'{buffer_name}_'
236 | if k.startswith(target_prefix):
237 | buffer_sub_dict[k.removeprefix(target_prefix)] = v
238 |
239 | buffer_obj.load_buffers_dict(buffer_sub_dict)
240 |
241 | return buffers_dict
242 |
--------------------------------------------------------------------------------
/memory_policy/deep_embedding_spectogram.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pdb
3 | import copy
4 | import math
5 | import numbers
6 | import numpy as np
7 | from dataclasses import dataclass
8 | from typing import Optional, Tuple, Union, Dict, Callable, List
9 | import hydra
10 | from omegaconf import DictConfig
11 |
12 | import torch
13 | from torch import nn
14 | import torch.utils.checkpoint
15 | import torch.nn.functional as F
16 | from torch.cuda.amp import autocast
17 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
18 | from transformers import LlamaPreTrainedModel
19 | from transformers.cache_utils import Cache, DynamicCache, StaticCache
20 | from .base import MemoryPolicy, ParamMemoryPolicy
21 | from .base_dynamic import (
22 | DynamicMemoryPolicy, DynamicParamMemoryPolicy,
23 | RecencyParams, AttentionParams, threshold_score_idxs
24 | )
25 | from .base_deep_components import (
26 | ScoringNetwork, TokenEmbedding, SelectionNetwork, wrap_torch_initializer,
27 | ComponentOutputParams, reduce_ema_values)
28 | from stateless_parallel_modules import StatelessGeneralizedMLP
29 |
30 |
31 | def fft_ema_mask(window_length, ema_coeff, hop_length):
32 | '''Creates a mask mimicking an exponential moving average across the fft
33 | values. For consistent usage, ensure window_length == stride.'''
34 | discount_vector_exponents = torch.arange(
35 | start=window_length-1, end=-1, step=-1,)
36 |
37 | discount_vector = torch.pow(ema_coeff, discount_vector_exponents)
38 | rescale_factor = (1-ema_coeff)/(1-ema_coeff**hop_length)
39 | return discount_vector*rescale_factor
40 |
41 |
42 | def fft_avg_mask(window_length):
43 | return torch.ones([window_length])/window_length
44 |
45 |
46 | @dataclass
47 | class STFTParams:
48 | n_fft: int
49 | hop_length: int
50 | window_fn: Optional[Union[
51 | Callable, np.ndarray, torch.Tensor, DictConfig]] = None
52 | pad_mode: str = 'constant'
53 | output_magnitudes: bool = True
54 |
55 |
56 | class AttentionSpectrogram(TokenEmbedding):
57 | '''Representing each KV, via the unerlying freqs. for each token in the
58 | attention matrix:
59 | Each token hits all its future tokens, given N new tokens we construct
60 | the spectrogram of all tokens based on the attention of the new tokens,
61 | we either store all future tokens spectrograms, prune old, or cumulate
62 | them with output_params (e.g., via EMA)'''
63 |
64 | def __init__(
65 | self,
66 | per_layer: bool,
67 | per_head: bool,
68 | shared: bool,
69 | output_params: ComponentOutputParams,
70 | stft_params: STFTParams,
71 | dtype: Optional[Union[str, torch.dtype]] = None,
72 | ):
73 |
74 | self.prev_attn_buffer: List[Optional[torch.Tensor]]
75 | assert output_params.requires_recomputation == True
76 | TokenEmbedding.__init__(
77 | self,
78 | per_layer=per_layer,
79 | per_head=per_head,
80 | shared=shared,
81 | output_params=output_params,
82 | buffer_names=['prev_attn_buffer'],
83 | dtype=dtype,
84 |
85 |
86 |
87 | )
88 | self.store_stft_params(stft_params=stft_params)
89 |
90 | def store_stft_params(self, stft_params: STFTParams):
91 | self.n_fft = stft_params.n_fft
92 | assert self.n_fft % 2 == 0
93 |
94 | self.window_length = self.n_fft
95 | self.stft_stride = stft_params.hop_length
96 | self.reduction_stride = self.stft_stride
97 | self.window_fn = stft_params.window_fn
98 |
99 | self.pad_attn_required = self.window_length - self.stft_stride
100 | self.base_n_fft = self.n_fft//2 + 1
101 | if self.window_fn is not None:
102 | if isinstance(self.window_fn, DictConfig):
103 | window = hydra.utils.call(self.window_fn, self.window_length)
104 | elif isinstance(self.window_fn, Callable):
105 | window = self.window_fn(self.window_length)
106 | else:
107 | window = self.window_fn
108 | assert window.shape[-1] == self.window_length
109 | self.register_buffer(
110 | 'stft_window', tensor=window, persistent=False)
111 | else:
112 | self.stft_window = None
113 | self.output_magnitudes = stft_params.output_magnitudes
114 | self.stft_pad_mode = stft_params.pad_mode
115 | assert self.stft_pad_mode == 'constant', (
116 | 'TODO: Deviating from constant padding might lead to inconsistent '
117 | 'results, needs to be tested')
118 |
119 | def get_tokens_embedding(
120 | self,
121 | layer_id,
122 | parameters,
123 | key_cache,
124 | value_cache,
125 | new_sequences,
126 | num_new_tokens,
127 | attn_weights,
128 | attn_mask=None,
129 | position_ids=None,
130 | analyze=False,
131 | **kwargs,
132 | ) -> torch.Tensor:
133 | '''Builds a tensor representation for each KV cache token'''
134 | parameters, aux_params = self.split_net_and_aux_params(
135 | parameters=parameters)
136 |
137 | device = key_cache.device
138 |
139 | batch_size, n_heads, num_all_tokens, emb_dim = key_cache.shape
140 | batch_size, n_heads, num_new_tokens, num_all_tokens = attn_weights.shape
141 |
142 | if attn_mask is not None:
143 |
144 | attn_mask = attn_mask[..., -num_all_tokens:].unsqueeze(-2) > 0
145 | else:
146 | attn_mask = torch.ones([batch_size, 1, num_all_tokens],
147 | device=device, dtype=torch.bool)
148 |
149 | num_new_embeddings = num_new_tokens // self.stft_stride
150 | stride_carry_over = num_new_tokens % self.stft_stride
151 | if stride_carry_over > 0:
152 |
153 | raise NotImplementedError
154 |
155 | attn_weights = attn_weights.transpose(dim0=-2, dim1=-1)
156 | if self.pad_attn_required > 0:
157 | if new_sequences:
158 |
159 | pad_tuple = [self.pad_attn_required, 0]
160 | rel_attn_weights = F.pad(
161 | input=attn_weights,
162 | pad=pad_tuple,
163 | mode=self.stft_pad_mode,
164 | )
165 | else:
166 |
167 | prev_attn_weights = self.prev_attn_buffer[layer_id]
168 | rel_prev_attn_weights = prev_attn_weights[
169 | ..., -self.pad_attn_required:]
170 | pad_tuple = [0, 0, 0, num_new_tokens]
171 |
172 | padded_rel_prev_attn_weights = F.pad(
173 | input=rel_prev_attn_weights,
174 | pad=pad_tuple,
175 | mode='constant',
176 | )
177 |
178 | rel_attn_weights = torch.concat(
179 | [padded_rel_prev_attn_weights, attn_weights],
180 | dim=-1,
181 | )
182 |
183 | if not analyze:
184 | self.prev_attn_buffer[layer_id] = attn_weights
185 | else:
186 | rel_attn_weights = attn_weights
187 |
188 | flat_rel_attn_weights = rel_attn_weights.flatten(
189 | start_dim=0, end_dim=-2)
190 |
191 | flat_attn_stft = torch.stft(
192 | input=flat_rel_attn_weights,
193 | n_fft=self.n_fft,
194 | hop_length=self.stft_stride,
195 | center=False,
196 | pad_mode='constant',
197 | normalized=False,
198 | onesided=True,
199 | return_complex=True,
200 | window=self.stft_window,
201 | )
202 |
203 | attn_stft = flat_attn_stft.view(
204 | batch_size, n_heads, num_all_tokens,
205 | self.base_n_fft, num_new_embeddings)
206 |
207 | attn_stft = attn_stft.permute(dims=[0, 1, 4, 2, 3])
208 |
209 | if self.output_magnitudes:
210 | attn_stft = attn_stft.abs()
211 | else:
212 |
213 | attn_stft = torch.view_as_real(attn_stft).flatten(
214 | start_dim=-2, end_dim=-1)
215 |
216 | if self._custom_dtype is not None:
217 | attn_stft = attn_stft.to(dtype=self.ptdtype)
218 |
219 | if self.output_past_non_reduced_history:
220 |
221 | past_attn_spectr = self.past_outputs_buffer[layer_id]
222 |
223 | pad_tuple = [0, 0, 0, num_new_tokens]
224 | padded_rel_prev_attn_weights = F.pad(
225 | input=rel_prev_attn_weights,
226 | pad=pad_tuple,
227 | mode='constant',
228 | )
229 | attn_stft = torch.concat([past_attn_spectr, attn_stft], dim=-3)
230 | if self.limit_past_history_size:
231 | attn_stft = attn_stft[
232 | ..., -self.max_non_reduced_history_len:, :, :]
233 | embeddings = attn_stft
234 |
235 | if not analyze:
236 | self.past_outputs_buffer[layer_id] = attn_stft
237 |
238 | else:
239 | embeddings = attn_stft
240 |
241 | embeddings = self.process_output(
242 | layer_id=layer_id,
243 | ema_coeff=self.ema_coeff,
244 | num_new_tokens=num_new_tokens,
245 | new_sequences=new_sequences,
246 | component_output=embeddings,
247 | aux_params=aux_params,
248 | attn_mask=attn_mask,
249 | analyze=analyze,
250 | **kwargs,
251 | )
252 |
253 | return embeddings
254 |
255 | def get_embedding_dim(self,) -> int:
256 |
257 | if self.output_magnitudes:
258 | return self.base_n_fft
259 | else:
260 | return self.base_n_fft*2
261 |
262 | def net_param_size(self,) -> int:
263 | return 0
264 |
265 | @property
266 | def requires_attn_scores(self,):
267 |
268 | return True
269 |
270 | @property
271 | def requires_recomputation(self,):
272 |
273 | return True
274 |
275 | def filter_buffer_values(
276 | self,
277 | layer_id: int,
278 |
279 | retained_idxs: torch.Tensor):
280 | if self.pad_attn_required > 0:
281 | prev_attn_buffer: torch.Tensor = self.prev_attn_buffer[layer_id]
282 | prev_new_tokens = prev_attn_buffer.shape[-1]
283 | expanded_idxs = retained_idxs.unsqueeze(-1).expand(
284 | -1, -1, -1, prev_new_tokens)
285 | prev_attn_buffer = torch.gather(
286 | input=prev_attn_buffer,
287 | dim=-2,
288 | index=expanded_idxs,
289 | )
290 | self.prev_attn_buffer[layer_id] = prev_attn_buffer
291 | TokenEmbedding.filter_buffer_values(
292 | self=self,
293 | layer_id=layer_id,
294 | retained_idxs=retained_idxs)
295 |
--------------------------------------------------------------------------------
/memory_policy/auxiliary_losses.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pdb
3 | import copy
4 | import math
5 | import numpy as np
6 | from dataclasses import dataclass
7 | from typing import Optional, Tuple, Union, List
8 |
9 | import abc
10 | import torch
11 | from torch import nn
12 | import torch.utils.checkpoint
13 | import torch.nn.functional as F
14 | from torch.cuda.amp import autocast
15 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
16 | from transformers import LlamaPreTrainedModel
17 | from transformers.cache_utils import Cache, DynamicCache, StaticCache
18 | from .base import MemoryPolicy, ParamMemoryPolicy
19 | from .base_dynamic import DynamicMemoryPolicy, DynamicParamMemoryPolicy
20 |
21 | from omegaconf import OmegaConf, DictConfig
22 | import hydra
23 | import numpy as np
24 |
25 |
26 | class MemoryPolicyAuxiliaryLoss(nn.Module, abc.ABC):
27 |
28 | def __init__(self,
29 | memory_policy: ParamMemoryPolicy,
30 | coeff: float = 1.0,
31 | adaptive_target: Optional[float] = None,
32 | adaptive_target_param: str = 'exp',
33 | optimizer=None,
34 | adaptive_update_every: int = 1,
35 | device: str = 'cuda',
36 | ):
37 | nn.Module.__init__(self,)
38 | self.device = device
39 | assert isinstance(memory_policy, ParamMemoryPolicy)
40 | self.pop_size = memory_policy.param_pop_size
41 |
42 | self.num_memory_layers = memory_policy.num_memory_layers
43 |
44 | self.step_losses = torch.zeros([self.pop_size], dtype=torch.long,
45 | device=device)
46 |
47 | self.is_adaptive = adaptive_target is not None
48 | self.adaptive_target = adaptive_target
49 |
50 | assert adaptive_target_param in ['exp', 'linear']
51 | self.is_exp = adaptive_target_param == 'exp'
52 | self.adaptive_target_param = adaptive_target_param
53 | if self.is_adaptive:
54 | assert optimizer is not None
55 | if self.is_exp:
56 | assert coeff > 0
57 | init_coeff = np.log(coeff)
58 | else:
59 | init_coeff = coeff
60 | self.raw_coeff = nn.Parameter(
61 | data=torch.zeros([1]) + float(init_coeff), requires_grad=False)
62 | else:
63 | self.register_buffer('raw_coeff', tensor=torch.tensor(coeff),
64 | persistent=False)
65 | self.adaptive_update_every = adaptive_update_every
66 | self.num_adaptive_updates = 0
67 | self.stored_losses = []
68 |
69 | self.num_samples_per_pop = torch.zeros(
70 | [self.pop_size], dtype=torch.long, device=device)
71 | self.total_losses_per_pop = torch.zeros(
72 | [self.pop_size], dtype=torch.long, device=device)
73 | memory_policy.register_auxiliary_loss_callback(auxiliary_loss=self)
74 |
75 | def restart_recording(self,):
76 |
77 | self.num_samples_per_pop.data.copy_(torch.zeros_like(
78 | self.num_samples_per_pop))
79 | self.total_losses_per_pop.data.copy_(torch.zeros_like(
80 | self.total_losses_per_pop))
81 |
82 | @abc.abstractmethod
83 | def memory_policy_layer_callback(
84 | self,
85 | layer_id,
86 | pop_idxs,
87 | new_sequences,
88 | key_cache,
89 | value_cache,
90 | dynamic_mask=None):
91 | raise NotImplementedError
92 |
93 | @abc.abstractmethod
94 | def memory_policy_update_callback(
95 | self,
96 | layer_id,
97 | pop_idxs,
98 | new_sequences,
99 | new_kv_cache,):
100 | raise NotImplementedError
101 |
102 | @property
103 | def coeff(self,):
104 | if self.is_adaptive and self.is_exp:
105 | return torch.exp(self.raw_coeff)
106 | else:
107 | return self.raw_coeff
108 |
109 | @abc.abstractmethod
110 | def get_loss(self,):
111 | raise NotImplementedError
112 |
113 | def optim_params(self, loss):
114 | self.optimizer.zero_grad(set_to_none=True)
115 | dual_difference = self.adaptive_target - loss
116 | self.raw_coeff.grad = dual_difference
117 | self.optimizer.step()
118 | return dual_difference
119 |
120 | def setup_optimizer(self, optimizer):
121 | assert optimizer is not None
122 | if isinstance(self.optimizer, DictConfig):
123 |
124 | self.optimizer: torch.optim.Optimizer = hydra.utils.instantiate(
125 | optimizer, params=[self.raw_coeff], _convert_='all')
126 | else:
127 | self.optimizer: torch.optim.Optimizer = optimizer(
128 | params=[self.raw_coeff])
129 |
130 | def forward(self,):
131 | loss = self.get_loss()
132 |
133 | if self.is_adaptive:
134 | if self.num_adaptive_updates % self.adaptive_update_every == 0:
135 | _ = self.optim_params(loss=loss)
136 | self.num_adaptive_updates += 1
137 | scaled_loss = self.coeff*loss
138 |
139 | return scaled_loss
140 |
141 |
142 | class JointAuxiliaryLosses(MemoryPolicyAuxiliaryLoss):
143 |
144 | def __init__(self, memory_policy: ParamMemoryPolicy,
145 | auxiliary_losses: List[MemoryPolicyAuxiliaryLoss]):
146 | MemoryPolicyAuxiliaryLoss.__init__(self, memory_policy=memory_policy)
147 | assert len(auxiliary_losses) > 0
148 | self.auxiliary_losses = nn.ModuleList(auxiliary_losses)
149 | self.adaptive_losses = [
150 | loss for loss in auxiliary_losses if loss.is_adaptive]
151 | self.is_adaptive = len(self.adaptive_losses) > 0
152 |
153 | def get_loss(self,):
154 | loss = 0
155 | for aux_loss in self.auxiliary_losses:
156 | loss = loss + aux_loss.get_loss()
157 | return loss
158 |
159 |
160 | class SparsityAuxiliaryLoss(MemoryPolicyAuxiliaryLoss):
161 |
162 | def __init__(self,
163 | memory_policy: DynamicParamMemoryPolicy,
164 | coeff: float = 1.0,
165 | adaptive_target: Optional[float] = None,
166 | adaptive_target_param: str = 'exp',
167 | optimizer=None,
168 | adaptive_update_every: int = 1,
169 | sparsity_mode: str = 'mean',
170 | sparsity_per_head: bool = False,
171 | device: str = 'cuda',
172 | ):
173 | MemoryPolicyAuxiliaryLoss.__init__(
174 | self,
175 | memory_policy=memory_policy,
176 | coeff=coeff,
177 | adaptive_target=adaptive_target,
178 | adaptive_target_param=adaptive_target_param,
179 | optimizer=optimizer,
180 | adaptive_update_every=adaptive_update_every,
181 | device=device,
182 | )
183 | self.sparsity_mode = sparsity_mode
184 | assert sparsity_mode in ['mean',]
185 | self.sparsity_per_head = sparsity_per_head
186 |
187 | self.losses_per_pop_per_layer = torch.zeros(
188 | [self.num_memory_layers, self.pop_size], dtype=torch.long,
189 | device=device)
190 |
191 | def memory_policy_layer_callback(
192 | self,
193 | layer_id,
194 | pop_idxs,
195 | new_sequences,
196 | key_cache,
197 | value_cache,
198 | dynamic_mask=None,
199 | scoring_network_params=None):
200 |
201 | if dynamic_mask is not None:
202 | unmasked_samples_per_head: torch.Tensor = dynamic_mask.to(
203 | dtype=self.losses_per_pop_per_layer.dtype).sum(-1)
204 | if self.sparsity_per_head:
205 |
206 | layer_loss = (unmasked_samples_per_head.sum(-1) /
207 | unmasked_samples_per_head.numel()).to(
208 | dtype=torch.long)
209 | else:
210 | layer_loss = unmasked_samples_per_head.max(-1)[0]
211 | else:
212 | layer_loss = key_cache.size(-2)*torch.ones(
213 | [key_cache.size(0)],
214 | device=self.losses_per_pop_per_layer.device,
215 | dtype=self.losses_per_pop_per_layer.dtype,
216 | )
217 | self.losses_per_pop_per_layer[layer_id].data.copy_(torch.zeros_like(
218 | self.losses_per_pop_per_layer[layer_id]))
219 | self.losses_per_pop_per_layer[layer_id].scatter_add_(
220 | dim=0, index=pop_idxs, src=layer_loss,) # reduce='sum', include_self=False)
221 |
222 | def memory_policy_update_callback(
223 | self,
224 | pop_idxs,
225 | new_sequences,
226 | new_kv_cache,):
227 | losses_per_pop = self.losses_per_pop_per_layer.sum(dim=0)
228 |
229 | self.num_samples_per_pop.scatter_reduce_(
230 | dim=0, index=pop_idxs,
231 | # .to(dtype=self.num_samples_per_), #[pop_idxs], device=self.num_samples_per_pop.device),
232 | src=torch.ones_like(pop_idxs).to(dtype=losses_per_pop.dtype),
233 | reduce='sum',
234 | include_self=True)
235 | self.total_losses_per_pop.data.add_(losses_per_pop)
236 |
237 | def get_loss(self,):
238 | return self.total_losses_per_pop/(
239 | torch.clamp_min_(
240 | self.num_samples_per_pop*self.num_memory_layers, 1))
241 |
242 |
243 | class L2NormAuxiliaryLoss(MemoryPolicyAuxiliaryLoss):
244 | def __init__(self,
245 | memory_policy: DynamicParamMemoryPolicy,
246 | coeff: float = 1.0,
247 | adaptive_target: Optional[float] = None,
248 | adaptive_target_param: str = 'linear',
249 | optimizer=None,
250 | adaptive_update_every: int = 1,
251 | device: str = 'cuda',
252 | ):
253 | MemoryPolicyAuxiliaryLoss.__init__(
254 | self,
255 | memory_policy=memory_policy,
256 | coeff=coeff,
257 | adaptive_target=adaptive_target,
258 | adaptive_target_param=adaptive_target_param,
259 | optimizer=optimizer,
260 | adaptive_update_every=adaptive_update_every,
261 | device=device,
262 | )
263 |
264 | self.losses_per_pop_per_layer = torch.zeros(
265 | [self.num_memory_layers, self.pop_size], device=device)
266 |
267 | self.total_losses_per_pop = torch.zeros(
268 | [self.pop_size], dtype=torch.float, device=device)
269 |
270 | def memory_policy_layer_callback(
271 | self,
272 | layer_id,
273 | pop_idxs,
274 | new_sequences,
275 | key_cache,
276 | value_cache,
277 | dynamic_mask=None,
278 | scoring_network_params=None):
279 |
280 | self.total_losses_per_pop[pop_idxs] = (
281 | scoring_network_params ** 2).mean().to(self.device)
282 |
283 | def memory_policy_update_callback(
284 | self,
285 | pop_idxs,
286 | new_sequences,
287 | new_kv_cache,):
288 | pass
289 |
290 | def get_loss(self,):
291 | return self.total_losses_per_pop
292 |
--------------------------------------------------------------------------------
/stateless_parallel_modules/base.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pdb
3 | import copy
4 | import math
5 | import numpy as np
6 | from dataclasses import dataclass
7 | from typing import Optional, Tuple, Union, Dict, Callable, List
8 | import abc
9 |
10 | import torch
11 | from torch import nn
12 | import torch.utils.checkpoint
13 | import torch.nn.functional as F
14 | from torch.cuda.amp import autocast
15 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
16 | from transformers import LlamaPreTrainedModel
17 | from transformers.cache_utils import Cache, DynamicCache, StaticCache
18 | from utils import get_nonlinearity
19 |
20 |
21 | class StatelessGeneralizedOperation(nn.Module, abc.ABC):
22 | def __init__(
23 | self,
24 | ):
25 | nn.Module.__init__(self=self)
26 |
27 | @abc.abstractmethod
28 | def total_parameters(
29 | self,
30 | *args,
31 | parallel_operations: Optional[int] = None,
32 | **kwargs,
33 | ) -> int:
34 | raise NotImplementedError
35 |
36 | @abc.abstractmethod
37 | def prepare_parameters(
38 | parameters: torch.Tensor,
39 | *args,
40 | parallel_operations: Optional[int] = None,
41 | **kwargs,
42 | ):
43 | raise NotImplementedError
44 |
45 | @abc.abstractmethod
46 | def forward(
47 | # parallel_operations x batch_size x in_features
48 | input: torch.Tensor,
49 | *args,
50 | parallel_operations: Optional[int] = None,
51 | n_parallel_dimensions: Optional[int] = None,
52 | **kwargs,
53 | ) -> torch.Tensor:
54 | raise NotImplementedError
55 |
56 |
57 | class GeneralizedLinear(StatelessGeneralizedOperation):
58 | def __init__(
59 | self,
60 | ):
61 | StatelessGeneralizedOperation.__init__(self=self,)
62 |
63 | def total_parameters(
64 | self,
65 | in_features: int,
66 | out_features: int,
67 | bias: bool,
68 | parallel_operations: Optional[int] = None,
69 | **kwargs,
70 | ) -> int:
71 | total_weight_dimension = in_features*out_features
72 | total_base_parameters_dimension = total_weight_dimension
73 | if bias:
74 | total_base_parameters_dimension += out_features
75 | if parallel_operations is not None:
76 | total_parameters_dimension = (
77 | parallel_operations*total_base_parameters_dimension)
78 | else:
79 | total_parameters_dimension = total_base_parameters_dimension
80 | return total_parameters_dimension, total_base_parameters_dimension
81 |
82 | def prepare_parameters(
83 | self,
84 | in_features: int,
85 | out_features: int,
86 | bias: bool,
87 | parameters: torch.Tensor,
88 | parallel_operations: Optional[int] = None,
89 | **kwargs,
90 | ):
91 | if parallel_operations is not None:
92 | parameters = parameters.view(parallel_operations, -1)
93 | if bias:
94 | total_weight_dimension = in_features*out_features
95 | flat_w, flat_b = parameters.split_with_sizes(
96 | [total_weight_dimension, out_features], dim=-1)
97 | # parallel_ops x 1 x out_features
98 | b = flat_b.unsqueeze_(-2)
99 | else:
100 | flat_w = parameters
101 | b = None
102 | # shape required for baddbmm/addbmm
103 | w = flat_w.view(parallel_operations, in_features, out_features)
104 | else:
105 | if bias:
106 | total_weight_dimension = in_features*out_features
107 | flat_w, b = parameters.split_with_sizes(
108 | [total_weight_dimension, out_features], dim=-1)
109 | else:
110 | flat_w = parameters
111 | b = None
112 | # shape required for F.linear
113 | w = flat_w.view(out_features, in_features)
114 | return w, b
115 |
116 | def forward(
117 | self,
118 | input: torch.Tensor, # parallel_operations x batch_size x in_features
119 | weight: torch.Tensor,
120 | bias: Optional[torch.Tensor] = None,
121 | parallel_operations: Optional[int] = None,
122 | n_parallel_dimensions: Optional[int] = None,
123 | **kwargs,
124 | ) -> torch.Tensor:
125 | if n_parallel_dimensions is not None:
126 | input_shape = input.shape
127 |
128 | # this is guaranteed to be 3-dimensional
129 | input = input.flatten(start_dim=1, end_dim=-2)
130 |
131 | if parallel_operations is not None:
132 | if bias is not None:
133 | out = torch.baddbmm(input=bias, batch1=input, batch2=weight)
134 | else:
135 | out = input @ weight
136 | else:
137 | out = F.linear(input=input, weight=weight, bias=bias)
138 | if n_parallel_dimensions is not None:
139 | out = out.view(*input_shape[:-1], -1)
140 | return out
141 |
142 |
143 | class StatelessGeneralizedModule(nn.Module, abc.ABC):
144 | def __init__(
145 | self,
146 | input_features: Optional[int],
147 | output_features: Optional[int],
148 | init_module: bool = True,
149 | operation_kwargs: Optional[dict] = None
150 | ):
151 | if init_module:
152 | nn.Module.__init__(self=self,)
153 |
154 | if operation_kwargs is None:
155 | self.operation_kwargs = {}
156 | else:
157 | self.operation_kwargs = operation_kwargs
158 |
159 | self.input_features = input_features
160 | self.output_features = output_features
161 | if not hasattr(self, 'total_base_parameter_dims'):
162 | # num. of params needed for forward (for each parallel op.)
163 | self.total_base_parameter_dims = 0
164 |
165 | if not hasattr(self, 'n_base_parameters_per_layer'):
166 | self.n_base_parameters_per_layer: List[int] = []
167 | if not hasattr(self, 'parameters_per_layer'):
168 | self.parameters_per_layer: List[torch.Tensor] = []
169 | if not hasattr(self, 'kwargs_per_op'):
170 | self.kwargs_per_op: List[dict] = []
171 | if not hasattr(self, 'generalized_ops'):
172 | self.generalized_ops: List[StatelessGeneralizedOperation] = []
173 | self.parallel_operations: Optional[int] = None
174 |
175 | def get_buffer_names(self,):
176 | return []
177 |
178 | def instantiate_and_setup_ops(
179 | self,
180 | input_features: Optional[int] = None,
181 | output_features: Optional[int] = None,
182 | preceding_module=None,
183 | default_output_features_mult: int = 1,
184 | **kwargs,
185 | ):
186 | raise NotImplementedError
187 |
188 | def setup_operations(
189 | self,
190 | operations: Union[
191 | List[StatelessGeneralizedOperation],
192 | StatelessGeneralizedOperation],
193 | operation_kwargs: Optional[dict] = None,
194 | operation_kwargs_overrides_list: Optional[List[dict]] = None,
195 | # save list of operations as nn.ModuleList
196 | save_as_module_list: bool = True,
197 | ):
198 |
199 | if operation_kwargs is not None:
200 | self.operation_kwargs = operation_kwargs
201 |
202 | operation_kwargs = self.operation_kwargs
203 |
204 | if operation_kwargs_overrides_list is not None:
205 | assert len(operation_kwargs_overrides_list) == len(operations)
206 | elif isinstance(operations, StatelessGeneralizedModule):
207 | operation_kwargs_overrides_list = [{}]
208 | else:
209 | operation_kwargs_overrides_list = [{} for _ in operations]
210 |
211 | if isinstance(operations, StatelessGeneralizedModule):
212 | operations = [operations for _ in range(
213 | operation_kwargs_overrides_list)]
214 |
215 | self.n_base_parameters_per_layer = self.n_base_parameters_per_layer
216 | self.generalized_ops = operations
217 | for op, op_kwargs_overrides in zip(
218 | self.generalized_ops, operation_kwargs_overrides_list):
219 | op: StatelessGeneralizedOperation
220 | op_kwargs = {}
221 | op_kwargs.update(operation_kwargs)
222 | op_kwargs.update(op_kwargs_overrides)
223 | self.kwargs_per_op.append(op_kwargs)
224 | op_base_params, _ = op.total_parameters(
225 | parallel_operations=None, **op_kwargs)
226 | self.n_base_parameters_per_layer.append(op_base_params)
227 |
228 | self.total_base_parameter_dims = int(sum(
229 | self.n_base_parameters_per_layer))
230 | if save_as_module_list:
231 | self.generalized_ops_module_list = nn.ModuleList(
232 | self.generalized_ops)
233 |
234 | def instantiate_model(
235 | self,
236 | input_features: Optional[int] = None,
237 | output_features: Optional[int] = None,
238 | preceding_module=None,
239 | default_output_features_mult: int = 1,
240 | ):
241 | if self.input_features is None and input_features is not None:
242 | self.input_features = input_features
243 | elif self.input_features is None:
244 | assert preceding_module is not None
245 | assert hasattr(preceding_module, 'output_features')
246 | self.input_features: int = preceding_module.output_features
247 | print('Warning: input features not specified setting to ' +
248 | f'{self.input_features} (output features of ' +
249 | 'preceding module)')
250 |
251 | assert isinstance(self.input_features, int)
252 | if self.output_features is None and output_features is not None:
253 | self.output_features = output_features
254 | elif self.output_features is None:
255 | self.output_features = (
256 | self.input_features*default_output_features_mult)
257 | print('Warning: output features not specified setting to ' +
258 | f'{self.output_features} ' +
259 | '(input features*default_output_features_mult)')
260 |
261 | def format_parameters(
262 | self,
263 | parameters: torch.Tensor,
264 | # should match first dimension of parameters
265 | parallel_operations: Optional[int] = None,
266 | ):
267 |
268 | if parallel_operations is not None:
269 | # sanity check, unneded/redundant (trivial overhead)
270 | parameters = parameters.view(
271 | parallel_operations, self.n_base_parameters_per_layer)
272 |
273 | self.parallel_operations = parallel_operations
274 |
275 | # assumes parameters is a flattened tensor
276 | flat_parameters_per_layer = parameters.split_with_sizes(
277 | self.n_base_parameters_per_layer, dim=-1)
278 |
279 | parameters_per_layer: list[torch.Tensor] = []
280 |
281 | for op, op_kwargs, layer_params in zip(
282 | self.generalized_ops,
283 | self.kwargs_per_op,
284 | flat_parameters_per_layer,
285 | ):
286 |
287 | prepared_params = op.prepare_parameters(
288 | parameters=layer_params,
289 | parallel_operations=parallel_operations,
290 | **op_kwargs,
291 | )
292 |
293 | parameters_per_layer.append(prepared_params)
294 | return parameters_per_layer
295 |
296 | def load_parameters(
297 | self,
298 | parameters: torch.Tensor,
299 | # should match first dimension of parameters
300 | parallel_operations: Optional[int] = None,
301 | ):
302 |
303 | if parallel_operations is not None:
304 | # sanity check, unneded/redundant (trivial overhead)
305 | parameters = parameters.view(
306 | parallel_operations, self.total_base_parameter_dims)
307 |
308 | self.parallel_operations = parallel_operations
309 |
310 | # assumes parameters is a flattened tensor
311 | flat_parameters_per_layer = parameters.split_with_sizes(
312 | self.n_base_parameters_per_layer, dim=-1)
313 |
314 | self.parameters_per_layer: list[torch.Tensor] = []
315 |
316 | for op, op_kwargs, layer_params in zip(
317 | self.generalized_ops,
318 | self.kwargs_per_op,
319 | flat_parameters_per_layer,
320 | ):
321 |
322 | prepared_params = op.prepare_parameters(
323 | **op_kwargs,
324 | parameters=layer_params,
325 | parallel_operations=parallel_operations,
326 | )
327 |
328 | self.parameters_per_layer.append(prepared_params)
329 |
330 | @abc.abstractmethod
331 | def forward(
332 | self,
333 | # parallel_operations x batch_size x input_features
334 | inputs: torch.Tensor,
335 | *args,
336 | n_parallel_dimensions: Optional[int] = None,
337 | **kwargs,
338 | ):
339 | raise NotImplementedError
340 |
--------------------------------------------------------------------------------
/env.yaml:
--------------------------------------------------------------------------------
1 | name: th2
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - _libgcc_mutex=0.1=conda_forge
9 | - _openmp_mutex=4.5=2_gnu
10 | - aom=3.9.1=hac33072_0
11 | - asttokens=2.4.1=pyhd8ed1ab_0
12 | - blas=1.0=mkl
13 | - brotli-python=1.1.0=py310hc6cd4ac_1
14 | - bzip2=1.0.8=hd590300_5
15 | - ca-certificates=2024.6.2=hbcca054_0
16 | - cairo=1.18.0=h3faef2a_0
17 | - certifi=2024.6.2=pyhd8ed1ab_0
18 | - charset-normalizer=3.3.2=pyhd8ed1ab_0
19 | - comm=0.2.2=pyhd8ed1ab_0
20 | - cuda-cudart=12.1.105=0
21 | - cuda-cupti=12.1.105=0
22 | - cuda-libraries=12.1.0=0
23 | - cuda-nvrtc=12.1.105=0
24 | - cuda-nvtx=12.1.105=0
25 | - cuda-opencl=12.5.39=0
26 | - cuda-runtime=12.1.0=0
27 | - cuda-version=12.5=3
28 | - dav1d=1.2.1=hd590300_0
29 | - debugpy=1.8.2=py310h76e45a6_0
30 | - decorator=5.1.1=pyhd8ed1ab_0
31 | - executing=2.0.1=pyhd8ed1ab_0
32 | - expat=2.6.2=h59595ed_0
33 | - ffmpeg=7.0.1=gpl_hb399a10_100
34 | - filelock=3.15.1=pyhd8ed1ab_0
35 | - font-ttf-dejavu-sans-mono=2.37=hab24e00_0
36 | - font-ttf-inconsolata=3.000=h77eed37_0
37 | - font-ttf-source-code-pro=2.038=h77eed37_0
38 | - font-ttf-ubuntu=0.83=h77eed37_2
39 | - fontconfig=2.14.2=h14ed4e7_0
40 | - fonts-conda-ecosystem=1=0
41 | - fonts-conda-forge=1=0
42 | - freetype=2.12.1=h267a509_2
43 | - fribidi=1.0.10=h36c2ea0_0
44 | - gettext=0.22.5=h59595ed_2
45 | - gettext-tools=0.22.5=h59595ed_2
46 | - gmp=6.3.0=h59595ed_1
47 | - gmpy2=2.1.5=py310hc7909c9_1
48 | - gnutls=3.7.9=hb077bed_0
49 | - graphite2=1.3.13=h59595ed_1003
50 | - harfbuzz=8.5.0=hfac3d4d_0
51 | - icu=73.2=h59595ed_0
52 | - idna=3.7=pyhd8ed1ab_0
53 | - importlib-metadata=8.0.0=pyha770c72_0
54 | - importlib_metadata=8.0.0=hd8ed1ab_0
55 | - intel-openmp=2022.0.1=h06a4308_3633
56 | - ipykernel=6.29.5=pyh3099207_0
57 | - ipython=8.26.0=pyh707e725_0
58 | - ipywidgets=8.1.3=pyhd8ed1ab_0
59 | - jedi=0.19.1=pyhd8ed1ab_0
60 | - jinja2=3.1.4=pyhd8ed1ab_0
61 | - jupyter_client=8.6.2=pyhd8ed1ab_0
62 | - jupyter_core=5.7.2=py310hff52083_0
63 | - jupyterlab_widgets=3.0.11=pyhd8ed1ab_0
64 | - keyutils=1.6.1=h166bdaf_0
65 | - krb5=1.21.3=h659f571_0
66 | - lame=3.100=h166bdaf_1003
67 | - lcms2=2.16=hb7c19ff_0
68 | - ld_impl_linux-64=2.40=hf3520f5_7
69 | - lerc=4.0.0=h27087fc_0
70 | - libabseil=20240116.2=cxx17_h59595ed_0
71 | - libasprintf=0.22.5=h661eb56_2
72 | - libasprintf-devel=0.22.5=h661eb56_2
73 | - libass=0.17.1=h8fe9dca_1
74 | - libblas=3.9.0=16_linux64_mkl
75 | - libcblas=3.9.0=16_linux64_mkl
76 | - libcublas=12.1.0.26=0
77 | - libcufft=11.0.2.4=0
78 | - libcufile=1.10.0.4=0
79 | - libcurand=10.3.6.39=0
80 | - libcusolver=11.4.4.55=0
81 | - libcusparse=12.0.2.55=0
82 | - libdeflate=1.20=hd590300_0
83 | - libdrm=2.4.120=hd590300_0
84 | - libedit=3.1.20191231=he28a2e2_2
85 | - libexpat=2.6.2=h59595ed_0
86 | - libffi=3.4.2=h7f98852_5
87 | - libgcc-ng=13.2.0=h77fa898_10
88 | - libgettextpo=0.22.5=h59595ed_2
89 | - libgettextpo-devel=0.22.5=h59595ed_2
90 | - libglib=2.80.2=h8a4344b_1
91 | - libgomp=13.2.0=h77fa898_10
92 | - libhwloc=2.10.0=default_h5622ce7_1001
93 | - libiconv=1.17=hd590300_2
94 | - libidn2=2.3.7=hd590300_0
95 | - libjpeg-turbo=3.0.0=hd590300_1
96 | - liblapack=3.9.0=16_linux64_mkl
97 | - libnpp=12.0.2.50=0
98 | - libnsl=2.0.1=hd590300_0
99 | - libnvjitlink=12.1.105=0
100 | - libnvjpeg=12.1.1.14=0
101 | - libopenvino=2024.1.0=h2da1b83_7
102 | - libopenvino-auto-batch-plugin=2024.1.0=hb045406_7
103 | - libopenvino-auto-plugin=2024.1.0=hb045406_7
104 | - libopenvino-hetero-plugin=2024.1.0=h5c03a75_7
105 | - libopenvino-intel-cpu-plugin=2024.1.0=h2da1b83_7
106 | - libopenvino-intel-gpu-plugin=2024.1.0=h2da1b83_7
107 | - libopenvino-intel-npu-plugin=2024.1.0=he02047a_7
108 | - libopenvino-ir-frontend=2024.1.0=h5c03a75_7
109 | - libopenvino-onnx-frontend=2024.1.0=h07e8aee_7
110 | - libopenvino-paddle-frontend=2024.1.0=h07e8aee_7
111 | - libopenvino-pytorch-frontend=2024.1.0=he02047a_7
112 | - libopenvino-tensorflow-frontend=2024.1.0=h39126c6_7
113 | - libopenvino-tensorflow-lite-frontend=2024.1.0=he02047a_7
114 | - libopus=1.3.1=h7f98852_1
115 | - libpciaccess=0.18=hd590300_0
116 | - libpng=1.6.43=h2797004_0
117 | - libprotobuf=4.25.3=h08a7969_0
118 | - libsodium=1.0.18=h36c2ea0_1
119 | - libsqlite=3.46.0=hde9e2c9_0
120 | - libstdcxx-ng=13.2.0=hc0a3c3a_10
121 | - libtasn1=4.19.0=h166bdaf_0
122 | - libtiff=4.6.0=h1dd3fc0_3
123 | - libunistring=0.9.10=h7f98852_0
124 | - libuuid=2.38.1=h0b41bf4_0
125 | - libva=2.21.0=h4ab18f5_2
126 | - libvpx=1.14.1=hac33072_0
127 | - libwebp-base=1.4.0=hd590300_0
128 | - libxcb=1.15=h0b41bf4_0
129 | - libxcrypt=4.4.36=hd590300_1
130 | - libxml2=2.12.7=hc051c1a_1
131 | - libzlib=1.3.1=h4ab18f5_1
132 | - llvm-openmp=15.0.7=h0cdce71_0
133 | - markupsafe=2.1.5=py310h2372a71_0
134 | - matplotlib-inline=0.1.7=pyhd8ed1ab_0
135 | - mkl=2022.1.0=hc2b9512_224
136 | - mpc=1.3.1=hfe3b2da_0
137 | - mpfr=4.2.1=h9458935_1
138 | - mpmath=1.3.0=pyhd8ed1ab_0
139 | - ncurses=6.5=h59595ed_0
140 | - nest-asyncio=1.6.0=pyhd8ed1ab_0
141 | - nettle=3.9.1=h7ab15ed_0
142 | - networkx=3.3=pyhd8ed1ab_1
143 | - ocl-icd=2.3.2=hd590300_1
144 | - openh264=2.4.1=h59595ed_0
145 | - openjpeg=2.5.2=h488ebb8_0
146 | - openssl=3.3.1=h4ab18f5_0
147 | - p11-kit=0.24.1=hc5aa10d_0
148 | - packaging=24.1=pyhd8ed1ab_0
149 | - parso=0.8.4=pyhd8ed1ab_0
150 | - pcre2=10.44=h0f59acf_0
151 | - pexpect=4.9.0=pyhd8ed1ab_0
152 | - pickleshare=0.7.5=py_1003
153 | - pillow=10.3.0=py310hf73ecf8_0
154 | - pip=24.0=pyhd8ed1ab_0
155 | - pixman=0.43.2=h59595ed_0
156 | - platformdirs=4.2.2=pyhd8ed1ab_0
157 | - prompt-toolkit=3.0.47=pyha770c72_0
158 | - pthread-stubs=0.4=h36c2ea0_1001
159 | - ptyprocess=0.7.0=pyhd3deb0d_0
160 | - pugixml=1.14=h59595ed_0
161 | - pure_eval=0.2.2=pyhd8ed1ab_0
162 | - pygments=2.18.0=pyhd8ed1ab_0
163 | - pysocks=1.7.1=pyha2e5f31_6
164 | - python=3.10.14=hd12c33a_0_cpython
165 | - python_abi=3.10=4_cp310
166 | - pytorch-cuda=12.1=ha16c6d3_5
167 | - pytorch-mutex=1.0=cuda
168 | - pyyaml=6.0.1=py310h2372a71_1
169 | - pyzmq=26.0.3=py310h6883aea_0
170 | - readline=8.2=h8228510_1
171 | - requests=2.32.3=pyhd8ed1ab_0
172 | - setuptools=70.0.0=pyhd8ed1ab_0
173 | - six=1.16.0=pyh6c4a22f_0
174 | - snappy=1.2.0=hdb0a2a9_1
175 | - stack_data=0.6.2=pyhd8ed1ab_0
176 | - svt-av1=2.1.0=hac33072_0
177 | - sympy=1.12.1=pypyh2585a3b_103
178 | - tbb=2021.12.0=h297d8ca_1
179 | - tk=8.6.13=noxft_h4845f30_101
180 | - torchaudio=2.3.1=py310_cu121
181 | - torchvision=0.18.1=py310_cu121
182 | - tornado=6.4.1=py310hc51659f_0
183 | - traitlets=5.14.3=pyhd8ed1ab_0
184 | - typing_extensions=4.12.2=pyha770c72_0
185 | - urllib3=2.2.2=pyhd8ed1ab_0
186 | - wcwidth=0.2.13=pyhd8ed1ab_0
187 | - wheel=0.43.0=pyhd8ed1ab_1
188 | - widgetsnbextension=4.0.11=pyhd8ed1ab_0
189 | - x264=1!164.3095=h166bdaf_2
190 | - x265=3.5=h924138e_3
191 | - xorg-fixesproto=5.0=h7f98852_1002
192 | - xorg-kbproto=1.0.7=h7f98852_1002
193 | - xorg-libice=1.1.1=hd590300_0
194 | - xorg-libsm=1.2.4=h7391055_0
195 | - xorg-libx11=1.8.9=h8ee46fc_0
196 | - xorg-libxau=1.0.11=hd590300_0
197 | - xorg-libxdmcp=1.1.3=h7f98852_0
198 | - xorg-libxext=1.3.4=h0b41bf4_2
199 | - xorg-libxfixes=5.0.3=h7f98852_1004
200 | - xorg-libxrender=0.9.11=hd590300_0
201 | - xorg-renderproto=0.11.1=h7f98852_1002
202 | - xorg-xextproto=7.3.0=h0b41bf4_1003
203 | - xorg-xproto=7.0.31=h7f98852_1007
204 | - xz=5.2.6=h166bdaf_0
205 | - yaml=0.2.5=h7f98852_2
206 | - zeromq=4.3.5=h75354e8_4
207 | - zipp=3.19.2=pyhd8ed1ab_0
208 | - zlib=1.3.1=h4ab18f5_1
209 | - zstd=1.5.6=ha6fb4c9_0
210 | - pip:
211 | - absl-py==2.1.0
212 | - accelerate==0.31.0
213 | - aiohttp==3.9.5
214 | - aiosignal==1.3.1
215 | - annotated-types==0.7.0
216 | - antlr4-python3-runtime==4.9.3
217 | - anyio==4.4.0
218 | - async-timeout==4.0.3
219 | - attrs==23.2.0
220 | - bitsandbytes==0.43.1
221 | - blis==0.7.11
222 | - bottle==0.12.25
223 | - cachetools==5.3.3
224 | - catalogue==2.0.10
225 | - cattrs==22.2.0
226 | - chardet==5.2.0
227 | - click==8.1.7
228 | - cloudpathlib==0.18.1
229 | - cloudpickle==3.0.0
230 | - cmake==3.29.5.1
231 | - colorama==0.4.6
232 | - confection==0.1.5
233 | - contourpy==1.2.1
234 | - crfm-helm==0.5.2
235 | - cycler==0.12.1
236 | - cymem==2.0.8
237 | - dacite==1.8.1
238 | - dataproperty==1.0.1
239 | - datasets==2.20.0
240 | - dill==0.3.8
241 | - diskcache==5.6.3
242 | - distro==1.9.0
243 | - dnspython==2.6.1
244 | - docker-pycreds==0.4.0
245 | - einops==0.8.0
246 | - email-validator==2.1.2
247 | - evaluate==0.4.2
248 | - exceptiongroup==1.2.1
249 | - fastapi==0.111.0
250 | - fastapi-cli==0.0.4
251 | - fonttools==4.53.1
252 | - frozenlist==1.4.1
253 | - fsspec==2024.5.0
254 | - ftfy==6.2.0
255 | - fugashi==1.3.2
256 | - fuzzywuzzy==0.18.0
257 | - gitdb==4.0.11
258 | - gitpython==3.1.43
259 | - google-api-core==2.19.0
260 | - google-api-python-client==2.133.0
261 | - google-auth==2.30.0
262 | - google-auth-httplib2==0.2.0
263 | - googleapis-common-protos==1.63.1
264 | - h11==0.14.0
265 | - httpcore==1.0.5
266 | - httplib2==0.22.0
267 | - httptools==0.6.1
268 | - httpx==0.27.0
269 | - huggingface-hub==0.23.4
270 | - hydra-core==1.3.2
271 | - importlib-resources==5.13.0
272 | - interegular==0.3.3
273 | - jieba==0.42.1
274 | - joblib==1.4.2
275 | - jsonlines==4.0.0
276 | - jsonschema==4.22.0
277 | - jsonschema-specifications==2023.12.1
278 | - kiwisolver==1.4.5
279 | - langcodes==3.4.0
280 | - language-data==1.2.0
281 | - lark==1.1.9
282 | - llvmlite==0.43.0
283 | - lm-eval==0.4.2
284 | - lm-format-enforcer==0.10.1
285 | - lxml==5.2.2
286 | - mako==1.3.5
287 | - marisa-trie==1.2.0
288 | - markdown-it-py==3.0.0
289 | - matplotlib==3.9.1
290 | - mbstrdecoder==1.1.3
291 | - mdurl==0.1.2
292 | - more-itertools==10.3.0
293 | - msgpack==1.0.8
294 | - multidict==6.0.5
295 | - multiprocess==0.70.16
296 | - murmurhash==1.0.10
297 | - ninja==1.11.1.1
298 | - nltk==3.8.1
299 | - numba==0.60.0
300 | - numexpr==2.10.0
301 | - numpy==1.26.4
302 | - nvidia-cublas-cu12==12.1.3.1
303 | - nvidia-cuda-cupti-cu12==12.1.105
304 | - nvidia-cuda-nvrtc-cu12==12.1.105
305 | - nvidia-cuda-runtime-cu12==12.1.105
306 | - nvidia-cudnn-cu12==8.9.2.26
307 | - nvidia-cufft-cu12==11.0.2.54
308 | - nvidia-curand-cu12==10.3.2.106
309 | - nvidia-cusolver-cu12==11.4.5.107
310 | - nvidia-cusparse-cu12==12.1.0.106
311 | - nvidia-ml-py==12.555.43
312 | - nvidia-nccl-cu12==2.20.5
313 | - nvidia-nvjitlink-cu12==12.5.40
314 | - nvidia-nvtx-cu12==12.1.105
315 | - omegaconf==2.3.0
316 | - openai==1.34.0
317 | - orjson==3.10.5
318 | - outlines==0.0.45
319 | - pandas==2.2.2
320 | - parameterized==0.9.0
321 | - pathvalidate==3.2.0
322 | - peft==0.11.1
323 | - portalocker==2.8.2
324 | - preshed==3.0.9
325 | - prometheus-client==0.20.0
326 | - prometheus-fastapi-instrumentator==7.0.0
327 | - proto-plus==1.23.0
328 | - protobuf==4.25.3
329 | - psutil==5.9.8
330 | - py-cpuinfo==9.0.0
331 | - pyairports==2.1.1
332 | - pyarrow==16.1.0
333 | - pyarrow-hotfix==0.6
334 | - pyasn1==0.6.0
335 | - pyasn1-modules==0.4.0
336 | - pybind11==2.12.0
337 | - pycountry==24.6.1
338 | - pydantic==2.7.4
339 | - pydantic-core==2.18.4
340 | - pyext==0.7
341 | - pyhocon==0.3.61
342 | - pyparsing==3.1.2
343 | - pytablewriter==1.2.0
344 | - python-dateutil==2.9.0.post0
345 | - python-dotenv==1.0.1
346 | - python-multipart==0.0.9
347 | - pytz==2024.1
348 | - ray==2.24.0
349 | - referencing==0.35.1
350 | - regex==2024.5.15
351 | - retrying==1.3.4
352 | - rich==13.7.1
353 | - rouge==1.0.1
354 | - rouge-score==0.1.2
355 | - rpds-py==0.18.1
356 | - rsa==4.9
357 | - sacrebleu==2.4.2
358 | - safetensors==0.4.3
359 | - scikit-learn==1.5.0
360 | - scipy==1.13.1
361 | - seaborn==0.13.2
362 | - sentencepiece==0.2.0
363 | - sentry-sdk==2.5.1
364 | - setproctitle==1.3.3
365 | - shellingham==1.5.4
366 | - smart-open==7.0.4
367 | - smmap==5.0.1
368 | - sniffio==1.3.1
369 | - spacy==3.7.5
370 | - spacy-legacy==3.0.12
371 | - spacy-loggers==1.0.5
372 | - sqlitedict==1.7.0
373 | - srsly==2.4.8
374 | - starlette==0.37.2
375 | - tabledata==1.3.3
376 | - tabulate==0.9.0
377 | - tcolorpy==0.1.6
378 | - thinc==8.2.4
379 | - threadpoolctl==3.5.0
380 | - tiktoken==0.7.0
381 | - tokenizers==0.19.1
382 | - torch==2.3.0
383 | - tqdm==4.66.4
384 | - tqdm-multiprocess==0.0.11
385 | - transformers==4.41.2
386 | - triton==2.3.0
387 | - typepy==1.3.2
388 | - typer==0.12.3
389 | - tzdata==2024.1
390 | - ujson==5.10.0
391 | - uncertainty-calibration==0.1.4
392 | - uritemplate==4.1.1
393 | - uvicorn==0.30.1
394 | - uvloop==0.19.0
395 | - vllm==0.5.0.post1
396 | - vllm-flash-attn==2.5.9
397 | - wandb==0.17.2
398 | - wasabi==1.1.3
399 | - watchfiles==0.22.0
400 | - weasel==0.4.1
401 | - websockets==12.0
402 | - word2number==1.1
403 | - wrapt==1.16.0
404 | - xformers==0.0.26.post1
405 | - xxhash==3.4.1
406 | - yarl==1.9.4
407 | - zstandard==0.18.0
408 | - highlight-text
409 |
--------------------------------------------------------------------------------
/stateless_parallel_modules/attention.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from dataclasses import dataclass
3 | from typing import Optional, Tuple, Union, Dict, Callable
4 |
5 |
6 | import torch
7 | from torch import nn
8 | import torch.utils.checkpoint
9 | import torch.nn.functional as F
10 |
11 | from transformers.models.llama.modeling_llama import (
12 | apply_rotary_pos_emb)
13 |
14 | from .base import (
15 | StatelessGeneralizedOperation, GeneralizedLinear,
16 | StatelessGeneralizedModule)
17 | from utils import get_nonlinearity
18 |
19 |
20 | class RotaryEmbedding(nn.Module):
21 | '''Based on transformers.models.llama.modeling_llama.LlamaRotaryEmbedding'''
22 |
23 | def __init__(
24 | self,
25 | dim,
26 | max_position_embeddings=2048,
27 | base=10000,
28 | device=None,
29 | scaling_factor=1.0,
30 | ):
31 | super().__init__()
32 | self.scaling_factor = scaling_factor
33 | self.dim = dim
34 | self.max_position_embeddings = max_position_embeddings
35 | self.base = base
36 | inv_freq = 1.0 / (self.base ** (
37 | torch.arange(0, self.dim, 2,
38 | dtype=torch.int64).float().to(device) / self.dim))
39 | self.register_buffer("inv_freq", inv_freq, persistent=False)
40 |
41 | self.max_seq_len_cached = max_position_embeddings
42 | t = torch.arange(
43 | self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(
44 | self.inv_freq)
45 | t = t / self.scaling_factor
46 | freqs = torch.outer(t, self.inv_freq)
47 |
48 | emb = torch.cat((freqs, freqs), dim=-1)
49 | self.register_buffer(
50 | "_cos_cached", emb.cos().to(torch.get_default_dtype()),
51 | persistent=False)
52 | self.register_buffer(
53 | "_sin_cached", emb.sin().to(torch.get_default_dtype()),
54 | persistent=False)
55 |
56 | @torch.no_grad()
57 | def forward(self, x, position_ids):
58 | inv_freq_expanded = self.inv_freq[None, :, None].float().expand(
59 | position_ids.shape[0], -1, 1)
60 | position_ids_expanded = position_ids[:, None, :].float()
61 |
62 | # Force float32 since bfloat16 loses precision on long contexts
63 | # See https://github.com/huggingface/transformers/pull/29285
64 | device_type = x.device.type
65 | device_type = device_type if isinstance(
66 | device_type, str) and device_type != "mps" else "cpu"
67 | with torch.autocast(device_type=device_type, enabled=False):
68 | freqs = (
69 | inv_freq_expanded.float() @
70 | position_ids_expanded.float()).transpose(1, 2)
71 | emb = torch.cat((freqs, freqs), dim=-1)
72 | cos = emb.cos()
73 | sin = emb.sin()
74 | return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
75 |
76 |
77 | @dataclass
78 | class StatelessAttentionParams:
79 | input_dim: Optional[int]
80 | # defaults to hidden dim.
81 | hidden_dim: Optional[int]
82 | output_dim: Optional[int]
83 | num_heads: int
84 | bias: bool
85 | max_position_id: int
86 | # rope params
87 | use_rope: bool
88 | rope_theta: float
89 | # None, forward, or backward
90 | masking_strategy: Optional[str]
91 |
92 |
93 | class StatelessAttention(StatelessGeneralizedModule):
94 | def __init__(self,
95 | attention_params: StatelessAttentionParams,):
96 | self.save_configs(attention_params=attention_params)
97 | StatelessGeneralizedModule.__init__(
98 | self=self,
99 | input_features=self.input_dim,
100 | output_features=self.output_dim,
101 | init_module=True,)
102 | if (self.input_dim is not None and
103 | self.output_dim is not None and
104 | self.hidden_dim is not None):
105 |
106 | self.instantiate_and_setup_ops(
107 | input_features=self.input_dim,
108 | hidden_features=self.hidden_dim,
109 | output_features=self.output_dim,
110 | preceding_module=None,
111 | default_output_features_mult=1,
112 | )
113 |
114 | def instantiate_and_setup_ops(
115 | self,
116 | input_features: Optional[int] = None,
117 | hidden_features: Optional[int] = None,
118 | output_features: Optional[int] = None,
119 | preceding_module=None,
120 | default_output_features_mult: int = 1,
121 | **kwargs,
122 | ):
123 |
124 | if (self.input_dim is None or
125 | self.output_dim is None or
126 | self.hidden_dim is None):
127 |
128 | self.instantiate_model(
129 | input_features=input_features,
130 | output_features=output_features,
131 | preceding_module=preceding_module,
132 | default_output_features_mult=default_output_features_mult,
133 | )
134 | self.input_dim = self.input_features
135 | self.output_dim = self.output_features
136 | if self.hidden_dim is None and hidden_features is not None:
137 | self.hidden_dim = hidden_features
138 | elif self.hidden_dim is None:
139 | print('Warning: hidden features not specified setting to ' +
140 | f'{self.output_features} (output features)')
141 | self.hidden_dim = self.output_features
142 |
143 | self.head_dim = self.hidden_dim // self.num_heads
144 | if (self.head_dim * self.num_heads) != self.hidden_dim:
145 | raise ValueError("hidden_dim must be divisible by num_heads")
146 |
147 | self.linear_op = GeneralizedLinear()
148 |
149 | operation_kwargs = dict(
150 | bias=self.bias,
151 | )
152 | qkv_kwargs = dict(
153 | in_features=self.input_dim,
154 | out_features=self.hidden_dim*3,
155 | )
156 | o_kwargs = dict(
157 | in_features=self.hidden_dim,
158 | out_features=self.output_dim,
159 | )
160 |
161 | self._init_rope()
162 |
163 | self.setup_operations(
164 | operations=[self.linear_op, self.linear_op],
165 | operation_kwargs=operation_kwargs,
166 | operation_kwargs_overrides_list=[qkv_kwargs, o_kwargs],
167 | )
168 |
169 | def save_configs(self, attention_params: StatelessAttentionParams,):
170 | self.config = attention_params
171 | self.input_dim = attention_params.input_dim
172 | self.hidden_dim = attention_params.hidden_dim
173 | self.output_dim = attention_params.output_dim
174 | self.num_heads = attention_params.num_heads
175 |
176 | self.multiple_heads = self.num_heads > 1
177 |
178 | self.bias = attention_params.bias
179 | self.max_position_id = attention_params.max_position_id
180 | self.use_rope = attention_params.use_rope
181 | self.rope_theta = attention_params.rope_theta
182 | self.masking_strategy = attention_params.masking_strategy
183 | self.apply_causal_mask = False
184 | self.backward_causal_mask = False
185 | if self.masking_strategy is not None:
186 | self.apply_causal_mask = True
187 | if self.masking_strategy == 'backward':
188 | self.backward_causal_mask = True
189 | else:
190 | assert self.masking_strategy == 'forward'
191 |
192 | def _init_rope(self):
193 | self.rotary_emb = RotaryEmbedding(
194 | dim=self.head_dim,
195 | max_position_embeddings=self.max_position_id,
196 | base=self.rope_theta,
197 | )
198 |
199 | def forward(
200 | self,
201 | inputs: torch.Tensor,
202 | *args,
203 | n_parallel_dimensions: Optional[int] = None,
204 | attn_mask: Optional[torch.Tensor] = None,
205 | position_ids: Optional[torch.LongTensor] = None,
206 | **kwargs,
207 | ) -> torch.Tensor:
208 | *batch_dims, num_tokens, input_dim = inputs.size()
209 |
210 | weight, bias = self.parameters_per_layer[0]
211 | qkv_states = self.linear_op(
212 | input=inputs,
213 | weight=weight,
214 | bias=bias,
215 | parallel_operations=self.parallel_operations,
216 | n_parallel_dimensions=n_parallel_dimensions,
217 | )
218 |
219 | # 3 dimensional - flatten all batch dims
220 | qkv_states = qkv_states.flatten(start_dim=0, end_dim=-3)
221 | # 2 dimensional - flatten all batch dims
222 | position_ids = position_ids.flatten(start_dim=0, end_dim=-2)
223 |
224 | if not self.multiple_heads:
225 | # add singleton dimension to use attention over faster kernel
226 | qkv_states = qkv_states.unsqueeze_(1)
227 |
228 | query_states, key_states, value_states = torch.chunk(
229 | qkv_states, chunks=3, dim=-1)
230 |
231 | if self.multiple_heads:
232 | total_batch_dim = qkv_states.shape[0]
233 | query_states = query_states.view(
234 | total_batch_dim, num_tokens,
235 | self.num_heads, self.head_dim).transpose(-2, -3)
236 | key_states = key_states.view(
237 | total_batch_dim, num_tokens,
238 | self.num_heads, self.head_dim).transpose(-2, -3)
239 | value_states = value_states.view(
240 | total_batch_dim, num_tokens,
241 | self.num_heads, self.head_dim).transpose(-2, -3)
242 |
243 | if self.use_rope:
244 | cos, sin = self.rotary_emb(value_states, position_ids)
245 | query_states, key_states = apply_rotary_pos_emb(
246 | query_states, key_states, cos, sin)
247 |
248 | if self.apply_causal_mask or attn_mask is not None:
249 | min_dtype = torch.finfo(inputs.dtype).min
250 | if self.apply_causal_mask:
251 | causal_mask = torch.ones(
252 | (num_tokens, num_tokens),
253 | device=key_states.device,
254 | dtype=torch.bool,
255 | )
256 | if self.backward_causal_mask:
257 | # one for all lower diagonal tokens to be masked
258 | causal_mask = torch.tril(causal_mask, diagonal=-1)
259 | else:
260 | # one for all upper diagonal tokens to be masked
261 | causal_mask = torch.triu(causal_mask, diagonal=1)
262 | if attn_mask is not None:
263 | # In case num tokens < size of the attention mask
264 | inv_attn_mask = torch.logical_not(attn_mask[..., -num_tokens:])
265 |
266 | if self.apply_causal_mask:
267 | causal_mask = torch.logical_or(inv_attn_mask, causal_mask)
268 | else:
269 | causal_mask = inv_attn_mask
270 | causal_mask = causal_mask*min_dtype
271 | elif self.apply_causal_mask:
272 | causal_mask = causal_mask*min_dtype
273 | else:
274 | causal_mask = None
275 |
276 | attn_output = F.scaled_dot_product_attention(
277 | query=query_states,
278 | key=key_states,
279 | value=value_states,
280 | attn_mask=causal_mask,
281 | is_causal=False,
282 | scale=None,
283 | )
284 | if self.multiple_heads:
285 | attn_output = attn_output.transpose(-2, -3).contiguous()
286 |
287 | attn_output = attn_output.reshape(
288 | *batch_dims, num_tokens, self.hidden_dim)
289 | else:
290 | attn_output = attn_output.view(
291 | *batch_dims, num_tokens, self.hidden_dim)
292 |
293 | weight, bias = self.parameters_per_layer[1]
294 | attn_output = self.linear_op(
295 | input=attn_output,
296 | weight=weight,
297 | bias=bias,
298 | parallel_operations=self.parallel_operations,
299 | n_parallel_dimensions=n_parallel_dimensions,
300 | )
301 | return attn_output
302 |
303 |
304 | class MonoHeadStatelessAttention(StatelessAttention):
305 | """Multi-headed attention from 'Attention Is All You Need' paper"""
306 | # NOTE: hydra specification should have this as partial, since input_dim
307 | # might be unwieldy to always manually specify
308 |
309 | def __init__(self,
310 | attention_params: StatelessAttentionParams,):
311 | # TODO
312 | self.save_configs(attention_params=attention_params)
313 | StatelessAttention.__init__(
314 | self=self, attention_params=attention_params)
315 | assert self.num_heads == 1
316 |
317 | def forward(
318 | self,
319 | inputs: torch.Tensor,
320 | *args,
321 | n_parallel_dimensions: Optional[int] = None,
322 | attn_mask: Optional[torch.Tensor] = None,
323 | position_ids: Optional[torch.LongTensor] = None,
324 | **kwargs,
325 | ) -> torch.Tensor:
326 | *batch_dims, num_tokens, input_dim = inputs.size()
327 |
328 | # else:
329 | weight, bias = self.parameters_per_layer[0]
330 | qkv_states = self.linear_op(
331 | input=inputs,
332 | weight=weight,
333 | bias=bias,
334 | parallel_operations=self.parallel_operations,
335 | # handle reshaping internally
336 | n_parallel_dimensions=n_parallel_dimensions,
337 | )
338 |
339 | # 3 dimensional - flatten all batch dims
340 | qkv_states = qkv_states.flatten(start_dim=0, end_dim=-3)
341 | # 2 dimensional - flatten all batch dims
342 | position_ids = position_ids.flatten(start_dim=0, end_dim=-2)
343 |
344 | # total_batch_dim x num_tokens x head_dim (single head)
345 | query_states, key_states, value_states = torch.chunk(
346 | qkv_states, chunks=3, dim=-1)
347 |
348 | if self.use_rope:
349 | cos, sin = self.rotary_emb(value_states, position_ids)
350 | query_states, key_states = apply_rotary_pos_emb(
351 | query_states, key_states, cos, sin)
352 |
353 | if self.apply_causal_mask:
354 | causal_mask = torch.ones(
355 | (num_tokens, num_tokens),
356 | device=key_states.device,
357 | dtype=torch.bool,
358 | )
359 | if self.backward_causal_mask:
360 | # one for all lower diagonal tokens to be masked
361 | causal_mask = torch.tril(causal_mask, diagonal=-1)
362 | else:
363 | # one for all upper diagonal tokens to be masked
364 | causal_mask = torch.triu(causal_mask, diagonal=1)
365 | if attn_mask is not None:
366 | # In case num tokens < size of the attention mask
367 | inv_attn_mask = torch.logical_not(attn_mask[..., -num_tokens:])
368 |
369 | if self.apply_causal_mask:
370 | causal_mask = torch.logical_or(inv_attn_mask, causal_mask)
371 | else:
372 | causal_mask = inv_attn_mask
373 |
374 | min_dtype = torch.finfo(inputs.dtype).min
375 | causal_mask = causal_mask*min_dtype
376 | elif self.apply_causal_mask:
377 | causal_mask = causal_mask*min_dtype
378 | else:
379 | causal_mask = None
380 |
381 | attn_output = F.scaled_dot_product_attention(
382 | query=query_states,
383 | key=key_states,
384 | value=value_states,
385 | attn_mask=causal_mask,
386 | is_causal=False,
387 | scale=None, # defaults to 1/root(head_dim)
388 | )
389 |
390 | attn_output = attn_output.view(
391 | *batch_dims, num_tokens, self.hidden_dim)
392 |
393 | weight, bias = self.parameters_per_layer[1]
394 | attn_output = self.linear_op(
395 | input=attn_output,
396 | weight=weight,
397 | bias=bias,
398 | parallel_operations=self.parallel_operations,
399 | n_parallel_dimensions=n_parallel_dimensions,
400 | )
401 |
402 | return attn_output
403 |
--------------------------------------------------------------------------------
/memory_policy/base_dynamic.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pdb
3 | import copy
4 | import math
5 | import numpy as np
6 | from dataclasses import dataclass
7 | from typing import Optional, Tuple, Union, Dict, List
8 |
9 | import abc
10 | import torch
11 | from torch import nn
12 | import torch.utils.checkpoint
13 | import torch.nn.functional as F
14 | from torch.cuda.amp import autocast
15 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
16 | from transformers import LlamaPreTrainedModel
17 | from transformers.cache_utils import Cache, DynamicCache, StaticCache
18 | from .base import MemoryPolicy, ParamMemoryPolicy
19 |
20 |
21 | def compute_recency(position_ids, start_recency_from=1):
22 | most_recent_positions = position_ids[..., [-1]]
23 |
24 | cache_age = most_recent_positions - position_ids + start_recency_from
25 | return cache_age
26 |
27 | def compute_recency_scores(position_ids, recency_exp, recency_coeff,):
28 |
29 | cache_age = compute_recency(position_ids=position_ids, start_recency_from=1)
30 | recency_scores = torch.pow(1/cache_age, exponent=recency_exp)
31 | return recency_scores*recency_coeff, recency_scores
32 |
33 | def threshold_score_idxs(
34 | masked_full_scores,
35 | dynamic_thresh,
36 |
37 |
38 | preserve_order=True,
39 | cache_size=None,
40 | ):
41 |
42 |
43 |
44 | num_samples = masked_full_scores.shape[-1]
45 | if cache_size is not None and num_samples > cache_size:
46 | sorted_scores, indices = torch.topk(
47 | masked_full_scores, k=cache_size, sorted=True, dim=-1)
48 | sorted_scores = sorted_scores.flip(dims=(-1,))
49 | indices = indices.flip(dims=(-1,))
50 | else:
51 | sorted_scores, indices = torch.sort(
52 | masked_full_scores, descending=False, dim=-1)
53 |
54 |
55 |
56 | thresholded_scores = sorted_scores >= dynamic_thresh
57 |
58 |
59 | first_above_thresh = torch.sum(~thresholded_scores, dim=-1,
60 | dtype=torch.long)
61 |
62 | discard_idx = torch.min(first_above_thresh)
63 |
64 |
65 |
66 | retained_idxs = indices[..., discard_idx:]
67 |
68 |
69 | new_mask = thresholded_scores[..., discard_idx:]
70 |
71 | if preserve_order:
72 |
73 | retained_idxs, _ = retained_idxs.sort(descending=False, dim=-1,)
74 | return retained_idxs, new_mask
75 |
76 | @dataclass
77 | class RecencyParams:
78 | recency_coeff: float
79 | recency_exp: float
80 |
81 | @dataclass
82 | class AttentionParams:
83 | attn_coeff: float
84 | attn_ema_coeff: float
85 |
86 | back_attn_coeff: float = 0
87 |
88 |
89 |
90 |
91 |
92 | class DynamicMemoryPolicy(MemoryPolicy):
93 |
94 |
95 | def __init__(self,
96 | cache_size: Optional[int] = None,
97 | init_module: bool = True,
98 | ):
99 |
100 | MemoryPolicy.__init__(
101 | self,
102 | cache_size=cache_size,
103 | init_module=init_module,
104 | )
105 |
106 | self._record_mask_based_sparsity = False
107 | self._record_stats_per_head = False
108 | self._record_recency_stats = False
109 |
110 | @property
111 | def record_mask_based_sparsity(self,):
112 | return self._record_mask_based_sparsity
113 |
114 | @property
115 | def record_stats_per_head(self,):
116 | return self._record_mask_based_sparsity
117 |
118 | @record_mask_based_sparsity.setter
119 | def record_mask_based_sparsity(self, value):
120 | self._record_mask_based_sparsity = value
121 |
122 | @record_stats_per_head.setter
123 | def record_stats_per_head(self, value):
124 | self._record_stats_per_head = value
125 |
126 | def is_dynamic(self,):
127 | return True
128 |
129 | def select_max_score_idxs(
130 | self,
131 | masked_full_scores,
132 | cache_size,
133 |
134 |
135 | preserve_order=True,
136 | ):
137 |
138 |
139 |
140 |
141 |
142 | sorted_top_scores, retained_idxs = torch.topk(
143 | input=masked_full_scores, k=cache_size, largest=True,
144 | sorted=False,
145 | )
146 |
147 | if preserve_order:
148 |
149 | retained_idxs, _ = retained_idxs.sort(descending=False, dim=-1,)
150 |
151 | new_mask = torch.ones_like(retained_idxs)
152 |
153 | return retained_idxs, new_mask
154 |
155 | def threshold_score_idxs(
156 | self,
157 | masked_full_scores,
158 | dynamic_thresh,
159 |
160 |
161 | preserve_order=True,
162 | cache_size=None,
163 | ):
164 |
165 | retained_idxs, new_mask = threshold_score_idxs(
166 | masked_full_scores=masked_full_scores,
167 | dynamic_thresh=dynamic_thresh,
168 | preserve_order=preserve_order,
169 | cache_size=cache_size,
170 | )
171 | return retained_idxs, new_mask
172 |
173 | def select_new_dynamic_idxs(
174 | self,
175 | masked_full_scores,
176 | dynamic_thresh,
177 | cache_size,
178 | preserve_order=True,
179 | ):
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 | return self.threshold_score_idxs(
188 | masked_full_scores=masked_full_scores,
189 | dynamic_thresh=dynamic_thresh,
190 | preserve_order=preserve_order,
191 | cache_size=cache_size,
192 | )
193 |
194 | def need_new_dynamic_idxs(
195 | self,
196 | masked_full_scores,
197 | cache_size,
198 | ):
199 | if cache_size is None:
200 | return True
201 | else:
202 | num_samples = masked_full_scores.shape[-1]
203 | if num_samples > cache_size:
204 | return True
205 | else:
206 | return False
207 |
208 |
209 |
210 | def save_recency_params(self, recency_params: RecencyParams):
211 | self.recency_coeff = recency_params.recency_coeff
212 | self.recency_exp = recency_params.recency_exp
213 | self.use_recency_scores = False
214 | if self.recency_coeff is not None:
215 | if self.recency_coeff > 0:
216 | self.use_recency_scores = True
217 |
218 | print(f'Using recency score: {self.use_recency_scores}')
219 |
220 |
221 | def save_attention_params(self, attention_params: AttentionParams):
222 |
223 | self.attn_coeff = attention_params.attn_coeff
224 | self.attn_ema_coeff = attention_params.attn_ema_coeff
225 |
226 | self.use_full_attn_scores = False
227 | if self.attn_coeff is not None:
228 | if self.attn_coeff > 0:
229 | self.use_full_attn_scores = True
230 |
231 |
232 |
233 |
234 | self.back_attn_coeff = attention_params.back_attn_coeff
235 |
236 |
237 | self.use_back_attn_scores = False
238 |
239 | if self.back_attn_coeff is not None:
240 | if self.back_attn_coeff > 0:
241 | self.use_back_attn_scores = True
242 |
243 |
244 |
245 |
246 |
247 | self.use_attention_scores = (self.use_full_attn_scores or
248 | self.use_back_attn_scores)
249 | if self.use_full_attn_scores and self.use_back_attn_scores:
250 | self.use_both_attn_scores = True
251 | else:
252 | self.use_both_attn_scores = False
253 |
254 |
255 | print(f'Using attention score: {self.use_attention_scores}')
256 |
257 |
258 |
259 | def process_position_ids(self, position_ids, num_all_tokens, num_new_tokens,
260 | attention_mask):
261 | '''Replicates position ids for each head'''
262 | if position_ids is None:
263 | assert num_all_tokens == num_new_tokens
264 | assert attention_mask is not None
265 |
266 | position_ids = torch.cumsum(attention_mask, dim=-1) - 1
267 | position_ids = position_ids.unsqueeze(-2)
268 | return position_ids.expand(-1, self.num_heads, -1)
269 |
270 | def compute_recency_scores(self, position_ids, recency_exp, recency_coeff):
271 | return compute_recency_scores(
272 | position_ids=position_ids,
273 | recency_exp=recency_exp,
274 | recency_coeff=recency_coeff,
275 | )
276 |
277 | def initialize_cache_masks(self,):
278 | self.cache_masks = [None for _ in range(self.num_memory_layers)]
279 |
280 | def finalize_registration(self,):
281 | MemoryPolicy.finalize_registration(self,)
282 | self.initialize_cache_masks()
283 | self.initialize_stat_objects()
284 |
285 |
286 | def initialize_stat_objects(self, initialize_mask_spasity=True):
287 | self.dynamic_cache_sizes = [[] for _ in range(
288 | self.num_memory_layers)]
289 | self.final_dynamic_cache_sizes = [[] for _ in range(
290 | self.num_memory_layers)]
291 |
292 | if initialize_mask_spasity:
293 | self.initialize_mask_based_sparsity()
294 | if self._record_recency_stats:
295 | self.initialize_recency_stats()
296 |
297 | def initialize_mask_based_sparsity(self, ):
298 | self.dynamic_mask_sample_sparsity = [[] for _ in range(
299 | self.num_memory_layers)]
300 | self.dynamic_mask_head_sparsity = [[] for _ in range(
301 | self.num_memory_layers)]
302 | if self.record_stats_per_head:
303 | raise NotImplementedError
304 | self.dynamic_mask_head_sparsity_dicts = [{}]
305 |
306 | def initialize_recency_stats(self,):
307 | self.recorded_final_recencies = [[] for _ in range(
308 | self.num_memory_layers)]
309 | self._record_recency_stats = True
310 |
311 | def record_recency_stats(self, layer_id, position_ids):
312 | if position_ids is not None:
313 | recency = compute_recency(position_ids=position_ids)
314 | self.recorded_final_recencies[layer_id].append(
315 | recency.float().mean().item())
316 |
317 | def get_param_stats(self, reset=True) -> dict:
318 | stats = dict()
319 | if self.record_eval_stats:
320 | if len(self.dynamic_cache_sizes[0]):
321 | all_final_cache_sizes = []
322 | for i in range(self.num_memory_layers):
323 | stats_key_prefix = f'mem_stats/layer_id_{i}/'
324 | stats[stats_key_prefix + 'dynamic_cache_sizes'] = np.mean(
325 | self.dynamic_cache_sizes[i])
326 | final_cache_sizes = (self.final_dynamic_cache_sizes[i] +
327 | [self.dynamic_cache_sizes[i][-1]])
328 | all_final_cache_sizes += final_cache_sizes
329 | stats[stats_key_prefix + 'final_dynamic_cache_sizes'] = (
330 | np.mean(final_cache_sizes))
331 |
332 | stats_key_prefix = 'mem_stats/overall/'
333 | stats[stats_key_prefix + 'dynamic_cache_sizes'] = np.mean(
334 | [v for vs in self.dynamic_cache_sizes for v in vs])
335 | stats[stats_key_prefix + 'final_dynamic_cache_sizes'] = np.mean(
336 | [cs[-1] for cs in self.dynamic_cache_sizes])
337 | stats[stats_key_prefix + 'final_dynamic_cache_sizes'] = (
338 | np.mean(all_final_cache_sizes))
339 | if self.record_eval_stats or self.record_mask_based_sparsity:
340 | if len(self.dynamic_mask_sample_sparsity[0]):
341 | for i in range(self.num_memory_layers):
342 | stats_key_prefix = f'mem_stats/layer_id_{i}/'
343 | stats[stats_key_prefix + 'unmasked_samples'] = np.mean(
344 | self.dynamic_mask_sample_sparsity[i])
345 | stats[stats_key_prefix + 'unmasked_samples_per_head'] = (
346 | np.mean(self.dynamic_mask_head_sparsity[i]))
347 | stats[stats_key_prefix + 'unmasked_sample_final'] = np.mean(
348 | self.dynamic_mask_sample_sparsity[i][-1])
349 | stats[stats_key_prefix +
350 | 'unmasked_sample_per_head_final'] = np.mean(
351 | self.dynamic_mask_head_sparsity[i][-1])
352 |
353 |
354 |
355 | stats_key_prefix = 'mem_stats/overall/'
356 | stats[stats_key_prefix + 'unmasked_samples'] = np.mean(
357 | [v for vs in self.dynamic_mask_sample_sparsity for v in vs])
358 | stats[stats_key_prefix + 'unmasked_samples_per_head'] = np.mean(
359 | [v for vs in self.dynamic_mask_head_sparsity for v in vs])
360 | stats[stats_key_prefix + 'unmasked_sample_final'] = np.mean(
361 | [cs[-1] for cs in self.dynamic_mask_sample_sparsity])
362 | stats[stats_key_prefix + 'unmasked_sample_per_head_final'] = (
363 | np.mean([cs[-1] for cs in self.dynamic_mask_head_sparsity]))
364 |
365 |
366 |
367 | if self._record_recency_stats:
368 | for i in range(self.num_memory_layers):
369 | stats_key_prefix = f'mem_stats/layer_id_{i}/'
370 | stats[stats_key_prefix + 'final_recencies'] = np.mean(
371 | self.recorded_final_recencies[i])
372 | stats_key_prefix = 'mem_stats/overall/'
373 | stats[stats_key_prefix + 'final_recencies'] = np.mean(
374 | [v for vs in self.recorded_final_recencies for v in vs])
375 |
376 | if reset:
377 | self.initialize_stat_objects()
378 | return stats
379 |
380 | def record_dynamic_stats(self, layer_id, cache_size, new_sequences=False):
381 | if new_sequences and len(self.dynamic_cache_sizes[layer_id]) > 0:
382 | self.final_dynamic_cache_sizes[layer_id].append(
383 | self.dynamic_cache_sizes[layer_id][-1])
384 | self.dynamic_cache_sizes[layer_id].append(int(cache_size))
385 |
386 | def record_mask_dynamic_stats(
387 | self, layer_id,
388 | cache_mask,
389 | ):
390 |
391 | unmasked_samples_per_head = cache_mask.to(torch.float32).sum(-1)
392 |
393 | self.dynamic_mask_sample_sparsity[layer_id].append(
394 | torch.max(unmasked_samples_per_head, dim=-1)[0].mean().item())
395 | self.dynamic_mask_head_sparsity[layer_id].append(
396 | unmasked_samples_per_head.mean().item())
397 |
398 | class DynamicParamMemoryPolicy(ParamMemoryPolicy, DynamicMemoryPolicy):
399 |
400 | def __init__(
401 | self,
402 | base_param_size,
403 | pop_size,
404 | per_head,
405 | per_layer,
406 | additional_shared_params=0,
407 | learnable_params: Optional[Dict[str, Union[str, tuple]]] = None,
408 | learned_params: Optional[Dict[str, bool]] = None,
409 | component_names: Optional[List[str]] = None,
410 | cache_size: Optional[int] = None,
411 | init_module: bool = True,
412 | lazy_param_num: bool = False,
413 | ):
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 | ParamMemoryPolicy.__init__(
423 | self, cache_size=cache_size, base_param_size=base_param_size,
424 | pop_size=pop_size, per_head=per_head, per_layer=per_layer,
425 | additional_shared_params=additional_shared_params,
426 | learnable_params=learnable_params,
427 | learned_params=learned_params,
428 | component_names=component_names,
429 | init_module=init_module,
430 | lazy_param_num=lazy_param_num,
431 | )
432 |
433 | def is_dynamic(self,):
434 | return True
435 |
436 | def finalize_registration(self,):
437 |
438 | ParamMemoryPolicy.finalize_registration(self,)
439 | self.initialize_cache_masks()
440 | self.initialize_stat_objects()
441 |
442 |
443 |
--------------------------------------------------------------------------------