├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── workflows │ ├── build_jaxlib.yml │ ├── ci.yml │ ├── docs.yml │ ├── release_alpa.yml │ └── release_jaxlib.yml ├── .gitignore ├── .gitmodules ├── .pylintrc ├── .style.yapf ├── LICENSE ├── README.md ├── alpa ├── __init__.py ├── api.py ├── collective │ ├── __init__.py │ ├── collective.py │ ├── collective_group │ │ ├── __init__.py │ │ ├── base_collective_group.py │ │ ├── cuda_stream.py │ │ ├── gloo_collective_group.py │ │ ├── gloo_util.py │ │ ├── nccl_collective_group.py │ │ ├── nccl_util.py │ │ ├── xla_nccl_collective_group.py │ │ └── xla_nccl_util.py │ ├── const.py │ ├── requirements.txt │ ├── types.py │ ├── util.py │ ├── worker_nccl_util.py │ ├── worker_nccl_util_cupy.py │ └── worker_nccl_util_xla.py ├── create_state_parallel.py ├── data_loader.py ├── device_mesh.py ├── follow_parallel.py ├── global_env.py ├── mesh_executable.py ├── mesh_profiling.py ├── model │ ├── __init__.py │ ├── bert_model.py │ ├── conformer.py │ ├── gpt_model.py │ ├── model_util.py │ ├── moe.py │ ├── unet_2d.py │ └── wide_resnet.py ├── monkey_patch.py ├── parallel_method.py ├── parallel_plan.py ├── pipeline_parallel │ ├── __init__.py │ ├── apply_grad.py │ ├── compile_executable.py │ ├── computation.py │ ├── cross_mesh_resharding.py │ ├── layer_construction.py │ ├── layer_stats.py │ ├── local_pipeline.py │ ├── pipeshard_executable.py │ ├── primitive_def.py │ ├── resharding_tensor.py │ ├── runtime_emitter.py │ ├── schedules.py │ ├── stage_construction.py │ └── stage_profiling.py ├── serialization.py ├── serve │ ├── __init__.py │ ├── controller.py │ ├── http_util.py │ └── run.py ├── shard_parallel │ ├── __init__.py │ ├── auto_sharding.py │ ├── compile_executable.py │ └── manual_sharding.py ├── test_install.py ├── testing.py ├── timer.py ├── torch │ ├── __init__.py │ ├── nn │ │ ├── __init__.py │ │ └── utils.py │ ├── ops │ │ ├── __init__.py │ │ └── mapping.py │ ├── optim │ │ ├── __init__.py │ │ └── adam.py │ ├── tensor_utils.py │ └── trainer.py ├── util.py ├── version.py └── wrapped_hlo.py ├── benchmark ├── alpa │ ├── README.md │ ├── benchmark.py │ ├── benchmark_one_case.py │ ├── benchmark_one_case_gpt_bert.py │ ├── benchmark_one_case_gpt_bert_inference.py │ ├── benchmark_one_case_moe.py │ ├── benchmark_one_case_moe_inference.py │ ├── benchmark_one_case_unet.py │ ├── benchmark_one_case_wresnet.py │ ├── benchmark_parallel_utils.py │ ├── gather_gpu_stat.py │ ├── gen_prof_database.py │ ├── gen_serving_database.py │ ├── inspect_prof_database.py │ ├── resharding │ │ ├── README.md │ │ ├── benchmark.py │ │ ├── benchmark_cross_mesh_resharding.py │ │ └── suite.py │ ├── run_exp.py │ ├── suite_auto_gpt.py │ ├── suite_auto_moe.py │ ├── suite_inference_gpt.py │ ├── suite_inference_moe.py │ ├── suite_manual_gpt.py │ ├── suite_manual_moe.py │ ├── suite_unet.py │ ├── suite_wresnet.py │ └── util.py ├── cupy │ ├── profile_communication.py │ └── profile_matmul.py ├── deepspeed │ ├── README.md │ ├── benchmark_gpt2.py │ ├── benchmark_moe.py │ ├── ds_zero_stage_2_config.json │ ├── ds_zero_stage_2_moe_config.json │ ├── ds_zero_stage_3_config.json │ ├── hostfile │ ├── killall_python.sh │ ├── patch │ │ ├── gpt2_model.py │ │ ├── training.py │ │ └── transformer.py │ ├── pretrain_gpt2.py │ ├── pretrain_gpt2_moe.py │ ├── training.py │ └── util.py └── megatron │ ├── README.md │ ├── benchmark_gpt_bert.py │ ├── benchmark_gpt_bert_one_case.py │ ├── benchmark_mlp.py │ ├── benchmark_mlp_one_case.py │ ├── benchmark_transformer_layer.py │ ├── benchmark_transformer_layer_one_case.py │ └── util.py ├── build_jaxlib ├── .bazelrc ├── .bazelversion ├── WORKSPACE ├── build │ ├── BUILD.bazel │ ├── LICENSE.txt │ ├── build.py │ └── build_wheel.py ├── jax ├── jaxlib ├── release │ ├── README.md │ ├── generate_pypi_index.py │ └── wheel_upload.py ├── third_party └── update_build_scripts.patch ├── docker ├── README.md ├── build_alpa.Dockerfile ├── build_doc.Dockerfile ├── build_jaxlib.Dockerfile ├── coreweave │ ├── README.md │ ├── cluster.yaml │ └── run_alpa_infiniband.Dockerfile ├── run_alpa.Dockerfile ├── scripts │ ├── build_alpa.sh │ ├── build_doc.sh │ ├── build_jaxlib_docker_entrypoint.sh │ ├── install_cuda.sh │ ├── install_torch.sh │ └── test_alpa_docker_entrypoint.sh └── unittest.Dockerfile ├── docs ├── Makefile ├── README.md ├── architecture │ ├── alpa-arch.png │ ├── alpa_compiler_walk_through.rst │ ├── cluster-mesh.png │ ├── intra_op_solver.rst │ ├── mesh-worker.png │ ├── overview.rst │ └── parallelism-view-and-rationale.rst ├── benchmark │ ├── bench-paper.png │ └── benchmark.rst ├── cluster_setup.md ├── conf.py ├── developer │ └── developer_guide.rst ├── gallery │ └── tutorials │ │ ├── README.rst │ │ ├── advanced_api_usage.py_disable │ │ ├── alpa_vs_pmap.py │ │ ├── pipeshard_parallelism.py │ │ └── quickstart.py ├── index.rst ├── install.rst ├── logo │ ├── alpa-logo-cropped.png │ ├── alpa-logo-cropped.svg │ ├── alpa-logo-no-word.ico │ ├── alpa-logo-no-word.png │ ├── alpa-logo.ico │ ├── alpa-logo.jpg │ ├── alpa-logo.pdf │ ├── alpa-logo.png │ └── alpa-logo.psd ├── make.bat ├── publications │ └── publications.rst ├── publish.py └── tutorials │ ├── alpa_on_slurm.rst │ ├── icml_big_model_tutorial.rst │ ├── opt_serving.rst │ └── perf_tuning_guide.rst ├── examples ├── ViT │ ├── README.md │ └── run_image_classification.py ├── __init__.py ├── gpt2 │ ├── README.md │ ├── create_config.py │ ├── run_clm_flax.py │ └── train_tokenizer.py ├── imagenet │ ├── README.md │ ├── configs │ │ ├── default.py │ │ ├── fake_data_benchmark.py │ │ ├── tpu.py │ │ ├── v100_x8.py │ │ └── v100_x8_mixed_precision.py │ ├── input_pipeline.py │ ├── main.py │ ├── models.py │ └── train.py ├── llm_serving │ ├── README.rst │ ├── __init__.py │ ├── benchmark │ │ ├── benchmark_1d.py │ │ ├── benchmark_step_func.py │ │ └── benchmark_text_gen.py │ ├── client.py │ ├── codegen.py │ ├── generator.py │ ├── launch_model_worker.py │ ├── launch_website.py │ ├── log_config.yaml │ ├── model │ │ ├── __init__.py │ │ ├── bloom_model.py │ │ ├── codegen_model.py │ │ ├── opt_model.py │ │ ├── opt_model_1d.py │ │ ├── opt_utils.py │ │ ├── test_cache.py │ │ ├── wrapper.py │ │ └── wrapper_1d.py │ ├── scripts │ │ ├── step_2_consolidate_992_shards_to_singleton.py │ │ ├── step_3_convert_to_numpy_weights.py │ │ └── utils.py │ ├── service │ │ ├── __init__.py │ │ ├── constants.py │ │ ├── recaptcha.py │ │ ├── scheduler.py │ │ ├── static │ │ │ ├── img.png │ │ │ └── index.html │ │ └── utils.py │ ├── test_completions.py │ ├── test_logprobs.py │ ├── test_textgen.sh │ ├── textgen.py │ └── textgen_1d.py ├── mnist │ ├── README.md │ ├── configs │ │ └── default.py │ ├── main.py │ ├── requirements.txt │ ├── train.py │ └── train_ray.py ├── opt_finetune │ ├── README.md │ ├── run_125m_shard.sh │ ├── run_2.7b_pipe.sh │ ├── run_2.7b_shard.sh │ └── run_clm_flax.py ├── setup.py └── slurm_script_examples │ ├── test_cuda.sh │ ├── test_prerequisites.sh │ ├── test_ray_multinode.sh │ ├── textgen_alpa_test.sh │ └── textgen_pt_test.sh ├── format.sh ├── playground ├── alpa_micro_benchmark │ ├── benchmark_dist_save_load.py │ ├── test_export_hlo.py │ └── test_shard_array.py ├── auto_sharding_solver │ ├── README.md │ ├── cluster_env.py │ ├── common.py │ ├── hlo.py │ ├── run_all.sh │ ├── solver.py │ ├── test_cost.py │ ├── test_sharding_spec.py │ ├── test_solver_attention.py │ └── test_solver_mlp.py ├── jax_basic │ ├── slice_jaxpr.ipynb │ ├── test_device_put.py │ ├── test_flop_count.py │ ├── test_jit.py │ ├── test_matmul_pmap.py │ ├── test_memory_allocator.py │ ├── test_mixed_precision.py │ ├── test_pjit.py │ ├── test_pmap.py │ ├── test_scan.py │ ├── test_sharding_spec.py │ ├── test_tuple_args.py │ ├── test_while.py │ ├── test_xmap.py │ └── util.py ├── other │ ├── input_pipeline.py │ ├── test_cupy_partial_transfer.py │ ├── test_ray_dataloader.py │ ├── test_ray_put.py │ ├── test_remote_call_cost.py │ ├── test_torch_ddp.py │ └── test_torch_trace.py ├── pipeline │ ├── auto_pipeline_slicing_dp.ipynb │ ├── jax_array_slicing.py │ ├── mesh_slicing.ipynb │ ├── profile_compilation.py │ ├── test_acc_grad.py │ ├── test_compile_and_profile.py │ ├── test_distributed_compile.py │ ├── test_generate_schedule.py │ ├── test_pipeline_mlp_distributed.py │ └── test_ray_jax_array.py └── xla_builder │ ├── test_multi_host.py │ └── test_xla_builder.py ├── setup.py ├── tests ├── README.md ├── __init__.py ├── killall_python.sh ├── pipeline_parallel │ ├── test_bert.py │ ├── test_cross_mesh_resharding.py │ ├── test_dynamic_programming.py │ ├── test_global_norm.py │ ├── test_inference_auto.py │ ├── test_inference_only.py │ ├── test_layer_construction.py │ ├── test_manual_sharding.py │ ├── test_mlp.py │ ├── test_multi_graph.py │ ├── test_old_dp_vs_new_dp.py │ ├── test_pipeline_marker.py │ ├── test_reduce_scatter.py │ ├── test_remat.py │ ├── test_scatter_gather.py │ ├── test_schedules.py │ ├── test_set_input_shard.py │ ├── test_stage_construction.py │ ├── test_stage_construction_slow.py │ ├── test_stage_construction_util.py │ └── test_tied_embedding.py ├── run_all.py ├── runtime │ ├── test_create_state.py │ ├── test_cross_mesh_communicator.py │ ├── test_data_loader.py │ ├── test_debug_info.py │ ├── test_device_mesh.py │ ├── test_dist_save_load.py │ ├── test_follow_parallel.py │ ├── test_install.py │ ├── test_memory_leak.py │ ├── test_parallel_plan.py │ ├── test_random_seed.py │ ├── test_save_load.py │ ├── test_tracing.py │ └── test_xla_nccl.py ├── serve │ └── test_controller.py ├── shard_parallel │ ├── test_basic.py │ ├── test_bert.py │ ├── test_conv.py │ ├── test_gradient_accumulation.py │ ├── test_manual.py │ ├── test_mixed_2d.py │ ├── test_mlp.py │ ├── test_moe.py │ └── test_numerical_correctness.py ├── torch_frontend │ ├── test_dict_input.py │ ├── test_reshape.py │ ├── test_simple.py │ └── test_zhen.py ├── tpu │ ├── test_create_state_parallel.py │ ├── test_follow_parallel.py │ └── test_shard_parallel.py └── util │ ├── test_hlo_cost_model.py │ └── test_ordered_set.py └── update_version.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve Alpa 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Please describe the bug** 11 | 12 | **Please describe the expected behavior** 13 | 14 | **System information and environment** 15 | - OS Platform and Distribution (e.g., Linux Ubuntu 16.04, docker): 16 | - Python version: 17 | - CUDA version: 18 | - NCCL version: 19 | - cupy version: 20 | - GPU model and memory: 21 | - Alpa version: 22 | - TensorFlow version: 23 | - JAX version: 24 | 25 | **To Reproduce** 26 | Steps to reproduce the behavior: 27 | 1. 28 | 2. 29 | 3. 30 | 4. See error 31 | 32 | **Screenshots** 33 | If applicable, add screenshots to help explain your problem. 34 | 35 | **Code snippet to reproduce the problem** 36 | 37 | **Additional information** 38 | Add any other context about the problem here or include any logs that would be helpful to diagnose the problem. -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest a new feature for Alpa 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **System information** 11 | - Alpa version: 12 | - Are you willing to contribute it (Yes/No): 13 | 14 | **Describe the new feature and the current behavior/state** 15 | 16 | **Will this change the current API? How?** 17 | 18 | **Describe alternatives you've considered** 19 | 20 | **Additional context** -------------------------------------------------------------------------------- /.github/workflows/build_jaxlib.yml: -------------------------------------------------------------------------------- 1 | name: Build Jaxlib 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | tensorflow: 7 | description: 'TensorFlow-alpa branch to build' 8 | required: true 9 | default: 'master' 10 | 11 | 12 | env: 13 | TF_BRANCH: ${{ github.event.inputs.tensorflow }} 14 | 15 | 16 | jobs: 17 | build_jaxlib: 18 | name: Build JaxLib wheels 19 | runs-on: [self-hosted] 20 | # change the following to build with 21 | # Python: 3.7, 3.8. 3.9 22 | # CUDA 11.1, 11.2, 11.3 23 | # Using github matrix 24 | 25 | steps: 26 | - name: Cancel previous 27 | uses: styfle/cancel-workflow-action@0.9.1 28 | with: 29 | access_token: ${{ secrets.PAT_TOKEN }} 30 | if: ${{github.ref != 'refs/head/main'}} 31 | 32 | # checkout repo 33 | - uses: actions/checkout@v3 34 | 35 | - name: clean up images 36 | run: | 37 | docker image prune -f 38 | 39 | - name: build image 40 | run: | 41 | docker build -t build-jaxlib-image -f docker/build_jaxlib.Dockerfile docker/ 42 | 43 | - name: Compile Jaxlib 44 | run: | 45 | mkdir -p dist 46 | docker run --gpus all --tmpfs /build:exec \ 47 | --rm -v $(pwd)/dist:/dist build-jaxlib-image \ 48 | 3.8 cuda 11.1 main ${TF_BRANCH##*/} 49 | 50 | # change this to publishing to pypi 51 | - name: Publish to local 52 | run: | 53 | echo "Move the Jaxlib binary" 54 | mv dist/*.whl /data/alpa-dist/jaxlib-alpa-ci/ 55 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | workflow_run: 5 | workflows: [Build Jaxlib and Jax] 6 | types: 7 | - completed 8 | workflow_dispatch: 9 | push: 10 | branches: [main] 11 | pull_request: 12 | branches: [main] 13 | 14 | jobs: 15 | yapf: 16 | runs-on: ubuntu-latest 17 | strategy: 18 | matrix: 19 | python-version: ["3.7"] 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install yapf==0.32.0 30 | - name: Running yapf 31 | run: | 32 | yapf --diff --style .style.yapf --recursive alpa && yapf --diff --style .style.yapf --recursive tests 33 | 34 | pylint: 35 | runs-on: ubuntu-latest 36 | strategy: 37 | matrix: 38 | python-version: ["3.7"] 39 | steps: 40 | - uses: actions/checkout@v2 41 | - name: Set up Python ${{ matrix.python-version }} 42 | uses: actions/setup-python@v2 43 | with: 44 | python-version: ${{ matrix.python-version }} 45 | - name: Install dependencies 46 | run: | 47 | python -m pip install --upgrade pip 48 | pip install pylint==2.14.0 49 | - name: Analysing the code with pylint 50 | run: | 51 | pylint alpa 52 | 53 | Unittest: 54 | runs-on: [self-hosted, gpu] 55 | needs: [yapf, pylint] 56 | steps: 57 | - name: Cancel previous 58 | uses: styfle/cancel-workflow-action@0.9.1 59 | with: 60 | access_token: ${{ secrets.PAT_TOKEN }} 61 | if: | 62 | github.event_name =='pull_request' && 63 | github.event.pull_request.head.repo.full_name == github.repository 64 | 65 | - uses: actions/checkout@v3 66 | 67 | - name: clean up images 68 | run: | 69 | docker image prune -f 70 | 71 | - name: build test image 72 | run: | 73 | docker build -t test-alpa-image -f docker/unittest.Dockerfile docker/ 74 | 75 | - name: Test 76 | run: | 77 | ALPA_BRANCH=${{ github.ref }} 78 | echo "${ALPA_BRANCH}" 79 | 80 | docker run --gpus all --tmpfs /build:exec --rm \ 81 | -v /data/alpa-dist:/alpa-dist \ 82 | --shm-size=10.24gb test-alpa-image 3.8 ${ALPA_BRANCH} 83 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | # This workflow will generate docs for alpa. 2 | 3 | name: Docs 4 | 5 | on: 6 | workflow_dispatch: 7 | 8 | jobs: 9 | build_docs: 10 | runs-on: [self-hosted, alpa] 11 | 12 | steps: 13 | - uses: actions/checkout@v3 14 | 15 | - name: Set up Python 3.8 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: 3.8 19 | 20 | - name: build doc-building image 21 | run: | 22 | docker build -t build-alpa-doc -f docker/build_doc.Dockerfile docker/ 23 | 24 | - name: Build docs 25 | run: | 26 | docker run --gpus all --tmpfs /build:exec --rm \ 27 | -v /data/alpa-dist:/alpa-dist \ 28 | --shm-size=10.24gb \ 29 | build-alpa-doc 30 | 31 | - name: Deploy 32 | uses: peaceiris/actions-gh-pages@v3 33 | with: 34 | personal_token: ${{ secrets.PAT_TOKEN }} 35 | external_repository: alpa-projects/alpa-projects.github.io 36 | publish_branch: master 37 | publish_dir: /data/alpa-dist/docs 38 | keep_files: true 39 | -------------------------------------------------------------------------------- /.github/workflows/release_alpa.yml: -------------------------------------------------------------------------------- 1 | name: Release Alpa 2 | 3 | on: 4 | release: 5 | types: [created] 6 | workflow_dispatch: 7 | 8 | env: 9 | TWINE_USERNAME: "__token__" 10 | TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} 11 | 12 | jobs: 13 | 14 | build-image: 15 | runs-on: [self-hosted] 16 | 17 | steps: 18 | - uses: actions/checkout@v3 19 | 20 | - name: clean up images 21 | run: | 22 | docker image prune -f 23 | 24 | - name: build docker image 25 | run: | 26 | docker build -t build-alpa-image -f docker/build_alpa.Dockerfile docker/ 27 | 28 | release-alpa: 29 | runs-on: [self-hosted] 30 | needs: [build-image] 31 | 32 | steps: 33 | - uses: actions/checkout@v3 34 | 35 | - name: Build Alpa wheels 36 | run: | 37 | mkdir -p dist 38 | docker run --gpus all --tmpfs /build:exec \ 39 | --rm -v $(pwd)/dist:/dist --entrypoint /build_alpa.sh \ 40 | build-alpa-image 3.8 ${ALPA_BRANCH} 41 | env: 42 | ALPA_BRANCH: ${{ github.ref }} 43 | 44 | - name: Set up Python 3.8 45 | uses: actions/setup-python@v3 46 | with: 47 | python-version: 3.8 48 | 49 | - name: Install dependencies 50 | run: | 51 | python -m pip install --upgrade pip 52 | pip install twine 53 | 54 | - name: Publish to Pypi 55 | run: | 56 | echo "Publish to PyPI" 57 | ls -ltr dist/ 58 | python -m twine upload --verbose dist/* 59 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python cache 2 | __pycache__ 3 | *.pyc 4 | dist 5 | *.egg-info 6 | .cache 7 | *env 8 | 9 | # NFS temp files 10 | .nfs* 11 | 12 | # Vim 13 | *.swp 14 | 15 | # pycharm 16 | .idea 17 | 18 | # vscode 19 | *vscode* 20 | 21 | # Build files 22 | alpa/pipeline_parallel/xla_custom_call_marker/build 23 | build/lib 24 | build/bdist* 25 | build_jaxlib/build/bazel* 26 | build_jaxlib/bazel-* 27 | build_jaxlib/.jax_configure.bazelrc 28 | build_jaxlib/dist 29 | 30 | # Examples build and tmp files 31 | examples/build/ 32 | examples/imagenet/imagenet 33 | examples/llm_serving/dataset/*.so 34 | examples/llm_serving/dataset/*.c 35 | examples/llm_serving/dataset/*.cpp 36 | examples/llm_serving/weblogs 37 | examples/llm_serving/keys_file.json 38 | examples/llm_serving/benchmark/tmp* 39 | examples/llm_serving/tmp* 40 | examples/opt_finetune/output/ 41 | examples/gpt2/norwegian-gpt2/ 42 | alpa_debug_info 43 | 44 | # Analysis temp files 45 | *.nvprof 46 | *.prof 47 | *.tsv 48 | *.hlo 49 | *.pkl 50 | benchmark/alpa/tmp* 51 | benchmark/alpa/chrome_trace 52 | *.log 53 | 54 | # Tests temp files 55 | tests/tmp 56 | tests/*/tmp 57 | 58 | # Dataset 59 | benchmark/deepspeed/data 60 | 61 | # plots 62 | benchmark/*.pdf 63 | 64 | # Numpy cache 65 | *.npy 66 | 67 | # Documentation website build 68 | docs/_build 69 | docs/tutorials 70 | 71 | # macOS temp files 72 | .DS_Store 73 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/jax"] 2 | path = third_party/jax 3 | url = https://github.com/google/jax.git 4 | [submodule "third_party/tensorflow-alpa"] 5 | path = third_party/tensorflow-alpa 6 | url = https://github.com/alpa-projects/tensorflow-alpa.git 7 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = google 3 | -------------------------------------------------------------------------------- /alpa/__init__.py: -------------------------------------------------------------------------------- 1 | """Alpa is a system for training large-scale neural networks.""" 2 | # Import all public packages 3 | from . import api 4 | from . import collective 5 | from . import create_state_parallel 6 | from . import data_loader 7 | from . import device_mesh 8 | from . import follow_parallel 9 | from . import global_env 10 | from . import mesh_executable 11 | from . import mesh_profiling 12 | from . import monkey_patch 13 | from . import parallel_method 14 | from . import parallel_plan 15 | from . import pipeline_parallel 16 | from . import shard_parallel 17 | from . import timer 18 | from . import util 19 | from . import version 20 | from . import wrapped_hlo 21 | 22 | # Short cuts 23 | from alpa.api import (init, shutdown, parallelize, grad, value_and_grad, 24 | clear_executable_cache) 25 | from alpa.data_loader import DataLoader, MeshDriverDataLoader 26 | from alpa.device_mesh import ( 27 | DeviceCluster, PhysicalDeviceMesh, LocalPhysicalDeviceMesh, 28 | DistributedPhysicalDeviceMesh, DistributedArray, prefetch, 29 | get_global_cluster, get_global_physical_mesh, 30 | get_global_virtual_physical_mesh, set_global_virtual_physical_mesh, 31 | set_seed, get_global_num_devices) 32 | from alpa.global_env import global_config 33 | from alpa.mesh_profiling import ProfilingResultDatabase 34 | from alpa.parallel_method import (ShardParallel, DataParallel, Zero2Parallel, 35 | Zero3Parallel, PipeshardParallel, 36 | CreateStateParallel, FollowParallel, 37 | get_3d_parallel_method) 38 | from alpa.parallel_plan import plan_to_method 39 | from alpa.pipeline_parallel.primitive_def import mark_pipeline_boundary 40 | from alpa.pipeline_parallel.layer_construction import (manual_remat, 41 | automatic_remat, 42 | ManualLayerOption, 43 | AutoLayerOption) 44 | from alpa.pipeline_parallel.stage_construction import (ManualStageOption, 45 | AutoStageOption, 46 | UniformStageOption) 47 | from alpa.shard_parallel.auto_sharding import AutoShardingOption 48 | from alpa.shard_parallel.manual_sharding import ManualShardingOption 49 | from alpa.serialization import save_checkpoint, restore_checkpoint 50 | from alpa.timer import timers 51 | from alpa.version import __version__ 52 | -------------------------------------------------------------------------------- /alpa/collective/__init__.py: -------------------------------------------------------------------------------- 1 | """Alpa's wrapper for NCCL collective operations.""" 2 | 3 | from alpa.collective.collective import ( 4 | nccl_available, gloo_available, is_group_initialized, init_collective_group, 5 | destroy_collective_group, create_collective_group, get_rank, 6 | get_collective_group_size, allreduce, allreduce_multigpu, barrier, reduce, 7 | reduce_multigpu, broadcast, broadcast_partialgpu, broadcast_multigpu, 8 | allgather, allgather_multigpu, reducescatter, reducescatter_multigpu, send, 9 | send_multigpu, recv, recv_multigpu, check_and_get_group, record_events, 10 | wait_events, comm_wait_compute, compute_wait_comm) 11 | 12 | __all__ = [ 13 | "nccl_available", "gloo_available", "is_group_initialized", 14 | "init_collective_group", "destroy_collective_group", 15 | "create_collective_group", "get_rank", "get_collective_group_size", 16 | "allreduce", "allreduce_multigpu", "barrier", "reduce", "reduce_multigpu", 17 | "broadcast", "broadcast_partialgpu", "broadcast_multigpu", "allgather", 18 | "allgather_multigpu", "reducescatter", "reducescatter_multigpu", "send", 19 | "send_multigpu", "recv", "recv_multigpu", "check_and_get_group", 20 | "record_events", "wait_events", "comm_wait_compute", "compute_wait_comm" 21 | ] 22 | -------------------------------------------------------------------------------- /alpa/collective/collective_group/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alpa-projects/alpa/b8078a9f75cb4c90cabb4550ee48c99ef394e209/alpa/collective/collective_group/__init__.py -------------------------------------------------------------------------------- /alpa/collective/collective_group/xla_nccl_util.py: -------------------------------------------------------------------------------- 1 | """Code to wrap NCCL API calls from XLA extension.""" 2 | from jax._src.lib import xla_extension as xe 3 | 4 | 5 | def get_nccl_runtime_version(): 6 | return xe.nccl_get_version() 7 | 8 | 9 | def get_nccl_unique_id(): 10 | return xe.nccl_get_unique_id() 11 | -------------------------------------------------------------------------------- /alpa/collective/const.py: -------------------------------------------------------------------------------- 1 | """ 2 | Constants. 3 | 4 | Contains constants used to setup collective groups. 5 | """ 6 | import hashlib 7 | import os 8 | from enum import Enum, auto 9 | 10 | 11 | def get_store_name(group_name): 12 | """Generate the unique name for the NCCLUniqueID store (named actor). 13 | 14 | Args: 15 | group_name (str): unique user name for the store. 16 | Return: 17 | str: MD5-hexlified name for the store. 18 | """ 19 | if not group_name: 20 | raise ValueError("group_name is None.") 21 | hexlified_name = hashlib.md5(group_name.encode()).hexdigest() 22 | return hexlified_name 23 | 24 | 25 | class ENV(Enum): 26 | """Environment variables.""" 27 | 28 | NCCL_USE_MULTISTREAM = auto(), lambda v: (v or "True") == "True" 29 | 30 | @property 31 | def val(self): 32 | """Return the output of the lambda against the system's env value.""" 33 | _, default_fn = self.value # pylint: disable=unpacking-non-sequence 34 | return default_fn(os.getenv(self.name)) 35 | -------------------------------------------------------------------------------- /alpa/collective/requirements.txt: -------------------------------------------------------------------------------- 1 | cupy-cuda111 -------------------------------------------------------------------------------- /alpa/collective/types.py: -------------------------------------------------------------------------------- 1 | """Types conversion between different backends.""" 2 | from enum import Enum 3 | from dataclasses import dataclass 4 | from datetime import timedelta 5 | 6 | _NUMPY_AVAILABLE = True 7 | _TORCH_AVAILABLE = False 8 | _CUPY_AVAILABLE = True 9 | 10 | try: 11 | import cupy as cp # pylint: disable=unused-import 12 | except ImportError: 13 | _CUPY_AVAILABLE = False 14 | 15 | 16 | def cupy_available(): 17 | return _CUPY_AVAILABLE 18 | 19 | 20 | def torch_available(): 21 | return _TORCH_AVAILABLE 22 | 23 | 24 | class Backend: 25 | """A class to represent different backends.""" 26 | NCCL = "nccl" 27 | MPI = "mpi" 28 | GLOO = "gloo" 29 | UNRECOGNIZED = "unrecognized" 30 | 31 | def __new__(cls, name: str): 32 | backend = getattr(Backend, name.upper(), Backend.UNRECOGNIZED) 33 | if backend == Backend.UNRECOGNIZED: 34 | raise ValueError(f"Unrecognized backend: '{name}'. " 35 | "Only NCCL is supported") 36 | if backend == Backend.MPI: 37 | raise RuntimeError("Ray does not support MPI backend.") 38 | return backend 39 | 40 | 41 | class ReduceOp(Enum): 42 | SUM = 0 43 | PRODUCT = 1 44 | MIN = 2 45 | MAX = 3 46 | 47 | 48 | unset_timeout_ms = timedelta(milliseconds=-1) 49 | 50 | 51 | @dataclass 52 | class AllReduceOptions: 53 | reduce_op = ReduceOp.SUM 54 | timeout_ms = unset_timeout_ms 55 | 56 | 57 | @dataclass 58 | class BarrierOptions: 59 | timeout_ms = unset_timeout_ms 60 | 61 | 62 | @dataclass 63 | class ReduceOptions: 64 | reduce_op = ReduceOp.SUM 65 | root_rank = 0 66 | root_tensor = 0 # index for multi-gpu reduce operations 67 | timeout_ms = unset_timeout_ms 68 | 69 | 70 | @dataclass 71 | class AllGatherOptions: 72 | timeout_ms = unset_timeout_ms 73 | 74 | 75 | # 76 | # @dataclass 77 | # class GatherOptions: 78 | # root_rank = 0 79 | # timeout = unset_timeout 80 | 81 | 82 | @dataclass 83 | class BroadcastOptions: 84 | comm_key = "" 85 | world_size = 0 86 | devices_ids = [] 87 | devices_global_rank = [] 88 | n_elements = 0 89 | timeout_ms = unset_timeout_ms 90 | local_start_pos_list = [] 91 | 92 | 93 | @dataclass 94 | class ReduceScatterOptions: 95 | reduce_op = ReduceOp.SUM 96 | timeout_ms = unset_timeout_ms 97 | 98 | 99 | @dataclass 100 | class SendOptions: 101 | dst_rank = 0 102 | dst_gpu_index = 0 103 | n_elements = 0 104 | timeout_ms = unset_timeout_ms 105 | start_pos = 0 106 | 107 | 108 | @dataclass 109 | class RecvOptions: 110 | src_rank = 0 111 | src_gpu_index = 0 112 | n_elements = 0 113 | unset_timeout_ms = unset_timeout_ms 114 | start_pos = 0 115 | -------------------------------------------------------------------------------- /alpa/collective/util.py: -------------------------------------------------------------------------------- 1 | """Some utility class for Collectives.""" 2 | import logging 3 | import ray 4 | 5 | logger = logging.getLogger(__name__) 6 | logger.setLevel(logging.DEBUG) 7 | 8 | 9 | @ray.remote 10 | class NCCLUniqueIDStore: 11 | """NCCLUniqueID Store as a named actor class. 12 | 13 | Args: 14 | name (str): the unique name for this named actor. 15 | 16 | Attributes: 17 | name (str): the unique name for this named actor. 18 | nccl_id (str): the NCCLUniqueID held in this store. 19 | """ 20 | 21 | def __init__(self, name): 22 | self.name = name 23 | self.nccl_id = None 24 | 25 | # A counter for this actor to auto-destory itself. 26 | self.access_counter = 1 27 | 28 | def set_id(self, uid): 29 | """ 30 | Initialize the NCCL unique ID for this store. 31 | 32 | Args: 33 | uid (str): the unique ID generated via the NCCL get_unique_id API. 34 | 35 | Returns: 36 | None 37 | """ 38 | self.nccl_id = uid 39 | return self.nccl_id 40 | 41 | def get_id(self): 42 | """Get the NCCL unique ID held in this store.""" 43 | if not self.nccl_id: 44 | logger.debug("The NCCL ID has not been set yet " 45 | f"for store {self.name} by rank-0 process.") 46 | return None 47 | else: 48 | self.access_counter += 1 49 | return self.nccl_id 50 | 51 | def get_access_counter(self): 52 | return self.access_counter 53 | 54 | 55 | @ray.remote 56 | class Info: 57 | """Store the group information created via `create_collective_group`. 58 | 59 | Note: Should be used as a NamedActor. 60 | """ 61 | 62 | def __init__(self): 63 | self.ids = None 64 | self.world_size = -1 65 | self.rank = -1 66 | self.backend = None 67 | self.access_counter = 0 68 | 69 | def set_info(self, ids, world_size, rank, backend): 70 | """Store collective information.""" 71 | self.ids = ids 72 | self.world_size = world_size 73 | self.rank = rank 74 | self.backend = backend 75 | 76 | def get_info(self): 77 | """Get previously stored collective information.""" 78 | self.access_counter += 1 79 | return self.ids, self.world_size, self.rank, self.backend 80 | 81 | def get_access_counter(self): 82 | return self.access_counter 83 | -------------------------------------------------------------------------------- /alpa/collective/worker_nccl_util.py: -------------------------------------------------------------------------------- 1 | """Unified Nccl APIs for cross-mesh resharding.""" 2 | from typing import Sequence 3 | 4 | import alpa.collective.worker_nccl_util_cupy as cupy_impl 5 | import alpa.collective.worker_nccl_util_xla as xla_impl 6 | from alpa.global_env import global_config 7 | 8 | 9 | def _switch_impl(cupy_fn, xla_fn, *args): 10 | if global_config.nccl_mode == "cupy": 11 | return cupy_fn(*args) 12 | elif global_config.nccl_mode == "xla_extension": 13 | return xla_fn(*args) 14 | else: 15 | raise ValueError(f"nccl mode {global_config.nccl_mode} is illegal") 16 | 17 | 18 | def send_tile(worker, uuid: int, device_id: int, offset: Sequence[slice], 19 | dst_rank: int, dst_gpu_idx: int, group_name: str): 20 | return _switch_impl(cupy_impl.send_tile, xla_impl.send_tile, worker, uuid, 21 | device_id, offset, dst_rank, dst_gpu_idx, group_name) 22 | 23 | 24 | def recv_tile(worker, uuid: int, device_id: int, 25 | indices_in_dst_tile: Sequence[slice], src_rank: int, 26 | src_gpu_idx: int, group_name: str): 27 | return _switch_impl(cupy_impl.recv_tile, xla_impl.recv_tile, worker, uuid, 28 | device_id, indices_in_dst_tile, src_rank, src_gpu_idx, 29 | group_name) 30 | 31 | 32 | def broadcast(worker, uuid: int, comm_key: str, world_size: int, 33 | devices_ids: Sequence[int], devices_global_rank: Sequence[int], 34 | tensor_slices: Sequence[Sequence[slice]], group_name: str): 35 | return _switch_impl(cupy_impl.broadcast, xla_impl.broadcast, worker, uuid, 36 | comm_key, world_size, devices_ids, devices_global_rank, 37 | tensor_slices, group_name) 38 | 39 | 40 | def allgather(worker, uuid: int, device_ids: Sequence[int], 41 | tensor_slices: Sequence[Sequence[slice]], output_slice): 42 | return _switch_impl(cupy_impl.allgather, xla_impl.allgather, worker, uuid, 43 | device_ids, tensor_slices, output_slice) 44 | 45 | 46 | def to_signal_buffer(jax_tensor): 47 | return _switch_impl(cupy_impl.to_signal_buffer, xla_impl.to_signal_buffer, 48 | jax_tensor) 49 | -------------------------------------------------------------------------------- /alpa/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alpa-projects/alpa/b8078a9f75cb4c90cabb4550ee48c99ef394e209/alpa/model/__init__.py -------------------------------------------------------------------------------- /alpa/parallel_plan.py: -------------------------------------------------------------------------------- 1 | """ 2 | The data strcutures to save all configurations/strategies of 3 | a parallel execution plan. 4 | """ 5 | from dataclasses import dataclass 6 | from typing import Sequence, Tuple 7 | 8 | import numpy as np 9 | from jax.core import ShapedArray 10 | from jax.interpreters import pxla 11 | 12 | 13 | @dataclass 14 | class PlacementSpec: 15 | """Specify how a tensor is stored distributedly.""" 16 | aval: ShapedArray 17 | mesh_ids: Sequence[int] 18 | sharding_specs: Sequence[pxla.ShardingSpec] 19 | 20 | 21 | @dataclass 22 | class StagePlan: 23 | """The parallel plan for a single sharded stage.""" 24 | build_random_seed: int 25 | logical_mesh_shape: Tuple[int] 26 | all_gather_threshold: int 27 | all_reduce_threshold: int 28 | auto_sharding_option: "AutoShardingOption" 29 | auto_sharding_solution_vector: np.ndarray 30 | auto_sharding_objective: int 31 | 32 | 33 | @dataclass 34 | class PipelinePlan: 35 | """The parallel plan for a pipeline.""" 36 | pipeline_schedule: str 37 | layer_option: "LayerOption" 38 | manual_stage_option: "ManualStageOption" 39 | 40 | 41 | @dataclass 42 | class ClusterInfo: 43 | num_hosts: int 44 | num_devices_per_host: int 45 | 46 | 47 | @dataclass 48 | class ParallelPlan: 49 | """The global parallel plan.""" 50 | cluster_info: ClusterInfo 51 | num_micro_batches: int 52 | auto_sharding_option: "AutoShardingOption" 53 | pipeline_plan: PipelinePlan 54 | input_placement_specs: Sequence[PlacementSpec] 55 | 56 | 57 | def plan_to_method(plan: ParallelPlan) -> "ParallelMethod": 58 | """Convert a parallel plan to a parallel method.""" 59 | # pylint: disable=import-outside-toplevel 60 | from alpa.parallel_method import ShardParallel, PipeshardParallel 61 | 62 | if plan.pipeline_plan is None: 63 | return ShardParallel(num_micro_batches=plan.num_micro_batches, 64 | auto_sharding_option=plan.auto_sharding_option) 65 | else: 66 | return PipeshardParallel( 67 | num_micro_batches=plan.num_micro_batches, 68 | default_auto_sharding_option=plan.auto_sharding_option, 69 | pipeline_schedule=plan.pipeline_plan.pipeline_schedule, 70 | layer_option=plan.pipeline_plan.layer_option, 71 | stage_option=plan.pipeline_plan.manual_stage_option) 72 | -------------------------------------------------------------------------------- /alpa/pipeline_parallel/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alpa-projects/alpa/b8078a9f75cb4c90cabb4550ee48c99ef394e209/alpa/pipeline_parallel/__init__.py -------------------------------------------------------------------------------- /alpa/serve/__init__.py: -------------------------------------------------------------------------------- 1 | """Alpa serving backend""" 2 | from alpa.serve.controller import CONTROLLER_NAME, run_controller 3 | -------------------------------------------------------------------------------- /alpa/serve/run.py: -------------------------------------------------------------------------------- 1 | """Run a controller.""" 2 | import argparse 3 | 4 | import ray 5 | 6 | from alpa.serve.controller import run_controller 7 | 8 | if __name__ == "__main__": 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--host", type=str, default="localhost") 11 | parser.add_argument("--port", type=int) 12 | parser.add_argument("--root-path", type=str, default="/") 13 | args = parser.parse_args() 14 | 15 | ray.init(address="auto", namespace="alpa_serve") 16 | controller = run_controller(args.host, args.port, args.root_path) 17 | 18 | while True: 19 | pass 20 | -------------------------------------------------------------------------------- /alpa/shard_parallel/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alpa-projects/alpa/b8078a9f75cb4c90cabb4550ee48c99ef394e209/alpa/shard_parallel/__init__.py -------------------------------------------------------------------------------- /alpa/test_install.py: -------------------------------------------------------------------------------- 1 | """Some basic tests to test installation.""" 2 | import os 3 | import unittest 4 | 5 | from alpa import (init, parallelize, ShardParallel, PipeshardParallel, 6 | AutoLayerOption, prefetch) 7 | from alpa.device_mesh import get_global_cluster 8 | from alpa.testing import assert_allclose, get_mlp_train_state_and_step 9 | 10 | 11 | class InstallationTest(unittest.TestCase): 12 | 13 | def setUp(self): 14 | os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" 15 | 16 | def test_1_shard_parallel(self): 17 | state, batch, train_step = get_mlp_train_state_and_step(batch_size=128, 18 | hidden_size=128, 19 | num_layers=4) 20 | 21 | # Serial execution 22 | expected_output = train_step(state, batch) 23 | 24 | # Parallel execution 25 | p_train_step = parallelize(train_step, 26 | method=ShardParallel(num_micro_batches=2)) 27 | actual_output = p_train_step(state, batch) 28 | 29 | # Check results 30 | assert_allclose(expected_output, actual_output) 31 | 32 | def test_2_pipeline_parallel(self): 33 | init(cluster="ray") 34 | 35 | state, batch, train_step = get_mlp_train_state_and_step(batch_size=128, 36 | hidden_size=128, 37 | num_layers=6) 38 | 39 | # Serial execution 40 | expected_output = train_step(state, batch) 41 | 42 | # Parallel execution 43 | layer_num = min(get_global_cluster().num_devices, 2) 44 | p_train_step = parallelize( 45 | train_step, 46 | method=PipeshardParallel( 47 | num_micro_batches=2, 48 | layer_option=AutoLayerOption(layer_num=layer_num))) 49 | actual_output = p_train_step(state, batch) 50 | 51 | # Check results 52 | prefetch(actual_output) 53 | assert_allclose(expected_output, actual_output) 54 | 55 | 56 | def suite(): 57 | s = unittest.TestSuite() 58 | s.addTest(InstallationTest("test_1_shard_parallel")) 59 | s.addTest(InstallationTest("test_2_pipeline_parallel")) 60 | return s 61 | 62 | 63 | if __name__ == "__main__": 64 | runner = unittest.TextTestRunner() 65 | runner.run(suite()) 66 | -------------------------------------------------------------------------------- /alpa/timer.py: -------------------------------------------------------------------------------- 1 | """Global timer for profiling.""" 2 | from collections import namedtuple 3 | import time 4 | from typing import Callable, Any 5 | 6 | 7 | class _Timer: 8 | """An internal timer.""" 9 | 10 | def __init__(self, name: str): 11 | self.name = name 12 | self.started = False 13 | self.start_time = None 14 | 15 | # start-stop timestamp pairs 16 | self.start_times = [] 17 | self.stop_times = [] 18 | self.costs = [] 19 | 20 | def start(self, sync_func: Callable = None): 21 | """Start the timer.""" 22 | assert not self.started, f"timer {self.name} has already been started." 23 | if sync_func: 24 | sync_func() 25 | 26 | self.start_time = time.time() 27 | self.start_times.append(self.start_time) 28 | self.started = True 29 | 30 | def stop(self, sync_func: Callable = None): 31 | """Stop the timer.""" 32 | assert self.started, f"timer {self.name} is not started." 33 | if sync_func: 34 | sync_func() 35 | 36 | stop_time = time.time() 37 | self.costs.append(stop_time - self.start_time) 38 | self.stop_times.append(stop_time) 39 | self.started = False 40 | 41 | def reset(self): 42 | """Reset timer.""" 43 | self.started = False 44 | self.start_time = None 45 | self.start_times = [] 46 | self.stop_times = [] 47 | self.costs = [] 48 | 49 | def elapsed(self, mode: str = "average"): 50 | """Calculate the elapsed time.""" 51 | if not self.costs: 52 | return 0.0 53 | if mode == "average": 54 | return sum(self.costs) / len(self.costs) 55 | elif mode == "sum": 56 | return sum(self.costs) 57 | else: 58 | raise RuntimeError("Supported mode is: average | sum") 59 | 60 | 61 | class Timers: 62 | """A group of timers.""" 63 | 64 | def __init__(self): 65 | self.timers = {} 66 | 67 | def __call__(self, name: str): 68 | if name not in self.timers: 69 | self.timers[name] = _Timer(name) 70 | return self.timers[name] 71 | 72 | def __contains__(self, name: str): 73 | return name in self.timers 74 | 75 | 76 | timers = Timers() 77 | 78 | Event = namedtuple("Event", ("tstamp", "name", "info")) 79 | 80 | 81 | class Tracer: 82 | """An activity tracer.""" 83 | 84 | def __init__(self): 85 | self.events = [] 86 | 87 | def log(self, name: str, info: Any, sync_func: Callable = None): 88 | if sync_func: 89 | sync_func() 90 | 91 | self.events.append(Event(time.time(), name, info)) 92 | 93 | 94 | tracer = Tracer() 95 | -------------------------------------------------------------------------------- /alpa/torch/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alpa-projects/alpa/b8078a9f75cb4c90cabb4550ee48c99ef394e209/alpa/torch/ops/__init__.py -------------------------------------------------------------------------------- /alpa/torch/optim/__init__.py: -------------------------------------------------------------------------------- 1 | """Optimizers 2 | """ 3 | from .adam import adam 4 | -------------------------------------------------------------------------------- /alpa/torch/optim/adam.py: -------------------------------------------------------------------------------- 1 | """Adam optimizer""" 2 | import copy 3 | 4 | import torch 5 | 6 | 7 | def adam(lr=1e-4): 8 | """torchoptim.adam(**adam_config)(params) 9 | Factory that generates functional version of Adam optimizer. 10 | Implementation has no in-place op and no data-dependent control flow. 11 | 12 | Returns: 13 | - `optim_func`: a function that: 14 | - takes (`params`, `optim_state`, `params_grad`) as input 15 | - returns (`params`, `optim_state`) 16 | after applying Adam algorithm 17 | - `optim_state_init_func`: a function that: 18 | - takes `optim_state` as input 19 | - returns `optim_state` which is Adam optimizer state 20 | - `optim_state`: tracked state (shape-only) of Adam optimizer. 21 | """ 22 | 23 | # TODO FIXME: properly implement Adam optimizer 24 | 25 | def optim_gen(params): 26 | 27 | def optim_func(params, optim_state, params_grad): 28 | for k in params: 29 | params[k] = params[k] + params_grad[k] * lr 30 | optim_state[k] = optim_state[k] + params_grad[k] 31 | return params, optim_state 32 | 33 | optim_state = copy.deepcopy(params) 34 | 35 | def optim_state_init_func(optim_state): 36 | new_state = {} 37 | for k, v in optim_state.items(): 38 | new_state[k] = torch.full_like(v, 0.0) 39 | return new_state 40 | 41 | return optim_func, optim_state_init_func, optim_state 42 | 43 | return optim_gen 44 | -------------------------------------------------------------------------------- /alpa/wrapped_hlo.py: -------------------------------------------------------------------------------- 1 | """A class that wraps HloModule and records whether the module runs AutoSharding 2 | and SPMD Partitioner or not. 3 | """ 4 | from enum import Enum, auto 5 | from typing import Union 6 | 7 | from jax._src.lib import xla_extension as xe 8 | from jax.interpreters import mlir 9 | 10 | 11 | class HloStatus(Enum): 12 | """ 13 | The status of an HloModule. 14 | See also the docstring at the beginning of shard_parallel/auto_sharding.py. 15 | """ 16 | UNOPTIMIZED = auto() 17 | SHARDING_ANNOTATED = auto() 18 | SPMD_PARTITIONED = auto() 19 | FULLY_OPTIMIZED = auto() 20 | 21 | 22 | class WrappedHlo: 23 | """Wrapped HloModule with HloStatus.""" 24 | 25 | def __init__(self, 26 | module: Union[xe.HloModule, xe.XlaComputation, bytes], 27 | status: HloStatus = HloStatus.UNOPTIMIZED): 28 | if isinstance(module, xe.HloModule): 29 | self.module = module 30 | elif isinstance(module, xe.XlaComputation): 31 | self.module = module.get_hlo_module() 32 | else: 33 | assert isinstance(module, bytes) 34 | self.module = xe.XlaComputation(module).get_hlo_module() 35 | self.name = self.module.name 36 | self.status = status 37 | self.is_manually_annotated = False 38 | 39 | def get_computation(self) -> xe.XlaComputation: 40 | return xe.XlaComputation(self.module.as_serialized_hlo_module_proto()) 41 | 42 | def get_mhlo(self): 43 | xla_computation = self.get_computation() 44 | module_str = xe.mlir.xla_computation_to_mlir_module(xla_computation) 45 | with mlir.make_ir_context(): 46 | mhlo = mlir.ir.Module.parse(module_str) 47 | return mhlo 48 | 49 | def get_module(self) -> xe.HloModule: 50 | return self.module 51 | 52 | def get_hlo_proto(self): 53 | return self.module.as_serialized_hlo_module_proto() 54 | 55 | def program_shape(self): 56 | return self.module.program_shape() 57 | 58 | def set_input_shardings(self, sharding_protos): 59 | assert self.is_sharding_annotated() or self.is_unoptimized() 60 | xe.set_hlo_module_input_shardings(self.module, sharding_protos) 61 | 62 | def set_output_shardings(self, sharding_protos): 63 | assert self.is_sharding_annotated() or self.is_unoptimized() 64 | xe.set_hlo_module_output_shardings(self.module, sharding_protos) 65 | 66 | def is_unoptimized(self): 67 | return self.status == HloStatus.UNOPTIMIZED 68 | 69 | def is_sharding_annotated(self): 70 | return self.status == HloStatus.SHARDING_ANNOTATED 71 | 72 | def is_spmd_partitioned(self): 73 | return self.status == HloStatus.SPMD_PARTITIONED 74 | 75 | def to_string(self): 76 | return self.module.to_string() 77 | 78 | def __getstate__(self): 79 | return (self.get_hlo_proto(), self.status) 80 | 81 | def __setstate__(self, bytes_and_status): 82 | b, s = bytes_and_status 83 | self.__init__(b, s) 84 | -------------------------------------------------------------------------------- /benchmark/alpa/gather_gpu_stat.py: -------------------------------------------------------------------------------- 1 | """Gather gpu utilization from all nodes.""" 2 | 3 | import os 4 | import tempfile 5 | 6 | import gpustat 7 | import ray 8 | 9 | 10 | def call_nvidia_smi(): 11 | gpus = gpustat.new_query().gpus 12 | return [g.utilization for g in gpus] 13 | 14 | 15 | if __name__ == "__main__": 16 | ray.init(address="auto") 17 | 18 | host_info = [] 19 | for node in ray.nodes(): 20 | for key in node["Resources"]: 21 | if key.startswith("node:"): 22 | host_info.append(node) 23 | 24 | results = [] 25 | for i in range(len(host_info)): 26 | # Launch a ray actor 27 | node_resource = "node:" + host_info[i]["NodeManagerAddress"] 28 | func = ray.remote(resources={node_resource: 1e-3})(call_nvidia_smi) 29 | results.append(func.remote()) 30 | results = ray.get(results) 31 | 32 | for i in range(len(host_info)): 33 | print(host_info[i]["NodeManagerAddress"]) 34 | print(results[i]) 35 | -------------------------------------------------------------------------------- /benchmark/alpa/gen_serving_database.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 run_exp.py gpt_inference 4 | python3 gen_serving_database.py 5 | """ 6 | 7 | import argparse 8 | 9 | from alpa_serve.profiling import ProfilingDatabase 10 | 11 | if __name__ == "__main__": 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--input", type=str, default="inference_prof_res.tsv") 14 | parser.add_argument("--output", type=str, default="profiling_result.pkl") 15 | parser.add_argument("--new", action="store_true") 16 | args = parser.parse_args() 17 | 18 | database = ProfilingDatabase(args.output, args.new) 19 | database.update_from_csv(args.input) 20 | database.materialize() 21 | -------------------------------------------------------------------------------- /benchmark/alpa/inspect_prof_database.py: -------------------------------------------------------------------------------- 1 | """Inspect and edit a profiling database.""" 2 | import argparse 3 | 4 | from alpa import DeviceCluster, ProfilingResultDatabase 5 | from alpa.util import run_cmd 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--filename", type=str, default="prof_database.pkl") 10 | args = parser.parse_args() 11 | 12 | prof_database = ProfilingResultDatabase() 13 | prof_database.load(args.filename) 14 | 15 | # Do some editing 16 | #prof_database.insert_dummy_mesh_result("default", (8, 8)) 17 | #prof_database.save(args.filename) 18 | 19 | # Print results 20 | print("Meshes:") 21 | print(list(prof_database.data.keys())) 22 | print() 23 | 24 | mesh_result = prof_database.query("default", (2, 8)) 25 | print(mesh_result) 26 | -------------------------------------------------------------------------------- /benchmark/alpa/run_exp.py: -------------------------------------------------------------------------------- 1 | """Run search experiments with mutliple cluster settings.""" 2 | import argparse 3 | from datetime import datetime 4 | import os 5 | import subprocess 6 | import sys 7 | 8 | from benchmark import benchmark_suite 9 | 10 | 11 | def run_exp(exp_name, cluster_settings, suite_name, benchmark_settings=None): 12 | os.environ["PYTHONUNBUFFERED"] = "1" 13 | now = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") 14 | 15 | tee = subprocess.Popen(["tee", f"{now}_{suite_name}.log"], 16 | stdin=subprocess.PIPE) 17 | os.dup2(tee.stdin.fileno(), sys.stdout.fileno()) 18 | os.dup2(tee.stdin.fileno(), sys.stderr.fileno()) 19 | 20 | benchmark_settings = benchmark_settings or {} 21 | 22 | for num_hosts, num_devices_per_host in cluster_settings: 23 | num_gpus = num_hosts * num_devices_per_host 24 | if exp_name is None: 25 | exp_name = f"{now}_{suite_name}_{num_gpus}_gpus" 26 | benchmark_suite(suite_name, 27 | num_hosts, 28 | num_devices_per_host, 29 | exp_name=exp_name, 30 | disable_tqdm=True, 31 | **benchmark_settings) 32 | 33 | 34 | model_search_suites = { 35 | "gpt": ("gpt.grid_search_auto", {}), 36 | "moe": ("moe.grid_search_auto", {}), 37 | "wresnet": ("wresnet.grid_search_auto", {}), 38 | "gpt_inference": ("gpt_inference.profile", { 39 | "niter": 10, 40 | "profile_stage_execution_time": True 41 | }), 42 | "moe_inference": ("moe_inference.profile", { 43 | "niter": 10, 44 | "profile_stage_execution_time": True 45 | }), 46 | "gpt_no_embedding_inference": ("gpt_no_embedding_inference.profile", {}), 47 | "gpt_inference_streaming": ("gpt_inference.profile", { 48 | "profile_driver_time": True 49 | }), 50 | } 51 | cluster_settings = [(8, 8), (4, 8), (3, 8), (2, 8), (1, 8), (1, 4), (1, 2), 52 | (1, 1)] 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument("suite", type=str, choices=model_search_suites.keys()) 57 | parser.add_argument("--exp-name", type=str, default=None) 58 | args = parser.parse_args() 59 | run_exp(args.exp_name, cluster_settings, *model_search_suites[args.suite]) 60 | -------------------------------------------------------------------------------- /benchmark/alpa/suite_auto_moe.py: -------------------------------------------------------------------------------- 1 | """Benchmark suites for moe with auto parallelization.""" 2 | from suite_manual_moe import moe_specs 3 | # Share parallel options with the GPT suite 4 | from suite_auto_gpt import (get_search_cases, get_solution_case, force_dp_dict) 5 | 6 | # Temporary debug suite 7 | tmp_suite = {} 8 | 9 | # Performance test with search solutions found for p3.16xlarge 10 | perf_test_suite = { 11 | 1: 12 | get_solution_case(moe_specs["380M"], 512, 1, [[0]], [(1, 1)], [(1, 1)], 13 | [{}]), 14 | 2: 15 | get_solution_case(moe_specs["690M"], 32, 8, [[0, 1, 2, 3, 4, 5, 6, 7]], 16 | [(1, 2)], [(2, 1)], [force_dp_dict]), 17 | 4: 18 | get_solution_case(moe_specs["1.3B"], 32, 8, 19 | [[0, 1, 2, 3], [4, 5, 6, 7]], [(1, 2)] * 2, 20 | [(2, 1)] * 2, [force_dp_dict] * 2), 21 | 8: 22 | get_solution_case(moe_specs["2.4B"], 32, 8, 23 | [[0, 1, 2, 3], [4, 5, 6, 7]], [(1, 4)] * 2, 24 | [(4, 1)] * 2, [force_dp_dict] * 2), 25 | 16: 26 | get_solution_case(moe_specs["10B"], 16, 8, [[0, 1, 2, 3], [4, 5, 6, 7]], 27 | [(1, 8)] * 2, [(8, 1)] * 2, [{}] * 2), 28 | 32: 29 | get_solution_case(moe_specs["27B"], 128, 8, 30 | [[0], [1], [2], [3], [4], [5], [6], [7]], 31 | [(1, 4)] * 8, [(4, 1)] * 8, [{}] * 8), 32 | 64: 33 | get_solution_case(moe_specs["70B"], 64, 8, 34 | [[0], [1], [2], [3], [4], [5], [6], [7]], 35 | [(1, 8)] * 8, [(8, 1)] * 8, [{}] * 8), 36 | } 37 | 38 | # Grid search on hyperparameters 39 | grid_search_suite = { 40 | 2: (get_search_cases(moe_specs["690M"], [16, 32, 64], [8])), 41 | 4: (get_search_cases(moe_specs["1.3B"], [16, 32, 64], [8])), 42 | 8: (get_search_cases(moe_specs["2.4B"], [16, 32, 64], [8])), 43 | 16: (get_search_cases(moe_specs["10B"], [16, 32, 64], [8])), 44 | 32: (get_search_cases(moe_specs["27B"], [32, 64, 128], [4, 8, 16])), 45 | 64: (get_search_cases(moe_specs["70B"], [64], [8, 16, 32])), 46 | # submesh_choices_mode: "small_power_of_two", max num_cpus = 20 47 | } 48 | -------------------------------------------------------------------------------- /benchmark/alpa/suite_inference_gpt.py: -------------------------------------------------------------------------------- 1 | """Benchmark suites for gpt with auto parallelization.""" 2 | from suite_manual_gpt import gpt_specs 3 | from benchmark_parallel_utils import (BenchmarkCase, UniformParallelArgs) 4 | 5 | prefer_reduce_scatter = True 6 | force_batch_dim_mapping = True 7 | use_remat = False 8 | 9 | profile_suite = {} 10 | force_dp_dict = {"force_batch_dim_to_mesh_dim": 0} 11 | 12 | 13 | def get_config(model_config, 14 | pp_list, 15 | dp_list, 16 | op_list, 17 | num_micro_batch_config, 18 | batch_size_config, 19 | ignore_one_device_case=False): 20 | for pp in pp_list: 21 | for dp in dp_list: 22 | for op in op_list: 23 | num_gpus = pp * dp * op 24 | if ignore_one_device_case and num_gpus == 1: 25 | continue 26 | for bs in batch_size_config: 27 | for nb in num_micro_batch_config: 28 | total_bs = bs * nb 29 | if num_gpus not in profile_suite: 30 | profile_suite[num_gpus] = [] 31 | parallel_args = UniformParallelArgs( 32 | prefer_reduce_scatter, use_remat, dp, op, pp, 33 | force_batch_dim_mapping) 34 | case = BenchmarkCase(total_bs, model_config, nb, 35 | "uniform", parallel_args) 36 | profile_suite[num_gpus].append(case) 37 | 38 | 39 | ## general examples: 40 | #get_config(gpt_specs["350M"], [1, 2, 4, 8], [1], [1], [1], [1, 4, 16]) 41 | #get_config(gpt_specs["760M"], [1, 2, 4, 8], [1], [1], [1], [1, 4, 16]) 42 | #get_config(gpt_specs["1.3B"], [1, 2, 4, 8], [1], [1], [1], [1, 4, 16]) 43 | #get_config(gpt_specs["2.6B"], [1, 2, 4, 8], [1], [1], [1], [1, 4, 16]) 44 | #get_config(gpt_specs["6.7B"], [1, 2, 4, 8], [1], [1], [1], [1, 4, 16]) 45 | #get_config(gpt_specs["15B"], [1, 2, 4, 8], [1], [1], [1], [1, 4, 16]) 46 | 47 | ## benchmark specific parallel method: 48 | #get_config(gpt_specs["6.7B"], [1], [1], [1, 2, 4, 8], [1, 256], [1, 4, 16, 64]) 49 | #get_config(gpt_specs["6.7B"], [1], [1, 2, 4, 8], [1], [1, 256], [1, 4, 16, 64], 50 | # ignore_one_device_case=True) 51 | #get_config(gpt_specs["6.7B"], [1, 2, 4, 8], [1], [1], [1, 256], [1, 4, 16, 64], 52 | # ignore_one_device_case=True) 53 | 54 | ## generate inference profiling results 55 | get_config(gpt_specs["1.3B"], [1, 2, 4, 8], [1], [1, 2, 4, 8], [1], 56 | [1, 2, 4, 8, 16]) 57 | get_config(gpt_specs["2.6B"], [1, 2, 4, 8, 16, 32], [1], [1, 2, 4, 8], [1], 58 | [1, 2, 4, 8, 16]) 59 | get_config(gpt_specs["6.7B"], [1, 2, 4, 8, 16, 32], [1], [1, 2, 4, 8], [1], 60 | [1, 2, 4, 8, 16]) 61 | get_config(gpt_specs["15B"], [1, 2, 4, 8, 16], [1], [1, 2, 4, 8], [1], 62 | [1, 2, 4, 8, 16]) 63 | -------------------------------------------------------------------------------- /benchmark/alpa/suite_inference_moe.py: -------------------------------------------------------------------------------- 1 | """Benchmark suites for gpt with auto parallelization.""" 2 | from suite_manual_moe import moe_specs 3 | from benchmark_parallel_utils import (BenchmarkCase, UniformParallelArgs) 4 | 5 | prefer_reduce_scatter = True 6 | force_batch_dim_mapping = True 7 | use_remat = False 8 | 9 | profile_suite = {} 10 | force_dp_dict = {"force_batch_dim_to_mesh_dim": 0} 11 | 12 | 13 | def get_config(model_config, 14 | pp_list, 15 | dp_list, 16 | op_list, 17 | num_micro_batch_config, 18 | batch_size_config, 19 | ignore_one_device_case=False): 20 | for pp in pp_list: 21 | for dp in dp_list: 22 | for op in op_list: 23 | num_gpus = pp * dp * op 24 | if ignore_one_device_case and num_gpus == 1: 25 | continue 26 | for bs in batch_size_config: 27 | for nb in num_micro_batch_config: 28 | total_bs = bs * nb 29 | if num_gpus not in profile_suite: 30 | profile_suite[num_gpus] = [] 31 | parallel_args = UniformParallelArgs( 32 | prefer_reduce_scatter, use_remat, dp, op, pp, 33 | force_batch_dim_mapping) 34 | case = BenchmarkCase(total_bs, model_config, nb, 35 | "uniform", parallel_args) 36 | profile_suite[num_gpus].append(case) 37 | 38 | 39 | ## generate inference profiling results 40 | get_config(moe_specs["1.3B"], [1, 2, 4, 8, 16], [1], [1, 2, 4, 8], [1], 41 | [1, 2, 4, 8, 16]) 42 | get_config(moe_specs["2.4B"], [1, 2, 4, 8, 16], [1], [1, 2, 4, 8], [1], 43 | [1, 2, 4, 8, 16]) 44 | get_config(moe_specs["7.1B"], [1, 2, 4, 8, 16], [1], [1, 2, 4, 8], [1], 45 | [1, 2, 4, 8, 16]) 46 | get_config(moe_specs["10B"], [1, 2, 4, 8, 16], [1], [1, 2, 4, 8], [1], 47 | [1, 2, 4, 8, 16]) 48 | -------------------------------------------------------------------------------- /benchmark/cupy/profile_matmul.py: -------------------------------------------------------------------------------- 1 | """Profile peak TFLOPS on matrix multiplications.""" 2 | import time 3 | import cupy as cp 4 | 5 | def benchmark(n, k, m, dtype, init_method="ones"): 6 | warmup = 5 7 | number = 50 8 | 9 | if init_method == "zeros": 10 | a = cp.zeros((n, k), dtype) 11 | b = cp.zeros((k, m), dtype) 12 | elif init_method == "full": 13 | a = cp.full((n, k), 1e-7, dtype) 14 | b = cp.full((k, m), 1e-7, dtype) 15 | elif init_method == "nans": 16 | a = cp.full((n, k), cp.nan, dtype) 17 | b = cp.full((k, m), cp.nan, dtype) 18 | elif init_method == "ones": 19 | a = cp.ones((n, k), dtype) 20 | b = cp.ones((k, m), dtype) 21 | elif init_method == "ones+randn": 22 | a = cp.ones((n, k), dtype) 23 | b = cp.ones((k, m), dtype) 24 | ratio = 2 25 | a[0:n//ratio, :] = cp.random.randn(n//ratio, k).astype(dtype) 26 | b[0:k//ratio, :] = cp.random.randn(k//ratio, m).astype(dtype) 27 | elif init_method == "randn": 28 | a = cp.random.randn(n, k).astype(dtype) 29 | b = cp.random.randn(k, m).astype(dtype) 30 | elif init_method == "uniform": 31 | a = cp.random.uniform(-1, 1, (n, k)).astype(dtype) 32 | b = cp.random.uniform(-1, 1, (k, m)).astype(dtype) 33 | elif init_method == "uniform+": 34 | a = cp.random.uniform(0, 1, (n, k)).astype(dtype) 35 | b = cp.random.uniform(0, 1, (k, m)).astype(dtype) 36 | else: 37 | raise ValueError(f"Invalid method: {init_method}") 38 | for i in range(warmup): 39 | c = a @ b 40 | 41 | cp.cuda.Device(0).synchronize() 42 | tic = time.time() 43 | for i in range(number): 44 | cp.dot(a, b, c) 45 | cp.cuda.Device(0).synchronize() 46 | toc = time.time() 47 | 48 | total_flops = 2 * n * k * m 49 | cost = (toc - tic) / number 50 | shape = (n, k, m, dtype) 51 | 52 | print(f"shape: {shape}, init_method: {init_method:>8}, " 53 | f"TFLOP: {total_flops / 1e12:.2f}, " 54 | f"cost: {cost:3f}, " 55 | f"TFLOPS : {total_flops / cost / 1e12:.2f}""") 56 | 57 | 58 | for n in [8192]: 59 | for init_method in ["nans", "full", "zeros", "ones", 60 | "randn", "uniform", "uniform+", "ones+randn"]: 61 | benchmark(n, n, n, "float16", init_method) 62 | -------------------------------------------------------------------------------- /benchmark/deepspeed/README.md: -------------------------------------------------------------------------------- 1 | # Benchmark Deepspeed 2 | 3 | ## Requirements 4 | 1. Install dependencies 5 | ``` 6 | # torch 7 | pip3 install torch==1.8.2+cu111 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html 8 | pip3 install nltk pandas sentencepiece boto3 pybind11 python-config 9 | 10 | # Adafactor optimizer 11 | pip3 install torch-optimizer 12 | 13 | # pdsh 14 | sudo apt-get update 15 | sudo apt-get install pdsh 16 | 17 | # Apex 18 | git clone https://github.com/NVIDIA/apex 19 | cd apex 20 | # Comment out the raised RuntimeError in setup.py if you get errors running the following command. 21 | pip3 install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 22 | ``` 23 | 24 | 2. Install deepspeed and deepspeed examples 25 | ``` 26 | pip3 install deepspeed==0.5.4 27 | git clone --recursive https://github.com/microsoft/DeepSpeed.git 28 | echo 'export DEEPSPEED_PATH=~/efs/DeepSpeed' >> ~/.bashrc # use your own path 29 | source ~/.bashrc 30 | 31 | # Replace source files (use your own path) 32 | cp alpa/benchmark/deepspeed/patch/training.py DeepSpeed/DeepSpeedExamples/Megatron-LM-v1.1.5-ZeRO3/megatron/training.py 33 | cp alpa/benchmark/deepspeed/patch/gpt2_model.py DeepSpeed/DeepSpeedExamples/Megatron-LM-v1.1.5-ZeRO3/megatron/model/gpt2_model.py 34 | cp alpa/benchmark/deepspeed/patch/transformer.py DeepSpeed/DeepSpeedExamples/Megatron-LM-v1.1.5-ZeRO3/megatron/model/transformer.py 35 | ``` 36 | 37 | 3. Download dataset 38 | ``` 39 | wget deepspeed_dataset.zip # ask Lianmin to get the file 40 | tar xzf deepspeed_dataset.zip 41 | cd deepspeed_dataset/ 42 | ln -s $(pwd) ~/efs/alpa/benchmark/deepspeed/data # use your own path 43 | ``` 44 | 45 | ## Run 46 | ### Single Node 47 | ``` 48 | # GPT 49 | python3 benchmark_gpt2.py --nproc_per_node 8 50 | # MOE 51 | python3 benchmark_gpt2_moe.py --nproc_per_node 8 52 | ``` 53 | 54 | ### Multiple Node 55 | - Modify the [hostfile](https://www.deepspeed.ai/getting-started/#resource-configuration-multi-node) and setup the ssh connections. 56 | ``` 57 | python3 benchmark_gpt2.py --nnodes 2 --nproc_per_node 8 58 | ``` 59 | -------------------------------------------------------------------------------- /benchmark/deepspeed/ds_zero_stage_2_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": 8192, 3 | "gradient_accumulation_steps": 4, 4 | "steps_per_print": 1, 5 | "zero_optimization": { 6 | "stage": 2, 7 | "allgather_partitions": true, 8 | "reduce_scatter": true, 9 | "allgather_bucket_size": 5e8, 10 | "reduce_bucket_size": 5e8, 11 | "overlap_comm": true, 12 | "contiguous_gradients": true 13 | }, 14 | "optimizer": { 15 | "type": "Adam", 16 | "params": { 17 | "lr": 0.00015, 18 | "max_grad_norm": 1.0, 19 | "betas": [0.9, 0.95] 20 | } 21 | }, 22 | "gradient_clipping": 1.0, 23 | "fp16": { 24 | "enabled": true, 25 | "loss_scale": 1.0, 26 | "loss_scale_window": 1000, 27 | "hysteresis": 2, 28 | "min_loss_scale": 1 29 | }, 30 | "wall_clock_breakdown": false, 31 | "zero_allow_untested_optimizer": false 32 | } 33 | -------------------------------------------------------------------------------- /benchmark/deepspeed/ds_zero_stage_2_moe_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": 8192, 3 | "gradient_accumulation_steps": 4, 4 | "steps_per_print": 1, 5 | "zero_optimization": { 6 | "stage": 2, 7 | "allgather_partitions": true, 8 | "reduce_scatter": true, 9 | "allgather_bucket_size": 5e8, 10 | "reduce_bucket_size": 5e8, 11 | "overlap_comm": true, 12 | "contiguous_gradients": true 13 | }, 14 | "gradient_clipping": 1.0, 15 | "fp16": { 16 | "enabled": true, 17 | "loss_scale": 1.0, 18 | "loss_scale_window": 1000, 19 | "hysteresis": 2, 20 | "min_loss_scale": 1 21 | }, 22 | "wall_clock_breakdown": false, 23 | "zero_allow_untested_optimizer": true 24 | } 25 | -------------------------------------------------------------------------------- /benchmark/deepspeed/ds_zero_stage_3_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": 8192, 3 | "gradient_accumulation_steps": 1, 4 | "steps_per_print": 1, 5 | "zero_optimization": { 6 | "stage": 3, 7 | "stage3_max_live_parameters": 1e9, 8 | "stage3_max_reuse_distance": 1e9, 9 | "stage3_prefetch_bucket_size": 1e7, 10 | "stage3_param_persitence_threshold": 1e5, 11 | "reduce_bucket_size": 1e7, 12 | "contiguous_gradients": true 13 | }, 14 | "optimizer": { 15 | "type": "Adam", 16 | "params": { 17 | "lr": 0.00015, 18 | "max_grad_norm": 1.0, 19 | "betas": [0.9, 0.95] 20 | } 21 | }, 22 | "gradient_clipping": 1.0, 23 | "fp16": { 24 | "enabled": true, 25 | "loss_scale": 1.0, 26 | "loss_scale_window": 1000, 27 | "hysteresis": 2, 28 | "min_loss_scale": 1 29 | }, 30 | "wall_clock_breakdown": false, 31 | "zero_allow_untested_optimizer": false 32 | } 33 | -------------------------------------------------------------------------------- /benchmark/deepspeed/hostfile: -------------------------------------------------------------------------------- 1 | 172.31.19.47 slots=8 2 | 172.31.27.46 slots=8 3 | -------------------------------------------------------------------------------- /benchmark/deepspeed/killall_python.sh: -------------------------------------------------------------------------------- 1 | kill -9 $(ps aux | grep 'python3' | grep -v 'grep' | awk '{print $2}') 2 | -------------------------------------------------------------------------------- /benchmark/deepspeed/util.py: -------------------------------------------------------------------------------- 1 | ../alpa/util.py -------------------------------------------------------------------------------- /benchmark/megatron/README.md: -------------------------------------------------------------------------------- 1 | # Benchmark Megatron-LM 2 | 3 | ## Requirements 4 | ``` 5 | # torch 1.8.0 and CUDA 11.1 6 | pip3 install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html 7 | 8 | pip3 install ninja 9 | 10 | # Install Megatron 11 | git clone https://github.com/NVIDIA/Megatron-LM.git 12 | cd Megatron-LM 13 | echo 'export PYTHONPATH=$PYTHONPATH:~/efs/Megatron-LM' >> ~/.bashrc # use your own path 14 | source ~/.bashrc 15 | 16 | # Install Apex 17 | git clone https://github.com/NVIDIA/apex 18 | cd apex 19 | # Comment out the raised RuntimeError in setup.py if you get errors running the following command. 20 | pip3 install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 21 | ``` 22 | 23 | ## Instructions 24 | ### Single Node 25 | ``` 26 | # MLP 27 | python3 benchmark_mlp.py --nproc_per_node 4 28 | # Transfomer layer 29 | python3 benchmark_transformer_layer.py --nproc_per_node 4 30 | # GPT 31 | python3 benchmark_gpt_bert.py --nproc_per_node 1 --suite gpt.tmp 32 | python3 benchmark_gpt_bert.py --nproc_per_node 8 --suite gpt.tmp 33 | ``` 34 | 35 | ### Multiple Nodes 36 | ``` 37 | # on node 0 38 | python3 benchmark_gpt_bert.py --suite gpt.tmp --nproc_per_node 8 --nnodes 2 --node_rank 0 --master_port 11000 --master_addr 172.31.16.139 39 | # on node 1 40 | python3 benchmark_gpt_bert.py --suite gpt.tmp --nproc_per_node 8 --nnodes 2 --node_rank 1 --master_port 11000 --master_addr 172.31.16.139 41 | ``` 42 | 43 | For other models, replace `benchmark_gpt_bert.py` with the corresponding filenames. 44 | 45 | ### With nvprof 46 | ``` 47 | nvprof --profile-child-processes python3 benchmark_mlp.py --nproc_per_node 4 &> megatron.prof 48 | ``` 49 | -------------------------------------------------------------------------------- /benchmark/megatron/benchmark_gpt_bert.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datetime import datetime 3 | 4 | from util import run_cmd 5 | 6 | from benchmark.alpa import suite_manual_gpt 7 | 8 | benchmark_suites = { 9 | "gpt.tmp": suite_manual_gpt.tmp_suite, 10 | #"gpt.grid_search_manual": suite_manual_gpt.grid_search_manual, 11 | } 12 | 13 | def benchmark_all(args): 14 | num_gpus = args.nproc_per_node * args.nnodes 15 | 16 | try: 17 | _ = benchmark_suites[args.suite][num_gpus] 18 | except KeyError: 19 | print(f"No available benchmark suite for {args.suite} with {num_gpus} GPUs.") 20 | exit() 21 | output_name = args.exp_name + "-" + datetime.now().strftime("%Y-%m-%d-%H-%M-%S") 22 | model = args.suite.split(".")[0] 23 | 24 | for case in benchmark_suites[args.suite][num_gpus]: 25 | case = tuple(tuple(x) if isinstance(x, tuple) else x for x in case) 26 | case_str = str((model,) + case) 27 | 28 | if args.nnodes == 1: 29 | # Single node 30 | ret = run_cmd('python3 -m torch.distributed.launch ' 31 | f'--nproc_per_node {args.nproc_per_node} ' 32 | 'benchmark_gpt_bert_one_case.py ' 33 | f'"{case_str}" ' 34 | f'{output_name}') 35 | else: 36 | # Multiple nodes 37 | ret = run_cmd('python3 -m torch.distributed.launch ' 38 | f'--nproc_per_node {args.nproc_per_node} ' 39 | f'--nnodes {args.nnodes} ' 40 | f'--node_rank {args.node_rank} ' 41 | f'--master_addr {args.master_addr} ' 42 | f'--master_port {args.master_port} ' 43 | 'benchmark_gpt_bert_one_case.py ' 44 | f'"{case_str}" ' 45 | f'{output_name}') 46 | 47 | if __name__ == "__main__": 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument("--nproc_per_node", type=int, required=True) 50 | parser.add_argument("--nnodes", type=int, default=1) 51 | parser.add_argument("--node_rank", type=int) 52 | parser.add_argument("--master_addr", type=str) 53 | parser.add_argument("--master_port", type=str) 54 | parser.add_argument("--suite", type=str, default="gpt.tmp") 55 | parser.add_argument("--exp_name", type=str, default="") 56 | args = parser.parse_args() 57 | 58 | benchmark_all(args) 59 | -------------------------------------------------------------------------------- /benchmark/megatron/benchmark_mlp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from util import run_cmd 4 | 5 | # B = batch_size, S = seq_len, H = hidden_size, L = num_layers, 6 | # #head = num_heads, DP = dp_size, TMP = tensor_mp_size, DPI = ddp_implementation, 7 | 8 | benchmark_suite_4_gpu = [ 9 | # B, S, H, L, #head, DP, TMP, DPI 10 | (32, 1024, 2304, 4, 2304//96, 4, 1, 1), 11 | (32, 1024, 2304, 4, 2304//96, 2, 2, 1), 12 | (32, 1024, 2304, 4, 2304//96, 1, 4, 1), 13 | 14 | # B, S, H, L, #head, DP, TMP, DPI 15 | (8, 256, 5760, 4, 5760//96, 4, 1, 1), 16 | (8, 256, 5760, 4, 5760//96, 2, 2, 1), 17 | (8, 256, 5760, 4, 5760//96, 1, 4, 1), 18 | ] 19 | 20 | 21 | def benchmark_all(): 22 | for case in benchmark_suite_4_gpu: 23 | nproc_per_node = 4 24 | case_str = str(case) 25 | ret = run_cmd('python3 -m torch.distributed.launch ' 26 | f'--nproc_per_node {nproc_per_node} ' 27 | 'benchmark_mlp_one_case.py ' 28 | f'"{case_str}"') 29 | if ret != 0: 30 | return 31 | 32 | if __name__ == "__main__": 33 | benchmark_all() 34 | 35 | -------------------------------------------------------------------------------- /benchmark/megatron/util.py: -------------------------------------------------------------------------------- 1 | ../alpa/util.py -------------------------------------------------------------------------------- /build_jaxlib/.bazelversion: -------------------------------------------------------------------------------- 1 | 5.1.1 2 | -------------------------------------------------------------------------------- /build_jaxlib/WORKSPACE: -------------------------------------------------------------------------------- 1 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 2 | 3 | # To update TensorFlow to a new revision, 4 | # a) update URL and strip_prefix to the new git commit hash 5 | # b) get the sha256 hash of the commit by running: 6 | # curl -L https://github.com/tensorflow/tensorflow/archive/.tar.gz | sha256sum 7 | # and update the sha256 with the result. 8 | http_archive( 9 | name = "org_tensorflow", 10 | sha256 = "9a7a7a87356bdeef5874fae135de380466482b593469035be3609a9cd2c153c4", 11 | strip_prefix = "tensorflow-cb946f223b9b3fa04efdbb7a0e6a9dabb22a7057", 12 | urls = [ 13 | "https://github.com/tensorflow/tensorflow/archive/cb946f223b9b3fa04efdbb7a0e6a9dabb22a7057.tar.gz", 14 | ], 15 | ) 16 | 17 | # For development, one often wants to make changes to the TF repository as well 18 | # as the JAX repository. You can override the pinned repository above with a 19 | # local checkout by either: 20 | # a) overriding the TF repository on the build.py command line by passing a flag 21 | # like: 22 | # python build/build.py --bazel_options=--override_repository=org_tensorflow=/path/to/tensorflow 23 | # or 24 | # b) by commenting out the http_archive above and uncommenting the following: 25 | # local_repository( 26 | # name = "org_tensorflow", 27 | # path = "/path/to/tensorflow", 28 | # ) 29 | 30 | load("//third_party/ducc:workspace.bzl", ducc = "repo") 31 | ducc() 32 | 33 | # Initialize TensorFlow's external dependencies. 34 | load("@org_tensorflow//tensorflow:workspace3.bzl", "tf_workspace3") 35 | tf_workspace3() 36 | 37 | load("@org_tensorflow//tensorflow:workspace2.bzl", "tf_workspace2") 38 | tf_workspace2() 39 | 40 | load("@org_tensorflow//tensorflow:workspace1.bzl", "tf_workspace1") 41 | tf_workspace1() 42 | 43 | load("@org_tensorflow//tensorflow:workspace0.bzl", "tf_workspace0") 44 | tf_workspace0() 45 | -------------------------------------------------------------------------------- /build_jaxlib/build/BUILD.bazel: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The JAX Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # JAX is Autograd and XLA 16 | 17 | load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") 18 | load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") 19 | load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") 20 | load("//jaxlib:jax.bzl", "if_windows") 21 | 22 | licenses(["notice"]) # Apache 2 23 | 24 | package(default_visibility = ["//visibility:public"]) 25 | 26 | bool_flag( 27 | name = "enable_remote_tpu", 28 | build_setting_default = False, 29 | ) 30 | 31 | config_setting( 32 | name = "remote_tpu_enabled", 33 | flag_values = { 34 | ":enable_remote_tpu": "True", 35 | }, 36 | ) 37 | 38 | py_binary( 39 | name = "build_wheel", 40 | srcs = ["build_wheel.py"], 41 | data = [ 42 | "LICENSE.txt", 43 | "//jaxlib", 44 | "//jaxlib:README.md", 45 | "//jaxlib:setup.py", 46 | "//jaxlib:setup.cfg", 47 | "@org_tensorflow//tensorflow/compiler/xla/python:xla_client", 48 | ] + if_windows([ 49 | "//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll", 50 | ]) + select({ 51 | ":remote_tpu_enabled": ["@org_tensorflow//tensorflow/compiler/xla/python/tpu_driver/client:py_tpu_client"], 52 | "//conditions:default": [], 53 | }) + if_cuda([ 54 | "//jaxlib/cuda:cuda_gpu_support", 55 | "@local_config_cuda//cuda:cuda-nvvm", 56 | ]) + if_rocm([ 57 | "//jaxlib/rocm:rocm_gpu_support", 58 | ]), 59 | deps = ["@bazel_tools//tools/python/runfiles"], 60 | ) 61 | -------------------------------------------------------------------------------- /build_jaxlib/jax: -------------------------------------------------------------------------------- 1 | ../third_party/jax/jax -------------------------------------------------------------------------------- /build_jaxlib/jaxlib: -------------------------------------------------------------------------------- 1 | ../third_party/jax/jaxlib -------------------------------------------------------------------------------- /build_jaxlib/release/README.md: -------------------------------------------------------------------------------- 1 | # How to Release JaxLib and generate a PyPI Index 2 | 3 | 1. Upload jaxlib wheels as assets under a release tag. 4 | ```shell 5 | GITHUB_TOKEN=[ADMIN_TOKEN] python wheel_upload.py --tag [TAG] --path [PATH_TO_WHEELS] 6 | ``` 7 | 8 | 2. Generate a html index page and commit it to the master branch of Alpa doc repository. 9 | ```shell 10 | GITHUB_TOKEN=[ADMIN_TOKEN] python generate_pypi_index.py --tag [TAG] 11 | ``` 12 | All wheel assets under `[TAG]` will be included in a html index page appeared in the doc repo. 13 | 14 | Please make sure the TAG is aligned in Step 1 and Step 2. 15 | -------------------------------------------------------------------------------- /build_jaxlib/release/wheel_upload.py: -------------------------------------------------------------------------------- 1 | """Update the wheels page, prune old nightly builds if necessary (source from tlcpack).""" 2 | import github3 3 | import github3.session as session 4 | import os 5 | import logging 6 | import argparse 7 | 8 | 9 | def upload(args, path): 10 | # gh = github3.login(token=os.environ["GITHUB_TOKEN"]) 11 | gh = github3.GitHub(token=os.environ["GITHUB_TOKEN"], 12 | session=session.GitHubSession(default_connect_timeout=100, default_read_timeout=100)) 13 | repo = gh.repository(*args.repo.split("/")) 14 | release = repo.release_from_tag(args.tag) 15 | name = os.path.basename(path) 16 | content_bytes = open(path, "rb").read() 17 | 18 | for asset in release.assets(): 19 | if asset.name == name: 20 | if not args.dry_run: 21 | asset.delete() 22 | print(f"Remove duplicated file {name}") 23 | print(f"Start to upload {path} to {args.repo}, this can take a while...") 24 | if not args.dry_run: 25 | release.upload_asset("application/octet-stream", name, content_bytes) 26 | print(f"Finish uploading {path}") 27 | 28 | 29 | def main(): 30 | logging.basicConfig(level=logging.WARNING) 31 | parser = argparse.ArgumentParser(description="Upload wheel as an asset of a tag.") 32 | parser.add_argument("--tag", type=str) 33 | parser.add_argument("--repo", type=str, default="alpa-projects/alpa") 34 | parser.add_argument("--dry-run", action="store_true") 35 | parser.add_argument("--path", type=str) 36 | 37 | if "GITHUB_TOKEN" not in os.environ: 38 | raise RuntimeError("need GITHUB_TOKEN") 39 | args = parser.parse_args() 40 | if os.path.isdir(args.path): 41 | for name in os.listdir(args.path): 42 | if name.endswith(".whl"): 43 | upload(args, os.path.join(args.path, name)) 44 | else: 45 | upload(args, args.path) 46 | 47 | if __name__ == "__main__": 48 | main() 49 | -------------------------------------------------------------------------------- /build_jaxlib/third_party: -------------------------------------------------------------------------------- 1 | ../third_party/jax/third_party/ -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | # Alpa Docker 2 | This directory contains Alpa's docker infrastructure. Alpa uses docker to provide environment to build and release Python wheels and to perform unit tests. 3 | Most docker files in this directory depend on [nvidia-docker](https://github.com/NVIDIA/nvidia-docker/). 4 | 5 | Below we provide instructions on 6 | - How to build Alpa-modified jaxlib in a docker container 7 | - How to run Alpa in a docker container 8 | 9 | More docker examples can be found in the directory of [Alpa CI/CD](../.github/workflows). 10 | 11 | ## Build Jaxlib-alpa wheels using Docker 12 | We provide a Docker image to build the Alpa-modified jaxlib wheels inside a container. 13 | 14 | 15 | ### Steps 16 | First, figure out the CUDA and Python versions you want to use to build jaxlib. Current we support the following versions: 17 | - CUDA: 11.1, 11.2, 11.3 18 | - Python: 3.7, 3.8, 3.9 19 | 20 | Suppose we want to build the jaxlib-alpa with CUDA 11.1 and Python 3.8. 21 | #### Build the docker image 22 | ```python 23 | # create a folder to save the output wheels 24 | cd alpa/docker && mkdir -p dist 25 | 26 | # build the image using the chosen CUDA version 27 | docker build -t build-jaxlib-image -f build_jaxlib.Dockerfile . --build-arg JAX_CUDA_VERSION=11.1 28 | ``` 29 | 30 | #### Build the wheels inside a container 31 | ```bash 32 | # create a subfolder for the specific wheel version. 33 | mkdir -p dist/cuda111 34 | 35 | # build the wheel in a container using the selected Python and CUDA versions 36 | docker run --tmpfs /build:exec --rm -v $(pwd)/dist:/dist build-jaxlib-image 3.8 cuda 11.1 main 37 | 38 | # Move the output wheel 39 | mv -f dist/*.whl dist/cuda111/ 40 | ``` 41 | Check out the wheel under the folder ``alpa/build/dist/cuda111/``. 42 | 43 | ## Run Alpa in a docker container 44 | You can run Alpa inside a docker container. Below are steps on how to run Alpa in a docker container in the interactive mode. 45 | 46 | First, build a docker image based on the provided dockerfile: 47 | ```bash 48 | docker build -t run-alpa-image -f run_alpa.Dockerfile . 49 | ``` 50 | 51 | For cloud provider with InfiniBand (such as CoreWeave) we need to include additional dependencies: 52 | ```bash 53 | docker build -t run-alpa-image -f run_alpa_infiniband.Dockerfile . 54 | ``` 55 | 56 | Second, build a container from the image and enter the container's interactive shell: 57 | ```bash 58 | docker run --gpus all --rm --shm-size=10.24gb -it run-alpa-image 59 | ``` 60 | 61 | Third, check alpa installation is correct: 62 | ```bash 63 | conda activate alpa 64 | # Start ray: 65 | ray start --head 66 | # Test Alpa can run correctly: 67 | python -m alpa.test_install 68 | ``` 69 | 70 | Alternatively, you can skip the interactive shell, and pass commands or job scripts via the `docker run` command to the container. 71 | -------------------------------------------------------------------------------- /docker/build_alpa.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM quay.io/pypa/manylinux2014_x86_64 2 | 3 | WORKDIR / 4 | SHELL ["/bin/bash", "-c"] 5 | RUN yum-config-manager --add-repo http://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo 6 | RUN yum --enablerepo=epel -y install cuda-11-1 7 | 8 | COPY scripts/build_alpa.sh /build_alpa.sh 9 | RUN chmod +x /build_alpa.sh 10 | 11 | WORKDIR /build 12 | ENV TEST_TMPDIR /build 13 | ENTRYPOINT ["/build_alpa.sh"] 14 | -------------------------------------------------------------------------------- /docker/build_doc.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM gcr.io/tensorflow-testing/nosla-cuda11.1-cudnn8-ubuntu18.04-manylinux2010-multipython 2 | 3 | WORKDIR / 4 | SHELL ["/bin/bash", "-c"] 5 | RUN rm -f /etc/apt/sources.list.d/jonathonf-ubuntu-python-3_6-xenial.list 6 | RUN apt-get update 7 | RUN apt-get install -y coinor-cbc glpk-utils python3-virtualenv 8 | 9 | RUN virtualenv --python=python3.8 python3.8-env 10 | RUN source python3.8-env/bin/activate && pip install --upgrade pip \ 11 | && pip install numpy==1.20 setuptools wheel six auditwheel \ 12 | sphinx sphinx-rtd-theme sphinx-gallery matplotlib 13 | COPY scripts/build_doc.sh /build_doc.sh 14 | RUN chmod +x build_doc.sh 15 | ENTRYPOINT ["/build_doc.sh"] 16 | -------------------------------------------------------------------------------- /docker/build_jaxlib.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM gcr.io/tensorflow-testing/nosla-cuda11.1-cudnn8-ubuntu18.04-manylinux2010-multipython 2 | 3 | WORKDIR / 4 | SHELL ["/bin/bash", "-c"] 5 | RUN sudo apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub 6 | RUN sudo apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub 7 | RUN rm -f /etc/apt/sources.list.d/jonathonf-ubuntu-python-3_6-xenial.list 8 | RUN apt-get update 9 | RUN apt-get install -y python3-virtualenv 10 | 11 | RUN virtualenv --python=python3.7 python3.7-env 12 | RUN virtualenv --python=python3.8 python3.8-env 13 | RUN virtualenv --python=python3.9 python3.9-env 14 | 15 | # We pin numpy to the minimum permitted version to avoid compatibility issues. 16 | RUN source python3.7-env/bin/activate && pip install --upgrade pip && pip install numpy==1.20 setuptools wheel six auditwheel 17 | RUN source python3.8-env/bin/activate && pip install --upgrade pip && pip install numpy==1.20 setuptools wheel six auditwheel 18 | RUN source python3.9-env/bin/activate && pip install --upgrade pip && pip install numpy==1.20 setuptools wheel six auditwheel 19 | 20 | # Change the CUDA version if it doesn't match the installed version in the base image 21 | # which is 10.0 22 | ARG JAX_CUDA_VERSION=11.1 23 | COPY scripts/install_cuda.sh /install_cuda.sh 24 | RUN chmod +x /install_cuda.sh 25 | RUN /bin/bash -c 'if [[ ! "$CUDA_VERSION" =~ ^$JAX_CUDA_VERSION.*$ ]]; then \ 26 | /install_cuda.sh $JAX_CUDA_VERSION; \ 27 | fi' 28 | 29 | 30 | WORKDIR / 31 | COPY scripts/build_jaxlib_docker_entrypoint.sh /build_jaxlib_docker_entrypoint.sh 32 | RUN chmod +x /build_jaxlib_docker_entrypoint.sh 33 | 34 | WORKDIR /build 35 | ENV TEST_TMPDIR /build 36 | ENTRYPOINT ["/build_jaxlib_docker_entrypoint.sh"] 37 | -------------------------------------------------------------------------------- /docker/run_alpa.Dockerfile: -------------------------------------------------------------------------------- 1 | # base docker image 2 | FROM nvidia/cuda:11.3.0-cudnn8-devel-ubuntu20.04 3 | 4 | # init workdir 5 | RUN mkdir -p /build 6 | WORKDIR /build 7 | 8 | # install common tool & conda 9 | RUN apt update && \ 10 | apt install wget -y && \ 11 | apt install git -y && \ 12 | apt install vim -y && \ 13 | wget --quiet https://repo.anaconda.com/archive/Anaconda3-2022.05-Linux-x86_64.sh -O ~/anaconda.sh && \ 14 | /bin/bash ~/anaconda.sh -b -p /opt/conda && \ 15 | rm ~/anaconda.sh && \ 16 | mkdir -p /opt/conda/envs/alpa && \ 17 | ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \ 18 | echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \ 19 | echo "conda activate base" >> ~/.bashrc 20 | 21 | # install conda alpa env 22 | RUN . /opt/conda/etc/profile.d/conda.sh && \ 23 | conda create --name alpa python=3.8 -y && \ 24 | conda activate alpa && \ 25 | apt install coinor-cbc -y && \ 26 | pip3 install --upgrade pip && \ 27 | pip3 install cupy-cuda113 && \ 28 | pip3 install alpa && \ 29 | pip3 install jaxlib==0.3.22+cuda113.cudnn820 -f https://alpa-projects.github.io/wheels.html 30 | -------------------------------------------------------------------------------- /docker/scripts/build_alpa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -xev 3 | if [ ! -d "/dist" ] 4 | then 5 | echo "/dist must be mounted to produce output" 6 | exit 1 7 | fi 8 | 9 | usage() { 10 | echo "usage: ${0##*/} [3.7|3.8|3.9] [alpa-branch]" 11 | exit 1 12 | } 13 | 14 | if [[ $# -lt 2 ]] 15 | then 16 | usage 17 | fi 18 | 19 | export PY_VERSION=$1 20 | 21 | if [ $PY_VERSION = "3.7" ]; then 22 | #alias python="/opt/python/cp37-cp37m/bin/python" 23 | ln -fs /opt/python/cp37-cp37m/bin/python /usr/bin/python3 24 | python3 -m ensurepip --upgrade 25 | python3 -m pip install cmake auditwheel pybind11 26 | ln -fs /opt/python/cp37-cp37m/bin/pybind11-config /usr/bin/pybind11-config 27 | elif [ $PY_VERSION = "3.8" ]; then 28 | #alias python="/opt/python/cp38-cp38/bin/python" 29 | ln -fs /opt/python/cp38-cp38/bin/python /usr/bin/python3 30 | python3 -m ensurepip --upgrade 31 | python3 -m pip install cmake auditwheel pybind11 32 | ln -fs /opt/python/cp38-cp38/bin//pybind11-config /usr/bin/pybind11-config 33 | elif [ $PY_VERSION = "3.9" ]; then 34 | #alias python="/opt/python/cp39-cp39/bin/python" 35 | ln -fs /opt/python/cp39-cp39/bin/python /usr/bin/python3 36 | python3 -m ensurepip --upgrade 37 | python3 -m pip install cmake auditwheel pybind11 38 | ln -fs /opt/python/cp39-cp39/bin/pybind11-config /usr/bin/pybind11-config 39 | else 40 | echo "Unsupported Python version: $PY_VERSION" 41 | exit 1 42 | fi 43 | 44 | ALPA_BRANCH="$2" 45 | 46 | # switch to the merge commit 47 | git clone https://github.com/alpa-projects/alpa.git 48 | cd alpa 49 | git fetch origin +${ALPA_BRANCH} 50 | git checkout -qf FETCH_HEAD 51 | 52 | # install jaxlib and jax 53 | python3 update_version.py --git-describe 54 | python3 setup.py bdist_wheel sdist 55 | 56 | #if ! python3 -m auditwheel show dist/alpa-*.whl | egrep 'platform tag: "(manylinux2014_x86_64|manylinux_2_17_x86_64)"' > /dev/null; then 57 | # # Print output for debugging 58 | # python3 -m auditwheel show dist/alpa-*.whl 59 | # echo "jaxlib wheel is not manylinux2014 compliant" 60 | # exit 1 61 | #fi 62 | 63 | #rename 'linux' manylinux2014 dist/*.whl 64 | cp -r dist/*whl /dist/ 65 | -------------------------------------------------------------------------------- /docker/scripts/build_doc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -xev 4 | 5 | if [ ! -d "/alpa-dist" ] 6 | then 7 | echo "/alpa-dist must be mounted to produce output" 8 | exit 1 9 | fi 10 | 11 | source /python3.8-env/bin/activate 12 | pip install /alpa-dist/jaxlib-alpa-ci/jaxlib-0.3.5+cuda111.cudnn805-cp38-none-manylinux2010_x86_64.whl 13 | pip install jax==0.3.5 14 | 15 | git clone https://github.com/alpa-projects/alpa.git 16 | cd alpa 17 | pip install cupy-cuda111 18 | python -m cupyx.tools.install_library --library nccl --cuda 11.1 19 | pip install -e .[doc] 20 | cd /alpa/docs 21 | make html 22 | cp -r _build/html/* /alpa-dist/docs/ 23 | -------------------------------------------------------------------------------- /docker/scripts/install_cuda.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -xe 3 | 4 | CUDA_VERSION=$1 5 | 6 | LIBCUDNN=libcudnn7 7 | if [ $CUDA_VERSION = "10.0" ]; then 8 | CUBLAS=libcublas10 9 | CUBLAS_DEV=libcublas-dev 10 | elif [ $CUDA_VERSION = "10.1" ]; then 11 | # Have to pin to libcublas10=10.2.1.243-1 due to bug in TF, see 12 | # https://github.com/tensorflow/tensorflow/issues/9489#issuecomment-562394257 13 | CUBLAS=libcublas10=10.2.1.243-1 14 | CUBLAS_DEV=libcublas-dev=10.2.1.243-1 15 | elif [ $CUDA_VERSION = "10.2" ]; then 16 | CUBLAS=libcublas10 17 | CUBLAS_DEV=libcublas-dev 18 | CUDNN_VERSION=7.6.5.32 19 | elif [ $CUDA_VERSION = "11.0" ]; then 20 | CUBLAS=libcublas-11-0 21 | CUBLAS_DEV=libcublas-dev-11-0 22 | CUDNN_VERSION=8.0.5.39 23 | LIBCUDNN=libcudnn8 24 | elif [ $CUDA_VERSION = "11.1" ]; then 25 | CUBLAS=libcublas-11-1 26 | CUBLAS_DEV=libcublas-dev-11-1 27 | CUDNN_VERSION=8.0.5.39 28 | LIBCUDNN=libcudnn8 29 | elif [ $CUDA_VERSION = "11.2" ]; then 30 | CUBLAS=libcublas-11-2 31 | CUBLAS_DEV=libcublas-dev-11-2 32 | CUDNN_VERSION=8.1.0.77 33 | LIBCUDNN=libcudnn8 34 | elif [ $CUDA_VERSION = "11.3" ]; then 35 | CUBLAS=libcublas-11-3 36 | CUBLAS_DEV=libcublas-dev-11-3 37 | CUDNN_VERSION=8.2.0.53 38 | LIBCUDNN=libcudnn8 39 | elif [ $CUDA_VERSION = "11.4" ]; then 40 | CUBLAS=libcublas-11-4 41 | CUBLAS_DEV=libcublas-dev-11-4 42 | CUDNN_VERSION=8.2.2.26 43 | LIBCUDNN=libcudnn8 44 | else 45 | echo "Unsupported CUDA version: $CUDA_VERSION" 46 | exit 1 47 | fi 48 | 49 | echo "Installing cuda version: $CUDA_VERSION" 50 | echo "cudnn version: $CUDNN_VERSION" 51 | 52 | apt-key adv --keyserver keyserver.ubuntu.com --recv-keys A4B469963BF863CC 53 | apt-get update 54 | apt-get remove -y --allow-change-held-packages -f cuda-license-10-0 libnccl-dev libcudnn7 libcudnn8 libnccl2 55 | apt-get install -y --no-install-recommends --allow-downgrades \ 56 | $CUBLAS \ 57 | $CUBLAS_DEV \ 58 | cuda-nvml-dev-$CUDA_VERSION \ 59 | cuda-command-line-tools-$CUDA_VERSION \ 60 | cuda-libraries-dev-$CUDA_VERSION \ 61 | cuda-minimal-build-$CUDA_VERSION \ 62 | $LIBCUDNN=$CUDNN_VERSION-1+cuda$CUDA_VERSION \ 63 | $LIBCUDNN-dev=$CUDNN_VERSION-1+cuda$CUDA_VERSION 64 | rm -f /usr/local/cuda 65 | ln -s /usr/local/cuda-$CUDA_VERSION /usr/local/cuda 66 | -------------------------------------------------------------------------------- /docker/scripts/install_torch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -xe 3 | 4 | install_torch_deps() { 5 | # NOTE: functorch is pinned to the last commit that works with PyTorch 1.12 6 | pip install --extra-index-url https://download.pytorch.org/whl/cpu torch==1.12 torchdistx && \ 7 | ([ -d "functorch" ] || git clone https://github.com/pytorch/functorch) && \ 8 | pushd functorch && git checkout 76976db8412b60d322c680a5822116ba6f2f762a && python setup.py install && popd 9 | } 10 | 11 | install_torch_deps 12 | -------------------------------------------------------------------------------- /docker/scripts/test_alpa_docker_entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -xev 3 | if [ ! -d "/alpa-dist" ] 4 | then 5 | echo "/alpa-dist must be mounted to produce output" 6 | exit 1 7 | fi 8 | 9 | usage() { 10 | echo "usage: ${0##*/} [3.7|3.8|3.9] [alpa-branch]" 11 | exit 1 12 | } 13 | 14 | if [[ $# -lt 2 ]] 15 | then 16 | usage 17 | fi 18 | 19 | export PY_VERSION=$1 20 | ALPA_BRANCH="$2" 21 | 22 | # Enter python env 23 | source /python${PY_VERSION}-env/bin/activate 24 | # switch to the merge commit 25 | git clone https://github.com/alpa-projects/alpa.git 26 | cd /build/alpa 27 | git fetch origin +${ALPA_BRANCH} 28 | git checkout -qf FETCH_HEAD 29 | 30 | # install jaxlib and jax 31 | pip install /alpa-dist/jaxlib-alpa-ci/jaxlib-0.3.22+cuda111.cudnn805-cp38-cp38-manylinux2014_x86_64.whl 32 | pip install jax==0.3.22 33 | 34 | # install cupy 35 | pip install cupy-cuda111 36 | python -m cupyx.tools.install_library --library nccl --cuda 11.1 37 | pip install -e .[dev] 38 | ray start --head 39 | cd tests 40 | python run_all.py 41 | -------------------------------------------------------------------------------- /docker/unittest.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM gcr.io/tensorflow-testing/nosla-cuda11.1-cudnn8-ubuntu18.04-manylinux2010-multipython 2 | 3 | WORKDIR / 4 | SHELL ["/bin/bash", "-c"] 5 | RUN rm -f /etc/apt/sources.list.d/jonathonf-ubuntu-python-3_6-xenial.list 6 | # Fetch latest pub key so apt-get works. 7 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub 8 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub 9 | RUN apt-get update 10 | RUN apt-get install -y python3-virtualenv 11 | RUN virtualenv --python=python3.7 python3.7-env 12 | RUN virtualenv --python=python3.8 python3.8-env 13 | RUN virtualenv --python=python3.9 python3.9-env 14 | 15 | # We pin numpy to the minimum permitted version to avoid compatibility issues. 16 | RUN source python3.7-env/bin/activate && pip install --upgrade pip \ 17 | && pip install numpy==1.20 setuptools wheel six auditwheel \ 18 | tqdm scipy numba pulp tensorstore prospector yapf coverage cmake \ 19 | pybind11 ray[default] matplotlib transformers uvicorn fastapi 20 | RUN source python3.8-env/bin/activate && pip install --upgrade pip \ 21 | && pip install numpy==1.20 setuptools wheel six auditwheel \ 22 | tqdm scipy numba pulp tensorstore prospector yapf coverage cmake \ 23 | pybind11 ray[default] matplotlib transformers uvicorn fastapi 24 | RUN source python3.9-env/bin/activate && pip install --upgrade pip \ 25 | && pip install numpy==1.20 setuptools wheel six auditwheel \ 26 | tqdm scipy numba pulp tensorstore prospector yapf coverage cmake \ 27 | pybind11 ray[default] matplotlib transformers uvicorn fastapi 28 | 29 | # Install PyTorch dependencies 30 | WORKDIR / 31 | COPY scripts/install_torch.sh /install_torch.sh 32 | RUN chmod +x /install_torch.sh 33 | RUN source python3.7-env/bin/activate && /install_torch.sh 34 | RUN source python3.8-env/bin/activate && /install_torch.sh 35 | RUN source python3.9-env/bin/activate && /install_torch.sh 36 | 37 | # We determine the CUDA version at `docker build ...` phase 38 | ARG JAX_CUDA_VERSION=11.1 39 | COPY scripts/install_cuda.sh /install_cuda.sh 40 | RUN chmod +x /install_cuda.sh 41 | RUN /bin/bash -c 'if [[ ! "$CUDA_VERSION" =~ ^$JAX_CUDA_VERSION.*$ ]]; then \ 42 | /install_cuda.sh $JAX_CUDA_VERSION; \ 43 | fi' 44 | 45 | # Install cupy 46 | RUN source python3.7-env/bin/activate && pip install cupy-cuda${JAX_CUDA_VERSION//.} 47 | RUN source python3.8-env/bin/activate && pip install cupy-cuda${JAX_CUDA_VERSION//.} 48 | RUN source python3.9-env/bin/activate && pip install cupy-cuda${JAX_CUDA_VERSION//.} 49 | 50 | WORKDIR / 51 | COPY scripts/test_alpa_docker_entrypoint.sh /test_alpa_docker_entrypoint.sh 52 | RUN chmod +x /test_alpa_docker_entrypoint.sh 53 | 54 | WORKDIR /build 55 | ENV TEST_TMPDIR /build 56 | ENTRYPOINT ["/test_alpa_docker_entrypoint.sh"] 57 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | 22 | clean: 23 | rm -rf $(BUILDDIR)/* 24 | rm -rf tutorials/ 25 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Alpa Documentation 2 | 3 | ## Build the documentation website 4 | 5 | ### Dependency 6 | ``` 7 | pip3 install sphinx sphinx-rtd-theme sphinx-gallery matplotlib 8 | ``` 9 | 10 | ### Build 11 | ``` 12 | make html 13 | ``` 14 | 15 | The build process will execute all tutorial scripts to generate the gallery. 16 | This may cause failures if the build machine does not have necessary environment. 17 | This may also result in a very long build time. 18 | You can set `ALPA_TUTORIAL_EXEC_PATTERN` to only execute the files that match the regular expression pattern. 19 | For example, to build one specific file, do 20 | ``` 21 | export ALPA_TUTORIAL_EXEC_PATTERN=filename.py 22 | make html 23 | ``` 24 | To skip execution of all tutorials, do 25 | ``` 26 | export ALPA_TUTORIAL_EXEC_PATTERN=none 27 | make html 28 | ``` 29 | 30 | ### Clean 31 | To remove all generated files: 32 | ``` 33 | make clean 34 | ``` 35 | 36 | ### Serve 37 | Run an HTTP server and visit http://localhost:8000 in your browser. 38 | ``` 39 | python3 -m http.server --d _build/html 40 | ``` 41 | 42 | ### Publish 43 | Clone [alpa-projects.github.io](https://github.com/alpa-projects/alpa-projects.github.io) and make sure you have write access. 44 | 45 | ```bash 46 | export ALPA_SITE_PATH=~/efs/alpa-projects.github.io # update this with your path 47 | ./publish.py 48 | ``` 49 | 50 | ## Add new documentations 51 | Alpa uses [Sphinx](https://www.sphinx-doc.org/en/master/index.html) to generate static documentation website and use [Sphinx-gallery](https://sphinx-gallery.github.io/stable/index.html) to generate gallery examples. 52 | 53 | Your new example should be created under `docs/gallery`. 54 | 55 | ### Define the Order of Tutorials 56 | You can define the order of tutorials with `subsection_order` and 57 | `within_subsection_order` in [`conf.py`](conf.py). 58 | By default, the tutorials within one subsection are sorted by filename. 59 | -------------------------------------------------------------------------------- /docs/architecture/alpa-arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alpa-projects/alpa/b8078a9f75cb4c90cabb4550ee48c99ef394e209/docs/architecture/alpa-arch.png -------------------------------------------------------------------------------- /docs/architecture/cluster-mesh.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alpa-projects/alpa/b8078a9f75cb4c90cabb4550ee48c99ef394e209/docs/architecture/cluster-mesh.png -------------------------------------------------------------------------------- /docs/architecture/intra_op_solver.rst: -------------------------------------------------------------------------------- 1 | ===================================== 2 | Code Structure of the Intra-op Solver 3 | ===================================== 4 | 5 | The specific code of the intra-op solver (a.k.a auto-sharding) is scattered 6 | in various files of the project. 7 | This page contains some pointers to key components of the intra-op solver and 8 | help you navigate the complicated code base. 9 | 10 | .. note:: 11 | 12 | All the links below are based on alpa v0.2.2 13 | 14 | 15 | Key Pointers 16 | ============ 17 | 18 | - Main entrance: 19 | - python entrance (``run_auto_sharding_pass``): https://github.com/alpa-projects/alpa/blob/181de4f5577a72c9b30525ed3da09e5b2138cc2c/alpa/shard_parallel/auto_sharding.py#L172 20 | - c++ entrance: https://github.com/alpa-projects/tensorflow-alpa/blob/cd865615b9b518bc507fbdc71dc44c7cc76618ac/tensorflow/compiler/xla/service/spmd/auto_sharding.cc#L2124 21 | 22 | - Where the possible sharding strategies are registred: 23 | - for matmul: https://github.com/alpa-projects/tensorflow-alpa/blob/cd865615b9b518bc507fbdc71dc44c7cc76618ac/tensorflow/compiler/xla/service/spmd/auto_sharding_dot_handler.cc#L327-L408 24 | - for elementwise operators: https://github.com/alpa-projects/tensorflow-alpa/blob/cd865615b9b518bc507fbdc71dc44c7cc76618ac/tensorflow/compiler/xla/service/spmd/auto_sharding.cc#L967-L1016 25 | 26 | - Where the ILP solver is called: 27 | - c++ side: https://github.com/alpa-projects/tensorflow-alpa/blob/cd865615b9b518bc507fbdc71dc44c7cc76618ac/tensorflow/compiler/xla/service/spmd/auto_sharding.cc#L2259 28 | - python side: https://github.com/alpa-projects/alpa/blob/181de4f5577a72c9b30525ed3da09e5b2138cc2c/alpa/shard_parallel/auto_sharding.py#L588 29 | 30 | 31 | How to Read and Learn the Code 32 | ============================== 33 | .. _learn-intra-op-solver: 34 | 35 | Run some simple examples 36 | ~~~~~~~~~~~~~~~~~~~~~~~~ 37 | You can run the unit tests under https://github.com/alpa-projects/alpa/tree/v0.2.2/tests/shard_parallel and set break points in the python entrance ``run_auto_sharding_pass``. 38 | You can start from the most basic ones in ``test_basic.py``. 39 | 40 | Inspect the sharding strategy 41 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 42 | You can print the HLO before and after the ``run_auto_sharding_pass``. 43 | 44 | 45 | How to Debug 46 | ============ 47 | - Set global environment variable ``ALPA_DEBUG_PRINT_AS_STRATEGY=1``. This will print the choosen sharding strategy for each instruction and edge costs in a prettier way. 48 | - Check batch dim analysis https://github.com/alpa-projects/tensorflow-alpa/blob/721260d122f096040762b2d226b37e8ab23f74b8/tensorflow/compiler/xla/service/spmd/auto_sharding_util.cc#L857 49 | -------------------------------------------------------------------------------- /docs/architecture/mesh-worker.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alpa-projects/alpa/b8078a9f75cb4c90cabb4550ee48c99ef394e209/docs/architecture/mesh-worker.png -------------------------------------------------------------------------------- /docs/architecture/parallelism-view-and-rationale.rst: -------------------------------------------------------------------------------- 1 | .. _rationale: 2 | 3 | Rationale 4 | ========= 5 | test 6 | -------------------------------------------------------------------------------- /docs/benchmark/bench-paper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alpa-projects/alpa/b8078a9f75cb4c90cabb4550ee48c99ef394e209/docs/benchmark/bench-paper.png -------------------------------------------------------------------------------- /docs/benchmark/benchmark.rst: -------------------------------------------------------------------------------- 1 | Performance Benchmark 2 | ===================== 3 | 4 | The figure below shows the scaling efficiency of Alpa on training models with billions of parameters on an AWS cluster. 5 | The instructions to reproduce the benchmark results is in this `README.md `_. 6 | The explanation of the results can be found in Section 8.1 of `Alpa paper `_. 7 | 8 | .. figure:: bench-paper.png 9 | :align: center 10 | 11 | .. raw:: html 12 | 13 |

