├── .github └── workflows │ └── publish.yml ├── .gitignore ├── .legacy ├── evaluate.py ├── generate.py └── tests │ └── logits_processing │ ├── penalize_frequency.py │ └── penalize_presence.py ├── .vscode ├── extensions.json └── settings.json ├── LICENSE ├── README.md ├── assets └── llama.png ├── compute_accuracy.py ├── determine_max_length.py ├── generate.py ├── hf_evaluate.py ├── hf_generate.py ├── len_dist.png ├── lib ├── __init__.py ├── array_utils │ └── __init__.py ├── data │ └── __init__.py ├── dataloader │ ├── LlamaDataLoader.py │ └── __init__.py ├── generation │ └── __init__.py ├── gsm_data │ ├── GSMDataset.py │ ├── __init__.py │ └── gsm_collate_fn.py ├── llama │ ├── ModelConfig.py │ ├── __init__.py │ ├── attention.py │ ├── decoder.py │ ├── decoder_block.py │ ├── dropout.py │ ├── embedding.py │ ├── kv_cache.py │ ├── llama.py │ ├── llama_model.py │ ├── rms_norm.py │ └── rotary_embedding.py ├── llama_params │ ├── __init__.py │ ├── convert_back_params.py │ └── convert_params.py ├── logits_processing │ └── __init__.py ├── loss │ └── __init__.py ├── multihost_utils │ ├── __init__.py │ ├── shard_array.py │ └── shard_model_params.py ├── param_utils │ ├── __init__.py │ ├── check_params_equal.py │ ├── load_params.py │ └── save_params.py ├── proc_init_utils │ ├── __init__.py │ └── initialisation.py ├── rand_utils │ └── __init__.py ├── seeding │ └── __init__.py └── tree_utils │ └── __init__.py ├── mypy.ini ├── podrun ├── requirements.txt ├── scripts ├── convert_params_runner.py ├── determine_params.py ├── run_tests.py └── sanity_check.py ├── tests ├── gsm_data │ └── GSMDataloader.py ├── llama │ ├── attention.py │ ├── decoder_block.py │ ├── embedding.py │ ├── llama_model.py │ ├── rms_norm.py │ └── rotary_embedding.py ├── muitihost_utils │ └── test.py ├── param_utils │ ├── check_params_equal.py │ └── save_params_bytes.py └── standalone │ ├── np_array_is_not_array.py │ ├── shape_is_tuple.py │ ├── static_argnum_in_nested_jit.py │ └── type_checking.py └── train.py /.github/workflows/publish.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/.github/workflows/publish.yml -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/.gitignore -------------------------------------------------------------------------------- /.legacy/evaluate.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/.legacy/evaluate.py -------------------------------------------------------------------------------- /.legacy/generate.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/.legacy/generate.py -------------------------------------------------------------------------------- /.legacy/tests/logits_processing/penalize_frequency.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/.legacy/tests/logits_processing/penalize_frequency.py -------------------------------------------------------------------------------- /.legacy/tests/logits_processing/penalize_presence.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/.legacy/tests/logits_processing/penalize_presence.py -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/.vscode/extensions.json -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/.vscode/settings.json -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/LICENSE -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/README.md -------------------------------------------------------------------------------- /assets/llama.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/assets/llama.png -------------------------------------------------------------------------------- /compute_accuracy.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/compute_accuracy.py -------------------------------------------------------------------------------- /determine_max_length.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/determine_max_length.py -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/generate.py -------------------------------------------------------------------------------- /hf_evaluate.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/hf_evaluate.py -------------------------------------------------------------------------------- /hf_generate.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/hf_generate.py -------------------------------------------------------------------------------- /len_dist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/len_dist.png -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/array_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/array_utils/__init__.py -------------------------------------------------------------------------------- /lib/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/data/__init__.py -------------------------------------------------------------------------------- /lib/dataloader/LlamaDataLoader.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/dataloader/LlamaDataLoader.py -------------------------------------------------------------------------------- /lib/dataloader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/dataloader/__init__.py -------------------------------------------------------------------------------- /lib/generation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/generation/__init__.py -------------------------------------------------------------------------------- /lib/gsm_data/GSMDataset.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/gsm_data/GSMDataset.py -------------------------------------------------------------------------------- /lib/gsm_data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/gsm_data/__init__.py -------------------------------------------------------------------------------- /lib/gsm_data/gsm_collate_fn.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/gsm_data/gsm_collate_fn.py -------------------------------------------------------------------------------- /lib/llama/ModelConfig.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/llama/ModelConfig.py -------------------------------------------------------------------------------- /lib/llama/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/llama/__init__.py -------------------------------------------------------------------------------- /lib/llama/attention.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/llama/attention.py -------------------------------------------------------------------------------- /lib/llama/decoder.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/llama/decoder.py -------------------------------------------------------------------------------- /lib/llama/decoder_block.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/llama/decoder_block.py -------------------------------------------------------------------------------- /lib/llama/dropout.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/llama/dropout.py -------------------------------------------------------------------------------- /lib/llama/embedding.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/llama/embedding.py -------------------------------------------------------------------------------- /lib/llama/kv_cache.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/llama/kv_cache.py -------------------------------------------------------------------------------- /lib/llama/llama.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/llama/llama.py -------------------------------------------------------------------------------- /lib/llama/llama_model.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/llama/llama_model.py -------------------------------------------------------------------------------- /lib/llama/rms_norm.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/llama/rms_norm.py -------------------------------------------------------------------------------- /lib/llama/rotary_embedding.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/llama/rotary_embedding.py -------------------------------------------------------------------------------- /lib/llama_params/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/llama_params/__init__.py -------------------------------------------------------------------------------- /lib/llama_params/convert_back_params.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/llama_params/convert_back_params.py -------------------------------------------------------------------------------- /lib/llama_params/convert_params.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/llama_params/convert_params.py -------------------------------------------------------------------------------- /lib/logits_processing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/logits_processing/__init__.py -------------------------------------------------------------------------------- /lib/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/loss/__init__.py -------------------------------------------------------------------------------- /lib/multihost_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/multihost_utils/__init__.py -------------------------------------------------------------------------------- /lib/multihost_utils/shard_array.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/multihost_utils/shard_array.py -------------------------------------------------------------------------------- /lib/multihost_utils/shard_model_params.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/multihost_utils/shard_model_params.py -------------------------------------------------------------------------------- /lib/param_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/param_utils/__init__.py -------------------------------------------------------------------------------- /lib/param_utils/check_params_equal.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/param_utils/check_params_equal.py -------------------------------------------------------------------------------- /lib/param_utils/load_params.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/param_utils/load_params.py -------------------------------------------------------------------------------- /lib/param_utils/save_params.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/param_utils/save_params.py -------------------------------------------------------------------------------- /lib/proc_init_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/proc_init_utils/__init__.py -------------------------------------------------------------------------------- /lib/proc_init_utils/initialisation.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/proc_init_utils/initialisation.py -------------------------------------------------------------------------------- /lib/rand_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/rand_utils/__init__.py -------------------------------------------------------------------------------- /lib/seeding/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/seeding/__init__.py -------------------------------------------------------------------------------- /lib/tree_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/lib/tree_utils/__init__.py -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/mypy.ini -------------------------------------------------------------------------------- /podrun: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/podrun -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/requirements.txt -------------------------------------------------------------------------------- /scripts/convert_params_runner.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/scripts/convert_params_runner.py -------------------------------------------------------------------------------- /scripts/determine_params.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/scripts/determine_params.py -------------------------------------------------------------------------------- /scripts/run_tests.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/scripts/run_tests.py -------------------------------------------------------------------------------- /scripts/sanity_check.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/scripts/sanity_check.py -------------------------------------------------------------------------------- /tests/gsm_data/GSMDataloader.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/tests/gsm_data/GSMDataloader.py -------------------------------------------------------------------------------- /tests/llama/attention.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/tests/llama/attention.py -------------------------------------------------------------------------------- /tests/llama/decoder_block.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/tests/llama/decoder_block.py -------------------------------------------------------------------------------- /tests/llama/embedding.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/tests/llama/embedding.py -------------------------------------------------------------------------------- /tests/llama/llama_model.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/tests/llama/llama_model.py -------------------------------------------------------------------------------- /tests/llama/rms_norm.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/tests/llama/rms_norm.py -------------------------------------------------------------------------------- /tests/llama/rotary_embedding.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/tests/llama/rotary_embedding.py -------------------------------------------------------------------------------- /tests/muitihost_utils/test.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/tests/muitihost_utils/test.py -------------------------------------------------------------------------------- /tests/param_utils/check_params_equal.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/tests/param_utils/check_params_equal.py -------------------------------------------------------------------------------- /tests/param_utils/save_params_bytes.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/tests/param_utils/save_params_bytes.py -------------------------------------------------------------------------------- /tests/standalone/np_array_is_not_array.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/tests/standalone/np_array_is_not_array.py -------------------------------------------------------------------------------- /tests/standalone/shape_is_tuple.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/tests/standalone/shape_is_tuple.py -------------------------------------------------------------------------------- /tests/standalone/static_argnum_in_nested_jit.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/tests/standalone/static_argnum_in_nested_jit.py -------------------------------------------------------------------------------- /tests/standalone/type_checking.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/tests/standalone/type_checking.py -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayaka14732/llama-2-jax/HEAD/train.py --------------------------------------------------------------------------------