├── train.py ├── benchmarks ├── benchmark_train.md ├── script_train.sh └── script_inference.sh ├── xmixers ├── modules │ ├── token_mixers │ │ ├── long_conv │ │ │ ├── dtu.py │ │ │ ├── __init__.py │ │ │ ├── rpe.py │ │ │ ├── gtu.py │ │ │ └── tno.py │ │ ├── chunk_linear_attention │ │ │ └── __init__.py │ │ ├── deep_memory │ │ │ ├── fast_weight │ │ │ │ ├── __init__.py │ │ │ │ └── fast_weight_glu.py │ │ │ ├── optimizer │ │ │ │ ├── __init__.py │ │ │ │ ├── utils.py │ │ │ │ ├── sgd.py │ │ │ │ └── base_optimizer.py │ │ │ ├── loss │ │ │ │ └── __init__.py │ │ │ └── utils.py │ │ ├── linear_attention │ │ │ └── __init__.py │ │ └── vanilla_attention │ │ │ ├── __init__.py │ │ │ └── utils.py │ ├── pes │ │ ├── md │ │ │ ├── __init__.py │ │ │ └── md_tpe.py │ │ ├── __init__.py │ │ ├── learnable_pe.py │ │ ├── utils.py │ │ ├── sin_cos_pe.py │ │ ├── mlp_pe.py │ │ ├── tpe.py │ │ └── lrpe.py │ ├── quantizer │ │ ├── __init__.py │ │ ├── utils.py │ │ └── finite_scalar_quantizer.py │ ├── channel_mixers │ │ ├── utils.py │ │ ├── ffn.py │ │ ├── __init__.py │ │ ├── glu.py │ │ ├── nglu.py │ │ └── alu.py │ ├── normalizations │ │ ├── l2_norm.py │ │ ├── srms_norm.py │ │ ├── offset_scale.py │ │ ├── scale_norm.py │ │ ├── group_srms_norm.py │ │ ├── __init__.py │ │ ├── group_norm.py │ │ ├── dynamic_tanh.py │ │ ├── utils.py │ │ ├── layer_norm.py │ │ ├── rms_norm.py │ │ └── group_rms_norm.py │ ├── __init__.py │ └── activations.py ├── version.py ├── models │ ├── long_conv │ │ ├── __init__.py │ │ └── tnn │ │ │ ├── __init__.py │ │ │ └── configuration_tnn.py │ ├── chunk_linear_transformer │ │ ├── __init__.py │ │ └── chunk_rnn │ │ │ ├── __init__.py │ │ │ └── configuration_chunk_rnn.py │ ├── transformer │ │ ├── __init__.py │ │ ├── gpt │ │ │ ├── __init__.py │ │ │ └── configuration_gpt.py │ │ ├── ngpt │ │ │ ├── __init__.py │ │ │ └── configuration_ngpt.py │ │ ├── llama │ │ │ ├── __init__.py │ │ │ └── configuration_llama.py │ │ └── flex_gpt │ │ │ ├── __init__.py │ │ │ └── configuration_flex_gpt.py │ ├── hybrid │ │ ├── __init__.py │ │ ├── naive_hybrid │ │ │ ├── __init__.py │ │ │ └── configuration_naive_hybrid.py │ │ └── lm_head_hybrid │ │ │ ├── __init__.py │ │ │ └── modeling_outputs.py │ ├── model.py │ ├── linear_transformer │ │ ├── gsa │ │ │ ├── __init__.py │ │ │ └── configuration_gsa.py │ │ ├── tnl │ │ │ ├── __init__.py │ │ │ └── configuration_tnl.py │ │ ├── ttt │ │ │ ├── __init__.py │ │ │ └── configuration_ttt.py │ │ ├── hgrn1 │ │ │ ├── __init__.py │ │ │ └── configuration_hgrn1.py │ │ ├── hgrn2 │ │ │ ├── __init__.py │ │ │ └── configuration_hgrn2.py │ │ ├── hgrn3 │ │ │ ├── __init__.py │ │ │ └── configuration_hgrn3.py │ │ ├── metala │ │ │ ├── __init__.py │ │ │ └── configuration_metala.py │ │ ├── mesa_net │ │ │ ├── __init__.py │ │ │ └── configuration_mesa_net.py │ │ ├── deltanet │ │ │ ├── __init__.py │ │ │ └── configuration_deltanet.py │ │ ├── lightnet │ │ │ ├── __init__.py │ │ │ └── configuration_lightnet.py │ │ ├── dense_rnn │ │ │ ├── __init__.py │ │ │ └── configuration_dense_rnn.py │ │ ├── polar_rnn │ │ │ ├── __init__.py │ │ │ └── configuration_polar_rnn.py │ │ ├── mamba2 │ │ │ ├── __init__.py │ │ │ └── configuration_mamba2.py │ │ ├── linear_transformer │ │ │ ├── __init__.py │ │ │ └── configuration_linear_transformer.py │ │ ├── decay_linear_transformer │ │ │ ├── __init__.py │ │ │ └── configuration_decay_linear_transformer.py │ │ ├── implicit_linear_transformer │ │ │ ├── __init__.py │ │ │ └── configuration_implicit_linear_transformer.py │ │ └── __init__.py │ └── __init__.py ├── ops │ ├── __init__.py │ └── long_conv_1d.py ├── monkey_patch │ ├── __init__.py │ └── mpa │ │ ├── __init__.py │ │ └── configuration_llama_mpa.py ├── utils │ ├── constants.py │ ├── __init__.py │ ├── mask_utils.py │ └── utils.py └── __init__.py ├── examples ├── script.sh ├── tnn.py ├── gtu.py └── llama.py ├── configs ├── transformer │ ├── mla │ │ ├── mla_86m.json │ │ └── mla_310m.json │ └── llama_half_rope │ │ ├── llama_half_rope_1_2b.json │ │ ├── llama_half_rope_2_5b.json │ │ ├── llama_half_rope_310m.json │ │ ├── llama_half_rope_7b.json │ │ ├── llama_half_rope_90m.json │ │ ├── llama_half_rope_xl.json │ │ └── llama_half_rope_xxl.json └── linear_transformer │ └── hgrn2 │ ├── hgrn2_7b.json │ ├── hgrn2_90m.json │ ├── hgrn2_xl.json │ ├── hgrn2_xxl.json │ ├── hgrn2_1_2b.json │ ├── hgrn2_2_5b.json │ ├── hgrn2_310m.json │ ├── hgrn2_h128_310m.json │ └── hgrn2_h128_90m.json ├── evals └── harness.py ├── tests ├── modules │ └── md_tpe │ │ └── test.py ├── script.sh └── train_test │ └── test.py ├── setup.py ├── scripts └── linear_decay_transformer │ ├── analysis_log_decay.py │ └── save_log_decay.py ├── README.md └── .gitignore /train.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /benchmarks/benchmark_train.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /xmixers/modules/token_mixers/long_conv/dtu.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /xmixers/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.0" 2 | -------------------------------------------------------------------------------- /xmixers/models/long_conv/__init__.py: -------------------------------------------------------------------------------- 1 | from .tnn import * 2 | -------------------------------------------------------------------------------- /xmixers/modules/pes/md/__init__.py: -------------------------------------------------------------------------------- 1 | from .md_tpe import MdTpe 2 | -------------------------------------------------------------------------------- /xmixers/modules/token_mixers/long_conv/__init__.py: -------------------------------------------------------------------------------- 1 | from .gtu import Gtu 2 | -------------------------------------------------------------------------------- /examples/script.sh: -------------------------------------------------------------------------------- 1 | export XMIXERS_DEBUG=True 2 | 3 | # python gtu.py 4 | python llama.py 5 | -------------------------------------------------------------------------------- /xmixers/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .long_conv_1d import long_conv_1d_op, long_conv_1d_op_naive 2 | -------------------------------------------------------------------------------- /xmixers/modules/quantizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .finite_scalar_quantizer import FiniteScalarQuantizer 2 | -------------------------------------------------------------------------------- /xmixers/modules/token_mixers/chunk_linear_attention/__init__.py: -------------------------------------------------------------------------------- 1 | from .chunk_rnn import ChunkRnn 2 | -------------------------------------------------------------------------------- /xmixers/monkey_patch/__init__.py: -------------------------------------------------------------------------------- 1 | from .mpa import LlamaMpaAttention, LlamaMpaConfig, LlamaMpaForCausalLM 2 | -------------------------------------------------------------------------------- /xmixers/modules/token_mixers/deep_memory/fast_weight/__init__.py: -------------------------------------------------------------------------------- 1 | from .fast_weight_glu import FastWeightGLU, FastWeightHpGLU 2 | -------------------------------------------------------------------------------- /xmixers/utils/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | XMIXERS_DEBUG = eval(os.getenv("XMIXERS_DEBUG", "False")) 4 | 5 | EMBED_DIM_BASE = 256 6 | -------------------------------------------------------------------------------- /xmixers/models/chunk_linear_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .chunk_rnn import ChunkRnnConfig, ChunkRnnForCausalLM, ChunkRnnLayer, ChunkRnnModel 2 | -------------------------------------------------------------------------------- /xmixers/modules/quantizer/utils.py: -------------------------------------------------------------------------------- 1 | def round_ste(x): 2 | """Round with straight through gradients.""" 3 | xhat = x.round() 4 | return x + (xhat - x).detach() 5 | -------------------------------------------------------------------------------- /examples/tnn.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModel 2 | 3 | from xmixers.models import TnnConfig 4 | 5 | config = TnnConfig() 6 | 7 | print(config) 8 | 9 | config.update({"num_layers": 12}) 10 | 11 | model = AutoModel.from_config(config) 12 | 13 | print(model) 14 | -------------------------------------------------------------------------------- /xmixers/modules/pes/__init__.py: -------------------------------------------------------------------------------- 1 | from .learnable_pe import LearnablePe 2 | from .lrpe import Lrpe 3 | from .md import MdTpe 4 | from .mlp_pe import MlpPe 5 | from .sin_cos_pe import SinCosPe 6 | from .tpe import Tpe 7 | from .utils import get_log_slopes, get_log_slopes_general 8 | -------------------------------------------------------------------------------- /xmixers/models/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .flex_gpt import FlexGPTConfig, FlexGPTForCausalLM, FlexGPTLayer, FlexGPTModel 2 | from .gpt import GPTConfig, GPTForCausalLM, GPTLayer, GPTModel 3 | from .llama import LLaMAConfig, LLaMAForCausalLM, LLaMALayer, LLaMAModel 4 | from .ngpt import nGPTConfig, nGPTForCausalLM, nGPTLayer, nGPTModel 5 | -------------------------------------------------------------------------------- /xmixers/models/hybrid/__init__.py: -------------------------------------------------------------------------------- 1 | from .lm_head_hybrid import ( 2 | LmHeadHybridConfig, 3 | LmHeadHybridForCausalLM, 4 | LmHeadHybridLayer, 5 | LmHeadHybridModel, 6 | ) 7 | from .naive_hybrid import ( 8 | NaiveHybridConfig, 9 | NaiveHybridForCausalLM, 10 | NaiveHybridLayer, 11 | NaiveHybridModel, 12 | ) 13 | -------------------------------------------------------------------------------- /xmixers/modules/token_mixers/deep_memory/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_optimizer import FastWeightOptimizer 2 | from .sgd import SGD 3 | 4 | 5 | def get_optimizer(name: str, **kwargs) -> FastWeightOptimizer: 6 | if name == "sgd": 7 | return SGD(**kwargs) 8 | else: 9 | raise ValueError(f"Optimizer {name} not found") 10 | -------------------------------------------------------------------------------- /xmixers/models/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class SequenceModel(nn.Module): 5 | def __init__(self, **kwargs): 6 | super().__init__() 7 | n_layer = kwargs.get("n_layer", default=1) 8 | 9 | def get_block_config(self, **kwargs): 10 | pass 11 | 12 | def build_block(self, **kwargs): 13 | pass 14 | -------------------------------------------------------------------------------- /xmixers/modules/token_mixers/deep_memory/optimizer/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_pooling_fn(pooling_method: str): 5 | if pooling_method == "mean": 6 | 7 | def f(x): 8 | return torch.mean(x, dim=1).unsqueeze(-1).unsqueeze(-1) 9 | 10 | else: 11 | raise ValueError(f"Pooling method {pooling_method} not found") 12 | 13 | return f 14 | -------------------------------------------------------------------------------- /xmixers/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import ( 2 | DenseRnnConfig, 3 | DenseRnnForCausalLM, 4 | DenseRnnModel, 5 | GPTConfig, 6 | GPTForCausalLM, 7 | GPTLayer, 8 | GPTModel, 9 | LLaMAConfig, 10 | LLaMAForCausalLM, 11 | LLaMALayer, 12 | LLaMAModel, 13 | TnlConfig, 14 | TnlForCausalLM, 15 | TnlModel, 16 | ) 17 | from .monkey_patch import LlamaMpaConfig, LlamaMpaForCausalLM 18 | -------------------------------------------------------------------------------- /xmixers/models/long_conv/tnn/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | 3 | from .configuration_tnn import TnnConfig 4 | from .modeling_tnn import TnnForCausalLM, TnnModel 5 | 6 | AutoConfig.register(TnnConfig.model_type, TnnConfig) 7 | AutoModel.register(TnnConfig, TnnModel) 8 | AutoModelForCausalLM.register(TnnConfig, TnnForCausalLM) 9 | 10 | __all__ = ["TnnConfig", "TnnModel", "TnnForCausalLM"] 11 | -------------------------------------------------------------------------------- /xmixers/monkey_patch/mpa/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | 3 | from .configuration_llama_mpa import LlamaMpaConfig 4 | from .modeling_llama_mpa import LlamaMpaAttention, LlamaMpaForCausalLM 5 | 6 | AutoConfig.register(LlamaMpaConfig.model_type, LlamaMpaConfig) 7 | AutoModelForCausalLM.register(LlamaMpaConfig, LlamaMpaForCausalLM) 8 | 9 | __all__ = ["LlamaMpaConfig", "LlamaMpaForCausalLM"] 10 | -------------------------------------------------------------------------------- /xmixers/models/transformer/gpt/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | 3 | from .configuration_gpt import GPTConfig 4 | from .modeling_gpt import GPTForCausalLM, GPTLayer, GPTModel 5 | 6 | AutoConfig.register(GPTConfig.model_type, GPTConfig) 7 | AutoModel.register(GPTConfig, GPTModel) 8 | AutoModelForCausalLM.register(GPTConfig, GPTForCausalLM) 9 | 10 | __all__ = ["GPTConfig", "GPTModel", "GPTForCausalLM"] 11 | -------------------------------------------------------------------------------- /xmixers/models/transformer/ngpt/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | 3 | from .configuration_ngpt import nGPTConfig 4 | from .modeling_ngpt import nGPTForCausalLM, nGPTLayer, nGPTModel 5 | 6 | AutoConfig.register(nGPTConfig.model_type, nGPTConfig) 7 | AutoModel.register(nGPTConfig, nGPTModel) 8 | AutoModelForCausalLM.register(nGPTConfig, nGPTForCausalLM) 9 | 10 | __all__ = ["nGPTConfig", "nGPTModel", "nGPTForCausalLM"] 11 | -------------------------------------------------------------------------------- /xmixers/models/transformer/llama/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | 3 | from .configuration_llama import LLaMAConfig 4 | from .modeling_llama import LLaMAForCausalLM, LLaMALayer, LLaMAModel 5 | 6 | AutoConfig.register(LLaMAConfig.model_type, LLaMAConfig) 7 | AutoModel.register(LLaMAConfig, LLaMAModel) 8 | AutoModelForCausalLM.register(LLaMAConfig, LLaMAForCausalLM) 9 | 10 | __all__ = ["LLaMAConfig", "LLaMAModel", "LLaMAForCausalLM"] 11 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/gsa/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | 3 | from .configuration_gsa import GsaConfig 4 | from .modeling_gsa import GsaForCausalLM, GsaLayer, GsaModel 5 | 6 | AutoConfig.register(GsaConfig.model_type, GsaConfig) 7 | AutoModel.register(GsaConfig, GsaModel) 8 | AutoModelForCausalLM.register(GsaConfig, GsaForCausalLM) 9 | 10 | __all__ = [ 11 | "GsaConfig", 12 | "GsaModel", 13 | "GsaForCausalLM", 14 | ] 15 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/tnl/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | 3 | from .configuration_tnl import TnlConfig 4 | from .modeling_tnl import TnlForCausalLM, TnlLayer, TnlModel 5 | 6 | AutoConfig.register(TnlConfig.model_type, TnlConfig) 7 | AutoModel.register(TnlConfig, TnlModel) 8 | AutoModelForCausalLM.register(TnlConfig, TnlForCausalLM) 9 | 10 | __all__ = [ 11 | "TnlConfig", 12 | "TnlModel", 13 | "TnlForCausalLM", 14 | ] 15 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/ttt/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | 3 | from .configuration_ttt import TTTConfig 4 | from .modeling_ttt import TTTForCausalLM, TTTLayer, TTTModel 5 | 6 | AutoConfig.register(TTTConfig.model_type, TTTConfig) 7 | AutoModel.register(TTTConfig, TTTModel) 8 | AutoModelForCausalLM.register(TTTConfig, TTTForCausalLM) 9 | 10 | __all__ = [ 11 | "TTTConfig", 12 | "TTTModel", 13 | "TTTForCausalLM", 14 | ] 15 | -------------------------------------------------------------------------------- /configs/transformer/mla/mla_86m.json: -------------------------------------------------------------------------------- 1 | { 2 | "base": 10000, 3 | "bias": false, 4 | "embed_dim": 768, 5 | "glu_activation": "silu", 6 | "init_type": 1, 7 | "kv_lora_rank": 128, 8 | "lrpe_type": 1, 9 | "mid_dim": 2048, 10 | "model_type": "llama_", 11 | "num_heads": 12, 12 | "num_layers": 14, 13 | "q_lora_rank": 128, 14 | "qk_rope_head_dim": 64, 15 | "tie_word_embeddings": true, 16 | "token_mixer_init_type": 2, 17 | "token_mixer_type": "mla", 18 | "use_lrpe": true 19 | } 20 | -------------------------------------------------------------------------------- /evals/harness.py: -------------------------------------------------------------------------------- 1 | from lm_eval.__main__ import cli_evaluate 2 | from lm_eval.api.registry import register_model 3 | from lm_eval.models.huggingface import HFLM 4 | 5 | import xmixers # noqa 6 | 7 | 8 | @register_model("xmixers") 9 | class XmixersLMWrapper(HFLM): 10 | def __init__(self, **kwargs): 11 | 12 | # TODO: provide options for doing inference with different kernels 13 | 14 | super().__init__(**kwargs) 15 | 16 | 17 | if __name__ == "__main__": 18 | cli_evaluate() 19 | -------------------------------------------------------------------------------- /configs/transformer/mla/mla_310m.json: -------------------------------------------------------------------------------- 1 | { 2 | "base": 10000, 3 | "bias": false, 4 | "embed_dim": 1024, 5 | "glu_activation": "silu", 6 | "init_type": 1, 7 | "kv_lora_rank": 512, 8 | "lrpe_type": 1, 9 | "mid_dim": 2816, 10 | "model_type": "llama_", 11 | "num_heads": 16, 12 | "num_layers": 24, 13 | "q_lora_rank": 512, 14 | "qk_rope_head_dim": 64, 15 | "tie_word_embeddings": true, 16 | "token_mixer_init_type": 2, 17 | "token_mixer_type": "mla", 18 | "use_lrpe": true 19 | } 20 | -------------------------------------------------------------------------------- /xmixers/models/transformer/flex_gpt/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | 3 | from .configuration_flex_gpt import FlexGPTConfig 4 | from .modeling_flex_gpt import FlexGPTForCausalLM, FlexGPTLayer, FlexGPTModel 5 | 6 | AutoConfig.register(FlexGPTConfig.model_type, FlexGPTConfig) 7 | AutoModel.register(FlexGPTConfig, FlexGPTModel) 8 | AutoModelForCausalLM.register(FlexGPTConfig, FlexGPTForCausalLM) 9 | 10 | __all__ = ["FlexGPTConfig", "FlexGPTModel", "FlexGPTForCausalLM"] 11 | -------------------------------------------------------------------------------- /xmixers/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .cache import XmixersCache 2 | from .constants import EMBED_DIM_BASE, XMIXERS_DEBUG 3 | from .init_utils import _init_weights, _initialize_weights, _post_init_weights 4 | from .loss_utils import Loss 5 | from .mask_utils import _upad_input, attn_mask_to_cu_seqlens, pad_input, unpad_input 6 | from .utils import ( 7 | endswith, 8 | logger, 9 | logging_info, 10 | next_power_of_2, 11 | pad_embed_dim, 12 | print_config, 13 | print_module, 14 | print_params, 15 | ) 16 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/hgrn1/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | 3 | from .configuration_hgrn1 import Hgrn1Config 4 | from .modeling_hgrn1 import Hgrn1ForCausalLM, Hgrn1Layer, Hgrn1Model 5 | 6 | AutoConfig.register(Hgrn1Config.model_type, Hgrn1Config) 7 | AutoModel.register(Hgrn1Config, Hgrn1Model) 8 | AutoModelForCausalLM.register(Hgrn1Config, Hgrn1ForCausalLM) 9 | 10 | __all__ = [ 11 | "Hgrn1Config", 12 | "Hgrn1Model", 13 | "Hgrn1ForCausalLM", 14 | ] 15 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/hgrn2/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | 3 | from .configuration_hgrn2 import Hgrn2Config 4 | from .modeling_hgrn2 import Hgrn2ForCausalLM, Hgrn2Layer, Hgrn2Model 5 | 6 | AutoConfig.register(Hgrn2Config.model_type, Hgrn2Config) 7 | AutoModel.register(Hgrn2Config, Hgrn2Model) 8 | AutoModelForCausalLM.register(Hgrn2Config, Hgrn2ForCausalLM) 9 | 10 | __all__ = [ 11 | "Hgrn2Config", 12 | "Hgrn2Model", 13 | "Hgrn2ForCausalLM", 14 | ] 15 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/hgrn3/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | 3 | from .configuration_hgrn3 import Hgrn3Config 4 | from .modeling_hgrn3 import Hgrn3ForCausalLM, Hgrn3Layer, Hgrn3Model 5 | 6 | AutoConfig.register(Hgrn3Config.model_type, Hgrn3Config) 7 | AutoModel.register(Hgrn3Config, Hgrn3Model) 8 | AutoModelForCausalLM.register(Hgrn3Config, Hgrn3ForCausalLM) 9 | 10 | __all__ = [ 11 | "Hgrn3Config", 12 | "Hgrn3Model", 13 | "Hgrn3ForCausalLM", 14 | ] 15 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/metala/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | 3 | from .configuration_metala import MetaLaConfig 4 | from .modeling_metala import MetaLaForCausalLM, MetaLaLayer, MetaLaModel 5 | 6 | AutoConfig.register(MetaLaConfig.model_type, MetaLaConfig) 7 | AutoModel.register(MetaLaConfig, MetaLaModel) 8 | AutoModelForCausalLM.register(MetaLaConfig, MetaLaForCausalLM) 9 | 10 | __all__ = [ 11 | "MetaLaConfig", 12 | "MetaLaModel", 13 | "MetaLaForCausalLM", 14 | ] 15 | -------------------------------------------------------------------------------- /xmixers/modules/channel_mixers/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from xopes.ops import gate_linear_fn 3 | 4 | 5 | class GateLinearOp(nn.Module): 6 | def __init__( 7 | self, 8 | ): 9 | super().__init__() 10 | 11 | def forward(self, x1, x2, weight, bias, act, residual=None): 12 | output = gate_linear_fn( 13 | x1=x1, 14 | x2=x2, 15 | W=weight, 16 | bias=bias, 17 | act=act, 18 | residual=residual, 19 | ) 20 | return output 21 | -------------------------------------------------------------------------------- /xmixers/modules/quantizer/finite_scalar_quantizer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from .utils import round_ste 5 | 6 | 7 | class FiniteScalarQuantizer(nn.Module): 8 | def __init__( 9 | self, 10 | center=False, 11 | **kwargs, 12 | ): 13 | super().__init__() 14 | self.center = center 15 | 16 | def forward(self, x): 17 | x_quant = round_ste(F.sigmoid(x)) 18 | if self.center: 19 | x_quant = 2 * x_quant - 1 20 | 21 | return x_quant 22 | -------------------------------------------------------------------------------- /configs/linear_transformer/hgrn2/hgrn2_7b.json: -------------------------------------------------------------------------------- 1 | { 2 | "base": 10000, 3 | "bias": false, 4 | "ce_type": "xopes_flce", 5 | "embed_dim": 4096, 6 | "fuse_norm_add": false, 7 | "gain": 0.01, 8 | "glu_activation": "silu", 9 | "init_type": 1, 10 | "lrpe_type": 1, 11 | "mid_dim": 11520, 12 | "model_type": "hgrn2", 13 | "norm_type": "rmsnorm", 14 | "num_heads": 32, 15 | "num_layers": 32, 16 | "rescale_type": 2, 17 | "tie_word_embeddings": false, 18 | "token_mixer_init_type": 4, 19 | "use_gate_linear": true, 20 | "use_lrpe": true 21 | } 22 | -------------------------------------------------------------------------------- /configs/linear_transformer/hgrn2/hgrn2_90m.json: -------------------------------------------------------------------------------- 1 | { 2 | "base": 10000, 3 | "bias": false, 4 | "ce_type": "xopes_flce", 5 | "embed_dim": 768, 6 | "fuse_norm_add": false, 7 | "gain": 0.01, 8 | "glu_activation": "silu", 9 | "init_type": 1, 10 | "lrpe_type": 1, 11 | "mid_dim": 2048, 12 | "model_type": "hgrn2", 13 | "norm_type": "rmsnorm", 14 | "num_heads": 12, 15 | "num_layers": 12, 16 | "rescale_type": 2, 17 | "tie_word_embeddings": false, 18 | "token_mixer_init_type": 4, 19 | "use_gate_linear": true, 20 | "use_lrpe": true 21 | } 22 | -------------------------------------------------------------------------------- /configs/linear_transformer/hgrn2/hgrn2_xl.json: -------------------------------------------------------------------------------- 1 | { 2 | "base": 10000, 3 | "bias": false, 4 | "ce_type": "xopes_flce", 5 | "embed_dim": 1280, 6 | "fuse_norm_add": false, 7 | "gain": 0.01, 8 | "glu_activation": "silu", 9 | "init_type": 1, 10 | "lrpe_type": 1, 11 | "mid_dim": 3584, 12 | "model_type": "hgrn2", 13 | "norm_type": "rmsnorm", 14 | "num_heads": 20, 15 | "num_layers": 36, 16 | "rescale_type": 2, 17 | "tie_word_embeddings": false, 18 | "token_mixer_init_type": 4, 19 | "use_gate_linear": true, 20 | "use_lrpe": true 21 | } 22 | -------------------------------------------------------------------------------- /configs/linear_transformer/hgrn2/hgrn2_xxl.json: -------------------------------------------------------------------------------- 1 | { 2 | "base": 10000, 3 | "bias": false, 4 | "ce_type": "xopes_flce", 5 | "embed_dim": 1536, 6 | "fuse_norm_add": false, 7 | "gain": 0.01, 8 | "glu_activation": "silu", 9 | "init_type": 1, 10 | "lrpe_type": 1, 11 | "mid_dim": 4096, 12 | "model_type": "hgrn2", 13 | "norm_type": "rmsnorm", 14 | "num_heads": 24, 15 | "num_layers": 48, 16 | "rescale_type": 2, 17 | "tie_word_embeddings": false, 18 | "token_mixer_init_type": 4, 19 | "use_gate_linear": true, 20 | "use_lrpe": true 21 | } 22 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/mesa_net/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | 3 | from .configuration_mesa_net import MesaNetConfig 4 | from .modeling_mesa_net import MesaNetForCausalLM, MesaNetLayer, MesaNetModel 5 | 6 | AutoConfig.register(MesaNetConfig.model_type, MesaNetConfig) 7 | AutoModel.register(MesaNetConfig, MesaNetModel) 8 | AutoModelForCausalLM.register(MesaNetConfig, MesaNetForCausalLM) 9 | 10 | __all__ = [ 11 | "MesaNetConfig", 12 | "MesaNetModel", 13 | "MesaNetForCausalLM", 14 | ] 15 | -------------------------------------------------------------------------------- /xmixers/monkey_patch/mpa/configuration_llama_mpa.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ Llama mpa configuration""" 3 | 4 | from transformers.models.llama.configuration_llama import LlamaConfig 5 | from transformers.utils import logging 6 | 7 | logger = logging.get_logger(__name__) 8 | 9 | 10 | class LlamaMpaConfig(LlamaConfig): 11 | model_type = "llama_mpa" 12 | keys_to_ignore_at_inference = ["past_key_values"] 13 | 14 | def __init__( 15 | self, 16 | **kwargs, 17 | ): 18 | super().__init__( 19 | **kwargs, 20 | ) 21 | -------------------------------------------------------------------------------- /configs/linear_transformer/hgrn2/hgrn2_1_2b.json: -------------------------------------------------------------------------------- 1 | { 2 | "base": 10000, 3 | "bias": false, 4 | "ce_type": "xopes_flce", 5 | "embed_dim": 2048, 6 | "fuse_norm_add": false, 7 | "gain": 0.01, 8 | "glu_activation": "silu", 9 | "init_type": 1, 10 | "lrpe_type": 1, 11 | "mid_dim": 5632, 12 | "model_type": "hgrn2", 13 | "norm_type": "rmsnorm", 14 | "num_heads": 16, 15 | "num_layers": 24, 16 | "rescale_type": 2, 17 | "tie_word_embeddings": false, 18 | "token_mixer_init_type": 4, 19 | "use_gate_linear": true, 20 | "use_lrpe": true 21 | } 22 | -------------------------------------------------------------------------------- /configs/linear_transformer/hgrn2/hgrn2_2_5b.json: -------------------------------------------------------------------------------- 1 | { 2 | "base": 10000, 3 | "bias": false, 4 | "ce_type": "xopes_flce", 5 | "embed_dim": 2560, 6 | "fuse_norm_add": false, 7 | "gain": 0.01, 8 | "glu_activation": "silu", 9 | "init_type": 1, 10 | "lrpe_type": 1, 11 | "mid_dim": 6912, 12 | "model_type": "hgrn2", 13 | "norm_type": "rmsnorm", 14 | "num_heads": 20, 15 | "num_layers": 32, 16 | "rescale_type": 2, 17 | "tie_word_embeddings": false, 18 | "token_mixer_init_type": 4, 19 | "use_gate_linear": true, 20 | "use_lrpe": true 21 | } 22 | -------------------------------------------------------------------------------- /configs/linear_transformer/hgrn2/hgrn2_310m.json: -------------------------------------------------------------------------------- 1 | { 2 | "base": 10000, 3 | "bias": false, 4 | "ce_type": "xopes_flce", 5 | "embed_dim": 1024, 6 | "fuse_norm_add": false, 7 | "gain": 0.01, 8 | "glu_activation": "silu", 9 | "init_type": 1, 10 | "lrpe_type": 1, 11 | "mid_dim": 2816, 12 | "model_type": "hgrn2", 13 | "norm_type": "rmsnorm", 14 | "num_heads": 16, 15 | "num_layers": 24, 16 | "rescale_type": 2, 17 | "tie_word_embeddings": false, 18 | "token_mixer_init_type": 4, 19 | "use_gate_linear": true, 20 | "use_lrpe": true 21 | } 22 | -------------------------------------------------------------------------------- /configs/linear_transformer/hgrn2/hgrn2_h128_310m.json: -------------------------------------------------------------------------------- 1 | { 2 | "base": 10000, 3 | "bias": false, 4 | "ce_type": "xopes_flce", 5 | "embed_dim": 1024, 6 | "fuse_norm_add": false, 7 | "gain": 0.01, 8 | "glu_activation": "silu", 9 | "init_type": 1, 10 | "lrpe_type": 1, 11 | "mid_dim": 2816, 12 | "model_type": "hgrn2", 13 | "norm_type": "rmsnorm", 14 | "num_heads": 8, 15 | "num_layers": 24, 16 | "rescale_type": 2, 17 | "tie_word_embeddings": false, 18 | "token_mixer_init_type": 4, 19 | "use_gate_linear": true, 20 | "use_lrpe": true 21 | } 22 | -------------------------------------------------------------------------------- /configs/linear_transformer/hgrn2/hgrn2_h128_90m.json: -------------------------------------------------------------------------------- 1 | { 2 | "base": 10000, 3 | "bias": false, 4 | "ce_type": "xopes_flce", 5 | "embed_dim": 768, 6 | "fuse_norm_add": false, 7 | "gain": 0.01, 8 | "glu_activation": "silu", 9 | "init_type": 1, 10 | "lrpe_type": 1, 11 | "mid_dim": 2048, 12 | "model_type": "hgrn2", 13 | "norm_type": "rmsnorm", 14 | "num_heads": 6, 15 | "num_layers": 12, 16 | "rescale_type": 2, 17 | "tie_word_embeddings": false, 18 | "token_mixer_init_type": 4, 19 | "use_gate_linear": true, 20 | "use_lrpe": true 21 | } 22 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/deltanet/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | 3 | from .configuration_deltanet import DeltaNetConfig 4 | from .modeling_deltanet import DeltaNetForCausalLM, DeltaNetLayer, DeltaNetModel 5 | 6 | AutoConfig.register(DeltaNetConfig.model_type, DeltaNetConfig) 7 | AutoModel.register(DeltaNetConfig, DeltaNetModel) 8 | AutoModelForCausalLM.register(DeltaNetConfig, DeltaNetForCausalLM) 9 | 10 | __all__ = [ 11 | "DeltaNetConfig", 12 | "DeltaNetModel", 13 | "DeltaNetForCausalLM", 14 | ] 15 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/lightnet/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | 3 | from .configuration_lightnet import LightNetConfig 4 | from .modeling_lightnet import LightNetForCausalLM, LightNetLayer, LightNetModel 5 | 6 | AutoConfig.register(LightNetConfig.model_type, LightNetConfig) 7 | AutoModel.register(LightNetConfig, LightNetModel) 8 | AutoModelForCausalLM.register(LightNetConfig, LightNetForCausalLM) 9 | 10 | __all__ = [ 11 | "LightNetConfig", 12 | "LightNetModel", 13 | "LightNetForCausalLM", 14 | ] 15 | -------------------------------------------------------------------------------- /configs/transformer/llama_half_rope/llama_half_rope_1_2b.json: -------------------------------------------------------------------------------- 1 | { 2 | "base": 10000, 3 | "bias": false, 4 | "ce_type": "xopes_flce", 5 | "embed_dim": 2048, 6 | "fuse_norm_add": false, 7 | "gain": 0.01, 8 | "glu_activation": "silu", 9 | "init_type": 1, 10 | "lrpe_type": 3, 11 | "mid_dim": 5632, 12 | "model_type": "llama_", 13 | "norm_type": "rmsnorm", 14 | "num_heads": 16, 15 | "num_layers": 24, 16 | "rescale_type": 2, 17 | "tie_word_embeddings": false, 18 | "token_mixer_init_type": 4, 19 | "use_gate_linear": true, 20 | "use_lrpe": true 21 | } 22 | -------------------------------------------------------------------------------- /configs/transformer/llama_half_rope/llama_half_rope_2_5b.json: -------------------------------------------------------------------------------- 1 | { 2 | "base": 10000, 3 | "bias": false, 4 | "ce_type": "xopes_flce", 5 | "embed_dim": 2560, 6 | "fuse_norm_add": false, 7 | "gain": 0.01, 8 | "glu_activation": "silu", 9 | "init_type": 1, 10 | "lrpe_type": 3, 11 | "mid_dim": 6912, 12 | "model_type": "llama_", 13 | "norm_type": "rmsnorm", 14 | "num_heads": 20, 15 | "num_layers": 32, 16 | "rescale_type": 2, 17 | "tie_word_embeddings": false, 18 | "token_mixer_init_type": 4, 19 | "use_gate_linear": true, 20 | "use_lrpe": true 21 | } 22 | -------------------------------------------------------------------------------- /configs/transformer/llama_half_rope/llama_half_rope_310m.json: -------------------------------------------------------------------------------- 1 | { 2 | "base": 10000, 3 | "bias": false, 4 | "ce_type": "xopes_flce", 5 | "embed_dim": 1024, 6 | "fuse_norm_add": false, 7 | "gain": 0.01, 8 | "glu_activation": "silu", 9 | "init_type": 1, 10 | "lrpe_type": 3, 11 | "mid_dim": 2816, 12 | "model_type": "llama_", 13 | "norm_type": "rmsnorm", 14 | "num_heads": 16, 15 | "num_layers": 24, 16 | "rescale_type": 2, 17 | "tie_word_embeddings": false, 18 | "token_mixer_init_type": 4, 19 | "use_gate_linear": true, 20 | "use_lrpe": true 21 | } 22 | -------------------------------------------------------------------------------- /configs/transformer/llama_half_rope/llama_half_rope_7b.json: -------------------------------------------------------------------------------- 1 | { 2 | "base": 10000, 3 | "bias": false, 4 | "ce_type": "xopes_flce", 5 | "embed_dim": 4096, 6 | "fuse_norm_add": false, 7 | "gain": 0.01, 8 | "glu_activation": "silu", 9 | "init_type": 1, 10 | "lrpe_type": 3, 11 | "mid_dim": 11008, 12 | "model_type": "llama_", 13 | "norm_type": "rmsnorm", 14 | "num_heads": 32, 15 | "num_layers": 32, 16 | "rescale_type": 2, 17 | "tie_word_embeddings": false, 18 | "token_mixer_init_type": 4, 19 | "use_gate_linear": true, 20 | "use_lrpe": true 21 | } 22 | -------------------------------------------------------------------------------- /configs/transformer/llama_half_rope/llama_half_rope_90m.json: -------------------------------------------------------------------------------- 1 | { 2 | "base": 10000, 3 | "bias": false, 4 | "ce_type": "xopes_flce", 5 | "embed_dim": 768, 6 | "fuse_norm_add": false, 7 | "gain": 0.01, 8 | "glu_activation": "silu", 9 | "init_type": 1, 10 | "lrpe_type": 3, 11 | "mid_dim": 2048, 12 | "model_type": "llama_", 13 | "norm_type": "rmsnorm", 14 | "num_heads": 12, 15 | "num_layers": 12, 16 | "rescale_type": 2, 17 | "tie_word_embeddings": false, 18 | "token_mixer_init_type": 4, 19 | "use_gate_linear": true, 20 | "use_lrpe": true 21 | } 22 | -------------------------------------------------------------------------------- /configs/transformer/llama_half_rope/llama_half_rope_xl.json: -------------------------------------------------------------------------------- 1 | { 2 | "base": 10000, 3 | "bias": false, 4 | "ce_type": "xopes_flce", 5 | "embed_dim": 1280, 6 | "fuse_norm_add": false, 7 | "gain": 0.01, 8 | "glu_activation": "silu", 9 | "init_type": 1, 10 | "lrpe_type": 3, 11 | "mid_dim": 3584, 12 | "model_type": "llama_", 13 | "norm_type": "rmsnorm", 14 | "num_heads": 20, 15 | "num_layers": 36, 16 | "rescale_type": 2, 17 | "tie_word_embeddings": false, 18 | "token_mixer_init_type": 4, 19 | "use_gate_linear": true, 20 | "use_lrpe": true 21 | } 22 | -------------------------------------------------------------------------------- /configs/transformer/llama_half_rope/llama_half_rope_xxl.json: -------------------------------------------------------------------------------- 1 | { 2 | "base": 10000, 3 | "bias": false, 4 | "ce_type": "xopes_flce", 5 | "embed_dim": 1536, 6 | "fuse_norm_add": false, 7 | "gain": 0.01, 8 | "glu_activation": "silu", 9 | "init_type": 1, 10 | "lrpe_type": 3, 11 | "mid_dim": 4096, 12 | "model_type": "llama_", 13 | "norm_type": "rmsnorm", 14 | "num_heads": 24, 15 | "num_layers": 48, 16 | "rescale_type": 2, 17 | "tie_word_embeddings": false, 18 | "token_mixer_init_type": 4, 19 | "use_gate_linear": true, 20 | "use_lrpe": true 21 | } 22 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/dense_rnn/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | 3 | from .configuration_dense_rnn import DenseRnnConfig 4 | from .modeling_dense_rnn import DenseRnnForCausalLM, DenseRnnLayer, DenseRnnModel 5 | 6 | AutoConfig.register(DenseRnnConfig.model_type, DenseRnnConfig) 7 | AutoModel.register(DenseRnnConfig, DenseRnnModel) 8 | AutoModelForCausalLM.register(DenseRnnConfig, DenseRnnForCausalLM) 9 | 10 | __all__ = [ 11 | "DenseRnnConfig", 12 | "DenseRnnModel", 13 | "DenseRnnForCausalLM", 14 | ] 15 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/polar_rnn/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | 3 | from .configuration_polar_rnn import PolarRnnConfig 4 | from .modeling_polar_rnn import PolarRnnForCausalLM, PolarRnnLayer, PolarRnnModel 5 | 6 | AutoConfig.register(PolarRnnConfig.model_type, PolarRnnConfig) 7 | AutoModel.register(PolarRnnConfig, PolarRnnModel) 8 | AutoModelForCausalLM.register(PolarRnnConfig, PolarRnnForCausalLM) 9 | 10 | __all__ = [ 11 | "PolarRnnConfig", 12 | "PolarRnnModel", 13 | "PolarRnnForCausalLM", 14 | ] 15 | -------------------------------------------------------------------------------- /xmixers/models/chunk_linear_transformer/chunk_rnn/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | 3 | from .configuration_chunk_rnn import ChunkRnnConfig 4 | from .modeling_chunk_rnn import ChunkRnnForCausalLM, ChunkRnnLayer, ChunkRnnModel 5 | 6 | AutoConfig.register(ChunkRnnConfig.model_type, ChunkRnnConfig) 7 | AutoModel.register(ChunkRnnConfig, ChunkRnnModel) 8 | AutoModelForCausalLM.register(ChunkRnnConfig, ChunkRnnForCausalLM) 9 | 10 | __all__ = [ 11 | "ChunkRnnConfig", 12 | "ChunkRnnModel", 13 | "ChunkRnnForCausalLM", 14 | ] 15 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/mamba2/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | 3 | from .configuration_mamba2 import Mamba2XmixersConfig 4 | from .modeling_mamba2 import Mamba2ForCausalLM, Mamba2Layer, Mamba2Model 5 | 6 | AutoConfig.register(Mamba2XmixersConfig.model_type, Mamba2XmixersConfig) 7 | AutoModel.register(Mamba2XmixersConfig, Mamba2Model) 8 | AutoModelForCausalLM.register(Mamba2XmixersConfig, Mamba2ForCausalLM) 9 | 10 | __all__ = [ 11 | "Mamba2XmixersConfig", 12 | "Mamba2Model", 13 | "Mamba2ForCausalLM", 14 | ] 15 | -------------------------------------------------------------------------------- /xmixers/modules/token_mixers/deep_memory/loss/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def get_loss_fn(loss_name): 7 | if loss_name == "mse": 8 | 9 | def fn(y_pred, y_true): 10 | return F.mse_loss(y_pred, y_true, reduction="sum") 11 | 12 | elif loss_name == "inner_product": 13 | 14 | def fn(y_pred, y_true): 15 | return -torch.sum((y_pred * y_true).sum(dim=-1)) 16 | 17 | else: 18 | raise ValueError(f"Loss function {loss_name} not found") 19 | 20 | return fn 21 | -------------------------------------------------------------------------------- /xmixers/models/hybrid/naive_hybrid/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | 3 | from .configuration_naive_hybrid import NaiveHybridConfig 4 | from .modeling_naive_hybrid import ( 5 | NaiveHybridForCausalLM, 6 | NaiveHybridLayer, 7 | NaiveHybridModel, 8 | ) 9 | 10 | AutoConfig.register(NaiveHybridConfig.model_type, NaiveHybridConfig) 11 | AutoModel.register(NaiveHybridConfig, NaiveHybridModel) 12 | AutoModelForCausalLM.register(NaiveHybridConfig, NaiveHybridForCausalLM) 13 | 14 | __all__ = ["NaiveHybridConfig", "NaiveHybridModel", "NaiveHybridForCausalLM"] 15 | -------------------------------------------------------------------------------- /examples/gtu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from xmixers.modules import Gtu 4 | 5 | device = "cuda:0" 6 | b, n, m, d = 2, 12, 23, 128 7 | 8 | x = torch.randn(b, n, m, d).to(device) 9 | 10 | gtu_causal = Gtu( 11 | embed_dim=d, 12 | causal=True, 13 | dims=[1, 2], 14 | in_dim=16, 15 | ).to(device) 16 | 17 | gtu_none_causal = Gtu( 18 | embed_dim=d, 19 | causal=False, 20 | dims=[1, 2], 21 | in_dim=16, 22 | ).to(device) 23 | 24 | print(gtu_causal) 25 | print(gtu_none_causal) 26 | 27 | y1 = gtu_causal(x) 28 | y2 = gtu_none_causal(x) 29 | 30 | print(x.shape) 31 | print(y1.shape) 32 | print(y2.shape) 33 | -------------------------------------------------------------------------------- /xmixers/models/hybrid/lm_head_hybrid/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | 3 | from .configuration_lm_head_hybrid import LmHeadHybridConfig 4 | from .modeling_lm_head_hybrid import ( 5 | LmHeadHybridForCausalLM, 6 | LmHeadHybridLayer, 7 | LmHeadHybridModel, 8 | ) 9 | 10 | AutoConfig.register(LmHeadHybridConfig.model_type, LmHeadHybridConfig) 11 | AutoModel.register(LmHeadHybridConfig, LmHeadHybridModel) 12 | AutoModelForCausalLM.register(LmHeadHybridConfig, LmHeadHybridForCausalLM) 13 | 14 | __all__ = ["LmHeadHybridConfig", "LmHeadHybridModel", "LmHeadHybridForCausalLM"] 15 | -------------------------------------------------------------------------------- /xmixers/modules/normalizations/l2_norm.py: -------------------------------------------------------------------------------- 1 | from xopes.ops.normalize import normalize_fn 2 | 3 | 4 | def l2_norm(x, eps=1e-5): 5 | return normalize_fn( 6 | x=x, 7 | weight=None, 8 | bias=None, 9 | residual=None, 10 | c=1, 11 | eps=eps, 12 | use_mean=False, 13 | num_groups=1, 14 | ) 15 | 16 | 17 | def rms_norm_fn(x, eps=1e-5): 18 | return normalize_fn( 19 | x=x, 20 | weight=None, 21 | bias=None, 22 | residual=None, 23 | c=x.shape[-1] ** 0.5, 24 | eps=eps, 25 | use_mean=False, 26 | num_groups=1, 27 | ) 28 | -------------------------------------------------------------------------------- /examples/llama.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModel, AutoModelForCausalLM 3 | 4 | from xmixers.models import LLaMAConfig 5 | 6 | config = LLaMAConfig() 7 | 8 | config.update({"num_layers": 12}) 9 | 10 | model1 = AutoModelForCausalLM.from_config(config).to(torch.bfloat16).cuda() 11 | model2 = AutoModel.from_config(config).to(torch.bfloat16).cuda() 12 | 13 | print(config) 14 | print(model1) 15 | print(model2) 16 | 17 | b = 2 18 | n = 2048 19 | m = config.vocab_size 20 | 21 | input = torch.randint(low=0, high=m, size=(b, n)).cuda() 22 | output1 = model1(input).logits 23 | output2 = model2(input).last_hidden_state 24 | print(output1.shape) 25 | print(output2.shape) 26 | -------------------------------------------------------------------------------- /benchmarks/script_train.sh: -------------------------------------------------------------------------------- 1 | date=$(date '+%Y-%m-%d-%H:%M:%S') 2 | 3 | 4 | cfg_path=../configs/transformer/llama/llama_86m.json 5 | # cfg_path=../configs/transformer/llama/llama_310m.json 6 | name=llama 7 | 8 | # cfg_path=../configs/transformer/mla/mla_86m.json 9 | # name=mla 10 | 11 | vocab_size=64000 12 | warmup_steps=16 13 | steps=64 14 | 15 | batch_size_list=(16 8 4 2) 16 | seq_len_list=(1024 2048 4096 8192) 17 | 18 | python benchmark_training.py \ 19 | --cfg_path $cfg_path \ 20 | --name $name \ 21 | --vocab_size $vocab_size \ 22 | --batch_size_list ${batch_size_list[@]} \ 23 | --seq_len_list ${seq_len_list[@]} \ 24 | --warmup_steps $warmup_steps \ 25 | --steps $steps 26 | -------------------------------------------------------------------------------- /xmixers/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .activations import ActLayer, get_activation_fn 2 | from .channel_mixers import ALU, FFN, GLU, get_channel_mixer, nGLU 3 | from .normalizations import get_norm_fn 4 | from .pes import ( 5 | LearnablePe, 6 | Lrpe, 7 | MlpPe, 8 | SinCosPe, 9 | get_log_slopes, 10 | get_log_slopes_general, 11 | ) 12 | from .token_mixers import ( 13 | LINEAR_TOKEN_MIXER_LIST, 14 | SOFTMAX_TOKEN_MIXER_LIST, 15 | Attention, 16 | FlexAttention, 17 | Gtu, 18 | Hgru2, 19 | Hgru3, 20 | LinearAttention, 21 | MetaLa, 22 | MultiProductAttention, 23 | PolarRnn, 24 | TnlAttention, 25 | get_token_mixer, 26 | nAttention, 27 | ) 28 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/linear_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | 3 | from .configuration_linear_transformer import LinearTransformerConfig 4 | from .modeling_linear_transformer import ( 5 | LinearTransformerForCausalLM, 6 | LinearTransformerLayer, 7 | LinearTransformerModel, 8 | ) 9 | 10 | AutoConfig.register(LinearTransformerConfig.model_type, LinearTransformerConfig) 11 | AutoModel.register(LinearTransformerConfig, LinearTransformerModel) 12 | AutoModelForCausalLM.register(LinearTransformerConfig, LinearTransformerForCausalLM) 13 | 14 | __all__ = [ 15 | "LinearTransformerConfig", 16 | "LinearTransformerModel", 17 | "LinearTransformerForCausalLM", 18 | ] 19 | -------------------------------------------------------------------------------- /xmixers/modules/normalizations/srms_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .utils import NormOp 4 | 5 | 6 | class SRMSNorm(torch.nn.Module): 7 | def __init__(self, dim: int, eps: float = 1e-5, **kwargs): 8 | super().__init__() 9 | self.dim = dim 10 | self.eps = eps 11 | self.op = NormOp(norm_type="srmsnorm") 12 | 13 | def extra_repr(self) -> str: 14 | return f"dim={self.dim}, eps={self.eps}" 15 | 16 | def forward(self, x, residual=None, return_residual=False): 17 | return self.op( 18 | x, 19 | None, 20 | None, 21 | residual, 22 | self.dim, 23 | self.eps, 24 | False, 25 | 1, 26 | return_residual, 27 | ) 28 | -------------------------------------------------------------------------------- /xmixers/modules/token_mixers/linear_attention/__init__.py: -------------------------------------------------------------------------------- 1 | from .decay_linear_attention import DecayLinearAttention 2 | from .delta_product_unit import DeltaProductUnit 3 | from .delta_unit import DeltaUnit 4 | from .dense_rnn import DenseRnn 5 | from .gsa import GatedSlotAttention 6 | from .hgru1 import Hgru1 7 | from .hgru2 import Hgru2 8 | from .hgru2_scalar_decay import Hgru2ScalarDecay 9 | from .hgru3 import Hgru3 10 | from .implicit_value_attention import ImplicitValueAttention 11 | from .lightnet_attention import LightNetAttention 12 | from .linear_attention import LinearAttention 13 | from .mamba2 import Mamba2 14 | from .mesa_unit import MesaUnit 15 | from .meta_la import MetaLa 16 | from .polar_rnn import PolarRnn 17 | from .tnl_attention import TnlAttention 18 | from .ttt import TTT 19 | -------------------------------------------------------------------------------- /tests/modules/md_tpe/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from xmixers.modules.pes.md import MdTpe 4 | 5 | 6 | def test_md_tpe(embed_dim, num_heads, dims, shape, device): 7 | md_tpe = MdTpe(embed_dim=embed_dim, num_heads=num_heads, dims=dims).to(device) 8 | x = torch.randn(shape, device=device).requires_grad_() 9 | print("input shape: ", x.shape) 10 | o = md_tpe(x) 11 | print("output shape: ", o.shape) 12 | o.sum().backward() 13 | print("input grad: ", x.grad.shape) 14 | 15 | 16 | if __name__ == "__main__": 17 | embed_dim = 128 18 | num_heads = 16 19 | dim = [-2, -3, -4] 20 | shape = (2, 16, 16, 16, embed_dim) 21 | device = "cuda" 22 | test_md_tpe( 23 | embed_dim=embed_dim, num_heads=num_heads, dims=dim, shape=shape, device=device 24 | ) 25 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/decay_linear_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | 3 | from .configuration_decay_linear_transformer import DecayLinearTransformerConfig 4 | from .modeling_decay_linear_transformer import ( 5 | DecayLinearTransformerForCausalLM, 6 | DecayLinearTransformerLayer, 7 | DecayLinearTransformerModel, 8 | ) 9 | 10 | AutoConfig.register( 11 | DecayLinearTransformerConfig.model_type, DecayLinearTransformerConfig 12 | ) 13 | AutoModel.register(DecayLinearTransformerConfig, DecayLinearTransformerModel) 14 | AutoModelForCausalLM.register( 15 | DecayLinearTransformerConfig, DecayLinearTransformerForCausalLM 16 | ) 17 | 18 | __all__ = [ 19 | "DecayLinearTransformerConfig", 20 | "DecayLinearTransformerModel", 21 | "DecayLinearTransformerForCausalLM", 22 | ] 23 | -------------------------------------------------------------------------------- /xmixers/modules/normalizations/offset_scale.py: -------------------------------------------------------------------------------- 1 | """ 2 | Offset scale: y = gamma * x + beta 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from xmixers.utils import XMIXERS_DEBUG, print_module, print_params 9 | 10 | 11 | class OffsetScale(nn.Module): 12 | def __init__(self, dim: int, **kwargs) -> None: 13 | super().__init__() 14 | 15 | if XMIXERS_DEBUG: 16 | # get local varables 17 | params = locals() 18 | # print params 19 | print_params(**params) 20 | 21 | self.gamma = nn.Parameter(torch.ones(dim)) 22 | self.beta = nn.Parameter(torch.zeros(dim)) 23 | nn.init.normal_(self.gamma, std=0.02) 24 | 25 | def extra_repr(self) -> str: 26 | return print_module(self) 27 | 28 | def forward(self, x): 29 | out = x * self.gamma + self.beta 30 | return out 31 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/implicit_linear_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | 3 | from .configuration_implicit_linear_transformer import ImplicitLinearTransformerConfig 4 | from .modeling_implicit_linear_transformer import ( 5 | ImplicitLinearTransformerForCausalLM, 6 | ImplicitLinearTransformerLayer, 7 | ImplicitLinearTransformerModel, 8 | ) 9 | 10 | AutoConfig.register( 11 | ImplicitLinearTransformerConfig.model_type, ImplicitLinearTransformerConfig 12 | ) 13 | AutoModel.register(ImplicitLinearTransformerConfig, ImplicitLinearTransformerModel) 14 | AutoModelForCausalLM.register( 15 | ImplicitLinearTransformerConfig, ImplicitLinearTransformerForCausalLM 16 | ) 17 | 18 | __all__ = [ 19 | "ImplicitLinearTransformerConfig", 20 | "ImplicitLinearTransformerModel", 21 | "ImplicitLinearTransformerForCausalLM", 22 | ] 23 | -------------------------------------------------------------------------------- /xmixers/modules/channel_mixers/ffn.py: -------------------------------------------------------------------------------- 1 | # FFN: https://arxiv.org/pdf/2002.05202.pdf 2 | 3 | import torch.nn as nn 4 | 5 | from xmixers.modules.activations import get_activation_fn 6 | from xmixers.utils import XMIXERS_DEBUG, print_params 7 | 8 | 9 | class FFN(nn.Module): 10 | def __init__( 11 | self, embed_dim: int, mid_dim: int, activation: str, bias: bool = False 12 | ) -> None: 13 | super().__init__() 14 | 15 | if XMIXERS_DEBUG: 16 | # get local varables 17 | params = locals() 18 | # print params 19 | print_params(**params) 20 | 21 | self.w1 = nn.Linear(embed_dim, mid_dim, bias=bias) 22 | self.w2 = nn.Linear(mid_dim, embed_dim, bias=bias) 23 | self.act = get_activation_fn(activation) 24 | 25 | def forward(self, x): 26 | output = self.w2(self.act(self.w1(x))) 27 | 28 | return output 29 | -------------------------------------------------------------------------------- /xmixers/modules/pes/learnable_pe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class LearnablePe(nn.Module): 7 | def __init__( 8 | self, 9 | embed_dim: int, 10 | max_sequence_length: int = 2048, 11 | ): 12 | super().__init__() 13 | self.max_sequence_length = max_sequence_length 14 | self.embed_dim = embed_dim 15 | weight = torch.randn(max_sequence_length, embed_dim) 16 | self.weight = nn.Parameter(weight, requires_grad=True) 17 | 18 | def extra_repr(self) -> str: 19 | s = "{max_sequence_length}, {embed_dim}" 20 | return s.format(**self.__dict__) 21 | 22 | def forward(self, x, shape=None, offset=0): 23 | n = x.shape[1] 24 | pos = torch.arange(0, n, dtype=torch.long, device=x.device) + offset 25 | pe = F.embedding(pos, self.weight) 26 | x = x + pe 27 | 28 | return x 29 | -------------------------------------------------------------------------------- /xmixers/modules/normalizations/scale_norm.py: -------------------------------------------------------------------------------- 1 | """ 2 | ScaleNorm in https://arxiv.org/pdf/2202.10447.pdf 3 | """ 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from xmixers.utils import XMIXERS_DEBUG, print_params 9 | 10 | 11 | class ScaleNorm(nn.Module): 12 | def __init__(self, d: int, eps: float = 1e-5, **kwargs) -> None: 13 | super().__init__() 14 | if XMIXERS_DEBUG: 15 | # get local varables 16 | params = locals() 17 | # print params 18 | print_params(**params) 19 | 20 | self.d = d 21 | self.eps = eps 22 | self.scala = nn.Parameter(torch.ones(1)) 23 | 24 | def extra_repr(self) -> str: 25 | return print_module(self) 26 | 27 | def forward(self, x): 28 | # TODO: add fusion here 29 | mean_square = (x**2).mean(dim=-1, keepdim=True) 30 | x = x * torch.rsqrt(mean_square + self.eps) * self.scala 31 | return x 32 | -------------------------------------------------------------------------------- /xmixers/modules/token_mixers/vanilla_attention/__init__.py: -------------------------------------------------------------------------------- 1 | from .attention import Attention 2 | from .flex_attention import FlexAttention 3 | from .forgetting_attention import ForgettingAttention 4 | from .fsq_kv_attention import FsqKvAttention 5 | from .fsq_kv_multi_product_attention import FsqKvMultiProductAttention 6 | from .kernel_regression_attention import KernelRegressionAttention 7 | from .multi_factor_attention import MultiFactorAttention 8 | from .multi_latent_attention import MultiLatentAttention 9 | from .multi_product_attention import MultiProductAttention 10 | from .n_attention import nAttention 11 | from .naive_sparse_attention import NaiveSparseAttention 12 | from .path_attention import PathAttention 13 | from .poly_attention import PolyAttention 14 | from .simple_sparse_attention import SimpleSparseAttention 15 | from .stickbreaking_attention import StickBreakingAttention 16 | from .tensor_product_attention import TensorProductAttention 17 | -------------------------------------------------------------------------------- /benchmarks/script_inference.sh: -------------------------------------------------------------------------------- 1 | vocab_size=16384 2 | batch_size_list=(16) 3 | input_length_list=(1) 4 | max_length_list=(256) 5 | max_length_list=(576) 6 | 7 | model_type=linear_transformer 8 | name=hgrn2 9 | for config_name in hgrn2_xxl #hgrn2_90m #hgrn2_310m hgrn2_xl hgrn2_xxl 10 | 11 | # model_type=transformer 12 | # name=llama_half_rope 13 | # for config_name in llama_half_rope_90m #llama_half_rope_310m llama_half_rope_xl llama_half_rope_xxl 14 | do 15 | cfg_path=../configs/${model_type}/${name}/${config_name}.json 16 | date=$(date '+%Y-%m-%d-%H:%M:%S') 17 | python benchmark_inference.py \ 18 | --cfg_path $cfg_path \ 19 | --name $name \ 20 | --vocab_size $vocab_size \ 21 | --batch_size_list ${batch_size_list[@]} \ 22 | --input_length_list ${input_length_list[@]} \ 23 | --max_length_list ${max_length_list[@]} \ 24 | --compile \ 25 | --cg \ 26 | 2>&1 | tee -a logs/${date}-${config_name}.log 27 | # --cg 28 | done 29 | -------------------------------------------------------------------------------- /xmixers/models/hybrid/lm_head_hybrid/modeling_outputs.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Tuple 3 | 4 | import torch 5 | from transformers.utils import ModelOutput 6 | 7 | 8 | @dataclass 9 | class LmHeadHybridOutputWithPast(ModelOutput): 10 | last_linear_attn_hidden_state: torch.FloatTensor = None 11 | last_softmax_attn_hidden_state: torch.FloatTensor = None 12 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None 13 | hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None 14 | attentions: Optional[Tuple[torch.FloatTensor, ...]] = None 15 | 16 | 17 | @dataclass 18 | class CausalLmHeadHybridOutputWithPast(ModelOutput): 19 | linear_attn_loss: Optional[torch.FloatTensor] = None 20 | softmax_attn_loss: Optional[torch.FloatTensor] = None 21 | loss: Optional[torch.FloatTensor] = None 22 | logits: torch.FloatTensor = None 23 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None 24 | hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None 25 | attentions: Optional[Tuple[torch.FloatTensor, ...]] = None 26 | -------------------------------------------------------------------------------- /xmixers/modules/normalizations/group_srms_norm.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | import torch 4 | from torch import Size 5 | 6 | _shape_t = Union[int, List[int], Size] 7 | 8 | 9 | from xopes.ops.normalize import group_srms_norm_fn 10 | 11 | 12 | class GroupSRMSNorm(torch.nn.Module): 13 | def __init__( 14 | self, 15 | num_channels: int, 16 | num_groups: int, 17 | eps: float = 1e-5, 18 | device=None, 19 | dtype=None, 20 | **kwargs 21 | ) -> None: 22 | factory_kwargs = {"device": device, "dtype": dtype} 23 | super().__init__() 24 | if num_channels % num_groups != 0: 25 | raise ValueError("num_channels must be divisible by num_groups") 26 | 27 | self.num_groups = num_groups 28 | self.num_channels = num_channels 29 | self.eps = eps 30 | 31 | def forward(self, x, residual=None, return_residual=False): 32 | return group_srms_norm_fn( 33 | x=x, 34 | dim=x.shape[-1], 35 | eps=self.eps, 36 | residual=residual, 37 | return_residual=return_residual, 38 | num_groups=self.num_groups, 39 | ) 40 | 41 | def extra_repr(self) -> str: 42 | return "num_groups={num_groups}, num_channels={num_channels}, eps={eps}".format( 43 | **self.__dict__ 44 | ) 45 | -------------------------------------------------------------------------------- /xmixers/modules/channel_mixers/__init__.py: -------------------------------------------------------------------------------- 1 | from .alu import ALU 2 | from .ffn import FFN 3 | from .glu import GLU 4 | from .nglu import nGLU 5 | 6 | AUTO_CHANNEL_MIXER_MAPPING = { 7 | "ffn": FFN, 8 | "glu": GLU, 9 | "nglu": nGLU, 10 | "alu": ALU, 11 | } 12 | 13 | 14 | def get_channel_mixer(config): 15 | cls = AUTO_CHANNEL_MIXER_MAPPING[config.channel_mixer_type] 16 | if config.channel_mixer_type in ["ffn", "glu", "nglu"]: 17 | return cls( 18 | embed_dim=config.embed_dim, 19 | mid_dim=config.mid_dim, 20 | activation=config.channel_mixer_activation, 21 | bias=config.bias, 22 | use_gate_linear=config.use_gate_linear, 23 | ) 24 | elif config.channel_mixer_type in ["alu"]: 25 | return cls( 26 | embed_dim=config.embed_dim, 27 | qk_dim=config.qk_dim, 28 | v_dim=config.v_dim, 29 | mem_dim=config.mem_dim, 30 | num_heads=config.num_heads, 31 | activation=config.channel_mixer_activation, 32 | bias=config.bias, 33 | use_scale=config.use_scale, 34 | use_output_gate=config.use_output_gate, 35 | output_gate_activation=config.output_gate_activation, 36 | use_low_rank_output_gate=config.use_low_rank_output_gate, 37 | channel_mixer_init_type=config.channel_mixer_init_type, 38 | ) 39 | -------------------------------------------------------------------------------- /xmixers/modules/normalizations/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .dynamic_tanh import DynamicTanh, DynamicTanhFusedGate 4 | from .group_norm import GroupNorm 5 | from .group_rms_norm import GroupRMSNorm, GroupRMSNormFusedGate 6 | from .group_srms_norm import GroupSRMSNorm 7 | from .l2_norm import l2_norm, rms_norm_fn 8 | from .layer_norm import LayerNorm 9 | from .offset_scale import OffsetScale 10 | from .rms_norm import GatedRMSNorm, RMSNorm, RMSNormFusedGate 11 | from .scale_norm import ScaleNorm 12 | from .srms_norm import SRMSNorm 13 | 14 | 15 | def get_norm_fn(norm_type: str): 16 | if norm_type == "rmsnorm": 17 | return RMSNorm 18 | elif norm_type == "rmsnorm_fused_gate": 19 | return RMSNormFusedGate 20 | elif norm_type == "gatedrmsnorm": 21 | return GatedRMSNorm 22 | elif norm_type == "srmsnorm": 23 | return SRMSNorm 24 | elif norm_type == "scalenorm": 25 | return ScaleNorm 26 | elif norm_type == "groupnorm": 27 | return GroupNorm 28 | elif norm_type == "grouprmsnorm": 29 | return GroupRMSNorm 30 | elif norm_type == "grouprmsnorm_fused_gate": 31 | return GroupRMSNormFusedGate 32 | elif norm_type == "groupsrmsnorm": 33 | return GroupSRMSNorm 34 | elif norm_type == "dynamictanh": 35 | return DynamicTanh 36 | elif norm_type == "dynamictanh_fused_gate": 37 | return DynamicTanhFusedGate 38 | else: 39 | return LayerNorm 40 | -------------------------------------------------------------------------------- /xmixers/modules/channel_mixers/glu.py: -------------------------------------------------------------------------------- 1 | # GLU: https://arxiv.org/pdf/2002.05202.pdf 2 | 3 | import torch.nn as nn 4 | 5 | from xmixers.modules.activations import get_activation_fn 6 | from xmixers.utils import XMIXERS_DEBUG, print_params 7 | 8 | from .utils import GateLinearOp 9 | 10 | 11 | class GLU(nn.Module): 12 | def __init__( 13 | self, 14 | embed_dim: int, 15 | mid_dim: int, 16 | activation: str, 17 | bias: bool = False, 18 | use_gate_linear: bool = False, 19 | ) -> None: 20 | super().__init__() 21 | 22 | if XMIXERS_DEBUG: 23 | # get local varables 24 | params = locals() 25 | # print params 26 | print_params(**params) 27 | 28 | self.w1 = nn.Linear(embed_dim, mid_dim, bias=bias) 29 | self.w2 = nn.Linear(embed_dim, mid_dim, bias=bias) 30 | self.w3 = nn.Linear(mid_dim, embed_dim, bias=bias) 31 | self.act = get_activation_fn(activation) 32 | self.activation = activation 33 | self.use_gate_linear = use_gate_linear 34 | self.gate_linear_op = GateLinearOp() 35 | 36 | def forward(self, x, residual=None): 37 | if self.use_gate_linear: 38 | # since we may use forward_pre_hook, we can't use key word arguments 39 | output = self.gate_linear_op( 40 | self.w1(x), 41 | self.w2(x), 42 | self.w3.weight, 43 | self.w3.bias, 44 | self.activation, 45 | residual, 46 | ) 47 | else: 48 | output = self.w3(self.act(self.w1(x)) * self.w2(x)) 49 | 50 | return output 51 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .decay_linear_transformer import ( 2 | DecayLinearTransformerConfig, 3 | DecayLinearTransformerForCausalLM, 4 | DecayLinearTransformerLayer, 5 | DecayLinearTransformerModel, 6 | ) 7 | from .deltanet import DeltaNetConfig, DeltaNetForCausalLM, DeltaNetLayer, DeltaNetModel 8 | from .dense_rnn import DenseRnnConfig, DenseRnnForCausalLM, DenseRnnLayer, DenseRnnModel 9 | from .gsa import GsaConfig, GsaForCausalLM, GsaLayer, GsaModel 10 | from .hgrn1 import Hgrn1Config, Hgrn1ForCausalLM, Hgrn1Layer, Hgrn1Model 11 | from .hgrn2 import Hgrn2Config, Hgrn2ForCausalLM, Hgrn2Layer, Hgrn2Model 12 | from .hgrn3 import Hgrn3Config, Hgrn3ForCausalLM, Hgrn3Layer, Hgrn3Model 13 | from .implicit_linear_transformer import ( 14 | ImplicitLinearTransformerConfig, 15 | ImplicitLinearTransformerForCausalLM, 16 | ImplicitLinearTransformerLayer, 17 | ImplicitLinearTransformerModel, 18 | ) 19 | from .lightnet import LightNetConfig, LightNetForCausalLM, LightNetLayer, LightNetModel 20 | from .linear_transformer import ( 21 | LinearTransformerConfig, 22 | LinearTransformerForCausalLM, 23 | LinearTransformerLayer, 24 | LinearTransformerModel, 25 | ) 26 | from .mamba2 import Mamba2ForCausalLM, Mamba2Layer, Mamba2Model, Mamba2XmixersConfig 27 | from .mesa_net import MesaNetConfig, MesaNetForCausalLM, MesaNetLayer, MesaNetModel 28 | from .metala import MetaLaConfig, MetaLaForCausalLM, MetaLaLayer, MetaLaModel 29 | from .polar_rnn import PolarRnnConfig, PolarRnnForCausalLM, PolarRnnLayer, PolarRnnModel 30 | from .tnl import TnlConfig, TnlForCausalLM, TnlLayer, TnlModel 31 | from .ttt import TTTConfig, TTTForCausalLM, TTTLayer, TTTModel 32 | -------------------------------------------------------------------------------- /xmixers/modules/channel_mixers/nglu.py: -------------------------------------------------------------------------------- 1 | # nGLU: https://arxiv.org/pdf/2410.01131 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from xmixers.modules.activations import get_activation_fn 7 | from xmixers.utils import XMIXERS_DEBUG, print_module, print_params 8 | 9 | 10 | class nGLU(nn.Module): 11 | def __init__( 12 | self, embed_dim: int, mid_dim: int, activation: str, bias: bool = False 13 | ) -> None: 14 | super().__init__() 15 | 16 | if XMIXERS_DEBUG: 17 | # get local varables 18 | params = locals() 19 | # print params 20 | print_params(**params) 21 | 22 | self.w1 = nn.Linear(embed_dim, mid_dim, bias=bias) 23 | self.w2 = nn.Linear(embed_dim, mid_dim, bias=bias) 24 | self.w3 = nn.Linear(mid_dim, embed_dim, bias=bias) 25 | self.act = get_activation_fn(activation) 26 | self.embed_dim = embed_dim 27 | 28 | self.suv_init_value = 1.0 29 | self.suv_init_scaling = 1.0 30 | self.su = torch.nn.Parameter( 31 | self.suv_init_scaling * torch.ones(mid_dim, dtype=torch.float32) 32 | ) 33 | self.sv = torch.nn.Parameter( 34 | self.suv_init_scaling * torch.ones(mid_dim, dtype=torch.float32) 35 | ) 36 | 37 | def extra_repr(self): 38 | return print_module(self) 39 | 40 | def justnorm(self, x): 41 | res = x / x.norm(p=2, dim=-1, keepdim=True) 42 | return res 43 | 44 | def forward(self, x): 45 | v = self.w1(x) 46 | u = self.w2(x) 47 | v = ( 48 | self.sv 49 | * ((self.suv_init_value / self.suv_init_scaling) * (self.embed_dim**0.5)) 50 | ) * v 51 | u = ( 52 | self.su 53 | * ((self.suv_init_value / self.suv_init_scaling) * (self.embed_dim**0.5)) 54 | ) * u 55 | output = self.w3(u * self.act(v)) 56 | 57 | return output 58 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ Setup 2 | """ 3 | from codecs import open 4 | from os import path 5 | 6 | from setuptools import find_packages, setup 7 | 8 | here = path.abspath(path.dirname(__file__)) 9 | 10 | # Get the long description from the README file 11 | with open(path.join(here, "README.md"), encoding="utf-8") as f: 12 | long_description = f.read() 13 | 14 | __version__ = "0.0.0" 15 | exec(open("xmixers/version.py").read()) 16 | setup( 17 | name="xmixers", 18 | version=__version__, 19 | description="Xmixers: A collection of SOTA efficient token/channel mixers", 20 | long_description=long_description, 21 | long_description_content_type="text/markdown", 22 | url="https://github.com/Doraemonzzz/xmixers", 23 | author="Doraemonzzz", 24 | author_email="doraemon_zzz@163.com", 25 | classifiers=[ 26 | # How mature is this project? Common values are 27 | # 3 - Alpha 28 | # 4 - Beta 29 | # 5 - Production/Stable 30 | "Development Status :: 4 - Beta", 31 | "Intended Audience :: Education", 32 | "Intended Audience :: Science/Research", 33 | "License :: OSI Approved :: Apache Software License", 34 | "Programming Language :: Python :: 3.7", 35 | "Programming Language :: Python :: 3.8", 36 | "Programming Language :: Python :: 3.9", 37 | "Programming Language :: Python :: 3.10", 38 | "Programming Language :: Python :: 3.11", 39 | "Topic :: Scientific/Engineering", 40 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 41 | "Topic :: Software Development", 42 | "Topic :: Software Development :: Libraries", 43 | "Topic :: Software Development :: Libraries :: Python Modules", 44 | ], 45 | # Note that this is a string of words separated by whitespace, not a list. 46 | keywords="", 47 | packages=find_packages(exclude=["tests", "results"]), 48 | include_package_data=True, 49 | install_requires=["torch >= 1.7"], 50 | python_requires=">=3.7", 51 | ) 52 | -------------------------------------------------------------------------------- /xmixers/modules/pes/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | # forgetting transformer: https://openreview.net/pdf?id=q2Lnyegkr8 8 | def get_log_slopes_general(d, n_min=2, n_max=256): 9 | log_n_min = math.log(n_min) 10 | log_n_max = math.log(n_max) 11 | n_list = [ 12 | math.exp((log_n_min + (log_n_max - log_n_min) * i / (d - 1))) for i in range(d) 13 | ] 14 | # exp(log_slope * n) = 1 / e => log_slope * n = -1 => log_slope = -1 / n 15 | log_slope_list = [-1 / n for n in n_list] 16 | return torch.tensor(np.array(log_slope_list), dtype=torch.float32) 17 | 18 | 19 | # alibi: https://arxiv.org/abs/2108.12409 20 | def get_log_slopes_power_of_2(d): 21 | start = 2 ** (-(2 ** -(math.log2(d) - 3))) 22 | ratio = start 23 | return [start * ratio**i for i in range(d)] 24 | 25 | 26 | # alibi: https://arxiv.org/abs/2108.12409 27 | def get_log_slopes(d): 28 | if math.log2(d).is_integer(): 29 | return torch.tensor( 30 | -np.array(get_log_slopes_power_of_2(d)), dtype=torch.float32 31 | ) # In the paper, we only train models that have 2^a heads for some a. This function has 32 | else: # some good properties that only occur when the input is a power of 2. To maintain that even 33 | closest_power_of_2 = 2 ** math.floor( 34 | math.log2(n) 35 | ) # when the number of heads is not a power of 2, we use this workaround. 36 | return torch.tensor( 37 | -np.array( 38 | ( 39 | get_slopes_power_of_2(closest_power_of_2) 40 | + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] 41 | ) 42 | ), 43 | dtype=torch.float32, 44 | ) 45 | 46 | 47 | if __name__ == "__main__": 48 | d = 8 49 | log_slope_alibi = get_log_slopes(d) 50 | log_slope_general = get_log_slopes_general(d, 2, 256) 51 | 52 | print(log_slope_alibi) 53 | print(log_slope_general) 54 | 55 | print(np.linalg.norm(log_slope_alibi - log_slope_general)) 56 | -------------------------------------------------------------------------------- /xmixers/ops/long_conv_1d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def long_conv_1d_op_naive(x: torch.Tensor, w: torch.Tensor, dim: int) -> torch.Tensor: 5 | """ 6 | x: b, n1, ... nk, d 7 | w: ni, d 8 | w0, w1, ... , w(n-1) for causal 9 | w0, w1, ... , w(n-1), w0, w(-(n-1)), ... , w(-1) for non causal 10 | dim: i 11 | """ 12 | # other dtype have numeric error 13 | assert w.dtype == torch.float32 14 | n = x.shape[dim] 15 | if w.shape[0] == n: # causal situation 16 | w = torch.cat([w, torch.zeros_like(w).to(w)], dim=0) 17 | 18 | zero = w[:1] 19 | pos = w[1:n] 20 | neg = w[n + 1 :] 21 | 22 | c = torch.cat([zero, pos], dim=0) 23 | r = torch.cat([zero, neg.flip(0)], dim=0) 24 | vals = torch.cat([r, c[1:].flip(0)], dim=0) 25 | i, j = torch.ones(n, n).nonzero().T 26 | T = vals[j - i].reshape(n, n, -1) 27 | 28 | x = x.transpose(0, dim) 29 | 30 | y = ( 31 | torch.einsum("n m d, m ... d -> n ... d", T.float(), x.float()) 32 | .transpose(0, dim) 33 | .to(x.dtype) 34 | ) 35 | 36 | return y 37 | 38 | 39 | def long_conv_1d_op(x: torch.Tensor, w: torch.Tensor, dim: int) -> torch.Tensor: 40 | """ 41 | x: b, n1, ... nk, d 42 | w: ni, d 43 | w0, w1, ... , w(n-1) for causal 44 | w0, w1, ... , w(n-1), w0, w(-(n-1)), ... , w(-1) for non causal 45 | dim: i 46 | """ 47 | # other dtype have numeric error 48 | assert w.dtype == torch.float32 49 | m = len(x.shape) 50 | if dim < 0: 51 | dim += m 52 | n = x.shape[dim] 53 | x_fft = torch.fft.rfft(x.float(), 2 * n, dim=dim) 54 | # convert w to a shape that can be broadcast to x 55 | for _ in range(dim): 56 | w = w.unsqueeze(0) 57 | for _ in range(m - 2 - dim): 58 | w = w.unsqueeze(-2) 59 | 60 | w_fft = torch.fft.rfft(w.float(), 2 * n, dim=dim) 61 | y_fft = x_fft * w_fft 62 | 63 | index = [slice(None)] * m 64 | index[dim] = slice(0, n) 65 | y = torch.fft.irfft(y_fft, 2 * n, dim=dim)[tuple(index)].to(x.dtype) 66 | 67 | return y 68 | -------------------------------------------------------------------------------- /xmixers/modules/activations.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from xmixers.utils import XMIXERS_DEBUG, logger, print_params 8 | 9 | 10 | def get_activation_fn(activation: str) -> Callable[[torch.Tensor], torch.Tensor]: 11 | if XMIXERS_DEBUG: 12 | logger.info(f"activation: {activation}") 13 | if activation == "gelu": 14 | return F.gelu 15 | elif activation == "relu": 16 | return F.relu 17 | elif activation == "elu": 18 | return F.elu 19 | elif activation == "sigmoid": 20 | return F.sigmoid 21 | elif activation == "exp": 22 | return torch.exp 23 | elif activation == "leak": 24 | return F.leaky_relu 25 | elif activation == "relu2": 26 | 27 | def f(x): 28 | return F.relu(x) ** 2 29 | 30 | return f 31 | elif activation == "1+elu": 32 | 33 | def f(x): 34 | return 1 + F.elu(x) 35 | 36 | return f 37 | elif activation == "2+elu": 38 | 39 | def f(x): 40 | return 2 + F.elu(x) 41 | 42 | return f 43 | elif activation in ["swish", "silu"]: 44 | return F.silu 45 | elif activation == "softmax_1": 46 | 47 | def f(x): 48 | return F.softmax(x, dim=-1) 49 | 50 | return f 51 | elif activation == "softmax_2": 52 | 53 | def f(x): 54 | return F.softmax(x, dim=-2) 55 | 56 | return f 57 | else: 58 | return lambda x: x 59 | 60 | 61 | class ActLayer(nn.Module): 62 | def __init__( 63 | self, 64 | activation: str, 65 | ) -> None: 66 | super().__init__() 67 | if XMIXERS_DEBUG: 68 | # get local varables 69 | params = locals() 70 | # print params 71 | print_params(**params) 72 | 73 | self.activation = activation 74 | self.f = get_activation_fn(activation) 75 | 76 | def forward(self, x): 77 | return self.f(x) 78 | 79 | def extra_repr(self): 80 | return self.activation.lower() 81 | -------------------------------------------------------------------------------- /xmixers/modules/pes/sin_cos_pe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import pack 4 | 5 | 6 | class SinCosPe(nn.Module): 7 | def __init__( 8 | self, 9 | embed_dim: int, 10 | base: int = 10000, 11 | ): 12 | super().__init__() 13 | 14 | theta = ( 15 | base 16 | ** ( 17 | -2 / embed_dim * torch.arange(embed_dim // 2, dtype=torch.int64) 18 | ).float() 19 | ) 20 | self.register_buffer("theta", theta, persistent=False) 21 | self.embed_dim = embed_dim 22 | self.pe = torch.empty(0) 23 | 24 | def extra_repr(self): 25 | return f"embed_dim={self.embed_dim}" 26 | 27 | def get_pe(self, x, shape=None): 28 | # x: b, ... , d 29 | # compute index 30 | if shape is None: 31 | shape = x.shape[1:-1] 32 | m = len(shape) 33 | array = [ 34 | torch.arange(n, dtype=torch.int64, device=torch.cuda.current_device()) 35 | for n in shape 36 | ] 37 | grid = torch.meshgrid(array) 38 | index = torch.stack(grid, dim=-1) 39 | 40 | # compute theta 41 | d = self.embed_dim // 2 // m 42 | 43 | theta = [] 44 | for i in range(m): 45 | theta.append(index[..., i : i + 1] * self.theta[:d]) 46 | 47 | theta = torch.cat(theta, dim=-1) 48 | 49 | # compute pe 50 | pe = torch.cat([torch.sin(theta), torch.cos(theta)], dim=-1) 51 | 52 | if len(x.shape) == 3: 53 | # b, n, d case 54 | pe, ps = pack([pe], "* d") 55 | 56 | self.pe = pe 57 | 58 | def forward(self, x, shape=None, offset=0): 59 | if self.pe.shape[0] == 0 or self.pe.shape[0] < (offset + x.shape[-2]): 60 | assert len(x.shape) == 3, "only support 1d case" 61 | self.get_pe(x, [offset + x.shape[-2]]) 62 | start = offset 63 | end = offset + x.shape[1] 64 | x = ( 65 | x 66 | + self.pe[ 67 | start:end, 68 | ] 69 | ) 70 | 71 | return x 72 | -------------------------------------------------------------------------------- /xmixers/modules/pes/mlp_pe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import pack 4 | 5 | 6 | class MlpPe(nn.Module): 7 | def __init__( 8 | self, 9 | embed_dim: int, 10 | base: int = 10000, 11 | bias: bool = False, 12 | ): 13 | super().__init__() 14 | theta = ( 15 | base 16 | ** ( 17 | -2 / embed_dim * torch.arange(embed_dim // 2, dtype=torch.int64) 18 | ).float() 19 | ) 20 | self.linear = nn.Linear(embed_dim, embed_dim, bias=bias) 21 | self.register_buffer("theta", theta, persistent=False) 22 | self.embed_dim = embed_dim 23 | self.pe = torch.empty(0) 24 | 25 | def get_pe(self, x, shape=None): 26 | # x: b, ... , d 27 | # compute index 28 | if shape is None: 29 | shape = x.shape[1:-1] 30 | m = len(shape) 31 | array = [ 32 | torch.arange(n, dtype=torch.int64, device=torch.cuda.current_device()) 33 | for n in shape 34 | ] 35 | grid = torch.meshgrid(array) 36 | index = torch.stack(grid, dim=-1) 37 | 38 | # compute theta 39 | d = self.embed_dim // 2 // m 40 | 41 | theta = [] 42 | for i in range(m): 43 | theta.append(index[..., i : i + 1] * self.theta[:d]) 44 | 45 | theta = torch.cat(theta, dim=-1) 46 | 47 | # compute pe 48 | pe = torch.cat([torch.sin(theta), torch.cos(theta)], dim=-1) 49 | 50 | if len(x.shape) == 3: 51 | # b, n, d case 52 | pe, ps = pack([pe], "* d") 53 | 54 | self.pe = pe 55 | 56 | def forward(self, x, shape=None, offset=0): 57 | if self.pe.shape[0] == 0 or self.pe.shape[0] < (offset + x.shape[-2]): 58 | assert len(x.shape) == 3, "only support 1d case" 59 | self.get_pe(x, [offset + x.shape[-2]]) 60 | start = offset 61 | end = offset + x.shape[1] 62 | pe = self.linear( 63 | self.pe[ 64 | start:end, 65 | ] 66 | ) 67 | x = x + pe 68 | 69 | return x 70 | -------------------------------------------------------------------------------- /scripts/linear_decay_transformer/analysis_log_decay.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | 6 | def analyze_log_decay(base_dir, model_type, model_name): 7 | """ 8 | Analyze log decay data for all layers and generate visualizations 9 | 10 | Parameters: 11 | base_dir: Base directory containing all layer data 12 | """ 13 | print("model_type model_name layer_idx median 1_4_per 3_4_per mean max min std") 14 | total_data = [] 15 | for file in os.listdir(base_dir): 16 | if file.endswith(".npy"): 17 | layer_idx = int(file.split("_")[-1].split(".")[0]) 18 | file_path = os.path.join(base_dir, file) 19 | data = np.exp(np.load(file_path))[0] 20 | total_data.append(data) 21 | 22 | print( 23 | model_type, 24 | model_name, 25 | layer_idx, 26 | np.median(data), 27 | np.percentile(data, 25), 28 | np.percentile(data, 75), 29 | np.mean(data), 30 | np.max(data), 31 | np.min(data), 32 | np.std(data), 33 | ) 34 | 35 | total_data = np.concatenate(total_data) 36 | print( 37 | model_type, 38 | model_name, 39 | -1, 40 | np.median(total_data), 41 | np.percentile(total_data, 25), 42 | np.percentile(total_data, 75), 43 | np.mean(total_data), 44 | np.max(total_data), 45 | np.min(total_data), 46 | np.std(total_data), 47 | ) 48 | 49 | 50 | if __name__ == "__main__": 51 | import argparse 52 | 53 | parser = argparse.ArgumentParser(description="Analyze log decay data") 54 | parser.add_argument( 55 | "--base_dir", 56 | type=str, 57 | required=True, 58 | help="Base directory containing all layer data", 59 | ) 60 | parser.add_argument("--model_type", type=str, required=True, help="Model type") 61 | parser.add_argument("--model_name", type=str, required=True, help="Model name") 62 | 63 | args = parser.parse_args() 64 | 65 | stats_df = analyze_log_decay(args.base_dir, args.model_type, args.model_name) 66 | -------------------------------------------------------------------------------- /xmixers/modules/normalizations/group_norm.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | from torch import Size 7 | 8 | _shape_t = Union[int, List[int], Size] 9 | 10 | 11 | class GroupNorm(torch.nn.Module): 12 | def __init__( 13 | self, 14 | num_channels: int, 15 | num_groups: int, 16 | eps: float = 1e-5, 17 | affine: bool = True, 18 | bias: bool = True, 19 | device=None, 20 | dtype=None, 21 | **kwargs 22 | ) -> None: 23 | factory_kwargs = {"device": device, "dtype": dtype} 24 | super().__init__() 25 | if num_channels % num_groups != 0: 26 | raise ValueError("num_channels must be divisible by num_groups") 27 | 28 | self.num_groups = num_groups 29 | self.num_channels = num_channels 30 | self.eps = eps 31 | self.affine = affine 32 | if self.affine: 33 | self.weight = nn.Parameter(torch.empty(num_channels, **factory_kwargs)) 34 | if bias: 35 | self.bias = nn.Parameter(torch.empty(num_channels, **factory_kwargs)) 36 | else: 37 | self.register_parameter("bias", None) 38 | else: 39 | self.register_parameter("weight", None) 40 | self.register_parameter("bias", None) 41 | 42 | self._init_weights() 43 | 44 | def _init_weights(self) -> None: 45 | if self.affine: 46 | init.ones_(self.weight) 47 | if self.bias is not None: 48 | init.zeros_(self.bias) 49 | 50 | def forward(self, x, residual=None, return_residual=False): 51 | return group_norm_fn( 52 | x=x, 53 | weight=self.weight, 54 | bias=self.bias, 55 | dim=x.shape[-1], 56 | eps=self.eps, 57 | residual=residual, 58 | return_residual=return_residual, 59 | num_groups=self.num_groups, 60 | ) 61 | 62 | def extra_repr(self) -> str: 63 | return "num_groups={num_groups}, num_channels={num_channels}, eps={eps}, affine={affine}".format( 64 | **self.__dict__ 65 | ) 66 | -------------------------------------------------------------------------------- /xmixers/modules/token_mixers/vanilla_attention/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | from xopes.ops import cumsum_fn 7 | 8 | try: 9 | from flash_attn import flash_attn_func 10 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input 11 | except: 12 | flash_attn_func = None 13 | index_first_axis = None 14 | pad_input = None 15 | unpad_input = None 16 | 17 | _pad_input = pad_input 18 | 19 | # credit to: https://github.com/fla-org/flash-linear-attention/blob/main/fla/layers/utils.py 20 | def _unpad_input( 21 | q: torch.Tensor, 22 | states: Tuple[torch.Tensor], 23 | attention_mask: torch.Tensor, 24 | q_len: int, 25 | keepdim: bool = False, 26 | ): 27 | seqlens = attention_mask.sum(-1, dtype=torch.int32) 28 | indices_k = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() 29 | max_seqlen_k = seqlens.max().item() 30 | cu_seqlens_k = F.pad(cumsum_fn(seqlens, dim=0), (1, 0)) 31 | 32 | num_heads = q.shape[-2] 33 | batch_size, seq_len, num_key_value_heads, head_dim = states[0].shape 34 | 35 | states = ( 36 | index_first_axis(rearrange(state, "b s ... -> (b s) ..."), indices_k) 37 | for state in states 38 | ) 39 | 40 | if q_len == seq_len: 41 | q = index_first_axis( 42 | q.reshape(batch_size * seq_len, num_heads, head_dim), indices_k 43 | ) 44 | cu_seqlens_q = cu_seqlens_k 45 | max_seqlen_q = max_seqlen_k 46 | indices_q = indices_k 47 | elif q_len == 1: 48 | max_seqlen_q = 1 49 | # There is a memcpy here, that is very bad. 50 | cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device) 51 | indices_q = cu_seqlens_q[:-1] 52 | q = q.squeeze(1) 53 | else: 54 | # The -q_len: slice assumes left padding. 55 | attention_mask = attention_mask[:, -q_len:] 56 | q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask) 57 | 58 | return ( 59 | q, 60 | states, 61 | indices_q, 62 | (cu_seqlens_q, cu_seqlens_k), 63 | (max_seqlen_q, max_seqlen_k), 64 | ) 65 | -------------------------------------------------------------------------------- /xmixers/utils/mask_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from xopes.ops import cumsum_fn 4 | 5 | try: 6 | from flash_attn import flash_attn_func 7 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input 8 | except: 9 | flash_attn_func = None 10 | index_first_axis = None 11 | pad_input = None 12 | unpad_input = None 13 | 14 | 15 | def attn_mask_to_cu_seqlens(attention_mask): 16 | seqlens = attention_mask.sum(-1, dtype=torch.int32) 17 | cu_seqlens = F.pad(cumsum_fn(seqlens, dim=0), (1, 0)) 18 | return cu_seqlens 19 | 20 | 21 | # credit to: https://github.com/fla-org/flash-linear-attention/blob/main/fla/layers/attn.py 22 | def _upad_input(q, k, v, attention_mask, q_len): 23 | seqlens = attention_mask.sum(-1, dtype=torch.int32) 24 | indices_k = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() 25 | max_seqlen_k = seqlens.max().item() 26 | cu_seqlens_k = F.pad(cumsum_fn(seqlens, dim=0), (1, 0)) 27 | 28 | num_heads = q.shape[-2] 29 | batch_size, seq_len, num_key_value_heads, head_dim = k.shape 30 | 31 | k = index_first_axis( 32 | k.reshape(batch_size * seq_len, num_key_value_heads, -1), indices_k 33 | ) 34 | v = index_first_axis( 35 | v.reshape(batch_size * seq_len, num_key_value_heads, -1), indices_k 36 | ) 37 | if q_len == seq_len: 38 | q = index_first_axis(q.reshape(batch_size * seq_len, num_heads, -1), indices_k) 39 | cu_seqlens_q = cu_seqlens_k 40 | max_seqlen_q = max_seqlen_k 41 | indices_q = indices_k 42 | elif q_len == 1: 43 | max_seqlen_q = 1 44 | # There is a memcpy here, that is very bad. 45 | cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device) 46 | indices_q = cu_seqlens_q[:-1] 47 | q = q.squeeze(1) 48 | else: 49 | # The -q_len: slice assumes left padding. 50 | attention_mask = attention_mask[:, -q_len:] 51 | q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask) 52 | 53 | return ( 54 | q, 55 | k, 56 | v, 57 | indices_q, 58 | (cu_seqlens_q, cu_seqlens_k), 59 | (max_seqlen_q, max_seqlen_k), 60 | ) 61 | -------------------------------------------------------------------------------- /xmixers/modules/pes/md/md_tpe.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | from einops import pack, unpack 6 | from transformers.cache_utils import Cache 7 | 8 | from ..tpe import Tpe 9 | 10 | 11 | class MdTpe(nn.Module): 12 | def __init__( 13 | self, 14 | embed_dim: int, 15 | num_heads: int, 16 | dims: List[int] = [-2], 17 | bias: bool = False, 18 | layer_idx: int = 0, 19 | token_mixer_norm_type: str = "rmsnorm", 20 | token_mixer_init_type: int = 4, 21 | rescale_type: int = 2, 22 | num_layers: int = 12, 23 | init_std: float = 0.02, 24 | gain: float = 0.01, 25 | **kwargs, 26 | ): 27 | super().__init__() 28 | self.dims = dims 29 | self.tpes = nn.ModuleList([]) 30 | for dim in dims: 31 | self.tpes.append( 32 | Tpe( 33 | embed_dim=embed_dim, 34 | num_heads=num_heads, 35 | dim=dim, 36 | bias=bias, 37 | layer_idx=layer_idx, 38 | token_mixer_norm_type=token_mixer_norm_type, 39 | token_mixer_init_type=token_mixer_init_type, 40 | rescale_type=rescale_type, 41 | num_layers=num_layers, 42 | init_std=init_std, 43 | gain=gain, 44 | **kwargs, 45 | ) 46 | ) 47 | 48 | def forward( 49 | self, 50 | x, 51 | attention_mask: Optional[torch.Tensor] = None, # (b, m) 52 | past_key_values: Optional[Cache] = None, 53 | use_cache: Optional[bool] = False, 54 | **kwargs, 55 | ): 56 | o = 0 57 | for i, tpe in enumerate(self.tpes): 58 | o += self.forward_tpe(x, tpe, dim=self.dims[i]) 59 | return o 60 | 61 | def forward_tpe(self, x, tpe, dim=-2): 62 | if dim != -2: 63 | x = x.transpose(dim, -2) 64 | 65 | x, ps = pack([x], "* n d") 66 | x = tpe(x)[0] 67 | x = unpack(x, ps, "* n d")[0] 68 | 69 | if dim != -2: 70 | x = x.transpose(dim, -2) 71 | 72 | return x 73 | -------------------------------------------------------------------------------- /xmixers/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .chunk_linear_transformer import ( 2 | ChunkRnnConfig, 3 | ChunkRnnForCausalLM, 4 | ChunkRnnLayer, 5 | ChunkRnnModel, 6 | ) 7 | from .hybrid import ( 8 | LmHeadHybridConfig, 9 | LmHeadHybridForCausalLM, 10 | LmHeadHybridLayer, 11 | LmHeadHybridModel, 12 | NaiveHybridConfig, 13 | NaiveHybridForCausalLM, 14 | NaiveHybridLayer, 15 | NaiveHybridModel, 16 | ) 17 | from .linear_transformer import ( 18 | DecayLinearTransformerConfig, 19 | DecayLinearTransformerForCausalLM, 20 | DecayLinearTransformerModel, 21 | DeltaNetConfig, 22 | DeltaNetForCausalLM, 23 | DeltaNetLayer, 24 | DeltaNetModel, 25 | DenseRnnConfig, 26 | DenseRnnForCausalLM, 27 | DenseRnnModel, 28 | GsaConfig, 29 | GsaForCausalLM, 30 | GsaModel, 31 | Hgrn1Config, 32 | Hgrn1ForCausalLM, 33 | Hgrn1Model, 34 | Hgrn2Config, 35 | Hgrn2ForCausalLM, 36 | Hgrn2Model, 37 | Hgrn3Config, 38 | Hgrn3ForCausalLM, 39 | Hgrn3Model, 40 | ImplicitLinearTransformerConfig, 41 | ImplicitLinearTransformerForCausalLM, 42 | ImplicitLinearTransformerLayer, 43 | ImplicitLinearTransformerModel, 44 | LightNetConfig, 45 | LightNetForCausalLM, 46 | LightNetModel, 47 | LinearTransformerConfig, 48 | LinearTransformerForCausalLM, 49 | LinearTransformerModel, 50 | Mamba2ForCausalLM, 51 | Mamba2Layer, 52 | Mamba2Model, 53 | Mamba2XmixersConfig, 54 | MesaNetConfig, 55 | MesaNetForCausalLM, 56 | MesaNetLayer, 57 | MesaNetModel, 58 | MetaLaConfig, 59 | MetaLaForCausalLM, 60 | MetaLaModel, 61 | PolarRnnConfig, 62 | PolarRnnForCausalLM, 63 | PolarRnnModel, 64 | TnlConfig, 65 | TnlForCausalLM, 66 | TnlModel, 67 | TTTConfig, 68 | TTTForCausalLM, 69 | TTTLayer, 70 | TTTModel, 71 | ) 72 | from .long_conv import * 73 | from .transformer import ( 74 | FlexGPTConfig, 75 | FlexGPTForCausalLM, 76 | FlexGPTLayer, 77 | FlexGPTModel, 78 | GPTConfig, 79 | GPTForCausalLM, 80 | GPTLayer, 81 | GPTModel, 82 | LLaMAConfig, 83 | LLaMAForCausalLM, 84 | LLaMALayer, 85 | LLaMAModel, 86 | nGPTConfig, 87 | nGPTForCausalLM, 88 | nGPTLayer, 89 | nGPTModel, 90 | ) 91 | -------------------------------------------------------------------------------- /xmixers/modules/token_mixers/long_conv/rpe.py: -------------------------------------------------------------------------------- 1 | """ 2 | Relative Position Encoder in https://arxiv.org/pdf/2305.04749.pdf 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from xmixers.modules import get_norm_fn 9 | from xmixers.modules.activations import ActLayer 10 | from xmixers.utils import XMIXERS_DEBUG, print_params 11 | 12 | 13 | class Rpe(nn.Module): 14 | def __init__( 15 | self, 16 | in_dim: int, 17 | feature_dim: int, 18 | out_dim: int, 19 | activation: str = "silu", 20 | bias: bool = False, 21 | rpe_layers: int = 3, 22 | norm_type: str = "layernorm", 23 | *args, 24 | **kwargs, 25 | ) -> None: 26 | super().__init__() 27 | 28 | if XMIXERS_DEBUG: 29 | # get local varables 30 | params = locals() 31 | # print params 32 | print_params(**params) 33 | 34 | self.in_dim = in_dim 35 | if in_dim > 1: 36 | theta = 10000 ** (-2 / in_dim * torch.arange(in_dim // 2)).reshape(1, -1) 37 | self.register_buffer("theta", theta) 38 | 39 | self.pos_proj = nn.Linear(in_dim, feature_dim, bias=bias) 40 | self.layers = nn.ModuleList([]) 41 | for _ in range(rpe_layers): 42 | self.layers.append( 43 | nn.Sequential( 44 | get_norm_fn(norm_type)(feature_dim), 45 | ActLayer(activation), 46 | nn.Linear(feature_dim, feature_dim, bias=bias), 47 | ) 48 | ) 49 | self.out = nn.Sequential( 50 | get_norm_fn(norm_type)(feature_dim), 51 | ActLayer(activation), 52 | nn.Linear(feature_dim, out_dim, bias=bias), 53 | ) 54 | 55 | def get_feature(self, index): 56 | if self.in_dim > 1: 57 | theta = index * self.theta 58 | x = torch.cat([torch.cos(theta), torch.sin(theta)], dim=-1) 59 | else: 60 | x = index 61 | 62 | return x 63 | 64 | def forward(self, index): 65 | input = self.get_feature(index).to(self.pos_proj.weight.dtype) 66 | x = self.pos_proj(input) 67 | for m in self.layers: 68 | x = m(x) + x 69 | x = self.out(x) 70 | 71 | return x 72 | -------------------------------------------------------------------------------- /xmixers/modules/normalizations/dynamic_tanh.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.distributed.tensor import DTensor 4 | 5 | 6 | class DynamicTanh(nn.Module): 7 | def __init__(self, normalized_shape, alpha_init_value=0.5, **kwargs): 8 | super().__init__() 9 | self.normalized_shape = normalized_shape 10 | self.alpha_init_value = alpha_init_value 11 | 12 | self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) 13 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 14 | 15 | self._init_weights() 16 | 17 | def _init_weights(self): 18 | alpha = torch.ones(1) * self.alpha_init_value 19 | if isinstance(self.alpha, DTensor): 20 | self.alpha.data.copy_( 21 | DTensor.from_local(alpha, device_mesh=self.alpha.device_mesh) 22 | ) 23 | else: 24 | self.alpha.data.copy_(alpha) 25 | nn.init.ones_(self.weight) 26 | 27 | def forward(self, x): 28 | # TODO: add a fusion here 29 | return torch.tanh(self.alpha * x) * self.weight 30 | 31 | def extra_repr(self): 32 | return f"normalized_shape={self.normalized_shape}, alpha_init_value={self.alpha_init_value}" 33 | 34 | 35 | class DynamicTanhFusedGate(nn.Module): 36 | def __init__(self, normalized_shape, alpha_init_value=0.5, **kwargs): 37 | super().__init__() 38 | self.normalized_shape = normalized_shape 39 | self.alpha_init_value = alpha_init_value 40 | 41 | self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) 42 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 43 | 44 | self._init_weights() 45 | 46 | def _init_weights(self): 47 | alpha = torch.ones(1) * self.alpha_init_value 48 | if isinstance(self.alpha, DTensor): 49 | self.alpha.data.copy_( 50 | DTensor.from_local(alpha, device_mesh=self.alpha.device_mesh) 51 | ) 52 | else: 53 | self.alpha.data.copy_(alpha) 54 | nn.init.ones_(self.weight) 55 | 56 | def forward(self, x, gate): 57 | # TODO: add a fusion here 58 | return torch.tanh(self.alpha * x) * self.weight * gate 59 | 60 | def extra_repr(self): 61 | return f"normalized_shape={self.normalized_shape}, alpha_init_value={self.alpha_init_value}" 62 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/gsa/configuration_gsa.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ Gsa configuration""" 3 | 4 | from transformers.configuration_utils import PretrainedConfig 5 | from transformers.utils import logging 6 | 7 | logger = logging.get_logger(__name__) 8 | 9 | 10 | class GsaConfig(PretrainedConfig): 11 | model_type = "gsa" 12 | keys_to_ignore_at_inference = ["past_key_values"] 13 | 14 | def __init__( 15 | self, 16 | pad_token_id=0, 17 | bos_token_id=1, 18 | eos_token_id=2, 19 | vocab_size=64000, 20 | use_cache=True, 21 | init_std=0.02, 22 | tie_word_embeddings=False, 23 | ########## model config 24 | ##### token mixer config 25 | token_mixer_type="gsa", 26 | embed_dim=1024, 27 | num_heads=8, 28 | bias=False, 29 | gate_act="sigmoid", 30 | gate_pos="pre", 31 | token_mixer_norm_type="rmsnorm", 32 | ##### channel mixer config 33 | channel_mixer_type="glu", 34 | mid_dim=1024, 35 | channel_mixer_activation="silu", 36 | use_gate_linear=True, 37 | ##### others 38 | max_position_embeddings=1024, 39 | num_layers=24, 40 | use_output_gate=False, 41 | norm_type="rmsnorm", 42 | q_activation="silu", 43 | k_activation="silu", 44 | num_slots=64, 45 | causal=True, 46 | use_embed_scale=False, 47 | ce_type="xopes_flce", 48 | pad_embed_dim=True, 49 | ##### init 50 | init_type=1, 51 | token_mixer_init_type=4, 52 | rescale_type=2, 53 | gain=0.01, 54 | channel_mixer_init_type=0, 55 | **kwargs, 56 | ): 57 | super().__init__( 58 | pad_token_id=pad_token_id, 59 | bos_token_id=bos_token_id, 60 | eos_token_id=eos_token_id, 61 | tie_word_embeddings=tie_word_embeddings, 62 | **kwargs, 63 | ) 64 | for key, value in locals().items(): 65 | if key not in [ 66 | "self", 67 | "kwargs", 68 | "__class__", 69 | "pad_token_id", 70 | "bos_token_id", 71 | "eos_token_id", 72 | "tie_word_embeddings", 73 | ]: 74 | setattr(self, key, value) 75 | -------------------------------------------------------------------------------- /xmixers/models/transformer/flex_gpt/configuration_flex_gpt.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ FlexGPT configuration""" 3 | 4 | from transformers.configuration_utils import PretrainedConfig 5 | from transformers.utils import logging 6 | 7 | logger = logging.get_logger(__name__) 8 | 9 | 10 | class FlexGPTConfig(PretrainedConfig): 11 | model_type = "flex_gpt" 12 | keys_to_ignore_at_inference = ["past_key_values"] 13 | 14 | def __init__( 15 | self, 16 | pad_token_id=0, 17 | bos_token_id=1, 18 | eos_token_id=2, 19 | vocab_size=64000, 20 | use_cache=True, 21 | init_std=0.02, 22 | tie_word_embeddings=False, 23 | ########## model config 24 | ##### token mixer config 25 | token_mixer_type="flex_attn", 26 | embed_dim=1024, 27 | num_heads=8, 28 | kv_heads=-1, 29 | bias=False, 30 | window_size=-1, 31 | token_mixer_norm_type="grouprmsnorm", 32 | use_offset=False, 33 | threshold=0.99, 34 | ###### channel mixer config 35 | channel_mixer_type="glu", 36 | mid_dim=1024, 37 | channel_mixer_activation="silu", 38 | use_gate_linear=True, 39 | ##### others 40 | max_position_embeddings=1024, 41 | num_layers=24, 42 | norm_type="rmsnorm", 43 | use_embed_scale=False, 44 | ce_type="xopes_flce", 45 | fuse_norm_add=False, 46 | rpe_type=0, # 0: no rpe, 1: alibi 47 | n_min=2, 48 | n_max=256, 49 | ##### init 50 | init_type=1, 51 | token_mixer_init_type=4, 52 | rescale_type=2, 53 | gain=0.01, 54 | pad_embed_dim=True, 55 | **kwargs, 56 | ): 57 | super().__init__( 58 | pad_token_id=pad_token_id, 59 | bos_token_id=bos_token_id, 60 | eos_token_id=eos_token_id, 61 | tie_word_embeddings=tie_word_embeddings, 62 | **kwargs, 63 | ) 64 | for key, value in locals().items(): 65 | if key not in [ 66 | "self", 67 | "kwargs", 68 | "__class__", 69 | "pad_token_id", 70 | "bos_token_id", 71 | "eos_token_id", 72 | "tie_word_embeddings", 73 | ]: 74 | setattr(self, key, value) 75 | -------------------------------------------------------------------------------- /xmixers/models/transformer/ngpt/configuration_ngpt.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # nGLU: https://arxiv.org/pdf/2410.01131 3 | """ nGPT configuration""" 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | from transformers.utils import logging 7 | 8 | logger = logging.get_logger(__name__) 9 | 10 | 11 | class nGPTConfig(PretrainedConfig): 12 | model_type = "ngpt" 13 | keys_to_ignore_at_inference = ["past_key_values"] 14 | 15 | def __init__( 16 | self, 17 | pad_token_id=0, 18 | bos_token_id=1, 19 | eos_token_id=2, 20 | vocab_size=64000, 21 | use_cache=True, 22 | init_std=0.02, 23 | tie_word_embeddings=False, 24 | ##### model config 25 | # attention config 26 | embed_dim=1024, 27 | num_heads=8, 28 | kv_heads=-1, 29 | bias=False, 30 | base=10000, 31 | ape_type="sincos", 32 | # glu config 33 | mid_dim=1024, 34 | glu_activation="silu", 35 | # others 36 | max_position_embeddings=1024, 37 | num_layers=24, 38 | token_mixer_init_type=0, 39 | init_type=0, 40 | use_embed_scale=False, 41 | pad_embed_dim=True, 42 | **kwargs, 43 | ): 44 | super().__init__( 45 | pad_token_id=pad_token_id, 46 | bos_token_id=bos_token_id, 47 | eos_token_id=eos_token_id, 48 | tie_word_embeddings=tie_word_embeddings, 49 | **kwargs, 50 | ) 51 | assert ape_type in ["sincos", "learnable", "mlp"] 52 | ##### hf origin 53 | self.vocab_size = vocab_size 54 | self.use_cache = use_cache 55 | self.init_std = init_std 56 | ##### add 57 | # attention config 58 | self.embed_dim = embed_dim 59 | self.num_heads = num_heads 60 | self.kv_heads = kv_heads 61 | self.bias = bias 62 | self.base = base 63 | self.ape_type = ape_type 64 | # glu config 65 | self.mid_dim = mid_dim 66 | self.glu_activation = glu_activation 67 | # others 68 | self.max_position_embeddings = max_position_embeddings 69 | self.num_layers = num_layers 70 | self.token_mixer_init_type = token_mixer_init_type 71 | self.init_type = init_type 72 | self.use_embed_scale = use_embed_scale 73 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Xmixers: A collection of SOTA efficient token/channel mixers 2 | 3 |