14 | -------------------------------------------------------------------------------- /docs/cluster_setup.md: -------------------------------------------------------------------------------- 1 | # AWS Cluster Setup Guide 2 | 3 | 1. Create a [placement group](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/placement-groups.html) on the AWS Management Console. Choose the `Cluster` placement strategy. This can make sure the interconnection bandwidth among different nodes in the cluster are high. 4 | 2. Create a securiy group on the AWS Management Console (EC2 -> Network & Security -> Security Groups). 5 | 3. Create an [EFS](https://console.aws.amazon.com/efs). This is used as an NFS for all nodes in the cluster. Please add the security group ID of the node you just started (can be found on the AWS Management Console) to the EFS to make sure your node can access the EFS. After that, you need to install the [efs-utils](https://docs.aws.amazon.com/efs/latest/ug/installing-other-distro.html) to mount the EFS on the node: 6 | ```bash 7 | git clone https://github.com/aws/efs-utils 8 | cd efs-utils 9 | ./build-deb.sh 10 | sudo apt-get -y install ./build/amazon-efs-utils*deb 11 | ``` 12 | You can try to mount the EFS on the node by: 13 | ```bash 14 | mkdir -p ~/efs 15 | sudo mount -t efs {Your EFS file system ID}:/ ~/efs 16 | sudo chmod 777 ~/efs 17 | ``` 18 | If this takes forever, make sure you configure the sercurity groups right. 19 | 20 | 21 | Clone the git repos under `~/efs`. 22 | -------------------------------------------------------------------------------- /docs/gallery/tutorials/README.rst: -------------------------------------------------------------------------------- 1 | Alpa Tutorials 2 | ============== 3 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Alpa Documentation 2 | ================== 3 | .. raw:: html 4 | 5 | Star 6 | Fork 7 | 8 |

