├── cfgs ├── run │ ├── default.yaml │ ├── base_memory_policy │ │ ├── none.yaml │ │ └── deep │ │ │ └── bam │ │ │ └── base_bam.yaml │ ├── namm_bam_i1.yaml │ ├── namm_bam_i2.yaml │ ├── namm_bam_i3.yaml │ ├── namm_bam_eval.yaml │ └── namm_bam_eval_choubun.yaml ├── auxiliary_loss │ └── none.yaml ├── typing │ ├── default.yaml │ └── half_attn.yaml ├── policy │ ├── deep_scoring │ │ ├── bam.yaml │ │ ├── mlp.yaml │ │ └── attn.yaml │ ├── none.yaml │ ├── deep_selection │ │ ├── topk.yaml │ │ ├── dynamic.yaml │ │ └── binary.yaml │ ├── deep.yaml │ └── deep_embedding │ │ ├── norm_base.yaml │ │ ├── attn_spec_norm.yaml │ │ └── attn_spec_base.yaml ├── task │ ├── lb_2subset.yaml │ ├── passage_retrieval_en.yaml │ ├── lb_3subset_incr.yaml │ ├── choubun_full.yaml │ ├── base_sampler.yaml │ └── lb_full.yaml ├── trainer │ ├── eval.yaml │ └── default.yaml ├── evolution │ ├── dummy.yaml │ └── cma_es.yaml ├── model │ ├── wrapped_llm │ │ ├── llama3-8b-rope-x4NTK.yaml │ │ └── base.yaml │ └── hf_evaluator.yaml ├── config_run_eval.yaml └── config.yaml ├── figures └── logo.png ├── memory_evolution ├── __init__.py └── base.py ├── memory_llms ├── __init__.py └── base.py ├── ChouBun └── config │ ├── dataset2maxlen.json │ └── dataset2prompt.json ├── stateless_parallel_modules ├── __init__.py ├── mlp.py ├── base.py └── attention.py ├── LongBench └── config │ ├── model2maxlen.json │ ├── dataset2maxlen.json │ ├── model2path.json │ └── dataset2prompt.json ├── env_minimal.yaml ├── memory_policy ├── __init__.py ├── deep_embedding_shared.py ├── deep_scoring_bam.py ├── deep_embedding.py ├── deep_selection.py ├── shared.py ├── deep_embedding_spectogram.py ├── auxiliary_losses.py └── base_dynamic.py ├── utils_log.py ├── choubun_metrics.py ├── utils_hydra.py ├── README.md ├── longbench_metrics.py ├── main.py ├── utils_longbench.py └── env.yaml /cfgs/run/default.yaml: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cfgs/auxiliary_loss/none.yaml: -------------------------------------------------------------------------------- 1 | auxiliary_loss: 2 | 3 | 4 | -------------------------------------------------------------------------------- /figures/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SakanaAI/evo-memory/HEAD/figures/logo.png -------------------------------------------------------------------------------- /memory_evolution/__init__.py: -------------------------------------------------------------------------------- 1 | from .cma_es import CMA_ES 2 | from .base import MemoryEvolution, DummyEvolution 3 | -------------------------------------------------------------------------------- /memory_llms/__init__.py: -------------------------------------------------------------------------------- 1 | from .llama import WrappedLlamaForCausalLM 2 | from .base import MemoryModelWrapper 3 | -------------------------------------------------------------------------------- /cfgs/typing/default.yaml: -------------------------------------------------------------------------------- 1 | # main model dtype 2 | dtype: 'bfloat16' 3 | output_attentions_full_precision: true 4 | 5 | dtype_conf_name: '' -------------------------------------------------------------------------------- /ChouBun/config/dataset2maxlen.json: -------------------------------------------------------------------------------- 1 | { 2 | "wiki_qa": 64, 3 | "edinet_qa": 64, 4 | "corp_sec_qa": 64, 5 | "corp_sec_sum": 128 6 | } -------------------------------------------------------------------------------- /cfgs/typing/half_attn.yaml: -------------------------------------------------------------------------------- 1 | # main model dtype 2 | dtype: 'bfloat16' 3 | output_attentions_full_precision: false 4 | 5 | dtype_conf_name: 'halfAttn' -------------------------------------------------------------------------------- /cfgs/policy/deep_scoring/bam.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - attn 3 | - _self_ 4 | 5 | mult: true 6 | mult_nonlinearity: 7 | # -- 8 | 9 | scoring_mp_log_name: bam -------------------------------------------------------------------------------- /cfgs/policy/none.yaml: -------------------------------------------------------------------------------- 1 | memory_policy: 2 | _target_: memory_policy.Recency 3 | cache_size: ${cache_size} 4 | 5 | cache_size: null 6 | 7 | 8 | mp_log_name: dummy -------------------------------------------------------------------------------- /cfgs/task/lb_2subset.yaml: -------------------------------------------------------------------------------- 1 | # TODO 2 | 3 | defaults: 4 | - base_sampler 5 | - _self_ 6 | 7 | tasks: ["lb/passage_retrieval_en", "lb/dureader"] 8 | metrics: 'perf' 9 | 10 | tasks_log_name: lb2subset -------------------------------------------------------------------------------- /cfgs/task/passage_retrieval_en.yaml: -------------------------------------------------------------------------------- 1 | # TODO 2 | 3 | defaults: 4 | - base_sampler 5 | - _self_ 6 | 7 | tasks: ["lb/passage_retrieval_en"] 8 | metrics: 'perf' 9 | 10 | tasks_log_name: passage-retrieval-en -------------------------------------------------------------------------------- /cfgs/task/lb_3subset_incr.yaml: -------------------------------------------------------------------------------- 1 | # TODO 2 | 3 | defaults: 4 | - base_sampler 5 | - _self_ 6 | 7 | tasks: ["lb/passage_retrieval_en", "lb/dureader", "lb/narrativeqa"] 8 | metrics: 'perf' 9 | 10 | tasks_log_name: lb3subset-i -------------------------------------------------------------------------------- /cfgs/policy/deep_selection/topk.yaml: -------------------------------------------------------------------------------- 1 | selection_criteria: ${topk_selection} 2 | 3 | topk_selection: 4 | _target_: memory_policy.TopKSelection 5 | cache_size: ${cache_size} 6 | 7 | cache_size: 8192 8 | 9 | selection_mp_log_name: topk-${cache_size}cs -------------------------------------------------------------------------------- /cfgs/task/choubun_full.yaml: -------------------------------------------------------------------------------- 1 | # TODO 2 | 3 | defaults: 4 | - base_sampler 5 | - _self_ 6 | 7 | tasks: 8 | - "choubun/wiki_qa" 9 | - "choubun/edinet_qa" 10 | - "choubun/corp_sec_qa" 11 | - "choubun/corp_sec_sum" 12 | metrics: 'perf' 13 | -------------------------------------------------------------------------------- /stateless_parallel_modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .mlp import StatelessGeneralizedMLP 2 | from .attention import StatelessAttention, StatelessAttentionParams 3 | 4 | from .base import ( 5 | StatelessGeneralizedOperation, GeneralizedLinear, get_nonlinearity, 6 | StatelessGeneralizedModule) 7 | -------------------------------------------------------------------------------- /cfgs/task/base_sampler.yaml: -------------------------------------------------------------------------------- 1 | task_sampler: 2 | _target_: task_sampler.TaskSampler 3 | tasks: ${tasks} 4 | metrics: ${metrics} 5 | training_tasks_subset: ${training_tasks_subset} 6 | test_tasks_subset: ${test_tasks_subset} 7 | 8 | 9 | tasks: ['sciq'] 10 | metrics: ['acc'] 11 | training_tasks_subset: 12 | test_tasks_subset: 13 | 14 | tasks_log_name: -------------------------------------------------------------------------------- /cfgs/trainer/eval.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | - _self_ 4 | 5 | max_iters: 0 6 | eval_only: true 7 | always_save_checkpoint: true 8 | scratch: false 9 | record_advanced_eval_stats: true 10 | eval_samples_batch_size: 11 | allow_distributed_eval: true 12 | 13 | score_aggregation: 'mean' 14 | score_normalization_reference: 15 | store_eval_results_locally: true -------------------------------------------------------------------------------- /cfgs/evolution/dummy.yaml: -------------------------------------------------------------------------------- 1 | evolution_algorithm: 2 | _target_: memory_evolution.base.DummyEvolution 3 | pop_size: ${pop_size} 4 | param_size: ${param_size} 5 | param_clip: ${param_clip} 6 | score_processing: ${score_processing} 7 | 8 | pop_size: 1 9 | param_size: 10 | param_clip: 11 | score_processing: 12 | 13 | evolution_algorithm_name: dummy 14 | 15 | evolution_algorithm_log_name: ${evolution_algorithm_name}-p${pop_size} -------------------------------------------------------------------------------- /cfgs/policy/deep_selection/dynamic.yaml: -------------------------------------------------------------------------------- 1 | selection_criteria: ${dynamic_selection} 2 | 3 | dynamic_selection: 4 | _target_: memory_policy.DynamicSelection 5 | per_layer: ${per_layer} 6 | per_head: ${per_head} 7 | shared: ${selection_shared} 8 | cache_size: ${cache_size} 9 | dynamic_thresh: ${dynamic_thresh} 10 | 11 | selection_shared: true 12 | cache_size: 13 | # 1/2048 14 | dynamic_thresh: 0.00048828125 15 | 16 | selection_mp_log_name: dynamic-${cache_size}cs -------------------------------------------------------------------------------- /cfgs/policy/deep_selection/binary.yaml: -------------------------------------------------------------------------------- 1 | selection_criteria: ${dynamic_selection} 2 | 3 | dynamic_selection: 4 | _target_: memory_policy.BinarySelection 5 | per_layer: ${per_layer} 6 | per_head: ${per_head} 7 | shared: ${selection_shared} 8 | cache_size: ${cache_size} 9 | is_probabilistic: ${is_probabilistic} 10 | temp: ${temp} 11 | learned_temp: ${learned_temp} 12 | 13 | 14 | is_probabilistic: false 15 | temp: 1.0 16 | learned_temp: false 17 | 18 | selection_shared: true 19 | cache_size: 20 | 21 | selection_mp_log_name: binary-${cache_size}cs -------------------------------------------------------------------------------- /LongBench/config/model2maxlen.json: -------------------------------------------------------------------------------- 1 | { 2 | "llama2-7b-chat-4k": 3500, 3 | "longchat-v1.5-7b-32k": 31500, 4 | "xgen-7b-8k": 7500, 5 | "internlm-7b-8k": 7500, 6 | "chatglm2-6b": 31500, 7 | "chatglm2-6b-32k": 31500, 8 | "chatglm3-6b-32k": 31500, 9 | "vicuna-v1.5-7b-16k": 15500, 10 | "Meta-Llama-3-8B": 7500, 11 | "Meta-Llama-3-8B-x4NTK": 31500, 12 | "Meta-Llama-3-8B-x4Linear": 31500, 13 | "Mixtral-8x22B-v0.1": 63500, 14 | "Mistral-7B-Instruct-v0.2": 31500, 15 | "Mistral-NoChatPrompt-7B-Instruct-v0.2": 31500 16 | } 17 | -------------------------------------------------------------------------------- /ChouBun/config/dataset2prompt.json: -------------------------------------------------------------------------------- 1 | { 2 | "wiki_qa": "以下の文書を注意深く読み、続く質問に答えてください。回答は必ず文書内の連続した文字列(テキストスパン)から抽出してください。\n### 文書 ###\n{context}\n### 質問 ###\n{input}\n### 回答 ###\n", 3 | "edinet_qa": "以下の文書を注意深く読み、続く質問に答えてください。回答は必ず文書内の連続した文字列(テキストスパン)から抽出してください。\n### 文書 ###\n{context}\n### 質問 ###\n{input}\n### 回答 ###\n", 4 | "corp_sec_qa": "以下の文書を注意深く読み、続く質問に答えてください。回答は必ず文書内の連続した文字列(テキストスパン)から抽出してください。### 文書 ###\n{context}\n### 質問 ###\n{input}\n### 回答 ###\n", 5 | "corp_sec_sum": "以下の文書を要約してください。要約は短く簡潔にしてください。また、文書の全体的な考え方、傾向、洞察を反映させてください。\n### 文書 ###\n{context}\n### 要約 ###\n" 6 | } -------------------------------------------------------------------------------- /LongBench/config/dataset2maxlen.json: -------------------------------------------------------------------------------- 1 | { 2 | "narrativeqa": 128, 3 | "qasper": 128, 4 | "multifieldqa_en": 64, 5 | "multifieldqa_zh": 64, 6 | "hotpotqa": 32, 7 | "2wikimqa": 32, 8 | "musique": 32, 9 | "dureader": 128, 10 | "gov_report": 512, 11 | "qmsum": 512, 12 | "multi_news": 512, 13 | "vcsum": 512, 14 | "trec": 64, 15 | "triviaqa": 32, 16 | "samsum": 128, 17 | "lsht": 64, 18 | "passage_count": 32, 19 | "passage_retrieval_en": 32, 20 | "passage_retrieval_zh": 32, 21 | "lcc": 64, 22 | "repobench-p": 64 23 | } -------------------------------------------------------------------------------- /cfgs/task/lb_full.yaml: -------------------------------------------------------------------------------- 1 | # TODO 2 | 3 | defaults: 4 | - base_sampler 5 | - _self_ 6 | 7 | tasks: 8 | - "lb/narrativeqa" 9 | - "lb/qasper" 10 | - "lb/multifieldqa_en" 11 | - "lb/multifieldqa_zh" 12 | - "lb/hotpotqa" 13 | - "lb/2wikimqa" 14 | - "lb/musique" 15 | - "lb/dureader" 16 | - "lb/gov_report" 17 | - "lb/qmsum" 18 | - "lb/multi_news" 19 | - "lb/vcsum" 20 | - "lb/trec" 21 | - "lb/triviaqa" 22 | - "lb/samsum" 23 | - "lb/lsht" 24 | - "lb/passage_count" 25 | - "lb/passage_retrieval_en" 26 | - "lb/passage_retrieval_zh" 27 | - "lb/lcc" 28 | - "lb/repobench-p" 29 | metrics: 'perf' 30 | -------------------------------------------------------------------------------- /cfgs/policy/deep.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - deep_embedding@_global_: joint_recency_attn_spec 4 | - deep_scoring@_global_: mlp 5 | - deep_selection@_global_: dynamic 6 | 7 | 8 | memory_policy: 9 | _target_: memory_policy.DeepMP 10 | pop_size: ${pop_size} 11 | per_head: ${per_head} 12 | per_layer: ${per_layer} 13 | token_embedding: ${token_embedding} 14 | scoring_network: ${scoring_network} 15 | selection_criteria: ${selection_criteria} 16 | lazy_param_num: ${lazy_param_num} 17 | 18 | per_head: false 19 | per_layer: true 20 | lazy_param_num: false 21 | 22 | mp_log_name: NAMM/${embedding_mp_log_name}/${scoring_mp_log_name}/${selection_mp_log_name} -------------------------------------------------------------------------------- /env_minimal.yaml: -------------------------------------------------------------------------------- 1 | name: th2 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - python=3.10 9 | - pytorch 10 | - pytorch-cuda=12.1 11 | - torchvision 12 | - torchaudio 13 | - ipykernel 14 | - matplotlib 15 | - ipywidgets 16 | - pip: 17 | - vllm 18 | - numpy 19 | - transformers==4.41.2 20 | - accelerate 21 | - datasets 22 | - tiktoken 23 | - wandb 24 | - tqdm 25 | - hydra-core 26 | - crfm-helm 27 | - lm-eval==0.4.2 28 | - fugashi==1.3.2 29 | - ftfy 30 | - bitsandbytes 31 | - rouge 32 | - jieba 33 | - fuzzywuzzy 34 | - einops 35 | - scipy==1.13.0 36 | - seaborn 37 | -------------------------------------------------------------------------------- /cfgs/model/wrapped_llm/llama3-8b-rope-x4NTK.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | - _self_ 4 | 5 | pretrained_llm_name: meta-llama/Meta-Llama-3-8B 6 | 7 | memory_model: 8 | _target_: memory_llms.llama.WrappedLlamaForCausalLM 9 | model: ${pretrained_llm} 10 | memory_policy: ${memory_policy} 11 | max_new_tokens: ${max_new_tokens} 12 | 13 | pretrained_llm: 14 | _target_: utils_hydra.AutoModelForCausalLM.from_pretrained 15 | pretrained_model_name_or_path: ${pretrained_llm_name} 16 | rope_scaling: ${rope_scaling} 17 | 18 | max_new_tokens: 19 | max_position_id: 32768 20 | 21 | rope_scaling: 22 | type: dynamic 23 | override: 'ntk' 24 | factor: 4 25 | alpha: 2 26 | 27 | 28 | llm_log_name: Meta-Llama-3-8B-x4NTK -------------------------------------------------------------------------------- /cfgs/policy/deep_embedding/norm_base.yaml: -------------------------------------------------------------------------------- 1 | emb_online_output_normalization: true 2 | emb_update_norm_during_training: false 3 | emb_update_norm_during_eval: true 4 | 5 | synchronized_buffers_freeze_after: 1 6 | 7 | embedding_output_params: 8 | _target_: memory_policy.ComponentOutputParams 9 | requires_recomputation: ${embedding_requires_recomputation} 10 | reduction_mode: ${embedding_reduction_mode} 11 | ema_params: ${embedding_ema_params} 12 | output_past_non_reduced_history: ${embedding_output_past_non_reduced_history} 13 | max_non_reduced_history_len: ${embedding_max_non_reduced_history_len} 14 | online_output_normalization: ${emb_online_output_normalization} 15 | update_norm_during_training: ${emb_update_norm_during_training} 16 | update_norm_during_eval: ${emb_update_norm_during_eval} 17 | -------------------------------------------------------------------------------- /LongBench/config/model2path.json: -------------------------------------------------------------------------------- 1 | { 2 | "llama2-7b-chat-4k": "meta-llama/Llama-2-7b-chat-hf", 3 | "longchat-v1.5-7b-32k": "lmsys/longchat-7b-v1.5-32k", 4 | "xgen-7b-8k": "Salesforce/xgen-7b-8k-inst", 5 | "internlm-7b-8k": "internlm/internlm-chat-7b-8k", 6 | "chatglm2-6b": "THUDM/chatglm2-6b", 7 | "chatglm2-6b-32k": "THUDM/chatglm2-6b-32k", 8 | "chatglm3-6b-32k": "THUDM/chatglm3-6b-32k", 9 | "vicuna-v1.5-7b-16k": "lmsys/vicuna-7b-v1.5-16k", 10 | "Meta-Llama-3-8B": "meta-llama/Meta-Llama-3-8B", 11 | "Meta-Llama-3-8B-x4NTK": "meta-llama/Meta-Llama-3-8B", 12 | "Meta-Llama-3-8B-x4Linear": "meta-llama/Meta-Llama-3-8B", 13 | "Mixtral-8x22B-v0.1": "mistralai/Mixtral-8x22B-v0.1", 14 | "Mistral-7B-Instruct-v0.2": "mistralai/Mistral-7B-Instruct-v0.2", 15 | "Mistral-NoChatPrompt-7B-Instruct-v0.2": "mistralai/Mistral-7B-Instruct-v0.2" 16 | } 17 | -------------------------------------------------------------------------------- /cfgs/config_run_eval.yaml: -------------------------------------------------------------------------------- 1 | 2 | # TODO 3 | defaults: 4 | - _self_ 5 | - trainer@_global_: default 6 | - model@_global_: hf_evaluator 7 | - policy@_global_: deep 8 | - auxiliary_loss@_global_: none 9 | - evolution@_global_: cma_es 10 | - task@_global_: passage_retrieval_en 11 | - run@_global_: default 12 | 13 | params_to_sweep: 14 | memory_policy: 15 | cache_size: [256, 512, 1024] 16 | 17 | out_folder: ./exp_local 18 | out_dir: '${out_folder}/${wandb_project}/${wandb_group_name}/${wandb_run_name}/${seed}' 19 | 20 | # wandb logging 21 | wandb_log: true # enabled by default 22 | wandb_project: memory_evolution_hf 23 | wandb_run_name: default_configs_test 24 | wandb_group_name: tests 25 | 26 | 27 | # system 28 | backend: 'nccl' # 'nccl', 'gloo', etc. 29 | device: 'cuda' 30 | seed: 1337 31 | deterministic_behavior: false 32 | 33 | 34 | hydra: 35 | run: 36 | dir: ${out_dir}/ -------------------------------------------------------------------------------- /cfgs/model/wrapped_llm/base.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | 4 | 5 | pretrained_llm: 6 | _target_: utils_hydra.AutoModelForCausalLM.from_pretrained 7 | pretrained_model_name_or_path: ${pretrained_llm_name} 8 | 9 | tokenizer: 10 | _target_: utils_hydra.AutoTokenizer.from_pretrained 11 | pretrained_model_name_or_path: ${pretrained_llm_name} 12 | 13 | pretrained_llm_name: meta-llama/Meta-Llama-3-8B 14 | 15 | 16 | memory_model: 17 | _target_: memory_llms.llama.WrappedLlamaForCausalLM 18 | model: ${pretrained_llm} 19 | memory_policy: ${memory_policy} 20 | max_new_tokens: ${max_new_tokens} 21 | memory_policy_fixed_delay: ${memory_policy_fixed_delay} 22 | output_attentions_full_precision: ${output_attentions_full_precision} 23 | 24 | # needs to specify only for non memory_models 25 | max_new_tokens: 26 | max_position_id: 2048 27 | max_conditioning_length: ${max_position_id} 28 | memory_policy_fixed_delay: -------------------------------------------------------------------------------- /cfgs/run/base_memory_policy/none.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model/wrapped_llm@_global_: llama3-8b-rope-x4NTK 4 | - override /task@_global_: passage_retrieval_en 5 | - override /evolution@_global_: dummy 6 | - override /policy@_global_: none 7 | - override /typing@_global_: half_attn 8 | - override /auxiliary_loss@_global_: none 9 | - _self_ 10 | 11 | run_name_suffix: '' 12 | 13 | wandb_project: memory_evolution_hf 14 | 15 | wandb_run_name: ${tasks_log_name}-${evolution_algorithm_log_name}-shared-${pop_size}pop-${samples_batch_size}qs-${memory_policy_fixed_delay}fixDel-${run_name_suffix} 16 | wandb_group_name: ${llm_log_name}/${mp_log_name} 17 | 18 | add_bos_token: true # to match ref. scores setting 19 | score_normalization_reference: 'lb_reference_scores/per_request/llama3-8b-32k-x2NTK2alpha-v2.json' 20 | 21 | batch_size: 1 22 | 23 | max_new_tokens: 128 24 | 25 | cache_size: ${max_position_id} # do not discard any tokens -------------------------------------------------------------------------------- /cfgs/run/namm_bam_i1.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - base_memory_policy/deep/bam/base_bam 4 | - _self_ 5 | 6 | run_name_suffix: 'stage1' 7 | 8 | cache_size: 9 | max_new_tokens: 128 10 | 11 | scoring_reduction_mode: 12 | embedding_reduction_mode: ema 13 | 14 | # Ema params. 15 | embedding_ema_coeff: 0.99 16 | embedding_reduction_learned: false 17 | 18 | 19 | # -- 20 | # increasing the hidden dim, increases time cost marginally 21 | scoring_attn_hidden_dim: 32 22 | scoring_attn_output_dim: 23 | scoring_attn_num_heads: 1 24 | scoring_attn_bias: true 25 | scoring_attn_use_rope: false 26 | scoring_attn_rope_theta: 50000 27 | scoring_attn_masking_strategy: backward 28 | #-- 29 | 30 | # -- 31 | # Stft params 32 | n_fft: 32 33 | hop_length: 16 34 | window_fn: 35 | _target_: utils_hydra.torch.hann_window 36 | window_length: ${n_fft} 37 | periodic: true 38 | window_fn_log_name: hann 39 | pad_mode: constant 40 | output_magnitudes: true 41 | # -- 42 | 43 | pop_size: 32 44 | 45 | scoring_initializer: 0 46 | 47 | keep_past_epoch_checkpoints_every: 1 -------------------------------------------------------------------------------- /cfgs/evolution/cma_es.yaml: -------------------------------------------------------------------------------- 1 | evolution_algorithm: 2 | _target_: memory_evolution.cma_es.CMA_ES 3 | pop_size: ${pop_size} 4 | param_size: ${param_size} 5 | param_clip: ${param_clip} 6 | clip_param_min: ${clip_param_min} 7 | clip_param_max: ${clip_param_max} 8 | score_processing: ${score_processing} 9 | elite_ratio: ${elite_ratio} 10 | c_1: ${c_1} 11 | c_mu: ${c_mu} 12 | c_sigma: ${c_sigma} 13 | d_sigma: ${d_sigma} 14 | c_c: ${c_c} 15 | c_m: ${c_m} 16 | init_sigma: ${init_sigma} 17 | init_param_range: ${init_param_range} 18 | prefer_mean_to_best: ${prefer_mean_to_best} 19 | 20 | 21 | pop_size: 16 22 | param_size: 23 | param_clip: 24 | clip_param_min: 25 | clip_param_max: 26 | # cma-es has own custom score processinhg 27 | score_processing: 28 | elite_ratio: 0.5 29 | c_1: 30 | c_mu: 31 | c_sigma: 32 | d_sigma: 33 | c_c: 34 | 35 | 36 | c_m: 1.0 37 | init_sigma: 0.065 38 | init_param_range: 39 | 40 | prefer_mean_to_best: false 41 | 42 | evolution_algorithm_name: cma-es 43 | 44 | evolution_algorithm_log_name: ${evolution_algorithm_name}-p${pop_size}-rMean${prefer_mean_to_best} -------------------------------------------------------------------------------- /cfgs/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - trainer@_global_: default 4 | - model@_global_: hf_evaluator 5 | - typing@_global_: default 6 | - policy@_global_: deep 7 | - auxiliary_loss@_global_: none 8 | - evolution@_global_: cma_es 9 | - task@_global_: passage_retrieval_en 10 | - run@_global_: default 11 | 12 | wandb_config: 13 | _target_: memory_trainer.WandbConfig 14 | wandb_log: ${wandb_log} 15 | wandb_project: ${wandb_project} 16 | wandb_run_name: ${wandb_run_name} 17 | wandb_group_name: ${wandb_group_name} 18 | 19 | 20 | out_folder: ./exp_local 21 | out_dir: '${out_folder}/${wandb_project}/${wandb_group_name}/${wandb_run_name}/${seed}' 22 | 23 | # wandb logging 24 | wandb_log: true # enabled by default 25 | wandb_project: memory_evolution_hf 26 | wandb_run_name: default_configs_test 27 | wandb_group_name: tests 28 | 29 | # system 30 | device: 'cuda' 31 | seed: 1337 32 | deterministic_behavior: false 33 | 34 | # ddp settings 35 | backend: 'nccl' # 'nccl', 'gloo', etc. 36 | ddp_timeout_limit: '0:6:0' # days:hours:minutes 37 | 38 | hydra: 39 | run: 40 | dir: ${out_dir}/ -------------------------------------------------------------------------------- /cfgs/policy/deep_embedding/attn_spec_norm.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - attn_spec_base 3 | - norm_base 4 | - _self_ 5 | 6 | token_embedding: 7 | _target_: memory_policy.RecencyEmbeddingWrapper 8 | token_embedding: ${attn_spectrogram} 9 | recency_embedding: ${recency_deep_embedding} 10 | start_recency_from: ${start_recency_from} 11 | wrapper_output_dim: ${wrapper_output_dim} 12 | processing_layers: ${processing_layers} 13 | joining_strategy: ${joining_strategy} 14 | output_params: ${wrapper_output_params} 15 | 16 | 17 | recency_deep_embedding: 18 | _target_: memory_policy.PositionalEmbedding 19 | max_position_id: ${max_position_id} 20 | embed_dim: ${recency_embed_dim} 21 | max_freq: ${recency_max_freq} 22 | 23 | wrapper_output_dim: 24 | processing_layers: 0 25 | joining_strategy: 'append' 26 | wrapper_output_params: 27 | 28 | recency_embedding_name: ${joining_strategy}RecAbsPoc${recency_embed_dim}D 29 | 30 | # must be even 31 | recency_embed_dim: 8 32 | recency_max_freq: 50000 33 | start_recency_from: 1 # NOTE: 1 for legacy compatibility purposes 34 | 35 | embedding_mp_log_name: attn-spec-norm -------------------------------------------------------------------------------- /cfgs/run/namm_bam_i2.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - base_memory_policy/deep/bam/base_bam 4 | - override /task@_global_: lb_2subset 5 | - _self_ 6 | 7 | run_name_suffix: 'stage2' 8 | 9 | cache_size: 10 | max_new_tokens: 128 11 | 12 | scoring_reduction_mode: 13 | embedding_reduction_mode: ema 14 | 15 | # Ema params. 16 | embedding_ema_coeff: 0.99 17 | embedding_reduction_learned: false 18 | 19 | # -- 20 | # increasing the hidden dim, increases time cost marginally 21 | scoring_attn_hidden_dim: 32 22 | scoring_attn_output_dim: 23 | scoring_attn_num_heads: 1 24 | scoring_attn_bias: true 25 | scoring_attn_use_rope: false 26 | scoring_attn_rope_theta: 50000 27 | scoring_attn_masking_strategy: backward 28 | #-- 29 | 30 | # -- 31 | # Stft params 32 | n_fft: 32 33 | hop_length: 16 34 | window_fn: 35 | _target_: utils_hydra.torch.hann_window 36 | window_length: ${n_fft} 37 | periodic: true 38 | window_fn_log_name: hann 39 | pad_mode: constant 40 | output_magnitudes: true 41 | # -- 42 | 43 | 44 | init_from: 'path/to/stage1.pt' 45 | 46 | 47 | pop_size: 32 48 | 49 | scoring_initializer: 0 50 | keep_past_epoch_checkpoints_every: 1 51 | 52 | scratch: false -------------------------------------------------------------------------------- /cfgs/run/namm_bam_i3.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - base_memory_policy/deep/bam/base_bam 4 | - override /task@_global_: lb_3subset_incr 5 | - _self_ 6 | 7 | run_name_suffix: 'stage3' 8 | 9 | cache_size: 10 | max_new_tokens: 128 11 | 12 | scoring_reduction_mode: 13 | embedding_reduction_mode: ema 14 | 15 | # Ema params. 16 | embedding_ema_coeff: 0.99 17 | embedding_reduction_learned: false 18 | 19 | # -- 20 | # increasing the hidden dim, increases time cost marginally 21 | scoring_attn_hidden_dim: 32 22 | scoring_attn_output_dim: 23 | scoring_attn_num_heads: 1 24 | scoring_attn_bias: true 25 | scoring_attn_use_rope: false 26 | scoring_attn_rope_theta: 50000 27 | scoring_attn_masking_strategy: backward 28 | #-- 29 | 30 | # -- 31 | # Stft params 32 | n_fft: 32 33 | hop_length: 16 34 | window_fn: 35 | _target_: utils_hydra.torch.hann_window 36 | window_length: ${n_fft} 37 | periodic: true 38 | window_fn_log_name: hann 39 | pad_mode: constant 40 | output_magnitudes: true 41 | # -- 42 | 43 | 44 | init_from: 'path/to/stage2.pt' 45 | 46 | 47 | pop_size: 32 48 | 49 | scoring_initializer: 0 50 | keep_past_epoch_checkpoints_every: 1 51 | 52 | scratch: false -------------------------------------------------------------------------------- /cfgs/model/hf_evaluator.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - wrapped_llm@_global_: pythia160m 4 | 5 | memory_evaluator: 6 | _target_: memory_evaluator.MemoryHFEvaluator 7 | model: ${pretrained_llm} 8 | tokenizer: ${tokenizer} 9 | 10 | evaluation_ctx_steps: ${evaluation_ctx_steps} 11 | add_bos_token: ${add_bos_token} 12 | 13 | eval_max_batch_size: ${eval_max_batch_size} 14 | memory_batch_size: ${memory_batch_size} 15 | 16 | batch_size: ${batch_size} 17 | max_conditioning_length: ${max_conditioning_length} 18 | 19 | max_memory_length: ${max_memory_length} 20 | max_gen_tokens: ${max_gen_tokens} 21 | full_context_gen: ${full_context_gen} 22 | 23 | per_timestep_loglikelihood: ${per_timestep_loglikelihood} 24 | force_clear_cache: ${force_clear_cache} 25 | 26 | device: ${device} 27 | 28 | log_misc: ${log_misc} 29 | 30 | evaluation_ctx_steps: 1 # default, unused 31 | add_bos_token: true 32 | 33 | eval_max_batch_size: 256 34 | memory_batch_size: false 35 | 36 | batch_size: "auto" 37 | 38 | max_memory_length: 2048 39 | max_gen_tokens: 512 # from the HFLM class 40 | full_context_gen: true 41 | 42 | per_timestep_loglikelihood: true 43 | force_clear_cache: true 44 | 45 | log_misc: false -------------------------------------------------------------------------------- /cfgs/run/namm_bam_eval.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - base_memory_policy/deep/bam/base_bam 4 | - override /task@_global_: lb_full 5 | - _self_ 6 | 7 | run_name_suffix: 'stage3' 8 | 9 | cache_size: 10 | max_new_tokens: 128 11 | 12 | scoring_reduction_mode: 13 | embedding_reduction_mode: ema 14 | 15 | # Ema params. 16 | embedding_ema_coeff: 0.99 17 | embedding_reduction_learned: false 18 | 19 | # -- 20 | # increasing the hidden dim, increases time cost marginally 21 | scoring_attn_hidden_dim: 32 22 | scoring_attn_output_dim: 23 | scoring_attn_num_heads: 1 24 | scoring_attn_bias: true 25 | scoring_attn_use_rope: false 26 | scoring_attn_rope_theta: 50000 27 | scoring_attn_masking_strategy: backward 28 | #-- 29 | 30 | # -- 31 | # Stft params 32 | n_fft: 32 33 | hop_length: 16 34 | window_fn: 35 | _target_: utils_hydra.torch.hann_window 36 | window_length: ${n_fft} 37 | periodic: true 38 | window_fn_log_name: hann 39 | pad_mode: constant 40 | output_magnitudes: true 41 | # -- 42 | 43 | 44 | init_from: 'path/to/namm/results/ckpt.pt' 45 | 46 | pop_size: 32 47 | 48 | scoring_initializer: 0 49 | keep_past_epoch_checkpoints_every: 1 50 | 51 | scratch: false 52 | 53 | eval_only: true 54 | # avoid timeout when evaluating on all tasks 55 | ddp_timeout_limit: '9:0:0' -------------------------------------------------------------------------------- /memory_policy/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .base import ( 3 | MemoryPolicy, ParamMemoryPolicy, Recency, AttnRequiringRecency,) 4 | from .base_dynamic import ( 5 | DynamicMemoryPolicy, DynamicParamMemoryPolicy, 6 | RecencyParams, AttentionParams, 7 | ) 8 | 9 | from .auxiliary_losses import MemoryPolicyAuxiliaryLoss, SparsityAuxiliaryLoss, L2NormAuxiliaryLoss 10 | 11 | from memory_policy.deep import DeepMP 12 | from memory_policy.deep_embedding_spectogram import ( 13 | STFTParams, AttentionSpectrogram, fft_avg_mask, fft_ema_mask, 14 | ) 15 | from memory_policy.deep_embedding import ( 16 | RecencyExponents, NormalizedRecencyExponents) 17 | from memory_policy.deep_scoring import ( 18 | MLPScoring, GeneralizedScoring, make_scaled_one_hot_init, TCNScoring) 19 | from memory_policy.deep_selection import ( 20 | DynamicSelection, TopKSelection, BinarySelection) 21 | from memory_policy.base_deep_components import ( 22 | EMAParams, ComponentOutputParams, wrap_torch_initializer, 23 | DeepMemoryPolicyComponent, TokenEmbedding, JointEmbeddings, 24 | ScoringNetwork, SelectionNetwork, 25 | ) 26 | from .shared import SynchronizableBufferStorage, RegistrationCompatible 27 | 28 | from .deep_embedding_shared import PositionalEmbedding, Embedding 29 | from .deep_embedding_wrappers import RecencyEmbeddingWrapper -------------------------------------------------------------------------------- /cfgs/run/namm_bam_eval_choubun.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - base_memory_policy/deep/bam/base_bam 4 | - override /task@_global_: choubun_full 5 | - _self_ 6 | 7 | run_name_suffix: 'stage3' 8 | 9 | cache_size: 10 | max_new_tokens: 128 11 | 12 | scoring_reduction_mode: 13 | embedding_reduction_mode: ema 14 | 15 | # Ema params. 16 | embedding_ema_coeff: 0.99 17 | embedding_reduction_learned: false 18 | 19 | # -- 20 | # increasing the hidden dim, increases time cost marginally 21 | scoring_attn_hidden_dim: 32 22 | scoring_attn_output_dim: 23 | scoring_attn_num_heads: 1 24 | scoring_attn_bias: true 25 | scoring_attn_use_rope: false 26 | scoring_attn_rope_theta: 50000 27 | scoring_attn_masking_strategy: backward 28 | #-- 29 | 30 | # -- 31 | # Stft params 32 | n_fft: 32 33 | hop_length: 16 34 | window_fn: 35 | _target_: utils_hydra.torch.hann_window 36 | window_length: ${n_fft} 37 | periodic: true 38 | window_fn_log_name: hann 39 | pad_mode: constant 40 | output_magnitudes: true 41 | # -- 42 | 43 | 44 | init_from: 'path/to/namm/results/ckpt.pt' 45 | 46 | pop_size: 32 47 | 48 | scoring_initializer: 0 49 | keep_past_epoch_checkpoints_every: 1 50 | 51 | scratch: false 52 | 53 | eval_only: true 54 | # avoid timeout when evaluating on all tasks 55 | ddp_timeout_limit: '9:0:0' 56 | 57 | score_normalization_reference: -------------------------------------------------------------------------------- /cfgs/policy/deep_embedding/attn_spec_base.yaml: -------------------------------------------------------------------------------- 1 | attn_spectrogram: 2 | _target_: memory_policy.AttentionSpectrogram 3 | per_layer: ${per_layer} 4 | per_head: ${per_head} 5 | shared: ${embedding_shared} 6 | output_params: ${embedding_output_params} 7 | stft_params: ${stft_params} 8 | 9 | embedding_output_params: 10 | _target_: memory_policy.ComponentOutputParams 11 | requires_recomputation: ${embedding_requires_recomputation} 12 | reduction_mode: ${embedding_reduction_mode} 13 | ema_params: ${embedding_ema_params} 14 | output_past_non_reduced_history: ${embedding_output_past_non_reduced_history} 15 | max_non_reduced_history_len: ${embedding_max_non_reduced_history_len} 16 | 17 | embedding_ema_params: 18 | _target_: memory_policy.EMAParams 19 | coeff: ${embedding_ema_coeff} 20 | learned: ${embedding_reduction_learned} 21 | reduction_stride: ${hop_length} 22 | 23 | stft_params: 24 | _target_: memory_policy.STFTParams 25 | n_fft: ${n_fft} 26 | hop_length: ${hop_length} 27 | window_fn: ${window_fn} 28 | pad_mode: ${pad_mode} 29 | output_magnitudes: ${output_magnitudes} 30 | 31 | # Spectrogram params. 32 | embedding_shared: true 33 | # Output params. 34 | embedding_requires_recomputation: true 35 | embedding_reduction_mode: 36 | # Ema params. 37 | embedding_ema_coeff: 0.975 38 | embedding_reduction_learned: false 39 | # -- 40 | embedding_output_past_non_reduced_history: false 41 | embedding_max_non_reduced_history_len: 42 | # -- 43 | # Stft params 44 | n_fft: 32 45 | hop_length: 16 46 | window_fn: 47 | _target_: memory_policy.fft_avg_mask 48 | window_length: ${n_fft} 49 | pad_mode: constant 50 | output_magnitudes: true 51 | # -- 52 | 53 | window_fn_log_name: avg 54 | embedding_spec_mp_log_name: attn-spec -------------------------------------------------------------------------------- /cfgs/policy/deep_scoring/mlp.yaml: -------------------------------------------------------------------------------- 1 | scoring_network: ${mlp_scoring} 2 | 3 | mlp_scoring: 4 | _target_: memory_policy.MLPScoring 5 | per_layer: ${per_layer} 6 | per_head: ${per_head} 7 | shared: ${scoring_shared} 8 | output_params: ${scoring_output_params} 9 | hidden_features: ${hidden_features} 10 | hidden_depth: ${hidden_depth} 11 | bias: ${scoring_mlp_bias} 12 | non_linearity: ${non_linearity} 13 | initializer: ${scoring_initializer} 14 | residual: ${residual} 15 | residual_first: ${residual_first} 16 | 17 | 18 | scoring_output_params: 19 | _target_: memory_policy.ComponentOutputParams 20 | requires_recomputation: ${scoring_requires_recomputation} 21 | reduction_mode: ${scoring_reduction_mode} 22 | ema_params: ${scoring_ema_params} 23 | output_past_non_reduced_history: ${scoring_output_past_non_reduced_history} 24 | max_non_reduced_history_len: ${scoring_max_non_reduced_history_len} 25 | 26 | scoring_ema_params: 27 | _target_: memory_policy.EMAParams 28 | coeff: ${scoring_ema_coeff} 29 | learned: ${scoring_reduction_learned} 30 | reduction_stride: ${hop_length} 31 | 32 | # MLP params 33 | scoring_shared: true 34 | # Output params. 35 | scoring_requires_recomputation: true 36 | scoring_reduction_mode: ema 37 | # Ema params. 38 | scoring_ema_coeff: 0.99 39 | scoring_reduction_learned: false 40 | # -- 41 | scoring_output_past_non_reduced_history: false 42 | scoring_max_non_reduced_history_len: 43 | # -- 44 | hidden_features: 45 | hidden_depth: 1 46 | scoring_mlp_bias: false 47 | non_linearity: relu 48 | scoring_initializer: 49 | _target_: memory_policy.make_scaled_one_hot_init 50 | idxs_to_scale: {} 51 | idxs_to_ones: [] 52 | residual: true 53 | residual_first: true 54 | # -- 55 | 56 | 57 | scoring_mp_log_name: mlp -------------------------------------------------------------------------------- /utils_log.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig 5 | 6 | from lm_eval.models.utils import Collator 7 | 8 | 9 | def stat_fn(list_of_lists, index=None): 10 | if index is None: 11 | flat_list = np.array( 12 | [v for vs in list_of_lists for v in vs]) 13 | else: 14 | mean = np.mean(list_of_lists[i]) 15 | mean = np.mean(flat_list) 16 | std = np.std(flat_list) 17 | above_avg = np.mean(flat_list > mean) 18 | max_v = np.max(flat_list) 19 | min_v = np.min(flat_list) 20 | return mean, std, above_avg, max_v, min_v 21 | 22 | 23 | def initialize_stat_objects_for( 24 | self, 25 | score_name, 26 | stats=['mean', 'std', 'above_mean', 'max', 'min'], 27 | ): 28 | for stat in stats: 29 | init_list = [[] for _ in range(self.num_memory_layers)] 30 | setattr(self, f'{score_name}_{stat}', init_list) 31 | raise NotImplementedError 32 | 33 | 34 | class COLOR: 35 | # ANSI color codes and tools 36 | BLACK = "\033[0;30m" 37 | RED = "\033[0;31m" 38 | GREEN = "\033[0;32m" 39 | BROWN = "\033[0;33m" 40 | BLUE = "\033[0;34m" 41 | PURPLE = "\033[0;35m" 42 | CYAN = "\033[0;36m" 43 | LIGHT_GRAY = "\033[0;37m" 44 | DARK_GRAY = "\033[1;30m" 45 | LIGHT_RED = "\033[1;31m" 46 | LIGHT_GREEN = "\033[1;32m" 47 | YELLOW = "\033[1;33m" 48 | LIGHT_BLUE = "\033[1;34m" 49 | LIGHT_PURPLE = "\033[1;35m" 50 | LIGHT_CYAN = "\033[1;36m" 51 | LIGHT_WHITE = "\033[1;37m" 52 | BOLD = "\033[1m" 53 | FAINT = "\033[2m" 54 | ITALIC = "\033[3m" 55 | UNDERLINE = "\033[4m" 56 | BLINK = "\033[5m" 57 | NEGATIVE = "\033[7m" 58 | CROSSED = "\033[9m" 59 | END = "\033[0m" 60 | -------------------------------------------------------------------------------- /choubun_metrics.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | 4 | from collections import Counter 5 | from rouge import Rouge 6 | 7 | 8 | def normalize_ja_answer(s): 9 | """Lower text and remove punctuation, extra whitespace.""" 10 | 11 | def white_space_fix(text): 12 | return "".join(text.split()) 13 | 14 | def remove_punc(text): 15 | cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏." 16 | all_punctuation = set(string.punctuation + cn_punctuation) 17 | return "".join(ch for ch in text if ch not in all_punctuation) 18 | 19 | def lower(text): 20 | return text.lower() 21 | 22 | return white_space_fix(remove_punc(lower(s))) 23 | 24 | 25 | def rouge_score(prediction, ground_truth, **kwargs): 26 | rouge = Rouge() 27 | try: 28 | scores = rouge.get_scores([prediction], [ground_truth], avg=True) 29 | except: 30 | return 0.0 31 | return scores["rouge-l"]["f"] 32 | 33 | 34 | def rouge_ja_score(prediction, ground_truth, tokenizer, **kwargs): 35 | prediction_tokens = [word.surface for word in tokenizer(prediction)] 36 | ground_truth_tokens = [word.surface for word in tokenizer(ground_truth)] 37 | prediction = " ".join(prediction_tokens) 38 | ground_truth = " ".join(ground_truth_tokens) 39 | score = rouge_score(prediction, ground_truth) 40 | return score 41 | 42 | 43 | def f1_score(prediction, ground_truth, **kwargs): 44 | common = Counter(prediction) & Counter(ground_truth) 45 | num_same = sum(common.values()) 46 | if num_same == 0: 47 | return 0 48 | precision = 1.0 * num_same / len(prediction) 49 | recall = 1.0 * num_same / len(ground_truth) 50 | f1 = (2 * precision * recall) / (precision + recall) 51 | return f1 52 | 53 | 54 | def qa_f1_ja_score(prediction, ground_truth, tokenizer, **kwargs): 55 | prediction_tokens = [word.surface for word in tokenizer(prediction)] 56 | ground_truth_tokens = [word.surface for word in tokenizer(ground_truth)] 57 | prediction_tokens = [normalize_ja_answer(token) for token in prediction_tokens] 58 | ground_truth_tokens = [normalize_ja_answer(token) for token in ground_truth_tokens] 59 | prediction_tokens = [token for token in prediction_tokens if len(token) > 0] 60 | ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0] 61 | return f1_score(prediction_tokens, ground_truth_tokens) -------------------------------------------------------------------------------- /cfgs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | 4 | trainer: 5 | _target_: memory_trainer.MemoryTrainer 6 | trainer_config: ${trainer_config} 7 | wandb_config: ${wandb_config} 8 | device: ${device} 9 | scratch: ${scratch} 10 | 11 | trainer_config: 12 | _target_: memory_trainer.TrainerConfig 13 | out_dir: ${out_dir} 14 | max_iters: ${max_iters} 15 | task_batch_size: ${task_batch_size} 16 | samples_batch_size: ${samples_batch_size} 17 | eval_samples_batch_size: ${eval_samples_batch_size} 18 | allow_distributed_eval: ${allow_distributed_eval} 19 | 20 | pop_accumulation_steps: ${pop_accumulation_steps} 21 | 22 | score_aggregation: ${score_aggregation} 23 | score_normalization_reference: ${score_normalization_reference} 24 | 25 | synchronized_buffers_aggregation: ${synchronized_buffers_aggregation} 26 | synchronized_buffers_freeze_after: ${synchronized_buffers_freeze_after} 27 | 28 | prefetch_task_tensors: ${prefetch_task_tensors} 29 | override_prefetched_tensors: ${override_prefetched_tensors} 30 | 31 | eval_interval: ${eval_interval} 32 | early_stop_patience: ${early_stop_patience} 33 | log_interval: ${log_interval} 34 | eval_iters: ${eval_iters} 35 | eval_only: ${eval_only} 36 | eval_candidate_samples: ${eval_candidate_samples} 37 | eval_candidate_temp: ${eval_candidate_temp} 38 | record_advanced_eval_stats: ${record_advanced_eval_stats} 39 | store_eval_results_locally: ${store_eval_results_locally} 40 | record_per_task_eval_stats: ${record_per_task_eval_stats} 41 | 42 | always_save_checkpoint: ${always_save_checkpoint} 43 | keep_past_epoch_checkpoints_every: ${keep_past_epoch_checkpoints_every} 44 | use_amp: ${use_amp} 45 | init_from: ${init_from} 46 | dtype: ${dtype} 47 | 48 | # main trainer cfgs 49 | max_iters: 10000 50 | task_batch_size: 51 | samples_batch_size: 16 52 | eval_samples_batch_size: 128 53 | allow_distributed_eval: false 54 | 55 | pop_accumulation_steps: 1 56 | score_aggregation: 'mean' 57 | score_normalization_reference: 58 | 59 | # use 'mean' or 'best' population member to aggregate sync. buffers 60 | synchronized_buffers_aggregation: 'mean' 61 | synchronized_buffers_freeze_after: 62 | 63 | # prefetching (only implemented for generation tasks) 64 | prefetch_task_tensors: false 65 | override_prefetched_tensors: false 66 | 67 | # logging 68 | eval_interval: 100 69 | early_stop_patience: -1 70 | log_interval: 50 71 | eval_iters: 1 72 | eval_only: false 73 | eval_candidate_samples: 74 | eval_candidate_temp: 75 | 76 | always_save_checkpoint: true 77 | keep_past_epoch_checkpoints_every: 78 | scratch: true 79 | record_advanced_eval_stats: true 80 | store_eval_results_locally: true 81 | record_per_task_eval_stats: false 82 | 83 | # precision 84 | use_amp: false 85 | init_from: 86 | dtype: 'bfloat16' 87 | 88 | -------------------------------------------------------------------------------- /memory_policy/deep_embedding_shared.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | import copy 4 | import math 5 | import numbers 6 | import abc 7 | import numpy as np 8 | from dataclasses import dataclass 9 | from typing import Optional, Tuple, Union, Dict, Callable, List 10 | 11 | from omegaconf import OmegaConf, DictConfig 12 | import hydra 13 | 14 | import torch 15 | from torch import nn 16 | import torch.utils.checkpoint 17 | 18 | 19 | def convert_to_tensor( 20 | el: Union[List[float], np.ndarray, torch.Tensor], 21 | ) -> torch.Tensor: 22 | if isinstance(el, torch.Tensor): 23 | return el 24 | else: 25 | el = torch.tensor(el) 26 | return el 27 | 28 | def cos_sin_seq_embeddings( 29 | length, 30 | embed_dim, 31 | max_freq=10000, 32 | 33 | 34 | ): 35 | assert embed_dim % 2 == 0, 'Embedding dimension should be even ' 36 | positions = np.arange(length) 37 | embed_dim_per_op = embed_dim // 2 38 | 39 | 40 | freq_coeff = np.arange(embed_dim_per_op, dtype=np.float64)/embed_dim_per_op 41 | freq_coeff = 1/(max_freq**freq_coeff) 42 | 43 | freq_values = np.expand_dims(positions, axis=-1)*freq_coeff 44 | embeddings = np.concatenate( 45 | [np.sin(freq_values), np.cos(freq_values)], axis=-1) 46 | return embeddings 47 | 48 | class Embedding(nn.Module, abc.ABC): 49 | def __init__(self, embed_dim): 50 | nn.Module.__init__(self=self,) 51 | self.embed_dim = embed_dim 52 | if embed_dim is not None: 53 | self.set_embed_dim(embed_dim=embed_dim) 54 | 55 | def set_embed_dim(self, embed_dim): 56 | self.embed_dim: int = embed_dim 57 | assert self.embed_dim is not None, 'make sure embed_dim is not None' 58 | 59 | @abc.abstractmethod 60 | def forward(self, x: torch.Tensor) -> torch.Tensor: 61 | raise NotImplementedError 62 | 63 | 64 | class PositionalEmbedding(Embedding): 65 | def __init__(self, max_position_id, embed_dim, max_freq): 66 | self.max_position_id = max_position_id 67 | self.max_freq = max_freq 68 | Embedding.__init__(self=self, embed_dim=embed_dim) 69 | 70 | 71 | def set_embed_dim(self, embed_dim): 72 | Embedding.set_embed_dim(self=self, embed_dim=embed_dim) 73 | embeddings = cos_sin_seq_embeddings( 74 | 75 | 76 | length=self.max_position_id + 1, 77 | embed_dim=embed_dim, 78 | max_freq=self.max_freq, 79 | ).astype('float32') 80 | self.register_buffer('embeddings', torch.tensor(embeddings)) 81 | 82 | def forward(self, x: torch.Tensor) -> torch.Tensor: 83 | 84 | 85 | 86 | 87 | 88 | emb = self.embeddings[x] 89 | 90 | return emb 91 | 92 | -------------------------------------------------------------------------------- /utils_hydra.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig 2 | 3 | from torch.nn import functional as F 4 | import torch 5 | import transformers 6 | import numpy as np 7 | 8 | 9 | from hydra import compose, initialize 10 | from main import make_eval_model, make_task_sampler, wandb_init 11 | import omegaconf 12 | import hydra 13 | import time 14 | 15 | 16 | # File containing imports to be used in hydra configs/for instantiating hydra 17 | # objects outside main 18 | 19 | def initialize_cfg( 20 | config_path="cfgs", 21 | hydra_overrides: dict = {}, 22 | log_yaml: bool = False, 23 | 24 | ): 25 | with initialize(version_base=None, config_path=config_path, 26 | job_name="test_app"): 27 | cfg = compose(config_name="config", 28 | overrides=hydra_overrides, 29 | ) 30 | if log_yaml: 31 | print('Loading the following configurations:') 32 | print(omegaconf.OmegaConf.to_yaml(cfg)) 33 | return cfg 34 | 35 | 36 | def load_run_cfgs_trainer( 37 | run_file_path_location: str, 38 | hydra_overrides_from_dict: dict = dict( 39 | wandb_log="false", 40 | wandb_project="scratch", 41 | ), 42 | task_sampler_kwargs: dict = dict( 43 | tasks=["lb/passage_retrieval_en"], 44 | ), 45 | config_path="cfgs", 46 | batch_size: int = 1, 47 | ): 48 | print(f"Loading model specified in {run_file_path_location}") 49 | start_time = time.time() 50 | hydra_overrides = [ 51 | f"run@_global_={run_file_path_location}", 52 | ] 53 | 54 | hydra_overrides_from_dict = [f"{k}={v}" for k, v in 55 | hydra_overrides_from_dict.items()] 56 | 57 | hydra_overrides = hydra_overrides + hydra_overrides_from_dict 58 | 59 | cfg = initialize_cfg( 60 | config_path=config_path, 61 | hydra_overrides=hydra_overrides, 62 | ) 63 | 64 | (memory_policy, memory_model, memory_evaluator, evolution_algorithm, 65 | auxiliary_loss) = make_eval_model(cfg=cfg) 66 | 67 | task_sampler = make_task_sampler(cfg=cfg, **task_sampler_kwargs) 68 | 69 | trainer = hydra.utils.instantiate( 70 | cfg.trainer, 71 | evaluation_model=memory_evaluator, 72 | task_sampler=task_sampler, 73 | evolution_algorithm=evolution_algorithm, 74 | auxiliary_loss=auxiliary_loss, 75 | ) 76 | params, buffers = trainer.sample_and_synchronize_params(best=True) 77 | memory_model = memory_model 78 | memory_model.set_memory_params(params=params) 79 | 80 | if memory_model.memory_policy_has_buffers_to_merge(): 81 | memory_model.load_buffers_dict(buffers_dict=buffers) 82 | 83 | batch_idxs = np.zeros([batch_size]) 84 | memory_policy.set_params_batch_idxs(batch_idxs) 85 | print("Time taken:", round(time.time() - start_time)) 86 | return trainer # contains all other models 87 | -------------------------------------------------------------------------------- /cfgs/policy/deep_scoring/attn.yaml: -------------------------------------------------------------------------------- 1 | scoring_network: ${generalized_scoring_net} 2 | 3 | generalized_scoring_net: 4 | _target_: memory_policy.GeneralizedScoring 5 | per_layer: ${per_layer} 6 | per_head: ${per_head} 7 | shared: ${scoring_shared} 8 | output_params: ${scoring_output_params} 9 | initializer: ${scoring_initializer} 10 | stateless_modules_list: ${stateless_modules_list} 11 | residual: ${residual} 12 | mult: ${mult} 13 | mult_nonlinearity: ${mult_nonlinearity} 14 | 15 | mlp_layer: 16 | _target_: stateless_parallel_modules.StatelessGeneralizedMLP 17 | input_features: ${scoring_attn_output_dim} 18 | hidden_features: ${scoring_mlp_hidden_features} 19 | hidden_depth: ${scoring_mlp_hidden_depth} 20 | output_features: ${scoring_mlp_output_features} 21 | bias: ${scoring_mlp_bias} 22 | non_linearity: ${scoring_mlp_non_linearity} 23 | residual: ${scoring_mlp_residual} 24 | residual_first: ${scoring_mlp_residual_first} 25 | 26 | scoring_mlp_hidden_features: 27 | scoring_mlp_hidden_depth: 1 28 | scoring_mlp_output_features: 1 29 | scoring_mlp_bias: true 30 | scoring_mlp_non_linearity: relu 31 | scoring_mlp_residual: true 32 | scoring_mlp_residual_first: false 33 | 34 | 35 | attention_layer: 36 | _target_: stateless_parallel_modules.StatelessAttention 37 | attention_params: ${attention_params} 38 | 39 | attention_params: 40 | _target_: stateless_parallel_modules.StatelessAttentionParams 41 | input_dim: 42 | hidden_dim: ${scoring_attn_hidden_dim} 43 | output_dim: ${scoring_attn_output_dim} 44 | num_heads: ${scoring_attn_num_heads} 45 | bias: ${scoring_attn_bias} 46 | max_position_id: ${max_position_id} 47 | use_rope: ${scoring_attn_use_rope} 48 | rope_theta: ${scoring_attn_rope_theta} 49 | masking_strategy: ${scoring_attn_masking_strategy} 50 | 51 | scoring_attn_hidden_dim: 32 52 | scoring_attn_output_dim: 53 | scoring_attn_num_heads: 4 54 | scoring_attn_bias: false 55 | scoring_attn_use_rope: false 56 | scoring_attn_rope_theta: 10000 57 | scoring_attn_masking_strategy: 58 | 59 | scoring_output_params: 60 | _target_: memory_policy.ComponentOutputParams 61 | requires_recomputation: ${scoring_requires_recomputation} 62 | reduction_mode: ${scoring_reduction_mode} 63 | ema_params: ${scoring_ema_params} 64 | output_past_non_reduced_history: ${scoring_output_past_non_reduced_history} 65 | max_non_reduced_history_len: ${scoring_max_non_reduced_history_len} 66 | 67 | scoring_ema_params: 68 | _target_: memory_policy.EMAParams 69 | coeff: ${scoring_ema_coeff} 70 | learned: ${scoring_reduction_learned} 71 | reduction_stride: ${hop_length} 72 | 73 | # MLP params 74 | scoring_shared: true 75 | # Output params. 76 | scoring_requires_recomputation: true 77 | scoring_reduction_mode: ema 78 | # Ema params. 79 | scoring_ema_coeff: 0.975 80 | scoring_reduction_learned: false 81 | # -- 82 | scoring_output_past_non_reduced_history: false 83 | scoring_max_non_reduced_history_len: 84 | # -- 85 | scoring_initializer: 0 86 | stateless_modules_list: 87 | - ${attention_layer} 88 | - ${mlp_layer} 89 | residual: true 90 | mult: false 91 | mult_nonlinearity: 92 | 93 | # -- 94 | 95 | 96 | scoring_mp_log_name: bam-no-mult -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |
4 | An Evolved Universal Transformer Memory
5 |

