├── tests ├── __init__.py ├── test_dataset_basic.py ├── test_make_wds_manifest.py ├── test_generate_kv_cache_time.py ├── test_param_parsing.py ├── test_tiny_generate_kv_cache_equal.py ├── test_training_tokens.py ├── test_save_load_from_main.py ├── test_grad_accum.py ├── test_generate_load_kv_cache_equal.py ├── test_custom_attention.py ├── test_tokenize_shuffle.py ├── shared.py ├── test_dataset_deterministic.py ├── test_loss_masking.py └── test_attention_masking.py ├── open_lm ├── __init__.py ├── utils │ ├── __init__.py │ ├── transformers │ │ ├── __init__.py │ │ ├── generation.py │ │ ├── hf_wrapper.py │ │ ├── convert_to_hf.py │ │ └── hf_config.py │ ├── llm_foundry_wrapper.py │ ├── update_manifest.py │ ├── make_wds_manifest.py │ ├── convert_llama.py │ ├── averaging_utils.py │ └── verify_converted_llama.ipynb ├── model_configs │ ├── __init__.py │ ├── aphid_neox.json │ ├── atom_neox.json │ ├── open_lm_11m.json │ ├── open_lm_11m_v2.json │ ├── open_lm_25m.json │ ├── open_lm_41m.json │ ├── open_lm_79m_v2.json │ ├── open_lm_87m.json │ ├── quark_neox.json │ ├── open_lm_154m_v2.json │ ├── open_lm_411m_v2.json │ ├── ant_neox.json │ ├── g3b_neox.json │ ├── l7b_neox.json │ ├── m1b_neox.json │ ├── marmot_neox.json │ ├── open_lm_160m.json │ ├── open_lm_1b.json │ ├── open_lm_3b.json │ ├── open_lm_7b.json │ ├── open_lm_test_tiny.json │ ├── potato_neox.json │ ├── m1b_tiktoken.json │ ├── open_lm_410m.json │ ├── open_lm_830m.json │ ├── mamba_130m.json │ ├── mamba_1b.json │ ├── mamba_7b.json │ ├── llama2_7b.json │ ├── open_lm_1b_old.json │ ├── mistral_7b.json │ ├── linear_1b.json │ ├── linear_7b.json │ ├── linear_tiny.json │ └── mistral_7b_linear.json ├── datapreprocess │ ├── __init__.py │ ├── ray │ │ ├── __init__.py │ │ ├── token_counter.py │ │ ├── readme.md │ │ └── ray_cluster_configs │ │ │ └── cluster_west.yaml │ ├── metadata │ │ └── rpj_lm_data.yaml │ ├── wiki_download.py │ ├── docs │ │ └── ray_cluster_setup.md │ └── make_assistant_data.py ├── positional_embedding │ ├── __init__.py │ ├── none.py │ ├── head_rotary.py │ ├── rotary.py │ └── llama_rotary.py ├── open_lm_hf │ ├── __init__.py │ ├── configuration_openlm.py │ ├── tokenization_openlm.py │ └── modeling_openlm.py ├── run_bench.sh ├── precision.py ├── losses.py ├── logger.py ├── scheduler.py ├── meters.py ├── tests │ └── test_accumulation.py ├── distributed.py ├── norms.py └── evaluate.py ├── eval ├── local_data │ └── .gitignore ├── in_memory_hf_eval.yaml └── eval_openlm_ckpt.py ├── sagemaker_train ├── .dockerignore ├── Dockerfile_update ├── cfg_sample.yaml └── Dockerfile ├── MANIFEST.in ├── .dockerignore ├── plots ├── fig1.png ├── logo.png ├── interpolation.png └── interpolation.py ├── requirements_test.txt ├── environment.yml ├── environment-tests.yml ├── pyproject.toml ├── requirements.txt ├── .pre-commit-config.yaml ├── scripts ├── train_example.sh ├── generate.py └── generate_without_hf.py ├── Makefile ├── LICENSE ├── .github └── workflows │ └── ci.yml ├── setup.py ├── .gitignore ├── AVERAGE.md └── MOE.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /open_lm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /open_lm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /eval/local_data/.gitignore: -------------------------------------------------------------------------------- 1 | !* 2 | -------------------------------------------------------------------------------- /open_lm/model_configs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sagemaker_train/.dockerignore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /open_lm/utils/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /open_lm/datapreprocess/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /open_lm/datapreprocess/ray/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /open_lm/positional_embedding/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include open_lm/model_configs/*.json -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | venv 2 | wandb 3 | logs 4 | checkpoints 5 | -------------------------------------------------------------------------------- /plots/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/linear_open_lm/HEAD/plots/fig1.png -------------------------------------------------------------------------------- /plots/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/linear_open_lm/HEAD/plots/logo.png -------------------------------------------------------------------------------- /plots/interpolation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/linear_open_lm/HEAD/plots/interpolation.png -------------------------------------------------------------------------------- /open_lm/positional_embedding/none.py: -------------------------------------------------------------------------------- 1 | def identity_with_cast(q, k, v, offset=0): 2 | return q.to(v.dtype), k.to(v.dtype), v 3 | -------------------------------------------------------------------------------- /requirements_test.txt: -------------------------------------------------------------------------------- 1 | black==23.11.0 2 | pytest-cov==3.0.0 3 | pytest-xdist==2.5.0 4 | pytest==7.0.1 5 | tensorboard==2.14.1 6 | llm-foundry>=0.4.0 7 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: open_lm 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.10 6 | - pip 7 | - pip: 8 | - -r requirements.txt 9 | - -e . -------------------------------------------------------------------------------- /open_lm/open_lm_hf/__init__.py: -------------------------------------------------------------------------------- 1 | from .configuration_openlm import OpenLMConfig 2 | from .modeling_openlm import OpenLMForCausalLM 3 | from .tokenization_openlm import OpenLMTokenizerFast 4 | -------------------------------------------------------------------------------- /environment-tests.yml: -------------------------------------------------------------------------------- 1 | name: open_lm_tests 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.10 6 | - pip 7 | - pip: 8 | - -r requirements.txt 9 | - pytest 10 | -------------------------------------------------------------------------------- /open_lm/model_configs/aphid_neox.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dim": 488, 3 | "n_layers": 12, 4 | "n_heads": 12, 5 | "seq_len": 2048, 6 | "vocab_size": 50432, 7 | "post_embed_norm": false, 8 | "weight_tying": false 9 | } -------------------------------------------------------------------------------- /open_lm/model_configs/atom_neox.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dim": 192, 3 | "n_layers": 12, 4 | "n_heads": 12, 5 | "seq_len": 2048, 6 | "vocab_size": 50432, 7 | "post_embed_norm": false, 8 | "weight_tying": false 9 | } -------------------------------------------------------------------------------- /open_lm/model_configs/open_lm_11m.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dim": 96, 3 | "n_layers": 12, 4 | "n_heads": 12, 5 | "seq_len": 2048, 6 | "vocab_size": 50432, 7 | "post_embed_norm": false, 8 | "weight_tying": false 9 | } -------------------------------------------------------------------------------- /open_lm/model_configs/open_lm_11m_v2.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dim": 96, 3 | "n_layers": 8, 4 | "n_heads": 4, 5 | "seq_len": 2048, 6 | "vocab_size": 50432, 7 | "post_embed_norm": false, 8 | "weight_tying": false 9 | } -------------------------------------------------------------------------------- /open_lm/model_configs/open_lm_25m.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dim": 192, 3 | "n_layers": 12, 4 | "n_heads": 12, 5 | "seq_len": 2048, 6 | "vocab_size": 50432, 7 | "post_embed_norm": false, 8 | "weight_tying": false 9 | } -------------------------------------------------------------------------------- /open_lm/model_configs/open_lm_41m.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dim": 288, 3 | "n_layers": 12, 4 | "n_heads": 12, 5 | "seq_len": 2048, 6 | "vocab_size": 50432, 7 | "post_embed_norm": false, 8 | "weight_tying": false 9 | } -------------------------------------------------------------------------------- /open_lm/model_configs/open_lm_79m_v2.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dim": 512, 3 | "n_layers": 8, 4 | "n_heads": 4, 5 | "seq_len": 2048, 6 | "vocab_size": 50432, 7 | "post_embed_norm": false, 8 | "weight_tying": false 9 | } -------------------------------------------------------------------------------- /open_lm/model_configs/open_lm_87m.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dim": 488, 3 | "n_layers": 12, 4 | "n_heads": 12, 5 | "seq_len": 2048, 6 | "vocab_size": 50432, 7 | "post_embed_norm": false, 8 | "weight_tying": false 9 | } -------------------------------------------------------------------------------- /open_lm/model_configs/quark_neox.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dim": 96, 3 | "n_layers": 12, 4 | "n_heads": 12, 5 | "seq_len": 2048, 6 | "vocab_size": 50432, 7 | "post_embed_norm": false, 8 | "weight_tying": false 9 | } -------------------------------------------------------------------------------- /open_lm/model_configs/open_lm_154m_v2.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dim": 576, 3 | "n_layers": 24, 4 | "n_heads": 8, 5 | "seq_len": 2048, 6 | "vocab_size": 50432, 7 | "post_embed_norm": false, 8 | "weight_tying": false 9 | } -------------------------------------------------------------------------------- /open_lm/model_configs/open_lm_411m_v2.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dim": 1024, 3 | "n_layers": 24, 4 | "n_heads": 8, 5 | "seq_len": 2048, 6 | "vocab_size": 50432, 7 | "post_embed_norm": false, 8 | "weight_tying": false 9 | } -------------------------------------------------------------------------------- /open_lm/model_configs/ant_neox.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dim": 768, 3 | "n_layers": 12, 4 | "n_heads": 12, 5 | "seq_len": 2048, 6 | "vocab_size": 50432, 7 | "post_embed_norm": false, 8 | "weight_tying": false 9 | } 10 | -------------------------------------------------------------------------------- /open_lm/model_configs/g3b_neox.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dim": 2560, 3 | "n_layers": 32, 4 | "n_heads": 32, 5 | "seq_len": 2048, 6 | "vocab_size": 50432, 7 | "post_embed_norm": false, 8 | "weight_tying": false 9 | } 10 | -------------------------------------------------------------------------------- /open_lm/model_configs/l7b_neox.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dim": 4096, 3 | "n_layers": 32, 4 | "n_heads": 32, 5 | "seq_len": 2048, 6 | "vocab_size": 50432, 7 | "post_embed_norm": false, 8 | "weight_tying": false 9 | } 10 | -------------------------------------------------------------------------------- /open_lm/model_configs/m1b_neox.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dim": 2048, 3 | "n_layers": 24, 4 | "n_heads": 16, 5 | "seq_len": 2048, 6 | "vocab_size": 50432, 7 | "post_embed_norm": false, 8 | "weight_tying": false 9 | } 10 | -------------------------------------------------------------------------------- /open_lm/model_configs/marmot_neox.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dim": 1536, 3 | "n_layers": 24, 4 | "n_heads": 16, 5 | "seq_len": 2048, 6 | "vocab_size": 50432, 7 | "post_embed_norm": false, 8 | "weight_tying": false 9 | } 10 | -------------------------------------------------------------------------------- /open_lm/model_configs/open_lm_160m.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dim": 768, 3 | "n_layers": 12, 4 | "n_heads": 12, 5 | "seq_len": 2048, 6 | "vocab_size": 50432, 7 | "post_embed_norm": false, 8 | "weight_tying": false 9 | } 10 | -------------------------------------------------------------------------------- /open_lm/model_configs/open_lm_1b.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dim": 2048, 3 | "n_layers": 24, 4 | "n_heads": 16, 5 | "seq_len": 2048, 6 | "vocab_size": 50432, 7 | "post_embed_norm": false, 8 | "weight_tying": false 9 | } 10 | -------------------------------------------------------------------------------- /open_lm/model_configs/open_lm_3b.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dim": 2560, 3 | "n_layers": 32, 4 | "n_heads": 32, 5 | "seq_len": 2048, 6 | "vocab_size": 50432, 7 | "post_embed_norm": false, 8 | "weight_tying": false 9 | } 10 | -------------------------------------------------------------------------------- /open_lm/model_configs/open_lm_7b.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dim": 4096, 3 | "n_layers": 32, 4 | "n_heads": 32, 5 | "seq_len": 2048, 6 | "vocab_size": 50432, 7 | "post_embed_norm": false, 8 | "weight_tying": false 9 | } 10 | -------------------------------------------------------------------------------- /open_lm/model_configs/open_lm_test_tiny.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dim": 16, 3 | "n_layers": 2, 4 | "n_heads": 2, 5 | "seq_len": 16, 6 | "vocab_size": 16, 7 | "post_embed_norm": false, 8 | "weight_tying": false 9 | } 10 | -------------------------------------------------------------------------------- /open_lm/model_configs/potato_neox.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dim": 1024, 3 | "n_layers": 24, 4 | "n_heads": 16, 5 | "seq_len": 2048, 6 | "vocab_size": 50432, 7 | "post_embed_norm": false, 8 | "weight_tying": false 9 | } 10 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | 4 | [tool.pytest.ini_options] 5 | markers = [ 6 | "slow: marks tests as slow", 7 | "gpu: marks tests as requiring gpus", 8 | "s3: marks tests as requiring s3 creds", 9 | ] 10 | -------------------------------------------------------------------------------- /open_lm/model_configs/m1b_tiktoken.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dim": 2048, 3 | "n_layers": 24, 4 | "n_heads": 16, 5 | "seq_len": 2048, 6 | "vocab_size": 50304, 7 | "post_embed_norm": false, 8 | "weight_tying": false 9 | } 10 | -------------------------------------------------------------------------------- /open_lm/model_configs/open_lm_410m.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dim": 1024, 3 | "n_layers": 24, 4 | "n_heads": 16, 5 | "seq_len": 2048, 6 | "vocab_size": 50432, 7 | "post_embed_norm": false, 8 | "weight_tying": false 9 | } 10 | -------------------------------------------------------------------------------- /open_lm/model_configs/open_lm_830m.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dim": 1536, 3 | "n_layers": 24, 4 | "n_heads": 16, 5 | "seq_len": 2048, 6 | "vocab_size": 50432, 7 | "post_embed_norm": false, 8 | "weight_tying": false 9 | } 10 | -------------------------------------------------------------------------------- /open_lm/model_configs/mamba_130m.json: -------------------------------------------------------------------------------- 1 | { 2 | "d_model": 768, 3 | "n_layer": 12, 4 | "vocab_size": 50432, 5 | "seq_len": 2048, 6 | "ssm_cfg": {}, 7 | "rms_norm": true, 8 | "residual_in_fp32": true, 9 | "fused_add_norm": true, 10 | "pad_vocab_size_multiple": 8 11 | } -------------------------------------------------------------------------------- /open_lm/model_configs/mamba_1b.json: -------------------------------------------------------------------------------- 1 | { 2 | "d_model": 2048, 3 | "n_layer": 36, 4 | "vocab_size": 50432, 5 | "seq_len": 2048, 6 | "ssm_cfg": {}, 7 | "rms_norm": true, 8 | "residual_in_fp32": true, 9 | "fused_add_norm": true, 10 | "pad_vocab_size_multiple": 8 11 | } -------------------------------------------------------------------------------- /open_lm/model_configs/mamba_7b.json: -------------------------------------------------------------------------------- 1 | { 2 | "d_model": 4096, 3 | "n_layer": 64, 4 | "vocab_size": 50432, 5 | "seq_len": 2048, 6 | "ssm_cfg": {}, 7 | "rms_norm": true, 8 | "residual_in_fp32": true, 9 | "fused_add_norm": true, 10 | "pad_vocab_size_multiple": 8 11 | } -------------------------------------------------------------------------------- /open_lm/model_configs/llama2_7b.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dim": 4096, 3 | "n_layers": 32, 4 | "n_heads": 32, 5 | "seq_len": 4096, 6 | "vocab_size": 32000, 7 | "post_embed_norm": false, 8 | "weight_tying": false, 9 | "positional_embedding_type": "llama_rotary", 10 | "model_norm": "rms_norm", 11 | "ffn_type": "swiglu" 12 | } -------------------------------------------------------------------------------- /open_lm/model_configs/open_lm_1b_old.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dim": 2048, 3 | "n_layers": 24, 4 | "n_heads": 16, 5 | "seq_len": 2048, 6 | "vocab_size": 50432, 7 | "post_embed_norm": false, 8 | "weight_tying": false, 9 | "qk_norm": false, 10 | "ffn_type": "swiglu", 11 | "model_norm": "default_layer_norm", 12 | "positional_embedding_type": "head_rotary" 13 | } -------------------------------------------------------------------------------- /open_lm/model_configs/mistral_7b.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dim": 4096, 3 | "intermediate_dim_ffn": 14336, 4 | "n_layers": 32, 5 | "n_heads": 32, 6 | "n_heads_kv": 8, 7 | "seq_len": 2048, 8 | "vocab_size": 32000, 9 | "post_embed_norm": false, 10 | "weight_tying": false, 11 | "qk_norm": false, 12 | "model_norm": "rms_norm", 13 | "positional_embedding_type": "rotary", 14 | "ffn_type": "swiglu", 15 | "attn_name": "xformers_attn" 16 | } -------------------------------------------------------------------------------- /open_lm/model_configs/linear_1b.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dim": 2048, 3 | "n_layers": 24, 4 | "n_heads": 16, 5 | "seq_len": 2048, 6 | "vocab_size": 50432, 7 | "post_embed_norm": false, 8 | "weight_tying": false, 9 | "qk_norm": true, 10 | "model_norm": "gain_only_lp_layer_norm", 11 | "positional_embedding_type": "rotary", 12 | "attn_name": "linear_attn", 13 | "use_decay": true, 14 | "use_retnet_slopes": false, 15 | "decay_start": null 16 | } -------------------------------------------------------------------------------- /open_lm/model_configs/linear_7b.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dim": 4096, 3 | "n_layers": 32, 4 | "n_heads": 32, 5 | "seq_len": 2048, 6 | "vocab_size": 50432, 7 | "post_embed_norm": false, 8 | "weight_tying": false, 9 | "qk_norm": false, 10 | "model_norm": "gain_only_lp_layer_norm", 11 | "positional_embedding_type": "rotary", 12 | "attn_name": "linear_attn", 13 | "use_decay": true, 14 | "use_retnet_slopes": false, 15 | "decay_start": null 16 | } -------------------------------------------------------------------------------- /open_lm/model_configs/linear_tiny.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dim": 256, 3 | "n_layers": 6, 4 | "n_heads": 8, 5 | "seq_len": 2048, 6 | "vocab_size": 50432, 7 | "post_embed_norm": false, 8 | "weight_tying": false, 9 | "qk_norm": true, 10 | "model_norm": "gain_only_lp_layer_norm", 11 | "positional_embedding_type": "rotary", 12 | "attn_name": "linear_attn", 13 | "use_decay": true, 14 | "use_retnet_slopes": false, 15 | "decay_start": null 16 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | xformers>=0.0.22 3 | tiktoken 4 | wandb 5 | webdataset 6 | pandas==2.1.4 7 | fsspec 8 | tqdm 9 | jsonlines 10 | boto3==1.26.90 11 | Pillow 12 | zstandard 13 | pysimdjson 14 | cloudpathlib 15 | datasets 16 | multiprocess>=0.70.11 17 | dill 18 | huggingface_hub 19 | pre-commit 20 | ray[all] 21 | loguru 22 | jsonlines 23 | transformers 24 | s3fs 25 | wikipedia 26 | ipython 27 | mosaicml 28 | lightning_attn @ git+https://github.com/OpenNLPLab/lightning-attention.git 29 | -------------------------------------------------------------------------------- /sagemaker_train/Dockerfile_update: -------------------------------------------------------------------------------- 1 | ARG BASE_DOCKER 2 | # Dockerfile that updates the container with new code. 3 | # SageMaker PyTorch image 4 | FROM ${BASE_DOCKER} 5 | 6 | # /opt/ml and all subdirectories are utilized by SageMaker, use the /code subdirectory to store your user code. 7 | COPY . /opt/ml/code/ 8 | 9 | # RUN pip install -e /opt/ml/code/ 10 | 11 | # Prevent sagemaker from installing requirements again. 12 | RUN rm /opt/ml/code/setup.py 13 | RUN rm /opt/ml/code/requirements.txt 14 | 15 | ENV SAGEMAKER_PROGRAM open_lm/main.py 16 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | # Using this mirror lets us use mypyc-compiled black, which is about 2x faster 3 | - repo: https://github.com/psf/black-pre-commit-mirror 4 | rev: 23.11.0 5 | hooks: 6 | - id: black 7 | # It is recommended to specify the latest version of Python 8 | # supported by your project here, or alternatively use 9 | # pre-commit's default_language_version, see 10 | # https://pre-commit.com/#top_level-default_language_version 11 | language_version: python3.11 12 | args: [--line-length=120] -------------------------------------------------------------------------------- /open_lm/run_bench.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BATCHSIZE=1 4 | MODEL="large2048" 5 | EXP_NAME="benchmark-$MODEL" 6 | 7 | torchrun --nproc-per-node 1 -m benchmark.main \ 8 | --train-data "pipe:aws s3 cp s3://s-laion/redpajama-tars/8192-v1/{0..7}/shard-{0000000..0000300}.tar -" \ 9 | --train-num-samples 30720 \ 10 | --workers 6 \ 11 | --precision amp_bfloat16 \ 12 | --grad-checkpointing \ 13 | --grad-clip-norm 1 \ 14 | --log-every-n-steps 1 \ 15 | --fsdp \ 16 | --profile \ 17 | --batch-size $BATCHSIZE \ 18 | --model $MODEL \ 19 | --name $EXP_NAME \ 20 | -------------------------------------------------------------------------------- /open_lm/precision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from contextlib import suppress 3 | 4 | 5 | def get_autocast(precision): 6 | if precision == "amp": 7 | return torch.cuda.amp.autocast if torch.cuda.is_available() else torch.cpu.amp.autocast 8 | elif precision == "amp_bfloat16" or precision == "amp_bf16": 9 | # amp_bfloat16 is more stable than amp float16 for clip training 10 | autocast_fn = torch.cuda.amp.autocast if torch.cuda.is_available() else torch.cpu.amp.autocast 11 | return lambda: autocast_fn(dtype=torch.bfloat16) 12 | else: 13 | return suppress 14 | -------------------------------------------------------------------------------- /open_lm/model_configs/mistral_7b_linear.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dim": 4096, 3 | "intermediate_dim_ffn": 14336, 4 | "n_layers": 32, 5 | "n_heads": 32, 6 | "n_heads_kv": 8, 7 | "seq_len": 2048, 8 | "vocab_size": 32000, 9 | "post_embed_norm": false, 10 | "weight_tying": false, 11 | "qk_norm": false, 12 | "qk_head_dim": 128, 13 | "v_head_dim": 128, 14 | "model_norm": "rms_norm", 15 | "positional_embedding_type": "rotary", 16 | "ffn_type": "swiglu", 17 | "attn_name": "linear_attn", 18 | "use_decay": true, 19 | "use_retnet_slopes": false, 20 | "decay_start": null 21 | } -------------------------------------------------------------------------------- /open_lm/open_lm_hf/configuration_openlm.py: -------------------------------------------------------------------------------- 1 | # Follows OLMo's HF template 2 | 3 | """ 4 | OpenLM configuration 5 | """ 6 | 7 | from transformers import AutoConfig, PretrainedConfig 8 | from transformers.utils import logging 9 | 10 | from open_lm.model import Params 11 | 12 | logger = logging.get_logger(__name__) 13 | 14 | 15 | class OpenLMConfig(PretrainedConfig): 16 | model_type = "openlm" 17 | 18 | def __init__(self, **kwargs): 19 | kwargs["architectures"] = ["OpenLMForCausalLM"] 20 | super().__init__(**kwargs) 21 | 22 | 23 | # Register the config class so that it is available for transformer pipelines, auto-loading etc. 24 | AutoConfig.register("openlm", OpenLMConfig) 25 | -------------------------------------------------------------------------------- /open_lm/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn import CrossEntropyLoss 4 | 5 | 6 | class CrossEntropyLossWithZLoss(CrossEntropyLoss): 7 | def __init__( 8 | self, 9 | eps: float = 1e-4, 10 | weight: Tensor = None, 11 | size_average=None, 12 | ignore_index: int = -100, 13 | reduce=None, 14 | reduction: str = "mean", 15 | label_smoothing: float = 0, 16 | ) -> None: 17 | super().__init__(weight, size_average, ignore_index, reduce, reduction, label_smoothing) 18 | self.eps = eps 19 | 20 | def forward(self, input: Tensor, target: Tensor) -> Tensor: 21 | return super().forward(input, target) + self.eps * torch.square(torch.logsumexp(input, dim=-1)).mean() 22 | -------------------------------------------------------------------------------- /open_lm/open_lm_hf/tokenization_openlm.py: -------------------------------------------------------------------------------- 1 | # Follows OLMo's HF template 2 | 3 | from transformers import AutoTokenizer, PreTrainedTokenizerFast 4 | 5 | from open_lm.open_lm_hf.configuration_openlm import OpenLMConfig 6 | 7 | 8 | class OpenLMTokenizerFast(PreTrainedTokenizerFast): 9 | # Note: OpenLM's tokenizer is already a wrapper around huggingface. This is potentially unnecessary. 10 | pass 11 | 12 | # def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: 13 | # # This is required to make the implementation complete. 14 | # pass 15 | 16 | 17 | # Register the tokenizer class so that it is available for transformer pipelines, auto-loading etc. 18 | AutoTokenizer.register(OpenLMConfig, fast_tokenizer_class=OpenLMTokenizerFast) 19 | -------------------------------------------------------------------------------- /tests/test_dataset_basic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from open_lm.data import get_wds_dataset, sample_chunk 4 | from tests.shared import MockDataArgs 5 | 6 | 7 | def test_dataloader_no_crash(): 8 | # basic test to make sure the datalaoder does not crash 9 | args = MockDataArgs() 10 | di = get_wds_dataset(args, True) 11 | 12 | for _ in di.dataloader: 13 | pass 14 | 15 | assert True 16 | 17 | 18 | def test_dataloader_shape(): 19 | # basic test to make sure the datalaoder does not crash 20 | args = MockDataArgs() 21 | di = get_wds_dataset(args, True) 22 | 23 | batch = next(iter(di.dataloader)) 24 | (texts,) = batch 25 | inputs, targets = sample_chunk(torch.LongTensor(texts), args) 26 | assert inputs.shape[-1] == args.seq_len 27 | assert targets.shape[-1] == args.seq_len 28 | -------------------------------------------------------------------------------- /open_lm/datapreprocess/metadata/rpj_lm_data.yaml: -------------------------------------------------------------------------------- 1 | sources: 2 | - source: "LMDATA" 3 | markers: ["lmdata"] 4 | - source: "COMMON_CRAWL" 5 | markers: ["common_crawl"] 6 | - source: "C4" 7 | markers: ["c4"] 8 | - source: "GITHUB" 9 | markers: ["github"] 10 | - source: "WIKIPEDIA" 11 | markers: ["wikipedia"] 12 | - source: "BOOKS" 13 | markers: ["book"] 14 | - source: "ARXIV" 15 | markers: ["arxiv"] 16 | - source: "STACKEXCHANGE" 17 | markers: ["stackexchange"] 18 | - source: "UNKNOWN" 19 | markers: [] # No specific markers for UNKNOWN 20 | 21 | sampling_frequencies: 22 | COMMON_CRAWL: 0.9233485194 23 | C4: 1.037142857 24 | GITHUB: 0.9228813559 25 | WIKIPEDIA: 2.26875 26 | BOOKS: 2.094230769 27 | ARXIV: 1.080357143 28 | STACKEXCHANGE: 1.21 29 | LMDATA: 1.0 30 | UNKNOWN: 0 31 | 32 | -------------------------------------------------------------------------------- /open_lm/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def setup_logging(log_file, level, include_host=False): 5 | if include_host: 6 | import socket 7 | 8 | hostname = socket.gethostname() 9 | formatter = logging.Formatter( 10 | f"%(asctime)s | {hostname} | %(levelname)s | %(message)s", 11 | datefmt="%Y-%m-%d,%H:%M:%S", 12 | ) 13 | else: 14 | formatter = logging.Formatter("%(asctime)s | %(levelname)s | %(message)s", datefmt="%Y-%m-%d,%H:%M:%S") 15 | 16 | logging.root.setLevel(level) 17 | 18 | stream_handler = logging.StreamHandler() 19 | stream_handler.setFormatter(formatter) 20 | logging.root.addHandler(stream_handler) 21 | 22 | if log_file: 23 | file_handler = logging.FileHandler(filename=log_file) 24 | file_handler.setFormatter(formatter) 25 | logging.root.addHandler(file_handler) 26 | -------------------------------------------------------------------------------- /scripts/train_example.sh: -------------------------------------------------------------------------------- 1 | # TORCH_DISTRIBUTED_DEBUG="DETAIL" torchrun --nproc-per-node 3 -m open_lm.main \ 2 | CUDA_VISIBLE_DEVICE=1 CUDA_LAUNCH_BLOCKING=1 python open_lm/main.py \ 3 | --model linear_tiny \ 4 | --dataset-manifest s3://tri-ml-datasets/openlm/dcnlp/datasets/rpj_lmdata_do_sample_weight_fix_try2/manifest.jsonl \ 5 | --train-num-samples 1_000_000 \ 6 | --precision "amp_bfloat16" \ 7 | --fsdp-amp \ 8 | --fsdp-pure-bf16 \ 9 | --workers 1 \ 10 | --global-batch-size 9 \ 11 | --log-every-n-steps 100 \ 12 | --grad-clip-norm 1 \ 13 | --data-key json.gz \ 14 | --lr 3e-4 \ 15 | --accum-freq 1 \ 16 | --warmup 10 \ 17 | --wd 0.1 \ 18 | --beta2 0.98 \ 19 | --epochs 10 \ 20 | --report-to wandb \ 21 | --wandb-project-name open_lm \ 22 | --name open_lm_ex_$RANDOM \ 23 | --resume latest \ 24 | --logs logs \ 25 | --z-loss-coefficient 1e-4 \ 26 | --load-not-strict \ 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /open_lm/utils/transformers/generation.py: -------------------------------------------------------------------------------- 1 | from utils.transformers.hf_model import OpenLMforCausalLM 2 | from transformers import GPTNeoXTokenizerFast 3 | import argparse 4 | 5 | if __name__ == "__main__": 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--checkpoint", type=str) 8 | parser.add_argument("--prompt", type=str, default="I enjoy walking with my cute dog") 9 | args = parser.parse_args() 10 | model = OpenLMforCausalLM.from_pretrained(args.checkpoint) 11 | model = model.cuda() 12 | # hardcoded to neox tokenizer 13 | tokenizer = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b") 14 | input_ids = tokenizer.encode(args.prompt, return_tensors="pt") 15 | greedy_output = model.generate(input_ids.to(0), max_length=500, do_sample=True, top_p=0.9) 16 | print("Output:\n" + 100 * "-") 17 | print(tokenizer.decode(greedy_output[0], skip_special_tokens=True)) 18 | -------------------------------------------------------------------------------- /eval/in_memory_hf_eval.yaml: -------------------------------------------------------------------------------- 1 | epoch: 1.25T 2 | dataset: bigdata 3 | num_params: 1B 4 | max_seq_len: 2048 5 | seed: 1 6 | precision: fp32 7 | 8 | # Tokenizer 9 | tokenizer: 10 | # name: [Add name from memory] 11 | pretrained_model_name_or_path: 12 | kwargs: 13 | model_max_length: 2048 14 | 15 | model: 16 | name: open_lm 17 | # pretrained_model_name_or_path: [add name from memory] 18 | init_device: cpu 19 | pretrained: true 20 | 21 | load_path: # Add your (optional) Composer checkpoint path here! 22 | 23 | device_eval_batch_size: 8 24 | 25 | # FSDP config for model sharding 26 | fsdp_config: 27 | sharding_strategy: FULL_SHARD 28 | mixed_precision: FULL 29 | 30 | 31 | icl_tasks: 32 | - 33 | label: mmlu 34 | dataset_uri: local_data/mmlu.jsonl # ADD YOUR OWN DATASET URI 35 | num_fewshot: [0] 36 | icl_task_type: multiple_choice 37 | continuation_delimiter: 'Answer: ' # this separates questions from answers 38 | has_categories: true 39 | -------------------------------------------------------------------------------- /open_lm/utils/transformers/hf_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import AutoModelForCausalLM 4 | 5 | 6 | class HfWrapper(nn.Module): 7 | def __init__(self, args) -> None: 8 | super().__init__() 9 | self.model = AutoModelForCausalLM.from_pretrained( 10 | args.hf_model, 11 | torch_dtype=torch.bfloat16, 12 | ) 13 | 14 | self.params = self.model.config 15 | self.vocab_size = self.model.config.vocab_size 16 | self.seq_len = args.hf_seq_len 17 | 18 | @torch.jit.ignore 19 | def set_grad_checkpointing(self, enable=True): 20 | if enable: 21 | self.model.gradient_checkpointing_enable() 22 | else: 23 | self.model.gradient_checkpointing_disable() 24 | 25 | def forward(self, input): 26 | return self.model(input_ids=input)[0], None 27 | 28 | 29 | def create_wrapped_hf_model(hf_model_name): 30 | return HfWrapper(hf_model_name) 31 | -------------------------------------------------------------------------------- /sagemaker_train/cfg_sample.yaml: -------------------------------------------------------------------------------- 1 | accum-freq: 4 2 | beta1: 0.9 3 | beta2: 0.95 4 | data-key: "json" 5 | dataset-resampled: True 6 | # delete-previous-checkpoint: False 7 | # Total 25B * 40 = 1T tokens 8 | epochs: 40 9 | fsdp: True 10 | fsdp-limit-all-gathers: True 11 | # grad-checkpointing: False 12 | grad-clip-norm: 1 13 | log-every-n-steps: 20 14 | model: "open_lm_7b" 15 | name: "sample_7b" 16 | precision: "amp_bfloat16" 17 | report-to: "wandb" 18 | seed: 124 19 | train-data-mix-weights: [0.725, 0.275] 20 | train-data: ["TODO"] 21 | train-num-samples: 25_000_000_000 22 | wandb-project-name: "lm1" 23 | workers: 4 24 | logs: /opt/ml/checkpoints/ 25 | 26 | # Some important parameters, double checked with Mitchell: 27 | batch-size: 16 28 | ffn-type: swiglu 29 | # fsdp-amp: False 30 | fsdp-pure-bf16: True 31 | fsdp-backward-prefetch: True 32 | lr: 3.e-4 33 | lr-cooldown-end: 3.e-5 34 | model-norm: "gain_only_lp_layer_norm" 35 | qk-norm: True 36 | warmup: 5000 37 | wd: 0.1 38 | z-loss-coefficient: 1.e-4 39 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | install: ## [Local development] Upgrade pip, install requirements, install package. 2 | python -m pip install -U pip 3 | python -m pip install -e . 4 | 5 | install-dev: ## [Local development] Install test requirements 6 | python -m pip install -r requirements_test.txt 7 | 8 | lint: ## [Local development] Run mypy, pylint and black 9 | python -m black --check -l 120 . 10 | 11 | black: ## [Local development] Auto-format python code using black 12 | python3 -m black -l 120 . 13 | 14 | TEST_ARGS = tests ## set default to run all tests 15 | test: ## [Local development] Run unit tests 16 | python -m pytest -x -s -v $(TEST_ARGS) -m "not gpu and not s3" 17 | 18 | test-gpu: ## [Local development] Run unit tests 19 | python -m pytest -x -s -v $(TEST_ARGS) -m gpu 20 | 21 | .PHONY: help 22 | 23 | help: # Run `make help` to get help on the make commands 24 | @grep -E '^[0-9a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 25 | -------------------------------------------------------------------------------- /open_lm/datapreprocess/wiki_download.py: -------------------------------------------------------------------------------- 1 | """ 2 | Quick wikipedia download script from huggingface for quickstart purposes. 3 | Just downloads the 20220301 english wikipedia from huggingface and 4 | does no extra preprocessing. 5 | 6 | """ 7 | 8 | import argparse 9 | from datasets import load_dataset # huggingface 10 | import os 11 | 12 | 13 | def main(output_dir): 14 | os.makedirs(output_dir, exist_ok=True) 15 | data = load_dataset("wikipedia", "20220301.en") 16 | 17 | for split, dataset in data.items(): 18 | print("Processing split: %s" % data) 19 | output_file = os.path.join(output_dir, "wiki_en_20220301_%s.jsonl" % (split)) 20 | dataset.to_json(output_file) 21 | 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument( 26 | "--output-dir", 27 | type=str, 28 | required=True, 29 | help="Where to store the wikipedia .jsonl file", 30 | ) 31 | 32 | args = parser.parse_args() 33 | main(args.output_dir) 34 | -------------------------------------------------------------------------------- /open_lm/utils/transformers/convert_to_hf.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from utils.transformers.hf_model import OpenLMModel 3 | from transformers import GPTNeoXTokenizerFast 4 | from utils.transformers.hf_config import OpenLMConfig 5 | import torch 6 | import json 7 | 8 | if __name__ == "__main__": 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--checkpoint") 11 | parser.add_argument("--model-config") 12 | parser.add_argument("--out-dir") 13 | args = parser.parse_args() 14 | checkpoint = torch.load(args.checkpoint) 15 | with open(args.model_config, "r") as f: 16 | config = json.load(f) 17 | openlm_config = OpenLMConfig(**config) 18 | open_lm = OpenLMModel(openlm_config) 19 | # hardcoded to NeoX Tokenizer 20 | tokenizer = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b") 21 | state_dict = checkpoint["state_dict"] 22 | state_dict = {x.replace("module.", ""): y for x, y in state_dict.items()} 23 | open_lm.model.load_state_dict(state_dict) 24 | open_lm.save_pretrained(args.out_dir) 25 | tokenizer.save_pretrained(args.out_dir) 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 mlfoundations 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /sagemaker_train/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG AWS_REGION 2 | 3 | # SageMaker PyTorch image 4 | FROM 763104351884.dkr.ecr.${AWS_REGION}.amazonaws.com/pytorch-training:2.1.0-gpu-py310-cu121-ubuntu20.04-sagemaker 5 | 6 | # Run custom installation of libraries 7 | # RUN pip install xxx 8 | # RUN apt-get update && apt-get install -y xxx 9 | # ENV 10 | # etc.... 11 | 12 | # Remove the conda installed symlink for libcurl, which causes an error with curl. 13 | # Fixes the following error: 14 | # curl: /opt/conda/lib/libcurl.so.4: no version information available (required by curl) 15 | RUN rm /opt/conda/lib/libcurl.so.4 16 | 17 | ENV PATH="/opt/ml/code:${PATH}" 18 | 19 | # this environment variable is used by the SageMaker PyTorch container to determine our user code directory. 20 | ENV SAGEMAKER_SUBMIT_DIRECTORY /opt/ml/code 21 | 22 | # /opt/ml and all subdirectories are utilized by SageMaker, use the /code subdirectory to store your user code. 23 | COPY . /opt/ml/code/ 24 | RUN rm /opt/ml/code/setup.py 25 | 26 | RUN pip install -r /opt/ml/code/requirements.txt 27 | RUN pip uninstall flash-attn -y 28 | RUN pip install flash-attn>=2.2 29 | # # Prevent sagemaker from installing requirements again. 30 | # RUN rm /opt/ml/code/setup.py 31 | RUN rm /opt/ml/code/requirements.txt 32 | 33 | # Defines a script entrypoint 34 | ENV SAGEMAKER_PROGRAM open_lm/main.py 35 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Continuous integration 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | lint: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v2 16 | - name: Set up Python 3.10 17 | uses: actions/setup-python@v2 18 | with: 19 | python-version: "3.10" 20 | - name: Install 21 | run: | 22 | sudo apt-get update 23 | python3 -m venv .env 24 | source .env/bin/activate 25 | python -m pip install -U pip 26 | make install-dev 27 | - name: Lint 28 | run: | 29 | source .env/bin/activate 30 | make lint 31 | tests: 32 | runs-on: ubuntu-latest 33 | strategy: 34 | matrix: 35 | python-version: ["3.10"] 36 | 37 | steps: 38 | - uses: actions/checkout@v2 39 | - name: Set up Python ${{ matrix.python-version }} 40 | uses: actions/setup-python@v2 41 | with: 42 | python-version: ${{ matrix.python-version }} 43 | - name: Install 44 | run: | 45 | mkdir $HOME/.aws/ 46 | wget https://gist.githubusercontent.com/Vaishaal/f109bfab6a194a93040ae2a19b6be251/raw/7d8026ae234d77ba1ca29b1f9d114c6780308ae4/dummy_creds -O $HOME/.aws/credentials 47 | sudo apt-get update 48 | python3 -m venv .env 49 | source .env/bin/activate 50 | make install 51 | make install-dev 52 | - name: Unit tests 53 | run: | 54 | source .env/bin/activate 55 | make test 56 | -------------------------------------------------------------------------------- /open_lm/utils/transformers/hf_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import fields 2 | from typing import List, Optional, Dict 3 | from argparse import Namespace 4 | 5 | from transformers import PretrainedConfig 6 | 7 | from open_lm.model import Params, create_params 8 | 9 | 10 | class OpenLMConfig(PretrainedConfig): 11 | model_type = "openlm" 12 | 13 | def __init__( 14 | self, 15 | params: Optional[Params] = None, 16 | params_args: Optional[Namespace] = None, 17 | params_args_dict: Optional[Dict] = None, 18 | **kwargs 19 | ): 20 | """ 21 | Initialize the HFConfig class. Any of the three arguments can be used to initialize the class. 22 | Note that the instance can get serialized when passing in either params_args or params_args_dict. 23 | 24 | Args: 25 | params (Optional[Params]): The parameters object. 26 | params_args (Optional[Namespace]): The namespace object containing the parameters arguments. 27 | params_args_dict (Optional[Dict]): The dictionary containing the parameters arguments. 28 | **kwargs: Additional keyword arguments. 29 | """ 30 | # Used by huggingface transformers 31 | super().__init__(**kwargs) 32 | 33 | if params_args is not None: 34 | params_args_dict = vars(params_args) 35 | 36 | self.params_args_dict = params_args_dict 37 | 38 | if params is not None: 39 | self.params = params 40 | 41 | def set_params(self, params: Params): 42 | self.tie_word_embeddings = params.weight_tying 43 | for field in fields(Params): 44 | setattr(self, field.name, getattr(params, field.name)) 45 | -------------------------------------------------------------------------------- /tests/test_make_wds_manifest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from open_lm.utils import make_wds_manifest as mwm 3 | from tests.utils import download_dl_test_data 4 | import os 5 | import json 6 | 7 | """ Test strategy: 8 | Given some tarfiles with "correct" manifests on huggingface, let's assert 9 | that we can recreate them 10 | """ 11 | 12 | 13 | @pytest.mark.parametrize("source_dir", ["source_1", "source_2"]) 14 | def test_make_manifest_from_source(source_dir): 15 | download_dl_test_data("tests/assets") 16 | 17 | MOCK_MANIFEST = "tests/assets/%s/mock_manifest.jsonl" % source_dir 18 | if os.path.exists(MOCK_MANIFEST): 19 | os.unlink(MOCK_MANIFEST) 20 | 21 | args = ["--data-dir", "tests/assets/%s" % source_dir, "--manifest-filename", "mock_manifest.jsonl"] 22 | mwm.main(args) 23 | 24 | true_manifest = "tests/assets/%s/manifest.jsonl" % source_dir 25 | with open(true_manifest, "r") as true_file: 26 | with open(MOCK_MANIFEST, "r") as mock_file: 27 | assert true_file.read() == mock_file.read() 28 | 29 | if os.path.exists(MOCK_MANIFEST): 30 | os.unlink(MOCK_MANIFEST) 31 | 32 | 33 | def test_make_toplevel_manifest(): 34 | download_dl_test_data("tests/assets") 35 | 36 | MOCK_MANIFEST = "tests/assets/mock_manifest.jsonl" 37 | if os.path.exists(MOCK_MANIFEST): 38 | os.unlink(MOCK_MANIFEST) 39 | 40 | args = ["--data-dir", "tests/assets/", "--manifest-filename", "mock_manifest.jsonl"] 41 | mwm.main(args) 42 | 43 | lines = [json.loads(_) for _ in open(MOCK_MANIFEST, "r").readlines()] 44 | assert lines == [{"shard": "shard_00000000", "num_sequences": 120}] 45 | 46 | if os.path.exists(MOCK_MANIFEST): 47 | os.unlink(MOCK_MANIFEST) 48 | -------------------------------------------------------------------------------- /open_lm/datapreprocess/ray/token_counter.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | import os 3 | import tarfile 4 | import glob 5 | import shutil 6 | import json 7 | import random 8 | 9 | s3 = boto3.client("s3") 10 | bucket_name, prefix = "dcnlp-hub", "C4_V3_tokenized/" 11 | paginator = s3.get_paginator("list_objects_v2") 12 | pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix) 13 | all_files = [obj["Key"] for objects in pages for obj in objects.get("Contents", [])] 14 | random.shuffle(all_files) 15 | all_files = all_files[6:100] 16 | 17 | 18 | total_tokens = 0 19 | num_errors = 0 20 | for i, file in enumerate(all_files): 21 | try: 22 | os.makedirs(f"tmp/token_jsons_{i}", exist_ok=True) 23 | output_path = f'tmp/{file.split("/")[-1]}' 24 | s3.download_file(bucket_name, file, output_path) 25 | 26 | with tarfile.open(output_path, "r") as tar: 27 | tar.extractall(path=f"tmp/token_jsons_{i}/", numeric_owner=True) 28 | 29 | num_tokens = len(glob.glob(f"tmp/token_jsons_{i}/*.json")) * 2048 30 | for tokens_file in glob.glob(f"tmp/token_jsons_{i}/*.json"): 31 | with open(tokens_file, "r") as file: 32 | tokens = json.load(file) 33 | assert len(tokens) == 2048, "Token length is wrong" 34 | except: 35 | print("Error on file:", file) 36 | num_tokens = 0 37 | num_errors += 1 38 | 39 | # os.rmdir("/tmp/token_jsons/") 40 | shutil.rmtree(f"tmp/token_jsons_{i}", ignore_errors=True) 41 | os.remove(output_path) 42 | total_tokens += num_tokens 43 | print(f"Reached tar {i}, Num Tokens = {num_tokens}, Total = {total_tokens/1e9} Billion") 44 | 45 | 46 | print("Total tokens", total_tokens) 47 | print("Num errors", num_errors) 48 | -------------------------------------------------------------------------------- /open_lm/utils/llm_foundry_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 MosaicML LLM Foundry authors 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | """Implements a Hugging Causal LM wrapped inside a :class:`.ComposerModel`.""" 5 | 6 | from typing import Mapping, Union 7 | 8 | from composer.metrics.nlp import ( 9 | InContextLearningLMAccuracy, 10 | InContextLearningLMExpectedCalibrationError, 11 | InContextLearningMCExpectedCalibrationError, 12 | InContextLearningMultipleChoiceAccuracy, 13 | InContextLearningQAAccuracy, 14 | InContextLearningCodeEvalAccuracy, 15 | LanguageCrossEntropy, 16 | LanguagePerplexity, 17 | ) 18 | from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast 19 | 20 | from composer.models.huggingface import HuggingFaceModel 21 | 22 | 23 | __all__ = ["ComposerOpenLMCausalLM", "SimpleComposerOpenLMCausalLM"] 24 | 25 | Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] 26 | 27 | TRAIN_METRICS = [ 28 | LanguageCrossEntropy(), 29 | LanguagePerplexity(), 30 | ] 31 | EVAL_METRICS = [ 32 | LanguageCrossEntropy(), 33 | LanguagePerplexity(), 34 | InContextLearningLMAccuracy(), 35 | InContextLearningMultipleChoiceAccuracy(), 36 | InContextLearningQAAccuracy(), 37 | InContextLearningLMExpectedCalibrationError(), 38 | InContextLearningMCExpectedCalibrationError(), 39 | InContextLearningCodeEvalAccuracy(), 40 | ] 41 | 42 | 43 | class SimpleComposerOpenLMCausalLM(HuggingFaceModel): 44 | def __init__(self, model, tokenizer): 45 | super().__init__( 46 | model=model, 47 | tokenizer=tokenizer, 48 | metrics=TRAIN_METRICS, 49 | eval_metrics=EVAL_METRICS, 50 | shift_labels=True, 51 | ) 52 | 53 | def generate(self, input_ids=None, inputs_embeds=None, **kwargs): 54 | return super().generate(input_ids=input_ids, **kwargs) 55 | -------------------------------------------------------------------------------- /open_lm/datapreprocess/ray/readme.md: -------------------------------------------------------------------------------- 1 | 2 | # Ray Cluster Setup and Execution Guide 3 | 4 | ## Quick Commands 5 | 6 | 1. **Spin up the ray cluster**: 7 | ``` 8 | ray up ray_cluster_configs/cluster_west.yaml 9 | ``` 10 | 11 | 2. **Access the ray cluster**: 12 | ``` 13 | ray attach ray_cluster_configs/cluster_west.yaml 14 | ``` 15 | 16 | 3. **Transfer the `tokenize_shuffle.py` script to the cluster**: 17 | ``` 18 | ray rsync_up ray_cluster_configs/cluster_west.yaml tokenize_shuffle.py /home/ubuntu 19 | ``` 20 | 21 | 5. **Tokenize with shuffling**: 22 | ``` 23 | python tokenize_shuffle.py --input “s3://dcnlp-data/redpajamas-raw/c4-train.{00000..00063}-of-01024.jsonl” --output s3://dcnlp-data/tokenize-shuffle-test/ 24 | ``` 25 | 26 | > **Note**: Ensure that the paths specified above are in the same AWS region as the one mentioned in the ray yaml file (currently set to `us-west-2`). 27 | 28 | 6. **Exit and re-enter the cluster as required**. 29 | 30 | ## Detailed Workflow 31 | 32 | 1. **Configure AWS**: 33 | Start by setting up your AWS credentials: 34 | ``` 35 | aws configure 36 | ``` 37 | 38 | 2. **Initialize the cluster**: 39 | ``` 40 | ray up ray_cluster_configs/cluster_west.yaml 41 | ``` 42 | 43 | 3. **Copy the script to the cluster**: 44 | ``` 45 | ray rsync_up ray_cluster_configs/cluster_west.yaml tokenize_shuffle.py /home/ubuntu 46 | ``` 47 | Copy the `default_dataset_yaml` as well if used. 48 | 49 | 4. **SSH into the cluster**: 50 | ``` 51 | ray attach ray_cluster_configs/cluster_west.yaml 52 | ``` 53 | 54 | 5. **Enter tmux and execute the job**: 55 | ``` 56 | tmux new-session -d -s ray_tokenize_shuffle 'python tokenize_shuffle.py' 57 | ``` 58 | 59 | > **Heads up**: This is version 0 of this script. The user interface will be improved in future versions. Currently, objects are being spilled to `dcnlp-hub`. 60 | -------------------------------------------------------------------------------- /open_lm/utils/update_manifest.py: -------------------------------------------------------------------------------- 1 | """Convert manifests to the new format. 2 | 3 | This file converts existing manifest files to the new format (changing the "num_chunks" field to "num_sequences"). 4 | """ 5 | 6 | import argparse 7 | import re 8 | import shutil 9 | import simdjson 10 | import sys 11 | import multiprocessing as mp 12 | from pathlib import Path 13 | from cloudpathlib import CloudPath 14 | from tqdm import tqdm 15 | 16 | 17 | def path_or_cloudpath(s): 18 | if re.match(r"^\w+://", s): 19 | return CloudPath(s) 20 | return Path(s) 21 | 22 | 23 | def parse_args(args): 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument( 26 | "--manifest-path", 27 | type=path_or_cloudpath, 28 | required=True, 29 | help="Manifest file to update.", 30 | ) 31 | parser.add_argument("--tmp-dir", type=str, default="/tmp", help="Temporary directory.") 32 | args = parser.parse_args(args) 33 | return args 34 | 35 | 36 | def main(args): 37 | args = parse_args(args) 38 | 39 | tmp_dir = Path(args.tmp_dir) 40 | 41 | temp_manifest_filename = tmp_dir / args.manifest_path.name 42 | 43 | with args.manifest_path.open("rb") as f: 44 | data = f.read() 45 | 46 | jsons = [simdjson.loads(o) for o in data.decode("utf-8").split("\n")[:-1]] 47 | 48 | with temp_manifest_filename.open("w") as f: 49 | for data in tqdm(jsons): 50 | new_data = {} 51 | new_data["shard"] = data["shard"] 52 | new_data["num_sequences"] = data["num_chunks"] 53 | f.write(simdjson.dumps(new_data)) 54 | f.write("\n") 55 | 56 | if isinstance(args.manifest_path, CloudPath): 57 | args.manifest_path.upload_from(temp_manifest_filename) 58 | else: 59 | shutil.copy(temp_manifest_filename, args.manifest_path) 60 | 61 | 62 | if __name__ == "__main__": 63 | main(sys.argv[1:]) 64 | -------------------------------------------------------------------------------- /open_lm/scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def assign_learning_rate(optimizer, new_lr): 5 | for param_group in optimizer.param_groups: 6 | param_group["lr"] = new_lr 7 | 8 | 9 | def _warmup_lr(base_lr, warmup_length, step): 10 | return base_lr * (step + 1) / warmup_length 11 | 12 | 13 | def const_lr(optimizer, base_lr, warmup_length): 14 | def _lr_adjuster(step): 15 | if step < warmup_length: 16 | lr = _warmup_lr(base_lr, warmup_length, step) 17 | else: 18 | lr = base_lr 19 | assign_learning_rate(optimizer, lr) 20 | return lr 21 | 22 | return _lr_adjuster 23 | 24 | 25 | def const_lr_cooldown( 26 | optimizer, 27 | base_lr, 28 | warmup_length, 29 | steps, 30 | cooldown_steps, 31 | cooldown_power=1.0, 32 | cooldown_end_lr=0.0, 33 | ): 34 | def _lr_adjuster(step): 35 | start_cooldown_step = steps - cooldown_steps 36 | if step < warmup_length: 37 | lr = _warmup_lr(base_lr, warmup_length, step) 38 | else: 39 | if step < start_cooldown_step: 40 | lr = base_lr 41 | else: 42 | e = step - start_cooldown_step 43 | es = steps - start_cooldown_step 44 | # linear decay if power == 1; polynomial decay otherwise; 45 | decay = (1 - (e / es)) ** cooldown_power 46 | lr = decay * (base_lr - cooldown_end_lr) + cooldown_end_lr 47 | assign_learning_rate(optimizer, lr) 48 | return lr 49 | 50 | return _lr_adjuster 51 | 52 | 53 | def cosine_lr(optimizer, base_lr, warmup_length, steps, min_lr, force_min_lr): 54 | def _lr_adjuster(step): 55 | if step < warmup_length: 56 | lr = _warmup_lr(base_lr, warmup_length, step) 57 | else: 58 | e = step - warmup_length 59 | es = steps - warmup_length 60 | lr = min_lr + 0.5 * (1 + np.cos(np.pi * e / es)) * (base_lr - min_lr) 61 | lr = max(lr, force_min_lr) 62 | assign_learning_rate(optimizer, lr) 63 | return lr 64 | 65 | return _lr_adjuster 66 | -------------------------------------------------------------------------------- /open_lm/datapreprocess/ray/ray_cluster_configs/cluster_west.yaml: -------------------------------------------------------------------------------- 1 | # An unique identifier for the head node and workers of this cluster. 2 | cluster_name: ray-shuffle-tokenize 3 | max_workers: 10 4 | upscaling_speed: 0.0 5 | available_node_types: 6 | ray.head.default: 7 | resources: {} 8 | node_config: 9 | SubnetIds: [subnet-0aa43f554b080bc71, subnet-0788d34e592b0adf3, subnet-0d7f60196af561fab] 10 | ImageId: ami-0efcece6bed30fd98 # ray us-west-2 11 | InstanceType: i4i.4xlarge 12 | IamInstanceProfile: 13 | Arn: arn:aws:iam::753985720788:instance-profile/ray-autoscaler-v1 14 | ray.worker.default: 15 | min_workers: 10 16 | max_workers: 10 17 | node_config: 18 | SubnetIds: [subnet-0aa43f554b080bc71, subnet-0788d34e592b0adf3, subnet-0d7f60196af561fab] 19 | ImageId: ami-0efcece6bed30fd98 # ray us-west-2 20 | InstanceType: i4i.4xlarge 21 | IamInstanceProfile: 22 | Arn: arn:aws:iam::753985720788:instance-profile/ray-autoscaler-v1 23 | 24 | # Cloud-provider specific configuration. 25 | provider: 26 | type: aws 27 | region: us-west-2 28 | cache_stopped_nodes: False 29 | 30 | 31 | 32 | setup_commands: 33 | # - sudo apt-get update 34 | - wget https://repo.anaconda.com/miniconda/Miniconda3-py310_23.3.1-0-Linux-x86_64.sh -O miniconda.sh 35 | - bash ~/miniconda.sh -f -b -p miniconda3/ 36 | - echo 'export PATH="$HOME/miniconda3/bin/:$PATH"' >> ~/.bashrc 37 | # if you have AWS CREDS fill them out here 38 | - pip install --upgrade pip setuptools wheel 39 | # - pip install ray 40 | - pip install -U "ray[default] @ https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp310-cp310-manylinux2014_x86_64.whl" 41 | - pip install boto3==1.26.90 42 | - pip install s3fs==2022.11.0 43 | - pip install psutil 44 | - pip install pysimdjson 45 | - pip install pandas 46 | - pip install pyarrow 47 | - pip install webdataset 48 | - pip install transformers 49 | - pip install jsonlines 50 | - pip install loguru 51 | - sudo mkfs -t xfs /dev/nvme1n1 52 | - sudo mount /dev/nvme1n1 /tmp 53 | - sudo chown -R $USER /tmp 54 | 55 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | from setuptools import find_packages 3 | from os import path 4 | 5 | here = path.abspath(path.dirname(__file__)) 6 | 7 | # Get the long description from the README file 8 | with open(path.join(here, "README.md"), encoding="utf-8") as f: 9 | long_description = f.read() 10 | 11 | 12 | def _read_reqs(relpath): 13 | fullpath = path.join(path.dirname(__file__), relpath) 14 | with open(fullpath) as f: 15 | return [s.strip() for s in f.readlines() if (s.strip() and not s.startswith("#"))] 16 | 17 | 18 | REQUIREMENTS = _read_reqs("requirements.txt") 19 | 20 | setuptools.setup( 21 | name="open_lm", 22 | version="0.0.34", 23 | author=[ 24 | "Suchin Gururangan*", 25 | "Mitchell Wortsman*", 26 | "Samir Yitzhak Gadre", 27 | "Achal Dave", 28 | "Maciej Kilian", 29 | "Weijia Shi", 30 | "Georgios Smyrnis", 31 | "Gabriel Ilharco", 32 | "Matt Jordan", 33 | "Ali Farhadi", 34 | "Ludwig Schmidt", 35 | ], 36 | author_email="sg01@cs.washington.edu", 37 | description="OpenLM", 38 | classifiers=[ 39 | # How mature is this project? Common values are 40 | # 3 - Alpha 41 | # 4 - Beta 42 | # 5 - Production/Stable 43 | "Development Status :: 3 - Alpha", 44 | "Intended Audience :: Education", 45 | "Intended Audience :: Science/Research", 46 | "License :: OSI Approved :: Apache Software License", 47 | "Programming Language :: Python :: 3.7", 48 | "Programming Language :: Python :: 3.8", 49 | "Programming Language :: Python :: 3.9", 50 | "Programming Language :: Python :: 3.10", 51 | "Topic :: Scientific/Engineering", 52 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 53 | "Topic :: Software Development", 54 | "Topic :: Software Development :: Libraries", 55 | "Topic :: Software Development :: Libraries :: Python Modules", 56 | ], 57 | install_requires=REQUIREMENTS, 58 | long_description=long_description, 59 | long_description_content_type="text/markdown", 60 | url="https://github.com/mlfoundations/open_lm", 61 | license="MIT", 62 | packages=find_packages(), 63 | include_package_data=True, 64 | ) 65 | -------------------------------------------------------------------------------- /open_lm/positional_embedding/head_rotary.py: -------------------------------------------------------------------------------- 1 | # NOTE: 08/31/23, this class is copied from xformers as there is currently a bug related to which channel dim the rotary embedding is applied to. 2 | # when the upstream issue is fixed, this file should be deleted. To track progress, see this issue: https://github.com/facebookresearch/xformers/issues/841 3 | 4 | # taken from: https://github.com/facebookresearch/xformers/blob/748c159096d4f9fcfe3eaf22801e5aed4777210b/xformers/components/positional_embedding/rotary.py 5 | 6 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 7 | # 8 | # This source code is licensed under the BSD license found in the 9 | # LICENSE file in the root directory of this source tree. 10 | 11 | 12 | # CREDITS: This implementation is inspired by GPT-NeoX https://github.com/EleutherAI/gpt-neox 13 | # NOTE: Almost the same right now, moving parts to Triton is the next step 14 | 15 | from typing import Tuple 16 | 17 | import torch 18 | 19 | from open_lm.positional_embedding.rotary import apply_rotary_pos_emb, RotaryEmbedding 20 | 21 | 22 | class HeadRotaryEmbedding(RotaryEmbedding): 23 | """ 24 | The rotary position embeddings used in the first version of OpenLM. 25 | It is only kept for compatibility, RotaryEmbedding should be used instead. 26 | """ 27 | 28 | def __init__(self, dim_model: int, seq_len: int, *_, **__): 29 | super().__init__(dim_model, seq_len) 30 | self._has_warned = False 31 | 32 | def forward(self, q: torch.Tensor, k: torch.Tensor, offset=0) -> Tuple[torch.Tensor, torch.Tensor]: 33 | self._update_cos_sin_tables(k.shape[2], device=k.device, dtype=k.dtype) 34 | 35 | if not self._has_warned and (offset != 0): 36 | print("Warning. HeadRotaryEmbedding does not support offset, I am not applying it.") 37 | self._has_warned = True 38 | 39 | out_q = apply_rotary_pos_emb(q.transpose(1, 2), self._cos_cached, self._sin_cached).transpose(1, 2) 40 | out_k = apply_rotary_pos_emb(k.transpose(1, 2), self._cos_cached, self._sin_cached).transpose(1, 2) 41 | return out_q, out_k 42 | 43 | 44 | class HeadRotaryWithCast(HeadRotaryEmbedding): 45 | # NOTE: this version has the bug, but we trained the 7B model with it so it's default 46 | def forward(self, q, k, v, offset=0): 47 | q, k = super().forward(q, k, offset) 48 | return q.to(v.dtype), k.to(v.dtype), v 49 | -------------------------------------------------------------------------------- /tests/test_generate_kv_cache_time.py: -------------------------------------------------------------------------------- 1 | import time 2 | import pytest 3 | 4 | from transformers import GPTNeoXTokenizerFast 5 | 6 | from open_lm.utils.transformers.hf_model import OpenLMforCausalLM 7 | from open_lm.utils.transformers.hf_config import OpenLMConfig 8 | from open_lm.model import create_params 9 | from tests.shared import MockTrainArgs 10 | from .utils import run_model 11 | 12 | 13 | @pytest.mark.gpu 14 | @pytest.mark.slow 15 | @pytest.mark.parametrize("wiki_page", ["Soil steam sterilization", "The Triumph of Death"]) 16 | @pytest.mark.parametrize("context_len", [256]) 17 | @pytest.mark.parametrize("max_gen_len", [1024, 1792]) 18 | def test_generate_kv_cache(wiki_page, context_len, max_gen_len): 19 | """Test that the model generates faster with cache than without.""" 20 | args = MockTrainArgs( 21 | model="open_lm_160m", 22 | **{ 23 | # Generation params: 24 | "input_text": "random", 25 | "max_gen_len": max_gen_len, 26 | "context_len": context_len, 27 | "temperature": 0.0, 28 | "top_p": 1.0, 29 | "use_cache": False, 30 | # Model params that might not be in config: 31 | "model_norm": "gain_only_layer_norm", 32 | "qk_norm": False, 33 | "positional_embedding_type": "rotary", 34 | "ffn_type": "swiglu", 35 | "moe_num_experts": None, 36 | "moe_freq": 0, 37 | "moe_weight_parallelism": False, 38 | "moe_expert_model_parallelism": False, 39 | "moe_capacity_factor": 1.25, 40 | "moe_loss_weight": 0.1, 41 | "moe_top_k": 2, 42 | "num_beams": 1, 43 | } 44 | ) 45 | 46 | open_lm = OpenLMforCausalLM(OpenLMConfig(create_params(args))) 47 | 48 | tokenizer = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b") 49 | 50 | open_lm.model.eval() 51 | 52 | start_time = time.time() 53 | args.use_cache = False 54 | run_model(open_lm, tokenizer, args, wiki_page=wiki_page, start_index=0) 55 | end_time = time.time() 56 | time_without_cache = end_time - start_time 57 | 58 | start_time = time.time() 59 | args.use_cache = True 60 | run_model(open_lm, tokenizer, args, wiki_page=wiki_page, start_index=0) 61 | end_time = time.time() 62 | time_with_cache = end_time - start_time 63 | 64 | assert time_with_cache < time_without_cache 65 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | slurm*.out 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | wandb 134 | logs* 135 | *.out 136 | eval/*.jsonl 137 | eval/*.jsonl_tmp 138 | weights* 139 | out* 140 | tests/assets/*.tar 141 | tests/assets/source_*/* 142 | .vscode/ 143 | secrets.env 144 | checkpoints/ 145 | experiments/ 146 | -------------------------------------------------------------------------------- /open_lm/utils/make_wds_manifest.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import re 3 | import simdjson 4 | import sys 5 | import subprocess 6 | import multiprocessing as mp 7 | from pathlib import Path 8 | from cloudpathlib import CloudPath 9 | from tqdm import tqdm 10 | 11 | 12 | def path_or_cloudpath(s): 13 | if re.match(r"^\w+://", s): 14 | return CloudPath(s) 15 | return Path(s) 16 | 17 | 18 | def parse_args(args): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument( 21 | "--data-dir", 22 | type=path_or_cloudpath, 23 | required=True, 24 | help="Directory containing a dataset in webdataset format.", 25 | ) 26 | parser.add_argument( 27 | "--manifest-filename", 28 | type=str, 29 | default="manifest.jsonl", 30 | help="Filename for the manifest that will be stored in the webdataset directory.", 31 | ) 32 | parser.add_argument("--tmp-dir", type=str, default=None, help="Temporary directory.") 33 | parser.add_argument("--num-workers", type=int, default=2, help="Number of workers.") 34 | args = parser.parse_args(args) 35 | return args 36 | 37 | 38 | def count_samples(shard_path, tmp_dir): 39 | if isinstance(shard_path, CloudPath): 40 | temp_shard_path = Path(tmp_dir) / shard_path.name 41 | shard_path.download_to(temp_shard_path) 42 | else: 43 | temp_shard_path = shard_path 44 | 45 | count = int(subprocess.check_output(f"tar tf {temp_shard_path} | wc -l", shell=True)) 46 | 47 | if isinstance(shard_path, CloudPath): 48 | temp_shard_path.unlink() 49 | 50 | return count 51 | 52 | 53 | def worker_fn(input_data): 54 | basename, data_dir, tmp_dir = input_data 55 | shard_path = data_dir / basename 56 | return ( 57 | basename, 58 | { 59 | "shard": basename.split(".")[0], 60 | "num_sequences": count_samples(shard_path, tmp_dir), 61 | }, 62 | ) 63 | 64 | 65 | def main(args): 66 | args = parse_args(args) 67 | 68 | shards = sorted([x for x in args.data_dir.iterdir() if x.name.endswith(".tar")]) 69 | input_data = [(shard.name, args.data_dir, args.tmp_dir) for shard in shards] 70 | 71 | print(f"Shards to process: {len(shards)}") 72 | print("Creating pool.") 73 | with mp.Pool(args.num_workers) as pool: 74 | data = [] 75 | for worker_data in tqdm(pool.imap_unordered(worker_fn, input_data)): 76 | data.append(worker_data) 77 | 78 | data = sorted(data) 79 | data = [item[1] for item in data] 80 | manifest_path = args.data_dir / args.manifest_filename 81 | with manifest_path.open("w") as fp: 82 | for item in data: 83 | simdjson.dump(item, fp) 84 | fp.write("\n") 85 | 86 | 87 | if __name__ == "__main__": 88 | main(sys.argv[1:]) 89 | -------------------------------------------------------------------------------- /tests/test_param_parsing.py: -------------------------------------------------------------------------------- 1 | import json 2 | import tempfile 3 | import pytest 4 | import yaml 5 | from contextlib import contextmanager 6 | 7 | from open_lm.main import main 8 | from open_lm.params import parse_args 9 | 10 | 11 | @contextmanager 12 | def create_config(config_dict, file_type="json"): 13 | assert file_type in ("json", "yaml") 14 | with tempfile.NamedTemporaryFile(mode="w", suffix="." + file_type) as f: 15 | if file_type == "json": 16 | json.dump(config_dict, f) 17 | elif file_type == "yaml": 18 | yaml.safe_dump(config_dict, f) 19 | f.seek(0) 20 | yield f 21 | 22 | 23 | def get_cmdline_config1(): 24 | samples = 1000 25 | batch_size = 2 26 | # fmt: off 27 | cmdline = [ 28 | "--train-num-samples", str(samples), 29 | "--global-batch-size", str(batch_size), 30 | "--dataset-type", "synthetic", 31 | "--model", "open_lm_test_tiny", 32 | "--epochs", "1", 33 | ] 34 | config_dict = { 35 | "train-num-samples": samples, 36 | "global-batch-size": batch_size, 37 | "dataset-type": "synthetic", 38 | "model": "open_lm_test_tiny", 39 | "epochs": 1, 40 | } 41 | # fmt: on 42 | return cmdline, config_dict 43 | 44 | 45 | @pytest.mark.parametrize("filetype", ["json", "yaml"]) 46 | def test_config_params1(filetype): 47 | cmdline, config_dict = get_cmdline_config1() 48 | cmdline_args = parse_args(cmdline) 49 | with create_config(config_dict, filetype) as f: 50 | config_args = parse_args(["--config", f.name]) 51 | assert vars(cmdline_args) == vars(config_args), "Config and command line match failed" 52 | 53 | 54 | @pytest.mark.parametrize("filetype", ["json", "yaml"]) 55 | def test_wrong_type_throws(filetype): 56 | config_dict = {"train-num-samples": "100"} 57 | with create_config(config_dict, filetype) as f: 58 | try: 59 | parse_args(["--config", f.name]) 60 | except ValueError as e: 61 | assert "Type mismatch" in str(e) 62 | 63 | 64 | @pytest.mark.parametrize("filetype", ["json", "yaml"]) 65 | def test_extra_config_key_throws(filetype): 66 | config_dict = {"this-key-should-not-exist": "100"} 67 | with create_config(config_dict, filetype) as f: 68 | try: 69 | parse_args(["--config", f.name]) 70 | except ValueError as e: 71 | assert "Unknown config" in str(e) 72 | 73 | 74 | @pytest.mark.parametrize("filetype", ["json", "yaml"]) 75 | def test_extra_arg_after_config_throws(filetype): 76 | config_dict = {"this-key-should-not-exist": "100"} 77 | with create_config(config_dict, filetype) as f: 78 | try: 79 | parse_args(["--config", f.name, "--train-data", "foo"]) 80 | except AssertionError as e: 81 | assert "--config is provided" in str(e) 82 | -------------------------------------------------------------------------------- /plots/interpolation.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import pandas as pd 3 | import numpy as np 4 | import os 5 | import json 6 | import torch 7 | import matplotlib.gridspec as gridspec 8 | from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes, mark_inset 9 | from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar 10 | 11 | 12 | def get_perplexity(filename): 13 | with open(filename, "r") as file: 14 | lines = file.readlines() 15 | 16 | # iterate over the lines from the end 17 | for line in reversed(lines): 18 | if "evaluation perplexity:" in line: 19 | _, perplexity = line.split("evaluation perplexity:") 20 | return float(perplexity) 21 | 22 | return None 23 | 24 | 25 | if __name__ == "__main__": 26 | kernel_size = 40 27 | min_loss = 14 28 | max_scaler = 1 29 | log_level = 1 # + len(modules) 30 | 31 | fig = plt.figure(figsize=(6 * 3, 5 * 3)) # , layout='tight') 32 | gs = gridspec.GridSpec(3, 3) 33 | 34 | exp_dir = "/fsx/home-mitchellw/experimetns/lm/" 35 | 36 | ax = fig.add_subplot(gs[0, 0]) 37 | 38 | for j, base in enumerate( 39 | [ 40 | #'/fsx/home-mitchellw/experimetns/lmtune/instruction-tune-1b-2e-5-6', 41 | "/fsx/home-mitchellw/experimetns/lmtune/instruction-tune-3b-2e-5-6", 42 | ] 43 | ): 44 | xs, ys, colors = [], [], [] 45 | for alpha in np.arange(0, 1.01, 0.05): 46 | chat_eval = f"{base}/checkpoints/chat-eval-interpolate-{alpha:.2f}-epoch_6.pt" 47 | base_eval = f"{base}/checkpoints/base-eval-interpolate-{alpha:.2f}-epoch_6.pt" 48 | if os.path.exists(chat_eval) and os.path.exists(base_eval): 49 | chat_y = get_perplexity(chat_eval) 50 | base_y = get_perplexity(base_eval) 51 | if chat_y is None or base_y is None: 52 | continue 53 | print(alpha) 54 | xs.append(base_y) 55 | ys.append(chat_y) 56 | colors.append(1 - alpha) # add alpha to the color list 57 | 58 | scatter = ax.scatter( 59 | xs, 60 | ys, 61 | c=colors, 62 | cmap="cool", 63 | marker="d" if "3B" in base else "o", 64 | label="OpenLM-1B" if "3B" in base else "OpenLM-3B", 65 | ) 66 | 67 | ax.set_xlabel("Base evaluation set (perplexity)", fontsize=12) 68 | ax.set_ylabel("Chat evaluation set (perplexity)", fontsize=12) 69 | 70 | ax.tick_params(axis="x", labelsize=11) 71 | ax.tick_params(axis="y", labelsize=11) 72 | ax.grid() 73 | 74 | ax.legend(fontsize=12) 75 | 76 | # Add a colorbar 77 | cbar = plt.colorbar(scatter) 78 | cbar.set_label( 79 | "Interpolation coefficient when interpolating\nbetween base and chat models", 80 | labelpad=10, 81 | ) 82 | 83 | plt.savefig("plots/interpolation.png", bbox_inches="tight") 84 | -------------------------------------------------------------------------------- /open_lm/utils/convert_llama.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script converts the weights from LLAMA to OpenLM compatible weights. 3 | Usage: `python convert_llama_to_openlm.py ` 4 | """ 5 | 6 | import torch 7 | import sys 8 | 9 | 10 | def convert(llama_state_dict: dict) -> dict: 11 | openlm_state_dict = {} 12 | 13 | n_layer = len(set([key.split(".")[1] for key in llama_state_dict if "layers." in key])) 14 | print(f"n_layer: {n_layer}") 15 | 16 | for key in ["tok_embeddings.weight", "norm.weight", "output.weight"]: 17 | value = llama_state_dict[key] 18 | assert key not in openlm_state_dict 19 | openlm_state_dict[key] = value 20 | 21 | for i in range(n_layer): 22 | src_key_1, src_key_2, src_key_3 = ( 23 | f"layers.{i}.attention.wq.weight", 24 | f"layers.{i}.attention.wk.weight", 25 | f"layers.{i}.attention.wv.weight", 26 | ) 27 | tgt_key = f"layers.{i}.attention.in_proj.weight" 28 | assert tgt_key not in openlm_state_dict 29 | openlm_state_dict[tgt_key] = torch.cat( 30 | [ 31 | llama_state_dict[src_key_1], 32 | llama_state_dict[src_key_2], 33 | llama_state_dict[src_key_3], 34 | ], 35 | dim=0, 36 | ) 37 | 38 | src_key = f"layers.{i}.attention.wo.weight" 39 | tgt_key = f"layers.{i}.attention.out_proj.weight" 40 | assert tgt_key not in openlm_state_dict 41 | openlm_state_dict[tgt_key] = llama_state_dict[src_key] 42 | 43 | src_key_1, src_key_2 = ( 44 | f"layers.{i}.feed_forward.w1.weight", 45 | f"layers.{i}.feed_forward.w3.weight", 46 | ) 47 | tgt_key = f"layers.{i}.feed_forward.w12.weight" 48 | assert tgt_key not in openlm_state_dict 49 | openlm_state_dict[tgt_key] = torch.cat([llama_state_dict[src_key_1], llama_state_dict[src_key_2]], dim=0) 50 | 51 | src_key = f"layers.{i}.feed_forward.w2.weight" 52 | tgt_key = f"layers.{i}.feed_forward.w3.weight" 53 | assert tgt_key not in openlm_state_dict 54 | openlm_state_dict[tgt_key] = llama_state_dict[src_key] 55 | 56 | tgt_key = f"layers.{i}.attention_norm.weight" 57 | assert tgt_key not in openlm_state_dict 58 | openlm_state_dict[tgt_key] = llama_state_dict[tgt_key] 59 | 60 | tgt_key = f"layers.{i}.ffn_norm.weight" 61 | assert tgt_key not in openlm_state_dict 62 | openlm_state_dict[tgt_key] = llama_state_dict[tgt_key] 63 | 64 | return openlm_state_dict 65 | 66 | 67 | if __name__ == "__main__": 68 | if len(sys.argv) != 3: 69 | print("Usage: `python convert_llama_to_openlm.py `") 70 | sys.exit(1) 71 | llama_state_dict = torch.load(sys.argv[1]) 72 | openlm_state_dict = {"state_dict": convert(llama_state_dict)} 73 | torch.save(openlm_state_dict, sys.argv[2]) 74 | -------------------------------------------------------------------------------- /open_lm/meters.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.distributed as dist 4 | import logging 5 | 6 | from open_lm.distributed import is_master 7 | 8 | 9 | class AverageMeter(object): 10 | """Computes and stores the average and current value""" 11 | 12 | def __init__(self): 13 | self.reset() 14 | 15 | def reset(self): 16 | self.val = 0 17 | self.avg = 0 18 | self.sum = 0 19 | self.count = 0 20 | 21 | def update(self, val, n=1): 22 | self.val = val 23 | self.sum += val * n 24 | self.count += n 25 | self.avg = self.sum / self.count 26 | 27 | 28 | class ConfidenceIntervalMeter(object): 29 | def __init__(self): 30 | self.reset() 31 | 32 | def reset(self): 33 | self.points = [] 34 | self.points_array = None 35 | 36 | def update(self, val): 37 | self.points.append(val) 38 | 39 | def compute_bootstrap_ci(self, max_population, num_iterations, interval=95): 40 | lower = None 41 | upper = None 42 | 43 | self.points_array = np.concatenate(self.points) 44 | 45 | num_points = self.points_array.shape[0] 46 | 47 | population_size = self.points_array.shape[0] 48 | if max_population is not None: 49 | population_size = min(max_population, population_size) 50 | 51 | estimates = [] 52 | for _ in range(num_iterations): 53 | i = np.random.choice(num_points, size=population_size) 54 | estimate = np.sum(self.points_array[i]) / population_size 55 | estimates.append(estimate.item()) 56 | 57 | half = (100 - interval) / 2 58 | 59 | lower = np.percentile(estimates, half).item() 60 | upper = np.percentile(estimates, 100 - half).item() 61 | 62 | return lower, upper 63 | 64 | 65 | def combine_average_meters(meter_list): 66 | combined_meter = AverageMeter() 67 | 68 | # arbitarily get latest val as the val from the last 69 | combined_meter.val = meter_list[-1].val 70 | combined_meter.sum = sum([m.sum for m in meter_list]) 71 | combined_meter.count = sum([m.count for m in meter_list]) 72 | combined_meter.avg = combined_meter.sum / combined_meter.count 73 | 74 | return combined_meter 75 | 76 | 77 | def combine_ci_meters(meter_list): 78 | combined_meter = ConfidenceIntervalMeter() 79 | for m in meter_list: 80 | combined_meter.points.extend(m.points) 81 | 82 | return combined_meter 83 | 84 | 85 | def gather_meters(meters, args): 86 | out_meters = [] 87 | for m in meters: 88 | combine_fn = None 89 | if isinstance(m, AverageMeter): 90 | combine_fn = combine_average_meters 91 | if isinstance(m, ConfidenceIntervalMeter): 92 | combine_fn = combine_ci_meters 93 | 94 | # buffer for a gather on all meters 95 | if is_master(args): 96 | # no need to gather unless its on master 97 | ms = [None for _ in range(args.world_size)] 98 | dist.gather_object(m, ms) 99 | out_meters.append(combine_fn(ms)) 100 | else: 101 | # meters on all others are assumed to be local 102 | dist.gather_object(m) 103 | out_meters.append(m) 104 | 105 | dist.barrier() 106 | 107 | return out_meters 108 | -------------------------------------------------------------------------------- /tests/test_tiny_generate_kv_cache_equal.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from open_lm.utils.transformers.hf_model import OpenLMforCausalLM 4 | from open_lm.utils.transformers.hf_config import OpenLMConfig 5 | from open_lm.model import create_params 6 | from tests.shared import MockTrainArgs 7 | from tests.utils import run_model, CharacterTokenizer 8 | 9 | 10 | # Download the checkpoint from HuggingFace Hub if it doesn't exist and set the args 11 | @pytest.mark.gpu 12 | @pytest.mark.slow 13 | @pytest.fixture(scope="module") 14 | def args(): 15 | args = MockTrainArgs( 16 | model="open_lm_test_tiny", 17 | **{ 18 | # Generation params: 19 | "input_text": "random", 20 | "max_gen_len": None, 21 | "context_len": None, 22 | "temperature": 0.0, 23 | "top_p": 1.0, 24 | "use_cache": False, 25 | "num_beams": 1, 26 | # Model params that might not be in config: 27 | "model_norm": "default_layer_norm", 28 | "qk_norm": False, 29 | "positional_embedding_type": "rotary", 30 | "ffn_type": "swiglu", 31 | "moe_num_experts": None, 32 | "moe_freq": 0, 33 | "moe_weight_parallelism": False, 34 | "moe_expert_model_parallelism": False, 35 | "moe_capacity_factor": 1.25, 36 | "moe_loss_weight": 0.1, 37 | "moe_top_k": 2, 38 | } 39 | ) 40 | return args 41 | 42 | 43 | @pytest.fixture(scope="module") 44 | def tiny_open_lm(args): 45 | tiny_open_lm = OpenLMforCausalLM(OpenLMConfig(create_params(args))) 46 | tiny_open_lm.model.eval() 47 | return tiny_open_lm 48 | 49 | 50 | # Create a mock tokenizer with a tiny vocab 51 | @pytest.fixture(scope="module") 52 | def tiny_tokenizer(): 53 | # The tiny model has a vocab size of 16, there are 7 special tokens, so we add 9 more 54 | tokenizer = CharacterTokenizer(["a", "b", "c", "d", "e", "f", "g", "h", "i"]) 55 | return tokenizer 56 | 57 | 58 | @pytest.mark.parametrize("wiki_page", ["Soil steam sterilization", "The Triumph of Death"]) 59 | @pytest.mark.parametrize("context_len", [4, 8]) 60 | @pytest.mark.parametrize("max_gen_len", [4, 8]) 61 | @pytest.mark.parametrize("num_beams", [1, 4]) 62 | def test_tiny_generate_kv_cache(tiny_open_lm, tiny_tokenizer, args, wiki_page, context_len, max_gen_len, num_beams): 63 | """ 64 | This test checks that the results of the generation are the same with and without cache. 65 | """ 66 | args.max_gen_len = max_gen_len 67 | args.context_len = context_len 68 | args.num_beams = num_beams 69 | 70 | if max_gen_len + context_len > tiny_open_lm.model.seq_len: 71 | pytest.skip("The model cannot generate sequences that long") 72 | 73 | args.use_cache = False 74 | result_no_cache1 = run_model(tiny_open_lm, tiny_tokenizer, args, wiki_page=wiki_page, start_index=0) 75 | result_no_cache2 = run_model(tiny_open_lm, tiny_tokenizer, args, wiki_page=wiki_page, start_index=0) 76 | 77 | # Check that the results are the same without cache (would fail if the sampling was not deterministic) 78 | assert result_no_cache1 == result_no_cache2 79 | 80 | args.use_cache = True 81 | result_with_cache = run_model(tiny_open_lm, tiny_tokenizer, args, wiki_page=wiki_page, start_index=0) 82 | 83 | # Check that the results are the same as without cache 84 | assert result_no_cache1 == result_with_cache 85 | -------------------------------------------------------------------------------- /tests/test_training_tokens.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from open_lm.data import get_wds_dataset 4 | from open_lm.file_utils import get_string_for_epoch 5 | from open_lm.train import train_one_epoch 6 | from tests.shared import create_train_fixtures 7 | from tests.utils import download_dl_test_data 8 | from torch.cuda.amp import GradScaler 9 | 10 | SOURCE_MANIFEST = ["tests/assets/source_3/manifest.jsonl"] 11 | 12 | 13 | @pytest.mark.gpu 14 | @pytest.mark.parametrize( 15 | "test_case", 16 | [ 17 | (100, 2, 1000, 4, [20, 40]), # Easy case. 18 | (100, 2, 1200, 4, [20, 40, 48]), # End before consuming all in a shard. 19 | (100, 2, 1500, 4, [20, 40, 54, 60]), # One of the shards here is smaller. 54 instead of 56 because of workers. 20 | (85, 2, 1000, 4, [22, 44, 47]), # Batch weirdness, total_steps = 1000 * 4 // 85 = 47, 21 | # steps_epoch = 2000 // (85 * 2) * 2 = 22 22 | ], 23 | ) 24 | def test_token_count(test_case): 25 | """Test if the correct number of steps are performed. 26 | 27 | Run training several times, and make sure that the expected number of steps is done each time. 28 | Having the same number of steps guarantees that the same number of tokens/samples are seen. 29 | 30 | TODO: this test seems to break for some reason, if test_training_simple.py is run along with it. 31 | It works fine when run by itself and if the other tests pass, and it does not affect CI, so it is fine for now. 32 | """ 33 | batch_size, workers, desired_sequences_per_epoch, desired_epochs, expected_result = test_case 34 | 35 | download_dl_test_data() 36 | args, model, _, optimizer, scheduler, loss = create_train_fixtures("open_lm_11m") 37 | args.global_batch_size = batch_size 38 | args.per_gpu_batch_size = args.global_batch_size // args.world_size 39 | args.workers = workers 40 | args.train_data = None 41 | args.dataset_manifest = SOURCE_MANIFEST 42 | args.epochs = desired_epochs 43 | args.train_num_samples = desired_sequences_per_epoch 44 | args.scaler = None if args.precision != "amp" else GradScaler() 45 | 46 | total_samples = desired_sequences_per_epoch * desired_epochs 47 | total_steps = total_samples // (args.global_batch_size) 48 | global_step = 0 49 | next_shard_per_source = [0] 50 | epoch = 0 51 | data = None 52 | 53 | while True: 54 | if data is not None: 55 | del data 56 | 57 | shard_string_for_epoch, num_samples_per_source, next_shard_per_source = get_string_for_epoch( 58 | args.train_num_samples, 59 | next_shard_per_source, 60 | SOURCE_MANIFEST, 61 | weights=None, 62 | num_workers_per_gpu=args.workers, 63 | world_size=args.world_size, 64 | ) 65 | args.train_data = shard_string_for_epoch 66 | print(args.train_data) 67 | data = {} 68 | data["train"] = get_wds_dataset( 69 | args, True, epoch, floor=True, force_num_samples=num_samples_per_source, data_key=args.data_key 70 | ) 71 | 72 | success, global_step = train_one_epoch( 73 | model, data, loss, epoch, global_step, optimizer, args.scaler, scheduler, total_steps, args 74 | ) 75 | 76 | assert success 77 | 78 | assert global_step == expected_result[epoch] 79 | 80 | epoch += 1 81 | 82 | if global_step == total_steps: 83 | break 84 | 85 | assert epoch == len(expected_result) 86 | -------------------------------------------------------------------------------- /open_lm/datapreprocess/docs/ray_cluster_setup.md: -------------------------------------------------------------------------------- 1 | # Complete Guide to Setting Up a Ray Cluster with AWS Instance Profile 2 | 3 | This guide provides a comprehensive walkthrough on setting up a Ray cluster configuration with an AWS Instance Profile, including steps to create the Instance Profile if it doesn't exist already. 4 | 5 | ## Part 1: Setting Up Ray Cluster Configuration 6 | 7 | ### Step 1: Basic Configuration 8 | Start by defining the basic parameters of your Ray cluster in the configuration file: 9 | 10 | ```yaml 11 | cluster_name: ray-shuffle-tokenize 12 | max_workers: 25 13 | upscaling_speed: 0.0 14 | provider: 15 | type: aws 16 | region: us-west-2 17 | cache_stopped_nodes: False 18 | ``` 19 | 20 | ### Step 2: Node Configuration 21 | Configure the node types, specifying the instance types, image IDs, and most importantly, the IAM Instance Profile ARN: 22 | 23 | ```yaml 24 | available_node_types: 25 | ray.head.default: 26 | resources: {} 27 | node_config: 28 | SubnetIds: [subnet-xxx, subnet-yyy, subnet-zzz] 29 | ImageId: ami-xxxxxxx # Example AMI ID 30 | InstanceType: i4i.8xlarge 31 | IamInstanceProfile: 32 | Arn: [Your-Instance-Profile-ARN] 33 | ray.worker.default: 34 | min_workers: 25 35 | max_workers: 25 36 | node_config: 37 | SubnetIds: [subnet-xxx, subnet-yyy, subnet-zzz] 38 | ImageId: ami-xxxxxxx # Example AMI ID 39 | InstanceType: i4i.8xlarge 40 | IamInstanceProfile: 41 | Arn: [Your-Instance-Profile-ARN] 42 | ``` 43 | Replace `[Your-Instance-Profile-ARN]` with the actual ARN of your instance profile. 44 | 45 | ### Step 3: Setup Commands 46 | Define any setup commands necessary for your environment: 47 | 48 | ```yaml 49 | setup_commands: 50 | - wget https://repo.anaconda.com/miniconda/Miniconda3-py310_23.3.1-0-Linux-x86_64.sh -O miniconda.sh 51 | # ... other setup commands ... 52 | ``` 53 | 54 | ### Step 4: Security Best Practices 55 | **Important**: Avoid hardcoding AWS credentials in your scripts or files. Using an IAM role through an Instance Profile is a more secure and recommended approach. 56 | 57 | ## Part 2: Creating an AWS Instance Profile (If Not Existing) 58 | 59 | ### Step 1: Create an IAM Role 60 | 1. **Open IAM in AWS Console**: Log into the AWS Management Console and navigate to the IAM (Identity and Access Management) service. 61 | 2. **Create a New Role**: Go to "Roles" > "Create role". 62 | 3. **Select EC2 as the Trust Entity**: Choose "AWS service" for the type of trusted entity and select "EC2". 63 | 4. **Attach Permissions**: Select `AmazonEC2FullAccess` and `AmazonS3FullAccess` policies for comprehensive EC2 and S3 access. 64 | 5. **Name and Create the Role**: Provide a name (e.g., `RayClusterRole`) and create the role. 65 | 66 | ### Step 2: Create the Instance Profile 67 | 1. **Navigate to the Role**: In IAM roles, find the newly created role. 68 | 2. **Create Instance Profile**: Under the "Role actions" menu, select "Add role to instance profile". 69 | 3. **Name the Instance Profile**: Give the instance profile the same name as the role for consistency. 70 | 71 | ### Step 3: Retrieve the Instance Profile ARN 72 | 1. **Open the Role Details**: Click on the role you just created. 73 | 2. **Copy the Instance Profile ARN**: In the summary section, you'll find the ARN which looks like `arn:aws:iam::[aws-account-id]:instance-profile/RayClusterRole`. 74 | 75 | ### Step 4: Update Ray Cluster Config 76 | Replace `[Your-Instance-Profile-ARN]` in your Ray cluster configuration with the ARN you just copied. 77 | 78 | -------------------------------------------------------------------------------- /scripts/generate.py: -------------------------------------------------------------------------------- 1 | """Script to generate text from a trained model using HuggingFace wrappers.""" 2 | 3 | import argparse 4 | import json 5 | import builtins as __builtin__ 6 | 7 | import torch 8 | 9 | from composer.utils import dist, get_device 10 | from open_lm.utils.transformers.hf_model import OpenLMforCausalLM 11 | from open_lm.utils.transformers.hf_config import OpenLMConfig 12 | from open_lm.utils.llm_foundry_wrapper import SimpleComposerOpenLMCausalLM 13 | from open_lm.model import create_params 14 | from open_lm.params import add_model_args 15 | from open_lm.file_utils import pt_load 16 | from transformers import GPTNeoXTokenizerFast, AutoTokenizer 17 | import numpy as np 18 | 19 | 20 | builtin_print = __builtin__.print 21 | 22 | 23 | @torch.inference_mode() 24 | def run_model(open_lm: OpenLMforCausalLM, tokenizer, args): 25 | dist.initialize_dist(get_device(None), timeout=600) 26 | input = tokenizer(args.input_text) 27 | input = {k: torch.tensor(v).unsqueeze(0).cuda() for k, v in input.items()} 28 | composer_model = SimpleComposerOpenLMCausalLM(open_lm, tokenizer) 29 | composer_model = composer_model.cuda() 30 | 31 | generate_args = { 32 | "do_sample": args.temperature > 0, 33 | "pad_token_id": 50282, 34 | "max_new_tokens": args.max_gen_len, 35 | "use_cache": args.use_cache, 36 | "num_beams": args.num_beams, 37 | } 38 | # If these are set when temperature is 0, they will trigger a warning and be ignored 39 | if args.temperature > 0: 40 | generate_args["temperature"] = args.temperature 41 | generate_args["top_p"] = args.top_p 42 | 43 | if args.seed is not None: 44 | np.random.seed(args.seed) 45 | torch.manual_seed(args.seed) 46 | if torch.cuda.is_available(): 47 | torch.cuda.manual_seed_all(args.seed) 48 | 49 | output = composer_model.generate( 50 | input["input_ids"], 51 | **generate_args, 52 | ) 53 | output = tokenizer.decode(output[0].cpu().numpy()) 54 | print("-" * 50) 55 | print("\t\t Model output:") 56 | print("-" * 50) 57 | print(output) 58 | 59 | 60 | def main(): 61 | parser = argparse.ArgumentParser() 62 | parser.add_argument("--checkpoint") 63 | parser.add_argument("--model", type=str, default="open_lm_1b", help="Name of the model to use") 64 | 65 | parser.add_argument("--input-text", required=True) 66 | parser.add_argument("--max-gen-len", default=200, type=int) 67 | parser.add_argument("--temperature", default=0.8, type=float) 68 | parser.add_argument("--top-p", default=0.95, type=float) 69 | parser.add_argument("--use-cache", default=False, action="store_true") 70 | parser.add_argument("--tokenizer", default="EleutherAI/gpt-neox-20b", type=str) 71 | parser.add_argument("--num-beams", default=1, type=int) 72 | parser.add_argument("--seed", default=None, type=int) 73 | 74 | add_model_args(parser) 75 | args = parser.parse_args() 76 | print("Loading model into the right classes...") 77 | open_lm = OpenLMforCausalLM(OpenLMConfig(create_params(args))) 78 | 79 | if "gpt-neox-20b" in args.tokenizer: 80 | tokenizer = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b") 81 | else: 82 | # mistralai/Mistral-7B-v0.1, meta-llama/Llama-2-7b-chat-hf, 83 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) 84 | 85 | if args.checkpoint is not None: 86 | print("Loading checkpoint from disk...") 87 | checkpoint = pt_load(args.checkpoint) 88 | state_dict = checkpoint["state_dict"] 89 | state_dict = {x.replace("module.", ""): y for x, y in state_dict.items()} 90 | open_lm.model.load_state_dict(state_dict) 91 | open_lm.model.eval() 92 | 93 | run_model(open_lm, tokenizer, args) 94 | 95 | 96 | if __name__ == "__main__": 97 | main() 98 | -------------------------------------------------------------------------------- /AVERAGE.md: -------------------------------------------------------------------------------- 1 | 2 | # Instruction tuning and weight averaging 3 | 4 | Note that some of these stpes may be out of date, but the general flow should remain. 5 | 6 | We downloaded the data from https://huggingface.co/datasets/timdettmers/openassistant-guanaco then ran `python datapreprocess/make_assistant_data.py --input-files /fsx/home-mitchellw/openassistant_best_replies_train.jsonl --output-dir /fsx/home-mitchellw/tmp --num-workers 1 --num-consumers 1`. Note that we changed shard size so there would be at least 8 shards. 7 | 8 | ``` 9 | torchrun --nproc-per-node 8 -m open_lm.main \ 10 | --train-data "pipe:aws s3 cp s3:///lmdata/assistant_data/train/shard-{0000000..0000008}.tar -" \ 11 | --train-num-samples 4382720 \ 12 | --workers 1 \ 13 | --precision amp_bfloat16 \ 14 | --batch-size 8 \ 15 | --grad-checkpointing \ 16 | --log-every-n-steps 1 \ 17 | --grad-clip-norm 1 \ 18 | --lr 2e-5 \ 19 | --model g3b_neox \ 20 | --fsdp --fsdp-amp \ 21 | --warmup 100 \ 22 | --wd 0.1 \ 23 | --beta2 0.95 \ 24 | --epochs 6 \ 25 | --disable-buffer \ 26 | --lr-cooldown-end 5e-6 \ 27 | --report-to wandb \ 28 | --wandb-project-name lmtune \ 29 | --pretrained /fsx/home-mitchellw/experimetns/lm/1p5T-bigdata-neox-g3b_neox-10-1e-3-0.1-nodes48-bs10-v0/checkpoints/epoch_24.pt \ 30 | --name instruction-tune-3b-2e-5-6 \ 31 | --logs /fsx/home-mitchellw/experimetns/lmtune 32 | ``` 33 | 34 | Now we want to interpolate between the base and fine-tuned model with different coefficients alpha. We can do so with this bash script. 35 | 36 | ``` 37 | BASEMODEL=/fsx/home-mitchellw/experimetns/lm/1p5T-bigdata-neox-g3b_neox-10-1e-3-0.1-nodes48-bs10-v0/checkpoints/epoch_24.pt 38 | FINALMODEL=/fsx/home-mitchellw/experimetns/lmtune/instruction-tune-3b-2e-5-6/checkpoints/epoch_6.pt 39 | MODEL=g3b_neox 40 | 41 | for alpha in $(seq 0 0.05 1) 42 | do 43 | 44 | #echo $model 45 | save_path_1="$(dirname $FINALMODEL)/chat-eval-interpolate-$alpha-$(basename $FINALMODEL)" 46 | save_path_2="$(dirname $FINALMODEL)/base-eval-interpolate-$alpha-$(basename $FINALMODEL)" 47 | 48 | echo $save_path_1 49 | echo $save_path_2 50 | 51 | 52 | if [ -f "$save_path_1" ]; then 53 | echo "$save_path_1 exists." 54 | else 55 | # first do the chat eval. 56 | torchrun --nproc-per-node 4 -m open_lm.main \ 57 | --val-data "pipe:aws s3 cp s3:///lmdata/assistant_data/val.tar -" \ 58 | --workers 6 \ 59 | --precision amp_bfloat16 \ 60 | --batch-size 8 \ 61 | --grad-checkpointing \ 62 | --log-every-n-steps 1 \ 63 | --model $MODEL \ 64 | --fsdp --fsdp-amp \ 65 | --train-num-samples 1000000000 \ 66 | --name $RANDOM \ 67 | --average $BASEMODEL $FINALMODEL \ 68 | --average-coefficients $alpha $(echo "1-$alpha" | bc -l) \ 69 | --logs /fsx/home-mitchellw/experimetns/lmdebug > $save_path_1 70 | 71 | # now do the base eval 72 | torchrun --nproc-per-node 4 -m open_lm.main \ 73 | --val-data "pipe:aws s3 cp s3:///lmdata/validation_data_tokenized/open_lm//shard_00000000.tar -" \ 74 | --workers 6 \ 75 | --precision amp_bfloat16 \ 76 | --batch-size 8 \ 77 | --grad-checkpointing \ 78 | --log-every-n-steps 1 \ 79 | --model $MODEL \ 80 | --data-key json \ 81 | --fsdp --fsdp-amp \ 82 | --train-num-samples 1000000000 \ 83 | --name $RANDOM \ 84 | --average $BASEMODEL $FINALMODEL \ 85 | --average-coefficients $alpha $(echo "1-$alpha" | bc -l) \ 86 | --logs /fsx/home-mitchellw/experimetns/lmdebug > $save_path_2 87 | fi 88 | done 89 | ``` 90 | 91 | Then you can make a plot with `python plots/interpolation.py` which results in the following plot. 92 | 93 | ![](plots/interpolation.png) -------------------------------------------------------------------------------- /tests/test_save_load_from_main.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import os 3 | import shutil 4 | 5 | import torch.multiprocessing as mp 6 | 7 | from open_lm.main import main 8 | 9 | 10 | def tiny_save_load(fsdp=False, distributed=False): 11 | """ 12 | This test checks that the model can be saved and loaded without changing the parameters. 13 | """ 14 | name = "test_tiny_save_load" 15 | # fmt: off 16 | logdir = "tests/assets/" 17 | args = [ 18 | "--train-num-samples", 64 * 16, # seq_len is 16 for open_lm_test_tiny 19 | "--global-batch-size", 4, 20 | "--name", name, 21 | "--model", "open_lm_test_tiny", 22 | "--dataset-type", "synthetic", 23 | "--logs", logdir, 24 | ] 25 | args = [str(x) for x in args] 26 | # fmt: on 27 | 28 | if fsdp: 29 | args += ["--fsdp", "--fsdp-amp", "--precision", "amp_bf16"] 30 | assert distributed 31 | 32 | if distributed: 33 | args += ["--force-distributed"] 34 | os.environ["RANK"] = "0" 35 | os.environ["WORLD_SIZE"] = "1" 36 | os.environ["MASTER_ADDR"] = "127.0.0.1" 37 | os.environ["MASTER_PORT"] = "12301" 38 | 39 | try: 40 | # Train for one epoch, load the model, then train for another epoch. 41 | main(args + ["--epochs", "1"]) 42 | 43 | # Loading saved tiny model 44 | resume_args = args + ["--resume", "latest", "--epochs", "2"] 45 | main(resume_args) 46 | finally: 47 | shutil.rmtree(f"{logdir}{name}", ignore_errors=True) 48 | 49 | 50 | def tiny_save_load_different_seed(fsdp=False, distributed=False): 51 | """ 52 | This test checks that the model can be saved and loaded without changing the parameters. 53 | """ 54 | name = "test_tiny_save_load" 55 | # fmt: off 56 | logdir = "tests/assets/" 57 | args = [ 58 | "--train-num-samples", 64 * 16, # seq_len is 16 for open_lm_test_tiny 59 | "--global-batch-size", 4, 60 | "--name", name, 61 | "--model", "open_lm_test_tiny", 62 | "--dataset-type", "synthetic", 63 | "--logs", logdir, 64 | ] 65 | args = [str(x) for x in args] 66 | # fmt: on 67 | 68 | if fsdp: 69 | args += ["--fsdp", "--fsdp-amp", "--precision", "amp_bf16"] 70 | assert distributed 71 | 72 | if distributed: 73 | args += ["--force-distributed"] 74 | os.environ["RANK"] = "0" 75 | os.environ["WORLD_SIZE"] = "1" 76 | os.environ["MASTER_ADDR"] = "127.0.0.1" 77 | os.environ["MASTER_PORT"] = "12301" 78 | 79 | try: 80 | # Train for one epoch, load the model, then train for another epoch. 81 | main(args + ["--epochs", "1"]) 82 | 83 | # Loading saved tiny model 84 | resume_args = args + ["--resume", "latest", "--epochs", "2", "--seed", "42"] 85 | main(resume_args) 86 | raise RuntimeError( 87 | "This checkpoint resuming should have failed due to different seeds, but the model resumed normally." 88 | ) 89 | except AssertionError as e: 90 | assert ( 91 | str(e) 92 | == "This checkpoint was trained with a random seed of 0. Since this seed affects shard shuffling, resuming training must use the same seed." 93 | ) 94 | finally: 95 | shutil.rmtree(f"{logdir}{name}", ignore_errors=True) 96 | 97 | 98 | def _save_load_helper_dist(rank, fsdp): 99 | tiny_save_load(fsdp=fsdp, distributed=True) 100 | 101 | 102 | def test_tiny_save_load_no_distributed(): 103 | tiny_save_load(fsdp=False, distributed=False) 104 | 105 | 106 | @pytest.mark.gpu 107 | @pytest.mark.parametrize("fsdp", [False, True]) 108 | def test_tiny_save_load_dist_fsdp(fsdp): 109 | mp.spawn(_save_load_helper_dist, args=(fsdp,), nprocs=1, join=True) 110 | 111 | 112 | if __name__ == "__main__": 113 | pytest.main([__file__]) 114 | -------------------------------------------------------------------------------- /tests/test_grad_accum.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import pytest 3 | 4 | import torch 5 | import torch.multiprocessing as mp 6 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 7 | from open_lm.model import create_model 8 | 9 | from open_lm.train import train_one_epoch 10 | from open_lm.main import random_seed 11 | from tests.shared import create_train_fixtures 12 | 13 | 14 | def _grad_acc_helper(test_fsdp, accs=[1, 2], threshold=1e-7): 15 | if test_fsdp: 16 | world_size = 1 17 | mp.spawn( 18 | _grad_acc_helper_fsdp, 19 | args=(world_size, accs, threshold), 20 | nprocs=world_size, 21 | join=True, 22 | ) 23 | else: 24 | _grad_acc_helper_single(test_fsdp=False, accs=accs, threshold=threshold) 25 | 26 | 27 | def _grad_acc_helper_fsdp(rank, world_size, accs, threshold): 28 | # Initialize distributed training 29 | torch.distributed.init_process_group( 30 | backend="nccl" if torch.cuda.is_available() else "gloo", 31 | init_method="tcp://127.0.0.1:29501", 32 | rank=rank, 33 | world_size=world_size, 34 | ) 35 | _grad_acc_helper_single(test_fsdp=True, accs=accs, threshold=threshold) 36 | torch.distributed.destroy_process_group() 37 | 38 | 39 | def _grad_acc_helper_single(test_fsdp, accs=[2, 1], threshold=1e-7): 40 | random_seed() 41 | # List of tuples with (args, model, data, optimizer, scheduler, loss) 42 | fixtures = [] 43 | for _ in accs: 44 | random_seed() 45 | (args, model, data, optimizer, scheduler, loss) = create_train_fixtures() 46 | 47 | # HACK: Currently, AdamW optimizer leads to different results with gradient accumulation. 48 | optimizer = torch.optim.SGD([p for p in model.parameters() if p.requires_grad], lr=args.lr) 49 | 50 | if test_fsdp: 51 | args.fsdp = True 52 | args.fsdp_amp = True 53 | # Required to force distributed mode on 1 gpu. 54 | args.distributed = True 55 | fixtures.append((args, model, data, optimizer, scheduler, loss)) 56 | 57 | model1 = fixtures[0][1] 58 | for fixture in fixtures[1:]: 59 | model2 = fixture[1] 60 | for p1, p2 in zip(model1.parameters(), model2.parameters()): 61 | assert torch.allclose(p1, p2, atol=threshold), "Parameter mismatch at init" 62 | 63 | # train on mock data with/without grad accumulation for one epoch 64 | for fixture, accum_freq in zip(fixtures, accs): 65 | args, model, data, optimizer, scheduler, loss = fixture 66 | if test_fsdp: 67 | model = FSDP(model) 68 | args.accum_freq = accum_freq 69 | random_seed() 70 | train_one_epoch( 71 | model=model, 72 | data=data, 73 | loss=loss, 74 | epoch=0, 75 | step=0, 76 | optimizer=optimizer, 77 | scaler=None, 78 | scheduler=scheduler, 79 | total_steps=10, 80 | args=args, 81 | ) 82 | 83 | model1 = fixtures[0][1] 84 | failed_grad = [] 85 | failed_weight = [] 86 | for fixture in fixtures[1:]: 87 | model2 = fixture[1] 88 | for (n1, p1), (n2, p2) in zip(model1.named_parameters(), model2.named_parameters()): 89 | if not torch.allclose(p1.grad, p2.grad, atol=threshold): 90 | failed_grad.append(n1) 91 | print(f"Gradient mismatch at {n1}, {n2}") 92 | 93 | if not torch.allclose(p1, p2, atol=threshold): 94 | failed_weight.append(n1) 95 | print(f"Weight mismatch at {n1}, {n2}") 96 | assert not failed_grad, f"Failed gradient checks at: {failed_grad}" 97 | assert not failed_weight, f"Failed weight checks at: {failed_weight}" 98 | 99 | 100 | def test_no_accumulation_matches(): 101 | _grad_acc_helper(test_fsdp=False, accs=[1, 1]) 102 | 103 | 104 | def test_grad_acc(): 105 | _grad_acc_helper(test_fsdp=False, accs=[1, 2]) 106 | 107 | 108 | @pytest.mark.gpu 109 | def test_grad_acc_fsdp(): 110 | _grad_acc_helper(test_fsdp=True, accs=[1, 2]) 111 | -------------------------------------------------------------------------------- /open_lm/tests/test_accumulation.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from math import ceil 3 | 4 | import torch 5 | import unittest 6 | from torch.cuda.amp import GradScaler 7 | from torch.utils.data import DataLoader, Dataset 8 | from torch import nn 9 | 10 | from open_lm.train import train_one_epoch 11 | 12 | 13 | # Dummy model 14 | class SimpleModel(torch.nn.Module): 15 | def __init__(self, vocab_size, dim=3): 16 | super(SimpleModel, self).__init__() 17 | self.tok_embeddings = nn.Embedding(vocab_size, dim) 18 | self.fc = torch.nn.Linear(dim, vocab_size) 19 | 20 | def forward(self, x): 21 | out = self.fc(self.tok_embeddings(x)) 22 | return out, None, None 23 | 24 | 25 | # Dummy dataset 26 | class DummyDataset(Dataset): 27 | def __init__(self, seq_len, vocab_size): 28 | self.vocab_size = vocab_size 29 | self.seq_len = seq_len 30 | 31 | def __len__(self): 32 | return 198 33 | 34 | def __getitem__(self, idx): 35 | generator = torch.Generator().manual_seed(idx) 36 | return ((torch.rand(self.seq_len + 1, generator=generator) * self.vocab_size).long(),) 37 | 38 | 39 | # Unit test 40 | class TestGradientAccumulation(unittest.TestCase): 41 | def test_accumulation(self): 42 | args = { 43 | "device": "cpu", 44 | "precision": "fp16", 45 | "accum_freq": 1, 46 | "seq_len": 9, 47 | "vocab_size": 10, 48 | "batch_size": 16, 49 | "log_logit_mean": False, 50 | "grad_clip_norm": 1.0, 51 | "skip_scheduler": True, 52 | "rank": 0, 53 | "local_rank": 0, 54 | "world_size": 1, 55 | "wandb": False, 56 | "log_every_n_steps": 1, 57 | "target_mask_left": None, 58 | "target_mask_individual": None, 59 | } 60 | 61 | model1 = SimpleModel(vocab_size=args["vocab_size"]) 62 | model2 = SimpleModel(vocab_size=args["vocab_size"]) 63 | model2.load_state_dict(model1.state_dict()) 64 | 65 | # Check if the weights are similar 66 | for p1, p2 in zip(model1.parameters(), model2.parameters()): 67 | self.assertTrue( 68 | torch.allclose(p1, p2, atol=1e-7), 69 | "Weights differ between accumulation modes.", 70 | ) 71 | 72 | optimizer1 = torch.optim.SGD(model1.parameters(), lr=0.001) 73 | optimizer2 = torch.optim.SGD(model2.parameters(), lr=0.001) 74 | 75 | loss_fn = torch.nn.CrossEntropyLoss() 76 | dataset = DummyDataset(seq_len=args["seq_len"], vocab_size=args["vocab_size"]) 77 | dataloader = DataLoader(dataset, batch_size=args["batch_size"], shuffle=False) 78 | dataloader.num_batches = len(dataloader) 79 | dataloader.num_samples = len(dataloader) * args["batch_size"] 80 | # Train model1 without accumulation 81 | args["accum_freq"] = 2 82 | scaler = None # GradScaler() 83 | data = Namespace(dataloader=dataloader, set_epoch=lambda x: None) 84 | 85 | train_one_epoch( 86 | model=model1, 87 | data={"train": data}, 88 | loss=loss_fn, 89 | step=0, 90 | epoch=0, 91 | optimizer=optimizer1, 92 | scaler=scaler, 93 | scheduler=None, 94 | total_steps=-1, 95 | args=Namespace(**args), 96 | ) 97 | # Train model2 with accumulation 98 | args["accum_freq"] = 1 99 | train_one_epoch( 100 | model=model2, 101 | data={"train": data}, 102 | loss=loss_fn, 103 | step=0, 104 | epoch=0, 105 | optimizer=optimizer2, 106 | scaler=scaler, 107 | scheduler=None, 108 | total_steps=-1, 109 | args=Namespace(**args), 110 | ) 111 | # Check if the weights are similar 112 | for p1, p2 in zip(model1.parameters(), model2.parameters()): 113 | self.assertTrue( 114 | torch.allclose(p1, p2, atol=1e-7), 115 | "Weights differ between accumulation modes.", 116 | ) 117 | 118 | 119 | if __name__ == "__main__": 120 | unittest.main() 121 | -------------------------------------------------------------------------------- /tests/test_generate_load_kv_cache_equal.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import os 3 | 4 | import pytest 5 | import torch 6 | 7 | from huggingface_hub import hf_hub_download 8 | from transformers import GPTNeoXTokenizerFast 9 | 10 | from open_lm.utils.transformers.hf_model import OpenLMforCausalLM 11 | from open_lm.utils.transformers.hf_config import OpenLMConfig 12 | from open_lm.model import create_params 13 | from tests.shared import MockTrainArgs 14 | from .utils import run_model 15 | 16 | 17 | # Download the checkpoint from HuggingFace Hub if it doesn't exist and set the args 18 | @pytest.mark.gpu 19 | @pytest.mark.slow 20 | @pytest.fixture(scope="module") 21 | def args(): 22 | if not os.path.exists("checkpoints/open_lm_1b_old.pt"): 23 | if not os.path.exists("checkpoints"): 24 | os.makedirs("checkpoints") 25 | print("Downloading checkpoint from HuggingFace Hub...") 26 | model_path = hf_hub_download("mlfoundations/open_lm_1B", filename="open_lm_1b.pt") 27 | shutil.copy2(model_path, "checkpoints/open_lm_1b_old.pt") 28 | 29 | args = MockTrainArgs( 30 | model="open_lm_1b_old", 31 | **{ 32 | # Generation params: 33 | "input_text": "random", 34 | "max_gen_len": None, 35 | "context_len": None, 36 | "temperature": 0.0, 37 | "top_p": 1.0, 38 | "use_cache": False, 39 | "checkpoint": "checkpoints/open_lm_1b_old.pt", 40 | # Model params that might not be in config: 41 | "model_norm": "default_layer_norm", 42 | "qk_norm": False, 43 | "positional_embedding_type": "head_rotary", 44 | "ffn_type": "swiglu", 45 | "moe_num_experts": None, 46 | "moe_freq": 0, 47 | "moe_weight_parallelism": False, 48 | "moe_expert_model_parallelism": False, 49 | "moe_capacity_factor": 1.25, 50 | "moe_loss_weight": 0.1, 51 | "moe_top_k": 2, 52 | "num_beams": 1, 53 | } 54 | ) 55 | return args 56 | 57 | 58 | # Set the tokenizer 59 | @pytest.mark.gpu 60 | @pytest.mark.slow 61 | @pytest.fixture(scope="module") 62 | def tokenizer(): 63 | return GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b") 64 | 65 | 66 | # Create the OpenLM model and load the weights only once 67 | @pytest.mark.gpu 68 | @pytest.mark.slow 69 | @pytest.fixture(scope="module") 70 | def open_lm(args): 71 | open_lm = OpenLMforCausalLM(OpenLMConfig(create_params(args))) 72 | 73 | if args.checkpoint is not None: 74 | print("Loading checkpoint from disk...") 75 | checkpoint = torch.load(args.checkpoint) 76 | state_dict = checkpoint["state_dict"] 77 | state_dict = {x.replace("module.", ""): y for x, y in state_dict.items()} 78 | open_lm.model.load_state_dict(state_dict) 79 | 80 | open_lm.model.eval() 81 | return open_lm 82 | 83 | 84 | @pytest.mark.gpu 85 | @pytest.mark.slow 86 | @pytest.mark.parametrize("wiki_page", ["Soil steam sterilization", "The Triumph of Death"]) 87 | @pytest.mark.parametrize("context_len", [128, 256]) 88 | @pytest.mark.parametrize("max_gen_len", [128, 256]) 89 | def test_generate_kv_cache(open_lm, tokenizer, args, wiki_page, context_len, max_gen_len): 90 | """ 91 | This test checks that the results of the generation are the same with and without cache. 92 | """ 93 | args.max_gen_len = max_gen_len 94 | args.context_len = context_len 95 | if max_gen_len + context_len > open_lm.model.seq_len: 96 | pytest.skip("The model cannot generate sequences that long") 97 | 98 | args.use_cache = False 99 | result_no_cache1 = run_model(open_lm, tokenizer, args, wiki_page=wiki_page, start_index=0) 100 | result_no_cache2 = run_model(open_lm, tokenizer, args, wiki_page=wiki_page, start_index=0) 101 | 102 | # Check that the results are the same without cache (would fail if the sampling was not deterministic) 103 | assert result_no_cache1 == result_no_cache2 104 | 105 | args.use_cache = True 106 | result_with_cache = run_model(open_lm, tokenizer, args, wiki_page=wiki_page, start_index=0) 107 | 108 | # Check that the results are the same as without cache 109 | assert result_no_cache1 == result_with_cache 110 | -------------------------------------------------------------------------------- /open_lm/distributed.py: -------------------------------------------------------------------------------- 1 | # This is from open_clip. 2 | import os 3 | import logging 4 | import torch 5 | import torch.distributed as dist 6 | 7 | 8 | def is_global_master(args): 9 | return args.rank == 0 10 | 11 | 12 | def is_local_master(args): 13 | return args.local_rank == 0 14 | 15 | 16 | def is_master(args, local=False): 17 | return is_local_master(args) if local else is_global_master(args) 18 | 19 | 20 | def is_using_distributed(): 21 | if "WORLD_SIZE" in os.environ: 22 | return int(os.environ["WORLD_SIZE"]) > 1 23 | if "SLURM_NTASKS" in os.environ: 24 | return int(os.environ["SLURM_NTASKS"]) > 1 25 | return False 26 | 27 | 28 | def world_info_from_env(): 29 | local_rank = 0 30 | for v in ( 31 | "LOCAL_RANK", 32 | "MPI_LOCALRANKID", 33 | "SLURM_LOCALID", 34 | "OMPI_COMM_WORLD_LOCAL_RANK", 35 | ): 36 | if v in os.environ: 37 | local_rank = int(os.environ[v]) 38 | break 39 | global_rank = 0 40 | for v in ("RANK", "PMI_RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK"): 41 | if v in os.environ: 42 | global_rank = int(os.environ[v]) 43 | break 44 | world_size = 1 45 | for v in ("WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE"): 46 | if v in os.environ: 47 | world_size = int(os.environ[v]) 48 | break 49 | 50 | return local_rank, global_rank, world_size 51 | 52 | 53 | def init_distributed_device(args): 54 | # Distributed training = training on more than one GPU. 55 | # Works in both single and multi-node scenarios. 56 | args.distributed = False 57 | args.world_size = 1 58 | args.rank = 0 # global rank 59 | args.local_rank = 0 60 | # For testing, allow forcing distributed mode to test distributed code path even on one gpu. 61 | if is_using_distributed() or args.force_distributed: 62 | if "SLURM_PROCID" in os.environ: 63 | # DDP via SLURM 64 | args.local_rank, args.rank, env_world_size = world_info_from_env() 65 | if args.preset_world_size is None: 66 | args.world_size = env_world_size 67 | else: 68 | args.world_size = args.preset_world_size 69 | if args.rank >= args.world_size: 70 | logging.info(f"Rank {args.rank} not needed with world size {args.world_size}. Exiting.") 71 | exit(0) 72 | 73 | # SLURM var -> torch.distributed vars in case needed 74 | os.environ["LOCAL_RANK"] = str(args.local_rank) 75 | os.environ["RANK"] = str(args.rank) 76 | os.environ["WORLD_SIZE"] = str(args.world_size) 77 | torch.distributed.init_process_group( 78 | backend=args.dist_backend, 79 | init_method=args.dist_url, 80 | world_size=args.world_size, 81 | rank=args.rank, 82 | ) 83 | else: 84 | # DDP via torchrun, torch.distributed.launch 85 | # Note that this currently assumes that the world size is all gpus in a node. 86 | assert args.preset_world_size is None, "--preset_world_size with torchrun is not currently supported." 87 | args.local_rank, _, _ = world_info_from_env() 88 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url) 89 | args.world_size = torch.distributed.get_world_size() 90 | args.rank = torch.distributed.get_rank() 91 | args.distributed = True 92 | 93 | if torch.cuda.is_available(): 94 | if args.distributed and not args.no_set_device_rank: 95 | device = "cuda:%d" % args.local_rank 96 | else: 97 | device = "cuda:0" 98 | torch.cuda.set_device(device) 99 | else: 100 | device = "cpu" 101 | args.device = device 102 | device = torch.device(device) 103 | return device 104 | 105 | 106 | def broadcast_object(args, obj, src=0): 107 | if args.rank == src: 108 | objects = [obj] 109 | else: 110 | objects = [None] 111 | dist.broadcast_object_list(objects, src=src) 112 | return objects[0] 113 | 114 | 115 | def all_gather_object(args, obj, dst=0): 116 | # gather a pickle-able python object across all ranks 117 | objects = [None for _ in range(args.world_size)] 118 | dist.all_gather_object(objects, obj) 119 | return objects 120 | -------------------------------------------------------------------------------- /tests/test_custom_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from open_lm.attention import torch_attn, custom_attn, xformers_attn, ATTN_ACTIVATIONS, ATTN_SEQ_SCALARS 4 | from open_lm.model import SwiGLUTorch 5 | from open_lm.precision import get_autocast 6 | from xformers.ops import SwiGLU 7 | 8 | 9 | def test_custom_attn_matches_softmax_attn(threshold=1e-7): 10 | for bs, q_seq_len, k_seq_len, h, d in [ 11 | [10, 1024, 2048, 8, 128], 12 | [10, 2048, 1024, 8, 128], 13 | [10, 2048, 2048, 8, 128], 14 | [1, 1024, 2048, 8, 128], 15 | ]: 16 | queries = torch.rand(bs, q_seq_len, h, d) 17 | keys = torch.rand(bs, k_seq_len, h, d) 18 | values = torch.rand(bs, k_seq_len, h, d) 19 | 20 | for is_causal in [True, False]: 21 | torch_out = torch_attn(queries.cpu(), keys.cpu(), values.cpu(), is_causal=is_causal) 22 | 23 | my_out = custom_attn( 24 | queries.cpu(), 25 | keys.cpu(), 26 | values.cpu(), 27 | attn_activation="softmax", 28 | attn_seq_scalar="none", 29 | alpha=1.0, 30 | is_causal=is_causal, 31 | ) 32 | 33 | assert torch.allclose( 34 | torch_out, my_out, atol=threshold 35 | ), "custom_attn incorrectly implements softmax attention" 36 | 37 | if torch.cuda.is_available(): 38 | # also test xformers attention 39 | torch_out = torch_attn(queries.cuda(), keys.cuda(), values.cuda(), is_causal=is_causal) 40 | xformers_out = xformers_attn(queries.cuda(), keys.cuda(), values.cuda(), is_causal=is_causal) 41 | my_out = custom_attn( 42 | queries.cuda(), 43 | keys.cuda(), 44 | values.cuda(), 45 | attn_activation="softmax", 46 | attn_seq_scalar="none", 47 | alpha=1.0, 48 | is_causal=is_causal, 49 | ) 50 | 51 | assert torch.allclose( 52 | torch_out, my_out, atol=threshold 53 | ), "custom_attn incorrectly implements softmax attention" 54 | 55 | assert torch.allclose( 56 | xformers_out, my_out, atol=threshold 57 | ), "custom_attn incorrectly implements softmax attention" 58 | 59 | 60 | def test_no_failure(): 61 | for nl in ATTN_ACTIVATIONS: 62 | for os in ATTN_SEQ_SCALARS: 63 | for bs, q_seq_len, k_seq_len, h, d in [ 64 | [2, 64, 64, 1, 32], 65 | [2, 64, 16, 1, 32], 66 | [2, 16, 64, 1, 32], 67 | ]: 68 | queries = torch.rand(bs, q_seq_len, h, d) 69 | keys = torch.rand(bs, k_seq_len, h, d) 70 | values = torch.rand(bs, k_seq_len, h, d) 71 | 72 | for is_causal in [True, False]: 73 | custom_attn( 74 | queries, 75 | keys, 76 | values, 77 | attn_activation=nl, 78 | attn_seq_scalar=os, 79 | alpha=1.0, 80 | is_causal=is_causal, 81 | ) 82 | 83 | assert True 84 | 85 | 86 | def test_swiglu_torch(threshold=1e-7): 87 | bsz = 5 88 | in_feats = 10 89 | hidden_feats = 30 90 | out_feats = 10 91 | num_tries = 5 92 | 93 | xops_swiglu = SwiGLU(in_features=in_feats, hidden_features=hidden_feats, out_features=out_feats) 94 | torch_swiglu = SwiGLUTorch(in_dim=in_feats, hidden_dim=hidden_feats, out_dim=out_feats) 95 | 96 | # Copy state dict from one swiglu to the other so that they have the same weights 97 | 98 | state_dict = xops_swiglu.state_dict() 99 | new_state_dict = { 100 | "w12.weight": state_dict["w12.weight"], 101 | "w3.weight": state_dict["w3.weight"], 102 | "w12.bias": state_dict["w12.bias"], 103 | "w3.bias": state_dict["w3.bias"], 104 | } 105 | torch_swiglu.load_state_dict(new_state_dict) 106 | 107 | with torch.no_grad(): 108 | for _ in range(num_tries): 109 | random_in = torch.rand((bsz, in_feats)) 110 | torch_out = torch_swiglu(random_in) 111 | xops_out = xops_swiglu(random_in) 112 | assert torch.allclose(torch_out, xops_out, atol=threshold) 113 | -------------------------------------------------------------------------------- /MOE.md: -------------------------------------------------------------------------------- 1 | # Mixture of Experts Language Models 2 | 3 | ## Dependencies 4 | 5 | Our implementation of mixture of experts depends on [megablocks](https://github.com/stanford-futuredata/megablocks) and the version of xformers which is compatible with torch 2.1: 6 | 7 | ``` 8 | pip install megablocks 9 | pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu121 10 | ``` 11 | 12 | ## Train MoE 13 | 14 | To train an MoE, add the `--moe-X` related arguments to the training command: 15 | 16 | ``` 17 | torchrun --nproc-per-node 8 -m open_lm.main \ 18 | --train-num-samples 10000000000 \ 19 | --workers 2 \ 20 | --dataset-manifest "s3://laion-west/rpj_tokenized_upsampled_eleutherai/manifest.jsonl" "s3://laion-west/2T_no_rpj_tokenized_upsampled_25k_shards/manifest.jsonl" \ 21 | --train-data-mix-weights 0.725 0.275 \ 22 | --precision amp_bfloat16 \ 23 | --batch-size 8 \ 24 | --accum-freq 4 \ 25 | --log-every-n-steps 20 \ 26 | --grad-clip-norm 1 \ 27 | --lr 5e-4 \ 28 | --warmup 200 \ 29 | --model open_lm_41m \ 30 | --wd 0.1 \ 31 | --beta2 0.95 \ 32 | --epochs 50 \ 33 | --report-to wandb \ 34 | --moe-freq 2 \ 35 | --moe-num-experts 8 \ 36 | --moe-top-k 2 \ 37 | --moe-capacity-factor 1.25 --moe-loss-weight 0.1 \ 38 | --disable-meta-device \ 39 | --wandb-project-name moe \ 40 | --name test$RANDOM \ 41 | --logs /fsx/home-$USER/experiments/moe \ 42 | --resume latest \ 43 | --seed 124 \ 44 | --data-key 'json' \ 45 | --fsdp --fsdp-amp \ 46 | --model-norm gain_only_layer_norm \ 47 | --lr-scheduler cosine \ 48 | --lr-cooldown-end 0.00001 49 | ``` 50 | 51 | The above command will add an MoE FFN layer to every other Transformer block. You can use an arbitrary number of experts; you are only limited by total RAM across all GPUs. 52 | 53 | 54 | You can also add the `moe_expert_model_parallelism` which will distribute experts across different GPUs. However, if the number of GPUs is larger than number of experts, an additional num_gpu/num_expert tensor parallelism is applied. Currently this is not eval-friendly though, so I would not recommend using it yet. 55 | 56 | You can evaluate the MoE in the same way as dense models: 57 | 58 | ``` 59 | torchrun --nproc-per-node 8 -m open_lm.main \ 60 | --val-data "pipe:aws s3 cp s3://laion-west/lmdata/validation_data_tokenized/open_lm//shard_00000000.tar -" \ 61 | --workers 6 \ 62 | --precision amp_bfloat16 \ 63 | --batch-size 8 \ 64 | --log-every-n-steps 1 \ 65 | --model open_lm_41m \ 66 | --fsdp --fsdp-amp \ 67 | --moe-num-experts 64 --moe-freq 2 \ 68 | --data-key json \ 69 | --train-num-samples 1000000000 \ 70 | --model-norm gain_only_layer_norm \ 71 | --name $RANDOM \ 72 | --resume /fsx/home-suching/experiments/mix_wo/test8086/checkpoints/epoch_1.pt \ 73 | --logs /fsx/home-$USER/experiments/eval 74 | ``` 75 | 76 | 77 | ## Benchmarking 78 | 79 | To benchmark your results, here are perplexities we obtain with our implementation across a number of compute budgets and model sizes on our A100 cluster: 80 | 81 | ### Compute budgets 82 | 83 | | Compute type | 41M | 87M | 160M | 410M | 830M | 84 | |--------------|------|------|------|------|------| 85 | | Number of nodes | 1 | 1 | 1 | 2 | 4 | 86 | | Number of tokens | 20.0B | 20.0B | 20.0B | 20.0B | 20.0B | 87 | 88 | ### Perplexity 89 | | Number of Experts | 41M | 87M | 160M | 410M | 830M | 90 | |--------------|------|------|------|------|------| 91 | | 1 | 27.61 | 18.68 | 14.87 | 10.54 | 9.39 | 92 | | 8 | 19.85 | 14.66 | 12.26 | 9.82 | 8.84 | 93 | | 32 | 20.55 | 15.28 |14.62 | | | 94 | 95 | 96 | ### Tokens/sec/GPU 97 | 98 | | Number of Experts | 41M | 87M | 160M | 410M | 830M | 99 | |--------------|------|------|------|------|------| 100 | | 1 | 141.2K | 106.0K | 95.5K | 30.3K | 16.0K | 101 | | 8 | 69.5K | 66.6K | 66.2K | 18.5K | 9.2K | 102 | 103 | ### Training Parameters 104 | 105 | | Number of Experts | 41M | 87M | 160M | 410M | 830M | 106 | |--------------|------|------|------|------|------| 107 | | 8 experts | 68.9M | 165.4M | 360.6M | 1.1B | 2.4B | 108 | | 32 experts | 164.5M | 439.9M | 1.0B | 3.5B | 7.9B | 109 | 110 | ### Inference Parameters 111 | 112 | | Number of Experts | 41M | 87M | 160M | 410M | 830M | 113 | |--------------|------|------|------|------|------| 114 | | 2 experts | 45.0M | 96.8M | 190.7M | 509.2M | 1.1B | -------------------------------------------------------------------------------- /open_lm/positional_embedding/rotary.py: -------------------------------------------------------------------------------- 1 | # NOTE: 08/31/23, this class is copied from xformers as there is currently a bug related to which channel dim the rotary embedding is applied to. 2 | # when the upstream issue is fixed, this file should be deleted. To track progress, see this issue: https://github.com/facebookresearch/xformers/issues/841 3 | 4 | # taken from: https://github.com/facebookresearch/xformers/blob/748c159096d4f9fcfe3eaf22801e5aed4777210b/xformers/components/positional_embedding/rotary.py 5 | from typing import Tuple 6 | 7 | import torch 8 | 9 | 10 | def rotate_half(x): 11 | x1, x2 = x.chunk(2, dim=-1) 12 | return torch.cat((-x2, x1), dim=-1) 13 | 14 | 15 | def apply_rotary_pos_emb(x, cos, sin, offset=0): 16 | # NOTE: This could probably be moved to Triton 17 | assert ( 18 | cos.shape[1] >= offset + x.shape[1] 19 | ), f"Offset and/or input sequence is too large,\ 20 | \n offset: {offset}, seq_len: {x.shape[1]}, max: {cos.shape[1]}" 21 | 22 | # Handle a possible sequence length mismatch in between q and k 23 | cos_out = cos[:, offset : offset + x.shape[1], :, :] 24 | sin_out = sin[:, offset : offset + x.shape[1], :, :] 25 | 26 | return (x * cos_out) + (rotate_half(x) * sin_out) 27 | 28 | 29 | class RotaryEmbedding(torch.nn.Module): 30 | """ 31 | The rotary position embeddings from RoFormer_ (Su et. al). 32 | A crucial insight from the method is that the query and keys are 33 | transformed by rotation matrices which depend on the relative positions. 34 | 35 | Other implementations are available in the Rotary Transformer repo_ and in 36 | GPT-NeoX_, GPT-NeoX was an inspiration 37 | 38 | .. _RoFormer: https://arxiv.org/abs/2104.09864 39 | .. _repo: https://github.com/ZhuiyiTechnology/roformer 40 | .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox 41 | 42 | 43 | .. warning: Please note that this embedding is not registered on purpose, as it is transformative 44 | (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis 45 | """ 46 | 47 | def __init__(self, dim_model: int, seq_len: int, frequency: int = 10000, scale: float = 1.0, *_, **__): 48 | super().__init__() 49 | # Generate and save the inverse frequency buffer (non trainable) 50 | self.dim_model = dim_model 51 | self.inv_freq = torch.zeros(self.dim_model // 2) 52 | 53 | self._cos_cached = None 54 | self._sin_cached = None 55 | self._seq_len_cached = 0 56 | self.seq_len = seq_len 57 | self.frequency = frequency 58 | self.scale = scale 59 | self.reset_parameters() 60 | 61 | def reset_parameters(self): 62 | self.inv_freq = 1.0 / ( 63 | self.scale * self.frequency ** (torch.arange(0, self.dim_model, 2).float() / self.dim_model) 64 | ) 65 | self._update_cos_sin_tables(self.seq_len) 66 | 67 | def _update_cos_sin_tables(self, seq_len: int = None, device: torch.device = None, dtype: torch.dtype = None): 68 | # If no seq_len is provided, use the cached one 69 | # If the seq_len is smaller than the cached one it is included in the cached one so no need to update 70 | if seq_len is None or seq_len < self._seq_len_cached: 71 | seq_len = self._seq_len_cached 72 | 73 | # Reset the tables if the sequence length has increased, 74 | # or if we're on a new device (possibly due to tracing for instance) 75 | if seq_len > self._seq_len_cached or self._cos_cached.device != device or self._cos_cached.dtype != dtype: 76 | self._seq_len_cached = seq_len 77 | t = torch.arange(seq_len, device=device, dtype=torch.float32) 78 | freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(dtype).to(device)) 79 | emb = torch.cat((freqs, freqs), dim=-1).to(device) 80 | 81 | self._cos_cached = emb.cos()[None, :, None, :].to(dtype) 82 | self._sin_cached = emb.sin()[None, :, None, :].to(dtype) 83 | 84 | def forward(self, q: torch.Tensor, k: torch.Tensor, offset=0) -> Tuple[torch.Tensor, torch.Tensor]: 85 | if isinstance(offset, torch.Tensor): 86 | offset_max = offset.max().item() 87 | else: 88 | offset_max = offset 89 | self._update_cos_sin_tables(k.shape[1] + offset_max, device=k.device, dtype=k.dtype) 90 | return ( 91 | apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached, offset), 92 | apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached, offset), 93 | ) 94 | 95 | 96 | class RotaryWithCast(RotaryEmbedding): 97 | def forward(self, q, k, v, offset=0): 98 | q, k = super().forward(q, k, offset) 99 | return q.to(v.dtype), k.to(v.dtype), v 100 | -------------------------------------------------------------------------------- /tests/test_tokenize_shuffle.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pytest 4 | import webdataset as wds 5 | 6 | 7 | @pytest.fixture(autouse=True) 8 | def run_around_tests(): 9 | yield 10 | os.system("rm -rf test_output/") 11 | os.system("rm -rf test_input/") 12 | os.system("aws s3 rm --recursive s3://dcnlp-west-test/tokenize_shuffle_test_output/simple/") 13 | 14 | 15 | def test_tokenize_shuffle_simple(): 16 | content_len = 2048 17 | NUM_TOKENS = 86058 18 | exit_value = os.system( 19 | f"python open_lm/datapreprocess/ray/tokenize_shuffle.py --input s3://dcnlp-west-test/tokenize_shuffle_test/C4_V3_tiny/ --content_key content --output test_output/ --seqlen {content_len}" 20 | ) 21 | assert exit_value == 0 22 | ds = wds.WebDataset("test_output/00000001.tar").decode() 23 | total = 0 24 | for x in ds: 25 | assert len(x["json.gz"]) == content_len + 1 26 | total += len(x["json.gz"]) 27 | # assert total == NUM_TOKENS 28 | 29 | with open("test_output/manifest.jsonl", "rb") as f: 30 | out = f.read() 31 | out = [json.loads(o) for o in out.decode("utf-8").split("\n")[:-1]] 32 | 33 | # assert out[0]["shard"] == "00000001" 34 | # assert out[0]["num_sequences"] == NUM_TOKENS // (content_len + 1) 35 | 36 | 37 | @pytest.mark.parametrize("content_key,NUM_TOKENS", [("npy", 4860228), ("txt", 24588), ("json:duration", 8196)]) 38 | def test_tokenize_shuffle_tar(content_key, NUM_TOKENS): 39 | content_len = 2048 40 | 41 | params = f"--content_key {content_key}" 42 | if content_key == "npy": 43 | params += " --vocab_size 16384" 44 | 45 | exit_value = os.system( 46 | f"python open_lm/datapreprocess/ray/tokenize_shuffle.py --input s3://dcnlp-west-test/tokenize_shuffle_test/webvid_tiny/ {params} --output test_output/ --seqlen {content_len}" 47 | ) 48 | assert exit_value == 0 49 | ds = wds.WebDataset("test_output/00000001.tar").decode() 50 | total = 0 51 | for x in ds: 52 | assert len(x["json.gz"]) == content_len + 1 53 | total += len(x["json.gz"]) 54 | assert total == NUM_TOKENS 55 | 56 | 57 | def test_tokenize_shuffle_simple_do_sample(): 58 | content_len = 2048 59 | NUM_TOKENS = 32784 60 | exit_value = os.system( 61 | f"python open_lm/datapreprocess/ray/tokenize_shuffle.py --input s3://dcnlp-west-test/tokenize_shuffle_test/C4_V3_tiny/ --content_key content --output test_output/ --seqlen {content_len} --do_sample" 62 | ) 63 | assert exit_value == 0 64 | ds = wds.WebDataset("test_output/00000001.tar").decode() 65 | total = 0 66 | for x in ds: 67 | assert len(x["json.gz"]) == content_len + 1 68 | total += len(x["json.gz"]) 69 | assert total == NUM_TOKENS 70 | 71 | 72 | @pytest.mark.s3 73 | def test_tokenize_shuffle_s3_write(): 74 | content_len = 2048 75 | NUM_TOKENS = 86058 76 | exit_value = os.system( 77 | f"python open_lm/datapreprocess/ray/tokenize_shuffle.py --input s3://dcnlp-west-test/tokenize_shuffle_test/C4_V3_tiny/ --content_key content --seqlen {content_len} --output s3://dcnlp-west-test/tokenize_shuffle_test_output/simple/" 78 | ) 79 | os.system("aws s3 sync s3://dcnlp-west-test/tokenize_shuffle_test_output/simple/ test_output/") 80 | ds = wds.WebDataset("test_output/00000001.tar").decode() 81 | total = 0 82 | for x in ds: 83 | assert len(x["json.gz"]) == content_len + 1 84 | total += len(x["json.gz"]) 85 | assert total == NUM_TOKENS 86 | assert exit_value == 0 87 | 88 | with open("test_output/manifest.jsonl", "rb") as f: 89 | out = f.read() 90 | out = [json.loads(o) for o in out.decode("utf-8").split("\n")[:-1]] 91 | 92 | assert out[0]["shard"] == "00000001" 93 | assert out[0]["num_sequences"] == NUM_TOKENS // (content_len + 1) 94 | 95 | 96 | def test_tokenize_shuffle_local_read_local_write(): 97 | content_len = 2048 98 | NUM_TOKENS = 24508089 99 | # download a small test json file and store at ./test_input 100 | os.system("mkdir test_input") 101 | os.system("mkdir test_output") 102 | os.system( 103 | "wget -O ./test_input/wikipedia_sample.jsonl https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample/resolve/main/wikipedia_sample.jsonl" 104 | ) 105 | # run tokenize script 106 | exit_value = os.system( 107 | f"python open_lm/datapreprocess/ray/tokenize_shuffle.py --input ./test_input --content_key text --seqlen {content_len} --output ./test_output/" 108 | ) 109 | tars = [os.path.join("test_output", fname) for fname in os.listdir("test_output") if fname.endswith(".tar")] 110 | total = 0 111 | for tar in tars: 112 | ds = wds.WebDataset(tar).decode() 113 | for x in ds: 114 | assert len(x["json.gz"]) == content_len + 1 115 | total += len(x["json.gz"]) 116 | assert total == NUM_TOKENS 117 | assert exit_value == 0 118 | -------------------------------------------------------------------------------- /eval/eval_openlm_ckpt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import builtins as __builtin__ 3 | import time 4 | from typing import List 5 | 6 | import torch 7 | from composer.loggers import InMemoryLogger, LoggerDestination 8 | from composer.trainer import Trainer 9 | from composer.utils import dist, get_device, reproducibility 10 | 11 | try: 12 | from llmfoundry.utils.builders import build_icl_evaluators, build_logger 13 | except ImportError: 14 | import logging 15 | 16 | logging.warning("llmfoundry not installed. Please install llmfoundry `pip install llm-foundry` to run this script.") 17 | 18 | from omegaconf import OmegaConf as om 19 | from transformers import GPTNeoXTokenizerFast, LlamaTokenizerFast 20 | 21 | from open_lm.model import create_params 22 | from open_lm.params import add_model_args 23 | from open_lm.utils.llm_foundry_wrapper import SimpleComposerOpenLMCausalLM 24 | from open_lm.utils.transformers.hf_config import OpenLMConfig 25 | from open_lm.utils.transformers.hf_model import OpenLMforCausalLM 26 | 27 | builtin_print = __builtin__.print 28 | 29 | 30 | def setup_for_distributed(is_master): 31 | def print(*args, **kwargs): 32 | force = kwargs.pop("force", False) 33 | if is_master or force: 34 | builtin_print(*args, **kwargs) 35 | 36 | __builtin__.print = print 37 | 38 | 39 | @torch.no_grad() 40 | def evaluate(model, tokenizer, cfg): 41 | cfg.dist_timeout = cfg.get("dist_timeout", 600.0) 42 | 43 | reproducibility.seed_all(cfg.seed) 44 | dist.initialize_dist(get_device(None), timeout=cfg.dist_timeout) 45 | setup_for_distributed(dist.get_global_rank() == 0) 46 | 47 | composer_model = SimpleComposerOpenLMCausalLM(model, tokenizer) 48 | 49 | evaluators, logger_keys = build_icl_evaluators( 50 | cfg.icl_tasks, tokenizer, cfg.max_seq_len, cfg.device_eval_batch_size 51 | ) 52 | 53 | in_memory_logger = InMemoryLogger() # track metrics in the in_memory_logger 54 | loggers: List[LoggerDestination] = [ 55 | build_logger(name, logger_cfg) for name, logger_cfg in (cfg.get("loggers") or {}).items() 56 | ] 57 | loggers.append(in_memory_logger) 58 | 59 | fsdp_config = cfg.get("fsdp_config", None) 60 | fsdp_config = om.to_container(fsdp_config, resolve=True) if fsdp_config is not None else None 61 | 62 | load_path = cfg.get("load_path", None) 63 | 64 | trainer = Trainer( 65 | model=composer_model, 66 | loggers=loggers, 67 | precision=cfg.precision, 68 | fsdp_config=fsdp_config, # type: ignore 69 | load_path=load_path, 70 | load_weights_only=True, 71 | progress_bar=False, 72 | log_to_console=True, 73 | dist_timeout=cfg.dist_timeout, 74 | ) 75 | 76 | if torch.cuda.is_available(): 77 | torch.cuda.synchronize() 78 | a = time.time() 79 | trainer.eval(eval_dataloader=evaluators) 80 | if torch.cuda.is_available(): 81 | torch.cuda.synchronize() 82 | b = time.time() 83 | 84 | print(f"Ran eval in: {b-a} seconds") 85 | 86 | for key in logger_keys: 87 | if key in in_memory_logger.data: 88 | result = in_memory_logger.data[key][0][1].item() 89 | print(f"{key}: {result}") 90 | 91 | 92 | def main(): 93 | """ 94 | Usage: 95 | python eval_openlm_ckpt.py --checkpoint --model --eval-yaml --tokenizer 96 | example: 97 | cd eval 98 | python eval_openlm_ckpt.py --checkpoint ../checkpoints/llama2_7b.pt --model llama2_7b.json --eval-yaml in_memory_hf_eval.yaml --tokenizer 99 | multi-gpu example: 100 | cd eval 101 | torchrun --nproc_per_node 3 python eval_openlm_ckpt.py --checkpoint ../checkpoints/llama2_7b.pt --model llama2_7b.json --eval-yaml in_memory_hf_eval.yaml --tokenizer 102 | """ 103 | parser = argparse.ArgumentParser() 104 | parser.add_argument("--checkpoint") 105 | parser.add_argument("--model", type=str, default="m1b_neox", help="Name of the model to use.") 106 | parser.add_argument("--eval-yaml") 107 | parser.add_argument("--tokenizer", type=str, default="EleutherAI/gpt-neox-20b") 108 | add_model_args(parser) 109 | args = parser.parse_args() 110 | 111 | with open(args.eval_yaml) as f: 112 | eval_cfg = om.load(f) 113 | 114 | print("Loading checkpoint from disk") 115 | checkpoint = torch.load(args.checkpoint) 116 | 117 | print("Loading model into the right classes") 118 | open_lm = OpenLMforCausalLM(OpenLMConfig(create_params(args))) 119 | if "gpt-neox-20b" in args.tokenizer: 120 | tokenizer = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b") 121 | elif "llama" in args.tokenizer: 122 | tokenizer = LlamaTokenizerFast.from_pretrained(args.tokenizer) 123 | 124 | state_dict = checkpoint["state_dict"] 125 | state_dict = {x.replace("module.", ""): y for x, y in state_dict.items()} 126 | open_lm.model.load_state_dict(state_dict) 127 | open_lm.model.eval() 128 | 129 | evaluate(open_lm, tokenizer, eval_cfg) 130 | 131 | 132 | if __name__ == "__main__": 133 | main() 134 | -------------------------------------------------------------------------------- /open_lm/utils/averaging_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import logging 4 | from copy import deepcopy 5 | 6 | 7 | def unwrap_model(model): 8 | return model.module if hasattr(model, "module") else model 9 | 10 | 11 | class ModelAverager(object): 12 | def __init__(self, model, methods: str): 13 | self.model = model 14 | self.avgs_dict = {} 15 | for method in methods.split(","): 16 | args = method.split("_") 17 | # method_name = args[0][:-1] if args[0].endswith('_') else args[0] 18 | # freq = int(args[1]) if len(args) > 1 else 1 19 | self.avgs_dict[method] = Averager(model, args) 20 | 21 | def step(self): 22 | for avg in self.avgs_dict.values(): 23 | avg.step() 24 | 25 | 26 | class Averager(object): 27 | def __init__(self, model, args): 28 | self.model = model 29 | self.method = args[0] 30 | self.update_counter = 1 31 | self.step_counter = 1 32 | self.freq = 1 if (len(args) <= 2) else int(args[2]) 33 | if self.method == "none": 34 | self.av_model = model 35 | return 36 | else: 37 | self.av_model = deepcopy(unwrap_model(model)) 38 | 39 | if self.method == "poly": 40 | self.eta = 0.0 if len(args) <= 1 else float(args[1]) 41 | elif self.method == "ema": 42 | self.gamma = 0.99 if len(args) <= 1 else float(args[1]) 43 | 44 | elif self.method == "cosine": 45 | pass 46 | else: 47 | print(f"Unknown averaging method {self.method}") 48 | 49 | def step(self): 50 | if self.update_counter != self.freq: 51 | pass 52 | else: 53 | self.update() 54 | self.update_counter += 1 55 | if self.update_counter > self.freq: 56 | self.update_counter = 1 57 | return 58 | 59 | def update(self): 60 | method = self.method 61 | if method == "none": 62 | return 63 | t = self.step_counter 64 | # model_sd is the current model state dict 65 | # av_sd is the averaged model state dict 66 | model_sd = self.model.state_dict() 67 | av_sd = self.av_model.state_dict() 68 | if self.method == "cosine" or self.method == "degree": 69 | pass 70 | first_k_av_sd = list(av_sd.keys())[0] 71 | for k in model_sd.keys(): 72 | av_sd_k = k 73 | if k.startswith("module") and not first_k_av_sd.startswith("module"): 74 | av_sd_k = k[len("module.") :] 75 | if isinstance(av_sd[av_sd_k], (torch.LongTensor, torch.cuda.LongTensor)): 76 | # these are buffers that store how many batches batch norm has seen so far 77 | av_sd[av_sd_k].copy_(model_sd[k]) 78 | continue 79 | if method == "poly": 80 | # the update rule is: new_average = (1 - (eta + 1) / (eta + t)) * old_average + (eta + 1) / (eta + t) * current_model 81 | # which is eq(10) in https://arxiv.org/pdf/1212.1824.pdf 82 | av_sd[av_sd_k].mul_(1 - ((self.eta + 1) / (self.eta + t))).add_( 83 | model_sd[k], alpha=(self.eta + 1) / (self.eta + t) 84 | ) 85 | if method == "ema": 86 | # the update rule is: new_average = (1 - gamma) * old_average + gamma * current_model 87 | av_sd[av_sd_k].mul_(self.gamma).add_(model_sd[k], alpha=1 - self.gamma) 88 | self.step_counter += 1 89 | 90 | def reset(self): 91 | self.step_counter = 2 92 | 93 | @property 94 | def averaged_model(self): 95 | return self.av_model 96 | 97 | def get_state_dict_avg(self): 98 | state_dict = { 99 | "update_counter": self.update_counter, 100 | "step_counter": self.step_counter, 101 | "freq": self.freq, 102 | "av_model_sd": unwrap_model(self.av_model).state_dict(), 103 | "method": self.method, 104 | "eta": self.eta if hasattr(self, "eta") else None, 105 | "gamma": self.gamma if hasattr(self, "gamma") else None, 106 | "suffix_steps": self.suffix_steps if hasattr(self, "suffix_steps") else None, 107 | "power": self.power if hasattr(self, "power") else None, 108 | "start": self.start if hasattr(self, "start") else None, 109 | } 110 | return state_dict 111 | 112 | def load_state_dict_avg(self, state_dict): 113 | self.update_counter = state_dict["update_counter"] 114 | self.step_counter = state_dict["step_counter"] 115 | self.freq = state_dict["freq"] 116 | self.method = state_dict["method"] 117 | self.av_model.load_state_dict(state_dict["av_model_sd"]) 118 | if hasattr(self, "eta"): 119 | self.eta = state_dict["eta"] 120 | if hasattr(self, "gamma"): 121 | self.gamma = state_dict["gamma"] 122 | if hasattr(self, "suffix_steps"): 123 | self.suffix_steps = state_dict["suffix_steps"] 124 | if hasattr(self, "power"): 125 | self.power = state_dict["power"] 126 | if hasattr(self, "start"): 127 | self.start = state_dict["start"] 128 | -------------------------------------------------------------------------------- /tests/shared.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim 3 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 4 | 5 | from open_lm.data import get_data 6 | from open_lm.distributed import init_distributed_device 7 | from open_lm.main import random_seed 8 | from open_lm.model import create_model 9 | from open_lm.params import parse_args 10 | from open_lm.scheduler import cosine_lr 11 | from tests.utils import download_val_data 12 | 13 | 14 | class MockTrainArgs: 15 | def __init__(self, model, **kwargs): 16 | data_path = download_val_data("shard_00000000.tar", "./tests/assets/") 17 | 18 | # fmt: off 19 | args = parse_args([ 20 | "--model", model, 21 | "--model-norm", "gain_only_layer_norm", 22 | "--train-data", data_path, 23 | "--precision", "fp32", 24 | "--wd", "0.033", 25 | "--lr", "3e-3", 26 | "--warmup", "2", 27 | "--global-batch-size", "8", 28 | "--accum", "1", 29 | "--name", "test_model_name", 30 | "--logs", "./tests/assets/", 31 | "--workers", "1", 32 | "--data-key", "json", 33 | "--seed", "1", 34 | ]) 35 | # fmt: off 36 | for k, v in vars(args).items(): 37 | setattr(self, k, v) 38 | 39 | self.device = "cuda:0" if torch.cuda.is_available() else "cpu" 40 | self.rank = 0 41 | self.local_rank = 0 42 | self.world_size = 1 43 | self.vocab_size = 50432 44 | self.seq_len = 300 45 | self.wandb = False 46 | self.fsdp = False 47 | self.fsdp_amp = False 48 | self.positional_embedding_type = "rotary" 49 | self.dist_backend = "nccl" 50 | self.dist_url = "env://" 51 | self.dataset_manifest = None 52 | self.target_mask_left = None 53 | self.target_mask_individual = None 54 | self.ignore_parse_errors = False 55 | self.moe_num_experts = None 56 | self.moe_freq = 0 57 | self.moe_weight_parallelism = False 58 | self.moe_expert_model_parallelism = False 59 | self.moe_capacity_factor = 1.25 60 | self.moe_loss_weight = 0.1 61 | self.moe_top_k = 2 62 | self.distributed = False 63 | self.per_gpu_batch_size = self.global_batch_size // self.world_size 64 | 65 | for k, v in kwargs.items(): 66 | setattr(self, k, v) 67 | 68 | # Recalculate batch size if overwritten. 69 | if "global_batch_size" in kwargs: 70 | self.per_gpu_batch_size = self.global_batch_size // self.world_size 71 | 72 | 73 | class MockDataArgs(object): 74 | def __init__(self): 75 | data_path = download_val_data("shard_00000000.tar", "./tests/assets/") 76 | 77 | self.train_data = [ 78 | data_path, 79 | ] 80 | self.dataset_resampled = True 81 | self.train_data_mix_weights = None 82 | self.val_num_samples = 0 83 | self.train_data_upsampling_factors = None 84 | self.train_num_samples = 512 85 | self.disable_buffer = True 86 | self.seq_len = 300 87 | self.vocab_size = 50432 88 | self.global_batch_size = 64 89 | self.world_size = 1 90 | self.rank = 0 91 | self.workers = 2 92 | self.seed = 42 93 | self.dataset_manifest = None 94 | self.target_mask_left = None 95 | self.target_mask_individual = None 96 | self.ignore_parse_errors = False 97 | self.per_gpu_batch_size = self.global_batch_size // self.world_size 98 | 99 | 100 | def create_train_fixtures(model="open_lm_11m", fsdp=False, **kwargs): 101 | # Setup data, optimizer, and other basic settings 102 | args = MockTrainArgs(model, **kwargs) 103 | args.fsdp = fsdp 104 | 105 | # only want to look at one batch 106 | args.train_num_samples = args.global_batch_size 107 | 108 | # increase learning rate and remove warmup for maximize change to model weights 109 | args.lr = 1e-3 110 | args.warmup = 0 111 | 112 | # create base models 113 | random_seed() 114 | if fsdp: 115 | model = create_model(args) 116 | model = FSDP(model) 117 | else: 118 | model = create_model(args) 119 | model.reset_parameters() 120 | model = model.to(args.device) 121 | 122 | # create dataloader 123 | data = get_data( 124 | args, 125 | epoch=0, 126 | tokenizer=None, 127 | skip_train=False, 128 | ) 129 | 130 | # create optimizer 131 | named_parameters = list(model.named_parameters()) 132 | params = [p for _, p in named_parameters if p.requires_grad] 133 | optimizer = optim.AdamW( 134 | [ 135 | {"params": params, "weight_decay": args.wd}, 136 | ], 137 | lr=args.lr, 138 | betas=(args.beta1, args.beta2), 139 | eps=args.eps, 140 | ) 141 | 142 | # create scheduler 143 | scheduler = cosine_lr( 144 | optimizer, 145 | args.lr, 146 | args.warmup, 147 | 10, 148 | args.lr_cooldown_end, 149 | args.force_min_lr, 150 | ) 151 | 152 | # create loss 153 | loss = torch.nn.CrossEntropyLoss() 154 | 155 | return args, model, data, optimizer, scheduler, loss 156 | -------------------------------------------------------------------------------- /scripts/generate_without_hf.py: -------------------------------------------------------------------------------- 1 | """Script to generate text from a trained model, without HuggingFace wrappers. 2 | 3 | This script is useful for simple generation, and to debug any issues with HuggingFace integration. 4 | The output of this script should match that of generate.py when `--temperature 0` is passed. 5 | """ 6 | 7 | # Thanks to Gabriel for this code. 8 | import argparse 9 | import os 10 | import glob 11 | import yaml 12 | from dataclasses import dataclass 13 | from typing import List 14 | from yaml import Loader 15 | 16 | import torch 17 | from transformers import GPTNeoXTokenizerFast 18 | 19 | from open_lm.model import Transformer, create_model 20 | 21 | 22 | @dataclass 23 | class GenerationArgs: 24 | max_gen_len: int = 200 25 | temperature: float = 0.8 26 | top_p: float = 0.95 27 | 28 | 29 | class Generator: 30 | def __init__(self, model: Transformer): 31 | self.model = model 32 | self.tokenizer = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b") 33 | self.pad_token_id = 50282 34 | self.seq_len = 2048 35 | 36 | @torch.inference_mode() 37 | def generate( 38 | self, 39 | prompts: List[str], 40 | gen_args: GenerationArgs = GenerationArgs(), 41 | ) -> List[str]: 42 | bsz = len(prompts) 43 | 44 | prompt_tokens = [self.tokenizer.encode(x) for x in prompts] 45 | 46 | min_prompt_size = min([len(t) for t in prompt_tokens]) 47 | max_prompt_size = max([len(t) for t in prompt_tokens]) 48 | 49 | total_len = min(self.seq_len, gen_args.max_gen_len + max_prompt_size) 50 | 51 | tokens = torch.full((bsz, total_len), self.pad_token_id).cuda().long() 52 | for k, t in enumerate(prompt_tokens): 53 | tokens[k, : len(t)] = torch.tensor(t).long() 54 | input_text_mask = tokens != self.pad_token_id 55 | start_pos = min_prompt_size 56 | prev_pos = 0 57 | for cur_pos in range(start_pos, total_len): 58 | last_logits = self.model(tokens[:, prev_pos:cur_pos].clone())[0][:, -1, :] 59 | if gen_args.temperature > 0: 60 | probs = torch.softmax(last_logits / gen_args.temperature, dim=-1) 61 | next_token = sample_top_p(probs, gen_args.top_p) 62 | else: 63 | next_token = torch.argmax(last_logits, dim=-1) 64 | next_token = next_token.reshape(-1) 65 | # only replace token if prompt has already been generated 66 | next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token) 67 | tokens[:, cur_pos] = next_token 68 | 69 | # TODO: enable caching again for inference 70 | # prev_pos = cur_pos 71 | 72 | decoded = [] 73 | for i, t in enumerate(tokens.tolist()): 74 | t = t[: len(prompt_tokens[i]) + gen_args.max_gen_len] 75 | decoded_i = self.tokenizer.decode(t) 76 | 77 | decoded = [] 78 | for t in decoded_i: 79 | decoded.append(t) 80 | 81 | return decoded 82 | 83 | 84 | def sample_top_p(probs, p): 85 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) 86 | probs_sum = torch.cumsum(probs_sort, dim=-1) 87 | mask = probs_sum - probs_sort > p 88 | probs_sort[mask] = 0.0 89 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) 90 | next_token = torch.multinomial(probs_sort, num_samples=1) 91 | next_token = torch.gather(probs_idx, -1, next_token) 92 | return next_token 93 | 94 | 95 | class ModelArgs: 96 | def __init__(self, path: str): 97 | with open(path, "r") as f: 98 | params = yaml.load(f, Loader=Loader) 99 | for k, v in params.items(): 100 | setattr(self, k, v) 101 | 102 | 103 | def main(): 104 | parser = argparse.ArgumentParser() 105 | parser.add_argument("--checkpoint", default="") 106 | # TODO: Make this take as input --model-config, similar to generate.py 107 | parser.add_argument("--params", default="") 108 | parser.add_argument("--wandb-dir", default="") 109 | parser.add_argument("--input-text", required=True) 110 | parser.add_argument("--max-gen-len", default=200, type=int) 111 | parser.add_argument("--temperature", default=0.8, type=float) 112 | parser.add_argument("--top-p", default=0.95, type=float) 113 | 114 | args = parser.parse_args() 115 | 116 | if args.wandb_dir != "": 117 | if args.params == "": 118 | args.params = os.path.join(args.wandb_dir, "params.txt") 119 | if args.checkpoint == "": 120 | chkpt_dir = os.path.join(args.wandb_dir, "checkpoints", "epoch_*.pt") 121 | list_of_files = glob.glob(chkpt_dir) 122 | latest_file = max(list_of_files, key=os.path.getctime) 123 | args.checkpoint = latest_file 124 | else: 125 | assert args.params != "", "Must provide params file or a wandb directory." 126 | assert args.checkpoint != "", "Must provide checkpoint file or a wandb directory." 127 | 128 | checkpoint = torch.load(args.checkpoint) 129 | open_lm = create_model(ModelArgs(args.params)).half() 130 | 131 | state_dict = checkpoint["state_dict"] 132 | state_dict = {x.replace("module.", ""): y for x, y in state_dict.items()} 133 | open_lm.load_state_dict(state_dict) 134 | open_lm.eval().cuda() 135 | generator = Generator(open_lm) 136 | input_text = [ 137 | args.input_text, 138 | ] 139 | output = generator.generate( 140 | input_text, 141 | GenerationArgs(args.max_gen_len, args.temperature, args.top_p), 142 | ) 143 | print("".join(output)) 144 | 145 | 146 | if __name__ == "__main__": 147 | main() 148 | -------------------------------------------------------------------------------- /open_lm/norms.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | from functools import partial 3 | from typing import Union, List 4 | 5 | import torch 6 | from torch import Tensor, Size 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.nn.parameter import Parameter 10 | 11 | 12 | class LayerNorm(nn.Module): 13 | # NOTE: taken from official pytorch implementation and modified 14 | # to allow revoval of gain and bias independently 15 | 16 | def __init__( 17 | self, 18 | normalized_shape: Union[int, List[int], Size], 19 | eps: float = 0.00001, 20 | elementwise_gain: bool = True, 21 | elementwise_bias: bool = True, 22 | device=None, 23 | dtype=None, 24 | ) -> None: 25 | factory_kwargs = {"device": device, "dtype": dtype} 26 | super().__init__() 27 | 28 | if isinstance(normalized_shape, numbers.Integral): 29 | # mypy error: incompatible types in assignment 30 | normalized_shape = (normalized_shape,) # type: ignore[assignment] 31 | self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] 32 | self.eps = eps 33 | self.elementwise_gain = elementwise_gain 34 | self.elementwise_bias = elementwise_bias 35 | 36 | if self.elementwise_gain: 37 | self.weight = Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) 38 | else: 39 | self.register_parameter("weight", None) 40 | 41 | if self.elementwise_bias: 42 | self.bias = Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) 43 | else: 44 | self.register_parameter("bias", None) 45 | 46 | self.reset_parameters() 47 | 48 | def reset_parameters(self) -> None: 49 | if self.elementwise_gain: 50 | with torch.no_grad(): 51 | self.weight.fill_(1.0) 52 | 53 | if self.elementwise_bias: 54 | with torch.no_grad(): 55 | self.bias.zero_() 56 | 57 | def forward(self, input: Tensor) -> Tensor: 58 | return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) 59 | 60 | def extra_repr(self) -> str: 61 | return ( 62 | "{normalized_shape}, eps={eps}, " 63 | "elementwise_gain={elementwise_gain}, " 64 | "elementwise_bias={elementwise_bias}".format(**self.__dict__) 65 | ) 66 | 67 | 68 | class LPLayerNorm(LayerNorm): 69 | """From MosaicML composer. 70 | 71 | See: https://github.com/mosaicml/composer/blob/6acca4c70425455be7280a5459dbf02e1ac5591d/composer/algorithms/low_precision_layernorm/low_precision_layernorm.py#L63 72 | """ 73 | 74 | def forward(self, x): 75 | module_device = x.device 76 | downcast_x = _cast_if_autocast_enabled(x) 77 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight 78 | downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias 79 | with torch.autocast(enabled=False, device_type=module_device.type): 80 | return F.layer_norm( 81 | downcast_x, 82 | self.normalized_shape, 83 | downcast_weight, 84 | downcast_bias, 85 | self.eps, 86 | ) 87 | 88 | 89 | def _cast_if_autocast_enabled(tensor): 90 | if torch.is_autocast_enabled(): 91 | if tensor.device.type == "cuda": 92 | dtype = torch.get_autocast_gpu_dtype() 93 | elif tensor.device.type == "cpu": 94 | dtype = torch.get_autocast_cpu_dtype() 95 | else: 96 | raise NotImplementedError() 97 | return tensor.to(dtype=dtype) 98 | return tensor 99 | 100 | 101 | class RmsNorm(nn.Module): 102 | def __init__( 103 | self, 104 | normalized_shape: Union[int, List[int], Size], 105 | eps: float = 1e-6, 106 | device=None, 107 | dtype=None, 108 | ) -> None: 109 | factory_kwargs = {"device": device, "dtype": dtype} 110 | super().__init__() 111 | 112 | if isinstance(normalized_shape, numbers.Integral): 113 | # mypy error: incompatible types in assignment 114 | normalized_shape = (normalized_shape,) # type: ignore[assignment] 115 | self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] 116 | self.eps = eps 117 | self.weight = Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) 118 | self.reset_parameters() 119 | 120 | def _norm(self, x): 121 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 122 | 123 | def forward(self, x): 124 | output = self._norm(x.float()).type_as(x) 125 | 126 | return output * self.weight 127 | 128 | def reset_parameters(self) -> None: 129 | with torch.no_grad(): 130 | self.weight.fill_(1.0) 131 | 132 | def extra_repr(self) -> str: 133 | return "{normalized_shape}, eps={eps} ".format(**self.__dict__) 134 | 135 | 136 | def get_norm_class(model_norm): 137 | if model_norm == "default_layer_norm": 138 | return torch.nn.LayerNorm 139 | elif model_norm == "lp_layer_norm": 140 | return LPLayerNorm 141 | elif model_norm == "gain_only_lp_layer_norm": 142 | return partial(LPLayerNorm, elementwise_gain=True, elementwise_bias=False) 143 | elif model_norm == "gain_only_layer_norm": 144 | return partial(LayerNorm, elementwise_gain=True, elementwise_bias=False) 145 | 146 | elif model_norm == "no_wb_layer_norm": 147 | return partial(LayerNorm, elementwise_gain=False, elementwise_bias=False) 148 | 149 | elif model_norm == "rms_norm": 150 | return RmsNorm 151 | 152 | else: 153 | raise ValueError(f"Unsupported model-norm: {model_norm}") 154 | -------------------------------------------------------------------------------- /tests/test_dataset_deterministic.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import argparse 3 | import random 4 | import os 5 | import webdataset as wds 6 | import glob 7 | from open_lm.model import _MODEL_CONFIGS 8 | from open_lm.main import random_seed 9 | from open_lm.data import get_wds_dataset 10 | from open_lm.file_utils import ( 11 | get_string_for_epoch, 12 | get_metadata_file, 13 | get_shards_for_chunk, 14 | ) 15 | from open_lm.params import parse_args 16 | from pathlib import Path 17 | from tests.utils import download_dl_test_data 18 | from time import sleep 19 | 20 | NUM_SAMPLES = 1000 21 | NUM_SAMPLES_TO_CHECK = 5 22 | 23 | # Update this to two data sources with webdataset, each with their own manifest. 24 | INPUT_PATHS = [ 25 | "tests/assets/source_1/manifest.jsonl", 26 | "tests/assets/source_2/manifest.jsonl", 27 | ] 28 | 29 | 30 | def retrieve_dataset(epoch, next_shard, weights, seed, disable_buffer, min_shards_needed=2): 31 | args = parse_args("") 32 | 33 | train_data_string_per_source, num_samples_per_source, _ = get_string_for_epoch( 34 | NUM_SAMPLES, [next_shard, next_shard], INPUT_PATHS, weights, min_shards_needed, world_size=1 35 | ) 36 | args.train_num_samples = NUM_SAMPLES 37 | args.train_data = train_data_string_per_source 38 | args.workers = 2 39 | args.global_batch_size = 2 40 | args.seed = seed 41 | args.dataset_resampled = False 42 | args.disable_buffer = disable_buffer 43 | args.vocab_size = _MODEL_CONFIGS[args.model]["vocab_size"] 44 | args.seq_len = _MODEL_CONFIGS[args.model]["seq_len"] 45 | args.world_size = 1 46 | args.rank = 0 47 | args.per_gpu_batch_size = 2 48 | data = get_wds_dataset(args, is_train=True, epoch=epoch, force_num_samples=num_samples_per_source) 49 | dl = data.dataloader 50 | 51 | return dl 52 | 53 | 54 | def retrieve_dataset_resampled(epoch, next_shard, weights, seed, min_shards_needed=2): 55 | args = parse_args("") 56 | train_data_string_per_source, _, _ = get_string_for_epoch( 57 | NUM_SAMPLES, [next_shard, next_shard], INPUT_PATHS, weights, min_shards_needed, world_size=1 58 | ) 59 | args.train_num_samples = NUM_SAMPLES 60 | args.train_data = train_data_string_per_source 61 | args.num_workers = 2 62 | args.global_batch_size = 2 63 | args.seed = seed 64 | args.dataset_resampled = True 65 | args.vocab_size = _MODEL_CONFIGS[args.model]["vocab_size"] 66 | args.seq_len = _MODEL_CONFIGS[args.model]["seq_len"] 67 | args.world_size = 1 68 | args.rank = 0 69 | args.per_gpu_batch_size = 2 70 | data = get_wds_dataset(args, is_train=True, epoch=epoch) 71 | dl = data.dataloader 72 | 73 | return dl 74 | 75 | 76 | @pytest.mark.parametrize("next_shard", [0, 2]) 77 | @pytest.mark.parametrize("weights", [[0.5, 0.5], [0.9, 0.1]]) 78 | @pytest.mark.parametrize("seed", [0, 17]) 79 | def test_deterministic_no_buffer(next_shard, weights, seed): 80 | download_dl_test_data("tests/assets") 81 | disable_buffer = True 82 | random_seed(seed) 83 | dl1 = retrieve_dataset(0, next_shard, weights, seed, disable_buffer) 84 | dl2 = retrieve_dataset(0, next_shard, weights, seed, disable_buffer) 85 | 86 | iter1 = iter(dl1) 87 | iter2 = iter(dl2) 88 | 89 | for _ in range(NUM_SAMPLES_TO_CHECK): 90 | item1 = next(iter1) 91 | item2 = next(iter2) 92 | assert item1 == item2 93 | sleep(0.001) 94 | 95 | 96 | @pytest.mark.parametrize("next_shard", [0, 2]) 97 | @pytest.mark.parametrize("weights", [[0.5, 0.5], [0.9, 0.1]]) 98 | @pytest.mark.parametrize("seed", [0, 17]) 99 | def test_deterministic_with_buffer(next_shard, weights, seed): 100 | download_dl_test_data("tests/assets") 101 | disable_buffer = False 102 | random_seed(seed) 103 | dl1 = retrieve_dataset(0, next_shard, weights, seed, disable_buffer) 104 | dl2 = retrieve_dataset(0, next_shard, weights, seed, disable_buffer) 105 | 106 | iter1 = iter(dl1) 107 | iter2 = iter(dl2) 108 | 109 | for _ in range(NUM_SAMPLES_TO_CHECK): 110 | item1 = next(iter1) 111 | item2 = next(iter2) 112 | assert item1 == item2 113 | sleep(0.001) 114 | 115 | 116 | @pytest.mark.parametrize("next_shard", [0, 2]) 117 | @pytest.mark.parametrize("weights", [[0.5, 0.5], [0.9, 0.1]]) 118 | @pytest.mark.parametrize("seed", [0, 17]) 119 | def test_deterministic_resampled(next_shard, weights, seed): 120 | download_dl_test_data("tests/assets") 121 | random_seed(seed) 122 | dl1 = retrieve_dataset_resampled(0, next_shard, weights, seed) 123 | dl2 = retrieve_dataset_resampled(0, next_shard, weights, seed) 124 | 125 | iter1 = iter(dl1) 126 | iter2 = iter(dl2) 127 | 128 | for _ in range(NUM_SAMPLES_TO_CHECK): 129 | item1 = next(iter1) 130 | item2 = next(iter2) 131 | assert item1 == item2 132 | sleep(0.001) 133 | 134 | 135 | @pytest.mark.parametrize("next_shard", [0, 2]) 136 | @pytest.mark.parametrize("weights", [[0.5, 0.5], [0.6, 0.4]]) 137 | @pytest.mark.parametrize("min_shards_needed", [2, 4]) 138 | def test_min_shards(next_shard, weights, min_shards_needed): 139 | download_dl_test_data("tests/assets") 140 | shard_strings, _, _ = get_string_for_epoch( 141 | NUM_SAMPLES, [next_shard, next_shard], INPUT_PATHS, weights, min_shards_needed, world_size=1 142 | ) 143 | for item in shard_strings: 144 | num_shards = len(item.split(",")) 145 | assert num_shards >= min_shards_needed 146 | 147 | 148 | def test_count_manifest(): 149 | download_dl_test_data("tests/assets") 150 | manifest_path = INPUT_PATHS[0] 151 | metadata = get_metadata_file(manifest_path) 152 | idx = random.randint(0, len(metadata)) 153 | item = metadata[idx] 154 | shard_path = os.path.join(str(Path(INPUT_PATHS[0]).parent), item["shard"] + ".tar") 155 | shard_ds = wds.WebDataset(str(shard_path)) 156 | count = 0 157 | for _ in iter(shard_ds): 158 | count += 1 159 | assert count == item["num_sequences"] 160 | -------------------------------------------------------------------------------- /tests/test_loss_masking.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from torch import Tensor, equal 4 | 5 | from open_lm.data import sample_chunk 6 | from tests.shared import create_train_fixtures 7 | 8 | 9 | def test_target_mask_left(): 10 | args, _, _, _, _, _ = create_train_fixtures() 11 | 12 | special_token_left = 42 13 | batched_tokens = Tensor( 14 | [ 15 | [0, 1, 2, 42], 16 | [0, 1, 42, 2], 17 | [0, 42, 1, 2], 18 | [42, 0, 1, 2], 19 | [0, 1, 2, 3], 20 | [42, 0, 42, 1], 21 | ] 22 | ) 23 | 24 | args.target_mask_left = special_token_left 25 | args.seq_len = 3 26 | 27 | input, target = sample_chunk(batched_tokens, args) 28 | 29 | real_input = Tensor( 30 | [ 31 | [0, 1, 2], 32 | [0, 1, 42], 33 | [0, 42, 1], 34 | [42, 0, 1], 35 | [0, 1, 2], 36 | [42, 0, 42], 37 | ] 38 | ) 39 | 40 | real_target = Tensor( 41 | [ 42 | [-100, -100, -100], 43 | [-100, -100, 2], 44 | [-100, 1, 2], 45 | [0, 1, 2], 46 | [1, 2, 3], 47 | [-100, -100, 1], 48 | ] 49 | ) 50 | 51 | assert equal(input, real_input) 52 | assert equal(target, real_target) 53 | 54 | 55 | def test_target_mask_individual(): 56 | args, _, _, _, _, _ = create_train_fixtures() 57 | special_token_individual = 2 58 | 59 | batched_tokens = Tensor( 60 | [ 61 | [0, 1, 2, 42], 62 | [0, 1, 42, 2], 63 | [0, 42, 1, 2], 64 | [42, 0, 1, 2], 65 | [0, 1, 2, 3], 66 | [42, 0, 42, 1], 67 | ] 68 | ) 69 | 70 | args.target_mask_individual = special_token_individual 71 | args.seq_len = 3 72 | 73 | input, target = sample_chunk(batched_tokens, args) 74 | 75 | real_input = Tensor( 76 | [ 77 | [0, 1, 2], 78 | [0, 1, 42], 79 | [0, 42, 1], 80 | [42, 0, 1], 81 | [0, 1, 2], 82 | [42, 0, 42], 83 | ] 84 | ) 85 | 86 | real_target = Tensor( 87 | [ 88 | [1, -100, 42], 89 | [1, 42, -100], 90 | [42, 1, -100], 91 | [0, 1, -100], 92 | [1, -100, 3], 93 | [0, 42, 1], 94 | ] 95 | ) 96 | 97 | assert equal(input, real_input) 98 | assert equal(target, real_target) 99 | 100 | 101 | def test_target_mask_left_individual(): 102 | args, _, _, _, _, _ = create_train_fixtures() 103 | special_token_left = 42 104 | special_token_individual = 2 105 | 106 | batched_tokens = Tensor( 107 | [ 108 | [0, 1, 2, 42], 109 | [0, 1, 42, 2], 110 | [0, 42, 1, 2], 111 | [42, 0, 1, 2], 112 | [0, 1, 2, 3], 113 | [42, 0, 42, 1], 114 | [42, 42, 42, 1], 115 | ] 116 | ) 117 | 118 | args.target_mask_left = special_token_left 119 | args.target_mask_individual = special_token_individual 120 | args.seq_len = 3 121 | 122 | input, target = sample_chunk(batched_tokens, args) 123 | 124 | real_input = Tensor( 125 | [ 126 | [0, 1, 2], 127 | [0, 1, 42], 128 | [0, 42, 1], 129 | [42, 0, 1], 130 | [0, 1, 2], 131 | [42, 0, 42], 132 | [42, 42, 42], 133 | ] 134 | ) 135 | 136 | real_target = Tensor( 137 | [ 138 | [-100, -100, -100], 139 | [-100, -100, -100], 140 | [-100, 1, -100], 141 | [0, 1, -100], 142 | [1, -100, 3], 143 | [-100, -100, 1], 144 | [-100, -100, 1], 145 | ] 146 | ) 147 | 148 | assert equal(input, real_input) 149 | assert equal(target, real_target) 150 | 151 | 152 | def test_target_mask_left_individual_squash(): 153 | args, _, _, _, _, _ = create_train_fixtures() 154 | special_token_left = 42 155 | special_token_individual = 2 156 | 157 | batched_tokens = Tensor( 158 | [ 159 | [0, 1, 2, 42], 160 | [0, 1, 42, 2], 161 | [0, 42, 1, 2], 162 | [42, 0, 1, 2], 163 | [0, 1, 2, 3], 164 | [42, 0, 42, 1], 165 | ] 166 | ) 167 | 168 | args.target_mask_left = special_token_left 169 | args.target_mask_individual = special_token_individual 170 | args.seq_len = 3 171 | args.squash_mask_left = True 172 | 173 | input, target = sample_chunk(batched_tokens, args) 174 | 175 | real_input = Tensor( 176 | [ 177 | [0, 1, 2], 178 | [0, 1, 2], 179 | [0, 1, 2], 180 | [0, 1, 2], 181 | [0, 1, 2], 182 | [0, 2, 2], 183 | ] 184 | ) 185 | 186 | real_target = Tensor( 187 | [ 188 | [-100, -100, -100], 189 | [-100, -100, -100], 190 | [1, -100, -100], 191 | [0, 1, -100], 192 | [1, -100, 3], 193 | [-100, 1, -100], 194 | ] 195 | ) 196 | 197 | assert equal(input, real_input) 198 | assert equal(target, real_target) 199 | 200 | 201 | def test_target_mask_left_individual_squash_real_data(): 202 | data = None 203 | with open("tests/assets/2049_span_pad.json", "r") as f: 204 | data = json.load(f) 205 | 206 | args, _, _, _, _, _ = create_train_fixtures() 207 | 208 | args.target_mask_left = 50300 209 | args.target_mask_individual = 50400 210 | args.seq_len = 2048 211 | args.squash_mask_left = True 212 | 213 | # skip the pad left token 214 | real_input = data[:65] + data[66:72] 215 | 216 | # right pad with the target_mask_individual token 217 | real_input += (2048 - len(real_input)) * [50400] 218 | real_input = Tensor([real_input]) 219 | 220 | input, target = sample_chunk(Tensor([data]), args) 221 | # print(input.shape) 222 | 223 | # skip the pad left token and mask out the prefix with -100 224 | real_target = len(data[1:65]) * [-100] + data[66:72] 225 | 226 | # right pad with the ignore xent token (-100) 227 | real_target += (2048 - len(real_target)) * [-100] 228 | real_target = Tensor([real_target]) 229 | 230 | assert equal(input, real_input) 231 | assert equal(target, real_target) 232 | -------------------------------------------------------------------------------- /tests/test_attention_masking.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | from open_lm.attention import xformers_attn, torch_attn 5 | 6 | 7 | @pytest.mark.gpu 8 | def test_attention_masking_xformers(): 9 | n, d = 8, 4 10 | queries = torch.rand((1, n, 1, d)).cuda() 11 | keys = torch.rand((1, n, 1, d)).cuda() 12 | values = torch.rand((1, n, 1, d)).cuda() 13 | 14 | attention_mask = torch.ones((1, n)).cuda() 15 | # Ignore first elements 16 | attention_mask[:, :4] = 0 17 | 18 | # Run with only last 4 elements of the sequence, no attention mask 19 | output_no_mask = xformers_attn( 20 | queries[:, 4:], 21 | keys[:, 4:], 22 | values[:, 4:], 23 | is_causal=True, 24 | attention_mask=None, 25 | ) 26 | 27 | # Run with only last 4 elements of the sequence 28 | output_dummy_mask = xformers_attn( 29 | queries[:, 4:], 30 | keys[:, 4:], 31 | values[:, 4:], 32 | is_causal=True, 33 | attention_mask=attention_mask[:, 4:], 34 | ) 35 | 36 | # Run with all elements but mask the first 4 elements 37 | output_mask_initial = xformers_attn(queries, keys, values, is_causal=True, attention_mask=attention_mask) 38 | 39 | # Run with the output of attention again and ensure it looks good (e.g., we don't run into NaNs. This happened in 40 | # initial implementations where the output had NaNs for certain types of masks. 41 | output3 = xformers_attn( 42 | output_mask_initial, output_mask_initial, output_mask_initial, is_causal=True, attention_mask=attention_mask 43 | ) 44 | assert not output3.isnan().any() 45 | 46 | assert torch.allclose(output_no_mask, output_dummy_mask) 47 | assert torch.allclose(output_no_mask, output_mask_initial[:, 4:]) 48 | 49 | 50 | def test_attention_masking_torchattn(): 51 | n, d = 8, 4 52 | queries = torch.rand((1, n, 1, d)) 53 | keys = torch.rand((1, n, 1, d)) 54 | values = torch.rand((1, n, 1, d)) 55 | 56 | attention_mask = torch.ones((1, n)) 57 | # Ignore first elements 58 | attention_mask[:, :4] = 0 59 | 60 | # Run with only last 4 elements of the sequence, no attention mask 61 | output_no_mask = torch_attn( 62 | queries[:, 4:], 63 | keys[:, 4:], 64 | values[:, 4:], 65 | is_causal=True, 66 | attention_mask=None, 67 | ) 68 | 69 | # Run with only last 4 elements of the sequence 70 | output_dummy_mask = torch_attn( 71 | queries[:, 4:], 72 | keys[:, 4:], 73 | values[:, 4:], 74 | is_causal=True, 75 | attention_mask=attention_mask[:, 4:], 76 | ) 77 | 78 | # Run with all elements but mask the first 4 elements 79 | output_mask_initial = torch_attn(queries, keys, values, is_causal=True, attention_mask=attention_mask) 80 | 81 | output3 = torch_attn( 82 | output_mask_initial, output_mask_initial, output_mask_initial, is_causal=True, attention_mask=attention_mask 83 | ) 84 | assert not output3.isnan().any() 85 | 86 | assert torch.allclose(output_no_mask, output_dummy_mask) 87 | assert torch.allclose(output_no_mask, output_mask_initial[:, 4:]) 88 | 89 | 90 | @pytest.mark.gpu 91 | def test_attention_masking_torchattn_vs_xformers(): 92 | n, d = 8, 4 93 | queries = torch.rand((1, n, 1, d)) 94 | keys = torch.rand((1, n, 1, d)) 95 | values = torch.rand((1, n, 1, d)) 96 | 97 | attention_mask = torch.ones((1, n)) 98 | # Ignore first elements 99 | attention_mask[:, :4] = 0 100 | 101 | # Run with only last 4 elements of the sequence, no attention mask 102 | output_no_mask_torch = torch_attn( 103 | queries[:, 4:], 104 | keys[:, 4:], 105 | values[:, 4:], 106 | is_causal=True, 107 | attention_mask=None, 108 | ) 109 | 110 | # Run with only last 4 elements of the sequence 111 | output_dummy_mask_torch = torch_attn( 112 | queries[:, 4:], 113 | keys[:, 4:], 114 | values[:, 4:], 115 | is_causal=True, 116 | attention_mask=attention_mask[:, 4:].clone(), 117 | ) 118 | 119 | # Run with all elements but mask the first 4 elements 120 | output_mask_initial_torch = torch_attn(queries, keys, values, is_causal=True, attention_mask=attention_mask.clone()) 121 | output_mask_initial_fewq_torch = torch_attn( 122 | queries[:, : n - 2], keys, values, is_causal=True, attention_mask=attention_mask.clone() 123 | ) 124 | 125 | output3_torch = torch_attn( 126 | output_mask_initial_torch, 127 | output_mask_initial_torch, 128 | output_mask_initial_torch, 129 | is_causal=True, 130 | attention_mask=attention_mask.clone(), 131 | ) 132 | assert not output3_torch.isnan().any() 133 | 134 | queries = queries.cuda() 135 | keys = keys.cuda() 136 | values = values.cuda() 137 | attention_mask = attention_mask.cuda() 138 | 139 | # Run with only last 4 elements of the sequence, no attention mask 140 | output_no_mask_xformers = xformers_attn( 141 | queries[:, 4:], 142 | keys[:, 4:], 143 | values[:, 4:], 144 | is_causal=True, 145 | attention_mask=None, 146 | ) 147 | 148 | # Run with only last 4 elements of the sequence 149 | output_dummy_mask_xformers = xformers_attn( 150 | queries[:, 4:], 151 | keys[:, 4:], 152 | values[:, 4:], 153 | is_causal=True, 154 | attention_mask=attention_mask[:, 4:].clone(), 155 | ) 156 | 157 | # Run with all elements but mask the first 4 elements 158 | output_mask_initial_xformers = xformers_attn( 159 | queries, keys, values, is_causal=True, attention_mask=attention_mask.clone() 160 | ) 161 | output_mask_initial_fewq_xformers = xformers_attn( 162 | queries[:, : n - 2], keys, values, is_causal=True, attention_mask=attention_mask.clone() 163 | ) 164 | 165 | output3_xformers = xformers_attn( 166 | output_mask_initial_xformers, 167 | output_mask_initial_xformers, 168 | output_mask_initial_xformers, 169 | is_causal=True, 170 | attention_mask=attention_mask.clone(), 171 | ) 172 | assert not output3_xformers.isnan().any() 173 | 174 | assert torch.allclose(output_no_mask_torch, output_no_mask_xformers.cpu()) 175 | assert torch.allclose(output_dummy_mask_torch, output_dummy_mask_xformers.cpu()) 176 | assert torch.allclose(output_mask_initial_torch, output_mask_initial_xformers.cpu()) 177 | assert torch.allclose(output3_torch, output3_xformers.cpu()) 178 | assert torch.allclose(output_mask_initial_fewq_torch, output_mask_initial_fewq_xformers.cpu()) 179 | -------------------------------------------------------------------------------- /open_lm/evaluate.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import time 4 | import copy 5 | 6 | import torch 7 | import torch.distributed as dist 8 | 9 | try: 10 | import wandb 11 | except ImportError: 12 | wandb = None 13 | 14 | from open_lm.data import sample_chunk 15 | from open_lm.distributed import is_master 16 | from open_lm.precision import get_autocast 17 | from open_lm.meters import ( 18 | AverageMeter, 19 | ConfidenceIntervalMeter, 20 | gather_meters, 21 | ) 22 | 23 | 24 | @torch.inference_mode() 25 | def evaluate(model, data, start_epoch, args, writer): 26 | """ 27 | evaluates perplexity on validation data 28 | """ 29 | if is_master(args): 30 | print("=> begin evaluation") 31 | device = torch.device(args.device) 32 | autocast = get_autocast(args.precision) 33 | 34 | model.eval() 35 | 36 | data["val"].set_epoch(start_epoch) # set epoch in process safe manner via sampler or shared_epoch 37 | dataloader = data["val"].dataloader 38 | 39 | # NOTE: dataloader.num_batches = 0 corresponds to exhausting iterator by convention 40 | exhaust_loader = dataloader.num_batches == 0 41 | 42 | losses_m = AverageMeter() 43 | batch_time_m = AverageMeter() 44 | data_time_m = AverageMeter() 45 | sps_m = AverageMeter() 46 | spspg_m = AverageMeter() 47 | losses_seq_ci_m = ConfidenceIntervalMeter() 48 | losses_tok_ci_m = ConfidenceIntervalMeter() 49 | 50 | end = time.time() 51 | loss = torch.nn.CrossEntropyLoss(reduction="none") 52 | 53 | # by default the dataloader will be exhausted 54 | for i, batch in enumerate(dataloader): 55 | if i == dataloader.num_batches and not exhaust_loader: 56 | break 57 | 58 | (texts,) = batch 59 | texts = torch.LongTensor(texts).to(device) 60 | 61 | data_time_m.update(time.time() - end) 62 | 63 | with autocast(): 64 | inputs, targets = sample_chunk(texts, args) 65 | 66 | out, _, _ = model(inputs) # [per_gpu_bs, seq_len, vocab_size] 67 | 68 | bs, seq_len = targets.shape 69 | 70 | targets = targets.reshape(-1) 71 | total_loss = loss(out.reshape(-1, args.vocab_size), targets) # [bs * seq_len] 72 | 73 | # cross entropy ignores -100 values in loss computation 74 | mask = targets != -100 75 | 76 | # reshape and average for sequence losses 77 | sum_loss_per_seq = torch.sum(total_loss.reshape(bs, seq_len), -1) 78 | num_toks_per_seq = torch.sum(mask.reshape(bs, seq_len), -1).float() 79 | losses_seq_ci_m.update((sum_loss_per_seq / num_toks_per_seq).cpu().numpy()) 80 | 81 | # individual token losses 82 | losses_tok_ci_m.update(total_loss[mask].cpu().numpy()) 83 | 84 | # compute average loss for the mini-batch 85 | total_loss = total_loss[mask].mean() 86 | losses_m.update(total_loss.item(), n=inputs.shape[0]) 87 | 88 | batch_time_m.update(time.time() - end) 89 | end = time.time() 90 | 91 | sps_m.update(inputs.numel() * args.world_size / batch_time_m.val) 92 | spspg_m.update(inputs.numel() / batch_time_m.val) 93 | 94 | if args.distributed: 95 | dist.barrier() 96 | 97 | if args.world_size > 1: 98 | # in this case we need to gather the loss. for simplicity we gather the meters only on main proc 99 | # Save eval loss / etc. 100 | meters = [ 101 | losses_m, 102 | batch_time_m, 103 | data_time_m, 104 | sps_m, 105 | spspg_m, 106 | losses_seq_ci_m, 107 | losses_tok_ci_m, 108 | ] 109 | 110 | # meters on master will become global meters, other meters will remain local 111 | losses_m, batch_time_m, data_time_m, sps_m, spspg_m, losses_seq_ci_m, losses_tok_ci_m = gather_meters( 112 | meters, args 113 | ) 114 | 115 | if args.distributed: 116 | dist.barrier() 117 | 118 | lower_seq, upper_seq, lower_tok, upper_tok = -1.0, -1.0, -1.0, -1.0 119 | if args.val_seq_ci: 120 | lower_seq, upper_seq = losses_seq_ci_m.compute_bootstrap_ci(args.val_max_pop_ci, args.val_iter_ci) 121 | 122 | if args.val_tok_ci: 123 | lower_tok, upper_tok = losses_tok_ci_m.compute_bootstrap_ci(args.val_max_pop_ci, args.val_iter_ci) 124 | 125 | num_seqs = sum([len(p) for p in losses_seq_ci_m.points]) 126 | num_toks = sum([len(p) for p in losses_tok_ci_m.points]) 127 | 128 | # Save eval loss / etc. 129 | log_data = { 130 | "loss": losses_m.avg, 131 | "data_time": data_time_m.avg, 132 | "batch_time": batch_time_m.avg, 133 | "samples_per_second": sps_m.avg, 134 | "samples_per_second_per_gpu": spspg_m.avg, 135 | "loss_sequences_lower_95": lower_seq, 136 | "loss_sequences_upper_95": upper_seq, 137 | "loss_tokens_lower_95": lower_tok, 138 | "loss_tokens_upper_95": upper_tok, 139 | "sequences": num_seqs, 140 | "tokens": num_toks, 141 | } 142 | if args.train_num_samples is not None: 143 | log_data["train_tokens"] = start_epoch * args.train_num_samples * args.seq_len 144 | 145 | for name, val in log_data.items(): 146 | name = "valid/" + name 147 | if writer is not None: 148 | writer.add_scalar(name, val, start_epoch) 149 | if args.wandb and is_master(args): 150 | assert wandb is not None, "Please install wandb." 151 | wandb.log({name: val, "epoch": start_epoch, "tokens": log_data["tokens"]}) 152 | 153 | if is_master(args): 154 | # meters on masters should be global 155 | print(f"evaluation on: {args.val_data}") 156 | print(f"evaluation loss: {losses_m.avg}") 157 | print(f"num loss point evaluations {losses_m.count}") 158 | print(f"evaluation perplexity: {math.exp(losses_m.avg)}") 159 | print(f"num seqs: {num_seqs}") 160 | print(f"num tokens: {num_toks}") 161 | 162 | log_data["checkpoint_path"] = args.resume 163 | log_data["val_data"] = args.val_data 164 | log_data["model"] = args.hf_model if args.hf_model else args.model 165 | 166 | return log_data 167 | 168 | 169 | def evaluate_loop(model, data_list, start_epoch, args, writer): 170 | log_data_list = [] 171 | for i, data in enumerate(data_list): 172 | args_copy = copy.deepcopy(args) 173 | args_copy.val_data = [args.val_data[i]] 174 | args_copy.val_data_key = args.val_data_key[i] 175 | 176 | if args.distributed: 177 | dist.barrier() 178 | 179 | log_data_list.append(evaluate(model, data, start_epoch, args_copy, writer)) 180 | 181 | if args.distributed: 182 | dist.barrier() 183 | 184 | return log_data_list 185 | -------------------------------------------------------------------------------- /open_lm/positional_embedding/llama_rotary.py: -------------------------------------------------------------------------------- 1 | # NOTE: 08/31/23, this class is copied from xformers as there is currently a bug related to which channel dim the rotary embedding is applied to. 2 | # when the upstream issue is fixed, this file should be deleted. To track progress, see this issue: https://github.com/facebookresearch/xformers/issues/841 3 | 4 | # taken from: https://github.com/facebookresearch/xformers/blob/748c159096d4f9fcfe3eaf22801e5aed4777210b/xformers/components/positional_embedding/rotary.py 5 | from typing import Tuple 6 | 7 | import torch 8 | 9 | 10 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, scale: float = 1.0): 11 | """ 12 | Precompute the frequency tensor for complex exponentials (cis) with given dimensions. 13 | 14 | This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' 15 | and the end index 'end'. The 'theta' parameter scales the frequencies. 16 | The returned tensor contains complex values in complex64 data type. 17 | 18 | Args: 19 | dim (int): Dimension of the frequency tensor. 20 | end (int): End index for precomputing frequencies. 21 | theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. 22 | 23 | Returns: 24 | torch.Tensor: Precomputed frequency tensor with complex exponentials. 25 | 26 | 27 | 28 | 29 | """ 30 | freqs = 1.0 / (scale * theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 31 | t = torch.arange(end, device=freqs.device) # type: ignore 32 | freqs = torch.outer(t, freqs).float() # type: ignore 33 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 34 | return freqs_cis 35 | 36 | 37 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 38 | """ 39 | Reshape frequency tensor for broadcasting it with another tensor. 40 | 41 | This function reshapes the frequency tensor to have the same shape as the target tensor 'x' 42 | for the purpose of broadcasting the frequency tensor during element-wise operations. 43 | 44 | Args: 45 | freqs_cis (torch.Tensor): Frequency tensor to be reshaped. 46 | x (torch.Tensor): Target tensor for broadcasting compatibility. 47 | 48 | Returns: 49 | torch.Tensor: Reshaped frequency tensor. 50 | 51 | Raises: 52 | AssertionError: If the frequency tensor doesn't match the expected shape. 53 | AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions. 54 | """ 55 | ndim = x.ndim 56 | assert 0 <= 1 < ndim 57 | assert freqs_cis.shape == (x.shape[1], x.shape[-1]) 58 | # dynamic shape could be more general but torchscript doesn't support it 59 | # shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] 60 | shape = [1, x.shape[1], 1, x.shape[-1]] 61 | return freqs_cis.view(shape) 62 | 63 | 64 | @torch.jit.script 65 | def apply_llama_rotary_pos_emb( 66 | x: torch.Tensor, 67 | freqs_cis: torch.Tensor, 68 | ) -> torch.Tensor: 69 | """ 70 | Apply llama rotary embeddings to input tensors using the given frequency tensor. 71 | 72 | This function applies rotary embeddings to the given 'x' tensors using the provided 73 | frequency tensor 'freqs_cis'. The input tensor is reshaped as complex numbers, and the frequency tensor 74 | is reshaped for broadcasting compatibility. The resulting tensor contain rotary embeddings and is 75 | returned as real tensors. 76 | 77 | Args: 78 | x (torch.Tensor): tensor to apply rotary embeddings. 79 | freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. 80 | 81 | Returns: 82 | torch.Tensor: Tuple of modified query tensor and key tensor with rotary embeddings. 83 | 84 | """ 85 | x_even = x[..., ::2] 86 | x_odd = x[..., 1::2] 87 | 88 | # Stack them along the last dimension to make it [N, ..., D, 2] 89 | x_ = torch.stack([x_even, x_odd], dim=-1) 90 | x_ = torch.view_as_complex(x_.float()) 91 | 92 | freqs_cis = reshape_for_broadcast(freqs_cis, x_) 93 | x_out = torch.view_as_real(x_ * freqs_cis).flatten(3) 94 | return x_out.type_as(x) 95 | 96 | 97 | class LLaMARotaryEmbedding(torch.nn.Module): 98 | """ 99 | The rotary position embeddings from RoFormer_ (Su et. al). 100 | A crucial insight from the method is that the query and keys are 101 | transformed by rotation matrices which depend on the relative positions. 102 | 103 | Other implementations are available in the Rotary Transformer repo_ and in 104 | GPT-NeoX_, GPT-NeoX was an inspiration 105 | 106 | .. _RoFormer: https://arxiv.org/abs/2104.09864 107 | .. _repo: https://github.com/ZhuiyiTechnology/roformer 108 | .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox 109 | 110 | 111 | .. warning: Please note that this embedding is not registered on purpose, as it is transformative 112 | (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis 113 | """ 114 | 115 | def __init__(self, head_dim: int, num_heads: int, seq_len: int, *_, **__): 116 | super().__init__() 117 | # Generate and save the inverse frequency buffer (non trainable) 118 | self.freqs_cis = precompute_freqs_cis( 119 | # Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation of models is 4096. 120 | # Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training or fine-tuning. 121 | head_dim, 122 | seq_len * 2, 123 | ) 124 | 125 | def reset_parameters(self): 126 | pass 127 | 128 | def forward(self, q: torch.Tensor, k: torch.Tensor, offset=0) -> Tuple[torch.Tensor, torch.Tensor]: 129 | assert ( 130 | q.shape[1] + offset <= self.freqs_cis.shape[0] 131 | ), f"offset {offset} or query sequence length {q.shape[1]}\ 132 | \n is too large for the precomputed frequency tensor {self.freqs_cis.shape}" 133 | assert ( 134 | k.shape[1] + offset <= self.freqs_cis.shape[0] 135 | ), f"offset {offset} or key sequence length {k.shape[1]}\ 136 | \n is too large for the precomputed frequency tensor {self.freqs_cis.shape}" 137 | q_seq_len = q.shape[1] 138 | k_seq_len = k.shape[1] 139 | self.freqs_cis = self.freqs_cis.to(q.device) 140 | q_freqs_cis = self.freqs_cis[offset : offset + q_seq_len] 141 | k_freqs_cis = self.freqs_cis[offset : offset + k_seq_len] 142 | 143 | return ( 144 | apply_llama_rotary_pos_emb(q, q_freqs_cis), 145 | apply_llama_rotary_pos_emb(k, k_freqs_cis), 146 | ) 147 | 148 | 149 | class LLaMARotaryWithCast(LLaMARotaryEmbedding): 150 | def forward(self, q, k, v, offset=0): 151 | q, k = super().forward(q, k, offset) 152 | return q.to(v.dtype), k.to(v.dtype), v 153 | -------------------------------------------------------------------------------- /open_lm/utils/verify_converted_llama.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "ename": "", 10 | "evalue": "", 11 | "output_type": "error", 12 | "traceback": [ 13 | "\u001b[1;31mRunning cells with 'openlm' requires the ipykernel package.\n", 14 | "\u001b[1;31mRun the following command to install 'ipykernel' into the Python environment. \n", 15 | "\u001b[1;31mCommand: 'conda install -n openlm ipykernel --update-deps --force-reinstall'" 16 | ] 17 | } 18 | ], 19 | "source": [ 20 | "import torch\n", 21 | "import torch.nn as nn\n", 22 | "import sys\n", 23 | "import json\n", 24 | "from dataclasses import dataclass\n", 25 | "\n", 26 | "sys.path.append(\"../../../open_lm\")\n", 27 | "from open_lm.model import Transformer\n", 28 | "from open_lm.norms import RmsNorm\n", 29 | "\n", 30 | "device = \"cuda:0\"\n", 31 | "cfg = json.load(open(\"../model_configs/llama2_7b.json\"))\n", 32 | "\n", 33 | "\n", 34 | "@dataclass\n", 35 | "class Params:\n", 36 | " dim: int\n", 37 | " n_layers: int\n", 38 | " n_heads: int\n", 39 | " vocab_size: int\n", 40 | " norm_eps: float\n", 41 | " seq_len: int\n", 42 | " post_embed_norm: bool\n", 43 | " weight_tying: bool\n", 44 | " norm_type: nn.Module = RmsNorm # Make sure to use RmsNorm for LLaMA\n", 45 | " apply_qk_norm: bool = False\n", 46 | " positional_embedding_type: str = \"llama_rotary\" # Make sure to set this for LLaMA\n", 47 | " ffn_type: str = \"swiglu\"\n", 48 | "\n", 49 | "\n", 50 | "args = Params(\n", 51 | " dim=cfg[\"hidden_dim\"],\n", 52 | " n_layers=cfg[\"n_layers\"],\n", 53 | " n_heads=cfg[\"n_heads\"],\n", 54 | " seq_len=cfg[\"seq_len\"],\n", 55 | " vocab_size=cfg[\"vocab_size\"],\n", 56 | " post_embed_norm=cfg[\"post_embed_norm\"],\n", 57 | " weight_tying=cfg[\"weight_tying\"],\n", 58 | " norm_eps=1e-5,\n", 59 | ")\n", 60 | "\n", 61 | "model = Transformer(args)\n", 62 | "state_dict = torch.load(\"./LLAMA2/llama-2-7b/consolidated.00.converted.pth\")\n", 63 | "model.load_state_dict(state_dict, strict=True)\n", 64 | "model = model.eval().to(device)" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": {}, 71 | "outputs": [ 72 | { 73 | "ename": "", 74 | "evalue": "", 75 | "output_type": "error", 76 | "traceback": [ 77 | "\u001b[1;31mRunning cells with 'openlm' requires the ipykernel package.\n", 78 | "\u001b[1;31mRun the following command to install 'ipykernel' into the Python environment. \n", 79 | "\u001b[1;31mCommand: 'conda install -n openlm ipykernel --update-deps --force-reinstall'" 80 | ] 81 | } 82 | ], 83 | "source": [ 84 | "sys.path.append(\"./LLAMA2/llama\")\n", 85 | "from llama.tokenizer import Tokenizer\n", 86 | "\n", 87 | "tokenizer = Tokenizer(\"./LLAMA2/tokenizer.model\")\n", 88 | "\n", 89 | "\n", 90 | "def sample_top_p(probs, p):\n", 91 | " probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)\n", 92 | " probs_sum = torch.cumsum(probs_sort, dim=-1)\n", 93 | " mask = probs_sum - probs_sort > p\n", 94 | " probs_sort[mask] = 0.0\n", 95 | " probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))\n", 96 | " next_token = torch.multinomial(probs_sort, num_samples=1)\n", 97 | " next_token = torch.gather(probs_idx, -1, next_token)\n", 98 | " return next_token\n", 99 | "\n", 100 | "\n", 101 | "def generate_top_p_language(prefix: str, temperature: float = 0.6, top_p: float = 0.9, max_len: int = 128):\n", 102 | " input_tokens = tokenizer.encode(prefix, bos=True, eos=False)\n", 103 | " tokens = torch.tensor(input_tokens).unsqueeze(0).to(device)\n", 104 | "\n", 105 | " for i in range(max_len):\n", 106 | " with torch.no_grad():\n", 107 | " logits, _, _ = model(tokens)\n", 108 | " if temperature > 0:\n", 109 | " probs = torch.softmax(logits[:, -1] / temperature, dim=-1)\n", 110 | " next_token = sample_top_p(probs, top_p)\n", 111 | " else:\n", 112 | " next_token = torch.argmax(logits[:, -1], dim=-1, keepdim=True)\n", 113 | " tokens = torch.cat([tokens, next_token], dim=-1)\n", 114 | "\n", 115 | " generation = tokenizer.decode(tokens[0].cpu().numpy().tolist())\n", 116 | " return generation\n", 117 | "\n", 118 | "\n", 119 | "prompts = [\n", 120 | " # For these prompts, the expected answer is the natural continuation of the prompt\n", 121 | " \"I believe the meaning of life is\",\n", 122 | " \"Simply put, the theory of relativity states that \",\n", 123 | " \"\"\"A brief message congratulating the team on the launch:\n", 124 | "\n", 125 | " Hi everyone,\n", 126 | " \n", 127 | " I just \"\"\",\n", 128 | " # Few shot prompt (providing a few examples before asking model to complete more);\n", 129 | " \"\"\"Translate English to French:\n", 130 | " \n", 131 | " sea otter => loutre de mer\n", 132 | " peppermint => menthe poivrée\n", 133 | " plush girafe => girafe peluche\n", 134 | " cheese =>\"\"\",\n", 135 | " \"\"\"He -> Him, She -> Her, They ->\"\"\",\n", 136 | " \"\"\"Who is Donald Trump?\"\"\",\n", 137 | "]\n", 138 | "\n", 139 | "for prompt in prompts:\n", 140 | " print(\"====================================\")\n", 141 | " generated_text = generate_top_p_language(prompt)\n", 142 | " print(prompt)\n", 143 | " print(generated_text)\n", 144 | " print(\"====================================\")" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": {}, 151 | "outputs": [ 152 | { 153 | "ename": "", 154 | "evalue": "", 155 | "output_type": "error", 156 | "traceback": [ 157 | "\u001b[1;31mRunning cells with 'openlm' requires the ipykernel package.\n", 158 | "\u001b[1;31mRun the following command to install 'ipykernel' into the Python environment. \n", 159 | "\u001b[1;31mCommand: 'conda install -n openlm ipykernel --update-deps --force-reinstall'" 160 | ] 161 | } 162 | ], 163 | "source": [] 164 | } 165 | ], 166 | "metadata": { 167 | "kernelspec": { 168 | "display_name": "modeldiff", 169 | "language": "python", 170 | "name": "python3" 171 | }, 172 | "language_info": { 173 | "codemirror_mode": { 174 | "name": "ipython", 175 | "version": 3 176 | }, 177 | "file_extension": ".py", 178 | "mimetype": "text/x-python", 179 | "name": "python", 180 | "nbconvert_exporter": "python", 181 | "pygments_lexer": "ipython3", 182 | "version": "3.11.5" 183 | } 184 | }, 185 | "nbformat": 4, 186 | "nbformat_minor": 2 187 | } 188 | -------------------------------------------------------------------------------- /open_lm/open_lm_hf/modeling_openlm.py: -------------------------------------------------------------------------------- 1 | # Follows OLMo's HF template 2 | 3 | import logging 4 | from dataclasses import fields 5 | from typing import List, Optional, Tuple, Union 6 | 7 | import torch 8 | from transformers import PreTrainedModel 9 | from transformers.cache_utils import Cache 10 | from transformers.modeling_outputs import CausalLMOutputWithPast 11 | from transformers.models.auto import AutoModelForCausalLM 12 | 13 | from open_lm.model import Params, Transformer 14 | from open_lm.norms import get_norm_class 15 | 16 | from .configuration_openlm import OpenLMConfig 17 | 18 | log = logging.getLogger(__name__) 19 | 20 | 21 | def create_model_config_from_pretrained_config(config: OpenLMConfig): 22 | """ 23 | Utility function 24 | """ 25 | 26 | kwargs = {} 27 | for field in fields(Params): 28 | if hasattr(config, field.name): 29 | kwargs[field.name] = getattr(config, field.name) 30 | 31 | model_config = Params(**kwargs) 32 | 33 | if hasattr(config, "norm_type"): 34 | model_config.norm_type = get_norm_class(config.norm_type) 35 | 36 | return model_config 37 | 38 | 39 | class OpenLMForCausalLM(PreTrainedModel): 40 | """ 41 | Extremely barebones HF model wrapper. 42 | """ 43 | 44 | config_class = OpenLMConfig 45 | base_model_prefix = "model" 46 | 47 | def __init__(self, config: OpenLMConfig, model: Optional[Transformer] = None): 48 | super().__init__(config) 49 | 50 | if not model: 51 | self.model_config = create_model_config_from_pretrained_config(config) 52 | # Initialize model (always on CPU to start with so we don't run out of GPU memory). 53 | self.model_config.init_device = "cpu" 54 | self.model = Transformer(self.model_config) 55 | 56 | else: 57 | self.model = model 58 | 59 | def forward( 60 | self, 61 | input_ids: torch.LongTensor = None, 62 | inputs_embeds: Optional[torch.FloatTensor] = None, 63 | attention_mask: Optional[torch.Tensor] = None, 64 | attention_bias: Optional[torch.Tensor] = None, 65 | past_key_values: Optional[List[torch.FloatTensor]] = None, 66 | labels: Optional[torch.LongTensor] = None, 67 | use_cache: Optional[bool] = None, 68 | output_attentions: Optional[bool] = None, 69 | output_hidden_states: Optional[bool] = None, 70 | return_dict: Optional[bool] = None, 71 | cache_position: Optional[ 72 | Cache 73 | ] = None, # This is a hack mitigation of an issue in transformers `4.39.x` https://github.com/huggingface/transformers/issues/29426 74 | ) -> Union[Tuple, CausalLMOutputWithPast]: 75 | 76 | if inputs_embeds is not None: 77 | log.warning("inputs_embeds is set but OpenLM does not support it yet") 78 | if attention_bias is not None: 79 | log.warning("attention_bias is et but OpenLM does not support it yet") 80 | if use_cache is None: 81 | use_cache = True 82 | if output_attentions: 83 | raise ValueError("output_attentions is not yet supported in OpenLM") 84 | if output_hidden_states: 85 | raise ValueError("output_hidden_states is not yet supported in OpenLM") 86 | 87 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 88 | # print("outer past_key_values: ", type(past_key_values)) 89 | # if past_key_values is not None: 90 | # print(len(past_key_values), type(past_key_values[0])) 91 | outputs = self.model.forward( 92 | input_ids=input_ids, 93 | attention_mask=attention_mask, 94 | past_key_values=past_key_values, 95 | use_cache=use_cache, 96 | ) 97 | 98 | logits = outputs[0] 99 | past_key_values = outputs[2] 100 | hidden_states = None 101 | 102 | loss = None 103 | if labels is not None: 104 | # Shift so that tokens < n predict n 105 | shift_logits = logits[..., :-1, :].contiguous() 106 | shift_labels = labels[..., 1:].contiguous() 107 | # Flatten the tokens 108 | loss_fct = torch.nn.CrossEntropyLoss() 109 | shift_logits = shift_logits.view(-1, self.model_config.vocab_size) 110 | shift_labels = shift_labels.view(-1) 111 | # Enable model parallelism 112 | shift_labels = shift_labels.to(shift_logits.device) 113 | loss = loss_fct(shift_logits, shift_labels) 114 | 115 | return CausalLMOutputWithPast( 116 | loss=loss, 117 | logits=logits, 118 | past_key_values=past_key_values, 119 | hidden_states=hidden_states, 120 | ) 121 | 122 | def can_generate(self) -> bool: 123 | return True 124 | 125 | def prepare_inputs_for_generation( 126 | self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs 127 | ): 128 | if past_key_values is not None: 129 | if isinstance(past_key_values[0][1], int): 130 | # This assumes that the second item of past key values is the length of the past (this is the case for linear attention) 131 | past_length = past_key_values[0][1] 132 | else: 133 | # This assumes that the first item of past key values is a list of all the past keys, thus the shape 1 is the length of the past (this is the case for attention without window) 134 | past_length = past_key_values[0][0].shape[1] 135 | 136 | # Some generation methods already pass only the last input ID 137 | if input_ids.shape[1] > past_length: 138 | remove_prefix_length = past_length 139 | else: 140 | # Default to old behavior: keep only final ID 141 | remove_prefix_length = input_ids.shape[1] - 1 142 | 143 | input_ids = input_ids[:, remove_prefix_length:] 144 | 145 | model_inputs = { 146 | "input_ids": input_ids, 147 | "past_key_values": past_key_values, 148 | "use_cache": kwargs.pop("use_cache", True), 149 | } 150 | return model_inputs 151 | 152 | def get_input_embeddings(self) -> torch.nn.Module: 153 | return self.model.tok_embeddings 154 | 155 | def set_input_embeddings(self, value: torch.nn.Module): 156 | self.model.tok_embeddings = value 157 | 158 | def get_output_embeddings(self): 159 | if self.model_config.weight_tying: 160 | return self.model.tok_embeddings 161 | else: 162 | return self.model.output 163 | 164 | def set_output_embeddings(self, value: torch.nn.Module): 165 | if self.model_config.weight_tying: 166 | self.model.tok_embeddings = value 167 | else: 168 | self.model.output = value 169 | 170 | def tie_weights(self): 171 | """ 172 | Copied from OLMo (description below). I removed it and the results just became garbage, so this pass is needed. 173 | This function is intentionally left as a no-op. 174 | Weight tying is handled as follows: 175 | - When the model is initialized, the `ff_out` layer is conditionally defined based on the `weight_tying` configuration. 176 | See: `if not config.weight_tying: self.transformer.update(...)` in `olmo/model.py`. 177 | - When computing logits, the `wte` weights are used directly if `weight_tying` is enabled. 178 | See: `if self.config.weight_tying: logits = F.linear(x, self.transformer.wte.weight, None)` in the `forward` method. 179 | Therefore, there is no need to explicitly tie the weights in this function. 180 | """ 181 | pass 182 | 183 | def resize_token_embeddings( 184 | self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None 185 | ) -> torch.nn.Embedding: 186 | raise NotImplementedError 187 | 188 | 189 | # Register the model so that it is available for transformer pipelines, auto-loading, etc. 190 | AutoModelForCausalLM.register(OpenLMConfig, OpenLMForCausalLM) 191 | -------------------------------------------------------------------------------- /open_lm/datapreprocess/make_assistant_data.py: -------------------------------------------------------------------------------- 1 | import jsonlines 2 | import glob 3 | import tiktoken 4 | import os 5 | import threading 6 | from webdataset import ShardWriter 7 | import random 8 | import time 9 | import boto3 10 | import io 11 | import zstandard as zstd 12 | from contextlib import contextmanager 13 | import argparse 14 | from pathlib import Path 15 | from transformers import GPTNeoXTokenizerFast 16 | 17 | 18 | QUEUE_MAX = 10000 19 | BUFFER_MIN = 1000 20 | BUFFER_MAX = 200000 21 | CHUNK_SIZE = 2048 + 1 22 | SHARD_SIZE = 267 23 | SLEEP_TIME = 1 24 | S3_BUCKET = "s-laion" 25 | S3_SUFFIX = "validation_data_tokenized/" 26 | S3_BASE = f"s3://" 27 | 28 | eot_token = "<|endoftext|>" 29 | pad_token = "<|pad|>" 30 | 31 | 32 | def write_to_shard(chunks, shard_writer): 33 | for idx, chunk in enumerate(chunks): 34 | shard_writer.write({"__key__": f"{idx:12d}", "txt": str(chunk)}) 35 | 36 | 37 | def upload_to_s3_and_remove(fname): 38 | fname_split = fname.split("/") 39 | s3_path = S3_BASE + fname_split[-2] + "/" + fname_split[-1] 40 | cmd = f"aws s3 cp {fname} {s3_path} && rm {fname}" 41 | print("COMMAND:", cmd) 42 | os.system(cmd) 43 | 44 | 45 | @contextmanager 46 | def get_item_reader(file_name): 47 | if file_name.endswith(".jsonl"): 48 | with jsonlines.open(file_name) as reader: 49 | yield reader 50 | else: 51 | dctx = zstd.ZstdDecompressor() 52 | with open(file_name, "rb") as compressed_file: 53 | with dctx.stream_reader(compressed_file) as reader: 54 | with io.TextIOWrapper(reader, encoding="utf-8") as text_reader: 55 | with jsonlines.Reader(text_reader) as jsonl_reader: 56 | yield jsonl_reader 57 | 58 | 59 | def process_files(file_list, buffer, enc, buffer_lock): 60 | remaining_tokens = [] 61 | queue = [] 62 | 63 | def dump_queue_to_buffer(): 64 | with buffer_lock: 65 | while queue: 66 | buffer.append(queue.pop(0)) 67 | 68 | for file_name in file_list: 69 | print("Processing", file_name) 70 | 71 | with get_item_reader(file_name) as item_reader: 72 | for item in item_reader: 73 | string = item["text"] 74 | try: 75 | tokens = remaining_tokens + enc(string) + [eot_token] 76 | remaining_tokens = [] 77 | except: 78 | print("Failed to encode string.") 79 | continue 80 | 81 | for i in range(0, len(tokens), CHUNK_SIZE): 82 | chunk = tokens[i : i + CHUNK_SIZE] 83 | if len(chunk) < CHUNK_SIZE: 84 | remaining_tokens = chunk 85 | else: 86 | if len(buffer) > BUFFER_MAX: 87 | time.sleep(1) 88 | continue 89 | 90 | if buffer_lock.locked(): 91 | if len(queue) < QUEUE_MAX: 92 | queue.append(chunk) 93 | else: 94 | time.sleep(1) 95 | else: 96 | if queue: 97 | dump_queue_to_buffer() 98 | with buffer_lock: 99 | buffer.append(chunk) 100 | 101 | 102 | def consumer(my_id, output_dir, threads, buffer, buffer_lock, num_consumers, upload_to_s3=False): 103 | output_directory = f"{output_dir}/{CHUNK_SIZE - 1}-v1/{my_id}" 104 | os.makedirs(output_directory, exist_ok=True) 105 | shard_writer = ShardWriter(os.path.join(output_directory, "shard-%07d.tar"), maxcount=SHARD_SIZE) 106 | 107 | chunks = [] 108 | 109 | start_time = time.time() 110 | 111 | while any(t.is_alive() for t in threads): 112 | time.sleep(SLEEP_TIME) 113 | with buffer_lock: 114 | lenb = len(buffer) 115 | print("Length of buffer", lenb) 116 | if lenb >= BUFFER_MIN: 117 | while buffer and len(chunks) < SHARD_SIZE: 118 | random_index = random.randint(0, len(buffer) - 1) 119 | chunks.append(buffer[random_index]) 120 | buffer.pop(random_index) # Remove the selected element 121 | 122 | if len(chunks) == SHARD_SIZE: 123 | print(f"I am {my_id} and I am writing a shard.", len(buffer)) 124 | write_to_shard(chunks, shard_writer) 125 | # print("FNAME", shard_writer.fname) 126 | chunks = [] 127 | time_for_shard = time.time() - start_time 128 | print("shards / s", num_consumers / time_for_shard) 129 | print("tokens / s", num_consumers * SHARD_SIZE * CHUNK_SIZE / time_for_shard) 130 | print( 131 | "hours req for 1.2T tokens", 132 | 1_200_000_000_000 / (num_consumers * SHARD_SIZE * CHUNK_SIZE / time_for_shard) / 3600, 133 | ) 134 | 135 | start_time = time.time() 136 | 137 | # Process the remaining items in the buffer after all threads have completed 138 | while buffer: 139 | with buffer_lock: 140 | while buffer and len(chunks) < SHARD_SIZE: 141 | random_index = random.randint(0, len(buffer) - 1) 142 | chunks.append(buffer[random_index]) 143 | buffer.pop(random_index) # Remove the selected element 144 | 145 | write_to_shard(chunks, shard_writer) 146 | chunks = [] 147 | 148 | 149 | def tokenize_eleutherai(tokenizer, string): 150 | return tokenizer(string).input_ids 151 | 152 | 153 | def main( 154 | input_files, 155 | output_dir, 156 | tokenizer="EleutherAI/gpt-neox-20b", 157 | num_workers=32, 158 | num_consumers=8, 159 | upload_to_s3=False, 160 | ): 161 | os.makedirs(f"{output_dir}/tars-{CHUNK_SIZE - 1}-v1", exist_ok=True) 162 | 163 | input_files = [glob.glob(input_file) for input_file in input_files] 164 | input_files = [x for y in input_files for x in y] 165 | 166 | # Shuffle the input files 167 | random.shuffle(input_files) 168 | 169 | print("Input files", input_files) 170 | 171 | enc = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b") 172 | 173 | tokenize = lambda x: tokenize_eleutherai(enc, x) 174 | buffer = [] # Use list instead of queue.Queue 175 | buffer_lock = threading.Lock() 176 | 177 | files_per_worker = len(input_files) // num_workers 178 | threads = [] 179 | for i in range(num_workers): 180 | start = i * files_per_worker 181 | end = (i + 1) * files_per_worker if i < num_workers - 1 else len(input_files) 182 | t = threading.Thread( 183 | target=process_files, 184 | args=(input_files[start:end], buffer, tokenize, buffer_lock), 185 | ) 186 | t.start() 187 | threads.append(t) 188 | 189 | consumer_threads = [] 190 | for i in range(num_consumers): 191 | t = threading.Thread( 192 | target=consumer, 193 | args=( 194 | i, 195 | output_dir, 196 | threads, 197 | buffer, 198 | buffer_lock, 199 | num_consumers, 200 | upload_to_s3, 201 | ), 202 | ) 203 | t.start() 204 | consumer_threads.append(t) 205 | 206 | 207 | if __name__ == "__main__": 208 | parser = argparse.ArgumentParser() 209 | parser.add_argument("--input-files", type=str, nargs="+") 210 | parser.add_argument("--output-dir", type=Path) 211 | parser.add_argument("--tokenizer", type=str, default="EleutherAI/gpt-neox-20b") 212 | parser.add_argument("--num-workers", type=int, default=32) 213 | parser.add_argument("--num-consumers", type=int, default=8) 214 | parser.add_argument("--upload-to-s3", action="store_true") 215 | 216 | args = parser.parse_args() 217 | 218 | main( 219 | args.input_files, 220 | args.output_dir, 221 | args.tokenizer, 222 | args.num_workers, 223 | args.num_consumers, 224 | args.upload_to_s3, 225 | ) 226 | --------------------------------------------------------------------------------