├── .github ├── license-check │ ├── config.json │ ├── header_c.txt │ └── header_python.txt └── workflows │ ├── flake8.yml │ ├── interrogate.yml │ ├── license-header-check.yml │ ├── pre-commit.yml │ ├── pylint.yml │ ├── pytest.yml │ └── sphinx.yml ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── LICENSE ├── NOTICE ├── README.md ├── configs ├── __init__.py ├── checkpointing │ └── model_checkpointing.yaml ├── config.yaml ├── data │ ├── bigcode_stack_arrayrecord.yaml │ ├── bigcode_stack_arrayrecord_train.yaml │ ├── bigcode_stack_snapshot_arrayrecord.yaml │ ├── bigcode_stack_snapshot_arrayrecord_train.yaml │ ├── dclm_arrayrecord.yaml │ ├── dclm_arrayrecord_eval.yaml │ ├── dclm_arrayrecord_eval_preprocessed.yaml │ ├── dclm_arrayrecord_train.yaml │ ├── default_arrayrecord_eval.yaml │ ├── default_arrayrecord_train.yaml │ ├── default_huggingface_ds_eval.yaml │ ├── default_huggingface_ds_train.yaml │ ├── dolmino_mix_arrayrecord.yaml │ ├── dolmino_mix_arrayrecord_train.yaml │ ├── fineweb_edu_arrayrecord.yaml │ ├── fineweb_edu_arrayrecord_train.yaml │ ├── large_sft_datasets_arrayrecord.yaml │ ├── large_sft_datasets_arrayrecord_train.yaml │ ├── large_sft_datasets_pp2_arrayrecord.yaml │ ├── large_sft_datasets_pp2_arrayrecord_train.yaml │ ├── open_web_math_arrayrecord.yaml │ ├── open_web_math_arrayrecord_train.yaml │ ├── proofpile2_arrayrecord.yaml │ ├── proofpile2_arrayrecord_train.yaml │ ├── slimpajama_627B_arrayrecord.yaml │ ├── slimpajama_627B_arrayrecord_eval.yaml │ ├── slimpajama_627B_arrayrecord_eval_preprocessed.yaml │ ├── slimpajama_627B_arrayrecord_eval_preprocessed_gpt2.yaml │ ├── slimpajama_627B_arrayrecord_train.yaml │ ├── slimpajama_627B_huggingface_ds.yaml │ ├── slimpajama_627B_huggingface_ds_eval.yaml │ ├── slimpajama_627B_huggingface_ds_train.yaml │ ├── slimpajama_6B_arrayrecord.yaml │ ├── slimpajama_6B_arrayrecord_eval.yaml │ ├── slimpajama_6B_arrayrecord_eval_preprocessed.yaml │ ├── slimpajama_6B_arrayrecord_eval_preprocessed_gpt2.yaml │ ├── slimpajama_6B_arrayrecord_train.yaml │ ├── slimpajama_6B_huggingface_ds.yaml │ ├── slimpajama_6B_huggingface_ds_eval.yaml │ ├── slimpajama_6B_huggingface_ds_train.yaml │ ├── small_sft_datasets_arrayrecord.yaml │ ├── small_sft_datasets_arrayrecord_train.yaml │ ├── small_sft_datasets_extended_arrayrecord.yaml │ ├── small_sft_datasets_extended_arrayrecord_train.yaml │ ├── smol_cosmopedia_arrayrecord.yaml │ ├── smol_cosmopedia_arrayrecord_train.yaml │ ├── smol_fineweb_edu_arrayrecord.yaml │ ├── smol_fineweb_edu_arrayrecord_train.yaml │ ├── smoltalk_arrayrecord.yaml │ ├── smoltalk_arrayrecord_train.yaml │ ├── smoltalk_magpie_arrayrecord.yaml │ ├── smoltalk_magpie_arrayrecord_train.yaml │ ├── smoltalk_metamathqa_arrayrecord.yaml │ ├── smoltalk_metamathqa_arrayrecord_train.yaml │ ├── smoltalk_numina_arrayrecord.yaml │ ├── smoltalk_numina_arrayrecord_train.yaml │ ├── smoltalk_openhermes_arrayrecord.yaml │ ├── smoltalk_openhermes_arrayrecord_train.yaml │ ├── synthetic.yaml │ ├── synthetic_eval.yaml │ ├── synthetic_train.yaml │ ├── zyda2.yaml │ ├── zyda2_dolmacc_arrayrecord.yaml │ ├── zyda2_dolmacc_arrayrecord_train.yaml │ ├── zyda2_zyda_arrayrecord.yaml │ └── zyda2_zyda_arrayrecord_train.yaml ├── experiment │ ├── benchmark_mLSTMv1_7B.yaml │ ├── data_for_unit_test.yaml │ ├── synthetic_experiment.yaml │ ├── synthetic_experiment_slurm.yaml │ ├── synthetic_experiment_slurm_sweep.yaml │ ├── tiny_experiment_for_unit_testing.yaml │ ├── train_llama1.3B_dclm.yaml │ ├── train_llama1.3B_slimpajama627b.yaml │ ├── train_llama165M_slimpajama627b.yaml │ ├── train_llama7B_dclm.yaml │ ├── train_llama7B_dclm_cosine.yaml │ ├── train_mLSTM1.3B_dclm_cosine.yaml │ ├── train_mLSTM1.3B_slimpajama627b.yaml │ ├── train_mLSTM120M_slimpajama627b.yaml │ ├── train_mLSTM165M_slimpajama627b.yaml │ ├── train_mLSTM165M_slimpajama6b.yaml │ ├── train_mLSTM7B_slimpajama627b.yaml │ ├── train_mLSTM_small_slimpajama6b_verify_slurm.yaml │ ├── train_mLSTMv1_1.3B_dclm.yaml │ ├── train_mLSTMv1_1.3B_dclm_cosine.yaml │ ├── train_mLSTMv1_1.3B_dclm_long_run.yaml │ ├── train_mLSTMv1_1.3B_fineweb_edu.yaml │ ├── train_mLSTMv1_1.3B_slimpajama627b.yaml │ ├── train_mLSTMv1_1.3B_zyda2.yaml │ ├── train_mLSTMv1_165M_slimpajama627b.yaml │ ├── train_mLSTMv1_2.7B_dclm_cosine.yaml │ ├── train_mLSTMv1_7B_dclm.yaml │ ├── train_mLSTMv1_7B_dclm_cosine.yaml │ ├── train_mLSTMv1_7B_dclm_cosine_igate_init.yaml │ ├── train_mLSTMv1_7B_dclm_cosine_no_softcap.yaml │ ├── train_mLSTMv1_7B_dclm_long_run.yaml │ ├── train_mLSTMv1_7B_dclm_long_run_finetune.yaml │ ├── train_mLSTMv1_7B_dclm_long_run_wu8k.yaml │ └── train_mLSTMv1_7B_slimpajama627b.yaml ├── hydra │ └── launcher │ │ └── slurm_launcher.yaml ├── logger │ └── default_logger.yaml ├── lr_monitor │ └── lr_monitor_config.yaml ├── model │ ├── llama1.3B.yaml │ ├── llama165M.yaml │ ├── llama7B.yaml │ ├── llama_default.yaml │ ├── mLSTM1.3B.yaml │ ├── mLSTM120M.yaml │ ├── mLSTM165M.yaml │ ├── mLSTM7B.yaml │ ├── mLSTM_default.yaml │ ├── mLSTMv1_1.3B.yaml │ ├── mLSTMv1_165M.yaml │ ├── mLSTMv1_2.7B.yaml │ ├── mLSTMv1_7B.yaml │ └── mLSTMv1_default.yaml ├── optimizer │ ├── adamw.yaml │ └── sgd.yaml ├── parallel │ ├── default.yaml │ ├── llama1.3B.yaml │ ├── llama165M.yaml │ ├── llama7B.yaml │ ├── mLSTM1.3B.yaml │ ├── mLSTM120M.yaml │ ├── mLSTM165M.yaml │ ├── mLSTM7B.yaml │ ├── mLSTMv1_1.3B.yaml │ ├── mLSTMv1_165M.yaml │ ├── mLSTMv1_2.7B.yaml │ ├── mLSTMv1_7B.yaml │ └── synthetic.yaml ├── profiling │ └── jax_profiling.yaml ├── scheduler │ ├── cosine_decay.yaml │ └── exponential_decay.yaml └── trainer │ ├── default_llm_trainer.yaml │ └── default_trainer.yaml ├── docs ├── Makefile ├── _static │ ├── custom.css │ ├── nxai_logo_dark.svg │ └── nxai_logo_light.svg ├── conf.py ├── configuration_with_hydra.md ├── dataset_preparation.md ├── distributed_training.md ├── example_training.md ├── index.rst ├── installation.md ├── make.bat └── requirements.txt ├── envs ├── environment_jax_0.4.32_cpu_python_3.11.yml └── environment_python_3.11_jax_0.4.34_cuda_12.6.yml ├── pyproject.toml ├── pytest.ini ├── readthedocs.yaml ├── scripts ├── __init__.py ├── check_config.py ├── checkpoint_conversion │ ├── __init__.py │ ├── convert_mlstm_checkpoint_jax_to_torch_simple.py │ ├── run_checkpoint_conversion.py │ └── run_for_new_checkpoints.py ├── data_processing │ ├── __init__.py │ ├── hf_to_arrayrecord.py │ ├── preprocess_ar_dataset.py │ └── split_array_records_dataset.py ├── evaluation │ ├── __init__.py │ ├── run_huggingface_evaluation.py │ └── run_lmeval.py ├── internal │ ├── __init__.py │ ├── create_sft_dataset_symlinks.py │ ├── generate_all_sft_datasets.py │ └── run_dataset_stats.py ├── run_pytorch_kernels.py ├── speed_benchmark │ ├── __init__.py │ ├── benchmark_with_hydra.py │ ├── run_benchmark.py │ ├── run_jax_kernel_comparison.py │ └── run_mlstm_backend_benchmarks.py └── training │ ├── __init__.py │ ├── get_cli_command_to_resume_training.py │ ├── resume_training_with_hydra.py │ ├── run_train_llama_slimpajama.py │ ├── run_train_slimpajama.py │ ├── run_train_synthetic.py │ ├── run_train_wikitext103.py │ └── train_with_hydra.py ├── slurm ├── slurm_benchmark.job ├── slurm_evaluation.job ├── slurm_evaluation_llama.job ├── slurm_llama_slimpajama.job ├── slurm_slimpajama.job └── slurm_wikitext103.job ├── tests ├── __init__.py ├── config │ ├── test_config_parser.py │ ├── test_equivalence_hydra_nonhydra.py │ ├── test_hydra_configs.py │ ├── test_hydra_model_trainer.py │ ├── test_hydra_trainer_continuation.py │ └── test_xla_flag_mesh_init.py ├── conftest.py ├── dataset │ ├── __init__.py │ ├── test_batch.py │ ├── test_grain_data_processing.py │ ├── test_grain_iterator.py │ ├── test_grain_transforms.py │ ├── test_hf_data_processing.py │ ├── test_lmeval_preprocessing.py │ └── test_synthetic_dataloading.py ├── distributed │ ├── __init__.py │ └── test_data_parallel.py ├── exception_handling │ └── test_exception_handling.py ├── models │ ├── __init__.py │ ├── conftest.py │ ├── llama │ │ ├── test_attention.py │ │ └── test_llama.py │ ├── mlstm_simple │ │ ├── __init__.py │ │ ├── load_pretrained_jax_model.ipynb │ │ ├── load_pretrained_torch_model.ipynb │ │ ├── test_components.py │ │ ├── test_xlstm_jax_parallel_huggingface_equivalent.py │ │ └── test_xlstm_jax_parallel_mlstm_simple_equivalent.py │ ├── xlstm_clean │ │ ├── test_xlstm_jax_pytorch_equivalent.py │ │ └── test_xlstm_single_device.py │ └── xlstm_parallel │ │ ├── test_xlstm_backend.py │ │ ├── test_xlstm_block.py │ │ ├── test_xlstm_causal.py │ │ ├── test_xlstm_data_parallel.py │ │ ├── test_xlstm_init.py │ │ ├── test_xlstm_kernels.py │ │ ├── test_xlstm_layer.py │ │ └── test_xlstm_tensor_parallel.py └── trainer │ ├── __init__.py │ ├── base │ ├── __init__.py │ └── test_trainer.py │ ├── callbacks │ ├── __init__.py │ ├── test_callback.py │ ├── test_checkpoint.py │ ├── test_extended_evaluation.py │ ├── test_load_model_config.py │ ├── test_lr_monitor.py │ └── test_profiler.py │ ├── conftest.py │ ├── eval │ ├── __init__.py │ ├── test_lmeval.py │ └── test_lmeval_evaluation.py │ ├── llm │ ├── __init__.py │ └── test_trainer.py │ ├── logger │ ├── __init__.py │ ├── test_base_logger.py │ ├── test_file_logger.py │ ├── test_tensorboard_logger.py │ └── test_wandb_logger.py │ ├── optimizer │ ├── __init__.py │ ├── test_optimizer.py │ └── test_scheduler.py │ └── test_metrics.py └── xlstm_jax ├── __init__.py ├── common_types.py ├── configs.py ├── dataset ├── README.md ├── __init__.py ├── batch.py ├── configs.py ├── grain_batch_rampup.py ├── grain_data_processing.py ├── grain_iterator.py ├── grain_transforms.py ├── hf_tokenizer.py ├── input_pipeline_interface.py ├── lmeval_dataset.py ├── lmeval_pipeline.py ├── multihost_dataloading.py └── synthetic_dataloading.py ├── define_hydra_schemas.py ├── distributed ├── __init__.py ├── array_utils.py ├── data_parallel.py ├── mesh_utils.py ├── pipeline_parallel.py ├── single_gpu.py ├── tensor_parallel.py └── xla_utils.py ├── import_utils.py ├── main_train.py ├── models ├── __init__.py ├── configs.py ├── llama │ ├── __init__.py │ ├── attention.py │ ├── feedforward.py │ └── llama.py ├── shared │ ├── __init__.py │ ├── init.py │ ├── lm_head.py │ └── utils.py ├── xlstm_clean │ ├── __init__.py │ ├── blocks │ │ ├── __init__.py │ │ ├── mlstm │ │ │ ├── __init__.py │ │ │ ├── backend │ │ │ │ ├── __init__.py │ │ │ │ ├── config.py │ │ │ │ ├── config_utils.py │ │ │ │ ├── layer_factory.py │ │ │ │ └── simple.py │ │ │ ├── block.py │ │ │ ├── cell.py │ │ │ └── layer.py │ │ └── xlstm_block.py │ ├── components │ │ ├── __init__.py │ │ ├── conv.py │ │ ├── feedforward.py │ │ ├── init.py │ │ ├── linear_headwise.py │ │ └── ln.py │ ├── utils.py │ ├── xlstm_block_stack.py │ └── xlstm_lm_model.py ├── xlstm_parallel │ ├── __init__.py │ ├── benchmark.py │ ├── blocks │ │ ├── __init__.py │ │ ├── mlstm │ │ │ ├── __init__.py │ │ │ ├── backend │ │ │ │ ├── __init__.py │ │ │ │ ├── attention.py │ │ │ │ ├── config.py │ │ │ │ ├── config_utils.py │ │ │ │ ├── fwbw.py │ │ │ │ ├── layer_factory.py │ │ │ │ ├── recurrent.py │ │ │ │ ├── recurrent_triton.py │ │ │ │ ├── simple.py │ │ │ │ └── triton_kernels.py │ │ │ ├── backend_utils.py │ │ │ ├── block.py │ │ │ ├── cell.py │ │ │ ├── layer.py │ │ │ └── layer_v1.py │ │ └── xlstm_block.py │ ├── checkpointing.py │ ├── components │ │ ├── __init__.py │ │ ├── conv.py │ │ ├── feedforward.py │ │ ├── init.py │ │ ├── linear_headwise.py │ │ └── normalization.py │ ├── training.py │ ├── utils.py │ ├── xlstm_block_stack.py │ └── xlstm_lm_model.py └── xlstm_pytorch │ ├── __init__.py │ ├── blocks │ ├── __init__.py │ ├── mlstm │ │ ├── __init__.py │ │ ├── backend │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ ├── config_utils.py │ │ │ ├── fwbw.py │ │ │ ├── layer_factory.py │ │ │ ├── simple.py │ │ │ ├── tl_utils.py │ │ │ └── triton_chunk.py │ │ ├── block.py │ │ ├── cell.py │ │ └── layer.py │ ├── slstm │ │ ├── __init__.py │ │ ├── block.py │ │ ├── cell.py │ │ ├── layer.py │ │ └── src │ │ │ ├── __init__.py │ │ │ ├── cuda │ │ │ ├── lstm_pointwise.cu │ │ │ ├── slstm.cc │ │ │ ├── slstm.h │ │ │ ├── slstm_backward.cu │ │ │ ├── slstm_backward_cut.cu │ │ │ ├── slstm_forward.cu │ │ │ ├── slstm_pointwise.cu │ │ │ └── slstm_pointwise.cuh │ │ │ ├── cuda_init.py │ │ │ ├── util │ │ │ ├── blas.cu │ │ │ ├── blas.h │ │ │ ├── cuda_error.cu │ │ │ ├── cuda_error.h │ │ │ ├── device_assert.h │ │ │ ├── inline_ops.cuh │ │ │ ├── inline_ops_2bf16.cuh │ │ │ ├── inline_ops_2fp16.cuh │ │ │ ├── inline_ops_bf16.cuh │ │ │ ├── inline_ops_fp16.cuh │ │ │ ├── inline_print.cuh │ │ │ ├── support.h │ │ │ └── util.h │ │ │ └── vanilla │ │ │ ├── __init__.py │ │ │ ├── lstm.py │ │ │ └── slstm.py │ └── xlstm_block.py │ ├── components │ ├── __init__.py │ ├── conv.py │ ├── feedforward.py │ ├── init.py │ ├── linear_headwise.py │ ├── ln.py │ └── util.py │ ├── utils.py │ ├── xlstm_block_stack.py │ └── xlstm_lm_model.py ├── resume_training.py ├── start_training.py ├── train_init_fns.py ├── trainer ├── __init__.py ├── base │ ├── __init__.py │ ├── param_utils.py │ └── trainer.py ├── callbacks │ ├── __init__.py │ ├── callback.py │ ├── checkpointing.py │ ├── extended_evaluation.py │ ├── lr_monitor.py │ └── profiler.py ├── data_module.py ├── eval │ └── lmeval_extended_evaluation.py ├── llm │ ├── __init__.py │ ├── sampling.py │ └── trainer.py ├── logger │ ├── __init__.py │ ├── base_logger.py │ ├── cmd_logging.py │ ├── file_logger.py │ ├── tensorboard_logger.py │ └── wandb_logger.py ├── metrics.py └── optimizer │ ├── __init__.py │ ├── ademamix.py │ ├── optimizer.py │ └── scheduler.py └── utils ├── __init__.py ├── error_logging_utils.py ├── model_param_handling ├── __init__.py ├── convert_checkpoint.py ├── convert_state_dict.py ├── handle_mlstm_simple.py ├── load.py └── store.py └── pytree_utils.py /.github/license-check/config.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/.github/license-check/config.json -------------------------------------------------------------------------------- /.github/license-check/header_c.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/.github/license-check/header_c.txt -------------------------------------------------------------------------------- /.github/license-check/header_python.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/.github/license-check/header_python.txt -------------------------------------------------------------------------------- /.github/workflows/flake8.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/.github/workflows/flake8.yml -------------------------------------------------------------------------------- /.github/workflows/interrogate.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/.github/workflows/interrogate.yml -------------------------------------------------------------------------------- /.github/workflows/license-header-check.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/.github/workflows/license-header-check.yml -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/.github/workflows/pre-commit.yml -------------------------------------------------------------------------------- /.github/workflows/pylint.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/.github/workflows/pylint.yml -------------------------------------------------------------------------------- /.github/workflows/pytest.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/.github/workflows/pytest.yml -------------------------------------------------------------------------------- /.github/workflows/sphinx.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/.github/workflows/sphinx.yml -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/.gitignore -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/.pre-commit-config.yaml -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/LICENSE -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/NOTICE -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/README.md -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/__init__.py -------------------------------------------------------------------------------- /configs/checkpointing/model_checkpointing.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/checkpointing/model_checkpointing.yaml -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/config.yaml -------------------------------------------------------------------------------- /configs/data/bigcode_stack_arrayrecord.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/bigcode_stack_arrayrecord.yaml -------------------------------------------------------------------------------- /configs/data/bigcode_stack_arrayrecord_train.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/bigcode_stack_arrayrecord_train.yaml -------------------------------------------------------------------------------- /configs/data/bigcode_stack_snapshot_arrayrecord.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/bigcode_stack_snapshot_arrayrecord.yaml -------------------------------------------------------------------------------- /configs/data/bigcode_stack_snapshot_arrayrecord_train.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/bigcode_stack_snapshot_arrayrecord_train.yaml -------------------------------------------------------------------------------- /configs/data/dclm_arrayrecord.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/dclm_arrayrecord.yaml -------------------------------------------------------------------------------- /configs/data/dclm_arrayrecord_eval.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/dclm_arrayrecord_eval.yaml -------------------------------------------------------------------------------- /configs/data/dclm_arrayrecord_eval_preprocessed.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/dclm_arrayrecord_eval_preprocessed.yaml -------------------------------------------------------------------------------- /configs/data/dclm_arrayrecord_train.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/dclm_arrayrecord_train.yaml -------------------------------------------------------------------------------- /configs/data/default_arrayrecord_eval.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/default_arrayrecord_eval.yaml -------------------------------------------------------------------------------- /configs/data/default_arrayrecord_train.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/default_arrayrecord_train.yaml -------------------------------------------------------------------------------- /configs/data/default_huggingface_ds_eval.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/default_huggingface_ds_eval.yaml -------------------------------------------------------------------------------- /configs/data/default_huggingface_ds_train.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/default_huggingface_ds_train.yaml -------------------------------------------------------------------------------- /configs/data/dolmino_mix_arrayrecord.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/dolmino_mix_arrayrecord.yaml -------------------------------------------------------------------------------- /configs/data/dolmino_mix_arrayrecord_train.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/dolmino_mix_arrayrecord_train.yaml -------------------------------------------------------------------------------- /configs/data/fineweb_edu_arrayrecord.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/fineweb_edu_arrayrecord.yaml -------------------------------------------------------------------------------- /configs/data/fineweb_edu_arrayrecord_train.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/fineweb_edu_arrayrecord_train.yaml -------------------------------------------------------------------------------- /configs/data/large_sft_datasets_arrayrecord.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/large_sft_datasets_arrayrecord.yaml -------------------------------------------------------------------------------- /configs/data/large_sft_datasets_arrayrecord_train.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/large_sft_datasets_arrayrecord_train.yaml -------------------------------------------------------------------------------- /configs/data/large_sft_datasets_pp2_arrayrecord.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/large_sft_datasets_pp2_arrayrecord.yaml -------------------------------------------------------------------------------- /configs/data/large_sft_datasets_pp2_arrayrecord_train.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/large_sft_datasets_pp2_arrayrecord_train.yaml -------------------------------------------------------------------------------- /configs/data/open_web_math_arrayrecord.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/open_web_math_arrayrecord.yaml -------------------------------------------------------------------------------- /configs/data/open_web_math_arrayrecord_train.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/open_web_math_arrayrecord_train.yaml -------------------------------------------------------------------------------- /configs/data/proofpile2_arrayrecord.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/proofpile2_arrayrecord.yaml -------------------------------------------------------------------------------- /configs/data/proofpile2_arrayrecord_train.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/proofpile2_arrayrecord_train.yaml -------------------------------------------------------------------------------- /configs/data/slimpajama_627B_arrayrecord.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/slimpajama_627B_arrayrecord.yaml -------------------------------------------------------------------------------- /configs/data/slimpajama_627B_arrayrecord_eval.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/slimpajama_627B_arrayrecord_eval.yaml -------------------------------------------------------------------------------- /configs/data/slimpajama_627B_arrayrecord_eval_preprocessed.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/slimpajama_627B_arrayrecord_eval_preprocessed.yaml -------------------------------------------------------------------------------- /configs/data/slimpajama_627B_arrayrecord_eval_preprocessed_gpt2.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/slimpajama_627B_arrayrecord_eval_preprocessed_gpt2.yaml -------------------------------------------------------------------------------- /configs/data/slimpajama_627B_arrayrecord_train.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/slimpajama_627B_arrayrecord_train.yaml -------------------------------------------------------------------------------- /configs/data/slimpajama_627B_huggingface_ds.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/slimpajama_627B_huggingface_ds.yaml -------------------------------------------------------------------------------- /configs/data/slimpajama_627B_huggingface_ds_eval.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/slimpajama_627B_huggingface_ds_eval.yaml -------------------------------------------------------------------------------- /configs/data/slimpajama_627B_huggingface_ds_train.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/slimpajama_627B_huggingface_ds_train.yaml -------------------------------------------------------------------------------- /configs/data/slimpajama_6B_arrayrecord.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/slimpajama_6B_arrayrecord.yaml -------------------------------------------------------------------------------- /configs/data/slimpajama_6B_arrayrecord_eval.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/slimpajama_6B_arrayrecord_eval.yaml -------------------------------------------------------------------------------- /configs/data/slimpajama_6B_arrayrecord_eval_preprocessed.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/slimpajama_6B_arrayrecord_eval_preprocessed.yaml -------------------------------------------------------------------------------- /configs/data/slimpajama_6B_arrayrecord_eval_preprocessed_gpt2.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/slimpajama_6B_arrayrecord_eval_preprocessed_gpt2.yaml -------------------------------------------------------------------------------- /configs/data/slimpajama_6B_arrayrecord_train.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/slimpajama_6B_arrayrecord_train.yaml -------------------------------------------------------------------------------- /configs/data/slimpajama_6B_huggingface_ds.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/slimpajama_6B_huggingface_ds.yaml -------------------------------------------------------------------------------- /configs/data/slimpajama_6B_huggingface_ds_eval.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/slimpajama_6B_huggingface_ds_eval.yaml -------------------------------------------------------------------------------- /configs/data/slimpajama_6B_huggingface_ds_train.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/slimpajama_6B_huggingface_ds_train.yaml -------------------------------------------------------------------------------- /configs/data/small_sft_datasets_arrayrecord.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/small_sft_datasets_arrayrecord.yaml -------------------------------------------------------------------------------- /configs/data/small_sft_datasets_arrayrecord_train.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/small_sft_datasets_arrayrecord_train.yaml -------------------------------------------------------------------------------- /configs/data/small_sft_datasets_extended_arrayrecord.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/small_sft_datasets_extended_arrayrecord.yaml -------------------------------------------------------------------------------- /configs/data/small_sft_datasets_extended_arrayrecord_train.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/small_sft_datasets_extended_arrayrecord_train.yaml -------------------------------------------------------------------------------- /configs/data/smol_cosmopedia_arrayrecord.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/smol_cosmopedia_arrayrecord.yaml -------------------------------------------------------------------------------- /configs/data/smol_cosmopedia_arrayrecord_train.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/smol_cosmopedia_arrayrecord_train.yaml -------------------------------------------------------------------------------- /configs/data/smol_fineweb_edu_arrayrecord.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/smol_fineweb_edu_arrayrecord.yaml -------------------------------------------------------------------------------- /configs/data/smol_fineweb_edu_arrayrecord_train.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/smol_fineweb_edu_arrayrecord_train.yaml -------------------------------------------------------------------------------- /configs/data/smoltalk_arrayrecord.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/smoltalk_arrayrecord.yaml -------------------------------------------------------------------------------- /configs/data/smoltalk_arrayrecord_train.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/smoltalk_arrayrecord_train.yaml -------------------------------------------------------------------------------- /configs/data/smoltalk_magpie_arrayrecord.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/smoltalk_magpie_arrayrecord.yaml -------------------------------------------------------------------------------- /configs/data/smoltalk_magpie_arrayrecord_train.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/smoltalk_magpie_arrayrecord_train.yaml -------------------------------------------------------------------------------- /configs/data/smoltalk_metamathqa_arrayrecord.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/smoltalk_metamathqa_arrayrecord.yaml -------------------------------------------------------------------------------- /configs/data/smoltalk_metamathqa_arrayrecord_train.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/smoltalk_metamathqa_arrayrecord_train.yaml -------------------------------------------------------------------------------- /configs/data/smoltalk_numina_arrayrecord.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/smoltalk_numina_arrayrecord.yaml -------------------------------------------------------------------------------- /configs/data/smoltalk_numina_arrayrecord_train.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/smoltalk_numina_arrayrecord_train.yaml -------------------------------------------------------------------------------- /configs/data/smoltalk_openhermes_arrayrecord.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/smoltalk_openhermes_arrayrecord.yaml -------------------------------------------------------------------------------- /configs/data/smoltalk_openhermes_arrayrecord_train.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/smoltalk_openhermes_arrayrecord_train.yaml -------------------------------------------------------------------------------- /configs/data/synthetic.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/synthetic.yaml -------------------------------------------------------------------------------- /configs/data/synthetic_eval.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/synthetic_eval.yaml -------------------------------------------------------------------------------- /configs/data/synthetic_train.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/synthetic_train.yaml -------------------------------------------------------------------------------- /configs/data/zyda2.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/zyda2.yaml -------------------------------------------------------------------------------- /configs/data/zyda2_dolmacc_arrayrecord.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/zyda2_dolmacc_arrayrecord.yaml -------------------------------------------------------------------------------- /configs/data/zyda2_dolmacc_arrayrecord_train.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/zyda2_dolmacc_arrayrecord_train.yaml -------------------------------------------------------------------------------- /configs/data/zyda2_zyda_arrayrecord.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/zyda2_zyda_arrayrecord.yaml -------------------------------------------------------------------------------- /configs/data/zyda2_zyda_arrayrecord_train.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/data/zyda2_zyda_arrayrecord_train.yaml -------------------------------------------------------------------------------- /configs/experiment/benchmark_mLSTMv1_7B.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/benchmark_mLSTMv1_7B.yaml -------------------------------------------------------------------------------- /configs/experiment/data_for_unit_test.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/data_for_unit_test.yaml -------------------------------------------------------------------------------- /configs/experiment/synthetic_experiment.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/synthetic_experiment.yaml -------------------------------------------------------------------------------- /configs/experiment/synthetic_experiment_slurm.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/synthetic_experiment_slurm.yaml -------------------------------------------------------------------------------- /configs/experiment/synthetic_experiment_slurm_sweep.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/synthetic_experiment_slurm_sweep.yaml -------------------------------------------------------------------------------- /configs/experiment/tiny_experiment_for_unit_testing.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/tiny_experiment_for_unit_testing.yaml -------------------------------------------------------------------------------- /configs/experiment/train_llama1.3B_dclm.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/train_llama1.3B_dclm.yaml -------------------------------------------------------------------------------- /configs/experiment/train_llama1.3B_slimpajama627b.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/train_llama1.3B_slimpajama627b.yaml -------------------------------------------------------------------------------- /configs/experiment/train_llama165M_slimpajama627b.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/train_llama165M_slimpajama627b.yaml -------------------------------------------------------------------------------- /configs/experiment/train_llama7B_dclm.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/train_llama7B_dclm.yaml -------------------------------------------------------------------------------- /configs/experiment/train_llama7B_dclm_cosine.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/train_llama7B_dclm_cosine.yaml -------------------------------------------------------------------------------- /configs/experiment/train_mLSTM1.3B_dclm_cosine.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/train_mLSTM1.3B_dclm_cosine.yaml -------------------------------------------------------------------------------- /configs/experiment/train_mLSTM1.3B_slimpajama627b.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/train_mLSTM1.3B_slimpajama627b.yaml -------------------------------------------------------------------------------- /configs/experiment/train_mLSTM120M_slimpajama627b.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/train_mLSTM120M_slimpajama627b.yaml -------------------------------------------------------------------------------- /configs/experiment/train_mLSTM165M_slimpajama627b.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/train_mLSTM165M_slimpajama627b.yaml -------------------------------------------------------------------------------- /configs/experiment/train_mLSTM165M_slimpajama6b.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/train_mLSTM165M_slimpajama6b.yaml -------------------------------------------------------------------------------- /configs/experiment/train_mLSTM7B_slimpajama627b.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/train_mLSTM7B_slimpajama627b.yaml -------------------------------------------------------------------------------- /configs/experiment/train_mLSTM_small_slimpajama6b_verify_slurm.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/train_mLSTM_small_slimpajama6b_verify_slurm.yaml -------------------------------------------------------------------------------- /configs/experiment/train_mLSTMv1_1.3B_dclm.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/train_mLSTMv1_1.3B_dclm.yaml -------------------------------------------------------------------------------- /configs/experiment/train_mLSTMv1_1.3B_dclm_cosine.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/train_mLSTMv1_1.3B_dclm_cosine.yaml -------------------------------------------------------------------------------- /configs/experiment/train_mLSTMv1_1.3B_dclm_long_run.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/train_mLSTMv1_1.3B_dclm_long_run.yaml -------------------------------------------------------------------------------- /configs/experiment/train_mLSTMv1_1.3B_fineweb_edu.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/train_mLSTMv1_1.3B_fineweb_edu.yaml -------------------------------------------------------------------------------- /configs/experiment/train_mLSTMv1_1.3B_slimpajama627b.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/train_mLSTMv1_1.3B_slimpajama627b.yaml -------------------------------------------------------------------------------- /configs/experiment/train_mLSTMv1_1.3B_zyda2.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/train_mLSTMv1_1.3B_zyda2.yaml -------------------------------------------------------------------------------- /configs/experiment/train_mLSTMv1_165M_slimpajama627b.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/train_mLSTMv1_165M_slimpajama627b.yaml -------------------------------------------------------------------------------- /configs/experiment/train_mLSTMv1_2.7B_dclm_cosine.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/train_mLSTMv1_2.7B_dclm_cosine.yaml -------------------------------------------------------------------------------- /configs/experiment/train_mLSTMv1_7B_dclm.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/train_mLSTMv1_7B_dclm.yaml -------------------------------------------------------------------------------- /configs/experiment/train_mLSTMv1_7B_dclm_cosine.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/train_mLSTMv1_7B_dclm_cosine.yaml -------------------------------------------------------------------------------- /configs/experiment/train_mLSTMv1_7B_dclm_cosine_igate_init.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/train_mLSTMv1_7B_dclm_cosine_igate_init.yaml -------------------------------------------------------------------------------- /configs/experiment/train_mLSTMv1_7B_dclm_cosine_no_softcap.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/train_mLSTMv1_7B_dclm_cosine_no_softcap.yaml -------------------------------------------------------------------------------- /configs/experiment/train_mLSTMv1_7B_dclm_long_run.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/train_mLSTMv1_7B_dclm_long_run.yaml -------------------------------------------------------------------------------- /configs/experiment/train_mLSTMv1_7B_dclm_long_run_finetune.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/train_mLSTMv1_7B_dclm_long_run_finetune.yaml -------------------------------------------------------------------------------- /configs/experiment/train_mLSTMv1_7B_dclm_long_run_wu8k.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/train_mLSTMv1_7B_dclm_long_run_wu8k.yaml -------------------------------------------------------------------------------- /configs/experiment/train_mLSTMv1_7B_slimpajama627b.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/experiment/train_mLSTMv1_7B_slimpajama627b.yaml -------------------------------------------------------------------------------- /configs/hydra/launcher/slurm_launcher.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/hydra/launcher/slurm_launcher.yaml -------------------------------------------------------------------------------- /configs/logger/default_logger.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/logger/default_logger.yaml -------------------------------------------------------------------------------- /configs/lr_monitor/lr_monitor_config.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/lr_monitor/lr_monitor_config.yaml -------------------------------------------------------------------------------- /configs/model/llama1.3B.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/model/llama1.3B.yaml -------------------------------------------------------------------------------- /configs/model/llama165M.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/model/llama165M.yaml -------------------------------------------------------------------------------- /configs/model/llama7B.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/model/llama7B.yaml -------------------------------------------------------------------------------- /configs/model/llama_default.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/model/llama_default.yaml -------------------------------------------------------------------------------- /configs/model/mLSTM1.3B.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/model/mLSTM1.3B.yaml -------------------------------------------------------------------------------- /configs/model/mLSTM120M.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/model/mLSTM120M.yaml -------------------------------------------------------------------------------- /configs/model/mLSTM165M.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/model/mLSTM165M.yaml -------------------------------------------------------------------------------- /configs/model/mLSTM7B.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/model/mLSTM7B.yaml -------------------------------------------------------------------------------- /configs/model/mLSTM_default.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/model/mLSTM_default.yaml -------------------------------------------------------------------------------- /configs/model/mLSTMv1_1.3B.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/model/mLSTMv1_1.3B.yaml -------------------------------------------------------------------------------- /configs/model/mLSTMv1_165M.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/model/mLSTMv1_165M.yaml -------------------------------------------------------------------------------- /configs/model/mLSTMv1_2.7B.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/model/mLSTMv1_2.7B.yaml -------------------------------------------------------------------------------- /configs/model/mLSTMv1_7B.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/model/mLSTMv1_7B.yaml -------------------------------------------------------------------------------- /configs/model/mLSTMv1_default.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/model/mLSTMv1_default.yaml -------------------------------------------------------------------------------- /configs/optimizer/adamw.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/optimizer/adamw.yaml -------------------------------------------------------------------------------- /configs/optimizer/sgd.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/optimizer/sgd.yaml -------------------------------------------------------------------------------- /configs/parallel/default.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/parallel/default.yaml -------------------------------------------------------------------------------- /configs/parallel/llama1.3B.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/parallel/llama1.3B.yaml -------------------------------------------------------------------------------- /configs/parallel/llama165M.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/parallel/llama165M.yaml -------------------------------------------------------------------------------- /configs/parallel/llama7B.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/parallel/llama7B.yaml -------------------------------------------------------------------------------- /configs/parallel/mLSTM1.3B.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/parallel/mLSTM1.3B.yaml -------------------------------------------------------------------------------- /configs/parallel/mLSTM120M.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/parallel/mLSTM120M.yaml -------------------------------------------------------------------------------- /configs/parallel/mLSTM165M.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/parallel/mLSTM165M.yaml -------------------------------------------------------------------------------- /configs/parallel/mLSTM7B.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/parallel/mLSTM7B.yaml -------------------------------------------------------------------------------- /configs/parallel/mLSTMv1_1.3B.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/parallel/mLSTMv1_1.3B.yaml -------------------------------------------------------------------------------- /configs/parallel/mLSTMv1_165M.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/parallel/mLSTMv1_165M.yaml -------------------------------------------------------------------------------- /configs/parallel/mLSTMv1_2.7B.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/parallel/mLSTMv1_2.7B.yaml -------------------------------------------------------------------------------- /configs/parallel/mLSTMv1_7B.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/parallel/mLSTMv1_7B.yaml -------------------------------------------------------------------------------- /configs/parallel/synthetic.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/parallel/synthetic.yaml -------------------------------------------------------------------------------- /configs/profiling/jax_profiling.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/profiling/jax_profiling.yaml -------------------------------------------------------------------------------- /configs/scheduler/cosine_decay.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/scheduler/cosine_decay.yaml -------------------------------------------------------------------------------- /configs/scheduler/exponential_decay.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/scheduler/exponential_decay.yaml -------------------------------------------------------------------------------- /configs/trainer/default_llm_trainer.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/trainer/default_llm_trainer.yaml -------------------------------------------------------------------------------- /configs/trainer/default_trainer.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/configs/trainer/default_trainer.yaml -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/docs/Makefile -------------------------------------------------------------------------------- /docs/_static/custom.css: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/docs/_static/custom.css -------------------------------------------------------------------------------- /docs/_static/nxai_logo_dark.svg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/docs/_static/nxai_logo_dark.svg -------------------------------------------------------------------------------- /docs/_static/nxai_logo_light.svg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/docs/_static/nxai_logo_light.svg -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/docs/conf.py -------------------------------------------------------------------------------- /docs/configuration_with_hydra.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/docs/configuration_with_hydra.md -------------------------------------------------------------------------------- /docs/dataset_preparation.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/docs/dataset_preparation.md -------------------------------------------------------------------------------- /docs/distributed_training.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/docs/distributed_training.md -------------------------------------------------------------------------------- /docs/example_training.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/docs/example_training.md -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/docs/index.rst -------------------------------------------------------------------------------- /docs/installation.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/docs/installation.md -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/docs/make.bat -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/docs/requirements.txt -------------------------------------------------------------------------------- /envs/environment_jax_0.4.32_cpu_python_3.11.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/envs/environment_jax_0.4.32_cpu_python_3.11.yml -------------------------------------------------------------------------------- /envs/environment_python_3.11_jax_0.4.34_cuda_12.6.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/envs/environment_python_3.11_jax_0.4.34_cuda_12.6.yml -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/pyproject.toml -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/pytest.ini -------------------------------------------------------------------------------- /readthedocs.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/readthedocs.yaml -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/check_config.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/scripts/check_config.py -------------------------------------------------------------------------------- /scripts/checkpoint_conversion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/scripts/checkpoint_conversion/__init__.py -------------------------------------------------------------------------------- /scripts/checkpoint_conversion/convert_mlstm_checkpoint_jax_to_torch_simple.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/scripts/checkpoint_conversion/convert_mlstm_checkpoint_jax_to_torch_simple.py -------------------------------------------------------------------------------- /scripts/checkpoint_conversion/run_checkpoint_conversion.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/scripts/checkpoint_conversion/run_checkpoint_conversion.py -------------------------------------------------------------------------------- /scripts/checkpoint_conversion/run_for_new_checkpoints.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/scripts/checkpoint_conversion/run_for_new_checkpoints.py -------------------------------------------------------------------------------- /scripts/data_processing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/scripts/data_processing/__init__.py -------------------------------------------------------------------------------- /scripts/data_processing/hf_to_arrayrecord.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/scripts/data_processing/hf_to_arrayrecord.py -------------------------------------------------------------------------------- /scripts/data_processing/preprocess_ar_dataset.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/scripts/data_processing/preprocess_ar_dataset.py -------------------------------------------------------------------------------- /scripts/data_processing/split_array_records_dataset.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/scripts/data_processing/split_array_records_dataset.py -------------------------------------------------------------------------------- /scripts/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/scripts/evaluation/__init__.py -------------------------------------------------------------------------------- /scripts/evaluation/run_huggingface_evaluation.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/scripts/evaluation/run_huggingface_evaluation.py -------------------------------------------------------------------------------- /scripts/evaluation/run_lmeval.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/scripts/evaluation/run_lmeval.py -------------------------------------------------------------------------------- /scripts/internal/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/scripts/internal/__init__.py -------------------------------------------------------------------------------- /scripts/internal/create_sft_dataset_symlinks.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/scripts/internal/create_sft_dataset_symlinks.py -------------------------------------------------------------------------------- /scripts/internal/generate_all_sft_datasets.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/scripts/internal/generate_all_sft_datasets.py -------------------------------------------------------------------------------- /scripts/internal/run_dataset_stats.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/scripts/internal/run_dataset_stats.py -------------------------------------------------------------------------------- /scripts/run_pytorch_kernels.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/scripts/run_pytorch_kernels.py -------------------------------------------------------------------------------- /scripts/speed_benchmark/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/scripts/speed_benchmark/__init__.py -------------------------------------------------------------------------------- /scripts/speed_benchmark/benchmark_with_hydra.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/scripts/speed_benchmark/benchmark_with_hydra.py -------------------------------------------------------------------------------- /scripts/speed_benchmark/run_benchmark.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/scripts/speed_benchmark/run_benchmark.py -------------------------------------------------------------------------------- /scripts/speed_benchmark/run_jax_kernel_comparison.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/scripts/speed_benchmark/run_jax_kernel_comparison.py -------------------------------------------------------------------------------- /scripts/speed_benchmark/run_mlstm_backend_benchmarks.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/scripts/speed_benchmark/run_mlstm_backend_benchmarks.py -------------------------------------------------------------------------------- /scripts/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/scripts/training/__init__.py -------------------------------------------------------------------------------- /scripts/training/get_cli_command_to_resume_training.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/scripts/training/get_cli_command_to_resume_training.py -------------------------------------------------------------------------------- /scripts/training/resume_training_with_hydra.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/scripts/training/resume_training_with_hydra.py -------------------------------------------------------------------------------- /scripts/training/run_train_llama_slimpajama.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/scripts/training/run_train_llama_slimpajama.py -------------------------------------------------------------------------------- /scripts/training/run_train_slimpajama.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/scripts/training/run_train_slimpajama.py -------------------------------------------------------------------------------- /scripts/training/run_train_synthetic.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/scripts/training/run_train_synthetic.py -------------------------------------------------------------------------------- /scripts/training/run_train_wikitext103.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/scripts/training/run_train_wikitext103.py -------------------------------------------------------------------------------- /scripts/training/train_with_hydra.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/scripts/training/train_with_hydra.py -------------------------------------------------------------------------------- /slurm/slurm_benchmark.job: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/slurm/slurm_benchmark.job -------------------------------------------------------------------------------- /slurm/slurm_evaluation.job: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/slurm/slurm_evaluation.job -------------------------------------------------------------------------------- /slurm/slurm_evaluation_llama.job: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/slurm/slurm_evaluation_llama.job -------------------------------------------------------------------------------- /slurm/slurm_llama_slimpajama.job: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/slurm/slurm_llama_slimpajama.job -------------------------------------------------------------------------------- /slurm/slurm_slimpajama.job: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/slurm/slurm_slimpajama.job -------------------------------------------------------------------------------- /slurm/slurm_wikitext103.job: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/slurm/slurm_wikitext103.job -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/__init__.py -------------------------------------------------------------------------------- /tests/config/test_config_parser.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/config/test_config_parser.py -------------------------------------------------------------------------------- /tests/config/test_equivalence_hydra_nonhydra.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/config/test_equivalence_hydra_nonhydra.py -------------------------------------------------------------------------------- /tests/config/test_hydra_configs.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/config/test_hydra_configs.py -------------------------------------------------------------------------------- /tests/config/test_hydra_model_trainer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/config/test_hydra_model_trainer.py -------------------------------------------------------------------------------- /tests/config/test_hydra_trainer_continuation.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/config/test_hydra_trainer_continuation.py -------------------------------------------------------------------------------- /tests/config/test_xla_flag_mesh_init.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/config/test_xla_flag_mesh_init.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/conftest.py -------------------------------------------------------------------------------- /tests/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/dataset/__init__.py -------------------------------------------------------------------------------- /tests/dataset/test_batch.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/dataset/test_batch.py -------------------------------------------------------------------------------- /tests/dataset/test_grain_data_processing.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/dataset/test_grain_data_processing.py -------------------------------------------------------------------------------- /tests/dataset/test_grain_iterator.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/dataset/test_grain_iterator.py -------------------------------------------------------------------------------- /tests/dataset/test_grain_transforms.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/dataset/test_grain_transforms.py -------------------------------------------------------------------------------- /tests/dataset/test_hf_data_processing.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/dataset/test_hf_data_processing.py -------------------------------------------------------------------------------- /tests/dataset/test_lmeval_preprocessing.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/dataset/test_lmeval_preprocessing.py -------------------------------------------------------------------------------- /tests/dataset/test_synthetic_dataloading.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/dataset/test_synthetic_dataloading.py -------------------------------------------------------------------------------- /tests/distributed/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/distributed/__init__.py -------------------------------------------------------------------------------- /tests/distributed/test_data_parallel.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/distributed/test_data_parallel.py -------------------------------------------------------------------------------- /tests/exception_handling/test_exception_handling.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/exception_handling/test_exception_handling.py -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/models/__init__.py -------------------------------------------------------------------------------- /tests/models/conftest.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/models/conftest.py -------------------------------------------------------------------------------- /tests/models/llama/test_attention.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/models/llama/test_attention.py -------------------------------------------------------------------------------- /tests/models/llama/test_llama.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/models/llama/test_llama.py -------------------------------------------------------------------------------- /tests/models/mlstm_simple/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/models/mlstm_simple/__init__.py -------------------------------------------------------------------------------- /tests/models/mlstm_simple/load_pretrained_jax_model.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/models/mlstm_simple/load_pretrained_jax_model.ipynb -------------------------------------------------------------------------------- /tests/models/mlstm_simple/load_pretrained_torch_model.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/models/mlstm_simple/load_pretrained_torch_model.ipynb -------------------------------------------------------------------------------- /tests/models/mlstm_simple/test_components.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/models/mlstm_simple/test_components.py -------------------------------------------------------------------------------- /tests/models/mlstm_simple/test_xlstm_jax_parallel_huggingface_equivalent.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/models/mlstm_simple/test_xlstm_jax_parallel_huggingface_equivalent.py -------------------------------------------------------------------------------- /tests/models/mlstm_simple/test_xlstm_jax_parallel_mlstm_simple_equivalent.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/models/mlstm_simple/test_xlstm_jax_parallel_mlstm_simple_equivalent.py -------------------------------------------------------------------------------- /tests/models/xlstm_clean/test_xlstm_jax_pytorch_equivalent.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/models/xlstm_clean/test_xlstm_jax_pytorch_equivalent.py -------------------------------------------------------------------------------- /tests/models/xlstm_clean/test_xlstm_single_device.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/models/xlstm_clean/test_xlstm_single_device.py -------------------------------------------------------------------------------- /tests/models/xlstm_parallel/test_xlstm_backend.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/models/xlstm_parallel/test_xlstm_backend.py -------------------------------------------------------------------------------- /tests/models/xlstm_parallel/test_xlstm_block.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/models/xlstm_parallel/test_xlstm_block.py -------------------------------------------------------------------------------- /tests/models/xlstm_parallel/test_xlstm_causal.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/models/xlstm_parallel/test_xlstm_causal.py -------------------------------------------------------------------------------- /tests/models/xlstm_parallel/test_xlstm_data_parallel.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/models/xlstm_parallel/test_xlstm_data_parallel.py -------------------------------------------------------------------------------- /tests/models/xlstm_parallel/test_xlstm_init.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/models/xlstm_parallel/test_xlstm_init.py -------------------------------------------------------------------------------- /tests/models/xlstm_parallel/test_xlstm_kernels.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/models/xlstm_parallel/test_xlstm_kernels.py -------------------------------------------------------------------------------- /tests/models/xlstm_parallel/test_xlstm_layer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/models/xlstm_parallel/test_xlstm_layer.py -------------------------------------------------------------------------------- /tests/models/xlstm_parallel/test_xlstm_tensor_parallel.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/models/xlstm_parallel/test_xlstm_tensor_parallel.py -------------------------------------------------------------------------------- /tests/trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/trainer/__init__.py -------------------------------------------------------------------------------- /tests/trainer/base/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/trainer/base/__init__.py -------------------------------------------------------------------------------- /tests/trainer/base/test_trainer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/trainer/base/test_trainer.py -------------------------------------------------------------------------------- /tests/trainer/callbacks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/trainer/callbacks/__init__.py -------------------------------------------------------------------------------- /tests/trainer/callbacks/test_callback.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/trainer/callbacks/test_callback.py -------------------------------------------------------------------------------- /tests/trainer/callbacks/test_checkpoint.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/trainer/callbacks/test_checkpoint.py -------------------------------------------------------------------------------- /tests/trainer/callbacks/test_extended_evaluation.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/trainer/callbacks/test_extended_evaluation.py -------------------------------------------------------------------------------- /tests/trainer/callbacks/test_load_model_config.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/trainer/callbacks/test_load_model_config.py -------------------------------------------------------------------------------- /tests/trainer/callbacks/test_lr_monitor.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/trainer/callbacks/test_lr_monitor.py -------------------------------------------------------------------------------- /tests/trainer/callbacks/test_profiler.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/trainer/callbacks/test_profiler.py -------------------------------------------------------------------------------- /tests/trainer/conftest.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/trainer/conftest.py -------------------------------------------------------------------------------- /tests/trainer/eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/trainer/eval/__init__.py -------------------------------------------------------------------------------- /tests/trainer/eval/test_lmeval.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/trainer/eval/test_lmeval.py -------------------------------------------------------------------------------- /tests/trainer/eval/test_lmeval_evaluation.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/trainer/eval/test_lmeval_evaluation.py -------------------------------------------------------------------------------- /tests/trainer/llm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/trainer/llm/__init__.py -------------------------------------------------------------------------------- /tests/trainer/llm/test_trainer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/trainer/llm/test_trainer.py -------------------------------------------------------------------------------- /tests/trainer/logger/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/trainer/logger/__init__.py -------------------------------------------------------------------------------- /tests/trainer/logger/test_base_logger.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/trainer/logger/test_base_logger.py -------------------------------------------------------------------------------- /tests/trainer/logger/test_file_logger.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/trainer/logger/test_file_logger.py -------------------------------------------------------------------------------- /tests/trainer/logger/test_tensorboard_logger.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/trainer/logger/test_tensorboard_logger.py -------------------------------------------------------------------------------- /tests/trainer/logger/test_wandb_logger.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/trainer/logger/test_wandb_logger.py -------------------------------------------------------------------------------- /tests/trainer/optimizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/trainer/optimizer/__init__.py -------------------------------------------------------------------------------- /tests/trainer/optimizer/test_optimizer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/trainer/optimizer/test_optimizer.py -------------------------------------------------------------------------------- /tests/trainer/optimizer/test_scheduler.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/trainer/optimizer/test_scheduler.py -------------------------------------------------------------------------------- /tests/trainer/test_metrics.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/tests/trainer/test_metrics.py -------------------------------------------------------------------------------- /xlstm_jax/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/__init__.py -------------------------------------------------------------------------------- /xlstm_jax/common_types.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/common_types.py -------------------------------------------------------------------------------- /xlstm_jax/configs.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/configs.py -------------------------------------------------------------------------------- /xlstm_jax/dataset/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/dataset/README.md -------------------------------------------------------------------------------- /xlstm_jax/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/dataset/__init__.py -------------------------------------------------------------------------------- /xlstm_jax/dataset/batch.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/dataset/batch.py -------------------------------------------------------------------------------- /xlstm_jax/dataset/configs.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/dataset/configs.py -------------------------------------------------------------------------------- /xlstm_jax/dataset/grain_batch_rampup.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/dataset/grain_batch_rampup.py -------------------------------------------------------------------------------- /xlstm_jax/dataset/grain_data_processing.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/dataset/grain_data_processing.py -------------------------------------------------------------------------------- /xlstm_jax/dataset/grain_iterator.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/dataset/grain_iterator.py -------------------------------------------------------------------------------- /xlstm_jax/dataset/grain_transforms.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/dataset/grain_transforms.py -------------------------------------------------------------------------------- /xlstm_jax/dataset/hf_tokenizer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/dataset/hf_tokenizer.py -------------------------------------------------------------------------------- /xlstm_jax/dataset/input_pipeline_interface.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/dataset/input_pipeline_interface.py -------------------------------------------------------------------------------- /xlstm_jax/dataset/lmeval_dataset.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/dataset/lmeval_dataset.py -------------------------------------------------------------------------------- /xlstm_jax/dataset/lmeval_pipeline.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/dataset/lmeval_pipeline.py -------------------------------------------------------------------------------- /xlstm_jax/dataset/multihost_dataloading.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/dataset/multihost_dataloading.py -------------------------------------------------------------------------------- /xlstm_jax/dataset/synthetic_dataloading.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/dataset/synthetic_dataloading.py -------------------------------------------------------------------------------- /xlstm_jax/define_hydra_schemas.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/define_hydra_schemas.py -------------------------------------------------------------------------------- /xlstm_jax/distributed/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/distributed/__init__.py -------------------------------------------------------------------------------- /xlstm_jax/distributed/array_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/distributed/array_utils.py -------------------------------------------------------------------------------- /xlstm_jax/distributed/data_parallel.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/distributed/data_parallel.py -------------------------------------------------------------------------------- /xlstm_jax/distributed/mesh_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/distributed/mesh_utils.py -------------------------------------------------------------------------------- /xlstm_jax/distributed/pipeline_parallel.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/distributed/pipeline_parallel.py -------------------------------------------------------------------------------- /xlstm_jax/distributed/single_gpu.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/distributed/single_gpu.py -------------------------------------------------------------------------------- /xlstm_jax/distributed/tensor_parallel.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/distributed/tensor_parallel.py -------------------------------------------------------------------------------- /xlstm_jax/distributed/xla_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/distributed/xla_utils.py -------------------------------------------------------------------------------- /xlstm_jax/import_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/import_utils.py -------------------------------------------------------------------------------- /xlstm_jax/main_train.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/main_train.py -------------------------------------------------------------------------------- /xlstm_jax/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/__init__.py -------------------------------------------------------------------------------- /xlstm_jax/models/configs.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/configs.py -------------------------------------------------------------------------------- /xlstm_jax/models/llama/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/llama/__init__.py -------------------------------------------------------------------------------- /xlstm_jax/models/llama/attention.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/llama/attention.py -------------------------------------------------------------------------------- /xlstm_jax/models/llama/feedforward.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/llama/feedforward.py -------------------------------------------------------------------------------- /xlstm_jax/models/llama/llama.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/llama/llama.py -------------------------------------------------------------------------------- /xlstm_jax/models/shared/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/shared/__init__.py -------------------------------------------------------------------------------- /xlstm_jax/models/shared/init.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/shared/init.py -------------------------------------------------------------------------------- /xlstm_jax/models/shared/lm_head.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/shared/lm_head.py -------------------------------------------------------------------------------- /xlstm_jax/models/shared/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/shared/utils.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_clean/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_clean/__init__.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_clean/blocks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_clean/blocks/__init__.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_clean/blocks/mlstm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_clean/blocks/mlstm/__init__.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_clean/blocks/mlstm/backend/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_clean/blocks/mlstm/backend/__init__.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_clean/blocks/mlstm/backend/config.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_clean/blocks/mlstm/backend/config.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_clean/blocks/mlstm/backend/config_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_clean/blocks/mlstm/backend/config_utils.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_clean/blocks/mlstm/backend/layer_factory.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_clean/blocks/mlstm/backend/layer_factory.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_clean/blocks/mlstm/backend/simple.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_clean/blocks/mlstm/backend/simple.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_clean/blocks/mlstm/block.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_clean/blocks/mlstm/block.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_clean/blocks/mlstm/cell.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_clean/blocks/mlstm/cell.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_clean/blocks/mlstm/layer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_clean/blocks/mlstm/layer.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_clean/blocks/xlstm_block.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_clean/blocks/xlstm_block.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_clean/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_clean/components/__init__.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_clean/components/conv.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_clean/components/conv.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_clean/components/feedforward.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_clean/components/feedforward.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_clean/components/init.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_clean/components/init.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_clean/components/linear_headwise.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_clean/components/linear_headwise.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_clean/components/ln.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_clean/components/ln.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_clean/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_clean/utils.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_clean/xlstm_block_stack.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_clean/xlstm_block_stack.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_clean/xlstm_lm_model.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_clean/xlstm_lm_model.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_parallel/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_parallel/__init__.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_parallel/benchmark.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_parallel/benchmark.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_parallel/blocks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_parallel/blocks/__init__.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_parallel/blocks/mlstm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_parallel/blocks/mlstm/__init__.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_parallel/blocks/mlstm/backend/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_parallel/blocks/mlstm/backend/__init__.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_parallel/blocks/mlstm/backend/attention.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_parallel/blocks/mlstm/backend/attention.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_parallel/blocks/mlstm/backend/config.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_parallel/blocks/mlstm/backend/config.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_parallel/blocks/mlstm/backend/config_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_parallel/blocks/mlstm/backend/config_utils.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_parallel/blocks/mlstm/backend/fwbw.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_parallel/blocks/mlstm/backend/fwbw.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_parallel/blocks/mlstm/backend/layer_factory.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_parallel/blocks/mlstm/backend/layer_factory.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_parallel/blocks/mlstm/backend/recurrent.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_parallel/blocks/mlstm/backend/recurrent.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_parallel/blocks/mlstm/backend/recurrent_triton.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_parallel/blocks/mlstm/backend/recurrent_triton.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_parallel/blocks/mlstm/backend/simple.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_parallel/blocks/mlstm/backend/simple.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_parallel/blocks/mlstm/backend/triton_kernels.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_parallel/blocks/mlstm/backend/triton_kernels.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_parallel/blocks/mlstm/backend_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_parallel/blocks/mlstm/backend_utils.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_parallel/blocks/mlstm/block.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_parallel/blocks/mlstm/block.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_parallel/blocks/mlstm/cell.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_parallel/blocks/mlstm/cell.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_parallel/blocks/mlstm/layer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_parallel/blocks/mlstm/layer.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_parallel/blocks/mlstm/layer_v1.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_parallel/blocks/mlstm/layer_v1.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_parallel/blocks/xlstm_block.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_parallel/blocks/xlstm_block.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_parallel/checkpointing.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_parallel/checkpointing.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_parallel/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_parallel/components/__init__.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_parallel/components/conv.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_parallel/components/conv.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_parallel/components/feedforward.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_parallel/components/feedforward.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_parallel/components/init.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_parallel/components/init.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_parallel/components/linear_headwise.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_parallel/components/linear_headwise.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_parallel/components/normalization.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_parallel/components/normalization.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_parallel/training.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_parallel/training.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_parallel/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_parallel/utils.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_parallel/xlstm_block_stack.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_parallel/xlstm_block_stack.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_parallel/xlstm_lm_model.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_parallel/xlstm_lm_model.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/__init__.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/__init__.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/mlstm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/mlstm/__init__.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/mlstm/backend/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/mlstm/backend/__init__.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/mlstm/backend/config.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/mlstm/backend/config.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/mlstm/backend/config_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/mlstm/backend/config_utils.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/mlstm/backend/fwbw.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/mlstm/backend/fwbw.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/mlstm/backend/layer_factory.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/mlstm/backend/layer_factory.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/mlstm/backend/simple.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/mlstm/backend/simple.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/mlstm/backend/tl_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/mlstm/backend/tl_utils.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/mlstm/backend/triton_chunk.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/mlstm/backend/triton_chunk.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/mlstm/block.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/mlstm/block.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/mlstm/cell.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/mlstm/cell.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/mlstm/layer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/mlstm/layer.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/slstm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/slstm/__init__.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/slstm/block.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/slstm/block.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/slstm/cell.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/slstm/cell.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/slstm/layer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/slstm/layer.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/__init__.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/cuda/lstm_pointwise.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/cuda/lstm_pointwise.cu -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/cuda/slstm.cc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/cuda/slstm.cc -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/cuda/slstm.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/cuda/slstm.h -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/cuda/slstm_backward.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/cuda/slstm_backward.cu -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/cuda/slstm_backward_cut.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/cuda/slstm_backward_cut.cu -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/cuda/slstm_forward.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/cuda/slstm_forward.cu -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/cuda/slstm_pointwise.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/cuda/slstm_pointwise.cu -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/cuda/slstm_pointwise.cuh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/cuda/slstm_pointwise.cuh -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/cuda_init.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/cuda_init.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/util/blas.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/util/blas.cu -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/util/blas.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/util/blas.h -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/util/cuda_error.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/util/cuda_error.cu -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/util/cuda_error.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/util/cuda_error.h -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/util/device_assert.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/util/device_assert.h -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/util/inline_ops.cuh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/util/inline_ops.cuh -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/util/inline_ops_2bf16.cuh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/util/inline_ops_2bf16.cuh -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/util/inline_ops_2fp16.cuh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/util/inline_ops_2fp16.cuh -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/util/inline_ops_bf16.cuh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/util/inline_ops_bf16.cuh -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/util/inline_ops_fp16.cuh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/util/inline_ops_fp16.cuh -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/util/inline_print.cuh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/util/inline_print.cuh -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/util/support.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/util/support.h -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/util/util.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/util/util.h -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/vanilla/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/vanilla/__init__.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/vanilla/lstm.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/vanilla/lstm.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/vanilla/slstm.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/slstm/src/vanilla/slstm.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/blocks/xlstm_block.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/blocks/xlstm_block.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/components/__init__.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/components/conv.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/components/conv.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/components/feedforward.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/components/feedforward.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/components/init.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/components/init.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/components/linear_headwise.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/components/linear_headwise.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/components/ln.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/components/ln.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/components/util.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/components/util.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/utils.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/xlstm_block_stack.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/xlstm_block_stack.py -------------------------------------------------------------------------------- /xlstm_jax/models/xlstm_pytorch/xlstm_lm_model.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/models/xlstm_pytorch/xlstm_lm_model.py -------------------------------------------------------------------------------- /xlstm_jax/resume_training.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/resume_training.py -------------------------------------------------------------------------------- /xlstm_jax/start_training.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/start_training.py -------------------------------------------------------------------------------- /xlstm_jax/train_init_fns.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/train_init_fns.py -------------------------------------------------------------------------------- /xlstm_jax/trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/trainer/__init__.py -------------------------------------------------------------------------------- /xlstm_jax/trainer/base/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/trainer/base/__init__.py -------------------------------------------------------------------------------- /xlstm_jax/trainer/base/param_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/trainer/base/param_utils.py -------------------------------------------------------------------------------- /xlstm_jax/trainer/base/trainer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/trainer/base/trainer.py -------------------------------------------------------------------------------- /xlstm_jax/trainer/callbacks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/trainer/callbacks/__init__.py -------------------------------------------------------------------------------- /xlstm_jax/trainer/callbacks/callback.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/trainer/callbacks/callback.py -------------------------------------------------------------------------------- /xlstm_jax/trainer/callbacks/checkpointing.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/trainer/callbacks/checkpointing.py -------------------------------------------------------------------------------- /xlstm_jax/trainer/callbacks/extended_evaluation.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/trainer/callbacks/extended_evaluation.py -------------------------------------------------------------------------------- /xlstm_jax/trainer/callbacks/lr_monitor.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/trainer/callbacks/lr_monitor.py -------------------------------------------------------------------------------- /xlstm_jax/trainer/callbacks/profiler.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/trainer/callbacks/profiler.py -------------------------------------------------------------------------------- /xlstm_jax/trainer/data_module.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/trainer/data_module.py -------------------------------------------------------------------------------- /xlstm_jax/trainer/eval/lmeval_extended_evaluation.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/trainer/eval/lmeval_extended_evaluation.py -------------------------------------------------------------------------------- /xlstm_jax/trainer/llm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/trainer/llm/__init__.py -------------------------------------------------------------------------------- /xlstm_jax/trainer/llm/sampling.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/trainer/llm/sampling.py -------------------------------------------------------------------------------- /xlstm_jax/trainer/llm/trainer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/trainer/llm/trainer.py -------------------------------------------------------------------------------- /xlstm_jax/trainer/logger/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/trainer/logger/__init__.py -------------------------------------------------------------------------------- /xlstm_jax/trainer/logger/base_logger.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/trainer/logger/base_logger.py -------------------------------------------------------------------------------- /xlstm_jax/trainer/logger/cmd_logging.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/trainer/logger/cmd_logging.py -------------------------------------------------------------------------------- /xlstm_jax/trainer/logger/file_logger.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/trainer/logger/file_logger.py -------------------------------------------------------------------------------- /xlstm_jax/trainer/logger/tensorboard_logger.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/trainer/logger/tensorboard_logger.py -------------------------------------------------------------------------------- /xlstm_jax/trainer/logger/wandb_logger.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/trainer/logger/wandb_logger.py -------------------------------------------------------------------------------- /xlstm_jax/trainer/metrics.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/trainer/metrics.py -------------------------------------------------------------------------------- /xlstm_jax/trainer/optimizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/trainer/optimizer/__init__.py -------------------------------------------------------------------------------- /xlstm_jax/trainer/optimizer/ademamix.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/trainer/optimizer/ademamix.py -------------------------------------------------------------------------------- /xlstm_jax/trainer/optimizer/optimizer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/trainer/optimizer/optimizer.py -------------------------------------------------------------------------------- /xlstm_jax/trainer/optimizer/scheduler.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/trainer/optimizer/scheduler.py -------------------------------------------------------------------------------- /xlstm_jax/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/utils/__init__.py -------------------------------------------------------------------------------- /xlstm_jax/utils/error_logging_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/utils/error_logging_utils.py -------------------------------------------------------------------------------- /xlstm_jax/utils/model_param_handling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/utils/model_param_handling/__init__.py -------------------------------------------------------------------------------- /xlstm_jax/utils/model_param_handling/convert_checkpoint.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/utils/model_param_handling/convert_checkpoint.py -------------------------------------------------------------------------------- /xlstm_jax/utils/model_param_handling/convert_state_dict.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/utils/model_param_handling/convert_state_dict.py -------------------------------------------------------------------------------- /xlstm_jax/utils/model_param_handling/handle_mlstm_simple.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/utils/model_param_handling/handle_mlstm_simple.py -------------------------------------------------------------------------------- /xlstm_jax/utils/model_param_handling/load.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/utils/model_param_handling/load.py -------------------------------------------------------------------------------- /xlstm_jax/utils/model_param_handling/store.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/utils/model_param_handling/store.py -------------------------------------------------------------------------------- /xlstm_jax/utils/pytree_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/xlstm-jax/HEAD/xlstm_jax/utils/pytree_utils.py --------------------------------------------------------------------------------