4 | 💬 Discord • 5 |

6 | 7 | ## Introduction 8 | This repository aims to implement SOTA efficient token/channel mixers. Any technologies related to non-Vanilla Transformer are welcome. If you are interested in this repository, please join our [Discord](https://discord.gg/ZpqcpSDE8g). 9 | 10 | ## Install 11 | 12 | Install Torch(>=2.6.0), [fla](https://github.com/fla-org/flash-linear-attention), [xopes](https://github.com/Doraemonzzz/xopes.git) first, then install xmixers: 13 | 14 | ``` 15 | git clone https://github.com/Doraemonzzz/xmixers.git 16 | cd xmixers 17 | pip install -e . 18 | ``` 19 | 20 | ## Models 21 | 22 | | Paper | Code | Config | 23 | | --------------------------------------------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | 24 | | Elucidating the Design Space of Decay in Linear Attention | [Link](https://github.com/Doraemonzzz/xmixers/blob/main/xmixers/modules/token_mixers/linear_attention/decay_linear_attention.py), core code: line 247 to 262. | [Core method link](https://github.com/Doraemonzzz/flame/tree/main/configs/xmixers/decay_linear_transformer/decay_parameterize/hgrn3), [Ablation link](https://github.com/Doraemonzzz/flame/tree/main/configs/xmixers/decay_linear_transformer/decay_parameterize) | 25 | | | | | 26 | | | | | 27 | 28 | 29 | 30 | ## Training 31 | 32 | To reproduce the results of the paper, we conducted training using the [flame](https://github.com/Doraemonzzz/flame?tab=readme-ov-file#training-xmixers-models) framework. First, we configured the environment in accordance with flame's requirements, then used flame's training script, and only needed to replace "config" with the corresponding name. 33 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/hgrn1/configuration_hgrn1.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ Hgrn1 configuration""" 3 | 4 | from transformers.configuration_utils import PretrainedConfig 5 | from transformers.utils import logging 6 | 7 | logger = logging.get_logger(__name__) 8 | 9 | 10 | class Hgrn1Config(PretrainedConfig): 11 | model_type = "hgrn1" 12 | keys_to_ignore_at_inference = ["past_key_values"] 13 | 14 | def __init__( 15 | self, 16 | pad_token_id=0, 17 | bos_token_id=1, 18 | eos_token_id=2, 19 | vocab_size=64000, 20 | use_cache=True, 21 | init_std=0.02, 22 | tie_word_embeddings=False, 23 | ########## model config 24 | ##### token mixer config 25 | token_mixer_type="hgru1", 26 | embed_dim=1024, 27 | head_dim=128, 28 | bias=False, 29 | gate_act="sigmoid", 30 | gate_pos="pre", 31 | token_mixer_norm_type="rmsnorm", 32 | ##### channel mixer config 33 | channel_mixer_type="glu", 34 | mid_dim=1024, 35 | channel_mixer_activation="silu", 36 | use_gate_linear=True, 37 | ##### others 38 | max_position_embeddings=1024, 39 | num_layers=24, 40 | use_output_gate=False, 41 | norm_type="rmsnorm", 42 | q_activation="silu", 43 | k_activation="silu", 44 | causal=True, 45 | use_embed_scale=False, 46 | use_dense_memory=False, 47 | beta_activation="silu", 48 | ce_type="xopes_flce", 49 | pad_embed_dim=True, 50 | ##### init 51 | init_type=1, 52 | token_mixer_init_type=4, 53 | rescale_type=2, 54 | gain=0.01, 55 | channel_mixer_init_type=0, 56 | **kwargs, 57 | ): 58 | super().__init__( 59 | pad_token_id=pad_token_id, 60 | bos_token_id=bos_token_id, 61 | eos_token_id=eos_token_id, 62 | tie_word_embeddings=tie_word_embeddings, 63 | **kwargs, 64 | ) 65 | for key, value in locals().items(): 66 | if key not in [ 67 | "self", 68 | "kwargs", 69 | "__class__", 70 | "pad_token_id", 71 | "bos_token_id", 72 | "eos_token_id", 73 | "tie_word_embeddings", 74 | ]: 75 | setattr(self, key, value) 76 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/hgrn2/configuration_hgrn2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ Hgrn2 configuration""" 3 | 4 | from transformers.configuration_utils import PretrainedConfig 5 | from transformers.utils import logging 6 | 7 | logger = logging.get_logger(__name__) 8 | 9 | 10 | class Hgrn2Config(PretrainedConfig): 11 | model_type = "hgrn2" 12 | keys_to_ignore_at_inference = ["past_key_values"] 13 | 14 | def __init__( 15 | self, 16 | pad_token_id=0, 17 | bos_token_id=1, 18 | eos_token_id=2, 19 | vocab_size=64000, 20 | use_cache=True, 21 | init_std=0.02, 22 | tie_word_embeddings=False, 23 | ########## model config 24 | ##### token mixer config 25 | token_mixer_type="hgru2", 26 | embed_dim=1024, 27 | num_heads=8, 28 | bias=False, 29 | gate_act="sigmoid", 30 | gate_pos="pre", 31 | token_mixer_norm_type="rmsnorm", 32 | ##### channel mixer config 33 | channel_mixer_type="glu", 34 | mid_dim=1024, 35 | channel_mixer_activation="silu", 36 | use_gate_linear=True, 37 | ##### others 38 | max_position_embeddings=1024, 39 | num_layers=24, 40 | use_output_gate=False, 41 | norm_type="rmsnorm", 42 | q_activation="silu", 43 | k_activation="silu", 44 | causal=True, 45 | use_embed_scale=False, 46 | use_dense_memory=False, 47 | beta_activation="silu", 48 | ce_type="xopes_flce", 49 | pad_embed_dim=True, 50 | ##### init 51 | init_type=1, 52 | token_mixer_init_type=4, 53 | rescale_type=2, 54 | gain=0.01, 55 | channel_mixer_init_type=0, 56 | **kwargs, 57 | ): 58 | super().__init__( 59 | pad_token_id=pad_token_id, 60 | bos_token_id=bos_token_id, 61 | eos_token_id=eos_token_id, 62 | tie_word_embeddings=tie_word_embeddings, 63 | **kwargs, 64 | ) 65 | for key, value in locals().items(): 66 | if key not in [ 67 | "self", 68 | "kwargs", 69 | "__class__", 70 | "pad_token_id", 71 | "bos_token_id", 72 | "eos_token_id", 73 | "tie_word_embeddings", 74 | ]: 75 | setattr(self, key, value) 76 | -------------------------------------------------------------------------------- /xmixers/modules/normalizations/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | from xopes.ops.normalize import normalize_fn 6 | 7 | 8 | class NormOp(nn.Module): 9 | def __init__( 10 | self, 11 | norm_type: str, 12 | ): 13 | self.norm_type = norm_type 14 | super().__init__() 15 | 16 | def forward( 17 | self, 18 | x: torch.Tensor, 19 | weight: Optional[torch.Tensor] = None, 20 | bias: Optional[torch.Tensor] = None, 21 | residual: Optional[torch.Tensor] = None, 22 | dim: float = 1.0, 23 | eps: float = 1e-6, 24 | use_mean: bool = False, 25 | num_groups: int = 1, 26 | return_residual: bool = False, 27 | gate: Optional[torch.Tensor] = None, 28 | gate_act: str = "sigmoid", 29 | gate_pos: str = "pre", 30 | ): 31 | if self.norm_type == "layernorm": 32 | c = dim**0.5 33 | use_mean = True 34 | num_groups = 1 35 | elif self.norm_type == "rmsnorm": 36 | bias = None 37 | c = dim**0.5 38 | use_mean = False 39 | num_groups = 1 40 | elif self.norm_type == "srmsnorm": 41 | weight = None 42 | bias = None 43 | c = dim**0.5 44 | use_mean = False 45 | num_groups = 1 46 | elif self.norm_type == "groupnorm": 47 | group_size = dim // num_groups 48 | c = group_size**0.5 49 | use_mean = True 50 | elif self.norm_type == "grouprmsnorm": 51 | group_size = dim // num_groups 52 | c = group_size**0.5 53 | use_mean = False 54 | elif self.norm_type == "groupsrmsnorm": 55 | group_size = dim // num_groups 56 | c = group_size**0.5 57 | use_mean = False 58 | else: 59 | raise ValueError(f"Invalid normalization type: {self.norm_type}") 60 | 61 | return normalize_fn( 62 | x=x, 63 | weight=weight, 64 | bias=bias, 65 | residual=residual, 66 | gate=gate, 67 | gate_act=gate_act, 68 | gate_pos=gate_pos, 69 | c=c, 70 | eps=eps, 71 | use_mean=use_mean, 72 | num_groups=num_groups, 73 | return_residual=return_residual, 74 | ) 75 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/hgrn3/configuration_hgrn3.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ Hgrn3 configuration""" 3 | 4 | from transformers.configuration_utils import PretrainedConfig 5 | from transformers.utils import logging 6 | 7 | logger = logging.get_logger(__name__) 8 | 9 | 10 | class Hgrn3Config(PretrainedConfig): 11 | model_type = "hgrn3" 12 | keys_to_ignore_at_inference = ["past_key_values"] 13 | 14 | def __init__( 15 | self, 16 | pad_token_id=0, 17 | bos_token_id=1, 18 | eos_token_id=2, 19 | vocab_size=64000, 20 | use_cache=True, 21 | init_std=0.02, 22 | tie_word_embeddings=False, 23 | ########## model config 24 | ##### token mixer config 25 | embed_dim=1024, 26 | num_heads=8, 27 | bias=False, 28 | scalar_decay=False, 29 | token_mixer_type="hgru3", 30 | gate_act="sigmoid", 31 | gate_pos="pre", 32 | token_mixer_norm_type="rmsnorm", 33 | ##### channel mixer config 34 | mid_dim=1024, 35 | channel_mixer_type="glu", 36 | channel_mixer_activation="silu", 37 | use_gate_linear=True, 38 | ##### others 39 | max_position_embeddings=1024, 40 | num_layers=24, 41 | use_output_gate=True, 42 | norm_type="layernorm", 43 | q_activation="silu", 44 | k_activation="silu", 45 | threshold=0.99, 46 | causal=True, 47 | use_dense_memory=True, 48 | use_embed_scale=False, 49 | ce_type="xopes_flce", 50 | pad_embed_dim=True, 51 | ##### init 52 | init_type=1, 53 | token_mixer_init_type=4, 54 | rescale_type=2, 55 | gain=0.01, 56 | channel_mixer_init_type=0, 57 | **kwargs, 58 | ): 59 | super().__init__( 60 | pad_token_id=pad_token_id, 61 | bos_token_id=bos_token_id, 62 | eos_token_id=eos_token_id, 63 | tie_word_embeddings=tie_word_embeddings, 64 | **kwargs, 65 | ) 66 | for key, value in locals().items(): 67 | if key not in [ 68 | "self", 69 | "kwargs", 70 | "__class__", 71 | "pad_token_id", 72 | "bos_token_id", 73 | "eos_token_id", 74 | "tie_word_embeddings", 75 | ]: 76 | setattr(self, key, value) 77 | -------------------------------------------------------------------------------- /xmixers/modules/token_mixers/deep_memory/optimizer/sgd.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | 5 | from .base_optimizer import FastWeightOptimizer 6 | from .utils import get_pooling_fn 7 | 8 | 9 | class SGD(FastWeightOptimizer): 10 | def __init__( 11 | self, 12 | fast_weight: Dict[str, torch.Tensor], 13 | lr: float = 1, 14 | wd: float = 0, 15 | momentum: float = 0, 16 | damping: float = 0, 17 | pooling_method: str = "mean", 18 | **kwargs, 19 | ): 20 | self.params_dict = {} 21 | for name, param in fast_weight.items(): 22 | if param.requires_grad: 23 | self.params_dict[name] = param 24 | 25 | self.lr = lr 26 | self.wd = wd 27 | self.momentum = momentum 28 | self.damping = damping 29 | self.use_momentum = momentum > 0 30 | self.pooling_fn = get_pooling_fn(pooling_method) 31 | 32 | self.setup_state() 33 | 34 | def step( 35 | self, 36 | lr_dict=None, 37 | wd_dict=None, 38 | momentum_dict=None, 39 | damping_dict=None, 40 | ): 41 | if lr_dict is None: 42 | lr_dict = {} 43 | if wd_dict is None: 44 | wd_dict = {} 45 | if momentum_dict is None: 46 | momentum_dict = {} 47 | if damping_dict is None: 48 | damping_dict = {} 49 | 50 | for name, p in self.params_dict.items(): 51 | lr = self.pooling(lr_dict.get(name, self.lr)) 52 | wd = self.pooling(wd_dict.get(name, self.wd)) 53 | momentum = self.pooling(momentum_dict.get(name, self.momentum)) 54 | damping = self.pooling(damping_dict.get(name, self.damping)) 55 | lr = 1 56 | 57 | use_weight_decay = (not isinstance(wd, torch.Tensor)) and (wd != 0) 58 | use_momentum = (not isinstance(momentum, torch.Tensor)) and (momentum != 0) 59 | 60 | grad = p.grad 61 | if use_weight_decay: 62 | grad.add_(p.data * wd) 63 | 64 | if use_momentum: 65 | if self.state[name].get("momentum", None) is not None: 66 | self.state[name]["momentum"] = torch.zeros_like(g) 67 | 68 | buf = self.state[name]["momentum"] 69 | buf.mul_(momentum).add_(grad * (1 - damping)) 70 | grad = buf 71 | 72 | p.data.add_(-grad * lr) 73 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/ttt/configuration_ttt.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ TTT configuration""" 3 | 4 | from transformers.configuration_utils import PretrainedConfig 5 | from transformers.utils import logging 6 | 7 | logger = logging.get_logger(__name__) 8 | 9 | 10 | class TTTConfig(PretrainedConfig): 11 | model_type = "ttt" 12 | keys_to_ignore_at_inference = ["past_key_values"] 13 | 14 | def __init__( 15 | self, 16 | pad_token_id=0, 17 | bos_token_id=1, 18 | eos_token_id=2, 19 | vocab_size=64000, 20 | use_cache=True, 21 | init_std=0.02, 22 | tie_word_embeddings=False, 23 | ########## model config 24 | ##### token mixer config 25 | token_mixer_type="ttt", 26 | embed_dim=1024, 27 | num_heads=8, 28 | bias=False, 29 | token_mixer_norm_type="rmsnorm", 30 | q_activation="silu", 31 | k_activation="silu", 32 | norm_k=True, 33 | beta_activation="neg", 34 | causal=True, 35 | gate_act="sigmoid", 36 | gate_pos="pre", 37 | use_input_gate=False, 38 | ##### channel mixer config 39 | channel_mixer_type="glu", 40 | mid_dim=1024, 41 | channel_mixer_activation="silu", 42 | use_gate_linear=True, 43 | ##### others 44 | use_initial_state=False, 45 | max_position_embeddings=1024, 46 | use_output_gate=True, 47 | norm_type="rmsnorm", 48 | num_layers=12, 49 | use_embed_scale=False, 50 | ce_type="xopes_flce", 51 | pad_embed_dim=True, 52 | ##### init 53 | init_type=1, 54 | token_mixer_init_type=4, 55 | rescale_type=2, 56 | gain=0.01, 57 | channel_mixer_init_type=0, 58 | **kwargs, 59 | ): 60 | super().__init__( 61 | pad_token_id=pad_token_id, 62 | bos_token_id=bos_token_id, 63 | eos_token_id=eos_token_id, 64 | tie_word_embeddings=tie_word_embeddings, 65 | **kwargs, 66 | ) 67 | for key, value in locals().items(): 68 | if key not in [ 69 | "self", 70 | "kwargs", 71 | "__class__", 72 | "pad_token_id", 73 | "bos_token_id", 74 | "eos_token_id", 75 | "tie_word_embeddings", 76 | ]: 77 | setattr(self, key, value) 78 | -------------------------------------------------------------------------------- /xmixers/models/transformer/gpt/configuration_gpt.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ GPT configuration""" 3 | 4 | from transformers.configuration_utils import PretrainedConfig 5 | from transformers.utils import logging 6 | 7 | logger = logging.get_logger(__name__) 8 | 9 | 10 | class GPTConfig(PretrainedConfig): 11 | model_type = "gpt" 12 | keys_to_ignore_at_inference = ["past_key_values"] 13 | 14 | def __init__( 15 | self, 16 | pad_token_id=0, 17 | bos_token_id=1, 18 | eos_token_id=2, 19 | vocab_size=64000, 20 | use_cache=True, 21 | init_std=0.02, 22 | tie_word_embeddings=False, 23 | ##### model config 24 | # attention config 25 | embed_dim=1024, 26 | num_heads=8, 27 | kv_heads=-1, 28 | bias=False, 29 | base=10000, 30 | ape_type="sincos", 31 | # ffn config 32 | mid_dim=1024, 33 | ffn_activation="silu", 34 | # others 35 | max_position_embeddings=1024, 36 | num_layers=24, 37 | norm_type="layernorm", 38 | token_mixer_init_type=0, 39 | init_type=0, 40 | rescale_type=0, 41 | use_embed_scale=False, 42 | pad_embed_dim=True, 43 | **kwargs, 44 | ): 45 | super().__init__( 46 | pad_token_id=pad_token_id, 47 | bos_token_id=bos_token_id, 48 | eos_token_id=eos_token_id, 49 | tie_word_embeddings=tie_word_embeddings, 50 | **kwargs, 51 | ) 52 | assert ape_type in ["sincos", "learnable", "mlp"] 53 | ##### hf origin 54 | self.vocab_size = vocab_size 55 | self.use_cache = use_cache 56 | self.init_std = init_std 57 | ##### add 58 | # attention config 59 | self.embed_dim = embed_dim 60 | self.num_heads = num_heads 61 | self.kv_heads = kv_heads 62 | self.bias = bias 63 | self.base = base 64 | self.ape_type = ape_type 65 | # ffn config 66 | self.mid_dim = mid_dim 67 | self.ffn_activation = ffn_activation 68 | # others 69 | self.max_position_embeddings = max_position_embeddings 70 | self.num_layers = num_layers 71 | self.norm_type = norm_type 72 | self.token_mixer_init_type = token_mixer_init_type 73 | self.init_type = init_type 74 | self.rescale_type = rescale_type 75 | self.use_embed_scale = use_embed_scale 76 | -------------------------------------------------------------------------------- /xmixers/modules/normalizations/layer_norm.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | from typing import List, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.init as init 7 | from torch import Size 8 | 9 | _shape_t = Union[int, List[int], Size] 10 | 11 | from .utils import NormOp 12 | 13 | 14 | class LayerNorm(torch.nn.Module): 15 | def __init__( 16 | self, 17 | normalized_shape: _shape_t, 18 | eps: float = 1e-5, 19 | elementwise_affine: bool = True, 20 | bias: bool = True, 21 | device=None, 22 | dtype=None, 23 | **kwargs, 24 | ) -> None: 25 | factory_kwargs = {"device": device, "dtype": dtype} 26 | super().__init__() 27 | if isinstance(normalized_shape, numbers.Integral): 28 | # mypy error: incompatible types in assignment 29 | normalized_shape = (normalized_shape,) # type: ignore[assignment] 30 | self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] 31 | self.eps = eps 32 | self.elementwise_affine = elementwise_affine 33 | if self.elementwise_affine: 34 | self.weight = nn.Parameter( 35 | torch.empty(self.normalized_shape, **factory_kwargs) 36 | ) 37 | if bias: 38 | self.bias = nn.Parameter( 39 | torch.empty(self.normalized_shape, **factory_kwargs) 40 | ) 41 | else: 42 | self.register_parameter("bias", None) 43 | else: 44 | self.register_parameter("weight", None) 45 | self.register_parameter("bias", None) 46 | self.op = NormOp(norm_type="layernorm") 47 | 48 | self._init_weights() 49 | 50 | def _init_weights(self) -> None: 51 | if self.elementwise_affine: 52 | init.ones_(self.weight) 53 | if self.bias is not None: 54 | init.zeros_(self.bias) 55 | 56 | def forward(self, x, residual=None, return_residual=False): 57 | return self.op( 58 | x, 59 | self.weight, 60 | self.bias, 61 | residual, 62 | x.shape[-1], 63 | self.eps, 64 | True, 65 | 1, 66 | return_residual, 67 | ) 68 | 69 | def extra_repr(self) -> str: 70 | return ( 71 | "{normalized_shape}, eps={eps}, " 72 | "elementwise_affine={elementwise_affine}".format(**self.__dict__) 73 | ) 74 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/mamba2/configuration_mamba2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ Mamba2 configuration""" 3 | 4 | from transformers.configuration_utils import PretrainedConfig 5 | from transformers.utils import logging 6 | 7 | logger = logging.get_logger(__name__) 8 | 9 | 10 | class Mamba2XmixersConfig(PretrainedConfig): 11 | model_type = "mamba2_" 12 | keys_to_ignore_at_inference = ["past_key_values"] 13 | 14 | def __init__( 15 | self, 16 | pad_token_id=0, 17 | bos_token_id=1, 18 | eos_token_id=2, 19 | vocab_size=64000, 20 | use_cache=True, 21 | init_std=0.02, 22 | tie_word_embeddings=False, 23 | ########## model config 24 | ##### token mixer config 25 | token_mixer_type="mamba2", 26 | embed_dim=1024, 27 | d_state=64, 28 | d_conv=4, 29 | conv_init=None, 30 | expand=2, 31 | headdim=128, 32 | ngroups=1, 33 | A_init_range=(1, 16), 34 | dt_min=0.001, 35 | dt_max=0.1, 36 | dt_init_floor=1e-4, 37 | dt_limit=(0.0, float("inf")), 38 | activation="swish", 39 | bias=False, 40 | conv_bias=True, 41 | chunk_size=256, 42 | layer_idx=0, 43 | token_mixer_norm_type="rmsnorm", 44 | gate_act="sigmoid", 45 | gate_pos="pre", 46 | use_lightning=False, 47 | ##### others 48 | max_position_embeddings=1024, 49 | num_layers=24, 50 | norm_type="rmsnorm", 51 | causal=True, 52 | use_embed_scale=False, 53 | ce_type="xopes_flce", 54 | pad_embed_dim=True, 55 | ##### init 56 | init_type=1, 57 | token_mixer_init_type=4, 58 | rescale_type=2, 59 | gain=0.01, 60 | channel_mixer_init_type=0, 61 | **kwargs, 62 | ): 63 | super().__init__( 64 | pad_token_id=pad_token_id, 65 | bos_token_id=bos_token_id, 66 | eos_token_id=eos_token_id, 67 | tie_word_embeddings=tie_word_embeddings, 68 | **kwargs, 69 | ) 70 | for key, value in locals().items(): 71 | if key not in [ 72 | "self", 73 | "kwargs", 74 | "__class__", 75 | "pad_token_id", 76 | "bos_token_id", 77 | "eos_token_id", 78 | "tie_word_embeddings", 79 | ]: 80 | setattr(self, key, value) 81 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/lightnet/configuration_lightnet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ LightNet configuration""" 3 | 4 | from transformers.configuration_utils import PretrainedConfig 5 | from transformers.utils import logging 6 | 7 | logger = logging.get_logger(__name__) 8 | 9 | 10 | class LightNetConfig(PretrainedConfig): 11 | model_type = "lightnet" 12 | keys_to_ignore_at_inference = ["past_key_values"] 13 | 14 | def __init__( 15 | self, 16 | pad_token_id=0, 17 | bos_token_id=1, 18 | eos_token_id=2, 19 | vocab_size=64000, 20 | use_cache=True, 21 | init_std=0.02, 22 | tie_word_embeddings=False, 23 | ########## model config 24 | ##### token mixer config 25 | token_mixer_type="lightnet", 26 | embed_dim=1024, 27 | num_heads=8, 28 | bias=False, 29 | use_lrpe=True, 30 | lrpe_type=1, 31 | base=10000, 32 | gate_act="sigmoid", 33 | gate_pos="pre", 34 | use_input_gate=False, 35 | token_mixer_norm_type="rmsnorm", 36 | use_tpe=True, 37 | ##### channel mixer config 38 | channel_mixer_type="glu", 39 | mid_dim=1024, 40 | channel_mixer_activation="silu", 41 | use_gate_linear=True, 42 | ##### others 43 | max_position_embeddings=1024, 44 | num_layers=24, 45 | use_output_gate=True, 46 | norm_type="rmsnorm", 47 | q_activation="silu", 48 | k_activation="silu", 49 | scalar_decay=False, 50 | use_embed_scale=False, 51 | causal=True, 52 | ce_type="xopes_flce", 53 | pad_embed_dim=True, 54 | ##### init 55 | init_type=1, 56 | token_mixer_init_type=4, 57 | rescale_type=2, 58 | gain=0.01, 59 | channel_mixer_init_type=0, 60 | **kwargs, 61 | ): 62 | super().__init__( 63 | pad_token_id=pad_token_id, 64 | bos_token_id=bos_token_id, 65 | eos_token_id=eos_token_id, 66 | tie_word_embeddings=tie_word_embeddings, 67 | **kwargs, 68 | ) 69 | for key, value in locals().items(): 70 | if key not in [ 71 | "self", 72 | "kwargs", 73 | "__class__", 74 | "pad_token_id", 75 | "bos_token_id", 76 | "eos_token_id", 77 | "tie_word_embeddings", 78 | ]: 79 | setattr(self, key, value) 80 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/mesa_net/configuration_mesa_net.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ MesaNet configuration""" 3 | 4 | from transformers.configuration_utils import PretrainedConfig 5 | from transformers.utils import logging 6 | 7 | logger = logging.get_logger(__name__) 8 | 9 | 10 | class MesaNetConfig(PretrainedConfig): 11 | model_type = "mesa_net" 12 | keys_to_ignore_at_inference = ["past_key_values"] 13 | 14 | def __init__( 15 | self, 16 | pad_token_id=0, 17 | bos_token_id=1, 18 | eos_token_id=2, 19 | vocab_size=64000, 20 | use_cache=True, 21 | init_std=0.02, 22 | tie_word_embeddings=False, 23 | ########## model config 24 | ##### token mixer config 25 | token_mixer_type="mesa_unit", 26 | embed_dim=1024, 27 | num_heads=8, 28 | bias=False, 29 | token_mixer_norm_type="rmsnorm", 30 | q_activation="silu", 31 | k_activation="silu", 32 | causal=True, 33 | gate_act="sigmoid", 34 | gate_pos="pre", 35 | threshold=0.99, 36 | lambda_initial_value=1.0, 37 | lambda_lower_bound=0.25, 38 | max_cg_step_training=30, 39 | max_cg_step_decoding=30, 40 | ##### channel mixer config 41 | channel_mixer_type="glu", 42 | mid_dim=1024, 43 | channel_mixer_activation="silu", 44 | use_gate_linear=True, 45 | ##### others 46 | max_position_embeddings=1024, 47 | use_output_gate=True, 48 | norm_type="rmsnorm", 49 | num_layers=12, 50 | use_embed_scale=False, 51 | ce_type="xopes_flce", 52 | pad_embed_dim=True, 53 | ##### init 54 | init_type=1, 55 | token_mixer_init_type=4, 56 | rescale_type=2, 57 | gain=0.01, 58 | channel_mixer_init_type=0, 59 | **kwargs, 60 | ): 61 | super().__init__( 62 | pad_token_id=pad_token_id, 63 | bos_token_id=bos_token_id, 64 | eos_token_id=eos_token_id, 65 | tie_word_embeddings=tie_word_embeddings, 66 | **kwargs, 67 | ) 68 | for key, value in locals().items(): 69 | if key not in [ 70 | "self", 71 | "kwargs", 72 | "__class__", 73 | "pad_token_id", 74 | "bos_token_id", 75 | "eos_token_id", 76 | "tie_word_embeddings", 77 | ]: 78 | setattr(self, key, value) 79 | -------------------------------------------------------------------------------- /xmixers/modules/token_mixers/deep_memory/optimizer/base_optimizer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Dict, Optional 3 | 4 | import torch 5 | 6 | from .utils import get_pooling_fn 7 | 8 | 9 | class FastWeightOptimizer(ABC): 10 | """ 11 | Abstract base class for custom optimizers. 12 | 13 | This class provides common functionality for optimizers including 14 | parameter management, gradient zeroing, and pooling operations. 15 | """ 16 | 17 | @abstractmethod 18 | def __init__( 19 | self, fast_weight: Dict[str, torch.nn.Parameter], pooling_type: str = "mean" 20 | ): 21 | """ 22 | Initialize the optimizer. 23 | 24 | Args: 25 | fast_weight: Dictionary mapping parameter names to parameters 26 | pooling_type: Type of pooling to use for tensor operations 27 | """ 28 | self.fast_weight = fast_weight 29 | self.pooling_fn = get_pooling_fn(pooling_type) 30 | self.state: Dict[str, Dict[str, Any]] = {} 31 | 32 | def setup_state(self) -> None: 33 | """Initialize state dictionaries for all parameters.""" 34 | self.state = {} 35 | for name in self.params_dict.keys(): 36 | self.state[name] = {} 37 | 38 | def zero_grad(self) -> None: 39 | """Zero out gradients for all parameters.""" 40 | for param in self.params_dict.values(): 41 | if param.grad is not None: 42 | param.grad.zero_() 43 | 44 | def pooling(self, x: torch.Tensor) -> torch.Tensor: 45 | """ 46 | Apply pooling operation to tensor if it has multiple elements in dimension 1. 47 | 48 | Args: 49 | x: Input tensor 50 | 51 | Returns: 52 | Pooled tensor or original tensor if pooling not applicable 53 | """ 54 | if isinstance(x, torch.Tensor) and x.shape[1] > 1: 55 | return self.pooling_fn(x) 56 | else: 57 | return x 58 | 59 | @abstractmethod 60 | def step( 61 | self, 62 | lr_dict: Optional[Dict[str, float]] = None, 63 | wd_dict: Optional[Dict[str, float]] = None, 64 | momentum_dict: Optional[Dict[str, float]] = None, 65 | ) -> None: 66 | """ 67 | Perform a single optimization step. 68 | 69 | Args: 70 | lr_dict: Dictionary of learning rates per parameter 71 | wd_dict: Dictionary of weight decay values per parameter 72 | momentum_dict: Dictionary of momentum values per parameter 73 | """ 74 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/linear_transformer/configuration_linear_transformer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ LinearTransformer configuration""" 3 | 4 | from transformers.configuration_utils import PretrainedConfig 5 | from transformers.utils import logging 6 | 7 | logger = logging.get_logger(__name__) 8 | 9 | 10 | class LinearTransformerConfig(PretrainedConfig): 11 | model_type = "linear_transformer" 12 | keys_to_ignore_at_inference = ["past_key_values"] 13 | 14 | def __init__( 15 | self, 16 | pad_token_id=0, 17 | bos_token_id=1, 18 | eos_token_id=2, 19 | vocab_size=64000, 20 | use_cache=True, 21 | init_std=0.02, 22 | tie_word_embeddings=False, 23 | ########## model config 24 | ##### token mixer config 25 | token_mixer_type="linear_attn", 26 | embed_dim=1024, 27 | num_heads=8, 28 | kv_heads=-1, 29 | bias=False, 30 | use_lrpe=True, 31 | lrpe_type=1, 32 | base=10000, 33 | gate_act="sigmoid", 34 | gate_pos="pre", 35 | token_mixer_norm_type="rmsnorm", 36 | use_tpe=True, 37 | ##### channel mixer config 38 | channel_mixer_type="glu", 39 | mid_dim=1024, 40 | channel_mixer_activation="silu", 41 | use_gate_linear=True, 42 | ##### others 43 | max_position_embeddings=1024, 44 | num_layers=24, 45 | use_output_gate=True, 46 | norm_type="rmsnorm", 47 | linear_activation="silu", 48 | causal=True, 49 | use_ape=False, 50 | use_embed_scale=False, 51 | use_dense_memory=False, 52 | ce_type="xopes_flce", 53 | pad_embed_dim=True, 54 | ##### init 55 | init_type=1, 56 | token_mixer_init_type=4, 57 | rescale_type=2, 58 | gain=0.01, 59 | channel_mixer_init_type=0, 60 | **kwargs, 61 | ): 62 | super().__init__( 63 | pad_token_id=pad_token_id, 64 | bos_token_id=bos_token_id, 65 | eos_token_id=eos_token_id, 66 | tie_word_embeddings=tie_word_embeddings, 67 | **kwargs, 68 | ) 69 | for key, value in locals().items(): 70 | if key not in [ 71 | "self", 72 | "kwargs", 73 | "__class__", 74 | "pad_token_id", 75 | "bos_token_id", 76 | "eos_token_id", 77 | "tie_word_embeddings", 78 | ]: 79 | setattr(self, key, value) 80 | -------------------------------------------------------------------------------- /xmixers/modules/token_mixers/long_conv/gtu.py: -------------------------------------------------------------------------------- 1 | """ 2 | Gated Toeplitz Unit in https://arxiv.org/pdf/2305.04749.pdf 3 | Add support for multi dimension. 4 | For example, if the input dim is (h, w, d), 5 | then the output is g * (tno1(x) + tno2(x)). 6 | """ 7 | 8 | from typing import List 9 | 10 | import torch.nn as nn 11 | 12 | from xmixers.modules.activations import get_activation_fn 13 | from xmixers.utils import XMIXERS_DEBUG, print_params 14 | 15 | from .tno import Tno 16 | 17 | 18 | class Gtu(nn.Module): 19 | def __init__( 20 | self, 21 | embed_dim: int, 22 | expand_ratio: int = 1, 23 | bias: bool = False, 24 | activation: str = "silu", 25 | causal: bool = True, 26 | norm_type: str = "layernorm", 27 | use_decay: bool = True, 28 | rpe_in_dim: int = 1, 29 | rpe_feature_dim: int = 32, 30 | rpe_layers: int = 3, 31 | dims: List[int] = [-2], 32 | lower_bound: float = 0.99, 33 | *args, 34 | **kwargs, 35 | ) -> None: 36 | super().__init__() 37 | 38 | if XMIXERS_DEBUG: 39 | # get local varables 40 | params = locals() 41 | # print params 42 | print_params(**params) 43 | 44 | self.embed_dim = embed_dim 45 | self.expand_ratio = expand_ratio 46 | 47 | d1 = int(self.expand_ratio * embed_dim) 48 | # linear projection 49 | self.uv_proj = nn.Linear(embed_dim, 2 * d1, bias=bias) 50 | self.o = nn.Linear(d1, embed_dim, bias=bias) 51 | self.act = get_activation_fn(activation) 52 | self.dims = dims 53 | self.tno_list = nn.ModuleList([]) 54 | for dim in dims: 55 | # tno 56 | self.tno_list.append( 57 | Tno( 58 | in_dim=rpe_in_dim, 59 | feature_dim=rpe_feature_dim, 60 | out_dim=d1, 61 | activation=activation, 62 | bias=bias, 63 | rpe_layers=rpe_layers, 64 | norm_type=norm_type, 65 | use_decay=use_decay, 66 | causal=causal, 67 | dim=dim, 68 | lower_bound=lower_bound, 69 | ) 70 | ) 71 | 72 | def forward(self, x): 73 | # x: b, n, d 74 | u, v = self.act(self.uv_proj(x)).chunk(2, dim=-1) 75 | output = 0 76 | for tno in self.tno_list: 77 | output += tno(v) 78 | output = u * output 79 | 80 | output = self.o(output) 81 | 82 | return output 83 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/implicit_linear_transformer/configuration_implicit_linear_transformer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ ImplicitLinearTransformer configuration""" 3 | 4 | from transformers.configuration_utils import PretrainedConfig 5 | from transformers.utils import logging 6 | 7 | logger = logging.get_logger(__name__) 8 | 9 | 10 | class ImplicitLinearTransformerConfig(PretrainedConfig): 11 | model_type = "implicit_linear_transformer" 12 | keys_to_ignore_at_inference = ["past_key_values"] 13 | 14 | def __init__( 15 | self, 16 | pad_token_id=0, 17 | bos_token_id=1, 18 | eos_token_id=2, 19 | vocab_size=64000, 20 | use_cache=True, 21 | init_std=0.02, 22 | tie_word_embeddings=False, 23 | ########## model config 24 | ##### token mixer config 25 | embed_dim=1024, 26 | num_heads=8, 27 | bias=False, 28 | scalar_decay=False, 29 | token_mixer_type="implicit_value_attn", 30 | gate_act="sigmoid", 31 | gate_pos="pre", 32 | token_mixer_norm_type="rmsnorm", 33 | use_decay=True, 34 | ##### channel mixer config 35 | mid_dim=1024, 36 | channel_mixer_type="glu", 37 | channel_mixer_activation="silu", 38 | use_gate_linear=True, 39 | ##### others 40 | max_position_embeddings=1024, 41 | num_layers=24, 42 | use_output_gate=True, 43 | norm_type="rmsnorm", 44 | q_activation="silu", 45 | k_activation="silu", 46 | threshold=0.99, 47 | use_offset=False, 48 | causal=True, 49 | use_embed_scale=False, 50 | ce_type="xopes_flce", 51 | pad_embed_dim=True, 52 | ##### init 53 | init_type=1, 54 | token_mixer_init_type=4, 55 | rescale_type=2, 56 | gain=0.01, 57 | channel_mixer_init_type=0, 58 | **kwargs, 59 | ): 60 | super().__init__( 61 | pad_token_id=pad_token_id, 62 | bos_token_id=bos_token_id, 63 | eos_token_id=eos_token_id, 64 | tie_word_embeddings=tie_word_embeddings, 65 | **kwargs, 66 | ) 67 | if scalar_decay: 68 | share_decay = False 69 | for key, value in locals().items(): 70 | if key not in [ 71 | "self", 72 | "kwargs", 73 | "__class__", 74 | "pad_token_id", 75 | "bos_token_id", 76 | "eos_token_id", 77 | "tie_word_embeddings", 78 | ]: 79 | setattr(self, key, value) 80 | -------------------------------------------------------------------------------- /tests/script.sh: -------------------------------------------------------------------------------- 1 | date=$(date '+%Y-%m-%d-%H:%M:%S') 2 | 3 | folder=models 4 | file=test 5 | 6 | model_type=llama 7 | model_type=mpa 8 | # model_type=tpa 9 | # model_type=hgrn2 10 | # model_type=lightnet 11 | # model_type=lightnet_scalar_decay 12 | # model_type=lightnet_no_tpe 13 | # model_type=lightnet_no_tpe_scalar_decay 14 | # model_type=mla 15 | # model_type=tnl 16 | # model_type=tnl_state 17 | # model_type=hgrn2_scalar_decay 18 | # model_type=linear_transformer 19 | # model_type=linear_transformer_no_tpe 20 | # model_type=cosformer2 21 | # model_type=cosformer2_no_tpe 22 | # model_type=naive_deltanet 23 | # model_type=scalar_decay_deltanet 24 | # model_type=scalar_decay_lower_bound_deltanet 25 | # model_type=vector_decay_deltanet 26 | # model_type=vector_decay_lower_bound_deltanet 27 | # model_type=dense_rnn 28 | # model_type=dense_rnn_lower_bound 29 | # model_type=decay_linear_transformer_hgrn2 30 | # model_type=decay_linear_transformer_hgrn2_scalar_decay 31 | # model_type=decay_linear_transformer_mamba 32 | # model_type=decay_linear_transformer_mamba_scalar_decay 33 | # model_type=decay_linear_transformer_gla 34 | # model_type=decay_linear_transformer_gla_scalar_decay 35 | # model_type=decay_linear_transformer_lightnet 36 | # model_type=decay_linear_transformer_lightnet_share_decay 37 | # model_type=decay_linear_transformer_lightnet_scalar_decay 38 | # model_type=decay_linear_transformer_lssp 39 | # model_type=decay_linear_transformer_lssp_scalar_decay 40 | # model_type=decay_linear_transformer_tnl 41 | # model_type=decay_linear_transformer_tnl_scalar_decay 42 | # model_type=decay_linear_transformer_hgrn2_rope 43 | # model_type=decay_linear_transformer_hgrn2_rope_scalar_decay 44 | # model_type=nsa 45 | # model_type=gsa 46 | # model_type=ttt 47 | # model_type=hgrn3 48 | # model_type=hgrn3_scalar_decay 49 | # model_type=alibi 50 | model_type=fox 51 | # model_type=fox_window 52 | # model_type=sb_attn 53 | # model_type=mfa 54 | # model_type=mfa_kv_share 55 | # model_type=hgrn1 56 | # model_type=mamba2 57 | # model_type=scalar_decay_delta_product_net 58 | # model_type=implicit_value_attn 59 | # model_type=mesa_net 60 | # model_type=poly_net 61 | # model_type=kernel_regression_attn 62 | # model_type=kernel_regression_attn_no_decay 63 | # model_type=kernel_regression_attn_no_kr 64 | # model_type=path_attn 65 | # model_type=path_attn_no_decay 66 | model_type=path_attn_no_beta 67 | # model_type=fox_offset 68 | # model_type=kernel_regression_attn_qk_norm 69 | 70 | dtype=bf16 71 | # dtype=fp32 72 | 73 | mkdir -p $folder/log 74 | 75 | export CUDA_VISIBLE_DEVICES=1 76 | 77 | python $folder/${file}.py --model_type $model_type --dtype $dtype 2>&1 | tee -a $folder/log/${date}-${file}.log 78 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/dense_rnn/configuration_dense_rnn.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ DenseRnn configuration""" 3 | 4 | from transformers.configuration_utils import PretrainedConfig 5 | from transformers.utils import logging 6 | 7 | logger = logging.get_logger(__name__) 8 | 9 | 10 | class DenseRnnConfig(PretrainedConfig): 11 | model_type = "dense_rnn" 12 | keys_to_ignore_at_inference = ["past_key_values"] 13 | 14 | def __init__( 15 | self, 16 | pad_token_id=0, 17 | bos_token_id=1, 18 | eos_token_id=2, 19 | vocab_size=64000, 20 | use_cache=True, 21 | init_std=0.02, 22 | tie_word_embeddings=False, 23 | ########## model config 24 | ##### token mixer config 25 | token_mixer_type="dense_rnn", 26 | embed_dim=1024, 27 | num_heads=8, 28 | bias=False, 29 | token_mixer_norm_type="rmsnorm", 30 | q_activation="silu", 31 | k_activation="silu", 32 | v_activation="silu", 33 | use_beta=True, 34 | beta_activation="neg", 35 | qkv_norm_type=2, 36 | norm_q=False, 37 | norm_v=False, 38 | causal=True, 39 | gate_act="sigmoid", 40 | gate_pos="pre", 41 | threshold=0.99, 42 | use_offset=False, 43 | num_blocks=1, 44 | ##### channel mixer config 45 | channel_mixer_type="glu", 46 | mid_dim=1024, 47 | channel_mixer_activation="silu", 48 | use_gate_linear=True, 49 | ##### others 50 | use_lower_bound=False, 51 | max_position_embeddings=1024, 52 | use_output_gate=True, 53 | norm_type="rmsnorm", 54 | num_layers=12, 55 | use_embed_scale=False, 56 | ce_type="xopes_flce", 57 | pad_embed_dim=True, 58 | ##### init 59 | init_type=1, 60 | token_mixer_init_type=4, 61 | rescale_type=2, 62 | gain=0.01, 63 | channel_mixer_init_type=0, 64 | **kwargs, 65 | ): 66 | super().__init__( 67 | pad_token_id=pad_token_id, 68 | bos_token_id=bos_token_id, 69 | eos_token_id=eos_token_id, 70 | tie_word_embeddings=tie_word_embeddings, 71 | **kwargs, 72 | ) 73 | for key, value in locals().items(): 74 | if key not in [ 75 | "self", 76 | "kwargs", 77 | "__class__", 78 | "pad_token_id", 79 | "bos_token_id", 80 | "eos_token_id", 81 | "tie_word_embeddings", 82 | ]: 83 | setattr(self, key, value) 84 | -------------------------------------------------------------------------------- /xmixers/models/long_conv/tnn/configuration_tnn.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ Tnn configuration""" 3 | 4 | from typing import List 5 | 6 | from transformers.configuration_utils import PretrainedConfig 7 | from transformers.utils import logging 8 | 9 | logger = logging.get_logger(__name__) 10 | 11 | 12 | class TnnConfig(PretrainedConfig): 13 | model_type = "tnn" 14 | keys_to_ignore_at_inference = ["past_key_values"] 15 | 16 | def __init__( 17 | self, 18 | pad_token_id: int = 1, 19 | bos_token_id: int = 0, 20 | eos_token_id: int = 2, 21 | vocab_size: int = 50272, 22 | use_cache: bool = True, 23 | init_std: float = 0.02, 24 | ##### model config 25 | # gtu config 26 | embed_dim: int = 768, 27 | expand_ratio: int = 1, 28 | bias: bool = False, 29 | gtu_activation: str = "silu", 30 | causal: bool = True, 31 | norm_type: str = "layernorm", 32 | use_decay: bool = True, 33 | rpe_in_dim: int = 1, 34 | rpe_feature_dim: int = 32, # for rpe in tno 35 | rpe_layers: int = 3, 36 | dims: List[int] = [-2], 37 | lower_bound: float = 0.99, 38 | # glu config 39 | mid_dim: int = 1024, 40 | glu_activation: str = "silu", 41 | # others 42 | num_layers: int = 24, 43 | add_bos_token: bool = False, 44 | max_position_embeddings: int = 2048, 45 | initializer_range: float = 0.02, 46 | **kwargs, 47 | ) -> None: 48 | super().__init__( 49 | pad_token_id=pad_token_id, 50 | bos_token_id=bos_token_id, 51 | eos_token_id=eos_token_id, 52 | **kwargs, 53 | ) 54 | ##### hf origin 55 | self.vocab_size = vocab_size 56 | self.use_cache = use_cache 57 | self.init_std = init_std 58 | ##### add 59 | self.embed_dim = embed_dim 60 | self.expand_ratio = expand_ratio 61 | self.bias = bias 62 | 63 | self.gtu_activation = gtu_activation 64 | self.causal = causal 65 | self.norm_type = norm_type 66 | self.use_decay = use_decay 67 | self.rpe_in_dim = rpe_in_dim 68 | self.rpe_feature_dim = rpe_feature_dim 69 | self.rpe_layers = rpe_layers 70 | self.dims = dims 71 | self.lower_bound = lower_bound 72 | # glu config 73 | self.mid_dim = mid_dim 74 | self.glu_activation = glu_activation 75 | # others 76 | self.num_layers = num_layers 77 | self.add_bos_token = add_bos_token 78 | self.max_position_embeddings = max_position_embeddings 79 | self.initializer_range = initializer_range 80 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Initially taken from Github's Python gitignore file 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # tests and logs 12 | tests/fixtures/cached_*_text.txt 13 | logs/ 14 | lightning_logs/ 15 | lang_code_data/ 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # celery beat schedule file 92 | celerybeat-schedule 93 | 94 | # SageMath parsed files 95 | *.sage.py 96 | 97 | # Environments 98 | .env 99 | .venv 100 | env/ 101 | venv/ 102 | ENV/ 103 | env.bak/ 104 | venv.bak/ 105 | 106 | # Spyder project settings 107 | .spyderproject 108 | .spyproject 109 | 110 | # Rope project settings 111 | .ropeproject 112 | 113 | # mkdocs documentation 114 | /site 115 | 116 | # mypy 117 | .mypy_cache/ 118 | .dmypy.json 119 | dmypy.json 120 | 121 | # Pyre type checker 122 | .pyre/ 123 | 124 | # vscode 125 | .vs 126 | .vscode 127 | 128 | # Pycharm 129 | .idea 130 | 131 | # TF code 132 | tensorflow_code 133 | 134 | # Models 135 | proc_data 136 | 137 | # examples 138 | runs 139 | /runs_old 140 | /wandb 141 | /examples/runs 142 | /examples/**/*.args 143 | /examples/rag/sweep 144 | 145 | # data 146 | /data 147 | serialization_dir 148 | 149 | # emacs 150 | *.*~ 151 | debug.env 152 | 153 | # vim 154 | .*.swp 155 | 156 | #ctags 157 | tags 158 | 159 | # pre-commit 160 | .pre-commit* 161 | 162 | # .lock 163 | *.lock 164 | 165 | # DS_Store (MacOS) 166 | .DS_Store 167 | 168 | # ruff 169 | .ruff_cache 170 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/deltanet/configuration_deltanet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ DeltaNet configuration""" 3 | 4 | from transformers.configuration_utils import PretrainedConfig 5 | from transformers.utils import logging 6 | 7 | logger = logging.get_logger(__name__) 8 | 9 | 10 | class DeltaNetConfig(PretrainedConfig): 11 | model_type = "delta_net" 12 | keys_to_ignore_at_inference = ["past_key_values"] 13 | 14 | def __init__( 15 | self, 16 | pad_token_id=0, 17 | bos_token_id=1, 18 | eos_token_id=2, 19 | vocab_size=64000, 20 | use_cache=True, 21 | init_std=0.02, 22 | tie_word_embeddings=False, 23 | ########## model config 24 | ##### token mixer config 25 | token_mixer_type="delta_unit", 26 | embed_dim=1024, 27 | num_heads=8, 28 | bias=False, 29 | token_mixer_norm_type="rmsnorm", 30 | q_activation="silu", 31 | k_activation="silu", 32 | v_activation="silu", 33 | use_beta=True, 34 | beta_activation="neg", 35 | use_decay=False, 36 | scalar_decay=False, 37 | qkv_norm_type=2, 38 | norm_q=False, 39 | norm_v=False, 40 | causal=True, 41 | gate_act="sigmoid", 42 | gate_pos="pre", 43 | use_input_gate=False, 44 | rank=2, 45 | use_offset=False, 46 | threshold=0.99, 47 | ##### channel mixer config 48 | channel_mixer_type="glu", 49 | mid_dim=1024, 50 | channel_mixer_activation="silu", 51 | use_gate_linear=True, 52 | ##### others 53 | use_lower_bound=False, 54 | max_position_embeddings=1024, 55 | use_output_gate=True, 56 | norm_type="rmsnorm", 57 | num_layers=12, 58 | use_embed_scale=False, 59 | ce_type="xopes_flce", 60 | pad_embed_dim=True, 61 | ##### init 62 | init_type=1, 63 | token_mixer_init_type=4, 64 | rescale_type=2, 65 | gain=0.01, 66 | channel_mixer_init_type=0, 67 | **kwargs, 68 | ): 69 | super().__init__( 70 | pad_token_id=pad_token_id, 71 | bos_token_id=bos_token_id, 72 | eos_token_id=eos_token_id, 73 | tie_word_embeddings=tie_word_embeddings, 74 | **kwargs, 75 | ) 76 | for key, value in locals().items(): 77 | if key not in [ 78 | "self", 79 | "kwargs", 80 | "__class__", 81 | "pad_token_id", 82 | "bos_token_id", 83 | "eos_token_id", 84 | "tie_word_embeddings", 85 | ]: 86 | setattr(self, key, value) 87 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/metala/configuration_metala.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ MetaLa configuration""" 3 | 4 | from transformers.configuration_utils import PretrainedConfig 5 | from transformers.utils import logging 6 | 7 | logger = logging.get_logger(__name__) 8 | 9 | 10 | class MetaLaConfig(PretrainedConfig): 11 | model_type = "metala" 12 | keys_to_ignore_at_inference = ["past_key_values"] 13 | 14 | def __init__( 15 | self, 16 | pad_token_id=0, 17 | bos_token_id=1, 18 | eos_token_id=2, 19 | vocab_size=64000, 20 | use_cache=True, 21 | init_std=0.02, 22 | tie_word_embeddings=False, 23 | ##### model config 24 | # attention config 25 | embed_dim=1024, 26 | expand_ratio=8, 27 | bias=False, 28 | # glu config 29 | mid_dim=1024, 30 | glu_activation="silu", 31 | # others 32 | max_position_embeddings=1024, 33 | num_layers=24, 34 | use_output_gate=False, 35 | non_sparse_ratio=1, 36 | num_sparse=4, 37 | norm_type="layernorm", 38 | q_activation="silu", 39 | causal=True, 40 | use_embed_scale=False, 41 | pad_embed_dim=True, 42 | # init 43 | init_type=0, 44 | token_mixer_init_type=0, 45 | rescale_type=0, 46 | channel_mixer_init_type=0, 47 | **kwargs, 48 | ): 49 | super().__init__( 50 | pad_token_id=pad_token_id, 51 | bos_token_id=bos_token_id, 52 | eos_token_id=eos_token_id, 53 | tie_word_embeddings=tie_word_embeddings, 54 | **kwargs, 55 | ) 56 | ##### hf origin 57 | self.vocab_size = vocab_size 58 | self.use_cache = use_cache 59 | self.init_std = init_std 60 | ##### add 61 | # attention config 62 | self.embed_dim = embed_dim 63 | self.expand_ratio = expand_ratio 64 | self.bias = bias 65 | # glu config 66 | self.mid_dim = mid_dim 67 | self.glu_activation = glu_activation 68 | # others 69 | self.max_position_embeddings = max_position_embeddings 70 | self.num_layers = num_layers 71 | self.use_output_gate = use_output_gate 72 | self.non_sparse_ratio = non_sparse_ratio 73 | self.num_sparse = num_sparse 74 | self.norm_type = norm_type 75 | self.q_activation = q_activation 76 | self.causal = causal 77 | self.use_embed_scale = use_embed_scale 78 | # init 79 | self.init_type = init_type 80 | self.token_mixer_init_type = token_mixer_init_type 81 | self.rescale_type = rescale_type 82 | self.channel_mixer_init_type = channel_mixer_init_type 83 | -------------------------------------------------------------------------------- /xmixers/models/chunk_linear_transformer/chunk_rnn/configuration_chunk_rnn.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ ChunkRnn configuration""" 3 | 4 | from transformers.configuration_utils import PretrainedConfig 5 | from transformers.utils import logging 6 | 7 | logger = logging.get_logger(__name__) 8 | 9 | 10 | class ChunkRnnConfig(PretrainedConfig): 11 | model_type = "chunk_rnn" 12 | keys_to_ignore_at_inference = ["past_key_values"] 13 | 14 | def __init__( 15 | self, 16 | pad_token_id=0, 17 | bos_token_id=1, 18 | eos_token_id=2, 19 | vocab_size=64000, 20 | use_cache=True, 21 | init_std=0.02, 22 | tie_word_embeddings=False, 23 | ########## model config 24 | ##### token mixer config 25 | embed_dim=1024, 26 | num_heads=8, 27 | bias=False, 28 | token_mixer_norm_type="rmsnorm", 29 | token_mixer_type="chunk_rnn", 30 | ##### chunk params 31 | chunk_type: int = 0, 32 | gradient_type: int = 0, 33 | use_initial_state: bool = False, 34 | use_scale: bool = False, 35 | chunk_size: int = 128, 36 | threshold=0.99, 37 | decay_type: str = "pos", 38 | decay_fn: str = "mean", 39 | scalar_decay: bool = False, 40 | ##### lrpe 41 | use_lrpe: bool = True, 42 | lrpe_type: int = 1, 43 | base: int = 10000, 44 | ##### channel mixer config 45 | channel_mixer_type="glu", 46 | mid_dim=1024, 47 | channel_mixer_activation="silu", 48 | use_gate_linear=True, 49 | ##### others 50 | max_position_embeddings=1024, 51 | num_layers=24, 52 | use_output_gate=True, 53 | norm_type="rmsnorm", 54 | q_activation="silu", 55 | k_activation="silu", 56 | causal=True, 57 | use_embed_scale=False, 58 | gate_act="sigmoid", 59 | gate_pos="pre", 60 | ce_type="xopes_flce", 61 | pad_embed_dim=True, 62 | ##### init 63 | init_type=1, 64 | token_mixer_init_type=4, 65 | rescale_type=2, 66 | gain=0.01, 67 | channel_mixer_init_type=0, 68 | **kwargs, 69 | ): 70 | super().__init__( 71 | pad_token_id=pad_token_id, 72 | bos_token_id=bos_token_id, 73 | eos_token_id=eos_token_id, 74 | tie_word_embeddings=tie_word_embeddings, 75 | **kwargs, 76 | ) 77 | for key, value in locals().items(): 78 | if key not in [ 79 | "self", 80 | "kwargs", 81 | "__class__", 82 | "pad_token_id", 83 | "bos_token_id", 84 | "eos_token_id", 85 | "tie_word_embeddings", 86 | ]: 87 | setattr(self, key, value) 88 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/decay_linear_transformer/configuration_decay_linear_transformer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ DecayLinearTransformer configuration""" 3 | 4 | from transformers.configuration_utils import PretrainedConfig 5 | from transformers.utils import logging 6 | 7 | logger = logging.get_logger(__name__) 8 | 9 | 10 | class DecayLinearTransformerConfig(PretrainedConfig): 11 | model_type = "decay_linear_transformer" 12 | keys_to_ignore_at_inference = ["past_key_values"] 13 | 14 | def __init__( 15 | self, 16 | pad_token_id=0, 17 | bos_token_id=1, 18 | eos_token_id=2, 19 | vocab_size=64000, 20 | use_cache=True, 21 | init_std=0.02, 22 | tie_word_embeddings=False, 23 | ########## model config 24 | ##### token mixer config 25 | token_mixer_type="decay_linear_attn", 26 | embed_dim=1024, 27 | num_heads=8, 28 | bias=False, 29 | use_lrpe=False, 30 | base=10000, 31 | gate_act="sigmoid", 32 | gate_pos="pre", 33 | token_mixer_norm_type="rmsnorm", 34 | use_tpe=True, 35 | use_lightning=False, 36 | ##### channel mixer config 37 | channel_mixer_type="glu", 38 | mid_dim=1024, 39 | channel_mixer_activation="silu", 40 | use_gate_linear=True, 41 | ##### others 42 | max_position_embeddings=1024, 43 | num_layers=24, 44 | use_output_gate=True, 45 | norm_type="rmsnorm", 46 | q_activation="silu", 47 | k_activation="silu", 48 | scalar_decay=False, 49 | use_embed_scale=False, 50 | causal=True, 51 | ce_type="xopes_flce", 52 | pad_embed_dim=True, 53 | ##### decay parameters 54 | decay_type="hgrn2", # choose from ["hgrn2", "gla", "mamba", "mamba_no_a_no_t", "mamba_no_a", "mamba_no_t", "lightnet", "tnl", "tnll", "lssp", "hgrn3"] # lssp: log sum soft plus 55 | A_init_range=(1, 16), 56 | dt_min=0.001, 57 | dt_max=0.1, 58 | dt_init_floor=1e-4, 59 | dt_limit=(0.0, float("inf")), 60 | gate_denom=16, 61 | share_decay=False, 62 | use_lower_bound=False, 63 | threshold=0.99, 64 | ##### init 65 | init_type=1, 66 | token_mixer_init_type=4, 67 | rescale_type=2, 68 | gain=0.01, 69 | channel_mixer_init_type=0, 70 | **kwargs, 71 | ): 72 | super().__init__( 73 | pad_token_id=pad_token_id, 74 | bos_token_id=bos_token_id, 75 | eos_token_id=eos_token_id, 76 | tie_word_embeddings=tie_word_embeddings, 77 | **kwargs, 78 | ) 79 | if scalar_decay: 80 | share_decay = False 81 | for key, value in locals().items(): 82 | if key not in [ 83 | "self", 84 | "kwargs", 85 | "__class__", 86 | "pad_token_id", 87 | "bos_token_id", 88 | "eos_token_id", 89 | "tie_word_embeddings", 90 | ]: 91 | setattr(self, key, value) 92 | -------------------------------------------------------------------------------- /xmixers/utils/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import os 4 | import sys 5 | 6 | import torch.distributed as dist 7 | from torch import nn 8 | 9 | from .constants import EMBED_DIM_BASE 10 | 11 | logging.basicConfig( 12 | format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 13 | datefmt="%Y-%m-%d %H:%M:%S", 14 | level=os.environ.get("LOGLEVEL", "INFO").upper(), 15 | stream=sys.stdout, 16 | ) 17 | logger = logging.getLogger("xmixers") 18 | 19 | 20 | def is_dist_avail_and_initialized() -> bool: 21 | if not dist.is_available(): 22 | return False 23 | if not dist.is_initialized(): 24 | return False 25 | return True 26 | 27 | 28 | def get_world_size() -> int: 29 | if not is_dist_avail_and_initialized(): 30 | return 1 31 | return dist.get_world_size() 32 | 33 | 34 | def get_rank() -> int: 35 | if not is_dist_avail_and_initialized(): 36 | return 0 37 | return dist.get_rank() 38 | 39 | 40 | def is_main_process() -> bool: 41 | return get_rank() == 0 42 | 43 | 44 | def logging_info(string: str) -> None: 45 | if is_main_process(): 46 | logger.info(string) 47 | 48 | 49 | def print_params(**kwargs) -> None: 50 | if is_main_process(): 51 | logger.info(f"start print config of {kwargs['__class__']}") 52 | for key in kwargs: 53 | if key in ["__class__", "self"]: 54 | continue 55 | logger.info(f"{key}: {kwargs[key]}") 56 | logger.info(f"end print config of {kwargs['__class__']}") 57 | 58 | 59 | def print_config(config) -> None: 60 | if is_main_process(): 61 | logger.info(f"start print config of {config['__class__']}") 62 | for key in config: 63 | if key in ["__class__", "self"]: 64 | continue 65 | logger.info(f"{key}: {config[key]}") 66 | logger.info(f"end print config of {config['__class__']}") 67 | 68 | 69 | def print_module(module: nn.Module) -> str: 70 | named_modules_ = set() 71 | for p in module.named_modules(): 72 | named_modules_.update([p[0]]) 73 | named_modules = list(named_modules_) 74 | 75 | string_repr = "" 76 | for p in module.named_parameters(): 77 | name = p[0].split(".")[0] 78 | if name not in named_modules: 79 | string_repr = ( 80 | string_repr 81 | + "(" 82 | + name 83 | + "): " 84 | + "Tensor(" 85 | + str(tuple(p[1].shape)) 86 | + ", requires_grad=" 87 | + str(p[1].requires_grad) 88 | + ")\n" 89 | ) 90 | 91 | return string_repr.rstrip("\n") 92 | 93 | 94 | def next_power_of_2(n: int) -> int: 95 | return 2 ** (math.ceil(math.log(n, 2))) 96 | 97 | 98 | def endswith(name, keyword_list): 99 | for keyword in keyword_list: 100 | if name.endswith(keyword): 101 | return True 102 | return False 103 | 104 | 105 | def pad_embed_dim(embed_dim: int) -> int: 106 | return (embed_dim + EMBED_DIM_BASE - 1) // EMBED_DIM_BASE * EMBED_DIM_BASE 107 | -------------------------------------------------------------------------------- /xmixers/modules/token_mixers/deep_memory/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .loss import get_loss_fn 5 | from .optimizer import get_optimizer 6 | 7 | 8 | def get_chunk(x, i, chunk_size): 9 | if x is None: 10 | return None 11 | else: 12 | n = x.shape[1] 13 | start = i * chunk_size 14 | end = min(start + chunk_size, n) 15 | return x[:, start:end] 16 | 17 | 18 | def stack_chunk(o): 19 | return torch.cat(o, dim=1) 20 | 21 | 22 | def get_default_value(value_dict, keys, index, chunk_size=128): 23 | res_dict = {} 24 | for key in keys: 25 | if key in value_dict: 26 | res_dict[key] = get_chunk(value_dict[key], index, chunk_size) 27 | 28 | return res_dict 29 | 30 | 31 | def prepare_fast_weight(fast_weight): 32 | # make sure all fast_weight parameters are leaf tensors 33 | for name, param in fast_weight.items(): 34 | if not param.is_leaf: 35 | # create new leaf tensor 36 | new_param = param.detach().requires_grad_(True) 37 | fast_weight[name] = new_param 38 | else: 39 | param.requires_grad_(True) 40 | 41 | return fast_weight 42 | 43 | 44 | def fast_weight_train( 45 | x_val: torch.Tensor, 46 | x: torch.Tensor, 47 | y: torch.Tensor, 48 | fast_weight_model: nn.Module, 49 | hyper_params_dict: dict, 50 | # optimizer 51 | optimizer_name: str = "sgd", 52 | lr: float = 1, 53 | wd: float = 0, 54 | momentum: float = 0, 55 | damping: float = 0, 56 | pooling_method: str = "mean", 57 | # loss 58 | loss_name: str = "mse", 59 | chunk_size: int = 128, 60 | ): 61 | b, n = x_val.shape[0], x_val.shape[1] 62 | # prepare the fast weight 63 | fast_weight = fast_weight_model.init_fast_weight(b) 64 | fast_weight = prepare_fast_weight(fast_weight) 65 | # prepare the optimizer 66 | optimizer = get_optimizer( 67 | optimizer_name, 68 | fast_weight=fast_weight, 69 | ) 70 | # prepare the loss function 71 | loss_fn = get_loss_fn(loss_name) 72 | 73 | # train 74 | y_val = [] 75 | num_chunks = (n + chunk_size - 1) // chunk_size 76 | for i in range(num_chunks): 77 | x_val_chunk = get_chunk(x_val, i, chunk_size) 78 | x_chunk = get_chunk(x, i, chunk_size) 79 | y_chunk = get_chunk(y, i, chunk_size) 80 | 81 | hyper_params_chunk = { 82 | name: get_default_value( 83 | hyper_params_dict[name], fast_weight.keys(), i, chunk_size 84 | ) 85 | for name in hyper_params_dict.keys() 86 | } 87 | 88 | y_val_chunk = fast_weight_model.forward(x_val_chunk, fast_weight) 89 | y_val.append(y_val_chunk) 90 | 91 | # update the fast weight 92 | optimizer.zero_grad() 93 | 94 | with torch.enable_grad(): 95 | y_chunk_pred = fast_weight_model.forward(x_chunk, fast_weight) 96 | loss = loss_fn(y_chunk_pred, y_chunk) 97 | loss.backward(retain_graph=True) 98 | optimizer.step(**hyper_params_chunk) 99 | 100 | y_val = stack_chunk(y_val) 101 | 102 | return y_val, fast_weight 103 | -------------------------------------------------------------------------------- /xmixers/models/transformer/llama/configuration_llama.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ LLaMA configuration""" 3 | 4 | from transformers.configuration_utils import PretrainedConfig 5 | from transformers.utils import logging 6 | 7 | logger = logging.get_logger(__name__) 8 | 9 | 10 | class LLaMAConfig(PretrainedConfig): 11 | model_type = "llama_" 12 | keys_to_ignore_at_inference = ["past_key_values"] 13 | 14 | def __init__( 15 | self, 16 | pad_token_id=0, 17 | bos_token_id=1, 18 | eos_token_id=2, 19 | vocab_size=64000, 20 | use_cache=True, 21 | init_std=0.02, 22 | tie_word_embeddings=False, 23 | ########## model config 24 | ##### token mixer config 25 | token_mixer_type="attn", 26 | embed_dim=1024, 27 | num_heads=8, 28 | kv_heads=-1, 29 | bias=False, 30 | use_lrpe=True, 31 | lrpe_type=1, 32 | base=10000, 33 | mpa_type=0, 34 | mpa_activation="sigmoid", 35 | use_l2_norm=False, 36 | gate_type=0, 37 | head_dim=-1, 38 | q_rank=8, 39 | kv_rank=2, 40 | cp_activation="none", 41 | q_lora_rank=512, 42 | kv_lora_rank=512, 43 | qk_rope_head_dim=64, 44 | window_size=-1, 45 | block_size=64, 46 | chunk_size=128, 47 | token_mixer_top_k=2, 48 | share_kv=False, 49 | poly_order=4, 50 | poly_type=1, 51 | window_head_dim=128, 52 | use_decay=False, 53 | use_kernel_regression=True, 54 | scale_type=0, 55 | threshold=0.99, 56 | token_mixer_norm_type="grouprmsnorm", 57 | use_qk_norm=False, 58 | use_beta=True, 59 | ###### channel mixer config 60 | channel_mixer_type="glu", 61 | mid_dim=1024, 62 | channel_mixer_activation="silu", 63 | use_gate_linear=True, 64 | # for alu and lalu 65 | qk_dim=1024, 66 | v_dim=1024, 67 | mem_dim=1024, 68 | use_scale=0, 69 | use_output_gate=False, 70 | output_gate_activation="silu", 71 | use_low_rank_output_gate=False, 72 | channel_mixer_init_type=0, 73 | ##### others 74 | max_position_embeddings=1024, 75 | num_layers=24, 76 | norm_type="rmsnorm", 77 | use_postnorm=False, 78 | use_embed_scale=False, 79 | ce_type="xopes_flce", 80 | fuse_norm_add=False, 81 | num_bins=128, 82 | center=False, 83 | use_proj=True, 84 | share_proj=True, 85 | ##### init 86 | init_type=1, 87 | token_mixer_init_type=4, 88 | rescale_type=2, 89 | gain=0.01, 90 | pad_embed_dim=True, 91 | **kwargs, 92 | ): 93 | super().__init__( 94 | pad_token_id=pad_token_id, 95 | bos_token_id=bos_token_id, 96 | eos_token_id=eos_token_id, 97 | tie_word_embeddings=tie_word_embeddings, 98 | **kwargs, 99 | ) 100 | for key, value in locals().items(): 101 | if key not in [ 102 | "self", 103 | "kwargs", 104 | "__class__", 105 | "pad_token_id", 106 | "bos_token_id", 107 | "eos_token_id", 108 | "tie_word_embeddings", 109 | ]: 110 | setattr(self, key, value) 111 | -------------------------------------------------------------------------------- /xmixers/models/hybrid/naive_hybrid/configuration_naive_hybrid.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ Naive Hybrid configuration""" 3 | 4 | from transformers.configuration_utils import PretrainedConfig 5 | from transformers.utils import logging 6 | 7 | logger = logging.get_logger(__name__) 8 | 9 | 10 | class NaiveHybridConfig(PretrainedConfig): 11 | model_type = "naive_hybrid" 12 | keys_to_ignore_at_inference = ["past_key_values"] 13 | 14 | def __init__( 15 | self, 16 | pad_token_id=0, 17 | bos_token_id=1, 18 | eos_token_id=2, 19 | vocab_size=64000, 20 | use_cache=True, 21 | init_std=0.02, 22 | tie_word_embeddings=False, 23 | ########## model config 24 | token_mixer_type_list=[], 25 | causal=True, 26 | # attention config 27 | token_mixer_type="attn", 28 | embed_dim=1024, 29 | kv_heads=-1, 30 | bias=False, 31 | softmax_use_lrpe=True, 32 | softmax_lrpe_type=1, 33 | softmax_base=10000, 34 | mpa_type=0, 35 | mpa_activation="none", 36 | softmax_head_dim=-1, 37 | softmax_num_heads=8, 38 | q_rank=-1, 39 | kv_rank=2, 40 | cp_activation="none", 41 | q_lora_rank=512, 42 | kv_lora_rank=512, 43 | qk_rope_head_dim=64, 44 | window_size=-1, 45 | # linear attention config 46 | linear_num_heads=8, 47 | use_output_gate=True, 48 | q_activation="silu", 49 | k_activation="silu", 50 | v_activation="silu", 51 | q_norm=False, 52 | k_norm=False, 53 | v_norm=False, 54 | use_initial_state=False, 55 | gate_act="sigmoid", 56 | gate_pos="pre", 57 | token_mixer_norm_type="rmsnorm", 58 | beta_activation="silu", 59 | use_dense_memory=True, 60 | n_min=2, 61 | n_max=256, 62 | linear_use_lrpe=False, 63 | linear_lrpe_type=1, 64 | linear_base=10000, 65 | # channel mixer config 66 | channel_mixer_type="glu", 67 | mid_dim=1024, 68 | channel_mixer_activation="silu", 69 | use_gate_linear=True, 70 | # for alu and lalu 71 | qk_dim=1024, 72 | v_dim=1024, 73 | mem_dim=1024, 74 | use_scale=0, 75 | output_gate_activation="silu", 76 | use_low_rank_output_gate=False, 77 | # others 78 | max_position_embeddings=1024, 79 | num_layers=24, 80 | norm_type="rmsnorm", 81 | use_postnorm=False, 82 | use_embed_scale=False, 83 | ce_type="xopes_flce", 84 | # init 85 | init_type=1, 86 | token_mixer_init_type=4, 87 | rescale_type=2, 88 | gain=0.01, 89 | channel_mixer_init_type=0, 90 | pad_embed_dim=True, 91 | **kwargs, 92 | ): 93 | super().__init__( 94 | pad_token_id=pad_token_id, 95 | bos_token_id=bos_token_id, 96 | eos_token_id=eos_token_id, 97 | tie_word_embeddings=tie_word_embeddings, 98 | **kwargs, 99 | ) 100 | for key, value in locals().items(): 101 | if key not in [ 102 | "self", 103 | "kwargs", 104 | "__class__", 105 | "pad_token_id", 106 | "bos_token_id", 107 | "eos_token_id", 108 | "tie_word_embeddings", 109 | ]: 110 | setattr(self, key, value) 111 | self.num_heads = self.linear_num_heads 112 | -------------------------------------------------------------------------------- /scripts/linear_decay_transformer/save_log_decay.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | from transformers import AutoModelForCausalLM, AutoTokenizer 6 | 7 | import xmixers # noqa 8 | 9 | 10 | def save_log_f( 11 | model, tokenizer, text, save_dir, save_name, max_length=2048, device="cuda" 12 | ): 13 | """ 14 | Run model inference and save log_f values. 15 | 16 | Args: 17 | model: The model to extract log_f from 18 | tokenizer: Tokenizer for input processing 19 | text: Input text for inference 20 | save_dir: Directory to save log_f values 21 | save_name: Filename for saved data 22 | max_length: Maximum sequence length for tokenization 23 | device: Device to run inference on 24 | """ 25 | # Create save directory if it doesn't exist 26 | os.makedirs(save_dir, exist_ok=True) 27 | os.path.join(save_dir, f"{save_name}.npy") 28 | 29 | # Tokenize input text with truncation 30 | inputs = tokenizer( 31 | text, return_tensors="pt", truncation=True, max_length=max_length 32 | ).to(device) 33 | inputs["save_decay"] = True 34 | inputs["save_dir"] = save_dir 35 | inputs["save_name"] = save_name 36 | print(f"Input sequence length: {inputs['input_ids'].shape[1]}") 37 | 38 | # Run inference 39 | with torch.inference_mode(): 40 | model( 41 | **inputs, 42 | ) 43 | 44 | 45 | def main(): 46 | parser = argparse.ArgumentParser( 47 | description="Save log_f values from decay linear attention models" 48 | ) 49 | parser.add_argument( 50 | "--model_path", type=str, required=True, help="Path to the pretrained model" 51 | ) 52 | parser.add_argument( 53 | "--text_file", type=str, required=True, help="Path to text file for input" 54 | ) 55 | parser.add_argument( 56 | "--save_dir", type=str, required=True, help="Directory to save log_f values" 57 | ) 58 | parser.add_argument( 59 | "--save_name", 60 | type=str, 61 | required=True, 62 | help="Filename for saved data (without extension)", 63 | ) 64 | parser.add_argument( 65 | "--max_length", 66 | type=int, 67 | default=2048, 68 | help="Maximum sequence length for tokenization", 69 | ) 70 | parser.add_argument( 71 | "--dtype", 72 | type=str, 73 | default="bf16", 74 | choices=["bf16", "fp32"], 75 | help="Data type for model", 76 | ) 77 | args = parser.parse_args() 78 | 79 | # Setup device and dtype 80 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 81 | dtype_map = {"bf16": torch.bfloat16, "fp32": torch.float32} 82 | dtype = dtype_map[args.dtype] 83 | 84 | # Load model and tokenizer 85 | print(f"Loading model from {args.model_path}") 86 | model = AutoModelForCausalLM.from_pretrained(args.model_path).to(device).to(dtype) 87 | tokenizer = AutoTokenizer.from_pretrained(args.model_path) 88 | if tokenizer.pad_token is None: 89 | tokenizer.pad_token = tokenizer.eos_token 90 | 91 | model.eval() 92 | 93 | # Read input text from file 94 | print(f"Reading text from {args.text_file}") 95 | with open(args.text_file, "r", encoding="utf-8") as f: 96 | text = f.read() 97 | 98 | # Extract and save log_f values 99 | print(f"Processing text (max length: {args.max_length})") 100 | save_log_f( 101 | model, tokenizer, text, args.save_dir, args.save_name, args.max_length, device 102 | ) 103 | 104 | 105 | if __name__ == "__main__": 106 | main() 107 | -------------------------------------------------------------------------------- /xmixers/modules/token_mixers/long_conv/tno.py: -------------------------------------------------------------------------------- 1 | """ 2 | Toeplitz Neural Operator in https://arxiv.org/pdf/2305.04749.pdf 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from xmixers.modules import get_norm_fn 10 | from xmixers.ops import long_conv_1d_op 11 | from xmixers.utils import XMIXERS_DEBUG, next_power_of_2, print_module, print_params 12 | 13 | from .rpe import Rpe 14 | 15 | 16 | class Tno(nn.Module): 17 | def __init__( 18 | self, 19 | in_dim: int, 20 | feature_dim: int, 21 | out_dim: int, 22 | activation: str = "silu", 23 | bias: bool = False, 24 | rpe_layers: int = 3, 25 | norm_type: str = "layernorm", 26 | use_decay: bool = True, 27 | causal: bool = True, 28 | dim: int = 1, 29 | lower_bound: float = 0.99, 30 | *args, 31 | **kwargs, 32 | ) -> None: 33 | super().__init__() 34 | 35 | if XMIXERS_DEBUG: 36 | # get local varables 37 | params = locals() 38 | # print params 39 | print_params(**params) 40 | 41 | self.rpe = Rpe( 42 | in_dim=in_dim, 43 | feature_dim=feature_dim, 44 | out_dim=out_dim, 45 | activation=activation, 46 | bias=bias, 47 | rpe_layers=rpe_layers, 48 | norm_type=norm_type, 49 | ) 50 | self.norm = get_norm_fn(norm_type)(out_dim) 51 | 52 | if use_decay: 53 | self.gamma = nn.Parameter(torch.randn(1, out_dim) * 0.1, requires_grad=True) 54 | # self.lower_bound = lower_bound 55 | # gamma = 1 / (torch.arange(1, out_dim + 1)) 56 | # self.gamma = nn.Parameter(gamma.reshape(1, -1), requires_grad=True) 57 | 58 | self.use_decay = use_decay 59 | self.dim = dim 60 | self.zero = torch.empty(0) 61 | self.pos = torch.empty(0) 62 | self.neg = torch.empty(0) 63 | self.cache_size = 0 64 | self.causal = causal 65 | 66 | def extra_repr(self): 67 | return print_module(self) 68 | 69 | def get_w(self, x): 70 | n = x.shape[self.dim] 71 | m = next_power_of_2(n) 72 | if self.cache_size < n: 73 | self.cache_size = m 74 | self.zero = torch.zeros(1).unsqueeze(-1).to(x.device) 75 | self.pos = torch.arange(1, m).unsqueeze(-1).to(x.device) 76 | if not self.causal: 77 | self.neg = -torch.arange(1, m).unsqueeze(-1).flip(0).to(x.device) 78 | 79 | if self.causal: 80 | self.index = torch.cat([self.zero, self.pos[:n]], dim=0) 81 | else: 82 | self.index = torch.cat( 83 | [self.zero, self.pos[:n], self.zero, self.neg[-n:]], dim=0 84 | ) 85 | 86 | return self.index 87 | 88 | def get_gamma(self, x): 89 | n = x.shape[self.dim] 90 | # gamma = self.lower_bound + (1 - self.lower_bound) * torch.clamp(self.gamma, min=0, max=1).float() 91 | # gamma_zero = torch.exp(self.zero * torch.log(self.gamma)) 92 | # gamma_pos = torch.exp(self.pos[:n] * torch.log(self.gamma)) 93 | 94 | gamma_zero = torch.exp(self.zero * F.logsigmoid(self.gamma)) 95 | gamma_pos = torch.exp(self.pos[:n] * F.logsigmoid(self.gamma)) 96 | 97 | if self.causal: 98 | gamma = torch.cat([gamma_zero, gamma_pos], dim=0) 99 | else: 100 | gamma = torch.cat( 101 | [gamma_zero, gamma_pos, gamma_zero, gamma_pos.flip(0)], dim=0 102 | ) 103 | # print(gamma) 104 | return gamma 105 | 106 | def forward(self, x): 107 | index = self.get_w(x) 108 | w = self.rpe(index).to(torch.float32) 109 | if self.use_decay: 110 | w = self.get_gamma(x) * w 111 | y = self.norm(long_conv_1d_op(x, w, self.dim)) 112 | 113 | return y 114 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/polar_rnn/configuration_polar_rnn.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ PolarRnn configuration""" 3 | 4 | from transformers.configuration_utils import PretrainedConfig 5 | from transformers.utils import logging 6 | 7 | logger = logging.get_logger(__name__) 8 | 9 | 10 | class PolarRnnConfig(PretrainedConfig): 11 | model_type = "polar_rnn" 12 | keys_to_ignore_at_inference = ["past_key_values"] 13 | 14 | def __init__( 15 | self, 16 | pad_token_id=0, 17 | bos_token_id=1, 18 | eos_token_id=2, 19 | vocab_size=64000, 20 | use_cache=True, 21 | init_std=0.02, 22 | tie_word_embeddings=False, 23 | ########## model config 24 | ##### token mixer config 25 | token_mixer_type="polar_rnn", 26 | embed_dim=1024, 27 | num_heads=8, 28 | bias=False, 29 | ##### channel mixer config 30 | channel_mixer_type="glu", 31 | mid_dim=1024, 32 | channel_mixer_activation="silu", 33 | use_gate_linear=False, 34 | ##### others 35 | max_position_embeddings=1024, 36 | use_output_gate=True, 37 | norm_type="layernorm", 38 | q_activation="silu", 39 | k_activation="silu", 40 | v_activation="silu", 41 | use_gamma=True, 42 | gamma_activation="pos", 43 | use_decay=True, 44 | scalar_decay=True, 45 | qkv_norm_type=2, 46 | norm_q=False, 47 | norm_v=False, 48 | causal=True, 49 | num_layers=12, 50 | use_embed_scale=False, 51 | pad_embed_dim=True, 52 | ##### init 53 | init_type=0, 54 | token_mixer_init_type=0, 55 | rescale_type=0, 56 | channel_mixer_init_type=0, 57 | gain=0.02, 58 | fuse_norm_add=True, 59 | ce_type="xopes_flce", 60 | debug=0, 61 | use_l2_norm=False, 62 | **kwargs, 63 | ): 64 | super().__init__( 65 | pad_token_id=pad_token_id, 66 | bos_token_id=bos_token_id, 67 | eos_token_id=eos_token_id, 68 | tie_word_embeddings=tie_word_embeddings, 69 | **kwargs, 70 | ) 71 | ##### hf origin 72 | self.vocab_size = vocab_size 73 | self.use_cache = use_cache 74 | self.init_std = init_std 75 | ##### add 76 | # attention config 77 | self.token_mixer_type = token_mixer_type 78 | self.embed_dim = embed_dim 79 | self.num_heads = num_heads 80 | self.bias = bias 81 | # channel mixer config 82 | self.channel_mixer_type = channel_mixer_type 83 | self.mid_dim = mid_dim 84 | self.channel_mixer_activation = channel_mixer_activation 85 | self.use_gate_linear = use_gate_linear 86 | # others 87 | self.max_position_embeddings = max_position_embeddings 88 | self.use_output_gate = use_output_gate 89 | self.norm_type = norm_type 90 | self.q_activation = q_activation 91 | self.k_activation = k_activation 92 | self.v_activation = v_activation 93 | self.use_gamma = use_gamma 94 | self.gamma_activation = gamma_activation 95 | self.use_decay = use_decay 96 | self.scalar_decay = scalar_decay 97 | self.qkv_norm_type = qkv_norm_type 98 | self.norm_q = norm_q 99 | self.norm_v = norm_v 100 | self.causal = causal 101 | self.use_embed_scale = use_embed_scale 102 | self.num_layers = num_layers 103 | # init 104 | self.init_type = init_type 105 | self.token_mixer_init_type = token_mixer_init_type 106 | self.rescale_type = rescale_type 107 | self.channel_mixer_init_type = channel_mixer_init_type 108 | self.gain = gain 109 | self.fuse_norm_add = fuse_norm_add 110 | self.ce_type = ce_type 111 | self.debug = debug 112 | self.use_l2_norm = use_l2_norm 113 | -------------------------------------------------------------------------------- /xmixers/modules/pes/tpe.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tpe in https://arxiv.org/abs/2405.21022 3 | """ 4 | 5 | from typing import Optional 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from einops import rearrange 11 | from transformers.cache_utils import Cache 12 | from xopes.ops import lightning_attn_func 13 | 14 | from xmixers.modules.normalizations import get_norm_fn 15 | from xmixers.utils import _initialize_weights, print_module 16 | 17 | 18 | class Tpe(nn.Module): 19 | def __init__( 20 | self, 21 | embed_dim: int, 22 | num_heads: int, 23 | bias: bool = False, 24 | layer_idx: int = 0, 25 | token_mixer_norm_type: str = "rmsnorm", 26 | token_mixer_init_type: int = 4, 27 | rescale_type: int = 2, 28 | num_layers: int = 12, 29 | init_std: float = 0.02, 30 | gain: float = 0.01, 31 | **kwargs, 32 | ): 33 | super().__init__() 34 | self.embed_dim = embed_dim 35 | self.head_dim = embed_dim // num_heads 36 | self.bias = bias 37 | self.q = nn.Parameter(torch.randn(embed_dim)) 38 | self.k = nn.Parameter(torch.randn(embed_dim)) 39 | self.log_decay = nn.Parameter(torch.randn(num_heads)) 40 | self.norm = get_norm_fn(token_mixer_norm_type)( 41 | embed_dim, bias=bias, num_groups=num_heads 42 | ) 43 | self.o_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 44 | self.layer_idx = layer_idx 45 | 46 | self.token_mixer_init_type = token_mixer_init_type 47 | self.rescale_type = rescale_type 48 | self.num_layers = num_layers 49 | self.embed_dim = embed_dim 50 | self.init_std = init_std 51 | self.gain = gain 52 | self._init_weights() 53 | 54 | def _init_weights(self): 55 | self.apply(self._initialize_weights) 56 | 57 | def _initialize_weights(self, module): 58 | return _initialize_weights(self, module) 59 | 60 | def extra_repr(self): 61 | return print_module(self) 62 | 63 | def forward( 64 | self, 65 | x, 66 | attention_mask: Optional[torch.Tensor] = None, # (b, m) 67 | past_key_values: Optional[Cache] = None, 68 | use_cache: Optional[bool] = False, 69 | **kwargs, 70 | ): 71 | b, n, d = x.shape 72 | q, k, v = map( 73 | lambda x: rearrange(x, "... (h d) -> ... h d", d=self.head_dim), 74 | [self.q, self.k, x], 75 | ) 76 | log_decay = F.logsigmoid(self.log_decay) 77 | 78 | recurrent_state = None 79 | q_offset = 0 80 | if past_key_values is not None and len(past_key_values) > self.layer_idx: 81 | recurrent_state = past_key_values[self.layer_idx]["recurrent_state"][0] 82 | q_offset = past_key_values.get_seq_length(self.layer_idx) 83 | 84 | use_attn_mask = ( 85 | attention_mask is not None and not attention_mask.all() and (n > 1) 86 | ) 87 | 88 | # TODO: update this later 89 | # left padding 90 | if use_attn_mask: 91 | start = q_offset 92 | attention_mask_ = attention_mask[:, start:].unsqueeze(-1).unsqueeze(-1) 93 | v = v.masked_fill(attention_mask_ == 0, 0) 94 | 95 | output, recurrent_state = lightning_attn_func( 96 | q=q, 97 | k=k, 98 | v=v, 99 | ld=log_decay, 100 | initial_state=recurrent_state, 101 | decay_type="positional", 102 | ) 103 | 104 | if past_key_values is not None: 105 | past_key_values.update( 106 | recurrent_state=[recurrent_state], 107 | layer_idx=self.layer_idx, 108 | offset=n, 109 | ) 110 | 111 | # reshape 112 | output = rearrange(output, "... n h d -> ... n (h d)") 113 | 114 | # normalize 115 | output = self.norm(output) 116 | 117 | # outproj 118 | output = self.o_proj(output) 119 | 120 | return output, past_key_values 121 | -------------------------------------------------------------------------------- /xmixers/models/linear_transformer/tnl/configuration_tnl.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Tnl: https://arxiv.org/pdf/2405.17381 3 | """ Tnl configuration""" 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | from transformers.utils import logging 7 | 8 | logger = logging.get_logger(__name__) 9 | 10 | 11 | class TnlConfig(PretrainedConfig): 12 | model_type = "tnl" 13 | keys_to_ignore_at_inference = ["past_key_values"] 14 | 15 | def __init__( 16 | self, 17 | pad_token_id=0, 18 | bos_token_id=1, 19 | eos_token_id=2, 20 | vocab_size=64000, 21 | use_cache=True, 22 | init_std=0.02, 23 | tie_word_embeddings=False, 24 | ########## model config 25 | ##### token mixer config 26 | token_mixer_type="tnl_attn", 27 | embed_dim=1024, 28 | num_heads=8, 29 | bias=False, 30 | use_lrpe=True, 31 | lrpe_type=1, 32 | base=10000, 33 | use_output_gate=True, 34 | norm_type="rmsnorm", 35 | q_activation="silu", 36 | k_activation="silu", 37 | v_activation="silu", 38 | q_norm=False, 39 | k_norm=False, 40 | v_norm=False, 41 | causal=True, 42 | use_initial_state=False, 43 | gate_act="sigmoid", 44 | gate_pos="pre", 45 | token_mixer_norm_type="rmsnorm", 46 | ##### channel mixer config 47 | channel_mixer_type="glu", 48 | mid_dim=1024, 49 | channel_mixer_activation="silu", 50 | use_gate_linear=True, 51 | ##### others 52 | max_position_embeddings=1024, 53 | num_layers=24, 54 | use_lrpe_list=[False], 55 | n_min=2, 56 | n_max=256, 57 | use_embed_scale=False, 58 | ce_type="xopes_flce", 59 | pad_embed_dim=True, 60 | ##### init 61 | init_type=1, 62 | token_mixer_init_type=4, 63 | rescale_type=2, 64 | gain=0.01, 65 | channel_mixer_init_type=0, 66 | **kwargs, 67 | ): 68 | super().__init__( 69 | pad_token_id=pad_token_id, 70 | bos_token_id=bos_token_id, 71 | eos_token_id=eos_token_id, 72 | tie_word_embeddings=tie_word_embeddings, 73 | **kwargs, 74 | ) 75 | ##### hf origin 76 | self.vocab_size = vocab_size 77 | self.use_cache = use_cache 78 | self.init_std = init_std 79 | ##### add 80 | # token mixer config 81 | self.embed_dim = embed_dim 82 | self.num_heads = num_heads 83 | self.bias = bias 84 | self.use_lrpe = use_lrpe 85 | self.lrpe_type = lrpe_type 86 | self.base = base 87 | self.use_output_gate = use_output_gate 88 | self.norm_type = norm_type 89 | self.q_activation = q_activation 90 | self.k_activation = k_activation 91 | self.v_activation = v_activation 92 | self.q_norm = q_norm 93 | self.k_norm = k_norm 94 | self.v_norm = v_norm 95 | self.causal = causal 96 | self.use_initial_state = use_initial_state 97 | self.token_mixer_type = token_mixer_type 98 | self.gate_act = gate_act 99 | self.gate_pos = gate_pos 100 | self.token_mixer_norm_type = token_mixer_norm_type 101 | # channel mixer config 102 | self.channel_mixer_type = channel_mixer_type 103 | self.mid_dim = mid_dim 104 | self.channel_mixer_activation = channel_mixer_activation 105 | self.use_gate_linear = use_gate_linear 106 | # others 107 | self.max_position_embeddings = max_position_embeddings 108 | self.num_layers = num_layers 109 | self.use_lrpe_list = use_lrpe_list 110 | self.n_min = n_min 111 | self.n_max = n_max 112 | self.use_embed_scale = use_embed_scale 113 | self.ce_type = ce_type 114 | self.token_mixer_init_type = token_mixer_init_type 115 | self.channel_mixer_init_type = channel_mixer_init_type 116 | self.init_type = init_type 117 | self.rescale_type = rescale_type 118 | self.gain = gain 119 | -------------------------------------------------------------------------------- /tests/train_test/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.utils.data import DataLoader, TensorDataset 5 | from tqdm import tqdm 6 | 7 | from xmixers.modules.token_mixers.deep_memory.deep_memory_unit import DeepMemoryUnit 8 | 9 | 10 | def create_test_training_loop(): 11 | # Set device 12 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | 14 | # Model parameters 15 | embed_dim = 256 16 | num_heads = 8 17 | batch_size = 4 18 | seq_len = 256 19 | num_epochs = 100 20 | 21 | # Initialize model 22 | model = DeepMemoryUnit( 23 | embed_dim=embed_dim, 24 | num_heads=num_heads, 25 | ).to(device) 26 | 27 | print(model) 28 | 29 | # Create synthetic data 30 | # Input data: random tensors 31 | X = torch.randn(batch_size, seq_len, embed_dim) 32 | # Target data: for testing, we can use the same data as target (autoencoding task) 33 | # or create some transformed data 34 | Y = X + 0.1 * torch.randn_like(X) # Add some noise as target 35 | 36 | # Create dataset and dataloader 37 | dataset = TensorDataset(X, Y) 38 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) 39 | 40 | # Loss function and optimizer 41 | criterion = nn.MSELoss() 42 | optimizer = optim.Adam(model.parameters(), lr=1e-4) 43 | 44 | # Training loop 45 | model.train() 46 | for epoch in tqdm(range(num_epochs)): 47 | total_loss = 0.0 48 | num_batches = 0 49 | 50 | for batch_idx, (inputs, targets) in enumerate(dataloader): 51 | inputs = inputs.to(device) 52 | targets = targets.to(device) 53 | 54 | # Zero gradients 55 | optimizer.zero_grad() 56 | 57 | # Forward pass 58 | outputs, _ = model(inputs) 59 | 60 | # Calculate loss 61 | loss = criterion(outputs, targets) 62 | 63 | # Backward pass 64 | loss.backward() 65 | 66 | # Gradient clipping (optional) 67 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) 68 | 69 | # Update parameters 70 | optimizer.step() 71 | 72 | total_loss += loss.item() 73 | num_batches += 1 74 | 75 | # Print progress 76 | if batch_idx % 5 == 0: 77 | print( 78 | f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}], Loss: {loss.item():.6f}" 79 | ) 80 | 81 | avg_loss = total_loss / num_batches if num_batches > 0 else 0 82 | print(f"Epoch [{epoch+1}/{num_epochs}] completed. Average Loss: {avg_loss:.6f}") 83 | print("-" * 50) 84 | 85 | print("Training completed!") 86 | 87 | # Evaluation mode 88 | model.eval() 89 | with torch.no_grad(): 90 | test_input = torch.randn(1, seq_len, embed_dim).to(device) 91 | test_output, _ = model(test_input) 92 | print(f"Test input shape: {test_input.shape}") 93 | print(f"Test output shape: {test_output.shape}") 94 | 95 | 96 | def test_single_forward_pass(): 97 | """Test single forward pass""" 98 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 99 | 100 | embed_dim = 256 101 | num_heads = 8 102 | seq_len = 256 103 | 104 | model = DeepMemoryUnit( 105 | embed_dim=embed_dim, 106 | num_heads=num_heads, 107 | ).to(device) 108 | 109 | print(model) 110 | 111 | # Test input 112 | x = torch.randn(2, seq_len, embed_dim).to(device) 113 | 114 | model.eval() 115 | with torch.no_grad(): 116 | output, _ = model(x) 117 | print(f"✓ Forward pass successful!") 118 | print(f"Input shape: {x.shape}") 119 | print(f"Output shape: {output.shape}") 120 | print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") 121 | 122 | 123 | if __name__ == "__main__": 124 | print("Testing single forward pass...") 125 | test_single_forward_pass() 126 | print("\n" + "=" * 50 + "\n") 127 | print("Starting training loop...") 128 | create_test_training_loop() 129 | -------------------------------------------------------------------------------- /xmixers/modules/normalizations/rms_norm.py: -------------------------------------------------------------------------------- 1 | """ 2 | SimpleRMSNorm in https://arxiv.org/abs/2307.14995 3 | RMSNorm in https://arxiv.org/pdf/1910.07467.pdf 4 | GatedRMSNorm in https://arxiv.org/pdf/2104.07012.pdf 5 | 6 | Reference: 7 | https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py 8 | https://github.com/bzhangGo/zero/blob/master/modules/rela.py 9 | """ 10 | import torch 11 | import torch.nn as nn 12 | from fla.modules import RMSNorm 13 | 14 | from xmixers.utils import XMIXERS_DEBUG, print_module, print_params 15 | 16 | from .utils import NormOp 17 | 18 | 19 | class RMSNorm(torch.nn.Module): 20 | def __init__(self, dim: int, eps: float = 1e-5, **kwargs): 21 | super().__init__() 22 | if XMIXERS_DEBUG: 23 | # get local varables 24 | params = locals() 25 | # print params 26 | print_params(**params) 27 | 28 | self.eps = eps 29 | self.dim = dim 30 | self.weight = nn.Parameter(torch.ones(dim)) 31 | self.op = NormOp(norm_type="rmsnorm") 32 | 33 | self._init_weights() 34 | 35 | def _init_weights(self): 36 | nn.init.ones_(self.weight) 37 | 38 | def extra_repr(self) -> str: 39 | return f"dim={self.dim}, eps={self.eps}" 40 | 41 | def forward(self, x, residual=None, return_residual=False): 42 | return self.op( 43 | x, 44 | self.weight, 45 | None, 46 | residual, 47 | self.dim, 48 | self.eps, 49 | False, 50 | 1, 51 | return_residual, 52 | ) 53 | 54 | return o 55 | 56 | 57 | class RMSNormFusedGate(torch.nn.Module): 58 | def __init__( 59 | self, dim: int, eps: float = 1e-5, gate_act="sigmoid", gate_pos="pre", **kwargs 60 | ): 61 | super().__init__() 62 | if XMIXERS_DEBUG: 63 | # get local varables 64 | params = locals() 65 | # print params 66 | print_params(**params) 67 | 68 | self.eps = eps 69 | self.dim = dim 70 | self.weight = nn.Parameter(torch.ones(dim)) 71 | self.gate_act = gate_act 72 | self.gate_pos = gate_pos 73 | self.op = NormOp(norm_type="rmsnorm") 74 | 75 | self._init_weights() 76 | 77 | def _init_weights(self): 78 | nn.init.ones_(self.weight) 79 | 80 | def extra_repr(self) -> str: 81 | return f"dim={self.dim}, eps={self.eps}, gate_act={self.gate_act}, gate_pos={self.gate_pos}" 82 | 83 | def forward(self, x, gate): 84 | return self.op( 85 | x, 86 | self.weight, 87 | None, 88 | None, 89 | self.dim, 90 | self.eps, 91 | False, 92 | 1, 93 | False, 94 | gate, 95 | self.gate_act, 96 | self.gate_pos, 97 | ) 98 | 99 | return o 100 | 101 | 102 | class GatedRMSNorm(nn.Module): 103 | def __init__(self, d: int, eps: float = 1e-5, bias: bool = False, **kwargs) -> None: 104 | super().__init__() 105 | 106 | if XMIXERS_DEBUG: 107 | # get local varables 108 | params = locals() 109 | # print params 110 | print_params(**params) 111 | 112 | self.eps = eps 113 | self.d = d 114 | self.bias = bias 115 | 116 | self.scale = nn.Parameter(torch.ones(d)) 117 | self.register_parameter("scale", self.scale) 118 | self.gate = nn.Parameter(torch.ones(d)) 119 | self.register_parameter("gate", self.scale) 120 | 121 | self._init_weights() 122 | 123 | def _init_weights(self): 124 | nn.init.ones_(self.scale) 125 | nn.init.ones_(self.gate) 126 | 127 | def extra_repr(self) -> str: 128 | return print_module(self) 129 | 130 | def forward(self, x): 131 | # TODO: add fusion here 132 | norm_x = x.norm(2, dim=-1, keepdim=True) 133 | d_x = self.d 134 | 135 | rms_x = norm_x * d_x ** (-1.0 / 2) 136 | x_normed = x / (rms_x + self.eps) 137 | 138 | return self.scale * x_normed * torch.sigmoid(self.gate * x) 139 | -------------------------------------------------------------------------------- /xmixers/modules/channel_mixers/alu.py: -------------------------------------------------------------------------------- 1 | # Attention Linear Unit: https://arxiv.org/pdf/1706.03762v3 2 | 3 | import torch 4 | import torch.nn as nn 5 | from einops import rearrange 6 | 7 | from xmixers.modules.activations import get_activation_fn 8 | from xmixers.utils import XMIXERS_DEBUG, print_module, print_params 9 | 10 | 11 | class ALU(nn.Module): 12 | def __init__( 13 | self, 14 | embed_dim: int, 15 | qk_dim: int, 16 | v_dim: int, 17 | mem_dim: int, 18 | num_heads: int, 19 | activation: str, 20 | bias: bool = False, 21 | use_scale: int = 0, 22 | use_output_gate: bool = False, 23 | output_gate_activation: str = "silu", 24 | use_low_rank_output_gate: bool = False, 25 | channel_mixer_init_type: int = 0, 26 | ) -> None: 27 | super().__init__() 28 | 29 | if XMIXERS_DEBUG: 30 | # get local varables 31 | params = locals() 32 | # print params 33 | print_params(**params) 34 | 35 | self.q_proj = nn.Linear(embed_dim, qk_dim, bias=bias) 36 | self.k_weight = nn.Parameter(torch.randn(mem_dim, qk_dim), requires_grad=True) 37 | self.v_weight = nn.Parameter(torch.randn(mem_dim, v_dim), requires_grad=True) 38 | self.out_proj = nn.Linear(v_dim, embed_dim, bias=bias) 39 | self.use_output_gate = use_output_gate 40 | if self.use_output_gate: 41 | if use_low_rank_output_gate: 42 | mid_dim = embed_dim // num_heads 43 | self.output_gate = nn.Sequential( 44 | nn.Linear(embed_dim, mid_dim, bias=bias), 45 | nn.Linear(mid_dim, v_dim, bias=bias), 46 | ) 47 | else: 48 | self.output_gate = nn.Linear(embed_dim, v_dim, bias=bias) 49 | self.output_gate_act = get_activation_fn(output_gate_activation) 50 | 51 | self.num_heads = num_heads 52 | self.use_scale = use_scale 53 | self.act = get_activation_fn(activation) 54 | self.channel_mixer_init_type = channel_mixer_init_type 55 | 56 | self._initialize_weights() 57 | 58 | def _initialize_weights(self): 59 | if self.channel_mixer_init_type == 0: 60 | nn.init.normal_(self.k_weight, mean=0.0, std=0.05) 61 | nn.init.normal_(self.v_weight, mean=0.0, std=0.05) 62 | elif self.channel_mixer_init_type == 1: # fla init 63 | nn.init.xavier_uniform_(self.k_weight, gain=2**-2.5) 64 | nn.init.xavier_uniform_(self.v_weight, gain=2**-2.5) 65 | elif self.channel_mixer_init_type == 2: # fairseq init 66 | nn.init.xavier_uniform_(self.k_weight, gain=2**-0.5) 67 | nn.init.xavier_uniform_(self.v_weight, gain=2**-0.5) 68 | elif self.channel_mixer_init_type == 3: 69 | nn.init.normal_(self.k_weight, mean=0.0, std=0.2) 70 | nn.init.normal_(self.v_weight, mean=0.0, std=0.2) 71 | 72 | def extra_repr(self): 73 | return print_module(self) 74 | 75 | def forward(self, x): 76 | q = self.q_proj(x) 77 | k = self.k_weight 78 | v = self.v_weight 79 | q, k, v = map( 80 | lambda x: rearrange(x, "... n (h d) -> ... h n d", h=self.num_heads), 81 | [q, k, v], 82 | ) 83 | 84 | if self.use_scale == 1: 85 | scale = q.shape[-1] ** 0.5 86 | elif self.use_scale == 2: 87 | scale = x.shape[-1] ** 0.5 88 | else: 89 | scale = 1 90 | 91 | # v1 92 | # energy = self.act(torch.einsum("... n d, ... m d -> ... n m", q, k) * scale) 93 | 94 | # v2 95 | k = k * scale 96 | v = v * scale 97 | energy = self.act(torch.einsum("... n d, ... m d -> ... n m", q, k)) 98 | output = torch.einsum("... n m, ... m d -> ... n d", energy, v) 99 | 100 | # reshape 101 | output = rearrange(output, "... h n d -> ... n (h d)") 102 | 103 | if self.use_output_gate: 104 | output_gate = self.output_gate_act(self.output_gate(x)) 105 | output = output * output_gate 106 | 107 | # outproj 108 | output = self.out_proj(output) 109 | 110 | return output 111 | -------------------------------------------------------------------------------- /xmixers/modules/pes/lrpe.py: -------------------------------------------------------------------------------- 1 | """ 2 | Lrpe in https://openreview.net/forum?id=xoLyps2qWc 3 | """ 4 | 5 | from typing import Optional 6 | 7 | import torch 8 | import torch.nn as nn 9 | from xopes.ops import lrpe_fn 10 | 11 | from xmixers.utils import XMIXERS_DEBUG, logging_info, print_params 12 | 13 | 14 | class Lrpe(nn.Module): 15 | def __init__( 16 | self, 17 | head_dim: int = 128, 18 | num_heads: int = 8, 19 | lrpe_type: int = 1, 20 | base: int = 10000, 21 | act: str = "none", 22 | act_dim: Optional[int] = None, 23 | ): 24 | """ 25 | lrpe_type: 1 for standard rope, 2 for mix rope (rope half head dim), 3 for complex version(cosformer style) 26 | """ 27 | super().__init__() 28 | if XMIXERS_DEBUG: 29 | # get local varables 30 | params = locals() 31 | # print params 32 | print_params(**params) 33 | 34 | self.head_dim = head_dim 35 | self.num_heads = num_heads 36 | self.base = base 37 | self.act = act 38 | self.act_dim = act_dim 39 | 40 | d = self.head_dim 41 | if lrpe_type in [1, 2, 3]: 42 | self.lrpe_type = "rotate" 43 | elif lrpe_type in [4, 5, 6]: 44 | self.lrpe_type = "cosine" 45 | else: 46 | raise ValueError(f"lrpe_type: {lrpe_type} has not been support!") 47 | 48 | # init parameters 49 | self.register_buffer("theta", torch.empty(0), persistent=False) 50 | self.lrpe_type_ = lrpe_type 51 | self.base = base 52 | self.d = d 53 | self._init_weights() 54 | 55 | def _init_weights(self): 56 | lrpe_type = self.lrpe_type_ 57 | base = self.base 58 | d = self.d 59 | if lrpe_type == 1: 60 | logging_info("lrpe rotate, i.e, rope") 61 | theta = base ** ( 62 | -2 / d * torch.arange(d // 2, dtype=torch.int64) 63 | ).float().reshape(1, -1) 64 | elif lrpe_type == 2: # result much worse than 3 65 | logging_info("lrpe mix rotate, rotate half head dim, use low freq") 66 | theta = ( 67 | base 68 | ** (-2 / d * torch.arange(d // 2, dtype=torch.int64)) 69 | .float() 70 | .reshape(1, -1)[:, d // 4 :] 71 | ) 72 | theta = torch.cat([torch.zeros_like(theta), theta], dim=-1) 73 | elif lrpe_type == 3: 74 | logging_info("lrpe mix rotate, rotate half head dim, use high freq") 75 | theta = ( 76 | base 77 | ** (-2 / d * torch.arange(d // 2, dtype=torch.int64)) 78 | .float() 79 | .reshape(1, -1)[:, : d // 4] 80 | ) 81 | theta = torch.cat([theta, torch.zeros_like(theta)], dim=-1) 82 | elif lrpe_type == 4: 83 | logging_info("lrpe cosine") 84 | theta = base ** ( 85 | -2 / d * torch.arange(d, dtype=torch.int64) 86 | ).float().reshape(1, -1) 87 | elif lrpe_type == 5: # result much worse than 6 88 | logging_info("lrpe cosine, cosine half head dim, use low freq") 89 | theta = ( 90 | base 91 | ** (-2 / d * torch.arange(d, dtype=torch.int64)) 92 | .float() 93 | .reshape(1, -1)[:, d // 2 :] 94 | ) 95 | theta = torch.cat([torch.zeros_like(theta), theta], dim=-1) 96 | elif lrpe_type == 6: 97 | logging_info("lrpe cosine, cosine half head dim, use high freq") 98 | theta = ( 99 | base 100 | ** (-2 / d * torch.arange(d, dtype=torch.int64)) 101 | .float() 102 | .reshape(1, -1)[:, : d // 2] 103 | ) 104 | theta = torch.cat( 105 | [ 106 | theta, 107 | torch.zeros_like(theta), 108 | ], 109 | dim=-1, 110 | ) 111 | self.theta = theta.to(self.theta.device) 112 | self._is_hf_initialized = True 113 | 114 | def forward(self, x, offset=0): 115 | return lrpe_fn( 116 | x=x, 117 | theta=self.theta, 118 | offset=offset, 119 | act=self.act, 120 | dim=self.act_dim, 121 | lrpe_type=self.lrpe_type, 122 | ) 123 | -------------------------------------------------------------------------------- /xmixers/modules/normalizations/group_rms_norm.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | from torch import Size 7 | 8 | _shape_t = Union[int, List[int], Size] 9 | 10 | 11 | from .utils import NormOp 12 | 13 | 14 | class GroupRMSNorm(torch.nn.Module): 15 | def __init__( 16 | self, 17 | num_channels: int, 18 | num_groups: int, 19 | eps: float = 1e-5, 20 | affine: bool = True, 21 | bias: bool = True, 22 | device=None, 23 | dtype=None, 24 | **kwargs 25 | ) -> None: 26 | factory_kwargs = {"device": device, "dtype": dtype} 27 | super().__init__() 28 | if num_channels % num_groups != 0: 29 | raise ValueError("num_channels must be divisible by num_groups") 30 | 31 | self.num_groups = num_groups 32 | self.num_channels = num_channels 33 | self.eps = eps 34 | self.affine = affine 35 | if self.affine: 36 | self.weight = nn.Parameter(torch.empty(num_channels, **factory_kwargs)) 37 | if bias: 38 | self.bias = nn.Parameter(torch.empty(num_channels, **factory_kwargs)) 39 | else: 40 | self.register_parameter("bias", None) 41 | else: 42 | self.register_parameter("weight", None) 43 | self.register_parameter("bias", None) 44 | 45 | self.op = NormOp(norm_type="grouprmsnorm") 46 | 47 | self._init_weights() 48 | 49 | def _init_weights(self) -> None: 50 | if self.affine: 51 | init.ones_(self.weight) 52 | if self.bias is not None: 53 | init.zeros_(self.bias) 54 | 55 | def forward(self, x, residual=None, return_residual=False): 56 | return self.op( 57 | x, 58 | self.weight, 59 | None, 60 | None, 61 | self.num_channels, 62 | self.eps, 63 | False, 64 | self.num_groups, 65 | False, 66 | ) 67 | 68 | def extra_repr(self) -> str: 69 | return "num_groups={num_groups}, num_channels={num_channels}, eps={eps}, affine={affine}".format( 70 | **self.__dict__ 71 | ) 72 | 73 | 74 | class GroupRMSNormFusedGate(torch.nn.Module): 75 | def __init__( 76 | self, 77 | num_channels: int, 78 | num_groups: int, 79 | eps: float = 1e-5, 80 | affine: bool = True, 81 | bias: bool = True, 82 | device=None, 83 | dtype=None, 84 | gate_act="sigmoid", 85 | gate_pos="pre", 86 | **kwargs 87 | ) -> None: 88 | factory_kwargs = {"device": device, "dtype": dtype} 89 | super().__init__() 90 | if num_channels % num_groups != 0: 91 | raise ValueError("num_channels must be divisible by num_groups") 92 | 93 | self.num_groups = num_groups 94 | self.num_channels = num_channels 95 | self.eps = eps 96 | self.affine = affine 97 | if self.affine: 98 | self.weight = nn.Parameter(torch.empty(num_channels, **factory_kwargs)) 99 | if bias: 100 | self.bias = nn.Parameter(torch.empty(num_channels, **factory_kwargs)) 101 | else: 102 | self.register_parameter("bias", None) 103 | else: 104 | self.register_parameter("weight", None) 105 | self.register_parameter("bias", None) 106 | self.gate_act = gate_act 107 | self.gate_pos = gate_pos 108 | self.op = NormOp(norm_type="grouprmsnorm") 109 | 110 | self._init_weights() 111 | 112 | def _init_weights(self) -> None: 113 | if self.affine: 114 | init.ones_(self.weight) 115 | if self.bias is not None: 116 | init.zeros_(self.bias) 117 | 118 | def forward(self, x, gate): 119 | return self.op( 120 | x, 121 | self.weight, 122 | None, 123 | None, 124 | self.num_channels, 125 | self.eps, 126 | False, 127 | self.num_groups, 128 | False, 129 | gate, 130 | self.gate_act, 131 | self.gate_pos, 132 | ) 133 | 134 | def extra_repr(self) -> str: 135 | return "num_groups={num_groups}, num_channels={num_channels}, eps={eps}, affine={affine}".format( 136 | **self.__dict__ 137 | ) 138 | -------------------------------------------------------------------------------- /xmixers/modules/token_mixers/deep_memory/fast_weight/fast_weight_glu.py: -------------------------------------------------------------------------------- 1 | # GLU: https://arxiv.org/pdf/2002.05202.pdf 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from einops import repeat 7 | 8 | from xmixers.modules.activations import get_activation_fn 9 | from xmixers.utils import XMIXERS_DEBUG, print_module, print_params 10 | 11 | 12 | class FastWeightGLU(nn.Module): 13 | def __init__( 14 | self, 15 | num_heads: int, 16 | fw_embed_dim: int, 17 | fw_mid_dim: int, 18 | fw_activation: str = "silu", 19 | bias: bool = False, 20 | ) -> None: 21 | super().__init__() 22 | 23 | if XMIXERS_DEBUG: 24 | # get local varables 25 | params = locals() 26 | # print params 27 | print_params(**params) 28 | 29 | self.w1 = nn.Parameter(torch.zeros(num_heads, fw_embed_dim, fw_mid_dim)) 30 | self.w2 = nn.Parameter(torch.zeros(num_heads, fw_embed_dim, fw_mid_dim)) 31 | self.w3 = nn.Parameter(torch.zeros(num_heads, fw_mid_dim, fw_embed_dim)) 32 | self.act = get_activation_fn(fw_activation) 33 | 34 | self._init_weights() 35 | 36 | def extra_repr(self): 37 | return print_module(self) 38 | 39 | def _init_weights(self, gain=0.01): 40 | if getattr(self, "_is_hf_initialized", False): 41 | return 42 | 43 | nn.init.xavier_uniform_(self.w1, gain=gain) 44 | nn.init.xavier_uniform_(self.w2, gain=gain) 45 | nn.init.xavier_uniform_(self.w3, gain=gain) 46 | 47 | self._is_hf_initialized = True 48 | 49 | def init_fast_weight(self, b): 50 | with torch.enable_grad(): 51 | self.w1.requires_grad = True 52 | self.w2.requires_grad = True 53 | self.w3.requires_grad = True 54 | 55 | return { 56 | "w1": repeat(self.w1, "h d e -> b h d e", b=b).contiguous(), 57 | "w2": repeat(self.w2, "h d e -> b h d e", b=b).contiguous(), 58 | "w3": repeat(self.w3, "h e d -> b h e d", b=b).contiguous(), 59 | } 60 | 61 | def forward(self, x, fast_weight): 62 | w1 = fast_weight["w1"] 63 | w2 = fast_weight["w2"] 64 | w3 = fast_weight["w3"] 65 | 66 | x1 = self.act(torch.einsum("b h d e, b n h d -> b n h e", w1, x)) 67 | x2 = torch.einsum("b h d e, b n h d -> b n h e", w2, x) 68 | output = torch.einsum("b h e d, b n h e -> b n h d", w3, x1 * x2) 69 | 70 | return output 71 | 72 | 73 | class FastWeightHpGLU(nn.Module): 74 | def __init__( 75 | self, 76 | embed_dim: int, 77 | num_heads: int, 78 | use_lr: bool = True, 79 | use_wd: bool = False, 80 | use_momentum: bool = False, 81 | bias: bool = False, 82 | ) -> None: 83 | super().__init__() 84 | 85 | if XMIXERS_DEBUG: 86 | # get local varables 87 | params = locals() 88 | # print params 89 | print_params(**params) 90 | 91 | self.use_lr = use_lr 92 | self.use_wd = use_wd 93 | self.use_momentum = use_momentum 94 | 95 | if self.use_lr: 96 | self.w1_lr = nn.Linear(embed_dim, num_heads, bias=bias) 97 | self.w2_lr = nn.Linear(embed_dim, num_heads, bias=bias) 98 | self.w3_lr = nn.Linear(embed_dim, num_heads, bias=bias) 99 | 100 | if self.use_wd: 101 | self.w1_wd = nn.Linear(embed_dim, num_heads, bias=bias) 102 | self.w2_wd = nn.Linear(embed_dim, num_heads, bias=bias) 103 | self.w3_wd = nn.Linear(embed_dim, num_heads, bias=bias) 104 | 105 | if self.use_momentum: 106 | self.w1_momentum = nn.Linear(embed_dim, num_heads, bias=bias) 107 | self.w2_momentum = nn.Linear(embed_dim, num_heads, bias=bias) 108 | self.w3_momentum = nn.Linear(embed_dim, num_heads, bias=bias) 109 | 110 | def forward(self, x): 111 | lr_dict = {} 112 | if self.use_lr: 113 | lr_dict["w1"] = F.sigmoid(self.w1_lr(x)) 114 | lr_dict["w2"] = F.sigmoid(self.w2_lr(x)) 115 | lr_dict["w3"] = F.sigmoid(self.w3_lr(x)) 116 | 117 | wd_dict = {} 118 | if self.use_wd: 119 | wd_dict["w1"] = F.sigmoid(self.w1_wd(x)) 120 | wd_dict["w2"] = F.sigmoid(self.w2_wd(x)) 121 | wd_dict["w3"] = F.sigmoid(self.w3_wd(x)) 122 | 123 | momentum_dict = {} 124 | if self.use_momentum: 125 | momentum_dict["w1"] = F.sigmoid(self.w1_momentum(x)) 126 | momentum_dict["w2"] = F.sigmoid(self.w2_momentum(x)) 127 | momentum_dict["w3"] = F.sigmoid(self.w3_momentum(x)) 128 | 129 | return {"lr_dict": lr_dict, "wd_dict": wd_dict, "momentum_dict": momentum_dict} 130 | --------------------------------------------------------------------------------