6 | 7 |

8 | 📄 [Paper] | 9 | 🤗 [Hugging Face] 10 | 📁 [Dataset] 11 |

12 | 13 | ## Installation 14 | 15 | We provide means to install this repository with [conda](https://docs.conda.io/projects/conda/en/latest/index.html): 16 | 17 | For the full set of dependencies with fixed versions (provided to ensure some level of long-term reproducibility): 18 | 19 | ```bash 20 | conda env create --file=env.yaml 21 | ``` 22 | 23 | For a more minimal and less constrained set of dependencies (for future development/extensions): 24 | 25 | ```bash 26 | conda env create --file=env_minimal.yaml 27 | ``` 28 | 29 | ## Usage 30 | 31 | ### Training 32 | 33 | Training following the incremental setup described in our work can be replicated via the following [hydra](https://hydra.cc/) commands: 34 | 35 | stage 1 training: 36 | ```bash 37 | torchrun --standalone --nproc_per_node=$NUM_OF_GPUs main.py run@_global_=namm_bam_i1.yaml 38 | ``` 39 | 40 | stage 2 training: 41 | ```bash 42 | torchrun --standalone --nproc_per_node=$NUM_OF_GPUs main.py run@_global_=namm_bam_i2.yaml init_from='path/to/stage1/results/ckpt.pt' 43 | ``` 44 | 45 | stage 3 training: 46 | ```bash 47 | torchrun --standalone --nproc_per_node=$NUM_OF_GPUs main.py run@_global_=namm_bam_i3.yaml init_from='path/to/stage2/results/ckpt.pt' 48 | ``` 49 | 50 | ### Evaluation 51 | 52 | Evaluating trained NAMMs on the full set of LongBench tasks can be replicated for both NAMMs with the following command: 53 | 54 | ```bash 55 | torchrun --standalone --nproc_per_node=$NUM_OF_GPUs main.py run@_global_=namm_bam_eval.yaml init_from='path/to/results/ckpt.pt' 56 | ``` 57 | 58 | Evaluating trained NAMMs on the full set of ChouBun tasks can be replicated with the following command: 59 | 60 | ```bash 61 | torchrun --standalone --nproc_per_node=$NUM_OF_GPUs main.py run@_global_=namm_bam_eval_choubun.yaml init_from='path/to/results/ckpt.pt' 62 | ``` 63 | 64 | ### Additional notes 65 | 66 | Using [wandb](https://wandb.ai/) to log the results (through the hydra setting wandb_log=true) requires authenticating to the wandb server via the following command: 67 | 68 | ```bash 69 | wandb login 70 | ``` 71 | 72 | and using your account's API key (which you should be able to find [here](https://wandb.ai/authorize)) 73 | 74 | ### Gated models (e.g., Llama) 75 | 76 | Using gated models requires authenticating to the hugging face hub by running: 77 | 78 | ```bash 79 | huggingface-cli login 80 | ``` 81 | 82 | and using your account's access tokens (which you should be able to find [here](https://huggingface.co/settings/tokens)) 83 | 84 | 85 | ## Bibtex 86 | 87 | To cite our work, you can use the following: 88 | 89 | ``` 90 | @article{sakana2024memory, 91 | title={An Evolved Universal Transformer Memory}, 92 | author={Edoardo Cetin and Qi Sun and Tianyu Zhao and Yujin Tang}, 93 | year={2024}, 94 | eprint={2410.13166}, 95 | archivePrefix={arXiv}, 96 | primaryClass={cs.LG}, 97 | url={https://arxiv.org/abs/2410.13166}, 98 | } 99 | ``` 100 | 101 | -------------------------------------------------------------------------------- /cfgs/run/base_memory_policy/deep/bam/base_bam.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model/wrapped_llm@_global_: llama3-8b-rope-x4NTK 4 | - override /task@_global_: passage_retrieval_en 5 | - override /evolution@_global_: cma_es 6 | - override /policy@_global_: deep 7 | - override /policy/deep_embedding@_global_: attn_spec_norm 8 | - override /policy/deep_scoring@_global_: bam 9 | - override /policy/deep_selection@_global_: binary 10 | - override /typing@_global_: half_attn 11 | - override /auxiliary_loss@_global_: none 12 | - _self_ 13 | 14 | run_name_suffix: '' 15 | 16 | per_head: false 17 | per_layer: false 18 | 19 | # ATTN params 20 | scoring_shared: true 21 | # Output params. 22 | scoring_requires_recomputation: true 23 | scoring_reduction_mode: 24 | # Ema params. 25 | scoring_ema_coeff: 0.99 26 | scoring_reduction_learned: false 27 | # -- 28 | scoring_output_past_non_reduced_history: false 29 | scoring_max_non_reduced_history_len: 30 | 31 | 32 | # -- 33 | # increasing the hidden dim, increases time cost marginally 34 | scoring_attn_hidden_dim: 16 35 | scoring_attn_output_dim: 36 | scoring_attn_num_heads: 1 37 | scoring_attn_bias: false 38 | scoring_attn_use_rope: false 39 | scoring_attn_rope_theta: 50000 40 | scoring_attn_masking_strategy: backward 41 | #-- 42 | 43 | 44 | # Spectrogram params. 45 | embedding_shared: true 46 | # Output params. 47 | embedding_requires_recomputation: true 48 | embedding_reduction_mode: ema 49 | # Ema params. 50 | embedding_ema_coeff: 0.99 51 | embedding_reduction_learned: false 52 | # -- 53 | embedding_output_past_non_reduced_history: false 54 | embedding_max_non_reduced_history_len: 55 | # -- 56 | # Stft params 57 | n_fft: 32 58 | hop_length: 16 59 | window_fn: 60 | _target_: utils_hydra.torch.hann_window 61 | window_length: ${n_fft} 62 | periodic: true 63 | window_fn_log_name: hann 64 | pad_mode: constant 65 | output_magnitudes: true 66 | # -- 67 | scoring_initializer: 0 68 | 69 | scoring_mlp_bias: true 70 | 71 | 72 | cache_size: 73 | 74 | 75 | always_save_checkpoint: true 76 | 77 | wandb_log: true #false # NOTE 78 | 79 | log_misc: false 80 | wandb_project: memory_evolution_hf 81 | 82 | wandb_run_name: ${tasks_log_name}-${evolution_algorithm_log_name}-shared-${pop_size}pop-${samples_batch_size}qs-${memory_policy_fixed_delay}fixDel-${run_name_suffix} 83 | wandb_group_name: ${llm_log_name}/${mp_log_name} 84 | 85 | add_bos_token: true # to match ref. scores setting 86 | score_normalization_reference: 'lb_reference_scores/per_request/llama3-8b-32k-x2NTK2alpha-v2.json' 87 | 88 | eval_max_batch_size: 32 89 | per_timestep_loglikelihood: false 90 | 91 | batch_size: 1 # NOTE: super low for the moment... 92 | 93 | max_iters: 1000 94 | 95 | # total samples per step = 64 x 3 (tasks) x 16 (pop) = 3072 96 | pop_size: 16 97 | 98 | # CMA-ES params 99 | elite_ratio: 0.5 100 | c_m: 1.0 101 | init_sigma: 0.065 102 | 103 | pop_accumulation_steps: 1 104 | score_aggregation: 'mean' 105 | 106 | # logging 107 | eval_samples_batch_size: 108 | 109 | eval_interval: 10 110 | early_stop_patience: -1 111 | log_interval: 1 112 | eval_iters: 1 113 | 114 | # precision 115 | use_amp: false 116 | dtype: bfloat16 117 | 118 | init_mean: 0.0 119 | init_stdev: 1.0 120 | 121 | allow_distributed_eval: true 122 | memory_policy_fixed_delay: 512 123 | max_new_tokens: 124 | 125 | synchronized_buffers_freeze_after: 1 126 | 127 | prefer_mean_to_best: true 128 | samples_batch_size: 64 129 | 130 | start_recency_from: 0 -------------------------------------------------------------------------------- /memory_llms/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Optional 3 | from memory_policy import MemoryPolicy 4 | 5 | 6 | class MemoryModelWrapper(abc.ABC): 7 | def __init__(self, 8 | config, 9 | memory_policy: MemoryPolicy, 10 | registration_kwargs: dict, 11 | # if set, recomputes memory policy every fixed number of steps 12 | memory_policy_fixed_delay: Optional[int] = None, 13 | output_attentions_full_precision: bool = True, 14 | max_new_tokens: Optional[int] = None, 15 | ): 16 | self.memory_policy: MemoryPolicy = memory_policy 17 | self.registration_kwargs = registration_kwargs 18 | self.config = config 19 | self.memory_policy.register_new_memory_model( 20 | self.config, self.registration_kwargs) 21 | self.memory_policy.finalize_registration() 22 | self.memory_policy_fixed_delay = memory_policy_fixed_delay 23 | 24 | if not hasattr(self, 'max_new_tokens'): 25 | self.max_new_tokens = max_new_tokens 26 | 27 | if (self.memory_policy_fixed_delay is not None) and ( 28 | self.max_new_tokens is not None): 29 | assert self.memory_policy_fixed_delay % self.max_new_tokens == 0 30 | self.output_attentions_full_precision = output_attentions_full_precision 31 | self._past_length = 0 32 | 33 | def load_partial_state_dict(self, state_dict): 34 | current_state = self.state_dict() 35 | for name in self.base_model_param_keys: 36 | target_param = state_dict[name] 37 | current_state[name].copy_(target_param.data) 38 | 39 | 40 | def swap_memory_policy(self, new_memory_policy: MemoryPolicy): 41 | self.memory_policy = new_memory_policy 42 | self.memory_policy.register_new_memory_model( 43 | self.config, self.registration_kwargs) 44 | self.memory_policy.finalize_registration() 45 | self.memory_requires_attn = self.memory_policy.requires_attn_scores 46 | self.memory_requires_queries = self.memory_policy.requires_queries 47 | 48 | @property 49 | def cache_size(self,): 50 | return self.memory_policy.cache_size 51 | 52 | def set_memory_params(self, params) -> None: 53 | self.memory_policy.set_params(params=params) 54 | 55 | def get_memory_params(self,): 56 | return self.memory_policy.get_layer_params() 57 | 58 | def set_memory_params_batch_idxs(self, param_idxs) -> None: 59 | self.memory_policy.set_params_batch_idxs(param_idxs=param_idxs) 60 | 61 | def get_param_size(self,): 62 | return self.memory_policy.param_size 63 | 64 | def get_param_stats(self,): 65 | return self.memory_policy.get_param_stats() 66 | 67 | def get_buffers_list(self,): 68 | return self.memory_policy.get_buffers_list() 69 | 70 | def self_merge(self,): 71 | return self.memory_policy.self_merge() 72 | 73 | def merge_buffers_list(self, buffers_to_merge): 74 | return self.memory_policy.merge_buffers_list( 75 | buffers_to_merge=buffers_to_merge) 76 | 77 | def receive_buffers_list(self, buffers_list): 78 | return self.memory_policy.receive_buffers_list( 79 | buffers_list=buffers_list) 80 | 81 | # functions to save (e.g., normalization) buffers along with the checkpoints 82 | def get_buffers_dict(self,): 83 | return self.memory_policy.get_buffers_dict() 84 | 85 | def load_buffers_dict(self, buffers_dict): 86 | return self.memory_policy.load_buffers_dict(buffers_dict=buffers_dict) 87 | 88 | def memory_policy_has_buffers_to_merge(self,): 89 | return self.memory_policy._has_buffers_to_merge 90 | 91 | # functions to signal the memory policy that ut is being trained/evaluated 92 | def training_mode(self,): 93 | self.memory_policy.training_mode() 94 | 95 | def evaluation_mode(self,): 96 | self.memory_policy.evaluation_mode() 97 | 98 | def freeze_sync_buffers(self, freeze=True): 99 | self.memory_policy.freeze_sync_buffers(freeze=freeze) 100 | 101 | def unfreeze_sync_buffers(self,): 102 | self.memory_policy.unfreeze_sync_buffers() 103 | # 104 | def are_sync_buffers_frozen(self,): 105 | return self.memory_policy.are_sync_buffers_frozen() 106 | 107 | class MemoryDecoderLayer(abc.ABC): 108 | 109 | abc.abstractmethod 110 | def __init__(self,): 111 | pass 112 | 113 | class MemoryAttention(abc.ABC): 114 | abc.abstractmethod 115 | def __init__(self,): 116 | pass 117 | -------------------------------------------------------------------------------- /LongBench/config/dataset2prompt.json: -------------------------------------------------------------------------------- 1 | { 2 | "narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {context}\n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:", 3 | "qasper": "You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nArticle: {context}\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:", 4 | "multifieldqa_en": "Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", 5 | "multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n{context}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:", 6 | "hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", 7 | "2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", 8 | "musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", 9 | "dureader": "请基于给定的文章回答下述问题。\n\n文章:{context}\n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:", 10 | "gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}\n\nNow, write a one-page summary of the report.\n\nSummary:", 11 | "qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n{context}\n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:", 12 | "multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\nSummary:", 13 | "vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n{context}\n\n会议总结:", 14 | "trec": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}\n{input}", 15 | "triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}", 16 | "samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}\n\n{input}", 17 | "lsht": "请判断给定新闻的类别,下面是一些例子。\n\n{context}\n{input}", 18 | "passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ", 19 | "passage_retrieval_en": "Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}\n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like \"Paragraph 1\", \"Paragraph 2\", etc.\n\nThe answer is: ", 20 | "passage_retrieval_zh": "以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}\n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是\"段落1\",\"段落2\"等格式\n\n答案是:", 21 | "lcc": "Please complete the code given below. \n{context}Next line of code:\n", 22 | "repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n" 23 | } -------------------------------------------------------------------------------- /longbench_metrics.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | 4 | import jieba 5 | from fuzzywuzzy import fuzz 6 | import difflib 7 | 8 | from typing import List 9 | from collections import Counter 10 | from rouge import Rouge 11 | 12 | # copied from https://github.com/THUDM/LongBench/blob/main/metrics.py 13 | 14 | 15 | def normalize_answer(s): 16 | """Lower text and remove punctuation, articles and extra whitespace.""" 17 | 18 | def remove_articles(text): 19 | return re.sub(r"\b(a|an|the)\b", " ", text) 20 | 21 | def white_space_fix(text): 22 | return " ".join(text.split()) 23 | 24 | def remove_punc(text): 25 | exclude = set(string.punctuation) 26 | return "".join(ch for ch in text if ch not in exclude) 27 | 28 | def lower(text): 29 | return text.lower() 30 | 31 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 32 | 33 | 34 | def normalize_zh_answer(s): 35 | """Lower text and remove punctuation, extra whitespace.""" 36 | 37 | def white_space_fix(text): 38 | return "".join(text.split()) 39 | 40 | def remove_punc(text): 41 | cn_punctuation = ( 42 | "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》" 43 | "「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏.") 44 | all_punctuation = set(string.punctuation + cn_punctuation) 45 | return "".join(ch for ch in text if ch not in all_punctuation) 46 | 47 | def lower(text): 48 | return text.lower() 49 | 50 | return white_space_fix(remove_punc(lower(s))) 51 | 52 | 53 | def count_score(prediction, ground_truth, **kwargs): 54 | numbers = re.findall(r"\d+", prediction) 55 | right_num = 0 56 | for number in numbers: 57 | if str(number) == str(ground_truth): 58 | right_num += 1 59 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) 60 | return float(final_score) 61 | 62 | 63 | def retrieval_score(prediction, ground_truth, **kwargs): 64 | pattern = r'Paragraph (\d+)' 65 | matches = re.findall(pattern, ground_truth) 66 | ground_truth_id = matches[0] 67 | numbers = re.findall(r"\d+", prediction) 68 | right_num = 0 69 | for number in numbers: 70 | if str(number) == str(ground_truth_id): 71 | right_num += 1 72 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) 73 | return float(final_score) 74 | 75 | 76 | def retrieval_zh_score(prediction, ground_truth, **kwargs): 77 | pattern = r'段落(\d+)' 78 | matches = re.findall(pattern, ground_truth) 79 | ground_truth_id = matches[0] 80 | numbers = re.findall(r"\d+", prediction) 81 | right_num = 0 82 | for number in numbers: 83 | if str(number) == str(ground_truth_id): 84 | right_num += 1 85 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) 86 | return float(final_score) 87 | 88 | 89 | def code_sim_score(prediction, ground_truth, **kwargs): 90 | all_lines = prediction.lstrip('\n').split('\n') 91 | prediction = "" 92 | for line in all_lines: 93 | if ('`' not in line) and ('#' not in line) and ('//' not in line): 94 | prediction = line 95 | break 96 | return (fuzz.ratio(prediction, ground_truth) / 100) 97 | 98 | 99 | def classification_score(prediction, ground_truth, **kwargs): 100 | em_match_list = [] 101 | all_classes = kwargs["all_classes"] 102 | for class_name in all_classes: 103 | if class_name in prediction: 104 | em_match_list.append(class_name) 105 | for match_term in em_match_list: 106 | if match_term in ground_truth and match_term != ground_truth: 107 | em_match_list.remove(match_term) 108 | if ground_truth in em_match_list: 109 | score = (1.0 / len(em_match_list)) 110 | else: 111 | score = 0.0 112 | return score 113 | 114 | 115 | def rouge_score(prediction, ground_truth, **kwargs): 116 | rouge = Rouge() 117 | try: 118 | scores = rouge.get_scores([prediction], [ground_truth], avg=True) 119 | except: 120 | return 0.0 121 | return scores["rouge-l"]["f"] 122 | 123 | 124 | def rouge_zh_score(prediction, ground_truth, **kwargs): 125 | prediction = " ".join(list(jieba.cut(prediction, cut_all=False))) 126 | ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False))) 127 | score = rouge_score(prediction, ground_truth) 128 | return score 129 | 130 | 131 | def f1_score(prediction, ground_truth, **kwargs): 132 | common = Counter(prediction) & Counter(ground_truth) 133 | num_same = sum(common.values()) 134 | if num_same == 0: 135 | return 0 136 | precision = 1.0 * num_same / len(prediction) 137 | recall = 1.0 * num_same / len(ground_truth) 138 | f1 = (2 * precision * recall) / (precision + recall) 139 | return f1 140 | 141 | 142 | def qa_f1_score(prediction, ground_truth, **kwargs): 143 | normalized_prediction = normalize_answer(prediction) 144 | normalized_ground_truth = normalize_answer(ground_truth) 145 | 146 | prediction_tokens = normalized_prediction.split() 147 | ground_truth_tokens = normalized_ground_truth.split() 148 | return f1_score(prediction_tokens, ground_truth_tokens) 149 | 150 | 151 | def qa_f1_zh_score(prediction, ground_truth, **kwargs): 152 | prediction_tokens = list(jieba.cut(prediction, cut_all=False)) 153 | ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False)) 154 | prediction_tokens = [normalize_zh_answer( 155 | token) for token in prediction_tokens] 156 | ground_truth_tokens = [normalize_zh_answer( 157 | token) for token in ground_truth_tokens] 158 | prediction_tokens = [ 159 | token for token in prediction_tokens if len(token) > 0] 160 | ground_truth_tokens = [ 161 | token for token in ground_truth_tokens if len(token) > 0] 162 | return f1_score(prediction_tokens, ground_truth_tokens) 163 | -------------------------------------------------------------------------------- /memory_evolution/base.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | from abc import ABC, abstractmethod 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | 7 | 8 | class MemoryEvolution(nn.Module, ABC): 9 | '''Strategy for parameterizing the queries is independent of memory 10 | and strategy for optimizing the queries''' 11 | 12 | def __init__(self, 13 | pop_size, # equal to n_replicas (of dividing) 14 | param_size, 15 | score_processing: tp.Optional[str] = None, 16 | param_clip: tp.Optional[float] = None, 17 | clip_param_min: tp.Optional[float] = None, 18 | clip_param_max: tp.Optional[float] = None, 19 | prefer_mean_to_best: bool = False, 20 | 21 | ): 22 | nn.Module.__init__(self=self,) 23 | self.stored_buffers_to_save = nn.ParameterDict() 24 | self.best_stored_buffers_to_save = nn.ParameterDict() 25 | self.pop_size = pop_size 26 | self.param_size = param_size 27 | self.clip_param_min = clip_param_min 28 | self.clip_param_max = clip_param_max 29 | 30 | if clip_param_min is None and param_clip is not None: 31 | self.clip_param_min = -param_clip 32 | if clip_param_max is None and param_clip is not None: 33 | self.clip_param_max = param_clip 34 | 35 | if isinstance(score_processing, str): 36 | score_processing = score_processing.lower() 37 | assert score_processing in ['none', 'rank'] 38 | if score_processing == 'rank': 39 | scores_values = torch.arange(pop_size)/(pop_size - 1) - 0.5 40 | self.register_buffer( 41 | name='scores_values', 42 | tensor=scores_values, 43 | persistent=False, 44 | ) 45 | 46 | self.score_processing = score_processing 47 | 48 | self.prefer_mean_to_best = prefer_mean_to_best 49 | 50 | @abstractmethod 51 | def ask(self,) -> torch.Tensor: 52 | """Ask the algorithm for a population of parameters and update internal 53 | state.""" 54 | raise NotImplementedError() 55 | 56 | @abstractmethod 57 | def tell(self, fitness) -> None: 58 | """Report the fitness of the population to the algorithm.""" 59 | raise NotImplementedError() 60 | 61 | @property 62 | def best_params(self) -> torch.Tensor: 63 | raise NotImplementedError() 64 | 65 | @property 66 | def best_buffer(self) -> torch.Tensor: 67 | raise NotImplementedError() 68 | 69 | @abstractmethod 70 | def sample_candidates( 71 | self, 72 | num_candidates: int, 73 | temperature: float = 1.0, 74 | ) -> torch.Tensor: 75 | """Ask the algorithm for a population of parameters, temperature should 76 | optionally indicate how much the parameters should differ from the 77 | best/mean solution, depending on the specific algorithm.""" 78 | raise NotImplementedError() 79 | 80 | def process_scores(self, fitness) -> torch.Tensor: 81 | """Preprocess the scores""" 82 | # fitness should be a pop_size-dim. vector 83 | if self.score_processing is None: 84 | return fitness 85 | elif self.score_processing == 'none': 86 | return fitness 87 | elif self.score_processing == 'rank': 88 | sorted_idxs = fitness.argsort(-1).argsort(-1) # from min to max 89 | # from -0.5 to 0.5 90 | scores = self.scores_values[sorted_idxs] 91 | return scores 92 | else: 93 | raise NotImplementedError 94 | 95 | def store_best_params(self, x, fitness=None): 96 | raise NotImplementedError 97 | 98 | def store_best_buffers(self, buffers): 99 | self.store_buffers(buffers=buffers, best=True) 100 | 101 | def forward(self, 102 | ): 103 | '''Computes queries''' 104 | return self.ask() 105 | 106 | def get_stats(self, 107 | ): 108 | '''Returns statistics for logging''' 109 | return {} 110 | 111 | def load_init(self, init_param): 112 | '''Loads custom initialization values''' 113 | pass 114 | 115 | def store_buffers( 116 | self, 117 | buffers: tp.Dict[str, torch.Tensor] = None, 118 | best: bool = False, 119 | ): 120 | if best: 121 | self.best_stored_buffers_to_save.update( 122 | {n: nn.Parameter(v, requires_grad=False) 123 | for n, v in buffers.items()}) 124 | else: 125 | self.stored_buffers_to_save.update( 126 | {n: nn.Parameter(v, requires_grad=False) 127 | for n, v in buffers.items()}) 128 | # buffers) 129 | 130 | def get_stored_buffers(self, best=False): 131 | if best and (not self.prefer_mean_to_best): 132 | tensor_dict = {k: v.data.clone() for k, v 133 | in self.best_stored_buffers_to_save.items()} 134 | else: 135 | tensor_dict = {k: v.data.clone() for k, v 136 | in self.stored_buffers_to_save.items()} 137 | return tensor_dict 138 | 139 | 140 | class DummyEvolution(MemoryEvolution): 141 | '''To be returned if no params to optimize, i.e., param_size=0''' 142 | 143 | def __init__(self, 144 | pop_size, 145 | param_size, 146 | param_clip, 147 | score_processing): 148 | nn.Module.__init__(self,) 149 | assert param_size == 0 150 | self.pop_size = pop_size 151 | self.param_size = param_size 152 | self.param_clip = param_clip 153 | self.register_buffer( 154 | name='dummy_params', tensor=torch.zeros( 155 | [self.pop_size, self.param_size],), persistent=False,) 156 | 157 | def ask(self) -> torch.Tensor: 158 | return self.dummy_params 159 | 160 | def tell(self, fitness) -> None: 161 | """Report the fitness of the population to the algorithm.""" 162 | pass 163 | 164 | @property 165 | def best_params(self) -> torch.Tensor: 166 | return self.dummy_params[0] 167 | 168 | def forward(self, 169 | ): 170 | '''Computes queries''' 171 | return self.ask() 172 | 173 | def sample_candidates( 174 | self, 175 | num_candidates: int, 176 | temperature: float = 1.0, 177 | ) -> torch.Tensor: 178 | pass 179 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import numpy as np 5 | from torch.distributed import init_process_group, destroy_process_group 6 | from omegaconf import DictConfig, OmegaConf 7 | import hydra 8 | import copy 9 | import datetime 10 | 11 | 12 | def ddp_setup( 13 | backend: str = "nccl", 14 | ddp_timeout_limit: str = '0:6:0', # days:hours:minutes 15 | ): 16 | days, hours, seconds = ddp_timeout_limit.split(':') 17 | timeout_delta = datetime.timedelta( 18 | days=int(days), hours=int(hours), seconds=int(seconds)) 19 | init_process_group(backend=backend, timeout=timeout_delta) 20 | torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) 21 | 22 | 23 | def get_dist_info(): 24 | local_rank = int(os.getenv('OMPI_COMM_WORLD_LOCAL_RANK', -1)) 25 | global_rank = int(os.getenv('OMPI_COMM_WORLD_RANK', -1)) 26 | world_size = int(os.getenv('OMPI_COMM_WORLD_SIZE', '-1')) 27 | 28 | local_rank = int(os.getenv( 29 | 'LOCAL_RANK', -1)) if local_rank == -1 else local_rank 30 | global_rank = int(os.getenv( 31 | 'RANK', -1)) if global_rank == -1 else global_rank 32 | world_size = int(os.getenv( 33 | 'WORLD_SIZE', -1)) if world_size == -1 else world_size 34 | 35 | if global_rank != -1: 36 | os.environ['RANK'] = str(global_rank) # Needed by torch distributed. 37 | os.environ['LOCAL_RANK'] = str(local_rank) 38 | os.environ['WORLD_SIZE'] = str(world_size) 39 | 40 | return local_rank, global_rank, world_size 41 | 42 | 43 | def wandb_init(cfg): 44 | import wandb 45 | config_dict = OmegaConf.to_container( 46 | # allow missing values for memory experiments 47 | cfg, resolve=True, throw_on_missing=False, 48 | ) 49 | # wandb has a 128-size character limit on the group name 50 | wandb.init( 51 | project=cfg.wandb_config.wandb_project, 52 | group=cfg.wandb_config.wandb_group_name[:127], 53 | name=cfg.wandb_config.wandb_run_name[:127], 54 | config=config_dict, 55 | ) 56 | return wandb 57 | 58 | 59 | def make_eval_model(cfg, log_prefix='...'): 60 | print(log_prefix + 'Instantialting llm...') 61 | pretrained_llm = hydra.utils.call(cfg.pretrained_llm, _convert_="object") 62 | tokenizer = hydra.utils.call(cfg.tokenizer) 63 | 64 | print(log_prefix + 'Instantialting memory policy...') 65 | memory_policy = hydra.utils.instantiate(cfg.memory_policy, 66 | _convert_="object") 67 | 68 | print(log_prefix + 'Instantialting memory llm...') 69 | memory_model = hydra.utils.instantiate( 70 | cfg.memory_model, model=pretrained_llm, memory_policy=memory_policy,) 71 | 72 | print(log_prefix + 'Instantialting evaluation module...') 73 | memory_evaluator = hydra.utils.instantiate( 74 | cfg.memory_evaluator, model=memory_model, tokenizer=tokenizer) 75 | 76 | print(log_prefix + 'Instantialting evolution module...') 77 | evolution_algorithm = hydra.utils.instantiate( 78 | cfg.evolution_algorithm, param_size=memory_policy.param_size, 79 | _recursive_=False) 80 | 81 | init_param = memory_policy.get_init_param_values() 82 | 83 | evolution_algorithm.load_init(init_param=init_param) 84 | 85 | if cfg.auxiliary_loss is not None: 86 | auxiliary_loss = hydra.utils.instantiate(cfg.auxiliary_loss, 87 | memory_policy=memory_policy) 88 | else: 89 | auxiliary_loss = None 90 | 91 | print(log_prefix + 'Finished instantiations.') 92 | return (memory_policy, memory_model, memory_evaluator, evolution_algorithm, 93 | auxiliary_loss) 94 | 95 | 96 | def make_task_sampler(cfg, log_prefix='', **task_sampler_kwargs): 97 | print(log_prefix + f'Instantiating tasks: {cfg.task_sampler.tasks}; with ' 98 | f' corresponding metrics: {cfg.task_sampler.metrics}') 99 | task_sampler = hydra.utils.instantiate( 100 | cfg.task_sampler, _convert_='none', 101 | **task_sampler_kwargs) 102 | return task_sampler 103 | 104 | 105 | def stochasticity_setup(cfg, seed_offset=0, log_prefix=''): 106 | print(log_prefix + f'Global rank used for seed offset {seed_offset}') 107 | np.random.seed(cfg.seed + seed_offset) 108 | torch.manual_seed(cfg.seed + seed_offset) 109 | random.seed(cfg.seed + seed_offset) 110 | 111 | # NOTE: likely can remove offset 112 | if cfg.deterministic_behavior: 113 | print('WARNING: training with deterministic behavior') 114 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" 115 | torch.backends.cudnn.benchmark = False 116 | torch.backends.cudnn.deterministic = True 117 | torch.use_deterministic_algorithms(True) 118 | 119 | 120 | @hydra.main(version_base=None, config_path='cfgs', config_name='config') 121 | def main(cfg: DictConfig): 122 | _, global_rank, n_ddp = get_dist_info() 123 | is_ddp = global_rank > -1 124 | if is_ddp: 125 | ddp_setup(backend=cfg.backend, ddp_timeout_limit=cfg.ddp_timeout_limit) 126 | master_process = global_rank == 0 127 | seed_offset = global_rank 128 | else: 129 | master_process = True 130 | seed_offset = 0 131 | 132 | if master_process: 133 | print(f"SHARED Working directory: {os.getcwd()}") 134 | print(f"SHARED Output directory: " + 135 | f"{hydra.core.hydra_config.HydraConfig.get().runtime.output_dir}") 136 | 137 | log_prefix = '' 138 | if is_ddp: 139 | log_prefix = f'RANK {global_rank} ({n_ddp} total): ' 140 | 141 | stochasticity_setup(cfg=cfg, seed_offset=seed_offset, 142 | log_prefix=log_prefix) 143 | 144 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 145 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 146 | 147 | with torch.no_grad(): 148 | (memory_policy, memory_model, memory_evaluator, evolution_algorithm, 149 | auxiliary_loss) = make_eval_model(cfg=cfg, log_prefix=log_prefix) 150 | 151 | task_sampler = make_task_sampler(cfg=cfg, log_prefix=log_prefix) 152 | 153 | trainer = hydra.utils.instantiate( 154 | cfg.trainer, 155 | evaluation_model=memory_evaluator, 156 | task_sampler=task_sampler, 157 | evolution_algorithm=evolution_algorithm, 158 | auxiliary_loss=auxiliary_loss, 159 | ) 160 | 161 | if cfg.wandb_config.wandb_log and master_process: 162 | wandb_init(cfg=cfg) 163 | 164 | with torch.no_grad(): 165 | trainer.train() 166 | 167 | if is_ddp: 168 | destroy_process_group() 169 | 170 | 171 | if __name__ == "__main__": 172 | with torch.no_grad(): 173 | main() 174 | -------------------------------------------------------------------------------- /stateless_parallel_modules/mlp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | import copy 4 | import math 5 | import numpy as np 6 | from dataclasses import dataclass 7 | from typing import Optional, Tuple, Union, Dict, Callable 8 | 9 | 10 | import torch 11 | from torch import nn 12 | import torch.utils.checkpoint 13 | 14 | from .base import ( 15 | StatelessGeneralizedOperation, GeneralizedLinear, 16 | StatelessGeneralizedModule) 17 | 18 | from utils import get_nonlinearity 19 | 20 | 21 | class StatelessGeneralizedMLP(StatelessGeneralizedModule): 22 | def __init__( 23 | self, 24 | input_features: Optional[int], 25 | hidden_features: Optional[int], 26 | output_features: Optional[int], 27 | # depth = 1, means a single linear operation 28 | hidden_depth: int, 29 | bias: bool, 30 | # base_parallel_operations: Optional[int] = None, 31 | non_linearity: Union[str, Callable] = 'relu', 32 | # add a residual connection for the hidden layers 33 | residual: bool = True, 34 | # add a residual connection from the first layer 35 | residual_first: bool = False, 36 | ): 37 | 38 | StatelessGeneralizedModule.__init__( 39 | self=self, 40 | input_features=input_features, 41 | output_features=output_features, 42 | init_module=True, 43 | ) 44 | 45 | self.hidden_features = hidden_features 46 | self.hidden_depth = hidden_depth 47 | self.bias = bias 48 | 49 | self.non_linearity = get_nonlinearity( 50 | nonlinearity=non_linearity, 51 | ) 52 | 53 | self.residual_first = residual_first 54 | self.residual = residual 55 | 56 | self.is_linear = hidden_depth == 1 57 | self.has_intermediate = hidden_depth > 2 58 | 59 | if (self.input_features is not None and 60 | self.output_features is not None and 61 | self.hidden_features is not None): 62 | print('instantiating ops') 63 | self.instantiate_and_setup_ops( 64 | input_features=input_features, 65 | hidden_features=hidden_features, 66 | output_features=output_features, 67 | preceding_module=None, 68 | default_output_features_mult=1, 69 | ) 70 | 71 | def instantiate_and_setup_ops( 72 | self, 73 | input_features: Optional[int] = None, 74 | hidden_features: Optional[int] = None, 75 | output_features: Optional[int] = None, 76 | preceding_module=None, 77 | default_output_features_mult: int = 1, 78 | **kwargs, 79 | ): 80 | 81 | if (self.input_features is None or 82 | self.output_features is None or 83 | self.hidden_features is None): 84 | 85 | self.instantiate_model( 86 | input_features=input_features, 87 | output_features=output_features, 88 | preceding_module=preceding_module, 89 | default_output_features_mult=default_output_features_mult, 90 | ) 91 | if self.hidden_features is None and hidden_features is not None: 92 | self.hidden_features = hidden_features 93 | elif self.hidden_features is None: 94 | print('Warning: hidden features not specified setting to ' + 95 | f'{self.output_features} (output features)') 96 | self.hidden_features = self.output_features 97 | 98 | if self.residual_first: 99 | # assert hidden_depth > 1 100 | assert self.input_features == self.hidden_features 101 | 102 | self.linear_op = GeneralizedLinear() 103 | 104 | in_dims = self.input_features 105 | self.input_output_features_tuple = [] 106 | for _ in range(self.hidden_depth - 1): 107 | out_dims = self.hidden_features 108 | self.input_output_features_tuple.append( 109 | (in_dims, out_dims)) 110 | in_dims = self.hidden_features 111 | out_dims = self.output_features 112 | self.input_output_features_tuple.append((in_dims, out_dims)) 113 | 114 | operation_list = [ 115 | self.linear_op for _ in range(self.hidden_depth)] 116 | operation_kwargs = dict( 117 | bias=self.bias, 118 | ) 119 | operation_kwargs_overrides_list = [ 120 | dict(in_features=in_dims, out_features=out_dims) 121 | for in_dims, out_dims in self.input_output_features_tuple] 122 | 123 | self.setup_operations( 124 | operations=operation_list, 125 | operation_kwargs=operation_kwargs, 126 | operation_kwargs_overrides_list=operation_kwargs_overrides_list, 127 | save_as_module_list=True, 128 | ) 129 | 130 | def forward( 131 | self, 132 | # parallel_operations x batch_size x input_features 133 | inputs: torch.Tensor, 134 | *args, 135 | n_parallel_dimensions: Optional[int] = None, 136 | **kwargs, 137 | ): 138 | 139 | if n_parallel_dimensions is not None: 140 | input_shape = inputs.shape 141 | inputs = inputs.flatten( 142 | start_dim=0, end_dim=n_parallel_dimensions-1) 143 | # this is guaranteed to be 3-dimensional 144 | inputs = inputs.flatten(start_dim=1, end_dim=-2) 145 | 146 | weight, bias = self.parameters_per_layer[0] 147 | # raise NotImplementedError 148 | h = inputs 149 | # TODO 150 | h_out = self.linear_op( 151 | input=h, 152 | weight=weight, 153 | bias=bias, 154 | parallel_operations=self.parallel_operations, 155 | # initial input already flattened above 156 | n_parallel_dimensions=None, 157 | ) 158 | 159 | if not self.is_linear: 160 | h_out = self.non_linearity(h_out) 161 | if self.residual_first: 162 | h = h + h_out 163 | else: 164 | h = h_out 165 | if self.has_intermediate: 166 | for weight, bias in self.parameters_per_layer[1:-1]: 167 | h_out = self.linear_op( 168 | input=h, 169 | weight=weight, 170 | bias=bias, 171 | parallel_operations=self.parallel_operations, 172 | # initial input already flattened above 173 | n_parallel_dimensions=None, 174 | ) 175 | h_out = self.non_linearity(h) 176 | if self.residual: 177 | h = h + h_out 178 | else: 179 | h = h_out 180 | weight, bias = self.parameters_per_layer[-1] 181 | h_out = self.linear_op( 182 | input=h, 183 | weight=weight, 184 | bias=bias, 185 | parallel_operations=self.parallel_operations, 186 | # initial input already flattened above 187 | n_parallel_dimensions=None, 188 | ) 189 | if n_parallel_dimensions is not None: 190 | h_out = h_out.view(*input_shape[:-1], self.output_features) 191 | return h_out 192 | -------------------------------------------------------------------------------- /utils_longbench.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import numpy as np 5 | 6 | from choubun_metrics import ( 7 | rouge_ja_score, 8 | qa_f1_ja_score 9 | ) 10 | from longbench_metrics import ( 11 | qa_f1_score, 12 | rouge_zh_score, 13 | qa_f1_zh_score, 14 | rouge_score, 15 | classification_score, 16 | retrieval_score, 17 | retrieval_zh_score, 18 | count_score, 19 | code_sim_score, 20 | ) 21 | 22 | dataset2metric = { 23 | # LongBench 24 | "narrativeqa": qa_f1_score, 25 | "qasper": qa_f1_score, 26 | "multifieldqa_en": qa_f1_score, 27 | "multifieldqa_zh": qa_f1_zh_score, 28 | "hotpotqa": qa_f1_score, 29 | "2wikimqa": qa_f1_score, 30 | "musique": qa_f1_score, 31 | "dureader": rouge_zh_score, 32 | "gov_report": rouge_score, 33 | "qmsum": rouge_score, 34 | "multi_news": rouge_score, 35 | "vcsum": rouge_zh_score, 36 | "trec": classification_score, 37 | "triviaqa": qa_f1_score, 38 | "samsum": rouge_score, 39 | "lsht": classification_score, 40 | "passage_retrieval_en": retrieval_score, 41 | "passage_count": count_score, 42 | "passage_retrieval_zh": retrieval_zh_score, 43 | "lcc": code_sim_score, 44 | "repobench-p": code_sim_score, 45 | # ChouBun 46 | "wiki_qa": qa_f1_ja_score, 47 | "edinet_qa": qa_f1_ja_score, 48 | "corp_sec_qa": qa_f1_ja_score, 49 | "corp_sec_sum": rouge_ja_score 50 | 51 | } 52 | 53 | # This is the customized building prompt for chat models 54 | 55 | 56 | def build_chat(lm, prompt): 57 | print('IN PROMPT') 58 | print(prompt) 59 | tokenizer = lm.tokenizer 60 | model_name = lm.model_name 61 | if "chatglm3" in model_name: 62 | tokenized_prompt = tokenizer.build_chat_input(prompt) 63 | elif "chatglm" in model_name: 64 | prompt = tokenizer.build_prompt(prompt) 65 | elif "longchat" in model_name or "vicuna" in model_name: 66 | raise NotImplementedError 67 | elif "llama2" in model_name: 68 | prompt = f"[INST]{prompt}[/INST]" 69 | elif "xgen" in model_name: 70 | header = ( 71 | "A chat between a curious human and an artificial intelligence " 72 | "assistant. The assistant gives helpful, detailed, and polite " 73 | "answers to the human's questions.\n\n" 74 | ) 75 | prompt = header + f" ### Human: {prompt}\n###" 76 | elif "internlm" in model_name: 77 | prompt = f"<|User|>:{prompt}\n<|Bot|>:" 78 | else: 79 | print 80 | raise NotImplementedError 81 | 82 | if not ("chatglm3" in model_name): 83 | tokenized_prompt = tokenizer(prompt, truncation=False, 84 | return_tensors="pt") 85 | return tokenized_prompt 86 | 87 | 88 | def parse_args(args=None): 89 | parser = argparse.ArgumentParser() 90 | parser.add_argument('--model', type=str, default=None) 91 | parser.add_argument('--e', action='store_true', 92 | help="Evaluate on LongBench-E") 93 | return parser.parse_args(args) 94 | 95 | 96 | def scorer_e(dataset, predictions, answers, lengths, all_classes): 97 | scores = {"0-4k": [], "4-8k": [], "8k+": []} 98 | for (prediction, ground_truths, length) in zip( 99 | predictions, answers, lengths): 100 | score = 0. 101 | if dataset in ["trec", "triviaqa", "samsum", "lsht"]: 102 | prediction = prediction.lstrip('\n').split('\n')[0] 103 | for ground_truth in ground_truths: 104 | score = max(score, dataset2metric[dataset]( 105 | prediction, ground_truth, all_classes=all_classes)) 106 | if length < 4000: 107 | scores["0-4k"].append(score) 108 | elif length < 8000: 109 | scores["4-8k"].append(score) 110 | else: 111 | scores["8k+"].append(score) 112 | for key in scores.keys(): 113 | scores[key] = round(100 * np.mean(scores[key]), 2) 114 | return scores 115 | 116 | 117 | def scorer(dataset, predictions, answers, all_classes): 118 | total_score = 0. 119 | for (prediction, ground_truths) in zip(predictions, answers): 120 | score = 0. 121 | if dataset in ["trec", "triviaqa", "samsum", "lsht"]: 122 | prediction = prediction.lstrip('\n').split('\n')[0] 123 | for ground_truth in ground_truths: 124 | score = max(score, dataset2metric[dataset]( 125 | prediction, ground_truth, all_classes=all_classes)) 126 | total_score += score 127 | return round(100 * total_score / len(predictions), 2) 128 | 129 | 130 | def get_all_scores(task, predictions, answers, all_classes): 131 | all_scores = [] 132 | for (prediction, ground_truths) in zip(predictions, answers): 133 | score = 0. 134 | if task in ["trec", "triviaqa", "samsum", "lsht"]: 135 | prediction = prediction.lstrip('\n').split('\n')[0] 136 | for ground_truth in ground_truths: 137 | score = max(score, dataset2metric[task](prediction, ground_truth, 138 | all_classes=all_classes)) 139 | all_scores.append(score) 140 | return all_scores 141 | 142 | 143 | def get_score(task, predictions, answers, all_classes): 144 | total_score = 0. 145 | all_scores = [] 146 | 147 | # Instantiate tokenizer for ChouBun tasks 148 | if task in ["wiki_qa", "edinet_qa", "corp_sec_qa", "corp_sec_sum"]: 149 | from fugashi import Tagger 150 | tokenizer = Tagger('-Owakati') 151 | else: 152 | tokenizer = None 153 | 154 | for (prediction, ground_truths) in zip(predictions, answers): 155 | score = 0. 156 | if task in ["trec", "triviaqa", "samsum", "lsht"]: 157 | prediction = prediction.lstrip('\n').split('\n')[0] 158 | for ground_truth in ground_truths: 159 | cur_score = dataset2metric[task]( 160 | prediction, 161 | ground_truth, 162 | tokenizer=tokenizer, 163 | all_classes=all_classes 164 | ) 165 | score = max(score, cur_score) 166 | all_scores.append(score) 167 | total_score += score 168 | mean_score = 100 * total_score / len(predictions) 169 | return mean_score, all_scores 170 | 171 | 172 | if __name__ == '__main__': 173 | args = parse_args() 174 | scores = dict() 175 | if args.e: 176 | path = f"pred_e/{args.model}/" 177 | else: 178 | path = f"pred/{args.model}/" 179 | all_files = os.listdir(path) 180 | print("Evaluating on:", all_files) 181 | for filename in all_files: 182 | if not filename.endswith("jsonl"): 183 | continue 184 | predictions, answers, lengths = [], [], [] 185 | dataset = filename.split('.')[0] 186 | with open(f"{path}{filename}", "r", encoding="utf-8") as f: 187 | for line in f: 188 | data = json.loads(line) 189 | predictions.append(data["pred"]) 190 | answers.append(data["answers"]) 191 | all_classes = data["all_classes"] 192 | if "length" in data: 193 | lengths.append(data["length"]) 194 | if args.e: 195 | score = scorer_e(dataset, predictions, 196 | answers, lengths, all_classes) 197 | else: 198 | score = scorer(dataset, predictions, answers, all_classes) 199 | scores[dataset] = score 200 | if args.e: 201 | out_path = f"pred_e/{args.model}/result.json" 202 | else: 203 | out_path = f"pred/{args.model}/result.json" 204 | with open(out_path, "w") as f: 205 | json.dump(scores, f, ensure_ascii=False, indent=4) 206 | -------------------------------------------------------------------------------- /memory_policy/deep_scoring_bam.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | import copy 4 | import math 5 | import numbers 6 | import numpy as np 7 | from dataclasses import dataclass 8 | from typing import Optional, Tuple, Union, Dict, Callable, List 9 | 10 | 11 | import torch 12 | from torch import nn 13 | import torch.utils.checkpoint 14 | import torch.nn.functional as F 15 | from torch.cuda.amp import autocast 16 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 17 | from transformers import LlamaPreTrainedModel 18 | from transformers.cache_utils import Cache, DynamicCache, StaticCache 19 | from .base import MemoryPolicy, ParamMemoryPolicy 20 | from .base_dynamic import ( 21 | DynamicMemoryPolicy, DynamicParamMemoryPolicy, 22 | RecencyParams, AttentionParams, threshold_score_idxs 23 | ) 24 | from .base_deep_components import ( 25 | ScoringNetwork, TokenEmbedding, SelectionNetwork, wrap_torch_initializer, 26 | ComponentOutputParams) 27 | from stateless_parallel_modules import StatelessGeneralizedMLP 28 | 29 | 30 | def make_scaled_one_hot_init( 31 | idxs_to_scale: dict, 32 | idxs_to_ones: Union[List[int], np.ndarray, torch.Tensor]): 33 | def _init_fn(shape): 34 | tensor = torch.zeros(shape) 35 | for idx in idxs_to_ones: 36 | tensor[..., idx] = 1.0 37 | for idx, value in idxs_to_scale.items(): 38 | tensor[..., int(idx)] = float(value) 39 | 40 | return tensor 41 | return _init_fn 42 | 43 | class MLPScoring(ScoringNetwork): 44 | '''MLP scoring layer, producing score as the NN output combination of the 45 | embeddings.''' 46 | def __init__( 47 | self, 48 | per_layer: bool, 49 | per_head: bool, 50 | shared: bool, 51 | output_params: ComponentOutputParams, 52 | hidden_features: Optional[int], 53 | hidden_depth: int, 54 | bias: bool, 55 | non_linearity: Optional[Union[str, Callable]] = 'relu', 56 | initializer: numbers.Number = 0, 57 | residual: bool = True, 58 | residual_first: bool = False, 59 | 60 | ): 61 | ScoringNetwork.__init__( 62 | self, 63 | per_layer=per_layer, 64 | per_head=per_head, 65 | shared=shared, 66 | output_params=output_params, 67 | buffer_names=['past_scores'], 68 | initializer=initializer, 69 | ) 70 | self.hidden_features = hidden_features 71 | self.hidden_depth = hidden_depth 72 | self.bias = bias 73 | self.non_linearity = non_linearity 74 | self.residual = residual 75 | self.residual_first = residual_first 76 | 77 | def register_embedding(self, embedding_module: TokenEmbedding): 78 | ScoringNetwork.register_embedding( 79 | self=self, 80 | embedding_module=embedding_module, 81 | ) 82 | if self.hidden_features is None: 83 | self.hidden_features=self.input_embedding_dim 84 | self.mlp = StatelessGeneralizedMLP( 85 | input_features=self.input_embedding_dim, 86 | hidden_features=self.hidden_features, 87 | output_features=1, 88 | hidden_depth=self.hidden_depth, 89 | bias=self.bias, 90 | non_linearity=self.non_linearity, 91 | residual=self.residual, 92 | residual_first=self.residual_first, 93 | ) 94 | self.mlp_base_parameters = self.mlp.total_base_parameter_dims 95 | 96 | def get_tokens_score( 97 | self, 98 | layer_id, 99 | parameters, 100 | token_embeddings: torch.Tensor, 101 | new_sequences, 102 | num_new_tokens, 103 | attn_weights=None, 104 | attn_mask=None, 105 | position_ids=None, 106 | 107 | **kwargs, 108 | ) -> torch.Tensor: 109 | '''Produces score for each KV cache token embedding''' 110 | 111 | if not self.requires_recomputation: 112 | token_embeddings = token_embeddings[..., -num_new_tokens:, :] 113 | 114 | if self.is_reduced_input: 115 | batch_size, n_heads, n_out_tokens, emb_dim = token_embeddings.shape 116 | else: 117 | batch_size, n_heads, non_reduced_outputs, n_out_tokens, emb_dim = ( 118 | token_embeddings.shape) 119 | token_embeddings = token_embeddings.flatten(start_dim=2, end_dim=3) 120 | 121 | 122 | 123 | n_out_tokens = n_out_tokens 124 | parallel_operations = batch_size 125 | 126 | 127 | if self.shared and self.per_head: 128 | 129 | parallel_operations = parallel_operations*n_heads 130 | token_embeddings = token_embeddings.flatten(start_dim=0, end_dim=1) 131 | else: 132 | 133 | token_embeddings = token_embeddings.flatten(start_dim=1, end_dim=2) 134 | 135 | 136 | 137 | self.mlp.load_parameters( 138 | parameters=parameters, 139 | parallel_operations=parallel_operations, 140 | ) 141 | scores = self.mlp(inputs=token_embeddings) 142 | 143 | if self.is_reduced_input: 144 | scores = scores.view(batch_size, n_heads, n_out_tokens) 145 | else: 146 | scores = scores.view( 147 | batch_size, n_heads, non_reduced_outputs, n_out_tokens) 148 | 149 | if not self.requires_recomputation: 150 | if not new_sequences: 151 | 152 | past_scores: torch.Tensor = self.past_scores[layer_id] 153 | scores = torch.concat([past_scores, scores], dim=-1) 154 | self.past_scores[layer_id] = scores 155 | 156 | if (self.reduction_mode is not None) and (not self.is_reduced_input): 157 | 158 | scores = self.process_output( 159 | layer_id=layer_id, 160 | ema_coeff=self.ema_coeff, 161 | num_new_tokens=num_new_tokens, 162 | new_sequences=new_sequences, 163 | component_output=scores, 164 | **kwargs, 165 | ) 166 | return scores 167 | 168 | def filter_buffer_values( 169 | self, 170 | layer_id: int, 171 | 172 | retained_idxs: torch.Tensor, 173 | ): 174 | ScoringNetwork.filter_buffer_values( 175 | self=self, 176 | layer_id=layer_id, 177 | retained_idxs=retained_idxs, 178 | ) 179 | if not self.requires_recomputation: 180 | 181 | 182 | past_scores: torch.Tensor = self.past_scores[layer_id] 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | self.past_scores[layer_id] = torch.gather( 193 | input=past_scores, dim=-1, index=retained_idxs) 194 | 195 | def net_param_size(self,) -> int: 196 | return self.mlp_base_parameters 197 | 198 | 199 | class LinearScoring(MLPScoring): 200 | '''Linear scoring layer, producing score as a linear combination of the 201 | embeddings.''' 202 | def __init__( 203 | self, 204 | per_layer: bool, 205 | per_head: bool, 206 | shared: bool, 207 | output_params: ComponentOutputParams, 208 | bias: bool, 209 | 210 | initializer: numbers.Number = 0, 211 | ): 212 | MLPScoring.__init__( 213 | self, 214 | per_layer=per_layer, 215 | per_head=per_head, 216 | shared=shared, 217 | output_params=output_params, 218 | hidden_features=0, 219 | hidden_depth=1, 220 | bias=bias, 221 | non_linearity=None, 222 | initializer=initializer, 223 | ) 224 | self.bias = bias 225 | 226 | -------------------------------------------------------------------------------- /memory_policy/deep_embedding.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | import copy 4 | import math 5 | import numbers 6 | import abc 7 | import numpy as np 8 | from dataclasses import dataclass 9 | from typing import Optional, Tuple, Union, Dict, Callable, List 10 | 11 | from omegaconf import OmegaConf, DictConfig 12 | import hydra 13 | 14 | import torch 15 | from torch import nn 16 | import torch.utils.checkpoint 17 | import torch.nn.functional as F 18 | from torch.cuda.amp import autocast 19 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 20 | from transformers import LlamaPreTrainedModel 21 | from transformers.cache_utils import Cache, DynamicCache, StaticCache 22 | from .base import MemoryPolicy, ParamMemoryPolicy 23 | from .base_dynamic import compute_recency 24 | 25 | from .base_deep_components import ( 26 | ScoringNetwork, TokenEmbedding, SelectionNetwork, wrap_torch_initializer, 27 | ComponentOutputParams,) 28 | 29 | from .deep_embedding_shared import Embedding, PositionalEmbedding 30 | 31 | from stateless_parallel_modules import StatelessGeneralizedMLP 32 | 33 | def convert_to_tensor( 34 | el: Union[List[float], np.ndarray, torch.Tensor], 35 | ) -> torch.Tensor: 36 | if isinstance(el, torch.Tensor): 37 | return el 38 | else: 39 | el = torch.tensor(el) 40 | return el 41 | 42 | 43 | class RecencyExponents(TokenEmbedding): 44 | '''Representing each KV, via a polynomial vector of its recency''' 45 | def __init__( 46 | self, 47 | per_layer: bool, 48 | per_head: bool, 49 | shared: bool, 50 | initial_exponents: Union[List[float], np.ndarray, torch.Tensor], 51 | dtype: Optional[Union[str, torch.dtype]] = None 52 | ): 53 | initial_exponents = convert_to_tensor(initial_exponents) 54 | assert len(initial_exponents.shape) == 1 55 | self._num_recency_exponents = initial_exponents.shape[-1] 56 | TokenEmbedding.__init__( 57 | self, 58 | per_layer=per_layer, 59 | per_head=per_head, 60 | shared=shared, 61 | output_params=ComponentOutputParams( 62 | requires_recomputation=False,), 63 | buffer_names=[], 64 | initializer=initial_exponents, 65 | dtype=dtype, 66 | ) 67 | self.initial_recency_exponents = initial_exponents 68 | 69 | def get_tokens_embedding( 70 | self, 71 | layer_id, 72 | parameters, 73 | key_cache, 74 | value_cache, 75 | new_sequences, 76 | num_new_tokens, 77 | position_ids, 78 | attn_mask=None, 79 | **kwargs, 80 | ) -> torch.Tensor: 81 | '''Builds a tensor representation for each KV cache token''' 82 | 83 | exponents = parameters 84 | 85 | 86 | unsqueezed_exponents = exponents.unsqueeze(dim=-2) 87 | 88 | 89 | cache_recencies = compute_recency(position_ids=position_ids) 90 | unsqueezed_recencies = cache_recencies.unsqueeze(dim=-1) 91 | 92 | embeddings = torch.pow( 93 | input=unsqueezed_recencies, exponent=unsqueezed_exponents) 94 | 95 | if self._custom_dtype is not None: 96 | embeddings = embeddings.to(dtype=self.ptdtype) 97 | 98 | 99 | embeddings = self.process_output( 100 | layer_id=layer_id, 101 | ema_coeff=self.ema_coeff, 102 | num_new_tokens=num_new_tokens, 103 | new_sequences=new_sequences, 104 | component_output=embeddings, 105 | attn_mask=attn_mask, 106 | **kwargs, 107 | ) 108 | 109 | return embeddings 110 | 111 | def get_embedding_dim(self,) -> int: 112 | return self._num_recency_exponents 113 | 114 | def net_param_size(self,) -> int: 115 | return self._num_recency_exponents 116 | 117 | def aux_param_size(self) -> int: 118 | return 0 119 | 120 | def get_net_params_stats(self, parameters: torch.Tensor): 121 | stats = dict() 122 | learned_exps = parameters.split(split_size=1, dim=-1) 123 | for i, learned_exp in enumerate(learned_exps): 124 | stats[f'net_params/rec_exp_{i}'] = learned_exp.mean().item() 125 | return stats 126 | 127 | @property 128 | def requires_position_ids(self,): 129 | 130 | return True 131 | 132 | @property 133 | def requires_recomputation(self,): 134 | 135 | return False 136 | 137 | @property 138 | def reduced_output(self,): 139 | return True 140 | 141 | def net_param_size(self,) -> int: 142 | return self._num_recency_exponents 143 | 144 | 145 | class NormalizedRecencyExponents(TokenEmbedding): 146 | '''Representing each KV, via a polynomial vector of its normalized 147 | recency - a score ranging from 1 to 0 (most recent to oldest possible 148 | i.e., max_position_id)''' 149 | def __init__( 150 | self, 151 | per_layer: bool, 152 | per_head: bool, 153 | shared: bool, 154 | max_position_id: int, 155 | initial_exponents: Union[List[float], np.ndarray, torch.Tensor], 156 | only_positive_exponents: bool = True, 157 | dtype: Optional[Union[str, torch.dtype]] = None 158 | ): 159 | initial_exponents = convert_to_tensor(initial_exponents) 160 | assert len(initial_exponents.shape) == 1 161 | self._num_recency_exponents = initial_exponents.shape[-1] 162 | self.max_position_id = max_position_id 163 | self.only_positive_exponents = only_positive_exponents 164 | 165 | if only_positive_exponents: 166 | assert torch.all(initial_exponents > 0) 167 | log_initial_exponents = torch.log( 168 | initial_exponents) 169 | initializer = log_initial_exponents 170 | else: 171 | initializer = initial_exponents 172 | 173 | TokenEmbedding.__init__( 174 | self, 175 | per_layer=per_layer, 176 | per_head=per_head, 177 | shared=shared, 178 | output_params=ComponentOutputParams( 179 | requires_recomputation=False,), 180 | buffer_names=[], 181 | initializer=initializer, 182 | dtype=dtype, 183 | ) 184 | 185 | def get_tokens_embedding( 186 | self, 187 | layer_id, 188 | parameters, 189 | key_cache, 190 | value_cache, 191 | new_sequences, 192 | num_new_tokens, 193 | position_ids, 194 | attn_mask=None, 195 | **kwargs, 196 | ) -> torch.Tensor: 197 | '''Builds a tensor representation for each KV cache token''' 198 | 199 | if self.only_positive_exponents: 200 | exponents = torch.exp(parameters) 201 | else: 202 | exponents = parameters 203 | 204 | 205 | unsqueezed_exponents = exponents.unsqueeze(dim=-2) 206 | 207 | 208 | 209 | cache_recencies = compute_recency( 210 | position_ids=position_ids)/self.max_position_id 211 | 212 | 213 | 214 | cache_recencies = 1 - cache_recencies 215 | 216 | unsqueezed_recencies = cache_recencies.unsqueeze(dim=-1) 217 | 218 | embeddings = torch.pow( 219 | input=unsqueezed_recencies, exponent=unsqueezed_exponents) 220 | 221 | if self._custom_dtype is not None: 222 | embeddings = embeddings.to(dtype=self.ptdtype) 223 | 224 | embeddings = self.process_output( 225 | layer_id=layer_id, 226 | ema_coeff=self.ema_coeff, 227 | num_new_tokens=num_new_tokens, 228 | new_sequences=new_sequences, 229 | component_output=embeddings, 230 | attn_mask=attn_mask, 231 | **kwargs, 232 | ) 233 | return embeddings 234 | 235 | def get_embedding_dim(self,) -> int: 236 | return self._num_recency_exponents 237 | 238 | def net_param_size(self,) -> int: 239 | return self._num_recency_exponents 240 | 241 | def aux_param_size(self) -> int: 242 | return 0 243 | 244 | def get_net_params_stats(self, parameters: torch.Tensor): 245 | stats = dict() 246 | learned_exps = parameters.split(split_size=1, dim=-1) 247 | for i, learned_exp in enumerate(learned_exps): 248 | stats[f'net_params/rec_exp_{i}'] = learned_exp.mean().item() 249 | return stats 250 | 251 | @property 252 | def requires_position_ids(self,): 253 | 254 | return True 255 | 256 | @property 257 | def requires_recomputation(self,): 258 | 259 | return False 260 | 261 | @property 262 | def reduced_output(self,): 263 | return True 264 | 265 | def net_param_size(self,) -> int: 266 | return self._num_recency_exponents 267 | 268 | 269 | -------------------------------------------------------------------------------- /memory_policy/deep_selection.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | import copy 4 | import math 5 | import numpy as np 6 | from dataclasses import dataclass 7 | from typing import Optional, Tuple, Union, Dict 8 | 9 | 10 | import torch 11 | from torch import nn 12 | import torch.utils.checkpoint 13 | import torch.nn.functional as F 14 | from torch.cuda.amp import autocast 15 | 16 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 17 | from transformers import LlamaPreTrainedModel 18 | from transformers.cache_utils import Cache, DynamicCache, StaticCache 19 | from .base import MemoryPolicy, ParamMemoryPolicy 20 | from .base_dynamic import ( 21 | DynamicMemoryPolicy, DynamicParamMemoryPolicy, 22 | RecencyParams, AttentionParams, threshold_score_idxs 23 | ) 24 | from .base_deep_components import SelectionNetwork, ComponentOutputParams 25 | 26 | 27 | class DynamicSelection(SelectionNetwork): 28 | '''Replicates the default selection criteria of dynamic memory policies''' 29 | def __init__(self, 30 | per_layer: bool, 31 | per_head: bool, 32 | shared: bool, 33 | cache_size: Optional[int], 34 | dynamic_thresh: float, 35 | ): 36 | SelectionNetwork.__init__( 37 | self, 38 | per_layer=per_layer, 39 | per_head=per_head, 40 | shared=shared, 41 | output_params=ComponentOutputParams( 42 | requires_recomputation=False, 43 | reduction_mode=None, 44 | ema_params=None, 45 | output_past_non_reduced_history=False, 46 | max_non_reduced_history_len=None, 47 | ), 48 | buffer_names=[], 49 | initializer=dynamic_thresh, 50 | ) 51 | self.cache_size = cache_size 52 | 53 | def select_new_tokens( 54 | self, 55 | parameters, 56 | token_scores, 57 | new_sequences, 58 | num_new_tokens, 59 | attn_weights=None, 60 | attn_mask=None, 61 | position_ids=None, 62 | threshold_shift: float = 0.0, 63 | **kwargs, 64 | ) -> torch.Tensor: 65 | '''Produces indexes for the selected KV cache tokens and a selection 66 | mask.''' 67 | dynamic_thresh = parameters 68 | min_value = torch.finfo(token_scores.dtype).min 69 | max_value = torch.finfo(dynamic_thresh.dtype).max 70 | 71 | 72 | masked_full_scores = torch.where( 73 | attn_mask.bool(), token_scores, min_value) 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | masked_full_scores[..., -1] = max_value 84 | retained_idxs, new_mask = threshold_score_idxs( 85 | masked_full_scores=masked_full_scores, 86 | dynamic_thresh=dynamic_thresh + threshold_shift, 87 | preserve_order=True, 88 | cache_size=self.cache_size, 89 | ) 90 | 91 | 92 | return retained_idxs, new_mask 93 | 94 | def get_cache_size(self,) -> Optional[int]: 95 | return self.cache_size 96 | 97 | def net_param_size(self,) -> int: 98 | return 1 99 | 100 | def get_param_scaling(self,) -> Optional[Union[str, Tuple[float, float]]]: 101 | 102 | 103 | return 'exp' 104 | 105 | class BinarySelection(SelectionNetwork): 106 | '''Mantains tokens when scores > 0 - can be probabilistic, based on a 107 | logistic distribution instead''' 108 | 109 | def __init__(self, 110 | per_layer: bool, 111 | per_head: bool, 112 | shared: bool, 113 | cache_size: Optional[int], 114 | is_probabilistic: bool = False, 115 | temp: float = 1.0, 116 | learned_temp: bool = False, 117 | ): 118 | 119 | if learned_temp: 120 | assert is_probabilistic, ( 121 | 'If is not probabilistic, the temperature will not be used') 122 | self._needs_learned_temp = True 123 | else: 124 | self._needs_learned_temp = False 125 | self.is_probabilistic = is_probabilistic 126 | self.initial_temp = temp 127 | SelectionNetwork.__init__( 128 | self, 129 | per_layer=per_layer, 130 | per_head=per_head, 131 | shared=shared, 132 | output_params=ComponentOutputParams( 133 | requires_recomputation=False, 134 | reduction_mode=None, 135 | ema_params=None, 136 | output_past_non_reduced_history=False, 137 | max_non_reduced_history_len=None, 138 | ), 139 | buffer_names=[], 140 | initializer=temp, 141 | ) 142 | self.cache_size = cache_size 143 | 144 | 145 | def select_new_tokens( 146 | self, 147 | parameters, 148 | token_scores, 149 | new_sequences, 150 | num_new_tokens, 151 | attn_weights=None, 152 | attn_mask=None, 153 | position_ids=None, 154 | threshold_shift: float = 0.0, 155 | **kwargs, 156 | ) -> torch.Tensor: 157 | '''Produces indexes for the selected KV cache tokens and a selection 158 | mask.''' 159 | 160 | 161 | 162 | 163 | 164 | min_value = torch.finfo(token_scores.dtype).min 165 | max_value = torch.finfo(token_scores.dtype).max 166 | 167 | if self.is_probabilistic: 168 | if self._needs_learned_temp: 169 | temp = parameters 170 | else: 171 | temp = self.initial_temp 172 | 173 | probabilities = F.sigmoid(masked_full_scores/temp) 174 | random_samples = torch.rand_like(probabilities) 175 | 176 | 177 | token_scores = (probabilities >= random_samples).to( 178 | probabilities.dtype) - 0.5 179 | 180 | masked_full_scores = torch.where( 181 | attn_mask.bool(), token_scores, min_value) 182 | masked_full_scores[..., -1] = max_value 183 | retained_idxs, new_mask = threshold_score_idxs( 184 | masked_full_scores=masked_full_scores, 185 | dynamic_thresh=threshold_shift, 186 | preserve_order=True, 187 | cache_size=self.cache_size, 188 | ) 189 | 190 | 191 | return retained_idxs, new_mask 192 | 193 | def get_cache_size(self,) -> Optional[int]: 194 | return self.cache_size 195 | 196 | def net_param_size(self,) -> int: 197 | if self._needs_learned_temp: 198 | return 1 199 | else: 200 | return 0 201 | 202 | def get_param_scaling(self,) -> Optional[Union[str, Tuple[float, float]]]: 203 | 204 | 205 | return 'exp' 206 | 207 | 208 | class TopKSelection(SelectionNetwork): 209 | '''Simply collects the top K scores with no thesholding''' 210 | def __init__(self, 211 | 212 | 213 | 214 | cache_size: Optional[int], 215 | 216 | ): 217 | SelectionNetwork.__init__( 218 | self, 219 | per_layer=False, 220 | per_head=False, 221 | shared=True, 222 | output_params=ComponentOutputParams( 223 | requires_recomputation=False, 224 | reduction_mode=None, 225 | ema_params=None, 226 | output_past_non_reduced_history=False, 227 | max_non_reduced_history_len=None, 228 | ), 229 | buffer_names=[], 230 | ) 231 | self.cache_size = cache_size 232 | 233 | def select_new_tokens( 234 | self, 235 | parameters, 236 | token_scores, 237 | new_sequences, 238 | num_new_tokens, 239 | attn_weights=None, 240 | attn_mask=None, 241 | position_ids=None, 242 | **kwargs, 243 | ) -> torch.Tensor: 244 | '''Produces indexes for the selected KV cache tokens and a selection 245 | mask.''' 246 | 247 | num_samples = token_scores.shape[-1] 248 | if self.cache_size is not None and num_samples > self.cache_size: 249 | min_value = torch.finfo(token_scores.dtype).min 250 | masked_full_scores = torch.where( 251 | attn_mask.bool(), token_scores, min_value) 252 | _, retained_idxs = torch.topk( 253 | masked_full_scores, k=self.cache_size, sorted=False, dim=-1) 254 | retained_idxs, _ = retained_idxs.sort(descending=False, dim=-1,) 255 | else: 256 | retained_idxs = torch.arange( 257 | num_samples, device=token_scores.device,).view( 258 | 1, 1, num_samples).expand_as(token_scores) 259 | if self.cache_size is not None: 260 | attn_mask = attn_mask[..., -self.cache_size:] 261 | new_mask = torch.ones_like(retained_idxs, dtype=torch.bool)*attn_mask 262 | return retained_idxs, new_mask 263 | 264 | def get_cache_size(self,) -> Optional[int]: 265 | return self.cache_size 266 | 267 | def net_param_size(self,) -> int: 268 | return 0 269 | 270 | 271 | 272 | 273 | 274 | -------------------------------------------------------------------------------- /memory_policy/shared.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union, Dict, List 2 | 3 | from collections import OrderedDict 4 | import abc 5 | import copy 6 | import torch 7 | from torch import nn 8 | 9 | 10 | class RegistrationCompatible(abc.ABC): 11 | def register_auxiliary_loss_callback(self, auxiliary_loss): 12 | self.auxiliary_loss = auxiliary_loss 13 | self.auxiliary_loss_callback = True 14 | print( 15 | 'ERROR: Using auxiliary loss with component with no parameters') 16 | raise NotImplementedError 17 | 18 | def register_new_memory_layer(self, config, registration_kwargs): 19 | 20 | 21 | 22 | 23 | curr_layer_id = self.num_memory_layers 24 | self.num_memory_layers = self.num_memory_layers + 1 25 | self.config = config 26 | self.registration_kwargs = registration_kwargs 27 | return curr_layer_id 28 | 29 | def register_new_memory_model(self, config, registration_kwargs): 30 | assert 'num_memory_layers' in registration_kwargs 31 | self.num_memory_layers = registration_kwargs['num_memory_layers'] 32 | self.num_heads = registration_kwargs['num_heads'] 33 | self.num_key_value_groups = registration_kwargs.get( 34 | 'num_key_value_groups', None) 35 | self.config = config 36 | self.registration_kwargs = registration_kwargs 37 | self.hidden_size = registration_kwargs.get( 38 | 'hidden_size', None) 39 | if self.hidden_size is None: 40 | if hasattr(config, 'hidden_size'): 41 | self.hidden_size = config.hidden_size 42 | else: 43 | raise NotImplementedError 44 | 45 | 46 | class SynchronizableBufferStorage(abc.ABC): 47 | def __init__( 48 | self, 49 | buffers_to_merge: Union[List[str], dict] = [], 50 | sub_buffer_storages: list = [], 51 | ): 52 | self.initialize_buffer_dicts_to_merge( 53 | buffers_to_merge=buffers_to_merge, 54 | sub_buffer_storages=sub_buffer_storages) 55 | self.training_mode() 56 | self.unfreeze_sync_buffers() 57 | 58 | def are_sync_buffers_frozen(self,): 59 | return self._frozen_sync_buffers 60 | 61 | def freeze_sync_buffers(self, freeze=True): 62 | self._frozen_sync_buffers = freeze 63 | if self._has_sub_buffers_to_merge: 64 | for sub_buffer_n, sub_buffer in ( 65 | self._sub_buffer_ordered_references.items()): 66 | sub_buffer.freeze_sync_buffers(freeze=freeze) 67 | 68 | def unfreeze_sync_buffers(self,): 69 | self.freeze_sync_buffers(freeze=False) 70 | 71 | def training_mode(self,): 72 | self._is_in_training_mode = True 73 | if self._has_sub_buffers_to_merge: 74 | for sub_buffer_n, sub_buffer in ( 75 | self._sub_buffer_ordered_references.items()): 76 | sub_buffer.training_mode() 77 | 78 | def evaluation_mode(self,): 79 | self._is_in_training_mode = False 80 | if self._has_sub_buffers_to_merge: 81 | for sub_buffer_n, sub_buffer in ( 82 | self._sub_buffer_ordered_references.items()): 83 | sub_buffer.evaluation_mode() 84 | 85 | def get_buffers_to_merge_keys(self,): 86 | return self._buffers_to_merge_keys 87 | 88 | 89 | def get_buffers_list(self,): 90 | buffers_dict = self.get_buffers_dict() 91 | assert len(buffers_dict) == len(self._buffers_to_merge_keys) 92 | return [buffers_dict[k] for k in self._buffers_to_merge_keys] 93 | 94 | def merge_buffers_list( 95 | self, 96 | buffers_to_merge: List[List[torch.Tensor]], 97 | ) -> List[torch.Tensor]: 98 | merged_buffers = [] 99 | buffer_group_idx = 0 100 | if self._has_owned_buffers_to_merge: 101 | merged_buffers += self._merge_own_buffers( 102 | buffers_to_merge=buffers_to_merge[ 103 | :self._num_total_owned_buffers_to_merge]) 104 | buffer_group_idx += self._num_total_buffers_to_merge 105 | if self._has_sub_buffers_to_merge: 106 | for n_buffers, (sub_buffer_n, sub_buffer) in zip( 107 | self.num_buffers_per_sub_buffer, 108 | self._sub_buffer_ordered_references.items()): 109 | 110 | end_buffer_group_idx = buffer_group_idx + n_buffers 111 | rel_buffers_to_merge = buffers_to_merge[ 112 | buffer_group_idx:end_buffer_group_idx] 113 | merged_buffers += sub_buffer.merge_buffers_list( 114 | buffers_to_merge=rel_buffers_to_merge) 115 | buffer_group_idx = end_buffer_group_idx 116 | assert len(merged_buffers) == len(buffers_to_merge) 117 | return merged_buffers 118 | 119 | def _merge_own_buffers( 120 | self, 121 | buffers_to_merge: List[List[torch.Tensor]], 122 | ) -> List[torch.Tensor]: 123 | raise NotImplementedError 124 | 125 | def receive_buffers_list(self, buffers_list): 126 | assert len(buffers_list) == len(self._buffers_to_merge_keys) 127 | buffers_dict = {k: v for k, v in zip( 128 | self._buffers_to_merge_keys, buffers_list)} 129 | self.load_buffers_dict(buffers_dict=buffers_dict) 130 | 131 | def get_buffers_dict(self,): 132 | buffers_dict = {} 133 | if self._has_owned_buffers_to_merge: 134 | buffers_dict.update(self.buffers_to_merge_dict) 135 | if self._has_sub_buffers_to_merge: 136 | buffers_dict.update(self.get_dict_from_sub_buffers()) 137 | return buffers_dict 138 | 139 | def load_buffers_dict(self, buffers_dict): 140 | if len(buffers_dict) == 0: 141 | return 142 | else: 143 | assert set(buffers_dict.keys()) == set( 144 | self.get_buffers_to_merge_keys()) 145 | 146 | for k in self.buffers_to_merge_dict: 147 | self.buffers_to_merge_dict[k] = buffers_dict[k] 148 | if self._has_sub_buffers_to_merge: 149 | self.load_dict_to_sub_buffers(buffers_dict=buffers_dict) 150 | 151 | def _self_merge_own_buffers(self,) -> List[torch.Tensor]: 152 | raise NotImplementedError 153 | 154 | def self_merge(self,) -> List[torch.Tensor]: 155 | merged_buffers = [] 156 | if self._has_owned_buffers_to_merge: 157 | merged_buffers += self._self_merge_own_buffers() 158 | if self._has_sub_buffers_to_merge: 159 | for k, sub_buffer in self._sub_buffer_ordered_references.items(): 160 | merged_buffers += sub_buffer.self_merge() 161 | return merged_buffers 162 | 163 | def initialize_buffer_dicts_to_merge( 164 | self, buffers_to_merge: Union[List[str], dict], 165 | sub_buffer_storages: list, 166 | reset: bool = True, 167 | ): 168 | 169 | assert reset 170 | if isinstance(buffers_to_merge, dict): 171 | self.buffers_to_merge_dict = OrderedDict(buffers_to_merge) 172 | else: 173 | self.buffers_to_merge_dict = OrderedDict( 174 | [(k, 0) for k in buffers_to_merge]) 175 | 176 | self._buffers_to_merge_keys = list(self.buffers_to_merge_dict.keys()) 177 | self._owned_buffers_to_merge_keys = copy.copy( 178 | self._buffers_to_merge_keys) 179 | self._num_total_owned_buffers_to_merge = len( 180 | self.buffers_to_merge_dict) 181 | self._num_total_buffers_to_merge = ( 182 | self._num_total_owned_buffers_to_merge) 183 | self._has_buffers_to_merge = self._has_owned_buffers_to_merge = ( 184 | self._num_total_buffers_to_merge > 0) 185 | self.register_sub_buffers_to_merge( 186 | sub_buffer_storages=sub_buffer_storages,) 187 | 188 | def register_sub_buffers_to_merge(self, sub_buffer_storages: list): 189 | self.num_buffers_per_sub_buffer = [] 190 | self._buffers_to_merge_keys_from_sub_buffers = [] 191 | self._sub_buffer_ordered_references: Dict[ 192 | str, SynchronizableBufferStorage] = OrderedDict() 193 | for i, buffer in enumerate(sub_buffer_storages): 194 | if isinstance(buffer, str): 195 | assert hasattr(self, buffer) 196 | buffer_obj: SynchronizableBufferStorage = getattr(self, buffer) 197 | buffer_name = buffer 198 | else: 199 | buffer_obj: SynchronizableBufferStorage = buffer 200 | buffer_name = f'{i}' 201 | 202 | if buffer_obj._has_buffers_to_merge: 203 | sub_buffer_keys = buffer_obj._buffers_to_merge_keys 204 | self._buffers_to_merge_keys_from_sub_buffers += [ 205 | buffer_name + '_' + k for k in sub_buffer_keys] 206 | self.num_buffers_per_sub_buffer.append(len(sub_buffer_keys)) 207 | self._sub_buffer_ordered_references[buffer_name] = buffer_obj 208 | 209 | self._buffers_to_merge_keys += ( 210 | self._buffers_to_merge_keys_from_sub_buffers) 211 | 212 | self._num_total_buffers_to_merge_from_sub_buffers = sum( 213 | self.num_buffers_per_sub_buffer) 214 | self._has_sub_buffers_to_merge = ( 215 | self._num_total_buffers_to_merge_from_sub_buffers > 0) 216 | self._has_buffers_to_merge = ( 217 | self._has_buffers_to_merge or self._has_sub_buffers_to_merge) 218 | self._num_total_buffers_to_merge += ( 219 | self._num_total_buffers_to_merge_from_sub_buffers) 220 | 221 | def get_dict_from_sub_buffers(self,) -> Dict[str, torch.Tensor]: 222 | buffers_dict = OrderedDict() 223 | for buffer_name, buffer_obj in ( 224 | self._sub_buffer_ordered_references.items()): 225 | for k, v in buffer_obj.get_buffers_dict().items(): 226 | buffers_dict[buffer_name + '_' + k] = v 227 | return buffers_dict 228 | 229 | def load_dict_to_sub_buffers(self, buffers_dict) -> Dict[str, torch.Tensor]: 230 | for buffer_name, buffer_obj in ( 231 | self._sub_buffer_ordered_references.items()): 232 | 233 | buffer_sub_dict = {} 234 | for k, v in buffers_dict.items(): 235 | target_prefix = f'{buffer_name}_' 236 | if k.startswith(target_prefix): 237 | buffer_sub_dict[k.removeprefix(target_prefix)] = v 238 | 239 | buffer_obj.load_buffers_dict(buffer_sub_dict) 240 | 241 | return buffers_dict 242 | -------------------------------------------------------------------------------- /memory_policy/deep_embedding_spectogram.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | import copy 4 | import math 5 | import numbers 6 | import numpy as np 7 | from dataclasses import dataclass 8 | from typing import Optional, Tuple, Union, Dict, Callable, List 9 | import hydra 10 | from omegaconf import DictConfig 11 | 12 | import torch 13 | from torch import nn 14 | import torch.utils.checkpoint 15 | import torch.nn.functional as F 16 | from torch.cuda.amp import autocast 17 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 18 | from transformers import LlamaPreTrainedModel 19 | from transformers.cache_utils import Cache, DynamicCache, StaticCache 20 | from .base import MemoryPolicy, ParamMemoryPolicy 21 | from .base_dynamic import ( 22 | DynamicMemoryPolicy, DynamicParamMemoryPolicy, 23 | RecencyParams, AttentionParams, threshold_score_idxs 24 | ) 25 | from .base_deep_components import ( 26 | ScoringNetwork, TokenEmbedding, SelectionNetwork, wrap_torch_initializer, 27 | ComponentOutputParams, reduce_ema_values) 28 | from stateless_parallel_modules import StatelessGeneralizedMLP 29 | 30 | 31 | def fft_ema_mask(window_length, ema_coeff, hop_length): 32 | '''Creates a mask mimicking an exponential moving average across the fft 33 | values. For consistent usage, ensure window_length == stride.''' 34 | discount_vector_exponents = torch.arange( 35 | start=window_length-1, end=-1, step=-1,) 36 | 37 | discount_vector = torch.pow(ema_coeff, discount_vector_exponents) 38 | rescale_factor = (1-ema_coeff)/(1-ema_coeff**hop_length) 39 | return discount_vector*rescale_factor 40 | 41 | 42 | def fft_avg_mask(window_length): 43 | return torch.ones([window_length])/window_length 44 | 45 | 46 | @dataclass 47 | class STFTParams: 48 | n_fft: int 49 | hop_length: int 50 | window_fn: Optional[Union[ 51 | Callable, np.ndarray, torch.Tensor, DictConfig]] = None 52 | pad_mode: str = 'constant' 53 | output_magnitudes: bool = True 54 | 55 | 56 | class AttentionSpectrogram(TokenEmbedding): 57 | '''Representing each KV, via the unerlying freqs. for each token in the 58 | attention matrix: 59 | Each token hits all its future tokens, given N new tokens we construct 60 | the spectrogram of all tokens based on the attention of the new tokens, 61 | we either store all future tokens spectrograms, prune old, or cumulate 62 | them with output_params (e.g., via EMA)''' 63 | 64 | def __init__( 65 | self, 66 | per_layer: bool, 67 | per_head: bool, 68 | shared: bool, 69 | output_params: ComponentOutputParams, 70 | stft_params: STFTParams, 71 | dtype: Optional[Union[str, torch.dtype]] = None, 72 | ): 73 | 74 | self.prev_attn_buffer: List[Optional[torch.Tensor]] 75 | assert output_params.requires_recomputation == True 76 | TokenEmbedding.__init__( 77 | self, 78 | per_layer=per_layer, 79 | per_head=per_head, 80 | shared=shared, 81 | output_params=output_params, 82 | buffer_names=['prev_attn_buffer'], 83 | dtype=dtype, 84 | 85 | 86 | 87 | ) 88 | self.store_stft_params(stft_params=stft_params) 89 | 90 | def store_stft_params(self, stft_params: STFTParams): 91 | self.n_fft = stft_params.n_fft 92 | assert self.n_fft % 2 == 0 93 | 94 | self.window_length = self.n_fft 95 | self.stft_stride = stft_params.hop_length 96 | self.reduction_stride = self.stft_stride 97 | self.window_fn = stft_params.window_fn 98 | 99 | self.pad_attn_required = self.window_length - self.stft_stride 100 | self.base_n_fft = self.n_fft//2 + 1 101 | if self.window_fn is not None: 102 | if isinstance(self.window_fn, DictConfig): 103 | window = hydra.utils.call(self.window_fn, self.window_length) 104 | elif isinstance(self.window_fn, Callable): 105 | window = self.window_fn(self.window_length) 106 | else: 107 | window = self.window_fn 108 | assert window.shape[-1] == self.window_length 109 | self.register_buffer( 110 | 'stft_window', tensor=window, persistent=False) 111 | else: 112 | self.stft_window = None 113 | self.output_magnitudes = stft_params.output_magnitudes 114 | self.stft_pad_mode = stft_params.pad_mode 115 | assert self.stft_pad_mode == 'constant', ( 116 | 'TODO: Deviating from constant padding might lead to inconsistent ' 117 | 'results, needs to be tested') 118 | 119 | def get_tokens_embedding( 120 | self, 121 | layer_id, 122 | parameters, 123 | key_cache, 124 | value_cache, 125 | new_sequences, 126 | num_new_tokens, 127 | attn_weights, 128 | attn_mask=None, 129 | position_ids=None, 130 | analyze=False, 131 | **kwargs, 132 | ) -> torch.Tensor: 133 | '''Builds a tensor representation for each KV cache token''' 134 | parameters, aux_params = self.split_net_and_aux_params( 135 | parameters=parameters) 136 | 137 | device = key_cache.device 138 | 139 | batch_size, n_heads, num_all_tokens, emb_dim = key_cache.shape 140 | batch_size, n_heads, num_new_tokens, num_all_tokens = attn_weights.shape 141 | 142 | if attn_mask is not None: 143 | 144 | attn_mask = attn_mask[..., -num_all_tokens:].unsqueeze(-2) > 0 145 | else: 146 | attn_mask = torch.ones([batch_size, 1, num_all_tokens], 147 | device=device, dtype=torch.bool) 148 | 149 | num_new_embeddings = num_new_tokens // self.stft_stride 150 | stride_carry_over = num_new_tokens % self.stft_stride 151 | if stride_carry_over > 0: 152 | 153 | raise NotImplementedError 154 | 155 | attn_weights = attn_weights.transpose(dim0=-2, dim1=-1) 156 | if self.pad_attn_required > 0: 157 | if new_sequences: 158 | 159 | pad_tuple = [self.pad_attn_required, 0] 160 | rel_attn_weights = F.pad( 161 | input=attn_weights, 162 | pad=pad_tuple, 163 | mode=self.stft_pad_mode, 164 | ) 165 | else: 166 | 167 | prev_attn_weights = self.prev_attn_buffer[layer_id] 168 | rel_prev_attn_weights = prev_attn_weights[ 169 | ..., -self.pad_attn_required:] 170 | pad_tuple = [0, 0, 0, num_new_tokens] 171 | 172 | padded_rel_prev_attn_weights = F.pad( 173 | input=rel_prev_attn_weights, 174 | pad=pad_tuple, 175 | mode='constant', 176 | ) 177 | 178 | rel_attn_weights = torch.concat( 179 | [padded_rel_prev_attn_weights, attn_weights], 180 | dim=-1, 181 | ) 182 | 183 | if not analyze: 184 | self.prev_attn_buffer[layer_id] = attn_weights 185 | else: 186 | rel_attn_weights = attn_weights 187 | 188 | flat_rel_attn_weights = rel_attn_weights.flatten( 189 | start_dim=0, end_dim=-2) 190 | 191 | flat_attn_stft = torch.stft( 192 | input=flat_rel_attn_weights, 193 | n_fft=self.n_fft, 194 | hop_length=self.stft_stride, 195 | center=False, 196 | pad_mode='constant', 197 | normalized=False, 198 | onesided=True, 199 | return_complex=True, 200 | window=self.stft_window, 201 | ) 202 | 203 | attn_stft = flat_attn_stft.view( 204 | batch_size, n_heads, num_all_tokens, 205 | self.base_n_fft, num_new_embeddings) 206 | 207 | attn_stft = attn_stft.permute(dims=[0, 1, 4, 2, 3]) 208 | 209 | if self.output_magnitudes: 210 | attn_stft = attn_stft.abs() 211 | else: 212 | 213 | attn_stft = torch.view_as_real(attn_stft).flatten( 214 | start_dim=-2, end_dim=-1) 215 | 216 | if self._custom_dtype is not None: 217 | attn_stft = attn_stft.to(dtype=self.ptdtype) 218 | 219 | if self.output_past_non_reduced_history: 220 | 221 | past_attn_spectr = self.past_outputs_buffer[layer_id] 222 | 223 | pad_tuple = [0, 0, 0, num_new_tokens] 224 | padded_rel_prev_attn_weights = F.pad( 225 | input=rel_prev_attn_weights, 226 | pad=pad_tuple, 227 | mode='constant', 228 | ) 229 | attn_stft = torch.concat([past_attn_spectr, attn_stft], dim=-3) 230 | if self.limit_past_history_size: 231 | attn_stft = attn_stft[ 232 | ..., -self.max_non_reduced_history_len:, :, :] 233 | embeddings = attn_stft 234 | 235 | if not analyze: 236 | self.past_outputs_buffer[layer_id] = attn_stft 237 | 238 | else: 239 | embeddings = attn_stft 240 | 241 | embeddings = self.process_output( 242 | layer_id=layer_id, 243 | ema_coeff=self.ema_coeff, 244 | num_new_tokens=num_new_tokens, 245 | new_sequences=new_sequences, 246 | component_output=embeddings, 247 | aux_params=aux_params, 248 | attn_mask=attn_mask, 249 | analyze=analyze, 250 | **kwargs, 251 | ) 252 | 253 | return embeddings 254 | 255 | def get_embedding_dim(self,) -> int: 256 | 257 | if self.output_magnitudes: 258 | return self.base_n_fft 259 | else: 260 | return self.base_n_fft*2 261 | 262 | def net_param_size(self,) -> int: 263 | return 0 264 | 265 | @property 266 | def requires_attn_scores(self,): 267 | 268 | return True 269 | 270 | @property 271 | def requires_recomputation(self,): 272 | 273 | return True 274 | 275 | def filter_buffer_values( 276 | self, 277 | layer_id: int, 278 | 279 | retained_idxs: torch.Tensor): 280 | if self.pad_attn_required > 0: 281 | prev_attn_buffer: torch.Tensor = self.prev_attn_buffer[layer_id] 282 | prev_new_tokens = prev_attn_buffer.shape[-1] 283 | expanded_idxs = retained_idxs.unsqueeze(-1).expand( 284 | -1, -1, -1, prev_new_tokens) 285 | prev_attn_buffer = torch.gather( 286 | input=prev_attn_buffer, 287 | dim=-2, 288 | index=expanded_idxs, 289 | ) 290 | self.prev_attn_buffer[layer_id] = prev_attn_buffer 291 | TokenEmbedding.filter_buffer_values( 292 | self=self, 293 | layer_id=layer_id, 294 | retained_idxs=retained_idxs) 295 | -------------------------------------------------------------------------------- /memory_policy/auxiliary_losses.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | import copy 4 | import math 5 | import numpy as np 6 | from dataclasses import dataclass 7 | from typing import Optional, Tuple, Union, List 8 | 9 | import abc 10 | import torch 11 | from torch import nn 12 | import torch.utils.checkpoint 13 | import torch.nn.functional as F 14 | from torch.cuda.amp import autocast 15 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 16 | from transformers import LlamaPreTrainedModel 17 | from transformers.cache_utils import Cache, DynamicCache, StaticCache 18 | from .base import MemoryPolicy, ParamMemoryPolicy 19 | from .base_dynamic import DynamicMemoryPolicy, DynamicParamMemoryPolicy 20 | 21 | from omegaconf import OmegaConf, DictConfig 22 | import hydra 23 | import numpy as np 24 | 25 | 26 | class MemoryPolicyAuxiliaryLoss(nn.Module, abc.ABC): 27 | 28 | def __init__(self, 29 | memory_policy: ParamMemoryPolicy, 30 | coeff: float = 1.0, 31 | adaptive_target: Optional[float] = None, 32 | adaptive_target_param: str = 'exp', 33 | optimizer=None, 34 | adaptive_update_every: int = 1, 35 | device: str = 'cuda', 36 | ): 37 | nn.Module.__init__(self,) 38 | self.device = device 39 | assert isinstance(memory_policy, ParamMemoryPolicy) 40 | self.pop_size = memory_policy.param_pop_size 41 | 42 | self.num_memory_layers = memory_policy.num_memory_layers 43 | 44 | self.step_losses = torch.zeros([self.pop_size], dtype=torch.long, 45 | device=device) 46 | 47 | self.is_adaptive = adaptive_target is not None 48 | self.adaptive_target = adaptive_target 49 | 50 | assert adaptive_target_param in ['exp', 'linear'] 51 | self.is_exp = adaptive_target_param == 'exp' 52 | self.adaptive_target_param = adaptive_target_param 53 | if self.is_adaptive: 54 | assert optimizer is not None 55 | if self.is_exp: 56 | assert coeff > 0 57 | init_coeff = np.log(coeff) 58 | else: 59 | init_coeff = coeff 60 | self.raw_coeff = nn.Parameter( 61 | data=torch.zeros([1]) + float(init_coeff), requires_grad=False) 62 | else: 63 | self.register_buffer('raw_coeff', tensor=torch.tensor(coeff), 64 | persistent=False) 65 | self.adaptive_update_every = adaptive_update_every 66 | self.num_adaptive_updates = 0 67 | self.stored_losses = [] 68 | 69 | self.num_samples_per_pop = torch.zeros( 70 | [self.pop_size], dtype=torch.long, device=device) 71 | self.total_losses_per_pop = torch.zeros( 72 | [self.pop_size], dtype=torch.long, device=device) 73 | memory_policy.register_auxiliary_loss_callback(auxiliary_loss=self) 74 | 75 | def restart_recording(self,): 76 | 77 | self.num_samples_per_pop.data.copy_(torch.zeros_like( 78 | self.num_samples_per_pop)) 79 | self.total_losses_per_pop.data.copy_(torch.zeros_like( 80 | self.total_losses_per_pop)) 81 | 82 | @abc.abstractmethod 83 | def memory_policy_layer_callback( 84 | self, 85 | layer_id, 86 | pop_idxs, 87 | new_sequences, 88 | key_cache, 89 | value_cache, 90 | dynamic_mask=None): 91 | raise NotImplementedError 92 | 93 | @abc.abstractmethod 94 | def memory_policy_update_callback( 95 | self, 96 | layer_id, 97 | pop_idxs, 98 | new_sequences, 99 | new_kv_cache,): 100 | raise NotImplementedError 101 | 102 | @property 103 | def coeff(self,): 104 | if self.is_adaptive and self.is_exp: 105 | return torch.exp(self.raw_coeff) 106 | else: 107 | return self.raw_coeff 108 | 109 | @abc.abstractmethod 110 | def get_loss(self,): 111 | raise NotImplementedError 112 | 113 | def optim_params(self, loss): 114 | self.optimizer.zero_grad(set_to_none=True) 115 | dual_difference = self.adaptive_target - loss 116 | self.raw_coeff.grad = dual_difference 117 | self.optimizer.step() 118 | return dual_difference 119 | 120 | def setup_optimizer(self, optimizer): 121 | assert optimizer is not None 122 | if isinstance(self.optimizer, DictConfig): 123 | 124 | self.optimizer: torch.optim.Optimizer = hydra.utils.instantiate( 125 | optimizer, params=[self.raw_coeff], _convert_='all') 126 | else: 127 | self.optimizer: torch.optim.Optimizer = optimizer( 128 | params=[self.raw_coeff]) 129 | 130 | def forward(self,): 131 | loss = self.get_loss() 132 | 133 | if self.is_adaptive: 134 | if self.num_adaptive_updates % self.adaptive_update_every == 0: 135 | _ = self.optim_params(loss=loss) 136 | self.num_adaptive_updates += 1 137 | scaled_loss = self.coeff*loss 138 | 139 | return scaled_loss 140 | 141 | 142 | class JointAuxiliaryLosses(MemoryPolicyAuxiliaryLoss): 143 | 144 | def __init__(self, memory_policy: ParamMemoryPolicy, 145 | auxiliary_losses: List[MemoryPolicyAuxiliaryLoss]): 146 | MemoryPolicyAuxiliaryLoss.__init__(self, memory_policy=memory_policy) 147 | assert len(auxiliary_losses) > 0 148 | self.auxiliary_losses = nn.ModuleList(auxiliary_losses) 149 | self.adaptive_losses = [ 150 | loss for loss in auxiliary_losses if loss.is_adaptive] 151 | self.is_adaptive = len(self.adaptive_losses) > 0 152 | 153 | def get_loss(self,): 154 | loss = 0 155 | for aux_loss in self.auxiliary_losses: 156 | loss = loss + aux_loss.get_loss() 157 | return loss 158 | 159 | 160 | class SparsityAuxiliaryLoss(MemoryPolicyAuxiliaryLoss): 161 | 162 | def __init__(self, 163 | memory_policy: DynamicParamMemoryPolicy, 164 | coeff: float = 1.0, 165 | adaptive_target: Optional[float] = None, 166 | adaptive_target_param: str = 'exp', 167 | optimizer=None, 168 | adaptive_update_every: int = 1, 169 | sparsity_mode: str = 'mean', 170 | sparsity_per_head: bool = False, 171 | device: str = 'cuda', 172 | ): 173 | MemoryPolicyAuxiliaryLoss.__init__( 174 | self, 175 | memory_policy=memory_policy, 176 | coeff=coeff, 177 | adaptive_target=adaptive_target, 178 | adaptive_target_param=adaptive_target_param, 179 | optimizer=optimizer, 180 | adaptive_update_every=adaptive_update_every, 181 | device=device, 182 | ) 183 | self.sparsity_mode = sparsity_mode 184 | assert sparsity_mode in ['mean',] 185 | self.sparsity_per_head = sparsity_per_head 186 | 187 | self.losses_per_pop_per_layer = torch.zeros( 188 | [self.num_memory_layers, self.pop_size], dtype=torch.long, 189 | device=device) 190 | 191 | def memory_policy_layer_callback( 192 | self, 193 | layer_id, 194 | pop_idxs, 195 | new_sequences, 196 | key_cache, 197 | value_cache, 198 | dynamic_mask=None, 199 | scoring_network_params=None): 200 | 201 | if dynamic_mask is not None: 202 | unmasked_samples_per_head: torch.Tensor = dynamic_mask.to( 203 | dtype=self.losses_per_pop_per_layer.dtype).sum(-1) 204 | if self.sparsity_per_head: 205 | 206 | layer_loss = (unmasked_samples_per_head.sum(-1) / 207 | unmasked_samples_per_head.numel()).to( 208 | dtype=torch.long) 209 | else: 210 | layer_loss = unmasked_samples_per_head.max(-1)[0] 211 | else: 212 | layer_loss = key_cache.size(-2)*torch.ones( 213 | [key_cache.size(0)], 214 | device=self.losses_per_pop_per_layer.device, 215 | dtype=self.losses_per_pop_per_layer.dtype, 216 | ) 217 | self.losses_per_pop_per_layer[layer_id].data.copy_(torch.zeros_like( 218 | self.losses_per_pop_per_layer[layer_id])) 219 | self.losses_per_pop_per_layer[layer_id].scatter_add_( 220 | dim=0, index=pop_idxs, src=layer_loss,) # reduce='sum', include_self=False) 221 | 222 | def memory_policy_update_callback( 223 | self, 224 | pop_idxs, 225 | new_sequences, 226 | new_kv_cache,): 227 | losses_per_pop = self.losses_per_pop_per_layer.sum(dim=0) 228 | 229 | self.num_samples_per_pop.scatter_reduce_( 230 | dim=0, index=pop_idxs, 231 | # .to(dtype=self.num_samples_per_), #[pop_idxs], device=self.num_samples_per_pop.device), 232 | src=torch.ones_like(pop_idxs).to(dtype=losses_per_pop.dtype), 233 | reduce='sum', 234 | include_self=True) 235 | self.total_losses_per_pop.data.add_(losses_per_pop) 236 | 237 | def get_loss(self,): 238 | return self.total_losses_per_pop/( 239 | torch.clamp_min_( 240 | self.num_samples_per_pop*self.num_memory_layers, 1)) 241 | 242 | 243 | class L2NormAuxiliaryLoss(MemoryPolicyAuxiliaryLoss): 244 | def __init__(self, 245 | memory_policy: DynamicParamMemoryPolicy, 246 | coeff: float = 1.0, 247 | adaptive_target: Optional[float] = None, 248 | adaptive_target_param: str = 'linear', 249 | optimizer=None, 250 | adaptive_update_every: int = 1, 251 | device: str = 'cuda', 252 | ): 253 | MemoryPolicyAuxiliaryLoss.__init__( 254 | self, 255 | memory_policy=memory_policy, 256 | coeff=coeff, 257 | adaptive_target=adaptive_target, 258 | adaptive_target_param=adaptive_target_param, 259 | optimizer=optimizer, 260 | adaptive_update_every=adaptive_update_every, 261 | device=device, 262 | ) 263 | 264 | self.losses_per_pop_per_layer = torch.zeros( 265 | [self.num_memory_layers, self.pop_size], device=device) 266 | 267 | self.total_losses_per_pop = torch.zeros( 268 | [self.pop_size], dtype=torch.float, device=device) 269 | 270 | def memory_policy_layer_callback( 271 | self, 272 | layer_id, 273 | pop_idxs, 274 | new_sequences, 275 | key_cache, 276 | value_cache, 277 | dynamic_mask=None, 278 | scoring_network_params=None): 279 | 280 | self.total_losses_per_pop[pop_idxs] = ( 281 | scoring_network_params ** 2).mean().to(self.device) 282 | 283 | def memory_policy_update_callback( 284 | self, 285 | pop_idxs, 286 | new_sequences, 287 | new_kv_cache,): 288 | pass 289 | 290 | def get_loss(self,): 291 | return self.total_losses_per_pop 292 | -------------------------------------------------------------------------------- /stateless_parallel_modules/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | import copy 4 | import math 5 | import numpy as np 6 | from dataclasses import dataclass 7 | from typing import Optional, Tuple, Union, Dict, Callable, List 8 | import abc 9 | 10 | import torch 11 | from torch import nn 12 | import torch.utils.checkpoint 13 | import torch.nn.functional as F 14 | from torch.cuda.amp import autocast 15 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 16 | from transformers import LlamaPreTrainedModel 17 | from transformers.cache_utils import Cache, DynamicCache, StaticCache 18 | from utils import get_nonlinearity 19 | 20 | 21 | class StatelessGeneralizedOperation(nn.Module, abc.ABC): 22 | def __init__( 23 | self, 24 | ): 25 | nn.Module.__init__(self=self) 26 | 27 | @abc.abstractmethod 28 | def total_parameters( 29 | self, 30 | *args, 31 | parallel_operations: Optional[int] = None, 32 | **kwargs, 33 | ) -> int: 34 | raise NotImplementedError 35 | 36 | @abc.abstractmethod 37 | def prepare_parameters( 38 | parameters: torch.Tensor, 39 | *args, 40 | parallel_operations: Optional[int] = None, 41 | **kwargs, 42 | ): 43 | raise NotImplementedError 44 | 45 | @abc.abstractmethod 46 | def forward( 47 | # parallel_operations x batch_size x in_features 48 | input: torch.Tensor, 49 | *args, 50 | parallel_operations: Optional[int] = None, 51 | n_parallel_dimensions: Optional[int] = None, 52 | **kwargs, 53 | ) -> torch.Tensor: 54 | raise NotImplementedError 55 | 56 | 57 | class GeneralizedLinear(StatelessGeneralizedOperation): 58 | def __init__( 59 | self, 60 | ): 61 | StatelessGeneralizedOperation.__init__(self=self,) 62 | 63 | def total_parameters( 64 | self, 65 | in_features: int, 66 | out_features: int, 67 | bias: bool, 68 | parallel_operations: Optional[int] = None, 69 | **kwargs, 70 | ) -> int: 71 | total_weight_dimension = in_features*out_features 72 | total_base_parameters_dimension = total_weight_dimension 73 | if bias: 74 | total_base_parameters_dimension += out_features 75 | if parallel_operations is not None: 76 | total_parameters_dimension = ( 77 | parallel_operations*total_base_parameters_dimension) 78 | else: 79 | total_parameters_dimension = total_base_parameters_dimension 80 | return total_parameters_dimension, total_base_parameters_dimension 81 | 82 | def prepare_parameters( 83 | self, 84 | in_features: int, 85 | out_features: int, 86 | bias: bool, 87 | parameters: torch.Tensor, 88 | parallel_operations: Optional[int] = None, 89 | **kwargs, 90 | ): 91 | if parallel_operations is not None: 92 | parameters = parameters.view(parallel_operations, -1) 93 | if bias: 94 | total_weight_dimension = in_features*out_features 95 | flat_w, flat_b = parameters.split_with_sizes( 96 | [total_weight_dimension, out_features], dim=-1) 97 | # parallel_ops x 1 x out_features 98 | b = flat_b.unsqueeze_(-2) 99 | else: 100 | flat_w = parameters 101 | b = None 102 | # shape required for baddbmm/addbmm 103 | w = flat_w.view(parallel_operations, in_features, out_features) 104 | else: 105 | if bias: 106 | total_weight_dimension = in_features*out_features 107 | flat_w, b = parameters.split_with_sizes( 108 | [total_weight_dimension, out_features], dim=-1) 109 | else: 110 | flat_w = parameters 111 | b = None 112 | # shape required for F.linear 113 | w = flat_w.view(out_features, in_features) 114 | return w, b 115 | 116 | def forward( 117 | self, 118 | input: torch.Tensor, # parallel_operations x batch_size x in_features 119 | weight: torch.Tensor, 120 | bias: Optional[torch.Tensor] = None, 121 | parallel_operations: Optional[int] = None, 122 | n_parallel_dimensions: Optional[int] = None, 123 | **kwargs, 124 | ) -> torch.Tensor: 125 | if n_parallel_dimensions is not None: 126 | input_shape = input.shape 127 | 128 | # this is guaranteed to be 3-dimensional 129 | input = input.flatten(start_dim=1, end_dim=-2) 130 | 131 | if parallel_operations is not None: 132 | if bias is not None: 133 | out = torch.baddbmm(input=bias, batch1=input, batch2=weight) 134 | else: 135 | out = input @ weight 136 | else: 137 | out = F.linear(input=input, weight=weight, bias=bias) 138 | if n_parallel_dimensions is not None: 139 | out = out.view(*input_shape[:-1], -1) 140 | return out 141 | 142 | 143 | class StatelessGeneralizedModule(nn.Module, abc.ABC): 144 | def __init__( 145 | self, 146 | input_features: Optional[int], 147 | output_features: Optional[int], 148 | init_module: bool = True, 149 | operation_kwargs: Optional[dict] = None 150 | ): 151 | if init_module: 152 | nn.Module.__init__(self=self,) 153 | 154 | if operation_kwargs is None: 155 | self.operation_kwargs = {} 156 | else: 157 | self.operation_kwargs = operation_kwargs 158 | 159 | self.input_features = input_features 160 | self.output_features = output_features 161 | if not hasattr(self, 'total_base_parameter_dims'): 162 | # num. of params needed for forward (for each parallel op.) 163 | self.total_base_parameter_dims = 0 164 | 165 | if not hasattr(self, 'n_base_parameters_per_layer'): 166 | self.n_base_parameters_per_layer: List[int] = [] 167 | if not hasattr(self, 'parameters_per_layer'): 168 | self.parameters_per_layer: List[torch.Tensor] = [] 169 | if not hasattr(self, 'kwargs_per_op'): 170 | self.kwargs_per_op: List[dict] = [] 171 | if not hasattr(self, 'generalized_ops'): 172 | self.generalized_ops: List[StatelessGeneralizedOperation] = [] 173 | self.parallel_operations: Optional[int] = None 174 | 175 | def get_buffer_names(self,): 176 | return [] 177 | 178 | def instantiate_and_setup_ops( 179 | self, 180 | input_features: Optional[int] = None, 181 | output_features: Optional[int] = None, 182 | preceding_module=None, 183 | default_output_features_mult: int = 1, 184 | **kwargs, 185 | ): 186 | raise NotImplementedError 187 | 188 | def setup_operations( 189 | self, 190 | operations: Union[ 191 | List[StatelessGeneralizedOperation], 192 | StatelessGeneralizedOperation], 193 | operation_kwargs: Optional[dict] = None, 194 | operation_kwargs_overrides_list: Optional[List[dict]] = None, 195 | # save list of operations as nn.ModuleList 196 | save_as_module_list: bool = True, 197 | ): 198 | 199 | if operation_kwargs is not None: 200 | self.operation_kwargs = operation_kwargs 201 | 202 | operation_kwargs = self.operation_kwargs 203 | 204 | if operation_kwargs_overrides_list is not None: 205 | assert len(operation_kwargs_overrides_list) == len(operations) 206 | elif isinstance(operations, StatelessGeneralizedModule): 207 | operation_kwargs_overrides_list = [{}] 208 | else: 209 | operation_kwargs_overrides_list = [{} for _ in operations] 210 | 211 | if isinstance(operations, StatelessGeneralizedModule): 212 | operations = [operations for _ in range( 213 | operation_kwargs_overrides_list)] 214 | 215 | self.n_base_parameters_per_layer = self.n_base_parameters_per_layer 216 | self.generalized_ops = operations 217 | for op, op_kwargs_overrides in zip( 218 | self.generalized_ops, operation_kwargs_overrides_list): 219 | op: StatelessGeneralizedOperation 220 | op_kwargs = {} 221 | op_kwargs.update(operation_kwargs) 222 | op_kwargs.update(op_kwargs_overrides) 223 | self.kwargs_per_op.append(op_kwargs) 224 | op_base_params, _ = op.total_parameters( 225 | parallel_operations=None, **op_kwargs) 226 | self.n_base_parameters_per_layer.append(op_base_params) 227 | 228 | self.total_base_parameter_dims = int(sum( 229 | self.n_base_parameters_per_layer)) 230 | if save_as_module_list: 231 | self.generalized_ops_module_list = nn.ModuleList( 232 | self.generalized_ops) 233 | 234 | def instantiate_model( 235 | self, 236 | input_features: Optional[int] = None, 237 | output_features: Optional[int] = None, 238 | preceding_module=None, 239 | default_output_features_mult: int = 1, 240 | ): 241 | if self.input_features is None and input_features is not None: 242 | self.input_features = input_features 243 | elif self.input_features is None: 244 | assert preceding_module is not None 245 | assert hasattr(preceding_module, 'output_features') 246 | self.input_features: int = preceding_module.output_features 247 | print('Warning: input features not specified setting to ' + 248 | f'{self.input_features} (output features of ' + 249 | 'preceding module)') 250 | 251 | assert isinstance(self.input_features, int) 252 | if self.output_features is None and output_features is not None: 253 | self.output_features = output_features 254 | elif self.output_features is None: 255 | self.output_features = ( 256 | self.input_features*default_output_features_mult) 257 | print('Warning: output features not specified setting to ' + 258 | f'{self.output_features} ' + 259 | '(input features*default_output_features_mult)') 260 | 261 | def format_parameters( 262 | self, 263 | parameters: torch.Tensor, 264 | # should match first dimension of parameters 265 | parallel_operations: Optional[int] = None, 266 | ): 267 | 268 | if parallel_operations is not None: 269 | # sanity check, unneded/redundant (trivial overhead) 270 | parameters = parameters.view( 271 | parallel_operations, self.n_base_parameters_per_layer) 272 | 273 | self.parallel_operations = parallel_operations 274 | 275 | # assumes parameters is a flattened tensor 276 | flat_parameters_per_layer = parameters.split_with_sizes( 277 | self.n_base_parameters_per_layer, dim=-1) 278 | 279 | parameters_per_layer: list[torch.Tensor] = [] 280 | 281 | for op, op_kwargs, layer_params in zip( 282 | self.generalized_ops, 283 | self.kwargs_per_op, 284 | flat_parameters_per_layer, 285 | ): 286 | 287 | prepared_params = op.prepare_parameters( 288 | parameters=layer_params, 289 | parallel_operations=parallel_operations, 290 | **op_kwargs, 291 | ) 292 | 293 | parameters_per_layer.append(prepared_params) 294 | return parameters_per_layer 295 | 296 | def load_parameters( 297 | self, 298 | parameters: torch.Tensor, 299 | # should match first dimension of parameters 300 | parallel_operations: Optional[int] = None, 301 | ): 302 | 303 | if parallel_operations is not None: 304 | # sanity check, unneded/redundant (trivial overhead) 305 | parameters = parameters.view( 306 | parallel_operations, self.total_base_parameter_dims) 307 | 308 | self.parallel_operations = parallel_operations 309 | 310 | # assumes parameters is a flattened tensor 311 | flat_parameters_per_layer = parameters.split_with_sizes( 312 | self.n_base_parameters_per_layer, dim=-1) 313 | 314 | self.parameters_per_layer: list[torch.Tensor] = [] 315 | 316 | for op, op_kwargs, layer_params in zip( 317 | self.generalized_ops, 318 | self.kwargs_per_op, 319 | flat_parameters_per_layer, 320 | ): 321 | 322 | prepared_params = op.prepare_parameters( 323 | **op_kwargs, 324 | parameters=layer_params, 325 | parallel_operations=parallel_operations, 326 | ) 327 | 328 | self.parameters_per_layer.append(prepared_params) 329 | 330 | @abc.abstractmethod 331 | def forward( 332 | self, 333 | # parallel_operations x batch_size x input_features 334 | inputs: torch.Tensor, 335 | *args, 336 | n_parallel_dimensions: Optional[int] = None, 337 | **kwargs, 338 | ): 339 | raise NotImplementedError 340 | -------------------------------------------------------------------------------- /env.yaml: -------------------------------------------------------------------------------- 1 | name: th2 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=conda_forge 9 | - _openmp_mutex=4.5=2_gnu 10 | - aom=3.9.1=hac33072_0 11 | - asttokens=2.4.1=pyhd8ed1ab_0 12 | - blas=1.0=mkl 13 | - brotli-python=1.1.0=py310hc6cd4ac_1 14 | - bzip2=1.0.8=hd590300_5 15 | - ca-certificates=2024.6.2=hbcca054_0 16 | - cairo=1.18.0=h3faef2a_0 17 | - certifi=2024.6.2=pyhd8ed1ab_0 18 | - charset-normalizer=3.3.2=pyhd8ed1ab_0 19 | - comm=0.2.2=pyhd8ed1ab_0 20 | - cuda-cudart=12.1.105=0 21 | - cuda-cupti=12.1.105=0 22 | - cuda-libraries=12.1.0=0 23 | - cuda-nvrtc=12.1.105=0 24 | - cuda-nvtx=12.1.105=0 25 | - cuda-opencl=12.5.39=0 26 | - cuda-runtime=12.1.0=0 27 | - cuda-version=12.5=3 28 | - dav1d=1.2.1=hd590300_0 29 | - debugpy=1.8.2=py310h76e45a6_0 30 | - decorator=5.1.1=pyhd8ed1ab_0 31 | - executing=2.0.1=pyhd8ed1ab_0 32 | - expat=2.6.2=h59595ed_0 33 | - ffmpeg=7.0.1=gpl_hb399a10_100 34 | - filelock=3.15.1=pyhd8ed1ab_0 35 | - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 36 | - font-ttf-inconsolata=3.000=h77eed37_0 37 | - font-ttf-source-code-pro=2.038=h77eed37_0 38 | - font-ttf-ubuntu=0.83=h77eed37_2 39 | - fontconfig=2.14.2=h14ed4e7_0 40 | - fonts-conda-ecosystem=1=0 41 | - fonts-conda-forge=1=0 42 | - freetype=2.12.1=h267a509_2 43 | - fribidi=1.0.10=h36c2ea0_0 44 | - gettext=0.22.5=h59595ed_2 45 | - gettext-tools=0.22.5=h59595ed_2 46 | - gmp=6.3.0=h59595ed_1 47 | - gmpy2=2.1.5=py310hc7909c9_1 48 | - gnutls=3.7.9=hb077bed_0 49 | - graphite2=1.3.13=h59595ed_1003 50 | - harfbuzz=8.5.0=hfac3d4d_0 51 | - icu=73.2=h59595ed_0 52 | - idna=3.7=pyhd8ed1ab_0 53 | - importlib-metadata=8.0.0=pyha770c72_0 54 | - importlib_metadata=8.0.0=hd8ed1ab_0 55 | - intel-openmp=2022.0.1=h06a4308_3633 56 | - ipykernel=6.29.5=pyh3099207_0 57 | - ipython=8.26.0=pyh707e725_0 58 | - ipywidgets=8.1.3=pyhd8ed1ab_0 59 | - jedi=0.19.1=pyhd8ed1ab_0 60 | - jinja2=3.1.4=pyhd8ed1ab_0 61 | - jupyter_client=8.6.2=pyhd8ed1ab_0 62 | - jupyter_core=5.7.2=py310hff52083_0 63 | - jupyterlab_widgets=3.0.11=pyhd8ed1ab_0 64 | - keyutils=1.6.1=h166bdaf_0 65 | - krb5=1.21.3=h659f571_0 66 | - lame=3.100=h166bdaf_1003 67 | - lcms2=2.16=hb7c19ff_0 68 | - ld_impl_linux-64=2.40=hf3520f5_7 69 | - lerc=4.0.0=h27087fc_0 70 | - libabseil=20240116.2=cxx17_h59595ed_0 71 | - libasprintf=0.22.5=h661eb56_2 72 | - libasprintf-devel=0.22.5=h661eb56_2 73 | - libass=0.17.1=h8fe9dca_1 74 | - libblas=3.9.0=16_linux64_mkl 75 | - libcblas=3.9.0=16_linux64_mkl 76 | - libcublas=12.1.0.26=0 77 | - libcufft=11.0.2.4=0 78 | - libcufile=1.10.0.4=0 79 | - libcurand=10.3.6.39=0 80 | - libcusolver=11.4.4.55=0 81 | - libcusparse=12.0.2.55=0 82 | - libdeflate=1.20=hd590300_0 83 | - libdrm=2.4.120=hd590300_0 84 | - libedit=3.1.20191231=he28a2e2_2 85 | - libexpat=2.6.2=h59595ed_0 86 | - libffi=3.4.2=h7f98852_5 87 | - libgcc-ng=13.2.0=h77fa898_10 88 | - libgettextpo=0.22.5=h59595ed_2 89 | - libgettextpo-devel=0.22.5=h59595ed_2 90 | - libglib=2.80.2=h8a4344b_1 91 | - libgomp=13.2.0=h77fa898_10 92 | - libhwloc=2.10.0=default_h5622ce7_1001 93 | - libiconv=1.17=hd590300_2 94 | - libidn2=2.3.7=hd590300_0 95 | - libjpeg-turbo=3.0.0=hd590300_1 96 | - liblapack=3.9.0=16_linux64_mkl 97 | - libnpp=12.0.2.50=0 98 | - libnsl=2.0.1=hd590300_0 99 | - libnvjitlink=12.1.105=0 100 | - libnvjpeg=12.1.1.14=0 101 | - libopenvino=2024.1.0=h2da1b83_7 102 | - libopenvino-auto-batch-plugin=2024.1.0=hb045406_7 103 | - libopenvino-auto-plugin=2024.1.0=hb045406_7 104 | - libopenvino-hetero-plugin=2024.1.0=h5c03a75_7 105 | - libopenvino-intel-cpu-plugin=2024.1.0=h2da1b83_7 106 | - libopenvino-intel-gpu-plugin=2024.1.0=h2da1b83_7 107 | - libopenvino-intel-npu-plugin=2024.1.0=he02047a_7 108 | - libopenvino-ir-frontend=2024.1.0=h5c03a75_7 109 | - libopenvino-onnx-frontend=2024.1.0=h07e8aee_7 110 | - libopenvino-paddle-frontend=2024.1.0=h07e8aee_7 111 | - libopenvino-pytorch-frontend=2024.1.0=he02047a_7 112 | - libopenvino-tensorflow-frontend=2024.1.0=h39126c6_7 113 | - libopenvino-tensorflow-lite-frontend=2024.1.0=he02047a_7 114 | - libopus=1.3.1=h7f98852_1 115 | - libpciaccess=0.18=hd590300_0 116 | - libpng=1.6.43=h2797004_0 117 | - libprotobuf=4.25.3=h08a7969_0 118 | - libsodium=1.0.18=h36c2ea0_1 119 | - libsqlite=3.46.0=hde9e2c9_0 120 | - libstdcxx-ng=13.2.0=hc0a3c3a_10 121 | - libtasn1=4.19.0=h166bdaf_0 122 | - libtiff=4.6.0=h1dd3fc0_3 123 | - libunistring=0.9.10=h7f98852_0 124 | - libuuid=2.38.1=h0b41bf4_0 125 | - libva=2.21.0=h4ab18f5_2 126 | - libvpx=1.14.1=hac33072_0 127 | - libwebp-base=1.4.0=hd590300_0 128 | - libxcb=1.15=h0b41bf4_0 129 | - libxcrypt=4.4.36=hd590300_1 130 | - libxml2=2.12.7=hc051c1a_1 131 | - libzlib=1.3.1=h4ab18f5_1 132 | - llvm-openmp=15.0.7=h0cdce71_0 133 | - markupsafe=2.1.5=py310h2372a71_0 134 | - matplotlib-inline=0.1.7=pyhd8ed1ab_0 135 | - mkl=2022.1.0=hc2b9512_224 136 | - mpc=1.3.1=hfe3b2da_0 137 | - mpfr=4.2.1=h9458935_1 138 | - mpmath=1.3.0=pyhd8ed1ab_0 139 | - ncurses=6.5=h59595ed_0 140 | - nest-asyncio=1.6.0=pyhd8ed1ab_0 141 | - nettle=3.9.1=h7ab15ed_0 142 | - networkx=3.3=pyhd8ed1ab_1 143 | - ocl-icd=2.3.2=hd590300_1 144 | - openh264=2.4.1=h59595ed_0 145 | - openjpeg=2.5.2=h488ebb8_0 146 | - openssl=3.3.1=h4ab18f5_0 147 | - p11-kit=0.24.1=hc5aa10d_0 148 | - packaging=24.1=pyhd8ed1ab_0 149 | - parso=0.8.4=pyhd8ed1ab_0 150 | - pcre2=10.44=h0f59acf_0 151 | - pexpect=4.9.0=pyhd8ed1ab_0 152 | - pickleshare=0.7.5=py_1003 153 | - pillow=10.3.0=py310hf73ecf8_0 154 | - pip=24.0=pyhd8ed1ab_0 155 | - pixman=0.43.2=h59595ed_0 156 | - platformdirs=4.2.2=pyhd8ed1ab_0 157 | - prompt-toolkit=3.0.47=pyha770c72_0 158 | - pthread-stubs=0.4=h36c2ea0_1001 159 | - ptyprocess=0.7.0=pyhd3deb0d_0 160 | - pugixml=1.14=h59595ed_0 161 | - pure_eval=0.2.2=pyhd8ed1ab_0 162 | - pygments=2.18.0=pyhd8ed1ab_0 163 | - pysocks=1.7.1=pyha2e5f31_6 164 | - python=3.10.14=hd12c33a_0_cpython 165 | - python_abi=3.10=4_cp310 166 | - pytorch-cuda=12.1=ha16c6d3_5 167 | - pytorch-mutex=1.0=cuda 168 | - pyyaml=6.0.1=py310h2372a71_1 169 | - pyzmq=26.0.3=py310h6883aea_0 170 | - readline=8.2=h8228510_1 171 | - requests=2.32.3=pyhd8ed1ab_0 172 | - setuptools=70.0.0=pyhd8ed1ab_0 173 | - six=1.16.0=pyh6c4a22f_0 174 | - snappy=1.2.0=hdb0a2a9_1 175 | - stack_data=0.6.2=pyhd8ed1ab_0 176 | - svt-av1=2.1.0=hac33072_0 177 | - sympy=1.12.1=pypyh2585a3b_103 178 | - tbb=2021.12.0=h297d8ca_1 179 | - tk=8.6.13=noxft_h4845f30_101 180 | - torchaudio=2.3.1=py310_cu121 181 | - torchvision=0.18.1=py310_cu121 182 | - tornado=6.4.1=py310hc51659f_0 183 | - traitlets=5.14.3=pyhd8ed1ab_0 184 | - typing_extensions=4.12.2=pyha770c72_0 185 | - urllib3=2.2.2=pyhd8ed1ab_0 186 | - wcwidth=0.2.13=pyhd8ed1ab_0 187 | - wheel=0.43.0=pyhd8ed1ab_1 188 | - widgetsnbextension=4.0.11=pyhd8ed1ab_0 189 | - x264=1!164.3095=h166bdaf_2 190 | - x265=3.5=h924138e_3 191 | - xorg-fixesproto=5.0=h7f98852_1002 192 | - xorg-kbproto=1.0.7=h7f98852_1002 193 | - xorg-libice=1.1.1=hd590300_0 194 | - xorg-libsm=1.2.4=h7391055_0 195 | - xorg-libx11=1.8.9=h8ee46fc_0 196 | - xorg-libxau=1.0.11=hd590300_0 197 | - xorg-libxdmcp=1.1.3=h7f98852_0 198 | - xorg-libxext=1.3.4=h0b41bf4_2 199 | - xorg-libxfixes=5.0.3=h7f98852_1004 200 | - xorg-libxrender=0.9.11=hd590300_0 201 | - xorg-renderproto=0.11.1=h7f98852_1002 202 | - xorg-xextproto=7.3.0=h0b41bf4_1003 203 | - xorg-xproto=7.0.31=h7f98852_1007 204 | - xz=5.2.6=h166bdaf_0 205 | - yaml=0.2.5=h7f98852_2 206 | - zeromq=4.3.5=h75354e8_4 207 | - zipp=3.19.2=pyhd8ed1ab_0 208 | - zlib=1.3.1=h4ab18f5_1 209 | - zstd=1.5.6=ha6fb4c9_0 210 | - pip: 211 | - absl-py==2.1.0 212 | - accelerate==0.31.0 213 | - aiohttp==3.9.5 214 | - aiosignal==1.3.1 215 | - annotated-types==0.7.0 216 | - antlr4-python3-runtime==4.9.3 217 | - anyio==4.4.0 218 | - async-timeout==4.0.3 219 | - attrs==23.2.0 220 | - bitsandbytes==0.43.1 221 | - blis==0.7.11 222 | - bottle==0.12.25 223 | - cachetools==5.3.3 224 | - catalogue==2.0.10 225 | - cattrs==22.2.0 226 | - chardet==5.2.0 227 | - click==8.1.7 228 | - cloudpathlib==0.18.1 229 | - cloudpickle==3.0.0 230 | - cmake==3.29.5.1 231 | - colorama==0.4.6 232 | - confection==0.1.5 233 | - contourpy==1.2.1 234 | - crfm-helm==0.5.2 235 | - cycler==0.12.1 236 | - cymem==2.0.8 237 | - dacite==1.8.1 238 | - dataproperty==1.0.1 239 | - datasets==2.20.0 240 | - dill==0.3.8 241 | - diskcache==5.6.3 242 | - distro==1.9.0 243 | - dnspython==2.6.1 244 | - docker-pycreds==0.4.0 245 | - einops==0.8.0 246 | - email-validator==2.1.2 247 | - evaluate==0.4.2 248 | - exceptiongroup==1.2.1 249 | - fastapi==0.111.0 250 | - fastapi-cli==0.0.4 251 | - fonttools==4.53.1 252 | - frozenlist==1.4.1 253 | - fsspec==2024.5.0 254 | - ftfy==6.2.0 255 | - fugashi==1.3.2 256 | - fuzzywuzzy==0.18.0 257 | - gitdb==4.0.11 258 | - gitpython==3.1.43 259 | - google-api-core==2.19.0 260 | - google-api-python-client==2.133.0 261 | - google-auth==2.30.0 262 | - google-auth-httplib2==0.2.0 263 | - googleapis-common-protos==1.63.1 264 | - h11==0.14.0 265 | - httpcore==1.0.5 266 | - httplib2==0.22.0 267 | - httptools==0.6.1 268 | - httpx==0.27.0 269 | - huggingface-hub==0.23.4 270 | - hydra-core==1.3.2 271 | - importlib-resources==5.13.0 272 | - interegular==0.3.3 273 | - jieba==0.42.1 274 | - joblib==1.4.2 275 | - jsonlines==4.0.0 276 | - jsonschema==4.22.0 277 | - jsonschema-specifications==2023.12.1 278 | - kiwisolver==1.4.5 279 | - langcodes==3.4.0 280 | - language-data==1.2.0 281 | - lark==1.1.9 282 | - llvmlite==0.43.0 283 | - lm-eval==0.4.2 284 | - lm-format-enforcer==0.10.1 285 | - lxml==5.2.2 286 | - mako==1.3.5 287 | - marisa-trie==1.2.0 288 | - markdown-it-py==3.0.0 289 | - matplotlib==3.9.1 290 | - mbstrdecoder==1.1.3 291 | - mdurl==0.1.2 292 | - more-itertools==10.3.0 293 | - msgpack==1.0.8 294 | - multidict==6.0.5 295 | - multiprocess==0.70.16 296 | - murmurhash==1.0.10 297 | - ninja==1.11.1.1 298 | - nltk==3.8.1 299 | - numba==0.60.0 300 | - numexpr==2.10.0 301 | - numpy==1.26.4 302 | - nvidia-cublas-cu12==12.1.3.1 303 | - nvidia-cuda-cupti-cu12==12.1.105 304 | - nvidia-cuda-nvrtc-cu12==12.1.105 305 | - nvidia-cuda-runtime-cu12==12.1.105 306 | - nvidia-cudnn-cu12==8.9.2.26 307 | - nvidia-cufft-cu12==11.0.2.54 308 | - nvidia-curand-cu12==10.3.2.106 309 | - nvidia-cusolver-cu12==11.4.5.107 310 | - nvidia-cusparse-cu12==12.1.0.106 311 | - nvidia-ml-py==12.555.43 312 | - nvidia-nccl-cu12==2.20.5 313 | - nvidia-nvjitlink-cu12==12.5.40 314 | - nvidia-nvtx-cu12==12.1.105 315 | - omegaconf==2.3.0 316 | - openai==1.34.0 317 | - orjson==3.10.5 318 | - outlines==0.0.45 319 | - pandas==2.2.2 320 | - parameterized==0.9.0 321 | - pathvalidate==3.2.0 322 | - peft==0.11.1 323 | - portalocker==2.8.2 324 | - preshed==3.0.9 325 | - prometheus-client==0.20.0 326 | - prometheus-fastapi-instrumentator==7.0.0 327 | - proto-plus==1.23.0 328 | - protobuf==4.25.3 329 | - psutil==5.9.8 330 | - py-cpuinfo==9.0.0 331 | - pyairports==2.1.1 332 | - pyarrow==16.1.0 333 | - pyarrow-hotfix==0.6 334 | - pyasn1==0.6.0 335 | - pyasn1-modules==0.4.0 336 | - pybind11==2.12.0 337 | - pycountry==24.6.1 338 | - pydantic==2.7.4 339 | - pydantic-core==2.18.4 340 | - pyext==0.7 341 | - pyhocon==0.3.61 342 | - pyparsing==3.1.2 343 | - pytablewriter==1.2.0 344 | - python-dateutil==2.9.0.post0 345 | - python-dotenv==1.0.1 346 | - python-multipart==0.0.9 347 | - pytz==2024.1 348 | - ray==2.24.0 349 | - referencing==0.35.1 350 | - regex==2024.5.15 351 | - retrying==1.3.4 352 | - rich==13.7.1 353 | - rouge==1.0.1 354 | - rouge-score==0.1.2 355 | - rpds-py==0.18.1 356 | - rsa==4.9 357 | - sacrebleu==2.4.2 358 | - safetensors==0.4.3 359 | - scikit-learn==1.5.0 360 | - scipy==1.13.1 361 | - seaborn==0.13.2 362 | - sentencepiece==0.2.0 363 | - sentry-sdk==2.5.1 364 | - setproctitle==1.3.3 365 | - shellingham==1.5.4 366 | - smart-open==7.0.4 367 | - smmap==5.0.1 368 | - sniffio==1.3.1 369 | - spacy==3.7.5 370 | - spacy-legacy==3.0.12 371 | - spacy-loggers==1.0.5 372 | - sqlitedict==1.7.0 373 | - srsly==2.4.8 374 | - starlette==0.37.2 375 | - tabledata==1.3.3 376 | - tabulate==0.9.0 377 | - tcolorpy==0.1.6 378 | - thinc==8.2.4 379 | - threadpoolctl==3.5.0 380 | - tiktoken==0.7.0 381 | - tokenizers==0.19.1 382 | - torch==2.3.0 383 | - tqdm==4.66.4 384 | - tqdm-multiprocess==0.0.11 385 | - transformers==4.41.2 386 | - triton==2.3.0 387 | - typepy==1.3.2 388 | - typer==0.12.3 389 | - tzdata==2024.1 390 | - ujson==5.10.0 391 | - uncertainty-calibration==0.1.4 392 | - uritemplate==4.1.1 393 | - uvicorn==0.30.1 394 | - uvloop==0.19.0 395 | - vllm==0.5.0.post1 396 | - vllm-flash-attn==2.5.9 397 | - wandb==0.17.2 398 | - wasabi==1.1.3 399 | - watchfiles==0.22.0 400 | - weasel==0.4.1 401 | - websockets==12.0 402 | - word2number==1.1 403 | - wrapt==1.16.0 404 | - xformers==0.0.26.post1 405 | - xxhash==3.4.1 406 | - yarl==1.9.4 407 | - zstandard==0.18.0 408 | - highlight-text 409 | -------------------------------------------------------------------------------- /stateless_parallel_modules/attention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from dataclasses import dataclass 3 | from typing import Optional, Tuple, Union, Dict, Callable 4 | 5 | 6 | import torch 7 | from torch import nn 8 | import torch.utils.checkpoint 9 | import torch.nn.functional as F 10 | 11 | from transformers.models.llama.modeling_llama import ( 12 | apply_rotary_pos_emb) 13 | 14 | from .base import ( 15 | StatelessGeneralizedOperation, GeneralizedLinear, 16 | StatelessGeneralizedModule) 17 | from utils import get_nonlinearity 18 | 19 | 20 | class RotaryEmbedding(nn.Module): 21 | '''Based on transformers.models.llama.modeling_llama.LlamaRotaryEmbedding''' 22 | 23 | def __init__( 24 | self, 25 | dim, 26 | max_position_embeddings=2048, 27 | base=10000, 28 | device=None, 29 | scaling_factor=1.0, 30 | ): 31 | super().__init__() 32 | self.scaling_factor = scaling_factor 33 | self.dim = dim 34 | self.max_position_embeddings = max_position_embeddings 35 | self.base = base 36 | inv_freq = 1.0 / (self.base ** ( 37 | torch.arange(0, self.dim, 2, 38 | dtype=torch.int64).float().to(device) / self.dim)) 39 | self.register_buffer("inv_freq", inv_freq, persistent=False) 40 | 41 | self.max_seq_len_cached = max_position_embeddings 42 | t = torch.arange( 43 | self.max_seq_len_cached, device=device, dtype=torch.int64).type_as( 44 | self.inv_freq) 45 | t = t / self.scaling_factor 46 | freqs = torch.outer(t, self.inv_freq) 47 | 48 | emb = torch.cat((freqs, freqs), dim=-1) 49 | self.register_buffer( 50 | "_cos_cached", emb.cos().to(torch.get_default_dtype()), 51 | persistent=False) 52 | self.register_buffer( 53 | "_sin_cached", emb.sin().to(torch.get_default_dtype()), 54 | persistent=False) 55 | 56 | @torch.no_grad() 57 | def forward(self, x, position_ids): 58 | inv_freq_expanded = self.inv_freq[None, :, None].float().expand( 59 | position_ids.shape[0], -1, 1) 60 | position_ids_expanded = position_ids[:, None, :].float() 61 | 62 | # Force float32 since bfloat16 loses precision on long contexts 63 | # See https://github.com/huggingface/transformers/pull/29285 64 | device_type = x.device.type 65 | device_type = device_type if isinstance( 66 | device_type, str) and device_type != "mps" else "cpu" 67 | with torch.autocast(device_type=device_type, enabled=False): 68 | freqs = ( 69 | inv_freq_expanded.float() @ 70 | position_ids_expanded.float()).transpose(1, 2) 71 | emb = torch.cat((freqs, freqs), dim=-1) 72 | cos = emb.cos() 73 | sin = emb.sin() 74 | return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) 75 | 76 | 77 | @dataclass 78 | class StatelessAttentionParams: 79 | input_dim: Optional[int] 80 | # defaults to hidden dim. 81 | hidden_dim: Optional[int] 82 | output_dim: Optional[int] 83 | num_heads: int 84 | bias: bool 85 | max_position_id: int 86 | # rope params 87 | use_rope: bool 88 | rope_theta: float 89 | # None, forward, or backward 90 | masking_strategy: Optional[str] 91 | 92 | 93 | class StatelessAttention(StatelessGeneralizedModule): 94 | def __init__(self, 95 | attention_params: StatelessAttentionParams,): 96 | self.save_configs(attention_params=attention_params) 97 | StatelessGeneralizedModule.__init__( 98 | self=self, 99 | input_features=self.input_dim, 100 | output_features=self.output_dim, 101 | init_module=True,) 102 | if (self.input_dim is not None and 103 | self.output_dim is not None and 104 | self.hidden_dim is not None): 105 | 106 | self.instantiate_and_setup_ops( 107 | input_features=self.input_dim, 108 | hidden_features=self.hidden_dim, 109 | output_features=self.output_dim, 110 | preceding_module=None, 111 | default_output_features_mult=1, 112 | ) 113 | 114 | def instantiate_and_setup_ops( 115 | self, 116 | input_features: Optional[int] = None, 117 | hidden_features: Optional[int] = None, 118 | output_features: Optional[int] = None, 119 | preceding_module=None, 120 | default_output_features_mult: int = 1, 121 | **kwargs, 122 | ): 123 | 124 | if (self.input_dim is None or 125 | self.output_dim is None or 126 | self.hidden_dim is None): 127 | 128 | self.instantiate_model( 129 | input_features=input_features, 130 | output_features=output_features, 131 | preceding_module=preceding_module, 132 | default_output_features_mult=default_output_features_mult, 133 | ) 134 | self.input_dim = self.input_features 135 | self.output_dim = self.output_features 136 | if self.hidden_dim is None and hidden_features is not None: 137 | self.hidden_dim = hidden_features 138 | elif self.hidden_dim is None: 139 | print('Warning: hidden features not specified setting to ' + 140 | f'{self.output_features} (output features)') 141 | self.hidden_dim = self.output_features 142 | 143 | self.head_dim = self.hidden_dim // self.num_heads 144 | if (self.head_dim * self.num_heads) != self.hidden_dim: 145 | raise ValueError("hidden_dim must be divisible by num_heads") 146 | 147 | self.linear_op = GeneralizedLinear() 148 | 149 | operation_kwargs = dict( 150 | bias=self.bias, 151 | ) 152 | qkv_kwargs = dict( 153 | in_features=self.input_dim, 154 | out_features=self.hidden_dim*3, 155 | ) 156 | o_kwargs = dict( 157 | in_features=self.hidden_dim, 158 | out_features=self.output_dim, 159 | ) 160 | 161 | self._init_rope() 162 | 163 | self.setup_operations( 164 | operations=[self.linear_op, self.linear_op], 165 | operation_kwargs=operation_kwargs, 166 | operation_kwargs_overrides_list=[qkv_kwargs, o_kwargs], 167 | ) 168 | 169 | def save_configs(self, attention_params: StatelessAttentionParams,): 170 | self.config = attention_params 171 | self.input_dim = attention_params.input_dim 172 | self.hidden_dim = attention_params.hidden_dim 173 | self.output_dim = attention_params.output_dim 174 | self.num_heads = attention_params.num_heads 175 | 176 | self.multiple_heads = self.num_heads > 1 177 | 178 | self.bias = attention_params.bias 179 | self.max_position_id = attention_params.max_position_id 180 | self.use_rope = attention_params.use_rope 181 | self.rope_theta = attention_params.rope_theta 182 | self.masking_strategy = attention_params.masking_strategy 183 | self.apply_causal_mask = False 184 | self.backward_causal_mask = False 185 | if self.masking_strategy is not None: 186 | self.apply_causal_mask = True 187 | if self.masking_strategy == 'backward': 188 | self.backward_causal_mask = True 189 | else: 190 | assert self.masking_strategy == 'forward' 191 | 192 | def _init_rope(self): 193 | self.rotary_emb = RotaryEmbedding( 194 | dim=self.head_dim, 195 | max_position_embeddings=self.max_position_id, 196 | base=self.rope_theta, 197 | ) 198 | 199 | def forward( 200 | self, 201 | inputs: torch.Tensor, 202 | *args, 203 | n_parallel_dimensions: Optional[int] = None, 204 | attn_mask: Optional[torch.Tensor] = None, 205 | position_ids: Optional[torch.LongTensor] = None, 206 | **kwargs, 207 | ) -> torch.Tensor: 208 | *batch_dims, num_tokens, input_dim = inputs.size() 209 | 210 | weight, bias = self.parameters_per_layer[0] 211 | qkv_states = self.linear_op( 212 | input=inputs, 213 | weight=weight, 214 | bias=bias, 215 | parallel_operations=self.parallel_operations, 216 | n_parallel_dimensions=n_parallel_dimensions, 217 | ) 218 | 219 | # 3 dimensional - flatten all batch dims 220 | qkv_states = qkv_states.flatten(start_dim=0, end_dim=-3) 221 | # 2 dimensional - flatten all batch dims 222 | position_ids = position_ids.flatten(start_dim=0, end_dim=-2) 223 | 224 | if not self.multiple_heads: 225 | # add singleton dimension to use attention over faster kernel 226 | qkv_states = qkv_states.unsqueeze_(1) 227 | 228 | query_states, key_states, value_states = torch.chunk( 229 | qkv_states, chunks=3, dim=-1) 230 | 231 | if self.multiple_heads: 232 | total_batch_dim = qkv_states.shape[0] 233 | query_states = query_states.view( 234 | total_batch_dim, num_tokens, 235 | self.num_heads, self.head_dim).transpose(-2, -3) 236 | key_states = key_states.view( 237 | total_batch_dim, num_tokens, 238 | self.num_heads, self.head_dim).transpose(-2, -3) 239 | value_states = value_states.view( 240 | total_batch_dim, num_tokens, 241 | self.num_heads, self.head_dim).transpose(-2, -3) 242 | 243 | if self.use_rope: 244 | cos, sin = self.rotary_emb(value_states, position_ids) 245 | query_states, key_states = apply_rotary_pos_emb( 246 | query_states, key_states, cos, sin) 247 | 248 | if self.apply_causal_mask or attn_mask is not None: 249 | min_dtype = torch.finfo(inputs.dtype).min 250 | if self.apply_causal_mask: 251 | causal_mask = torch.ones( 252 | (num_tokens, num_tokens), 253 | device=key_states.device, 254 | dtype=torch.bool, 255 | ) 256 | if self.backward_causal_mask: 257 | # one for all lower diagonal tokens to be masked 258 | causal_mask = torch.tril(causal_mask, diagonal=-1) 259 | else: 260 | # one for all upper diagonal tokens to be masked 261 | causal_mask = torch.triu(causal_mask, diagonal=1) 262 | if attn_mask is not None: 263 | # In case num tokens < size of the attention mask 264 | inv_attn_mask = torch.logical_not(attn_mask[..., -num_tokens:]) 265 | 266 | if self.apply_causal_mask: 267 | causal_mask = torch.logical_or(inv_attn_mask, causal_mask) 268 | else: 269 | causal_mask = inv_attn_mask 270 | causal_mask = causal_mask*min_dtype 271 | elif self.apply_causal_mask: 272 | causal_mask = causal_mask*min_dtype 273 | else: 274 | causal_mask = None 275 | 276 | attn_output = F.scaled_dot_product_attention( 277 | query=query_states, 278 | key=key_states, 279 | value=value_states, 280 | attn_mask=causal_mask, 281 | is_causal=False, 282 | scale=None, 283 | ) 284 | if self.multiple_heads: 285 | attn_output = attn_output.transpose(-2, -3).contiguous() 286 | 287 | attn_output = attn_output.reshape( 288 | *batch_dims, num_tokens, self.hidden_dim) 289 | else: 290 | attn_output = attn_output.view( 291 | *batch_dims, num_tokens, self.hidden_dim) 292 | 293 | weight, bias = self.parameters_per_layer[1] 294 | attn_output = self.linear_op( 295 | input=attn_output, 296 | weight=weight, 297 | bias=bias, 298 | parallel_operations=self.parallel_operations, 299 | n_parallel_dimensions=n_parallel_dimensions, 300 | ) 301 | return attn_output 302 | 303 | 304 | class MonoHeadStatelessAttention(StatelessAttention): 305 | """Multi-headed attention from 'Attention Is All You Need' paper""" 306 | # NOTE: hydra specification should have this as partial, since input_dim 307 | # might be unwieldy to always manually specify 308 | 309 | def __init__(self, 310 | attention_params: StatelessAttentionParams,): 311 | # TODO 312 | self.save_configs(attention_params=attention_params) 313 | StatelessAttention.__init__( 314 | self=self, attention_params=attention_params) 315 | assert self.num_heads == 1 316 | 317 | def forward( 318 | self, 319 | inputs: torch.Tensor, 320 | *args, 321 | n_parallel_dimensions: Optional[int] = None, 322 | attn_mask: Optional[torch.Tensor] = None, 323 | position_ids: Optional[torch.LongTensor] = None, 324 | **kwargs, 325 | ) -> torch.Tensor: 326 | *batch_dims, num_tokens, input_dim = inputs.size() 327 | 328 | # else: 329 | weight, bias = self.parameters_per_layer[0] 330 | qkv_states = self.linear_op( 331 | input=inputs, 332 | weight=weight, 333 | bias=bias, 334 | parallel_operations=self.parallel_operations, 335 | # handle reshaping internally 336 | n_parallel_dimensions=n_parallel_dimensions, 337 | ) 338 | 339 | # 3 dimensional - flatten all batch dims 340 | qkv_states = qkv_states.flatten(start_dim=0, end_dim=-3) 341 | # 2 dimensional - flatten all batch dims 342 | position_ids = position_ids.flatten(start_dim=0, end_dim=-2) 343 | 344 | # total_batch_dim x num_tokens x head_dim (single head) 345 | query_states, key_states, value_states = torch.chunk( 346 | qkv_states, chunks=3, dim=-1) 347 | 348 | if self.use_rope: 349 | cos, sin = self.rotary_emb(value_states, position_ids) 350 | query_states, key_states = apply_rotary_pos_emb( 351 | query_states, key_states, cos, sin) 352 | 353 | if self.apply_causal_mask: 354 | causal_mask = torch.ones( 355 | (num_tokens, num_tokens), 356 | device=key_states.device, 357 | dtype=torch.bool, 358 | ) 359 | if self.backward_causal_mask: 360 | # one for all lower diagonal tokens to be masked 361 | causal_mask = torch.tril(causal_mask, diagonal=-1) 362 | else: 363 | # one for all upper diagonal tokens to be masked 364 | causal_mask = torch.triu(causal_mask, diagonal=1) 365 | if attn_mask is not None: 366 | # In case num tokens < size of the attention mask 367 | inv_attn_mask = torch.logical_not(attn_mask[..., -num_tokens:]) 368 | 369 | if self.apply_causal_mask: 370 | causal_mask = torch.logical_or(inv_attn_mask, causal_mask) 371 | else: 372 | causal_mask = inv_attn_mask 373 | 374 | min_dtype = torch.finfo(inputs.dtype).min 375 | causal_mask = causal_mask*min_dtype 376 | elif self.apply_causal_mask: 377 | causal_mask = causal_mask*min_dtype 378 | else: 379 | causal_mask = None 380 | 381 | attn_output = F.scaled_dot_product_attention( 382 | query=query_states, 383 | key=key_states, 384 | value=value_states, 385 | attn_mask=causal_mask, 386 | is_causal=False, 387 | scale=None, # defaults to 1/root(head_dim) 388 | ) 389 | 390 | attn_output = attn_output.view( 391 | *batch_dims, num_tokens, self.hidden_dim) 392 | 393 | weight, bias = self.parameters_per_layer[1] 394 | attn_output = self.linear_op( 395 | input=attn_output, 396 | weight=weight, 397 | bias=bias, 398 | parallel_operations=self.parallel_operations, 399 | n_parallel_dimensions=n_parallel_dimensions, 400 | ) 401 | 402 | return attn_output 403 | -------------------------------------------------------------------------------- /memory_policy/base_dynamic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | import copy 4 | import math 5 | import numpy as np 6 | from dataclasses import dataclass 7 | from typing import Optional, Tuple, Union, Dict, List 8 | 9 | import abc 10 | import torch 11 | from torch import nn 12 | import torch.utils.checkpoint 13 | import torch.nn.functional as F 14 | from torch.cuda.amp import autocast 15 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 16 | from transformers import LlamaPreTrainedModel 17 | from transformers.cache_utils import Cache, DynamicCache, StaticCache 18 | from .base import MemoryPolicy, ParamMemoryPolicy 19 | 20 | 21 | def compute_recency(position_ids, start_recency_from=1): 22 | most_recent_positions = position_ids[..., [-1]] 23 | 24 | cache_age = most_recent_positions - position_ids + start_recency_from 25 | return cache_age 26 | 27 | def compute_recency_scores(position_ids, recency_exp, recency_coeff,): 28 | 29 | cache_age = compute_recency(position_ids=position_ids, start_recency_from=1) 30 | recency_scores = torch.pow(1/cache_age, exponent=recency_exp) 31 | return recency_scores*recency_coeff, recency_scores 32 | 33 | def threshold_score_idxs( 34 | masked_full_scores, 35 | dynamic_thresh, 36 | 37 | 38 | preserve_order=True, 39 | cache_size=None, 40 | ): 41 | 42 | 43 | 44 | num_samples = masked_full_scores.shape[-1] 45 | if cache_size is not None and num_samples > cache_size: 46 | sorted_scores, indices = torch.topk( 47 | masked_full_scores, k=cache_size, sorted=True, dim=-1) 48 | sorted_scores = sorted_scores.flip(dims=(-1,)) 49 | indices = indices.flip(dims=(-1,)) 50 | else: 51 | sorted_scores, indices = torch.sort( 52 | masked_full_scores, descending=False, dim=-1) 53 | 54 | 55 | 56 | thresholded_scores = sorted_scores >= dynamic_thresh 57 | 58 | 59 | first_above_thresh = torch.sum(~thresholded_scores, dim=-1, 60 | dtype=torch.long) 61 | 62 | discard_idx = torch.min(first_above_thresh) 63 | 64 | 65 | 66 | retained_idxs = indices[..., discard_idx:] 67 | 68 | 69 | new_mask = thresholded_scores[..., discard_idx:] 70 | 71 | if preserve_order: 72 | 73 | retained_idxs, _ = retained_idxs.sort(descending=False, dim=-1,) 74 | return retained_idxs, new_mask 75 | 76 | @dataclass 77 | class RecencyParams: 78 | recency_coeff: float 79 | recency_exp: float 80 | 81 | @dataclass 82 | class AttentionParams: 83 | attn_coeff: float 84 | attn_ema_coeff: float 85 | 86 | back_attn_coeff: float = 0 87 | 88 | 89 | 90 | 91 | 92 | class DynamicMemoryPolicy(MemoryPolicy): 93 | 94 | 95 | def __init__(self, 96 | cache_size: Optional[int] = None, 97 | init_module: bool = True, 98 | ): 99 | 100 | MemoryPolicy.__init__( 101 | self, 102 | cache_size=cache_size, 103 | init_module=init_module, 104 | ) 105 | 106 | self._record_mask_based_sparsity = False 107 | self._record_stats_per_head = False 108 | self._record_recency_stats = False 109 | 110 | @property 111 | def record_mask_based_sparsity(self,): 112 | return self._record_mask_based_sparsity 113 | 114 | @property 115 | def record_stats_per_head(self,): 116 | return self._record_mask_based_sparsity 117 | 118 | @record_mask_based_sparsity.setter 119 | def record_mask_based_sparsity(self, value): 120 | self._record_mask_based_sparsity = value 121 | 122 | @record_stats_per_head.setter 123 | def record_stats_per_head(self, value): 124 | self._record_stats_per_head = value 125 | 126 | def is_dynamic(self,): 127 | return True 128 | 129 | def select_max_score_idxs( 130 | self, 131 | masked_full_scores, 132 | cache_size, 133 | 134 | 135 | preserve_order=True, 136 | ): 137 | 138 | 139 | 140 | 141 | 142 | sorted_top_scores, retained_idxs = torch.topk( 143 | input=masked_full_scores, k=cache_size, largest=True, 144 | sorted=False, 145 | ) 146 | 147 | if preserve_order: 148 | 149 | retained_idxs, _ = retained_idxs.sort(descending=False, dim=-1,) 150 | 151 | new_mask = torch.ones_like(retained_idxs) 152 | 153 | return retained_idxs, new_mask 154 | 155 | def threshold_score_idxs( 156 | self, 157 | masked_full_scores, 158 | dynamic_thresh, 159 | 160 | 161 | preserve_order=True, 162 | cache_size=None, 163 | ): 164 | 165 | retained_idxs, new_mask = threshold_score_idxs( 166 | masked_full_scores=masked_full_scores, 167 | dynamic_thresh=dynamic_thresh, 168 | preserve_order=preserve_order, 169 | cache_size=cache_size, 170 | ) 171 | return retained_idxs, new_mask 172 | 173 | def select_new_dynamic_idxs( 174 | self, 175 | masked_full_scores, 176 | dynamic_thresh, 177 | cache_size, 178 | preserve_order=True, 179 | ): 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | return self.threshold_score_idxs( 188 | masked_full_scores=masked_full_scores, 189 | dynamic_thresh=dynamic_thresh, 190 | preserve_order=preserve_order, 191 | cache_size=cache_size, 192 | ) 193 | 194 | def need_new_dynamic_idxs( 195 | self, 196 | masked_full_scores, 197 | cache_size, 198 | ): 199 | if cache_size is None: 200 | return True 201 | else: 202 | num_samples = masked_full_scores.shape[-1] 203 | if num_samples > cache_size: 204 | return True 205 | else: 206 | return False 207 | 208 | 209 | 210 | def save_recency_params(self, recency_params: RecencyParams): 211 | self.recency_coeff = recency_params.recency_coeff 212 | self.recency_exp = recency_params.recency_exp 213 | self.use_recency_scores = False 214 | if self.recency_coeff is not None: 215 | if self.recency_coeff > 0: 216 | self.use_recency_scores = True 217 | 218 | print(f'Using recency score: {self.use_recency_scores}') 219 | 220 | 221 | def save_attention_params(self, attention_params: AttentionParams): 222 | 223 | self.attn_coeff = attention_params.attn_coeff 224 | self.attn_ema_coeff = attention_params.attn_ema_coeff 225 | 226 | self.use_full_attn_scores = False 227 | if self.attn_coeff is not None: 228 | if self.attn_coeff > 0: 229 | self.use_full_attn_scores = True 230 | 231 | 232 | 233 | 234 | self.back_attn_coeff = attention_params.back_attn_coeff 235 | 236 | 237 | self.use_back_attn_scores = False 238 | 239 | if self.back_attn_coeff is not None: 240 | if self.back_attn_coeff > 0: 241 | self.use_back_attn_scores = True 242 | 243 | 244 | 245 | 246 | 247 | self.use_attention_scores = (self.use_full_attn_scores or 248 | self.use_back_attn_scores) 249 | if self.use_full_attn_scores and self.use_back_attn_scores: 250 | self.use_both_attn_scores = True 251 | else: 252 | self.use_both_attn_scores = False 253 | 254 | 255 | print(f'Using attention score: {self.use_attention_scores}') 256 | 257 | 258 | 259 | def process_position_ids(self, position_ids, num_all_tokens, num_new_tokens, 260 | attention_mask): 261 | '''Replicates position ids for each head''' 262 | if position_ids is None: 263 | assert num_all_tokens == num_new_tokens 264 | assert attention_mask is not None 265 | 266 | position_ids = torch.cumsum(attention_mask, dim=-1) - 1 267 | position_ids = position_ids.unsqueeze(-2) 268 | return position_ids.expand(-1, self.num_heads, -1) 269 | 270 | def compute_recency_scores(self, position_ids, recency_exp, recency_coeff): 271 | return compute_recency_scores( 272 | position_ids=position_ids, 273 | recency_exp=recency_exp, 274 | recency_coeff=recency_coeff, 275 | ) 276 | 277 | def initialize_cache_masks(self,): 278 | self.cache_masks = [None for _ in range(self.num_memory_layers)] 279 | 280 | def finalize_registration(self,): 281 | MemoryPolicy.finalize_registration(self,) 282 | self.initialize_cache_masks() 283 | self.initialize_stat_objects() 284 | 285 | 286 | def initialize_stat_objects(self, initialize_mask_spasity=True): 287 | self.dynamic_cache_sizes = [[] for _ in range( 288 | self.num_memory_layers)] 289 | self.final_dynamic_cache_sizes = [[] for _ in range( 290 | self.num_memory_layers)] 291 | 292 | if initialize_mask_spasity: 293 | self.initialize_mask_based_sparsity() 294 | if self._record_recency_stats: 295 | self.initialize_recency_stats() 296 | 297 | def initialize_mask_based_sparsity(self, ): 298 | self.dynamic_mask_sample_sparsity = [[] for _ in range( 299 | self.num_memory_layers)] 300 | self.dynamic_mask_head_sparsity = [[] for _ in range( 301 | self.num_memory_layers)] 302 | if self.record_stats_per_head: 303 | raise NotImplementedError 304 | self.dynamic_mask_head_sparsity_dicts = [{}] 305 | 306 | def initialize_recency_stats(self,): 307 | self.recorded_final_recencies = [[] for _ in range( 308 | self.num_memory_layers)] 309 | self._record_recency_stats = True 310 | 311 | def record_recency_stats(self, layer_id, position_ids): 312 | if position_ids is not None: 313 | recency = compute_recency(position_ids=position_ids) 314 | self.recorded_final_recencies[layer_id].append( 315 | recency.float().mean().item()) 316 | 317 | def get_param_stats(self, reset=True) -> dict: 318 | stats = dict() 319 | if self.record_eval_stats: 320 | if len(self.dynamic_cache_sizes[0]): 321 | all_final_cache_sizes = [] 322 | for i in range(self.num_memory_layers): 323 | stats_key_prefix = f'mem_stats/layer_id_{i}/' 324 | stats[stats_key_prefix + 'dynamic_cache_sizes'] = np.mean( 325 | self.dynamic_cache_sizes[i]) 326 | final_cache_sizes = (self.final_dynamic_cache_sizes[i] + 327 | [self.dynamic_cache_sizes[i][-1]]) 328 | all_final_cache_sizes += final_cache_sizes 329 | stats[stats_key_prefix + 'final_dynamic_cache_sizes'] = ( 330 | np.mean(final_cache_sizes)) 331 | 332 | stats_key_prefix = 'mem_stats/overall/' 333 | stats[stats_key_prefix + 'dynamic_cache_sizes'] = np.mean( 334 | [v for vs in self.dynamic_cache_sizes for v in vs]) 335 | stats[stats_key_prefix + 'final_dynamic_cache_sizes'] = np.mean( 336 | [cs[-1] for cs in self.dynamic_cache_sizes]) 337 | stats[stats_key_prefix + 'final_dynamic_cache_sizes'] = ( 338 | np.mean(all_final_cache_sizes)) 339 | if self.record_eval_stats or self.record_mask_based_sparsity: 340 | if len(self.dynamic_mask_sample_sparsity[0]): 341 | for i in range(self.num_memory_layers): 342 | stats_key_prefix = f'mem_stats/layer_id_{i}/' 343 | stats[stats_key_prefix + 'unmasked_samples'] = np.mean( 344 | self.dynamic_mask_sample_sparsity[i]) 345 | stats[stats_key_prefix + 'unmasked_samples_per_head'] = ( 346 | np.mean(self.dynamic_mask_head_sparsity[i])) 347 | stats[stats_key_prefix + 'unmasked_sample_final'] = np.mean( 348 | self.dynamic_mask_sample_sparsity[i][-1]) 349 | stats[stats_key_prefix + 350 | 'unmasked_sample_per_head_final'] = np.mean( 351 | self.dynamic_mask_head_sparsity[i][-1]) 352 | 353 | 354 | 355 | stats_key_prefix = 'mem_stats/overall/' 356 | stats[stats_key_prefix + 'unmasked_samples'] = np.mean( 357 | [v for vs in self.dynamic_mask_sample_sparsity for v in vs]) 358 | stats[stats_key_prefix + 'unmasked_samples_per_head'] = np.mean( 359 | [v for vs in self.dynamic_mask_head_sparsity for v in vs]) 360 | stats[stats_key_prefix + 'unmasked_sample_final'] = np.mean( 361 | [cs[-1] for cs in self.dynamic_mask_sample_sparsity]) 362 | stats[stats_key_prefix + 'unmasked_sample_per_head_final'] = ( 363 | np.mean([cs[-1] for cs in self.dynamic_mask_head_sparsity])) 364 | 365 | 366 | 367 | if self._record_recency_stats: 368 | for i in range(self.num_memory_layers): 369 | stats_key_prefix = f'mem_stats/layer_id_{i}/' 370 | stats[stats_key_prefix + 'final_recencies'] = np.mean( 371 | self.recorded_final_recencies[i]) 372 | stats_key_prefix = 'mem_stats/overall/' 373 | stats[stats_key_prefix + 'final_recencies'] = np.mean( 374 | [v for vs in self.recorded_final_recencies for v in vs]) 375 | 376 | if reset: 377 | self.initialize_stat_objects() 378 | return stats 379 | 380 | def record_dynamic_stats(self, layer_id, cache_size, new_sequences=False): 381 | if new_sequences and len(self.dynamic_cache_sizes[layer_id]) > 0: 382 | self.final_dynamic_cache_sizes[layer_id].append( 383 | self.dynamic_cache_sizes[layer_id][-1]) 384 | self.dynamic_cache_sizes[layer_id].append(int(cache_size)) 385 | 386 | def record_mask_dynamic_stats( 387 | self, layer_id, 388 | cache_mask, 389 | ): 390 | 391 | unmasked_samples_per_head = cache_mask.to(torch.float32).sum(-1) 392 | 393 | self.dynamic_mask_sample_sparsity[layer_id].append( 394 | torch.max(unmasked_samples_per_head, dim=-1)[0].mean().item()) 395 | self.dynamic_mask_head_sparsity[layer_id].append( 396 | unmasked_samples_per_head.mean().item()) 397 | 398 | class DynamicParamMemoryPolicy(ParamMemoryPolicy, DynamicMemoryPolicy): 399 | 400 | def __init__( 401 | self, 402 | base_param_size, 403 | pop_size, 404 | per_head, 405 | per_layer, 406 | additional_shared_params=0, 407 | learnable_params: Optional[Dict[str, Union[str, tuple]]] = None, 408 | learned_params: Optional[Dict[str, bool]] = None, 409 | component_names: Optional[List[str]] = None, 410 | cache_size: Optional[int] = None, 411 | init_module: bool = True, 412 | lazy_param_num: bool = False, 413 | ): 414 | 415 | 416 | 417 | 418 | 419 | 420 | 421 | 422 | ParamMemoryPolicy.__init__( 423 | self, cache_size=cache_size, base_param_size=base_param_size, 424 | pop_size=pop_size, per_head=per_head, per_layer=per_layer, 425 | additional_shared_params=additional_shared_params, 426 | learnable_params=learnable_params, 427 | learned_params=learned_params, 428 | component_names=component_names, 429 | init_module=init_module, 430 | lazy_param_num=lazy_param_num, 431 | ) 432 | 433 | def is_dynamic(self,): 434 | return True 435 | 436 | def finalize_registration(self,): 437 | 438 | ParamMemoryPolicy.finalize_registration(self,) 439 | self.initialize_cache_masks() 440 | self.initialize_stat_objects() 441 | 442 | 443 | --------------------------------------------------------------------------------