9 | 10 | Alpa is a system for training and serving large-scale neural networks. 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | :caption: Getting Started 15 | 16 | install.rst 17 | tutorials/quickstart.rst 18 | 19 | .. toctree:: 20 | :maxdepth: 1 21 | :caption: Tutorials 22 | 23 | tutorials/pipeshard_parallelism.rst 24 | tutorials/alpa_vs_pmap.rst 25 | tutorials/opt_serving.rst 26 | tutorials/perf_tuning_guide.rst 27 | tutorials/icml_big_model_tutorial.rst 28 | tutorials/alpa_on_slurm.rst 29 | tutorials/faq.rst 30 | 31 | .. toctree:: 32 | :maxdepth: 1 33 | :caption: Architecture 34 | 35 | architecture/overview.rst 36 | architecture/alpa_compiler_walk_through.rst 37 | architecture/intra_op_solver.rst 38 | 39 | .. toctree:: 40 | :maxdepth: 1 41 | :caption: Benchmark 42 | 43 | benchmark/benchmark.rst 44 | 45 | .. toctree:: 46 | :maxdepth: 1 47 | :caption: Publications 48 | 49 | publications/publications.rst 50 | 51 | .. toctree:: 52 | :maxdepth: 1 53 | :caption: Developer Guide 54 | 55 | developer/developer_guide.rst 56 | -------------------------------------------------------------------------------- /docs/logo/alpa-logo-cropped.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alpa-projects/alpa/b8078a9f75cb4c90cabb4550ee48c99ef394e209/docs/logo/alpa-logo-cropped.png -------------------------------------------------------------------------------- /docs/logo/alpa-logo-cropped.svg: -------------------------------------------------------------------------------- 1 | Alpa -------------------------------------------------------------------------------- /docs/logo/alpa-logo-no-word.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alpa-projects/alpa/b8078a9f75cb4c90cabb4550ee48c99ef394e209/docs/logo/alpa-logo-no-word.ico -------------------------------------------------------------------------------- /docs/logo/alpa-logo-no-word.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alpa-projects/alpa/b8078a9f75cb4c90cabb4550ee48c99ef394e209/docs/logo/alpa-logo-no-word.png -------------------------------------------------------------------------------- /docs/logo/alpa-logo.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alpa-projects/alpa/b8078a9f75cb4c90cabb4550ee48c99ef394e209/docs/logo/alpa-logo.ico -------------------------------------------------------------------------------- /docs/logo/alpa-logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alpa-projects/alpa/b8078a9f75cb4c90cabb4550ee48c99ef394e209/docs/logo/alpa-logo.jpg -------------------------------------------------------------------------------- /docs/logo/alpa-logo.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alpa-projects/alpa/b8078a9f75cb4c90cabb4550ee48c99ef394e209/docs/logo/alpa-logo.pdf -------------------------------------------------------------------------------- /docs/logo/alpa-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alpa-projects/alpa/b8078a9f75cb4c90cabb4550ee48c99ef394e209/docs/logo/alpa-logo.png -------------------------------------------------------------------------------- /docs/logo/alpa-logo.psd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alpa-projects/alpa/b8078a9f75cb4c90cabb4550ee48c99ef394e209/docs/logo/alpa-logo.psd -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/publications/publications.rst: -------------------------------------------------------------------------------- 1 | Publications 2 | ============ 3 | 4 | Alpa is developed as a research project with collaborators from multiple institutions. 5 | This page includes references to publications describing the ideas behind Alpa. 6 | 7 | | `Alpa: Automating Inter- and Intra-Operator Parallelism for Distributed Deep Learning `_ 8 | | Lianmin Zheng*, Zhuohan Li*, Hao Zhang*, Yonghao Zhuang, Zhifeng Chen, Yanping Huang, Yida Wang, Yuanzhong Xu, Danyang Zhuo, Eric P. Xing, Joseph E. Gonzalez, Ion Stoica 9 | | *OSDI 2022* 10 | | 11 | | `On Optimizing the Communication of Model Parallelism `_ 12 | | Yonghao Zhuang*, Hexu Zhao*, Lianmin Zheng, Zhuohan Li, Eric P. Xing, Qirong Ho, Joseph E. Gonzalez, Ion Stoica, Hao Zhang 13 | | *MLSys 2023* 14 | | 15 | | `AlpaServe: Statistical Multiplexing with Model Parallelism for Deep Learning Serving `_ 16 | | Zhuohan Li*, Lianmin Zheng*, Yinmin Zhong*, Vincent Liu, Ying Sheng, Xin Jin, Yanping Huang, Zhifeng Chen, Hao Zhang, Joseph E. Gonzalez, Ion Stoica 17 | | *OSDI 2023* 18 | -------------------------------------------------------------------------------- /docs/publish.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import os 4 | from datetime import datetime 5 | 6 | 7 | def run_cmd(cmd): 8 | print(cmd) 9 | os.system(cmd) 10 | 11 | 12 | run_cmd(f"cd $ALPA_SITE_PATH; git pull") 13 | 14 | # (Optional) Remove old files 15 | # run_cmd("rm -rf $ALPA_SITE_PATH/*") 16 | 17 | run_cmd("cp -r _build/html/* $ALPA_SITE_PATH") 18 | 19 | cmd_message = f"Archive {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" 20 | run_cmd( 21 | f"cd $ALPA_SITE_PATH; git add .; git commit -m '{cmd_message}'; git push origin master" 22 | ) 23 | -------------------------------------------------------------------------------- /docs/tutorials/icml_big_model_tutorial.rst: -------------------------------------------------------------------------------- 1 | ICML'22 Big Model Tutorial 2 | ========================== 3 | 4 | Alpa team ran a tutorial on training big models at ICML 2022. This tutorial covers background, concepts, and advanced reserach topics, which will be very helpful for understanding the design of Alpa. 5 | 6 | Recordings, slides, and demos are available at https://sites.google.com/view/icml-2022-big-model 7 | -------------------------------------------------------------------------------- /docs/tutorials/opt_serving.rst: -------------------------------------------------------------------------------- 1 | ../../examples/llm_serving/README.rst -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alpa-projects/alpa/b8078a9f75cb4c90cabb4550ee48c99ef394e209/examples/__init__.py -------------------------------------------------------------------------------- /examples/gpt2/create_config.py: -------------------------------------------------------------------------------- 1 | from transformers import GPT2Config 2 | 3 | config = GPT2Config.from_pretrained("gpt2", resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0, vocab_size=50256) 4 | config.save_pretrained("./norwegian-gpt2") 5 | -------------------------------------------------------------------------------- /examples/gpt2/train_tokenizer.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer 3 | 4 | # load dataset 5 | dataset = load_dataset("oscar", "unshuffled_deduplicated_no", split="train") 6 | 7 | # Instantiate tokenizer 8 | tokenizer = ByteLevelBPETokenizer() 9 | 10 | def batch_iterator(batch_size=1000): 11 | for i in range(0, len(dataset), batch_size): 12 | yield dataset[i: i + batch_size]["text"] 13 | 14 | # Customized training 15 | tokenizer.train_from_iterator(batch_iterator(), vocab_size=50256, min_frequency=2, special_tokens=[ 16 | "", 17 | "", 18 | "", 19 | "", 20 | "", 21 | ]) 22 | 23 | # Save files to disk 24 | tokenizer.save("./norwegian-gpt2/tokenizer.json") 25 | -------------------------------------------------------------------------------- /examples/imagenet/configs/default.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2021 The Flax Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Default Hyperparameter configuration.""" 29 | 30 | import ml_collections 31 | 32 | 33 | def get_config(): 34 | """Get the default hyperparameter configuration.""" 35 | config = ml_collections.ConfigDict() 36 | 37 | # As defined in the `models` module. 38 | config.model = 'ResNet50' 39 | # `name` argument of tensorflow_datasets.builder() 40 | config.dataset = 'imagenet2012:5.*.*' 41 | 42 | config.learning_rate = 0.1 43 | config.warmup_epochs = 5.0 44 | config.momentum = 0.9 45 | config.batch_size = 128 46 | 47 | config.num_epochs = 100.0 48 | config.log_every_steps = 50 49 | 50 | config.cache = True 51 | config.half_precision = False 52 | 53 | # If num_train_steps==-1 then the number of training steps is calculated from 54 | # num_epochs using the entire dataset. Similarly for steps_per_eval. 55 | config.num_train_steps = -1 56 | config.steps_per_eval = -1 57 | return config 58 | -------------------------------------------------------------------------------- /examples/imagenet/configs/fake_data_benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Hyperparameter configuration for Fake data benchmark.""" 16 | 17 | import jax 18 | 19 | from configs import default as default_lib 20 | 21 | 22 | def get_config(): 23 | """Get the hyperparameter configuration for Fake data benchmark.""" 24 | # Override default configuration to avoid duplication of field definition. 25 | config = default_lib.get_config() 26 | config.batch_size = 256 * jax.device_count() 27 | config.half_precision = True 28 | config.num_epochs = 5 29 | 30 | # Previously the input pipeline computed: 31 | # `steps_per_epoch` as input_pipeline.TRAIN_IMAGES // batch_size 32 | config.num_train_steps = 1024 // config.batch_size 33 | # and `steps_per_eval` as input_pipeline.EVAL_IMAGES // batch_size 34 | config.steps_per_eval = 512 // config.batch_size 35 | 36 | return config 37 | -------------------------------------------------------------------------------- /examples/imagenet/configs/tpu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2021 The Flax Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Hyperparameter configuration to run the example on TPUs.""" 29 | 30 | import ml_collections 31 | 32 | 33 | def get_config(): 34 | """Get the hyperparameter configuration to train on TPUs.""" 35 | config = ml_collections.ConfigDict() 36 | 37 | # As defined in the `models` module. 38 | config.model = 'ResNet50' 39 | # `name` argument of tensorflow_datasets.builder() 40 | config.dataset = 'imagenet2012:5.*.*' 41 | 42 | config.learning_rate = 0.1 43 | config.warmup_epochs = 5.0 44 | config.momentum = 0.9 45 | 46 | config.num_epochs = 100.0 47 | config.log_every_steps = 100 48 | 49 | # If num_train_steps==-1 then the number of training steps is calculated from 50 | # num_epochs using the entire dataset. Similarly for steps_per_eval. 51 | config.num_train_steps = -1 52 | config.steps_per_eval = -1 53 | 54 | # Consider setting the batch size to max(tpu_chips * 256, 8 * 1024) if you 55 | # train on a larger pod slice. 56 | config.batch_size = 1024 57 | config.cache = True 58 | config.half_precision = True 59 | 60 | return config 61 | -------------------------------------------------------------------------------- /examples/imagenet/configs/v100_x8.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Hyperparameter configuration to run the example on 8 x Nvidia V100 GPUs.""" 16 | 17 | from configs import default as default_lib 18 | 19 | 20 | def get_config(): 21 | """Get the hyperparameter configuration to train on 8 x Nvidia V100 GPUs.""" 22 | # Override default configuration to avoid duplication of field definition. 23 | config = default_lib.get_config() 24 | 25 | config.batch_size = 512 26 | config.cache = True 27 | 28 | return config 29 | -------------------------------------------------------------------------------- /examples/imagenet/configs/v100_x8_mixed_precision.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Hyperparameter configuration to run the example on 8 x Nvidia V100 GPUs.""" 16 | 17 | from configs import default as default_lib 18 | 19 | 20 | def get_config(): 21 | """Get the hyperparameter configuration to train on 8 x Nvidia V100 GPUs.""" 22 | # Override default configuration to avoid duplication of field definition. 23 | config = default_lib.get_config() 24 | 25 | config.batch_size = 2048 26 | config.cache = True 27 | config.half_precision = True 28 | 29 | return config 30 | -------------------------------------------------------------------------------- /examples/imagenet/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Main file for running the ImageNet example. 16 | 17 | This file is intentionally kept short. The majority for logic is in libraries 18 | that can be easily tested and imported in Colab. 19 | """ 20 | 21 | from absl import app 22 | from absl import flags 23 | from absl import logging 24 | from clu import platform 25 | import jax 26 | from ml_collections import config_flags 27 | import tensorflow as tf 28 | 29 | import train 30 | 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | flags.DEFINE_string('workdir', None, 'Directory to store model data.') 35 | config_flags.DEFINE_config_file( 36 | 'config', 37 | None, 38 | 'File path to the training hyperparameter configuration.', 39 | lock_config=True) 40 | 41 | 42 | def main(argv): 43 | if len(argv) > 1: 44 | raise app.UsageError('Too many command-line arguments.') 45 | 46 | # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make 47 | # it unavailable to JAX. 48 | tf.config.experimental.set_visible_devices([], 'GPU') 49 | 50 | #logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) 51 | #logging.info('JAX local devices: %r', jax.local_devices()) 52 | 53 | # Add a note so that we can tell which task is which JAX host. 54 | # (Depending on the platform task 0 is not guaranteed to be host 0) 55 | #platform.work_unit().set_task_status(f'process_index: {jax.process_index()}, ' 56 | # f'process_count: {jax.process_count()}') 57 | platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, 58 | FLAGS.workdir, 'workdir') 59 | 60 | train.train_and_evaluate(FLAGS.config, FLAGS.workdir) 61 | 62 | 63 | if __name__ == '__main__': 64 | flags.mark_flags_as_required(['config', 'workdir']) 65 | app.run(main) 66 | -------------------------------------------------------------------------------- /examples/llm_serving/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alpa-projects/alpa/b8078a9f75cb4c90cabb4550ee48c99ef394e209/examples/llm_serving/__init__.py -------------------------------------------------------------------------------- /examples/llm_serving/codegen.py: -------------------------------------------------------------------------------- 1 | """Use huggingface/transformers interface and Alpa backend for distributed inference.""" 2 | import argparse 3 | 4 | import numpy as np 5 | from transformers import AutoTokenizer 6 | 7 | from llm_serving.model.wrapper import get_model 8 | 9 | def main(args): 10 | # Load the tokenizer. 11 | if "codegen" in args.model: 12 | name = args.model.replace("alpa", "Salesforce")\ 13 | .replace("jax", "Salesforce") 14 | tokenizer = AutoTokenizer.from_pretrained(name, padding_side = "left") 15 | tokenizer.pad_token = 50256 16 | generate_params = { 17 | "do_sample": args.do_sample, 18 | "num_beams": args.num_beams, 19 | "num_return_sequences": args.num_return_sequences 20 | } 21 | 22 | # Load the model 23 | model = get_model(model_name=args.model, 24 | path="~/codegen_weights", 25 | batch_size=args.n_prompts, 26 | **generate_params) 27 | 28 | # Generate 29 | prompts = [ 30 | "# This function prints hello world.\n", 31 | "def fib(k):\n # Returns the k-th Fibonacci number.\n", 32 | "def is_prime(n):\n # Return whether n is a prime number.\n", 33 | "def return_len(s):\n # Return the length of s.\n", 34 | ] 35 | prompts = prompts[:args.n_prompts] 36 | 37 | input_ids = tokenizer(prompts, return_tensors="pt", padding="longest").input_ids 38 | 39 | output_ids = model.generate(input_ids=input_ids, 40 | max_length=64, 41 | **generate_params) 42 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True, 43 | truncate_before_pattern=[r"\n\n^#", "^'''", "\n\n\n"]) 44 | 45 | # Print results 46 | print("Outputs:\n" + 100 * '-') 47 | for i, output in enumerate(outputs): 48 | print(f"{i}: {output}") 49 | print(100 * '-') 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("--model", type=str, default="alpa/codegen-2B-mono") 55 | # help: see https://github.com/salesforce/CodeGen for a list of available models. 56 | parser.add_argument('--do-sample', action='store_true') 57 | parser.add_argument('--num-beams', type=int, default=1) 58 | parser.add_argument('--num-return-sequences', type=int, default=1) 59 | parser.add_argument('--n-prompts', type=int, default=4) 60 | args = parser.parse_args() 61 | 62 | main(args) 63 | -------------------------------------------------------------------------------- /examples/llm_serving/log_config.yaml: -------------------------------------------------------------------------------- 1 | version: 1 2 | formatters: 3 | simple: 4 | format: "%(asctime)s | %(levelname)s | %(name)s | %(message)s" 5 | datefmt: "%Y-%m-%d %H:%M:%S" 6 | handlers: 7 | console: 8 | class : logging.StreamHandler 9 | formatter: simple 10 | level : INFO 11 | stream : ext://sys.stdout 12 | file: 13 | class : logging.handlers.TimedRotatingFileHandler 14 | filename: weblogs/llm_serving.website.log 15 | when: "D" 16 | utc: True 17 | formatter: simple 18 | level : INFO 19 | root: 20 | level: INFO 21 | handlers: [console, file] 22 | -------------------------------------------------------------------------------- /examples/llm_serving/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alpa-projects/alpa/b8078a9f75cb4c90cabb4550ee48c99ef394e209/examples/llm_serving/model/__init__.py -------------------------------------------------------------------------------- /examples/llm_serving/model/opt_utils.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import jax 4 | from jax import xla, jit 5 | from jax.core import Primitive 6 | from jax._src.lib import xla_client as xc 7 | from transformers.generation_utils import dataclass 8 | 9 | 10 | def sync(device_id=0): 11 | jax.devices()[device_id].synchronize_all_activity() 12 | return 13 | 14 | 15 | @dataclass 16 | class TransformerModelConfig: 17 | # hidden size 18 | H: int = 768 19 | # number of layers 20 | L: int = 12 21 | # number of attention heads 22 | n_head: int = 12 23 | seq_len: int = 2048 24 | vocab_size: int = 50272 25 | 26 | 27 | def compute_gpt_tflops_inference_with_padding(batch_size, gen_len, seq_len, 28 | num_layers, hidden_size, 29 | vocab_size, num_gpus, latency): 30 | """This calculation assumes that each code decoded attend to seq_len number tokens.""" 31 | factor = 24 32 | total_flop = factor * batch_size * gen_len * (hidden_size ** 2) * num_layers * \ 33 | (1 + seq_len / (6 * hidden_size)) \ 34 | + 2 * batch_size * gen_len * hidden_size * vocab_size 35 | # Note (Hao): it should be 4 here because of input embedding, but we will 36 | # respect Deepak's eq. instead. 37 | tflops = total_flop / latency / num_gpus / 1e12 38 | return tflops 39 | 40 | 41 | def is_power_of_two(n): 42 | return (n != 0) and (n & (n-1) == 0) 43 | 44 | 45 | index_select_p = Primitive("index-select") 46 | 47 | 48 | @partial(jit, static_argnums=(2,)) 49 | def jax_index_select(input, index, dim=0): 50 | return index_select_p.bind(input, index, dim=dim) 51 | 52 | 53 | def _index_select_eval(input, index, dim): 54 | return input 55 | 56 | 57 | def _index_select_translation(c, input, index, dim): 58 | return xc.ops.IndexSelect(input, index, dim) 59 | 60 | 61 | index_select_p.def_abstract_eval(_index_select_eval) 62 | index_select_p.def_impl(partial(xla.apply_primitive, index_select_p)) 63 | xla.translations[index_select_p] = _index_select_translation 64 | -------------------------------------------------------------------------------- /examples/llm_serving/scripts/step_3_convert_to_numpy_weights.py: -------------------------------------------------------------------------------- 1 | """Convert Metaseq's OPT model weights into Alpa numpy weights.""" 2 | import time 3 | 4 | import argparse 5 | import os 6 | 7 | import numpy as np 8 | from llm_serving.scripts.utils import torch_load_cpu 9 | 10 | 11 | def save_numpy(weight_dict, to_folder): 12 | os.makedirs(to_folder, exist_ok=True) 13 | for tensor_name, tensor in weight_dict.items(): 14 | print(f"- Writing tensor {tensor_name} with shape {tensor.shape}") 15 | t = tensor.cpu().detach().numpy() 16 | with open(to_folder + "/" + tensor_name, "wb") as g: 17 | np.save(g, t) 18 | 19 | 20 | if __name__ == "__main__": 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--ckpt-path", type=str, default="/home/ubuntu/consolidated") 23 | parser.add_argument("--output-folder", type=str, default="/home/ubuntu/opt-175b-np") 24 | args = parser.parse_args() 25 | start_time = time.time() 26 | print("- Reading the weight into memory") 27 | state = torch_load_cpu(args.ckpt_path) 28 | print(f"Done with reading: {time.time() - start_time} seconds") 29 | save_numpy(state["model"], args.output_folder) 30 | -------------------------------------------------------------------------------- /examples/llm_serving/scripts/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from omegaconf.dictconfig import DictConfig 3 | 4 | 5 | def recursively_cast_dictconfigs(cfg): 6 | if isinstance(cfg, DictConfig): 7 | return {k2: recursively_cast_dictconfigs(v2) for k2, v2 in cfg.items()} 8 | else: 9 | return cfg 10 | 11 | 12 | def torch_load_cpu(path): 13 | state = torch.load(path, map_location=torch.device("cpu")) 14 | # If model was trained with fp16, model from loaded state_dict can be moved to fp16 15 | if not isinstance(state, dict): 16 | return state 17 | if "cfg" in state: 18 | state["cfg"] = recursively_cast_dictconfigs(state["cfg"]) 19 | if ( 20 | state["cfg"]["common"]["fp16"] 21 | or state["cfg"]["common"]["memory_efficient_fp16"] 22 | ): 23 | state["model"] = {k: v.half() for k, v in state["model"].items()} 24 | 25 | return state 26 | 27 | 28 | def load_and_pop_last_optimizer_state(pth): 29 | st = torch_load_cpu(pth) 30 | st.pop("last_optimizer_state", None) 31 | return st 32 | -------------------------------------------------------------------------------- /examples/llm_serving/service/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alpa-projects/alpa/b8078a9f75cb4c90cabb4550ee48c99ef394e209/examples/llm_serving/service/__init__.py -------------------------------------------------------------------------------- /examples/llm_serving/service/constants.py: -------------------------------------------------------------------------------- 1 | """Hyper params for serving Meta's OPT model.""" 2 | from enum import Enum 3 | 4 | # Alpa serve url 5 | ALPA_SERVE_PORT = 20001 6 | ALPA_SERVE_URL = f"window.location.protocol + '//' + window.location.hostname + ':{ALPA_SERVE_PORT}/completions'" 7 | #ALPA_SERVE_URL = f'"completions"' 8 | 9 | # Generation params 10 | NUM_BEAMS = 1 11 | NUM_RETURN_SEQ = 1 12 | 13 | # Authentication params 14 | USE_RECAPTCHA = False 15 | USE_API_KEYS = False 16 | ALLOW_NON_KEY_ACCESS = True 17 | KEYS_FILENAME = "/home/ubuntu/efs/alpa/examples/llm_serving/keys_file.json" 18 | 19 | # Scheduler params 20 | class AuthGroups(Enum): 21 | RECAPTCHA_USER = 1 22 | API_KEY_USER = 2 23 | NON_KEY_USER = 3 24 | 25 | AUTH_GROUP_WEIGHTS = { 26 | AuthGroups.RECAPTCHA_USER: 300, 27 | AuthGroups.API_KEY_USER: 10, 28 | AuthGroups.NON_KEY_USER: 1 29 | } 30 | AUTH_GROUP_SCHEDULER_SCALE = 300 31 | API_KEY_SCHEDULER_SCALE = 100 32 | API_KEY_DEFAULT_WEIGHT = 10 33 | LOGPROBS_PRIORITY_TIME_LIMIT_S = 15 34 | 35 | # Logging params 36 | LOGDIR = "weblogs" 37 | -------------------------------------------------------------------------------- /examples/llm_serving/service/static/img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alpa-projects/alpa/b8078a9f75cb4c90cabb4550ee48c99ef394e209/examples/llm_serving/service/static/img.png -------------------------------------------------------------------------------- /examples/llm_serving/service/utils.py: -------------------------------------------------------------------------------- 1 | """Adapted from Metaseq.""" 2 | import datetime 3 | import logging 4 | import logging.handlers 5 | import os 6 | import sys 7 | 8 | from llm_serving.service.constants import LOGDIR 9 | 10 | 11 | handler = None 12 | 13 | 14 | def build_logger(): 15 | global handler 16 | 17 | formatter = logging.Formatter( 18 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 19 | datefmt="%Y-%m-%d %H:%M:%S", 20 | ) 21 | 22 | # Set the format of root handlers 23 | if not logging.getLogger().handlers: 24 | logging.basicConfig(level=logging.INFO) 25 | logging.getLogger().handlers[0].setFormatter(formatter) 26 | 27 | # Redirect stdout and stderr to loggers 28 | stdout_logger = logging.getLogger("stdout") 29 | stdout_logger.setLevel(logging.INFO) 30 | sl = StreamToLogger(stdout_logger, logging.INFO) 31 | sys.stdout = sl 32 | 33 | stderr_logger = logging.getLogger("stderr") 34 | stderr_logger.setLevel(logging.ERROR) 35 | sl = StreamToLogger(stderr_logger, logging.ERROR) 36 | sys.stderr = sl 37 | 38 | # Get logger 39 | logger = logging.getLogger("alpa.llm_serving") 40 | logger.setLevel(logging.INFO) 41 | 42 | # Add a file handler for all loggers 43 | if handler is None: 44 | os.makedirs(LOGDIR, exist_ok=True) 45 | filename = os.path.join(LOGDIR, f"llm_serving.worker.log") 46 | handler = logging.handlers.TimedRotatingFileHandler( 47 | filename, when='D', utc=True) 48 | handler.setFormatter(formatter) 49 | 50 | for name, item in logging.root.manager.loggerDict.items(): 51 | if isinstance(item, logging.Logger): 52 | item.addHandler(handler) 53 | 54 | return logger 55 | 56 | 57 | class StreamToLogger(object): 58 | """ 59 | Fake file-like stream object that redirects writes to a logger instance. 60 | """ 61 | def __init__(self, logger, log_level=logging.INFO): 62 | self.terminal = sys.stdout 63 | self.logger = logger 64 | self.log_level = log_level 65 | self.linebuf = '' 66 | 67 | def __getattr__(self, attr): 68 | return getattr(self.terminal, attr) 69 | 70 | def write(self, buf): 71 | temp_linebuf = self.linebuf + buf 72 | self.linebuf = '' 73 | for line in temp_linebuf.splitlines(True): 74 | # From the io.TextIOWrapper docs: 75 | # On output, if newline is None, any '\n' characters written 76 | # are translated to the system default line separator. 77 | # By default sys.stdout.write() expects '\n' newlines and then 78 | # translates them so this is still cross platform. 79 | if line[-1] == '\n': 80 | self.logger.log(self.log_level, line.rstrip()) 81 | else: 82 | self.linebuf += line 83 | 84 | def flush(self): 85 | if self.linebuf != '': 86 | self.logger.log(self.log_level, self.linebuf.rstrip()) 87 | self.linebuf = '' 88 | -------------------------------------------------------------------------------- /examples/llm_serving/test_completions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | 4 | python3 test_completions.py --url http://localhost:20001 5 | python3 test_completions.py --url https://api.alpa.ai --api-key YOUR_KEY 6 | """ 7 | import argparse 8 | 9 | from client import Client 10 | 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--url", type=str) 15 | parser.add_argument("--api-key", type=str) 16 | parser.add_argument("--model", type=str, default="default") 17 | args = parser.parse_args() 18 | 19 | client = Client(args.url, api_key=args.api_key, default_model=args.model) 20 | ret = client.completions( 21 | ["Paris is the capital city of", 22 | "Computer science is the study of"] 23 | ) 24 | print(ret) 25 | -------------------------------------------------------------------------------- /examples/llm_serving/test_logprobs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | 4 | python3 test_logprobs.py --url http://localhost:20001 5 | python3 test_logprobs.py --url https://api.alpa.ai --api-key YOUR_KEY 6 | """ 7 | import argparse 8 | import time 9 | 10 | import numpy as np 11 | from scipy.special import softmax 12 | from transformers import AutoTokenizer 13 | 14 | from client import Client 15 | 16 | 17 | if __name__ == "__main__": 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--url", type=str) 20 | parser.add_argument("--api-key", type=str) 21 | args = parser.parse_args() 22 | 23 | client = Client(args.url, api_key=args.api_key) 24 | tokenizer = AutoTokenizer.from_pretrained("facebook/opt-30b", use_fast=False) 25 | tokenizer.add_bos_token = False 26 | 27 | prompts = [ 28 | "Paris is the capital city of France", 29 | "Computer science is the", 30 | ] 31 | 32 | input_ids = tokenizer(prompts, padding="longest").input_ids 33 | top_k = 50 34 | 35 | output = client.logprobs(input_ids, top_k=top_k) 36 | 37 | tic = time.time() 38 | num_tokens = 40 39 | for i in range(num_tokens): 40 | print("=" * 20 + f" Step {i} " + "=" * 20) 41 | for j in range(len(input_ids)): 42 | distribution = np.full((tokenizer.vocab_size + 10), -1e8, dtype=np.float32) 43 | for idx, logprob in zip(output['indices'][j], output['logprobs'][j]): 44 | distribution[idx] = logprob 45 | # distribution = softmax(distribution) 46 | # token = np.random.choice(np.arange(len(distribution)), p=distribution) 47 | token = distribution.argmax() 48 | input_ids[j].append(int(token)) 49 | print(tokenizer.decode(input_ids[j], skip_special_tokens=True)) 50 | print("-" * 20) 51 | output = client.logprobs(input_ids, top_k=top_k, cache_id=output["cache_id"]) 52 | time_cost = time.time() - tic 53 | print(f"Generation throughput: {len(prompts) * num_tokens/time_cost:.2f} token/s") 54 | -------------------------------------------------------------------------------- /examples/llm_serving/test_textgen.sh: -------------------------------------------------------------------------------- 1 | # Test the correctness of textgen.py 2 | set -x 3 | 4 | python3 textgen.py --model bigscience/bloom-560m 5 | python3 textgen.py --model jax/bloom-560m 6 | python3 textgen.py --model alpa/bloom-560m 7 | 8 | python3 textgen.py --model facebook/opt-1.3b 9 | python3 textgen.py --model jax/opt-1.3b 10 | python3 textgen.py --model alpa/opt-1.3b 11 | -------------------------------------------------------------------------------- /examples/llm_serving/textgen.py: -------------------------------------------------------------------------------- 1 | """Use huggingface/transformers interface and Alpa backend for distributed inference.""" 2 | import argparse 3 | 4 | import numpy as np 5 | from transformers import AutoTokenizer 6 | 7 | from llm_serving.model.wrapper import get_model 8 | 9 | def main(args): 10 | # Load the tokenizer. 11 | if "opt" in args.model: 12 | # We have to use the 30B version because other versions have some issues. 13 | # The 30B version works for all OPT models. 14 | tokenizer = AutoTokenizer.from_pretrained("facebook/opt-30b") 15 | tokenizer.add_bos_token = False 16 | elif "bloom" in args.model: 17 | name = args.model.replace("alpa", "bigscience")\ 18 | .replace("jax", "bigscience") 19 | tokenizer = AutoTokenizer.from_pretrained(name) 20 | 21 | generate_params = { 22 | "do_sample": args.do_sample, 23 | "num_beams": args.num_beams, 24 | "num_return_sequences": args.num_return_sequences 25 | } 26 | 27 | # Load the model 28 | model = get_model(model_name=args.model, 29 | path=args.path, 30 | batch_size=args.n_prompts, 31 | **generate_params) 32 | 33 | # Generate 34 | prompts = [ 35 | "Paris is the capital city of", 36 | "Today is a good day and I'd like to", 37 | "Computer Science studies the area of", 38 | "University of California Berkeley is a public university" 39 | ] 40 | prompts = prompts[:args.n_prompts] 41 | input_ids = tokenizer(prompts, return_tensors="pt", padding="longest").input_ids 42 | output_ids = model.generate(input_ids=input_ids, 43 | max_length=64, 44 | **generate_params) 45 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) 46 | 47 | # Print results 48 | print("Outputs:\n" + 100 * '-') 49 | for i, output in enumerate(outputs): 50 | print(f"{i}: {output}") 51 | print(100 * '-') 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument('--model', type=str, default='alpa/opt-1.3b') 57 | parser.add_argument('--path', type=str, default='~/opt_weights') 58 | parser.add_argument('--do-sample', action='store_true') 59 | parser.add_argument('--num-beams', type=int, default=1) 60 | parser.add_argument('--num-return-sequences', type=int, default=1) 61 | parser.add_argument('--n-prompts', type=int, default=4) 62 | args = parser.parse_args() 63 | 64 | main(args) 65 | -------------------------------------------------------------------------------- /examples/llm_serving/textgen_1d.py: -------------------------------------------------------------------------------- 1 | """Use huggingface/transformers interface and Alpa backend for distributed inference.""" 2 | import argparse 3 | import time 4 | 5 | import numpy as np 6 | from transformers import AutoTokenizer 7 | 8 | from llm_serving.model.wrapper_1d import get_model 9 | from llm_serving.model.opt_utils import sync 10 | from alpa.timer import timers 11 | 12 | 13 | def main(args): 14 | # Load the tokenizer. We have to use the 30B version because 15 | # other versions have some issues. The 30B version works for all OPT models. 16 | tokenizer = AutoTokenizer.from_pretrained("facebook/opt-30b", use_fast=False) 17 | tokenizer.add_bos_token = False 18 | 19 | generate_params = { 20 | "do_sample": args.do_sample, 21 | "max_new_tokens": 128, 22 | # "max_length": 128 23 | } 24 | 25 | # Load the model 26 | model = get_model(model_name=args.model, 27 | path="~/opt_weights", 28 | batch_size=32, 29 | cache_size=4096) 30 | 31 | prompts = [ 32 | "Computer science is the study of computation and", 33 | "Ion Stoica is a Romanian-American computer scientist specializing in", 34 | "The University of California, Berkeley is a public", 35 | "Today is a good day and I want to", 36 | "What is the valuation of Databricks?", 37 | "Paris is the capital city of", 38 | "Which country has the most population?", 39 | "What do you think about the future of Cryptocurrency?", 40 | "What do you think about the meaning of life?", 41 | "Donald Trump is the president of", 42 | "GPT-3 is a large language model that is capable of" 43 | ] 44 | 45 | input_ids = tokenizer(prompts, return_tensors="np", padding="longest").input_ids 46 | 47 | n_warmup = 10 48 | for i in range(n_warmup): 49 | sync() 50 | tic = time.time() 51 | output_ids, latency = model.generate(input_ids, **generate_params) 52 | sync() 53 | elapsed = time.time() - tic 54 | print(f"- It takes {elapsed}, latency: {latency}") 55 | 56 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) 57 | if False: 58 | print("Outputs:\n" + 100 * '-') 59 | for i, output in enumerate(outputs): 60 | print(output_ids[i]) 61 | print(f"{i + 1}: {output}") 62 | print(100 * '-') 63 | 64 | 65 | if __name__ == "__main__": 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument("--model", type=str, default="alpa/opt-1d-1.3b") 68 | parser.add_argument('--do-sample', action='store_true') 69 | args = parser.parse_args() 70 | 71 | main(args) 72 | -------------------------------------------------------------------------------- /examples/mnist/README.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- 2 | 3 | Adopted from https://github.com/google/flax/tree/main/examples/mnist. 4 | 5 | Use `alpa.parallelize` to parallelize the training loop. 6 | 7 | 1. Run training with all local GPUs in a single machine. 8 | ``` 9 | python3 main.py --workdir=/tmp/mnist --config=configs/default.py --config.batch_size 8192 10 | ``` 11 | See `train.py` for a minimal example of using alpa on a single machine. 12 | 13 | 2. Run training with all GPUs in a ray cluster 14 | ``` 15 | ray start --head 16 | python3 main.py --workdir=/tmp/mnist --config=configs/default.py --config.batch_size 8192 --use_ray 17 | ``` 18 | See `train_ray.py` for a minimal example of using alpa on a ray cluster. 19 | 20 | -------------------------------------------------------------------------------- 21 | 22 | ## MNIST classification 23 | 24 | Trains a simple convolutional network on the MNIST dataset. 25 | 26 | You can run this code and even modify it directly in Google Colab, no 27 | installation required: 28 | 29 | https://colab.research.google.com/github/google/flax/blob/main/examples/mnist/mnist.ipynb 30 | 31 | ### Requirements 32 | * TensorFlow dataset `mnist` will be downloaded and prepared automatically, if necessary 33 | 34 | ### Example output 35 | 36 | | Name | Epochs | Walltime | Top-1 accuracy | Metrics | Workdir | 37 | | :------ | -----: | :------- | :------------- | :---------- | :---------------------------------------- | 38 | | default | 10 | 7.7m | 99.17% | [tfhub.dev] | [gs://flax_public/examples/mnist/default] | 39 | 40 | [tfhub.dev]: https://tensorboard.dev/experiment/1G9SvrW5RQyojRtMKNmMuQ/#scalars&_smoothingWeight=0®exInput=default 41 | [gs://flax_public/examples/mnist/default]: https://console.cloud.google.com/storage/browser/flax_public/examples/mnist/default 42 | 43 | ``` 44 | I0828 08:51:41.821526 139971964110656 train.py:130] train epoch: 10, loss: 0.0097, accuracy: 99.69 45 | I0828 08:51:42.248714 139971964110656 train.py:180] eval epoch: 10, loss: 0.0299, accuracy: 99.14 46 | ``` 47 | 48 | ### How to run 49 | 50 | `python main.py --workdir=/tmp/mnist --config=configs/default.py` 51 | 52 | #### Overriding Hyperparameter configurations 53 | 54 | MNIST example allows specifying a hyperparameter configuration by the means of 55 | setting `--config` flag. Configuration flag is defined using 56 | [config_flags](https://github.com/google/ml_collections/tree/master#config-flags). 57 | `config_flags` allows overriding configuration fields. This can be done as 58 | follows: 59 | 60 | ```shell 61 | python main.py \ 62 | --workdir=/tmp/mnist --config=configs/default.py \ 63 | --config.learning_rate=0.05 --config.num_epochs=5 64 | ``` 65 | -------------------------------------------------------------------------------- /examples/mnist/configs/default.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Default Hyperparameter configuration.""" 16 | 17 | import ml_collections 18 | 19 | 20 | def get_config(): 21 | """Get the default hyperparameter configuration.""" 22 | config = ml_collections.ConfigDict() 23 | 24 | config.learning_rate = 0.1 25 | config.momentum = 0.9 26 | config.batch_size = 128 27 | config.num_epochs = 10 28 | return config 29 | -------------------------------------------------------------------------------- /examples/mnist/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Main file for running the MNIST example. 16 | 17 | This file is intentionally kept short. The majority of logic is in libraries 18 | than can be easily tested and imported in Colab. 19 | """ 20 | 21 | from absl import app 22 | from absl import flags 23 | from absl import logging 24 | from clu import platform 25 | import jax 26 | from ml_collections import config_flags 27 | import tensorflow as tf 28 | 29 | FLAGS = flags.FLAGS 30 | 31 | flags.DEFINE_string('workdir', None, 'Directory to store model data.') 32 | flags.DEFINE_boolean('use_ray', False, 'Whether to use Ray cluster.') 33 | config_flags.DEFINE_config_file( 34 | 'config', 35 | None, 36 | 'File path to the training hyperparameter configuration.', 37 | lock_config=True) 38 | 39 | 40 | def main(argv): 41 | if len(argv) > 1: 42 | raise app.UsageError('Too many command-line arguments.') 43 | 44 | # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make 45 | # it unavailable to JAX. 46 | tf.config.experimental.set_visible_devices([], 'GPU') 47 | 48 | if FLAGS.use_ray: 49 | import train_ray as train 50 | else: 51 | import train 52 | 53 | train.train_and_evaluate(FLAGS.config, FLAGS.workdir) 54 | 55 | 56 | if __name__ == '__main__': 57 | flags.mark_flags_as_required(['config', 'workdir']) 58 | app.run(main) 59 | -------------------------------------------------------------------------------- /examples/mnist/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | clu==0.0.6 3 | flax==0.3.6 4 | jax==0.2.21 5 | --find-links https://storage.googleapis.com/jax-releases/jax_releases.html 6 | jaxlib==0.1.70+cuda110 # Make sure CUDA version matches the base image. 7 | ml-collections==0.1.0 8 | numpy==1.21.4 9 | optax==0.1.0 10 | tensorflow==2.7.0 11 | tensorflow-datasets==4.4.0 12 | -------------------------------------------------------------------------------- /examples/opt_finetune/README.md: -------------------------------------------------------------------------------- 1 | 16 | 17 | # Fine-tuning OPT Language Models 18 | 19 | ## Instructions 20 | 21 | ### Launch a Ray cluster 22 | 23 | 1. Use the command below to launch ray on a head node 24 | ```ray start --head``` 25 | 2. (Optional) If you have more nodes, connect them to the head node. The command should look like this, but with the ip address and password printed by the previous command. 26 | ```ray start --address='172.31.34.216:6379' --redis-password='5241590000000000'``` 27 | 28 | ### Run training 29 | 30 | **Note**: The command below is tested on AWS p3.16xlarge instances with 8 x 16GB V100 GPUs. 31 | To run on other clusters, please tune the arguments `per_device_train_batch_size/num_micro_batches/operator_parallel/pipeline_parallel` to avoid out-of-memory and achieve a good throughput. 32 | ``` 33 | python3 run_clm_flax.py \ 34 | --output_dir="./output" \ 35 | --model_name_or_path="facebook/opt-2.7b" \ 36 | --dataset_name="wikitext" \ 37 | --dataset_config_name="wikitext-2-raw-v1" \ 38 | --do_train --do_eval \ 39 | --block_size="1024" \ 40 | --per_device_train_batch_size="20" \ 41 | --per_device_eval_batch_size="20" \ 42 | --num_micro_batches 4 \ 43 | --operator_parallel 4 \ 44 | --pipeline_parallel 1 \ 45 | --dtype="float16" \ 46 | --learning_rate="5e-4" --warmup_steps="2000" \ 47 | --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \ 48 | --overwrite_output_dir \ 49 | --num_train_epochs="8" \ 50 | --logging_steps="16" \ 51 | --save_steps="2500" \ 52 | --eval_steps="2500" 53 | ``` 54 | 55 | More documentation coming soon. 56 | 57 | 58 | # Acknowledgement 59 | Adopted from https://github.com/huggingface/transformers/tree/main/examples/flax/language-modeling 60 | -------------------------------------------------------------------------------- /examples/opt_finetune/run_125m_shard.sh: -------------------------------------------------------------------------------- 1 | python3 run_clm_flax.py \ 2 | --output_dir="./output" \ 3 | --model_name_or_path="facebook/opt-125m" \ 4 | --dataset_name="wikitext" \ 5 | --dataset_config_name="wikitext-2-raw-v1" \ 6 | --do_train --do_eval \ 7 | --block_size="1024" \ 8 | --per_device_train_batch_size="20" \ 9 | --per_device_eval_batch_size="20" \ 10 | --num_micro_batches 4 \ 11 | --operator_parallel 4 \ 12 | --pipeline_parallel 1 \ 13 | --dtype="float16" \ 14 | --learning_rate="5e-4" --warmup_steps="2000" \ 15 | --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \ 16 | --overwrite_output_dir \ 17 | --num_train_epochs="8" \ 18 | --logging_steps="16" \ 19 | --save_steps="32" \ 20 | --eval_steps="32" 21 | -------------------------------------------------------------------------------- /examples/opt_finetune/run_2.7b_pipe.sh: -------------------------------------------------------------------------------- 1 | python3 run_clm_flax.py \ 2 | --output_dir="./output" \ 3 | --model_name_or_path="facebook/opt-2.7b" \ 4 | --dataset_name="wikitext" \ 5 | --dataset_config_name="wikitext-2-raw-v1" \ 6 | --do_train --do_eval \ 7 | --block_size="1024" \ 8 | --per_device_train_batch_size="64" \ 9 | --per_device_eval_batch_size="64" \ 10 | --num_micro_batches 64 \ 11 | --operator_parallel 1 \ 12 | --pipeline_parallel 2 \ 13 | --dtype="float16" \ 14 | --learning_rate="5e-4" --warmup_steps="2000" \ 15 | --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \ 16 | --overwrite_output_dir \ 17 | --num_train_epochs="10" \ 18 | --logging_steps="5" \ 19 | --save_steps="40" \ 20 | --eval_steps="25" 21 | -------------------------------------------------------------------------------- /examples/opt_finetune/run_2.7b_shard.sh: -------------------------------------------------------------------------------- 1 | python3 run_clm_flax.py \ 2 | --output_dir="./output" \ 3 | --model_name_or_path="facebook/opt-2.7b" \ 4 | --dataset_name="wikitext" \ 5 | --dataset_config_name="wikitext-2-raw-v1" \ 6 | --do_train --do_eval \ 7 | --block_size="1024" \ 8 | --per_device_train_batch_size="20" \ 9 | --per_device_eval_batch_size="20" \ 10 | --num_micro_batches 4 \ 11 | --operator_parallel 4 \ 12 | --pipeline_parallel 1 \ 13 | --dtype="float16" \ 14 | --learning_rate="5e-4" --warmup_steps="2000" \ 15 | --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \ 16 | --overwrite_output_dir \ 17 | --num_train_epochs="8" \ 18 | --logging_steps="16" \ 19 | --save_steps="2500" \ 20 | --eval_steps="2500" 21 | -------------------------------------------------------------------------------- /examples/setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from setuptools import find_packages, setup 3 | 4 | setup(name="llm_serving", 5 | packages=find_packages()) 6 | -------------------------------------------------------------------------------- /examples/slurm_script_examples/test_cuda.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=test_cuda 3 | #SBATCH -N 1 4 | #SBATCH -p GPU-shared 5 | #SBATCH -t 1:00 6 | #SBATCH --gpus=v100-16:1 7 | 8 | #import modules 9 | module purge 10 | module load cuda 11 | module load nvhpc 12 | 13 | #check environments 14 | echo $CUDA_HOME 15 | nvcc --version 16 | 17 | #exit 18 | -------------------------------------------------------------------------------- /examples/slurm_script_examples/test_prerequisites.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=test_alpa_prerequisites 3 | #SBATCH -p GPU-shared 4 | #SBATCH -t 1:00 5 | #SBATCH --gpus=v100-16:1 6 | 7 | module load cuda 8 | module load cudnn 9 | module load nvhpc 10 | 11 | nvcc --version 12 | -------------------------------------------------------------------------------- /examples/slurm_script_examples/test_ray_multinode.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=ray_multinode_test 3 | #SBATCH --cpus-per-task=16 4 | #SBATCH --mem-per-cpu=1GB 5 | #SBATCH --ntasks-per-node=1 6 | gpus_per_node=0 7 | # load modules 8 | module purge 9 | conda init bash 10 | source ~/.bashrc 11 | # start conda 12 | conda activate alpa_environment 13 | # environment activated, check environment 14 | python3 -V 15 | python3 -c "from cupy.cuda import nccl" 16 | # Getting the node names 17 | nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST") 18 | nodes_array=($nodes) 19 | 20 | head_node=${nodes_array[0]} 21 | head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) 22 | 23 | # if we detect a space character in the head node IP, we'll 24 | # convert it to an ipv4 address. This step is optional. 25 | if [[ "$head_node_ip" == *" "* ]]; then 26 | IFS=' ' read -ra ADDR <<<"$head_node_ip" 27 | if [[ ${#ADDR[0]} -gt 16 ]]; then 28 | head_node_ip=${ADDR[1]} 29 | else 30 | head_node_ip=${ADDR[0]} 31 | fi 32 | echo "IPV6 address detected. We split the IPV4 address as $head_node_ip" 33 | fi 34 | 35 | # start head node 36 | port=6789 37 | ip_head=$head_node_ip:$port 38 | export ip_head 39 | 40 | srun --nodes=1 --ntasks=1 -w "$head_node" \ 41 | ray start --head --node-ip-address="$head_node_ip" --port=$port \ 42 | --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus $gpus_per_node --block & 43 | 44 | # start worker nodes 45 | # number of nodes other than the head node 46 | worker_num=$((SLURM_JOB_NUM_NODES - 1)) 47 | 48 | for ((i = 1; i <= worker_num; i++)); do 49 | node_i=${nodes_array[$i]} 50 | echo "Starting WORKER $i at $node_i" 51 | srun --nodes=1 --ntasks=1 -w "$node_i" \ 52 | ray start --address "$ip_head" --num-cpus "${SLURM_CPUS_PER_TASK}" \ 53 | --num-gpus $gpus_per_node --block & 54 | sleep 5 55 | done 56 | # try ray 57 | echo "test ray status" 58 | ray list nodes --address "$ip_head" 59 | ray list nodes 60 | ray list actors 61 | ray summary tasks 62 | # end ray 63 | ray stop 64 | # exit environment 65 | conda deactivate 66 | exit 67 | -------------------------------------------------------------------------------- /examples/slurm_script_examples/textgen_alpa_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=ray_singlenode_test 3 | # load modules 4 | module purge 5 | module load cuda 6 | module load nvhpc 7 | conda init bash 8 | source ~/.bashrc 9 | # test nvcc 10 | nvcc --version 11 | # start environment using conda 12 | conda activate alpa_environment 13 | # start ray on head 14 | ray start --head 15 | # start alpa textgen.py 16 | python3 alpa/examples/llm_serving/textgen.py --model alpa/bloom-560m --n-prompts 1 --path $PROJECT/alpa_weights 17 | # end ray 18 | ray stop 19 | # exit environment 20 | conda deactivate 21 | exit 22 | -------------------------------------------------------------------------------- /examples/slurm_script_examples/textgen_pt_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=ray_singlenode_test 3 | # load modules 4 | module purge 5 | module load cuda 6 | module load nvhpc 7 | conda init bash 8 | source ~/.bashrc 9 | # test nvcc 10 | nvcc --version 11 | # start environment using conda 12 | conda activate alpa_environment 13 | # start ray on head 14 | ray start --head 15 | # start alpa textgen.py 16 | python3 alpa/examples/llm_serving/textgen.py --model facebook/opt-125m --n-prompts 1 --path $PROJECT/alpa_weights 17 | # end ray 18 | ray stop 19 | # exit environment 20 | conda deactivate 21 | exit 22 | -------------------------------------------------------------------------------- /playground/alpa_micro_benchmark/test_shard_array.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax.interpreters import pxla 4 | from jax.interpreters.pxla import (ShardingSpec, 5 | NoSharding, Replicated, Chunked, ShardedAxis) 6 | import numpy as np 7 | import ray 8 | 9 | import alpa 10 | 11 | def benchmark(physical_mesh, shape, sharding_spec): 12 | avals = [] 13 | shard_indices = [] 14 | sharding_specs = [] 15 | donated_invars = [] 16 | args = [] 17 | 18 | number = 2 19 | 20 | for i in range(number): 21 | array = jnp.ones(shape, jnp.float32) 22 | indices = sharding_spec.indices(array.shape) 23 | 24 | avals.append(jax.ShapedArray(array.shape, array.dtype)) 25 | sharding_specs.append(sharding_spec) 26 | shard_indices.append(indices.flatten()) 27 | donated_invars.append(True) 28 | args.append(array) 29 | 30 | print(sharding_spec) 31 | buffers = physical_mesh.shard_args_to_bufs(shard_indices, donated_invars, args) 32 | 33 | return buffers 34 | 35 | 36 | if __name__ == "__main__": 37 | ray.init(address="auto") 38 | 39 | cluster = alpa.DeviceCluster() 40 | physical_mesh = cluster.get_physical_mesh() 41 | 42 | shape = (8192, 8192) 43 | 44 | sharding_specs = [ 45 | ShardingSpec( 46 | sharding=[NoSharding(), NoSharding(),], 47 | mesh_mapping=[Replicated(8),]), 48 | ShardingSpec( 49 | sharding=[Chunked([8]), NoSharding(),], 50 | mesh_mapping=[ShardedAxis(0),]), 51 | ShardingSpec( 52 | sharding=[NoSharding(), Chunked([8])], 53 | mesh_mapping=[ShardedAxis(0),]), 54 | ShardingSpec( 55 | sharding=[Chunked([2]), Chunked([4])], 56 | mesh_mapping=[ShardedAxis(0), ShardedAxis(1)]), 57 | ] 58 | 59 | for spec in sharding_specs: 60 | benchmark(physical_mesh, shape, spec) 61 | 62 | -------------------------------------------------------------------------------- /playground/auto_sharding_solver/README.md: -------------------------------------------------------------------------------- 1 | # A Prototype of Auto-sharding Solver 2 | 3 | This is only a prototype in python. It is not used by alpa. 4 | 5 | ## Requirements 6 | ``` 7 | pip3 install pulp 8 | ``` 9 | 10 | ## Examples 11 | ``` 12 | python3 test_solver_mlp.py 13 | ``` 14 | -------------------------------------------------------------------------------- /playground/auto_sharding_solver/common.py: -------------------------------------------------------------------------------- 1 | """Common Utilities""" 2 | 3 | import numpy as np 4 | 5 | 6 | def append_flatten_elements(result, array, indices, cur_depth, cur_indices): 7 | """Append elements of `array` to `result`. The `indices` is a generalized 8 | multi-dimensional index that can index a whole row (use -1 to indicate this)""" 9 | if cur_depth == len(array.shape) - 1: 10 | result.append(array[tuple(cur_indices)]) 11 | else: 12 | next_depth = cur_depth + 1 13 | index = indices[next_depth] 14 | 15 | if index == -1: 16 | for i in range(array.shape[next_depth]): 17 | cur_indices[next_depth] = i 18 | append_flatten_elements(result, array, indices, next_depth, cur_indices) 19 | else: 20 | cur_indices[next_depth] = index 21 | append_flatten_elements(result, array, indices, next_depth, cur_indices) 22 | 23 | 24 | def get_dim_last_value(array, dim): 25 | """Get the value of the last element in a dimension""" 26 | indices = tuple(0 if i != dim else array.shape[dim] - 1 for i in range(len(array.shape))) 27 | return array[indices] 28 | 29 | 30 | def transpose_flatten(array, shape, dimensions): 31 | """Transpose a flatten array""" 32 | array = np.array(array) 33 | return np.array(np.transpose(array.reshape(shape), dimensions)).flatten() 34 | 35 | 36 | def reshape_flatten(array, shape, new_shape): 37 | """Reshape a flatten array""" 38 | array = np.array(array) 39 | return np.array(array.reshape(shape)).flatten() 40 | 41 | 42 | def compute_bytes(shape): 43 | return np.prod(shape) * 4 44 | 45 | -------------------------------------------------------------------------------- /playground/auto_sharding_solver/run_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python3 -m unittest -bv *.py 4 | 5 | -------------------------------------------------------------------------------- /playground/auto_sharding_solver/test_cost.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from cluster_env import ClusterEnvironment 4 | 5 | def s(*shape): 6 | return np.prod(shape) * 4 7 | 8 | env = ClusterEnvironment(np.ones((8, 1)), [1, 1], [0.02, 0.02], 0) 9 | 10 | a = env.all_reduce_cost(s(16, 14, 14, 8192)) + env.all_reduce_cost(s(16, 28, 28, 2048)) + \ 11 | env.all_to_all_cost(s(16, 28, 28, 4096)) 12 | 13 | print(a) 14 | 15 | 16 | b = env.all_gather_cost(s(16, 28, 28, 4096)) + env.all_gather_cost(s(1, 1, 4096, 8192)) 17 | print(b) 18 | 19 | 20 | -------------------------------------------------------------------------------- /playground/jax_basic/test_flop_count.py: -------------------------------------------------------------------------------- 1 | import jax, jax.numpy as jnp 2 | 3 | def func(a, b): 4 | c = jnp.asarray(a, jnp.int32) @ jnp.asarray(b, jnp.int32) 5 | #c = a @ b 6 | c = c.transpose() 7 | c += a 8 | return c 9 | 10 | a = jnp.ones((100, 100)) 11 | b = jnp.ones((100, 100)) 12 | 13 | m = jax.xla_computation(func)(a, b).as_hlo_module() 14 | print(m.to_string()) 15 | r = jax.lib.xla_client._xla.hlo_module_count_flop_dot_conv_only(m) 16 | print(r) 17 | 18 | -------------------------------------------------------------------------------- /playground/jax_basic/test_jit.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jax 3 | from jax import numpy as jnp 4 | 5 | def test_jit_cache(): 6 | 7 | @jax.jit 8 | def add_one(x): 9 | return x + 1 10 | 11 | a = jnp.ones(10) 12 | 13 | print(add_one(a)) 14 | print(add_one(a)) 15 | print(add_one(a)) 16 | 17 | 18 | def test_cache_closure(): 19 | outer_scope = [0] 20 | 21 | @jax.jit 22 | def add_one(x): 23 | print('call add_one') 24 | return x + outer_scope[0] 25 | 26 | a = jnp.ones(10) 27 | 28 | print(add_one(a)) 29 | print(add_one(a)) 30 | outer_scope[0] = 1 31 | print(add_one(a)) 32 | 33 | 34 | 35 | def test_non_jit(): 36 | a = jnp.array(np.ones(10)) 37 | b = jnp.array(np.ones(10)) 38 | c = a + b 39 | c = a + c 40 | c = a + c 41 | 42 | print(c) 43 | 44 | 45 | if __name__ == "__main__": 46 | #test_jit_cache() 47 | test_cache_closure() 48 | #test_non_jit() 49 | 50 | -------------------------------------------------------------------------------- /playground/jax_basic/test_memory_allocator.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import jax 4 | from jax import numpy as jnp 5 | 6 | def run_cmd(x): 7 | os.system(x) 8 | 9 | def test_platform_allocator(): 10 | os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" 11 | #os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" 12 | 13 | a = jnp.ones(1 << 30) 14 | 15 | run_cmd("nvidia-smi") 16 | 17 | a = None 18 | 19 | run_cmd("nvidia-smi") 20 | 21 | 22 | if __name__ == "__main__": 23 | test_platform_allocator() 24 | 25 | -------------------------------------------------------------------------------- /playground/jax_basic/test_pmap.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import jax 3 | from jax import lax 4 | import jax.numpy as jnp 5 | 6 | 7 | def debug_pmap(): 8 | @jax.pmap 9 | def func(x, w): 10 | return x @ w 11 | 12 | y = func(jnp.ones((2, 4)), jnp.ones((2, 4))) 13 | print(y, type(y)) 14 | 15 | 16 | def test_nested_pmap(): 17 | @partial(jax.pmap, axis_name='a0', in_axes=(0, None), out_axes=0) 18 | def add(a, b): 19 | # a.shape = (32, 64) 20 | # b.shape = (64, 2, 32) 21 | @partial(jax.pmap, axis_name='a1', in_axes=(None, 1), out_axes=1) 22 | def add_inner(x, y): 23 | # x.shape = (32, 64) 24 | # y.shape = (64, 32) 25 | return x @ y 26 | 27 | # ret.shape = (32, 2, 32) 28 | ret = add_inner(a, b) 29 | return ret 30 | 31 | a = jnp.ones((2, 32, 64)) 32 | b = jnp.ones((64, 2, 32)) 33 | 34 | #jaxpr = jax.make_jaxpr(add)(a, b) 35 | #print(jaxpr) 36 | #print(jaxpr.jaxpr.outvars[0].aval.shape) 37 | 38 | c = add(a, b) 39 | print(c) 40 | 41 | 42 | def test_allreduce_sum(): 43 | @partial(jax.pmap, axis_name='i') 44 | def normalize(x): 45 | return x / lax.psum(x, 'i') 46 | 47 | print(normalize(jnp.arange(2))) 48 | 49 | 50 | if __name__ == "__main__": 51 | #debug_pmap() 52 | #test_nested_pmap() 53 | 54 | test_allreduce_sum() 55 | 56 | -------------------------------------------------------------------------------- /playground/jax_basic/test_scan.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import jax 3 | import jax.numpy as jnp 4 | import numpy as np 5 | from flax import linen as nn 6 | from flax import optim 7 | 8 | batch_size = 32 9 | hidden_size = 128 10 | 11 | class Layer(nn.Module): 12 | @nn.compact 13 | def __call__(self, x): 14 | return 15 | 16 | class Model(nn.Module): 17 | def __call__(self, x): 18 | cell = nn.scan( 19 | nn.Dense, 20 | variable_broadcast="params", 21 | in_axes=1, 22 | out_axes=1, 23 | split_rngs={"params": False}, 24 | ) 25 | 26 | @partial(jax.jit, static_argnums=(2,)) 27 | def train_step(optimizer, batch, apply_fn): 28 | def loss_func(params): 29 | out = apply_fn(params, batch["x"]) 30 | return jnp.mean((out - batch["y"]) ** 2) 31 | 32 | grad = jax.grad(loss_func)(optimizer.target) 33 | new_optimizer = optimizer.apply_gradient(grad) 34 | return new_optimizer 35 | 36 | x = jnp.ones((batch_size, hidden_size)) 37 | y = jnp.ones((batch_size, hidden_size)) 38 | 39 | # Init model and optimizer 40 | model = Model() 41 | rngkey = jax.random.PRNGKey(0) 42 | params = model.init(rngkey, x) 43 | optimizer = optim.GradientDescent(1e-2).create(params) 44 | 45 | # JIT compile 46 | optimizer = train_step(optimizer, {"x": x, "y": y}, model.apply) 47 | 48 | -------------------------------------------------------------------------------- /playground/jax_basic/test_tuple_args.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from jax import numpy as jnp 3 | 4 | 5 | @jax.pmap 6 | def many_args(*args): 7 | x = 0 8 | for i in range(len(args)): 9 | x += args[i] 10 | return x 11 | 12 | N = 110 13 | 14 | args = [ 15 | jnp.ones((4, 10)) for _ in range(N) 16 | ] 17 | 18 | out = many_args(*args) 19 | print(out) 20 | 21 | -------------------------------------------------------------------------------- /playground/jax_basic/test_while.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import jax 3 | import jax.numpy as jnp 4 | import numpy as np 5 | from flax import linen as nn 6 | from flax import optim 7 | 8 | batch_size = 32 9 | hidden_size = 128 10 | 11 | class Model(nn.Module): 12 | def setup(self): 13 | self.weight = self.param("weight", 14 | jax.nn.initializers.zeros, (hidden_size, hidden_size)) 15 | 16 | def __call__(self, x): 17 | def cond_func(args): 18 | counter = args[0] 19 | return counter < 5 20 | 21 | def body_func(args): 22 | counter, x = args 23 | return [counter + 1, x @ self.weight] 24 | 25 | return jax.lax.while_loop(cond_func, body_func, [0, x])[1] 26 | 27 | @partial(jax.jit, static_argnums=(2,)) 28 | def train_step(optimizer, batch, apply_fn): 29 | def loss_func(params): 30 | out = apply_fn(params, batch["x"]) 31 | return jnp.mean((out - batch["y"]) ** 2) 32 | 33 | grad = jax.grad(loss_func)(optimizer.target) 34 | new_optimizer = optimizer.apply_gradient(grad) 35 | return new_optimizer 36 | 37 | x = jnp.ones((batch_size, hidden_size)) 38 | y = jnp.ones((batch_size, hidden_size)) 39 | 40 | # Init model and optimizer 41 | model = Model() 42 | rngkey = jax.random.PRNGKey(0) 43 | params = model.init(rngkey, x) 44 | optimizer = optim.GradientDescent(1e-2).create(params) 45 | 46 | # JIT compile 47 | optimizer = train_step(optimizer, {"x": x, "y": y}, model.apply) 48 | 49 | -------------------------------------------------------------------------------- /playground/jax_basic/util.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | 5 | def benchmark_func(func, warmup=1, repeat=3): 6 | for i in range(warmup): 7 | func() 8 | 9 | costs = [] 10 | for i in range(repeat): 11 | tic = time.time() 12 | func() 13 | costs.append(time.time() - tic) 14 | 15 | return np.array(costs) 16 | 17 | -------------------------------------------------------------------------------- /playground/other/test_ray_dataloader.py: -------------------------------------------------------------------------------- 1 | import ray 2 | import jax 3 | 4 | import input_pipeline 5 | 6 | @ray.remote 7 | class Worker: 8 | def __init__(self): 9 | self.generator = None 10 | 11 | def register_generator(self, func): 12 | self.generator = iter(func()) 13 | 14 | def get_next(self): 15 | return next(self.generator) 16 | 17 | 18 | def make_generator(): 19 | import tensorflow as tf 20 | import tensorflow_datasets as tfds 21 | 22 | dataset_builder = tfds.builder('imagenet2012:5.*.*') 23 | batch_size = 64 24 | image_size = 224 25 | dtype = tf.float32 26 | train = True 27 | cache = True 28 | 29 | ds = input_pipeline.create_split( 30 | dataset_builder, batch_size, image_size=image_size, dtype=dtype, 31 | train=train, cache=cache) 32 | it = map(lambda xs: jax.tree_map(lambda x: x._numpy(), xs), ds) 33 | return it 34 | 35 | 36 | if __name__ == "__main__": 37 | ray.init(address="auto") 38 | 39 | worker = Worker.remote() 40 | 41 | worker.register_generator.remote(make_generator) 42 | 43 | x = ray.get(worker.get_next.remote()) 44 | print(x.keys()) 45 | print(x['image'].shape) 46 | -------------------------------------------------------------------------------- /playground/other/test_ray_put.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import jax 4 | import ray 5 | import numpy as np 6 | 7 | MB = 1024**2 8 | GB = 1024**3 9 | 10 | 11 | def benchmark_ray(x): 12 | array = np.ones((x,), dtype=np.float32) 13 | warmup = 0 14 | number = 1 15 | 16 | # warm up 17 | for i in range(warmup): 18 | ray.put(array) 19 | 20 | # benchmark 21 | tic = time.time() 22 | for i in range(number): 23 | ray.put(array) 24 | cost = time.time() - tic 25 | 26 | size = np.prod(array.shape) * array.dtype.itemsize 27 | bandwidth = size / (cost / number) 28 | print(f"size: {size/MB:.2f} MB, bandwidth: {bandwidth/MB:.2f} MB") 29 | 30 | 31 | def benchmark_jax_put(x): 32 | batch = np.ones((x,), dtype=np.float32) 33 | 34 | # warm up 35 | for i in range(2): 36 | tmp = jax.device_put(batch) 37 | tmp.block_until_ready() 38 | 39 | # benchmark 40 | tic = time.time() 41 | y = [None] * 10 42 | for i in range(10): 43 | y[i] = jax.device_put(batch) 44 | #y[i] = None 45 | #y[i].block_until_ready() 46 | print(f"size: {x}, time: {time.time() - tic:.2f}") 47 | 48 | 49 | for i in [1, 64, 128, 512, 1024]: 50 | benchmark_ray(i * MB) 51 | for i in [1, 64, 128, 512, 1024]: 52 | benchmark_ray(i * MB) 53 | for i in [1, 64, 128, 512, 1024]: 54 | benchmark_ray(i * MB) 55 | 56 | #for i in range(10): 57 | # benchmark_jax_put(8192 * 28 * 28 * 1) 58 | -------------------------------------------------------------------------------- /playground/other/test_remote_call_cost.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from alpa.device_mesh import Mesh 4 | import numpy as np 5 | import ray 6 | 7 | ray.init(address="auto") 8 | worker = ray.remote(num_gpus=1)(Worker).remote() 9 | 10 | latencies = [] 11 | for i in range(1000): 12 | tic = time.time() 13 | ray.get(worker.check_alive.remote()) 14 | latency = time.time() - tic 15 | print(f"{i}, latency: {latency * 1e3:.2f} ms") 16 | -------------------------------------------------------------------------------- /playground/other/test_torch_ddp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m torch.distributed.launch --nproc_per_node 1 --nnodes 1 --node_rank 0 --master_addr localhost --master_port 11000 test_torch_ddp.py 4 | """ 5 | import torch 6 | import torch.optim as optim 7 | from torch import nn 8 | from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP 9 | #from torch.nn.parallel import DataParallel as torchDDP 10 | 11 | class Net(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | self.net1 = nn.Linear(1 << 10, 1 << 19) 16 | self.net2 = nn.Linear(1 << 19, 1) 17 | 18 | def forward(self, x): 19 | return self.net2(self.net1(x)) 20 | 21 | 22 | GB = 1024 ** 3 23 | 24 | def get_memory_usage(print_info=False): 25 | """Get accurate gpu memory usage by querying torch runtime""" 26 | rank = torch.distributed.get_rank() 27 | device = rank % torch.cuda.device_count() 28 | allocated = torch.cuda.memory_allocated(device) 29 | reserved = torch.cuda.memory_reserved(device) 30 | if print_info: 31 | print("allocated: %.2f GB" % (allocated / GB), flush=True) 32 | print("reserved: %.2f GB" % (reserved / GB), flush=True) 33 | return allocated 34 | 35 | torch.distributed.init_process_group(backend="nccl", world_size=1) 36 | 37 | raw_model = Net().cuda() 38 | 39 | print("After init model", get_memory_usage() / GB) 40 | model = torchDDP(raw_model, device_ids=[0], output_device=0, gradient_as_bucket_view=True) 41 | optimizer = optim.SGD(model.parameters(), lr=0.001) 42 | 43 | print("After torchDDP", get_memory_usage() / GB) 44 | 45 | data = torch.ones((1, 1<<10)).cuda() 46 | label = torch.ones((1,)).cuda() 47 | 48 | optimizer.zero_grad() 49 | loss = torch.square(model(data) - label).sum() 50 | loss.backward() 51 | optimizer.step() 52 | 53 | print("After first backward", get_memory_usage() / GB) 54 | 55 | optimizer.zero_grad() 56 | loss = torch.square(model(data) - label).sum() 57 | loss.backward() 58 | optimizer.step() 59 | print("After second backward", get_memory_usage() / GB) 60 | 61 | -------------------------------------------------------------------------------- /playground/other/test_torch_trace.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | N = 2 4 | H = 4 5 | 6 | loss_func = torch.nn.MSELoss() 7 | model = torch.nn.Linear(H, H) 8 | 9 | def func(data, target, *params): 10 | optimizer = torch.optim.SGD(model.parameters(), lr=0.1) 11 | 12 | y = model(data) 13 | loss = loss_func(y, target) 14 | 15 | print(y) 16 | 17 | loss.backward() 18 | return loss 19 | 20 | data = torch.ones((N, H)) 21 | target = torch.ones((N, H)) 22 | 23 | model_params = tuple(model.parameters()) 24 | func(*((data, target,) + model_params)) 25 | model_grads = tuple(x.grad for x in model_params) 26 | 27 | graph, output = torch.jit._get_trace_graph(func, (data, target) + model_params + model_grads) 28 | -------------------------------------------------------------------------------- /playground/pipeline/jax_array_slicing.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import numpy 3 | from jax import core, xla 4 | from jax._src.util import (partial, unzip3) 5 | from jax.abstract_arrays import array_types 6 | from jax.interpreters import pxla 7 | from jax.interpreters.pxla import (ShardingSpec, Chunked, NoSharding, Replicated, 8 | ShardedAxis, _as_slice_indices, _hashable_index, ShardedDeviceArray) 9 | import numpy as np 10 | from jax.lib import xla_client, xla_bridge 11 | import jax.numpy as jnp 12 | from alpa.util import jax_buffer_set, jax_buffer_set_v2 13 | 14 | 15 | offset = [0, 4] 16 | m = jnp.zeros([10, 10], dtype=np.float32) 17 | print(m.__cuda_array_interface__) 18 | n = jnp.ones([2, 2], dtype=np.float32) 19 | print(n.__cuda_array_interface__) 20 | k = jax_buffer_set_v2(m, n, tuple(offset)) 21 | print(k.__cuda_array_interface__) 22 | print(k) 23 | -------------------------------------------------------------------------------- /playground/pipeline/test_ray_jax_array.py: -------------------------------------------------------------------------------- 1 | # check gpu devices 2 | import os 3 | 4 | import jax.numpy as jnp 5 | import ray 6 | 7 | 8 | os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False" 9 | ray.init(num_gpus=2, num_cpus=4) 10 | 11 | 12 | @ray.remote(num_gpus=1, num_cpus=2) 13 | class Runner: 14 | def __init__(self, name): 15 | print("ray.get_gpu_ids(): {}".format(ray.get_gpu_ids())) 16 | print("CUDA_VISIBLE_DEVICES: {}".format(os.environ["CUDA_VISIBLE_DEVICES"])) 17 | self.name = name 18 | self.a = None 19 | self.b = None 20 | 21 | def compute(self): 22 | print(type(self.a)) 23 | print(type(self.b)) 24 | c = jnp.matmul(self.a, self.b) 25 | print(type(c)) 26 | return c 27 | 28 | def set(self, refs): 29 | arrays = ray.get(refs) 30 | print(arrays) 31 | # a = ray.get(a_ref) 32 | # print(a) 33 | # print(type(a)) 34 | self.a = jnp.asarray(arrays[0]) 35 | # b = ray.get(b_ref) 36 | # print(b) 37 | # print(type(b)) 38 | self.b = jnp.asarray(arrays[1]) 39 | 40 | 41 | workers = [] 42 | workers.append(Runner.remote(name="0")) 43 | workers.append(Runner.remote(name="1")) 44 | 45 | a = jnp.ones([3, 4]) 46 | b = jnp.ones([4, 5]) 47 | a_ref = ray.put(a) 48 | b_ref = ray.put(b) 49 | worker = workers[0] 50 | worker.set.remote([a_ref, b_ref]) 51 | c_ref = worker.compute.remote() 52 | c_result = ray.get(c_ref) 53 | 54 | worker = workers[1] 55 | worker.set.remote([a_ref, b_ref]) 56 | c_ref = worker.compute.remote() 57 | c_result = ray.get(c_ref) 58 | print(c_result) 59 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # Unit test 2 | 3 | ## Requirement 4 | A machine with at least 4 gpus. 5 | 6 | ## Run all test cases 7 | 8 | 1. Start a ray cluster 9 | ``` 10 | ray start --head 11 | ``` 12 | 13 | 2. Run all tests 14 | ``` 15 | python3 run_all.py 16 | ``` 17 | 18 | ## Run specific files 19 | 20 | - For debug usage: 21 | ``` 22 | python3 shard_parallel/test_basic.py 23 | ``` 24 | 25 | - More similar to how CI runs files 26 | ``` 27 | # Run one file 28 | python3 run_all.py --run-pattern shard_parallel/test_basic.py 29 | 30 | # Run a folder 31 | python3 run_all.py --run-pattern shard_parallel 32 | ``` 33 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alpa-projects/alpa/b8078a9f75cb4c90cabb4550ee48c99ef394e209/tests/__init__.py -------------------------------------------------------------------------------- /tests/killall_python.sh: -------------------------------------------------------------------------------- 1 | kill -9 $(ps aux | grep 'python3' | grep -v 'grep' | awk '{print $2}') 2 | -------------------------------------------------------------------------------- /tests/pipeline_parallel/test_dynamic_programming.py: -------------------------------------------------------------------------------- 1 | """Test dynamic programming.""" 2 | 3 | import numpy as np 4 | import unittest 5 | 6 | import alpa 7 | from alpa.pipeline_parallel.stage_construction import (training_dp as 8 | stage_construction_dp, 9 | get_submesh_choices) 10 | from alpa.testing import assert_allclose 11 | 12 | 13 | class DynamicProgrammingTest(unittest.TestCase): 14 | """Test dynamic programming.""" 15 | 16 | def test_stage_construction(self): 17 | """Test stage construction.""" 18 | num_layers = 8 19 | num_hosts = 1 20 | num_devices_per_host = 8 21 | num_devices = num_hosts * num_devices_per_host 22 | num_micro_batches = 16 23 | num_autosharding_configs = 1 24 | for i in range(1, num_devices + 1): 25 | if num_devices % i == 0: 26 | num_autosharding_configs += 1 27 | submesh_choices = get_submesh_choices(num_hosts, num_devices_per_host, 28 | "all") 29 | num_submesh_choices = len(submesh_choices) 30 | np.random.seed(42) 31 | compute_cost = np.random.rand(num_layers, num_layers, 32 | num_submesh_choices, 33 | num_autosharding_configs) 34 | max_n_succ_stages = np.full( 35 | (num_layers, num_layers, num_submesh_choices, 36 | num_autosharding_configs), 4096) 37 | alpa.util._DISABLE_NUMBA = False 38 | numba_cost, _ = stage_construction_dp(num_layers, num_devices, 39 | num_micro_batches, 40 | submesh_choices, 41 | num_autosharding_configs, 42 | compute_cost, max_n_succ_stages) 43 | alpa.util._DISABLE_NUMBA = True 44 | no_numba_cost, _ = stage_construction_dp( 45 | num_layers, num_devices, num_micro_batches, submesh_choices, 46 | num_autosharding_configs, compute_cost, max_n_succ_stages) 47 | assert_allclose(numba_cost, no_numba_cost) 48 | # Note(zhuohan): The profiling here suggest that the numba jitted 49 | # version is ~250x faster than the non-jitted version. Therefore, 50 | # we highly recommend to use the numba version, but for smaller 51 | # problem sizes, the non-jitted version is also acceptable. 52 | 53 | 54 | def suite(): 55 | suite = unittest.TestSuite() 56 | suite.addTest(unittest.makeSuite(DynamicProgrammingTest)) 57 | return suite 58 | 59 | 60 | if __name__ == "__main__": 61 | runner = unittest.TextTestRunner() 62 | runner.run(suite()) 63 | -------------------------------------------------------------------------------- /tests/pipeline_parallel/test_global_norm.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import jax 4 | from jax import numpy as jnp, lax 5 | from jax._src.tree_util import tree_map 6 | from optax import global_norm 7 | 8 | from alpa import grad 9 | from alpa.testing import PipelineBasicTest 10 | 11 | 12 | class GlobalNormTest(PipelineBasicTest): 13 | 14 | def test_global_norm(self): 15 | hlos = self.run_n_layer_bert(num_layers=2, 16 | manual_pipeline_layer=False, 17 | clip_by_global_norm=True) 18 | for x in hlos[-2:]: 19 | assert "CrossMeshAllReduce" in x 20 | 21 | @unittest.skip("No data to test efficiently.") 22 | def test_dynamic_scale(self): 23 | hlos = self.run_n_layer_bert(num_layers=2, 24 | manual_pipeline_layer=False, 25 | use_dynamic_scale=True) 26 | 27 | @unittest.skip("No data to test efficiently.") 28 | def test_global_norm_dynamic_scale(self): 29 | hlos = self.run_n_layer_bert(num_layers=2, 30 | manual_pipeline_layer=False, 31 | clip_by_global_norm=True, 32 | use_dynamic_scale=True) 33 | 34 | def test_glob_norm_and_all_le(self): 35 | 36 | def train_step(state, batch): 37 | 38 | def loss_func(params): 39 | out = state.apply_fn(params, batch["x"], 40 | batch["attention_mask"]) 41 | loss = jnp.mean((out - batch["y"])**2) 42 | return loss 43 | 44 | grads = grad(loss_func)(state.params) 45 | glob_norm = global_norm(grads) 46 | new_grads = tree_map(lambda g: g / glob_norm, grads) 47 | new_state = state.apply_gradients(grads=new_grads) 48 | 49 | ls_1 = jnp.array(True) 50 | for g in jax.tree_util.tree_leaves(grads): 51 | ls_1 &= jnp.all(lax.le(g, 1.)) 52 | return new_state, (new_grads, ls_1) 53 | 54 | hlos = self.run_n_layer_bert(num_layers=2, inject_train_step=train_step) 55 | for x in hlos[-2:]: 56 | assert 'backend_config="SUM;' in x 57 | assert 'backend_config="AND;' in x 58 | assert x.count("CrossMeshAllReduce") == 2 59 | 60 | 61 | def suite(): 62 | suite = unittest.TestSuite() 63 | suite.addTest(GlobalNormTest("test_global_norm")) 64 | suite.addTest(GlobalNormTest("test_dynamic_scale")) 65 | suite.addTest(GlobalNormTest("test_global_norm_dynamic_scale")) 66 | suite.addTest(GlobalNormTest("test_glob_norm_and_all_le")) 67 | return suite 68 | 69 | 70 | if __name__ == '__main__': 71 | runner = unittest.TextTestRunner() 72 | runner.run(suite()) 73 | -------------------------------------------------------------------------------- /tests/pipeline_parallel/test_layer_construction.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import jax 4 | from alpa.testing import PipelineBasicTest 5 | 6 | 7 | class LayerConstructionTest(PipelineBasicTest): 8 | 9 | def test_mlp_layer_construction(self): 10 | self.run_mlp(manual_pipeline_layer=False) 11 | 12 | def test_2_layer_bert_layer_construction(self): 13 | self.run_n_layer_bert(num_layers=2, manual_pipeline_layer=False) 14 | 15 | @unittest.skipIf(jax.device_count('gpu') < 8, "no enough device") 16 | def test_8_layer_bert_layer_construction(self): 17 | self.run_n_layer_bert(num_layers=8, manual_pipeline_layer=False) 18 | 19 | 20 | def suite(): 21 | suite = unittest.TestSuite() 22 | suite.addTest(LayerConstructionTest('test_mlp_layer_construction')) 23 | suite.addTest(LayerConstructionTest('test_2_layer_bert_layer_construction')) 24 | suite.addTest(LayerConstructionTest('test_8_layer_bert_layer_construction')) 25 | return suite 26 | 27 | 28 | if __name__ == "__main__": 29 | runner = unittest.TextTestRunner() 30 | runner.run(suite()) 31 | -------------------------------------------------------------------------------- /tests/pipeline_parallel/test_multi_graph.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import unittest 5 | 6 | from alpa import init, parallelize, global_config, PipeshardParallel 7 | from alpa.testing import assert_allclose, get_mlp_train_state_and_step 8 | 9 | 10 | class MultipleGraphRuntimeTest(unittest.TestCase): 11 | 12 | def setUp(self): 13 | init(cluster="ray") 14 | 15 | def run_2_mlp(self, use_value_and_grad=False, stage_option="uniform"): 16 | 17 | def test_one_mlp(method, batch_size=64, hidden_size=16): 18 | state, batch, train_step = get_mlp_train_state_and_step( 19 | batch_size=batch_size, 20 | hidden_size=hidden_size, 21 | add_manual_pipeline_marker=True) 22 | 23 | # Compile 24 | serial_train_step = train_step 25 | parallel_train_step = parallelize(train_step, method=method) 26 | executable = parallel_train_step.get_executable(state, batch) 27 | 28 | # Run and check 29 | expected_new_state, expected_val = serial_train_step(state, batch) 30 | actual_new_state, actual_val = parallel_train_step(state, batch) 31 | 32 | assert_allclose(expected_new_state.params, actual_new_state.params, 33 | 1e-3, 1e-3) 34 | assert_allclose(expected_val, actual_val, 1e-3, 1e-3) 35 | 36 | return executable 37 | 38 | method = PipeshardParallel(num_micro_batches=2, 39 | stage_option=stage_option, 40 | layer_option="manual") 41 | executable = test_one_mlp(method) 42 | executable_2 = test_one_mlp(method) 43 | 44 | assert executable != executable_2 45 | 46 | def test_2_mlp(self): 47 | self.run_2_mlp() 48 | 49 | 50 | def suite(): 51 | suite = unittest.TestSuite() 52 | suite.addTest(MultipleGraphRuntimeTest('test_2_mlp')) 53 | return suite 54 | 55 | 56 | if __name__ == "__main__": 57 | runner = unittest.TextTestRunner() 58 | runner.run(suite()) 59 | -------------------------------------------------------------------------------- /tests/pipeline_parallel/test_remat.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import jax 4 | from alpa.testing import PipelineBasicTest 5 | 6 | 7 | class PipelineRematTest(PipelineBasicTest): 8 | 9 | def test_mlp_remat(self): 10 | self.run_mlp(use_remat=True) 11 | 12 | def test_2_layer_bert_remat(self): 13 | self.run_n_layer_bert(num_layers=2, use_remat=True) 14 | 15 | def test_2_layer_bert_auto_layer_slicing_remat(self): 16 | self.run_n_layer_bert(num_layers=2, 17 | manual_pipeline_layer=False, 18 | use_remat=True) 19 | 20 | @unittest.skipIf(jax.local_device_count("gpu") < 8, "no enough device") 21 | def test_8_layer_bert_auto_layer_slicing_remat(self): 22 | self.run_n_layer_bert(num_layers=8, 23 | manual_pipeline_layer=False, 24 | use_remat=True) 25 | 26 | 27 | def suite(): 28 | suite = unittest.TestSuite() 29 | suite.addTest(PipelineRematTest('test_mlp_remat')) 30 | suite.addTest(PipelineRematTest('test_2_layer_bert_remat')) 31 | suite.addTest( 32 | PipelineRematTest('test_2_layer_bert_auto_layer_slicing_remat')) 33 | suite.addTest( 34 | PipelineRematTest('test_8_layer_bert_auto_layer_slicing_remat')) 35 | return suite 36 | 37 | 38 | if __name__ == "__main__": 39 | runner = unittest.TextTestRunner() 40 | runner.run(suite()) 41 | -------------------------------------------------------------------------------- /tests/pipeline_parallel/test_scatter_gather.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from alpa.device_mesh import (get_global_cluster, 4 | set_global_virtual_physical_mesh) 5 | from alpa.pipeline_parallel.stage_construction import ManualStageOption 6 | from alpa.testing import PipelineBasicTest 7 | 8 | 9 | class ScatterGatherTest(PipelineBasicTest): 10 | 11 | def test_2_layer_bert(self): 12 | virtual_mesh = get_global_cluster().get_virtual_physical_mesh([0], 4) 13 | set_global_virtual_physical_mesh(virtual_mesh) 14 | 15 | stage_option = ManualStageOption( 16 | forward_stage_layer_ids=[[0], [1]], 17 | submesh_physical_shapes=[(1, 2), (1, 2)], 18 | submesh_logical_shapes=[(1, 2), (2, 1)], 19 | submesh_autosharding_option_dicts=[ 20 | dict(force_batch_dim_to_mesh_dim=0), {} 21 | ]) 22 | 23 | self.run_n_layer_bert(num_layers=2, 24 | batch_size=4, 25 | seq_len=4, 26 | hidden_size=4, 27 | num_heads=1, 28 | stage_option=stage_option) 29 | 30 | 31 | def suite(): 32 | suite = unittest.TestSuite() 33 | suite.addTest(ScatterGatherTest('test_2_layer_bert')) 34 | return suite 35 | 36 | 37 | if __name__ == "__main__": 38 | runner = unittest.TextTestRunner() 39 | runner.run(suite()) 40 | -------------------------------------------------------------------------------- /tests/pipeline_parallel/test_stage_construction.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from alpa.pipeline_parallel.stage_construction import AutoStageOption 4 | from alpa.testing import PipelineBasicTest 5 | 6 | 7 | def auto_stage(): 8 | return AutoStageOption(submesh_physical_shape_space="small_power_of_two", 9 | submesh_logical_shape_space="same_as_physical") 10 | 11 | 12 | class StageConstructionTest(PipelineBasicTest): 13 | 14 | def test_mlp_stage_construction(self): 15 | self.run_mlp(stage_option=auto_stage()) 16 | 17 | def test_mlp_layer_and_stage(self): 18 | self.run_mlp(manual_pipeline_layer=False, stage_option=auto_stage()) 19 | 20 | 21 | def suite(): 22 | suite = unittest.TestSuite() 23 | suite.addTest(StageConstructionTest('test_mlp_stage_construction')) 24 | suite.addTest(StageConstructionTest('test_mlp_layer_and_stage')) 25 | return suite 26 | 27 | 28 | if __name__ == "__main__": 29 | runner = unittest.TextTestRunner() 30 | runner.run(suite()) 31 | -------------------------------------------------------------------------------- /tests/pipeline_parallel/test_stage_construction_slow.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from alpa.pipeline_parallel.stage_construction import AutoStageOption 4 | from alpa.testing import PipelineBasicTest 5 | 6 | 7 | def auto_stage(): 8 | return AutoStageOption(submesh_physical_shape_space="small_power_of_two", 9 | submesh_logical_shape_space="same_as_physical") 10 | 11 | 12 | class StageConstructionSlowTest(PipelineBasicTest): 13 | 14 | def test_mlp_stage_construction(self): 15 | self.run_mlp(stage_option=auto_stage()) 16 | 17 | def test_mlp_layer_and_stage(self): 18 | self.run_mlp(manual_pipeline_layer=False, stage_option=auto_stage()) 19 | 20 | def test_2_layer_bert_stage_construction(self): 21 | self.run_n_layer_bert(num_layers=2, stage_option=auto_stage()) 22 | 23 | def test_2_layer_bert_layer_and_stage(self): 24 | self.run_n_layer_bert(num_layers=2, 25 | manual_pipeline_layer=False, 26 | stage_option=auto_stage()) 27 | 28 | def test_8_layer_bert_stage_construction(self): 29 | self.run_n_layer_bert(num_layers=8, stage_option=auto_stage()) 30 | 31 | def test_8_layer_bert_layer_and_stage(self): 32 | self.run_n_layer_bert(num_layers=8, 33 | manual_pipeline_layer=False, 34 | stage_option=auto_stage()) 35 | 36 | 37 | def suite(): 38 | suite = unittest.TestSuite() 39 | suite.addTest(StageConstructionSlowTest('test_mlp_stage_construction')) 40 | suite.addTest(StageConstructionSlowTest('test_mlp_layer_and_stage')) 41 | suite.addTest( 42 | StageConstructionSlowTest('test_2_layer_bert_stage_construction')) 43 | suite.addTest( 44 | StageConstructionSlowTest('test_2_layer_bert_layer_and_stage')) 45 | suite.addTest( 46 | StageConstructionSlowTest('test_8_layer_bert_stage_construction')) 47 | suite.addTest( 48 | StageConstructionSlowTest('test_8_layer_bert_layer_and_stage')) 49 | return suite 50 | 51 | 52 | if __name__ == "__main__": 53 | runner = unittest.TextTestRunner() 54 | runner.run(suite()) 55 | -------------------------------------------------------------------------------- /tests/pipeline_parallel/test_tied_embedding.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | 4 | from flax import linen as nn 5 | import jax 6 | import jax.numpy as jnp 7 | import optax 8 | 9 | from alpa import (init, parallelize, mark_pipeline_boundary, grad, 10 | PipeshardParallel) 11 | from alpa.model.model_util import TrainState 12 | from alpa.testing import assert_allclose 13 | 14 | 15 | class PipelineTiedEmbeddingTest(unittest.TestCase): 16 | 17 | def setUp(self): 18 | init(cluster="ray") 19 | 20 | def train_tied_embedding(self, method): 21 | vocab_size = 256 22 | hidden_size = 16 23 | batch_size = 8 24 | seq_len = 8 25 | 26 | class Model(nn.Module): 27 | """Tied input and output embedding.""" 28 | 29 | def setup(self): 30 | self.embed = nn.Embed(vocab_size, hidden_size) 31 | 32 | def __call__(self, x): 33 | x = self.embed(x) 34 | mark_pipeline_boundary() 35 | embed = self.embed.variables["params"]["embedding"] 36 | x = x @ embed.T 37 | return x 38 | 39 | def train_step(state, batch): 40 | 41 | def loss_func(params): 42 | out = state.apply_fn(params, batch["x"]) 43 | y_ = jax.nn.one_hot(batch["y"], out.shape[-1]) 44 | loss = -jnp.sum(y_ * jax.nn.log_softmax(out, axis=-1), 45 | axis=-1).sum() 46 | return loss 47 | 48 | grads = grad(loss_func)(state.params) 49 | return state.apply_gradients(grads=grads) 50 | 51 | x = jnp.ones((batch_size, seq_len), jnp.int32) 52 | y = jnp.ones((batch_size, seq_len), jnp.int32) 53 | 54 | # Init model and optimizer 55 | model = Model() 56 | rngkey = jax.random.PRNGKey(0) 57 | params = model.init(rngkey, x) 58 | tx = optax.adam(learning_rate=1e-2) 59 | state = TrainState.create(apply_fn=model.apply, 60 | params=params, 61 | tx=tx, 62 | dynamic_scale=None) 63 | 64 | # Run and check results 65 | p_train_step = parallelize(train_step, method=method) 66 | batch = {"x": x, "y": y} 67 | expected_new_state = train_step(state, batch) 68 | actual_new_state = p_train_step(state, batch) 69 | assert_allclose(actual_new_state.params, expected_new_state.params) 70 | 71 | def test_tied_embedding_pipeshard_parallel(self): 72 | method = PipeshardParallel(num_micro_batches=2, layer_option="manual") 73 | self.train_tied_embedding(method) 74 | 75 | 76 | def suite(): 77 | suite = unittest.TestSuite() 78 | suite.addTest( 79 | PipelineTiedEmbeddingTest("test_tied_embedding_pipeshard_parallel")) 80 | return suite 81 | 82 | 83 | if __name__ == '__main__': 84 | runner = unittest.TextTestRunner() 85 | runner.run(suite()) 86 | -------------------------------------------------------------------------------- /tests/runtime/test_cross_mesh_communicator.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import ray 4 | from alpa import init 5 | from alpa.device_mesh import ( 6 | create_and_record_cross_mesh_collective_communicators, get_global_cluster) 7 | from alpa.pipeline_parallel.stage_construction import get_sliced_virtual_submeshes 8 | from alpa.util import mesh_ids_hash 9 | 10 | 11 | class CrossMeshCollectiveCommunicatorTest(unittest.TestCase): 12 | 13 | def setUp(self) -> None: 14 | init("ray") 15 | 16 | def test_create_and_set(self): 17 | virtual_mesh = get_global_cluster().get_virtual_physical_mesh( 18 | host_ids=[0], num_devices_per_host=4) 19 | submesh_shapes = [(1, 2)] * 2 20 | sliced_virtual_meshes = get_sliced_virtual_submeshes( 21 | virtual_mesh, submesh_shapes) 22 | virtual_mesh.get_physical_mesh_group(sliced_virtual_meshes) 23 | mesh_group = virtual_mesh.launched_physical_mesh_group 24 | meshes = mesh_group.meshes 25 | key = mesh_ids_hash([0, 1]) 26 | ray.get( 27 | create_and_record_cross_mesh_collective_communicators(meshes, key)) 28 | 29 | 30 | def suite(): 31 | suite = unittest.TestSuite() 32 | suite.addTest(CrossMeshCollectiveCommunicatorTest("test_create_and_set")) 33 | 34 | return suite 35 | 36 | 37 | if __name__ == "__main__": 38 | runner = unittest.TextTestRunner() 39 | runner.run(suite()) 40 | -------------------------------------------------------------------------------- /tests/runtime/test_debug_info.py: -------------------------------------------------------------------------------- 1 | """Test the debug information dummping.""" 2 | import os 3 | import unittest 4 | 5 | from alpa import (init, parallelize, ShardParallel, PipeshardParallel, 6 | AutoLayerOption, global_config) 7 | from alpa.pipeline_parallel.stage_construction import get_last_dp_result 8 | from alpa.device_mesh import get_global_cluster 9 | from alpa.testing import assert_allclose, get_mlp_train_state_and_step 10 | 11 | 12 | class DebugInfoTest(unittest.TestCase): 13 | 14 | def setUp(self): 15 | os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" 16 | 17 | def test_1_debug_shard_parallel(self): 18 | state, batch, train_step = get_mlp_train_state_and_step(batch_size=128, 19 | hidden_size=128, 20 | num_layers=4) 21 | 22 | # Print auto-sharding intermidiate results 23 | os.environ["ALPA_DEBUG_PRINT_AS_STRATEGY"] = "1" 24 | 25 | p_train_step = parallelize(train_step, 26 | method=ShardParallel(num_micro_batches=2)) 27 | actual_output = p_train_step(state, batch) 28 | executable = p_train_step.get_last_executable() 29 | executable.sync() 30 | 31 | # Dump final HLO and other debug info 32 | executable.dump_debug_info("alpa_debug_info") 33 | 34 | def test_2_debug_pipeline_parallel(self): 35 | init(cluster="ray") 36 | state, batch, train_step = get_mlp_train_state_and_step(batch_size=128, 37 | hidden_size=128, 38 | num_layers=6) 39 | 40 | # Print auto-sharding intermidiate results 41 | global_config.pipeline_distributed_compile = False 42 | os.environ["ALPA_DEBUG_PRINT_AS_STRATEGY"] = "1" 43 | 44 | layer_num = min(get_global_cluster().num_devices, 2) 45 | p_train_step = parallelize( 46 | train_step, 47 | method=PipeshardParallel( 48 | num_micro_batches=2, 49 | layer_option=AutoLayerOption(layer_num=layer_num))) 50 | actual_output = p_train_step(state, batch) 51 | executable = p_train_step.get_last_executable() 52 | executable.sync() 53 | 54 | # Dump final HLO and other debug info 55 | executable.dump_debug_info("alpa_debug_info") 56 | 57 | # Print auto-stage dynamic programming results if use auto stage partition 58 | print(get_last_dp_result()) 59 | 60 | 61 | def suite(): 62 | s = unittest.TestSuite() 63 | s.addTest(DebugInfoTest("test_1_debug_shard_parallel")) 64 | s.addTest(DebugInfoTest("test_2_debug_pipeline_parallel")) 65 | return s 66 | 67 | 68 | if __name__ == "__main__": 69 | runner = unittest.TextTestRunner() 70 | runner.run(suite()) 71 | -------------------------------------------------------------------------------- /tests/runtime/test_install.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from alpa.test_install import suite 4 | 5 | if __name__ == "__main__": 6 | runner = unittest.TextTestRunner() 7 | runner.run(suite()) 8 | -------------------------------------------------------------------------------- /tests/runtime/test_memory_leak.py: -------------------------------------------------------------------------------- 1 | """Test whether there is any memory leak for distributed arrays and remote buffers.""" 2 | import unittest 3 | 4 | import ray 5 | 6 | from alpa import (init, shutdown, parallelize, global_config, ShardParallel, 7 | PipeshardParallel) 8 | from alpa.device_mesh import get_global_cluster 9 | from alpa.test_install import get_mlp_train_state_and_step 10 | 11 | 12 | class MemoryLeakTest(unittest.TestCase): 13 | 14 | def setUp(self): 15 | init() 16 | global_config.delete_remote_arrays_threshold = 0 17 | 18 | def tearDown(self): 19 | shutdown() 20 | 21 | def test_shard_parallel(self): 22 | state, batch, train_step = get_mlp_train_state_and_step(batch_size=128, 23 | hidden_size=128) 24 | train_step = parallelize(train_step, 25 | method=ShardParallel(num_micro_batches=2)) 26 | 27 | for i in range(2): 28 | state, loss = train_step(state, batch) 29 | del loss 30 | del state 31 | 32 | # Assert all buffers are freed 33 | executable = train_step.get_last_executable() 34 | for w in executable.physical_mesh.workers: 35 | # One loss array cannot be deleted due to python's GC behavior 36 | assert len(ray.get(w.get_live_buffer_uuids.remote())) <= 1 37 | 38 | def test_pipeline_parallel(self): 39 | state, batch, train_step = get_mlp_train_state_and_step( 40 | batch_size=128, hidden_size=128, add_manual_pipeline_marker=True) 41 | 42 | layer_num = min(get_global_cluster().num_devices, 2) 43 | train_step = parallelize( 44 | train_step, 45 | method=PipeshardParallel(num_micro_batches=2, 46 | layer_option="manual")) 47 | 48 | for i in range(2): 49 | state, loss = train_step(state, batch) 50 | del loss 51 | del state 52 | 53 | # Assert all buffers are freed 54 | executable = train_step.get_last_executable() 55 | for mesh in executable.mesh_group: 56 | for w in mesh.workers: 57 | assert len(ray.get(w.get_live_buffer_uuids.remote())) == 0 58 | 59 | 60 | def suite(): 61 | suite = unittest.TestSuite() 62 | suite.addTest(MemoryLeakTest("test_shard_parallel")) 63 | suite.addTest(MemoryLeakTest("test_pipeline_parallel")) 64 | return suite 65 | 66 | 67 | if __name__ == "__main__": 68 | runner = unittest.TextTestRunner() 69 | runner.run(suite()) 70 | -------------------------------------------------------------------------------- /tests/runtime/test_save_load.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import time 3 | from tempfile import TemporaryFile 4 | 5 | import ray 6 | import jax 7 | import jax.numpy as jnp 8 | import numpy as np 9 | import pickle 10 | import flax 11 | 12 | from alpa import init, parallelize, PipeshardParallel, util 13 | from alpa.testing import get_mlp_train_state_and_step, assert_allclose 14 | 15 | 16 | class SaveLoadTest(unittest.TestCase): 17 | 18 | def setUp(self): 19 | init(cluster="ray") 20 | 21 | def test_mlp_state_load(self): 22 | # Init model 23 | state, batch, train_step = get_mlp_train_state_and_step( 24 | batch_size=128, hidden_size=128, add_manual_pipeline_marker=True) 25 | 26 | # Compile 27 | method = PipeshardParallel(num_micro_batches=2, layer_option="manual") 28 | serial_train_step = train_step 29 | parallel_train_step = parallelize(train_step, method=method) 30 | executable = parallel_train_step.get_executable(state, batch) 31 | 32 | serial_state = state 33 | parallel_state = state 34 | serial_state = serial_train_step(serial_state, batch)[0] 35 | parallel_state = parallel_train_step(parallel_state, batch)[0] 36 | assert_allclose(serial_state.params, parallel_state.params, 1e-3, 1e-3) 37 | 38 | # Save model to a temporary file 39 | outfile = TemporaryFile() 40 | parallel_state_dict = flax.serialization.to_state_dict(parallel_state) 41 | pickle.dump(util.map_to_nparray(parallel_state_dict), outfile) 42 | 43 | # Load model from the temporary file 44 | outfile.seek(0) 45 | loaded_state_dict = pickle.load(outfile) 46 | loaded_state = flax.serialization.from_state_dict( 47 | state, loaded_state_dict) 48 | outfile.close() 49 | 50 | # Compare the loaded state with the original state 51 | assert_allclose(loaded_state.params, serial_state.params, 1e-3, 1e-3) 52 | assert_allclose(loaded_state.params, parallel_state.params, 1e-3, 1e-3) 53 | 54 | # Take a step with the loaded state on both serial and parallel version 55 | serial_state = serial_train_step(serial_state, batch)[0] 56 | parallel_state = parallel_train_step(parallel_state, batch)[0] 57 | serial_loaded_state = serial_train_step(loaded_state, batch)[0] 58 | parallel_loaded_state = parallel_train_step(loaded_state, batch)[0] 59 | 60 | # All the states should be the same 61 | assert_allclose(serial_state.params, parallel_state.params, 1e-3, 1e-3) 62 | assert_allclose(serial_state.params, serial_loaded_state.params, 1e-3, 63 | 1e-3) 64 | assert_allclose(serial_state.params, parallel_loaded_state.params, 1e-3, 65 | 1e-3) 66 | 67 | 68 | def suite(): 69 | suite = unittest.TestSuite() 70 | suite.addTest(SaveLoadTest('test_mlp_state_load')) 71 | return suite 72 | 73 | 74 | if __name__ == "__main__": 75 | runner = unittest.TextTestRunner() 76 | runner.run(suite()) 77 | -------------------------------------------------------------------------------- /tests/runtime/test_tracing.py: -------------------------------------------------------------------------------- 1 | """Test activity tracing.""" 2 | import unittest 3 | 4 | from alpa import (init, shutdown, parallelize, global_config, PipeshardParallel) 5 | from alpa.global_env import global_config 6 | from alpa.device_mesh import get_global_cluster 7 | from alpa.test_install import get_mlp_train_state_and_step 8 | 9 | 10 | class TracingTest(unittest.TestCase): 11 | 12 | def setUp(self): 13 | global_config.collect_trace = True 14 | init() 15 | 16 | def tearDown(self): 17 | shutdown() 18 | 19 | def test_trace_pipeshard_execuable(self): 20 | state, batch, train_step = get_mlp_train_state_and_step( 21 | batch_size=128, hidden_size=128, add_manual_pipeline_marker=True) 22 | 23 | layer_num = min(get_global_cluster().num_devices, 2) 24 | train_step = parallelize( 25 | train_step, 26 | method=PipeshardParallel(num_micro_batches=2, 27 | layer_option="manual")) 28 | 29 | for i in range(2): 30 | state, _ = train_step(state, batch) 31 | 32 | executable = train_step.get_last_executable() 33 | stage_exec_info = executable.get_stage_execution_info() 34 | 35 | assert len(stage_exec_info) == 6 # 6 stages 36 | assert len(stage_exec_info[0]) == 4 # 4 invocations 37 | 38 | 39 | def suite(): 40 | suite = unittest.TestSuite() 41 | suite.addTest(TracingTest("test_trace_pipeshard_execuable")) 42 | return suite 43 | 44 | 45 | if __name__ == "__main__": 46 | runner = unittest.TextTestRunner() 47 | runner.run(suite()) 48 | -------------------------------------------------------------------------------- /tests/runtime/test_xla_nccl.py: -------------------------------------------------------------------------------- 1 | """Test cross-mesh resharding.""" 2 | import unittest 3 | 4 | import numpy as np 5 | import ray 6 | 7 | from alpa import init 8 | from alpa.device_mesh import get_global_virtual_physical_mesh, next_array_uuids 9 | from alpa.global_env import global_config 10 | 11 | 12 | class XLANCCLTest(unittest.TestCase): 13 | 14 | def setUp(self): 15 | init(cluster="ray") 16 | 17 | @unittest.skip("manually calling allgather is deprecated") 18 | def test_xla_nccl_allgather(self): 19 | backup_nccl_mode = global_config.nccl_mode 20 | global_config.nccl_mode = "xla_extension" 21 | 22 | mesh_shape = (1, 4) 23 | size = (4, 4) 24 | virtual_mesh = get_global_virtual_physical_mesh() 25 | mesh = virtual_mesh.slice_2d(range(mesh_shape[0]), 26 | [range(mesh_shape[1])] * 27 | mesh_shape[0]).get_physical_mesh() 28 | worker = mesh.workers[0] 29 | device_ids = np.arange(mesh.num_devices_per_host) 30 | 31 | # Put buffers 32 | ary_uuid = next_array_uuids(1)[0] 33 | shard_len = size[0] // mesh.num_devices_per_host 34 | shards = [] 35 | for i in range(mesh.num_devices_per_host): 36 | data = np.zeros(size, dtype=int) 37 | data[i * shard_len:(i + 1) * shard_len, :] = i 38 | shards.append(data) 39 | ray.get(worker.put_buffers.remote(ary_uuid, shards, 1, 0)) 40 | 41 | # Put allgather task 42 | output_slice = [slice(0, size[0], None), slice(0, size[1], None)] 43 | tensor_slices = [] 44 | for i in range(mesh.num_devices_per_host): 45 | tensor_slices.append([ 46 | slice(i * shard_len, (i + 1) * shard_len, None), 47 | slice(0, size[1], None) 48 | ]) 49 | ray.get( 50 | worker.put_resharding_allgather_task.remote( 51 | 0, (ReshardingAllGatherSpec(device_ids, tensor_slices, 52 | output_slice),))) 53 | 54 | # Run allgather task 55 | ray.get(worker.run_allgather_task.remote(0, ary_uuid)) 56 | refs = ray.get(worker.get_buffers.remote(ary_uuid)) 57 | for i in range(4): 58 | for j in range(4): 59 | assert refs[i][j * shard_len, 0] == j 60 | 61 | global_config.nccl_mode = backup_nccl_mode 62 | 63 | 64 | def suite(): 65 | suite = unittest.TestSuite() 66 | suite.addTest(XLANCCLTest("test_xla_nccl_allgather")) 67 | return suite 68 | 69 | 70 | if __name__ == '__main__': 71 | runner = unittest.TextTestRunner() 72 | runner.run(suite()) 73 | -------------------------------------------------------------------------------- /tests/shard_parallel/test_numerical_correctness.py: -------------------------------------------------------------------------------- 1 | """Test the numerical correctness of shard parallel.""" 2 | import unittest 3 | 4 | from flax import linen as nn 5 | import jax 6 | import jax.numpy as jnp 7 | import optax 8 | import ray 9 | 10 | import alpa 11 | from alpa import parallelize, LocalPhysicalDeviceMesh 12 | from alpa.model.bert_model import BertConfig, FlaxBertLayer, TrainState 13 | from alpa.testing import (assert_allclose, create_train_state, 14 | get_bert_layer_train_state_and_step) 15 | 16 | 17 | class AutoShardingCorrectnessTest(unittest.TestCase): 18 | 19 | def test_2_layer_bert_shard_parallel(self): 20 | physical_mesh = LocalPhysicalDeviceMesh(jax.local_devices()[:4]) 21 | logical_mesh = physical_mesh.get_logical_mesh([2, 2]) 22 | 23 | # Init model 24 | state, batch, train_step = get_bert_layer_train_state_and_step( 25 | batch_size=16, 26 | seq_len=8, 27 | num_layers=2, 28 | hidden_size=256, 29 | num_heads=8, 30 | clip_by_global_norm=False, 31 | use_dynamic_scale=False, 32 | add_manual_pipeline_marker=False) 33 | 34 | # Train one step 35 | p_train_step = parallelize(train_step) 36 | expected_state, expected_grads = train_step(state, batch) 37 | actual_state, actual_grads = p_train_step(state, batch) 38 | 39 | #print(expected_state) 40 | #print(actual_state) 41 | 42 | # print("group 1:") 43 | # print("expected param example: ", jax.tree_util.tree_flatten(expected_params.params)[0][0][0:10]) 44 | # print("actual param example: ", jax.tree_util.tree_flatten(actual_params.params)[0][0]._value[0:10]) 45 | # print("expected grad example: ", jax.tree_util.tree_flatten(expected_grads)[0][0][0:10]) 46 | # print("actual grad example: ", jax.tree_util.tree_flatten(actual_grads)[0][0]._value[0:10]) 47 | 48 | # print("group 2:") 49 | # print("expected param example: ", jax.tree_util.tree_flatten(expected_params.params)[0][-1][0:100]) 50 | # print("actual param example: ", jax.tree_util.tree_flatten(actual_params.params)[0][-1]._value[0:100]) 51 | # print("expected grad example: ", jax.tree_util.tree_flatten(expected_grads)[0][-1][0:100]) 52 | # print("actual grad example: ", jax.tree_util.tree_flatten(actual_grads)[0][-1]._value[0:100]) 53 | 54 | assert_allclose(expected_state, actual_state, rtol=5e-4, atol=5e-4) 55 | 56 | 57 | def suite(): 58 | suite = unittest.TestSuite() 59 | suite.addTest( 60 | AutoShardingCorrectnessTest("test_2_layer_bert_shard_parallel")) 61 | return suite 62 | 63 | 64 | if __name__ == "__main__": 65 | runner = unittest.TextTestRunner() 66 | runner.run(suite()) 67 | -------------------------------------------------------------------------------- /tests/torch_frontend/test_dict_input.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | import alpa.torch.optim as torchoptim 5 | import alpa 6 | from alpa.torch.trainer import train_torch_module 7 | 8 | 9 | class MyModule(torch.nn.Module): 10 | 11 | def __init__(self): 12 | super().__init__() 13 | self.linear1 = torch.nn.Linear(16, 16) 14 | self.linear2 = torch.nn.Linear(16, 16) 15 | self.linear3 = torch.nn.Linear(16, 16) 16 | self.linear4 = torch.nn.Linear(16, 16) 17 | 18 | def forward(self, input_dict): 19 | x = input_dict["x"] 20 | y = input_dict["dict2"]["y"] 21 | x = self.linear1(x) + y 22 | # do some debugging when in local mode 23 | if getattr(torch, "local_mode", True): 24 | print(x) 25 | x = self.linear2(x) 26 | x = self.linear3(x) 27 | x = self.linear4(x) 28 | return x 29 | 30 | 31 | def weight_init_func(pt_module, name_map, params, bufs): 32 | for k, m in pt_module.named_modules(): 33 | if isinstance(m, torch.nn.Linear): 34 | params[name_map[f"{k}.weight"]] = torch.nn.init.xavier_uniform( 35 | params[name_map[f"{k}.weight"]]) 36 | params[name_map[f"{k}.bias"]] = torch.nn.init.normal( 37 | params[name_map[f"{k}.bias"]], std=1e-6) 38 | return params, bufs 39 | 40 | 41 | class TorchDictInputTest(unittest.TestCase): 42 | 43 | def setUp(self): 44 | torch.manual_seed(123) 45 | alpa.set_seed(123) 46 | 47 | def test_dict_input(self): 48 | pt_module_gen = lambda: MyModule() 49 | 50 | dataloader = [ 51 | ({ 52 | "x": torch.randn(8, 16), 53 | "dict2": { 54 | "y": torch.randn(8, 16) 55 | } 56 | }, torch.randn(8, 16)), 57 | ({ 58 | "x": torch.randn(8, 16), 59 | "dict2": { 60 | "y": torch.randn(8, 16) 61 | } 62 | }, torch.randn(8, 16)), 63 | ] 64 | loss_func = lambda *args, **kwargs: torch.nn.functional.mse_loss( 65 | *args, **kwargs) 66 | optim_gen = torchoptim.adam(lr=1e-3) 67 | parallel_method = alpa.ShardParallel() 68 | 69 | train_torch_module(pt_module_gen, weight_init_func, dataloader, 70 | loss_func, optim_gen, parallel_method) 71 | 72 | 73 | def suite(): 74 | suite = unittest.TestSuite() 75 | suite.addTest(TorchDictInputTest("test_dict_input")) 76 | return suite 77 | 78 | 79 | if __name__ == '__main__': 80 | runner = unittest.TextTestRunner() 81 | runner.run(suite()) 82 | -------------------------------------------------------------------------------- /tests/torch_frontend/test_reshape.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | import alpa.torch.optim as torchoptim 5 | import alpa 6 | from alpa.torch.trainer import train_torch_module 7 | 8 | 9 | class MyModule(torch.nn.Module): 10 | 11 | def __init__(self): 12 | super().__init__() 13 | self.linear1 = torch.nn.Linear(16, 16) 14 | self.linear2 = torch.nn.Linear(16, 16) 15 | 16 | def forward(self, x): 17 | x = self.linear1(x) 18 | x = self.linear2(x) 19 | x = x.reshape(x.shape[0], 2, -1) 20 | x = x.reshape(x.shape[0], -1, 2) 21 | x = x.reshape(x.shape[0], 16) 22 | return x 23 | 24 | 25 | def weight_init_func(pt_module, name_map, params, bufs): 26 | # for k, m in pt_module.named_modules(): 27 | # if isinstance(m, torch.nn.Linear): 28 | # params[name_map[f"{k}.weight"]] = torch.nn.init.xavier_uniform(params[name_map[f"{k}.weight"]]) 29 | # params[name_map[f"{k}.bias"]] = torch.nn.init.normal(params[name_map[f"{k}.bias"]], std=1e-6) 30 | return params, bufs 31 | 32 | 33 | class TorchReshapeTest(unittest.TestCase): 34 | 35 | def setUp(self): 36 | torch.manual_seed(123) 37 | alpa.set_seed(123) 38 | 39 | def test_reshape(self): 40 | B = 64 41 | 42 | pt_module_gen = lambda: MyModule() 43 | 44 | dataloader = [ 45 | (torch.randn(B, 16), torch.randn(B, 16)), 46 | (torch.randn(B, 16), torch.randn(B, 16)), 47 | ] 48 | loss_func = lambda *args, **kwargs: torch.nn.functional.mse_loss( 49 | *args, **kwargs) 50 | optim_gen = torchoptim.adam(lr=1e-3) 51 | parallel_method = alpa.ShardParallel() 52 | 53 | train_torch_module(pt_module_gen, weight_init_func, dataloader, 54 | loss_func, optim_gen, parallel_method) 55 | 56 | 57 | def suite(): 58 | suite = unittest.TestSuite() 59 | suite.addTest(TorchReshapeTest("test_reshape")) 60 | return suite 61 | 62 | 63 | if __name__ == '__main__': 64 | runner = unittest.TextTestRunner() 65 | runner.run(suite()) 66 | -------------------------------------------------------------------------------- /tests/tpu/test_create_state_parallel.py: -------------------------------------------------------------------------------- 1 | """Test CreateStateParallel on TPU.""" 2 | import unittest 3 | 4 | from alpa import global_config 5 | 6 | import tests.runtime.test_create_state as test_create_state 7 | from tests.tpu.test_shard_parallel import has_tpu 8 | 9 | 10 | class TpuCreateStateTest(test_create_state.CreateStateTest): 11 | 12 | def setUp(self): 13 | global_config.backend = "tpu" 14 | 15 | def tearDown(self): 16 | return 17 | 18 | @unittest.skip("unsupported yet.") 19 | def test_shard_parallel_grad_acc(self): 20 | super().test_shard_parallel_grad_acc() 21 | 22 | @unittest.skip("unsupported yet.") 23 | def test_pipeshard_parallel(self): 24 | super().test_pipeshard_parallel() 25 | 26 | 27 | def suite(): 28 | suite = unittest.TestSuite() 29 | if not has_tpu(): 30 | return suite 31 | 32 | suite.addTest(TpuCreateStateTest("test_shard_parallel")) 33 | return suite 34 | 35 | 36 | if __name__ == "__main__": 37 | runner = unittest.TextTestRunner() 38 | runner.run(suite()) -------------------------------------------------------------------------------- /tests/tpu/test_follow_parallel.py: -------------------------------------------------------------------------------- 1 | """Test FollowParallel on TPU.""" 2 | import unittest 3 | 4 | from alpa import global_config 5 | 6 | import tests.runtime.test_follow_parallel as test_follow_parallel 7 | from tests.tpu.test_shard_parallel import has_tpu 8 | 9 | 10 | class TpuFollowParallelTest(test_follow_parallel.FollowParallelTest): 11 | 12 | def setUp(self): 13 | global_config.backend = "tpu" 14 | 15 | def tearDown(self): 16 | return 17 | 18 | @unittest.skip("unsupported yet.") 19 | def test_shard_parallel_grad_acc(self): 20 | super().test_shard_parallel_grad_acc() 21 | 22 | @unittest.skip("unsupported yet.") 23 | def test_pipeshard_parallel(self): 24 | super().test_pipeshard_parallel() 25 | 26 | 27 | def suite(): 28 | suite = unittest.TestSuite() 29 | if not has_tpu(): 30 | return suite 31 | 32 | suite.addTest(TpuFollowParallelTest("test_shard_parallel")) 33 | return suite 34 | 35 | 36 | if __name__ == "__main__": 37 | runner = unittest.TextTestRunner() 38 | runner.run(suite()) --------------------------------------------------------------------------------