├── .azure-pipelines └── ut.yml ├── .clang-format ├── .codecov.yml ├── .dockerignore ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── workflows │ ├── codeql.yml │ ├── lint.yml │ ├── ut-cuda.yml │ └── ut-rocm.yml ├── .gitignore ├── .gitmodules ├── .vscode ├── c_cpp_properties.json ├── launch.json └── settings.json ├── CITATION.cff ├── CMakeLists.txt ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── ark ├── CMakeLists.txt ├── api │ ├── context.cpp │ ├── context_test.cpp │ ├── data_type.cpp │ ├── dims.cpp │ ├── dims_test.cpp │ ├── error_test.cpp │ ├── executor.cpp │ ├── executor_test.cpp │ ├── init.cpp │ ├── init_test.cpp │ ├── log.cpp │ ├── model.cpp │ ├── model_graph.cpp │ ├── model_test.cpp │ ├── planner.cpp │ ├── planner_test.cpp │ ├── random.cpp │ ├── tensor.cpp │ ├── version.cpp │ └── version_test.cpp ├── arch.cpp ├── arch.hpp ├── arch_test.cpp ├── bfloat16.cpp ├── bfloat16.h ├── bfloat16_test.cpp ├── codegen.cpp ├── codegen.hpp ├── context_impl.cpp ├── context_impl.hpp ├── cpu_timer.cpp ├── cpu_timer.h ├── env.cpp ├── env.h ├── file_io.cpp ├── file_io.h ├── file_io_test.cpp ├── gpu │ ├── gpu.hpp │ ├── gpu_compile.cpp │ ├── gpu_compile.hpp │ ├── gpu_event.cpp │ ├── gpu_event.hpp │ ├── gpu_kernel.cpp │ ├── gpu_kernel.hpp │ ├── gpu_kernel_test.cpp │ ├── gpu_logging.hpp │ ├── gpu_manager.cpp │ ├── gpu_manager.hpp │ ├── gpu_memory.cpp │ ├── gpu_memory.hpp │ ├── gpu_stream.cpp │ └── gpu_stream.hpp ├── half.cpp ├── half.h ├── half_test.cpp ├── include │ ├── ark.hpp │ ├── ark │ │ ├── context.hpp │ │ ├── data_type.hpp │ │ ├── dims.hpp │ │ ├── error.hpp │ │ ├── executor.hpp │ │ ├── init.hpp │ │ ├── log.hpp │ │ ├── model.hpp │ │ ├── model_graph.hpp │ │ ├── model_ref.hpp │ │ ├── planner.hpp │ │ ├── random.hpp │ │ ├── tensor.hpp │ │ └── version.hpp │ └── kernels │ │ ├── arithmetic.h │ │ ├── ark_kernels.h │ │ ├── cast.h │ │ ├── comm.h │ │ ├── common │ │ ├── arch.h │ │ ├── atomic.h │ │ ├── bf16.h │ │ ├── broadcast.h │ │ ├── checker.h │ │ ├── device.h │ │ ├── ewise.h │ │ ├── fp16.h │ │ ├── fp32.h │ │ ├── integer.h │ │ ├── load_store.h │ │ ├── shfl.h │ │ ├── smem.h │ │ ├── static_math.h │ │ ├── sync.h │ │ ├── type_intrinsics.h │ │ ├── unit_op.h │ │ ├── vec.h │ │ └── vector_type.h │ │ ├── copy.h │ │ ├── embedding.h │ │ ├── gemm_ck.h │ │ ├── gemm_cutlass.h │ │ ├── im2col.h │ │ ├── kernel_template.in │ │ ├── layernorm.h │ │ ├── math_functions.h │ │ ├── matmul.h │ │ ├── noop.h │ │ ├── reduce.h │ │ ├── scalar.h │ │ └── transpose.h ├── logging.cpp ├── logging.hpp ├── model │ ├── model_buffer.cpp │ ├── model_buffer.hpp │ ├── model_context_manager.cpp │ ├── model_context_manager.hpp │ ├── model_context_manager_test.cpp │ ├── model_data_type.cpp │ ├── model_data_type.hpp │ ├── model_graph_impl.cpp │ ├── model_graph_impl.hpp │ ├── model_json.cpp │ ├── model_json.hpp │ ├── model_named_type.hpp │ ├── model_node.hpp │ ├── model_offset.cpp │ ├── model_offset.hpp │ ├── model_op.cpp │ ├── model_op.hpp │ ├── model_op_arg.cpp │ ├── model_op_arg.hpp │ ├── model_op_arg_test.cpp │ ├── model_tensor.cpp │ └── model_tensor.hpp ├── ops │ ├── ops_all_reduce.cpp │ ├── ops_all_reduce_test.cpp │ ├── ops_arithmetic.cpp │ ├── ops_arithmetic.hpp │ ├── ops_arithmetic_test.cpp │ ├── ops_broadcast.cpp │ ├── ops_broadcast.hpp │ ├── ops_cast.cpp │ ├── ops_cast.hpp │ ├── ops_cast_test.cpp │ ├── ops_common.cpp │ ├── ops_common.hpp │ ├── ops_communication.cpp │ ├── ops_communication.hpp │ ├── ops_communication_test.cpp │ ├── ops_copy.cpp │ ├── ops_copy.hpp │ ├── ops_copy_test.cpp │ ├── ops_embedding.cpp │ ├── ops_embedding.hpp │ ├── ops_embedding_test.cpp │ ├── ops_identity.cpp │ ├── ops_identity.hpp │ ├── ops_identity_test.cpp │ ├── ops_math.cpp │ ├── ops_math.hpp │ ├── ops_math_test.cpp │ ├── ops_matmul.cpp │ ├── ops_matmul.hpp │ ├── ops_matmul_test.cpp │ ├── ops_noop.cpp │ ├── ops_noop.hpp │ ├── ops_reduce.cpp │ ├── ops_reduce.hpp │ ├── ops_reduce_test.cpp │ ├── ops_refer.cpp │ ├── ops_refer.hpp │ ├── ops_reshape.cpp │ ├── ops_reshape.hpp │ ├── ops_reshape_test.cpp │ ├── ops_rope.cpp │ ├── ops_rope.hpp │ ├── ops_rope_test.cpp │ ├── ops_scalar.cpp │ ├── ops_scalar.hpp │ ├── ops_scalar_test.cpp │ ├── ops_sharding.cpp │ ├── ops_sharding_test.cpp │ ├── ops_tensor.cpp │ ├── ops_tensor.hpp │ ├── ops_tensor_test.cpp │ ├── ops_test_common.cpp │ ├── ops_test_common.hpp │ ├── ops_transpose.cpp │ ├── ops_transpose.hpp │ └── ops_transpose_test.cpp ├── range.cpp ├── range.hpp ├── range_test.cpp ├── unique_list.hpp ├── unique_list_test.cpp ├── unittest │ ├── unittest_utils.cpp │ └── unittest_utils.h └── utils │ ├── utils_math.cpp │ ├── utils_math.hpp │ ├── utils_math_test.cpp │ ├── utils_net.cpp │ ├── utils_net.hpp │ ├── utils_net_test.cpp │ ├── utils_string.cpp │ ├── utils_string.hpp │ └── utils_string_test.cpp ├── cmake ├── CheckAmdGpu.cmake ├── CheckNvidiaGpu.cmake ├── FindIBVerbs.cmake ├── FindNUMA.cmake ├── Utils.cmake ├── check_amd_gpu.hip └── check_nvidia_gpu.cu ├── docker ├── base-dev-x.dockerfile ├── base-rocm5.6.dockerfile ├── base-x.dockerfile ├── build-x.dockerfile └── build.sh ├── docs ├── doxygen │ ├── .gitignore │ └── Doxyfile ├── env.md ├── imgs │ ├── GPU-driven_System_Architecture.svg │ └── logos.svg ├── install.md ├── model_file.md ├── plan_file.md ├── quickstart.md ├── sphinx │ ├── .gitignore │ ├── Makefile │ ├── requirements.txt │ └── source │ │ ├── api.rst │ │ ├── conf.py │ │ └── index.rst └── tutorial │ ├── module_tutorial.md │ └── multi_gpu_tutorial.md ├── examples ├── ffn │ ├── Makefile │ └── ffn.cc ├── llama │ ├── README.md │ ├── generator.py │ ├── generator_torch.py │ ├── model.py │ ├── model_test.py │ └── requirements.txt ├── tensor_parallel │ └── parallel_matmul.py ├── transformer │ ├── megatron_ark.py │ ├── megatron_test.py │ ├── transformer_ark.py │ ├── transformer_pytorch.py │ ├── transformer_test.py │ └── transformer_utils.py └── tutorial │ ├── allreduce-packet │ ├── plan_gpu0.json │ ├── plan_gpu1.json │ ├── plan_gpu2.json │ ├── plan_gpu3.json │ ├── plan_gpu4.json │ ├── plan_gpu5.json │ ├── plan_gpu6.json │ └── plan_gpu7.json │ ├── allreduce-sm │ ├── plan_gpu0.json │ ├── plan_gpu1.json │ ├── plan_gpu2.json │ ├── plan_gpu3.json │ ├── plan_gpu4.json │ ├── plan_gpu5.json │ ├── plan_gpu6.json │ └── plan_gpu7.json │ ├── default_plan.json │ ├── model.json │ ├── module_tutorial.py │ ├── multi_gpu_plan.py │ ├── multi_gpu_tutorial.py │ ├── plan.json │ ├── plan_1_larger_tile.json │ ├── plan_2_split_k.json │ ├── plan_3_overwrite.json │ ├── plan_tutorial.py │ ├── planner_tutorial.py │ └── quickstart_tutorial.py ├── pyproject.toml ├── python ├── CMakeLists.txt ├── ark │ ├── __init__.py │ ├── data_type.py │ ├── error.py │ ├── init.py │ ├── log.py │ ├── model.py │ ├── module.py │ ├── ops.py │ ├── planner.py │ ├── runtime.py │ ├── serialize.py │ └── tensor.py ├── ark_py.cpp ├── data_type_py.cpp ├── dims_py.cpp ├── error_py.cpp ├── executor_py.cpp ├── init_py.cpp ├── log_py.cpp ├── model_graph_py.cpp ├── model_py.cpp ├── planner_py.cpp ├── random_py.cpp ├── tensor_py.cpp ├── unittest │ ├── common.py │ ├── test.py │ ├── test_data_type.py │ ├── test_error.py │ ├── test_model.py │ ├── test_ops.py │ └── test_runtime.py └── version_py.cpp ├── requirements.txt └── third_party ├── CMakeLists.txt └── patches ├── composable_kernel.patch └── cutlass.patch /.clang-format: -------------------------------------------------------------------------------- 1 | BasedOnStyle: Google 2 | ColumnLimit: 80 3 | IndentWidth: 4 4 | TabWidth: 4 5 | -------------------------------------------------------------------------------- /.codecov.yml: -------------------------------------------------------------------------------- 1 | codecov: 2 | require_ci_to_pass: yes 3 | 4 | flag_management: 5 | default_rules: 6 | carryforward: true 7 | paths: 8 | - ark/ 9 | - python/ark/ 10 | 11 | coverage: 12 | status: 13 | project: 14 | default: 15 | target: 85% 16 | threshold: 1% 17 | 18 | ignore: 19 | - "/usr/*" 20 | - "/tmp/*" 21 | - "*/build/*" 22 | - "*/dist-packages/*" 23 | - "*/third_party/*" 24 | - "*/ark/*_test.*" 25 | - "*/examples/*" 26 | - "*/python/unittest/*" 27 | - "*/ark/unittest/*" 28 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | #ARK 2 | build/ 3 | 4 | # Python 5 | **/__pycache__ 6 | *.pyc 7 | *.pyo 8 | *.pyd 9 | .pytest_cache/ 10 | 11 | # Git 12 | **/.git 13 | !/.git 14 | **/.gitmodules 15 | **/.dockerignore 16 | .github/ 17 | .azure-pipelines/ 18 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us fix 4 | title: "[Bug]" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. ... 16 | 2. ... 17 | 3. ... 18 | 4. ... 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **System (please complete the following information):** 24 | - Commit Hash of ARK 25 | - OS: [e.g. Ubuntu20.04] 26 | - GPU [e.g. V100, A100] 27 | - Networking Environment [e.g. Single-GPU, Single-node, Multi-node, NVLink, InfiniBand, RoCE] 28 | 29 | **Additional context** 30 | Add any other context about the problem here. 31 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "[Feature]" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | linters: 10 | runs-on: ubuntu-20.04 11 | 12 | steps: 13 | - name: Check out Git repository 14 | uses: actions/checkout@v4 15 | 16 | - name: Install ClangFormat 17 | run: sudo apt-get install -y clang-format 18 | 19 | - name: Run git-clang-format 20 | run: git clang-format --style=file --diff 21 | 22 | - name: Set up Python 23 | uses: actions/setup-python@v4 24 | with: 25 | python-version: 3.8 26 | 27 | - name: Install Python dependencies 28 | run: python3.8 -m pip install black 29 | 30 | - name: Run black 31 | run: python3.8 -m black --check --config pyproject.toml . 32 | 33 | spelling: 34 | runs-on: ubuntu-20.04 35 | 36 | steps: 37 | - name: Check out Git repository 38 | uses: actions/checkout@v4 39 | 40 | - name: Download misspell 41 | run: | 42 | curl -L https://github.com/client9/misspell/releases/download/v0.3.4/misspell_0.3.4_linux_64bit.tar.gz -o /tmp/misspell_0.3.4_linux_64bit.tar.gz 43 | tar -xzf /tmp/misspell_0.3.4_linux_64bit.tar.gz -C . 44 | 45 | - name: Check spelling 46 | run: | 47 | ./misspell -error .github ark examples python scripts 48 | -------------------------------------------------------------------------------- /.github/workflows/ut-rocm.yml: -------------------------------------------------------------------------------- 1 | name: "Unit Tests (ROCm)" 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | UnitTest: 13 | runs-on: [ self-hosted, AMD ] 14 | defaults: 15 | run: 16 | shell: bash 17 | strategy: 18 | matrix: 19 | rocm: [ rocm6.0 ] 20 | concurrency: 21 | group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.rocm }} 22 | cancel-in-progress: true 23 | # container: 24 | # image: "ghcr.io/microsoft/ark/ark:base-dev-${{ matrix.rocm }}" 25 | # options: --privileged --ipc=host --security-opt seccomp=unconfined --group-add video --ulimit memlock=-1:-1 26 | 27 | steps: 28 | - name: Checkout 29 | uses: actions/checkout@v4 30 | 31 | - name: Dubious ownership exception 32 | run: | 33 | git config --global --add safe.directory /__w/ark/ark 34 | 35 | - name: Build 36 | run: | 37 | mkdir build && cd build 38 | cmake -DCMAKE_BUILD_TYPE=Debug .. 39 | make -j ut 40 | 41 | - name: RunUT 42 | run: | 43 | cd build && ARK_ROOT=$PWD ARK_IGNORE_BINARY_CACHE=1 ctest --stop-on-failure --verbose --schedule-random 44 | 45 | - name: ReportCoverage 46 | run: | 47 | cd build 48 | lcov --capture --directory . --output-file coverage.info 49 | lcov --remove coverage.info \ 50 | '/usr/*' \ 51 | '/tmp/*' \ 52 | '*/third_party/*' \ 53 | '*/ark/*_test.*' \ 54 | '*/examples/*' \ 55 | '*/python/*' \ 56 | '*/ark/unittest/unittest_utils.cc' \ 57 | --output-file coverage.info 58 | lcov --list coverage.info 59 | bash <(curl -s https://codecov.io/bash) -f coverage.info || echo "Codecov did not collect coverage reports" 60 | 61 | - name: BuildPython 62 | run: | 63 | python3 -m pip install -r requirements.txt 64 | python3 -m pip install . 65 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Build temp files 2 | build/ 3 | **/build/ 4 | 5 | # VS Code 6 | .settings/ 7 | 8 | # Python 9 | _pycache__/ 10 | **/__pycache__/ 11 | *.pyc 12 | **/*.pyc 13 | *.pyo 14 | **/*.pyo 15 | 16 | # PyTest 17 | .pytest_cache/ 18 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/cutlass"] 2 | path = third_party/cutlass 3 | url = https://github.com/NVIDIA/cutlass 4 | 5 | [submodule "examples/llama/llama"] 6 | path = examples/llama/llama 7 | url = https://github.com/facebookresearch/llama 8 | 9 | [submodule "third_party/composable_kernel"] 10 | path = third_party/composable_kernel 11 | url = https://github.com/ROCmSoftwarePlatform/composable_kernel 12 | 13 | [submodule "third_party/mscclpp"] 14 | path = third_party/mscclpp 15 | url = https://github.com/microsoft/mscclpp 16 | 17 | [submodule "third_party/json"] 18 | path = third_party/json 19 | url = https://github.com/nlohmann/json 20 | -------------------------------------------------------------------------------- /.vscode/c_cpp_properties.json: -------------------------------------------------------------------------------- 1 | { 2 | "configurations": [ 3 | { 4 | "name": "Linux", 5 | "includePath": [ 6 | "${workspaceFolder}/**", 7 | "/usr/local/cuda/include", 8 | "/opt/rocm/include" 9 | ], 10 | "cppStandard": "c++17" 11 | } 12 | ], 13 | "version": 4 14 | } 15 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | "configurations": [ 3 | { 4 | "name": "ops_cast_test", 5 | "type": "cppdbg", 6 | "request": "launch", 7 | "program": "${workspaceFolder}/build/ark/ops_cast_test", 8 | "args": [], 9 | "stopAtEntry": false, 10 | "cwd": "${fileDirname}", 11 | "environment": [ 12 | { 13 | "name": "ARK_ROOT", 14 | "value": "${workspaceFolder}/build" 15 | }, 16 | { 17 | "name": "ARK_LOG_LEVEL", 18 | "value": "DEBUG" 19 | } 20 | ], 21 | "externalConsole": false, 22 | "MIMode": "gdb", 23 | "setupCommands": [ 24 | { 25 | "description": "Enable pretty-printing for gdb", 26 | "text": "-enable-pretty-printing", 27 | "ignoreFailures": true 28 | }, 29 | { 30 | "description": "Set Disassembly Flavor to Intel", 31 | "text": "-gdb-set disassembly-flavor intel", 32 | "ignoreFailures": true 33 | } 34 | ] 35 | }, 36 | { 37 | "name": "Python: generator", 38 | "type": "python", 39 | "request": "launch", 40 | "program": "${workspaceFolder}/examples/llama/generator.py", 41 | "console": "integratedTerminal", 42 | "args": [ 43 | "--ckpt_dir=/mnt/llama/llama-2-13b-chat", 44 | "--tok_path=/mnt/llama/tokenizer.model", 45 | "--params_path=/mnt/llama/llama-2-13b-chat/params.json", 46 | "--ngpu=2" 47 | ], 48 | "env": { 49 | "MSCCLPP_DEBUG": "WARN", 50 | "MSCCLPP_DEBUG_SUBSYS": "ALL", 51 | "ARK_LOG_LEVEL": "INFO" 52 | }, 53 | "justMyCode": false, 54 | }, 55 | { 56 | "name": "Python: model tester", 57 | "type": "python", 58 | "request": "launch", 59 | "program": "${workspaceFolder}/examples/llama/model_test.py", 60 | "console": "integratedTerminal", 61 | "args": [ 62 | "--ckpt_dir=/mnt/llama/llama-2-13b-chat", 63 | "--ngpu=2" 64 | ], 65 | "env": { 66 | "MSCCLPP_DEBUG": "WARN", 67 | "MSCCLPP_DEBUG_SUBSYS": "ALL", 68 | "ARK_LOG_LEVEL": "INFO", 69 | "ARK_DISABLE_GRAPH_OPT": "1" 70 | }, 71 | "justMyCode": false, 72 | } 73 | ] 74 | } 75 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "cmake.buildDirectory": "${workspaceFolder}/build", 3 | "cmake.environment": { 4 | "ARK_ROOT": "${workspaceFolder}/build", 5 | "ARK_IGNORE_BINARY_CACHE": "1", 6 | // "ARK_LOG_LEVEL": "DEBUG" 7 | }, 8 | "cmake.ctestArgs": [ 9 | "--verbose" 10 | ], 11 | } 12 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | title: "ARK: A GPU-driven system framework for scalable AI applications" 3 | version: 0.5.0 4 | message: >- 5 | If you use this project in your research, please cite it as below. 6 | authors: 7 | - given-names: Changho 8 | family-names: Hwang 9 | affiliation: Microsoft Research 10 | 11 | repository-code: 'https://github.com/microsoft/ark' 12 | abstract: >- 13 | ARK is a deep learning framework especially designed for highly optimized 14 | performance over distributed GPUs. Specifically, ARK adopts a GPU-driven 15 | execution model, where the GPU autonomously schedule and execute both 16 | computation and communication without any CPU intervention. 17 | ARK provides a set of APIs for users to express their distributed deep 18 | learning applications. ARK then automatically schedules a GPU-driven 19 | execution plan for the application, which generates a GPU kernel code 20 | called loop kernel. The loop kernel is a GPU kernel that contains a loop 21 | that iteratively executes the entire application, including both 22 | computation and communication. ARK then executes the loop kernel on the 23 | distributed GPUs. 24 | license: MIT 25 | license-url: https://github.com/microsoft/ark/blob/main/LICENSE 26 | 27 | preferred-citation: 28 | type: conference-paper 29 | title: "ARK: GPU-driven Code Execution for Distributed Deep Learning" 30 | authors: 31 | - given-names: Changho 32 | family-names: Hwang 33 | affiliation: Microsoft Research, KAIST 34 | - given-names: KyoungSoo 35 | family-names: Park 36 | affiliation: KAIST 37 | - given-names: Ran 38 | family-names: Shu 39 | affiliation: Microsoft Research 40 | - given-names: Xinyuan 41 | family-names: Qu 42 | affiliation: Microsoft Research 43 | - given-names: Peng 44 | family-names: Cheng 45 | affiliation: Microsoft Research 46 | - given-names: Yongqiang 47 | family-names: Xiong 48 | affiliation: Microsoft Research 49 | conference: 50 | name: 20th USENIX Symposium on Networked Systems Design and Implementation (NSDI '23) 51 | city: Boston 52 | region: MA 53 | country: US 54 | month: 4 55 | year: 2023 56 | url: https://www.usenix.org/conference/nsdi23/presentation/hwang 57 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # Support 2 | 3 | ## How to file issues and get help 4 | 5 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 6 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 7 | feature request as a new Issue. 8 | 9 | For help and questions about using this project, please file them as new Issues. 10 | 11 | ## Microsoft Support Policy 12 | 13 | Support for this project is limited to the resources listed above. 14 | -------------------------------------------------------------------------------- /ark/api/context.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "context_impl.hpp" 5 | #include "logging.hpp" 6 | 7 | namespace ark { 8 | 9 | Context::Context(Model& model) : impl_(std::make_shared(model)) {} 10 | 11 | int Context::id() const { return this->impl_->id_; } 12 | 13 | std::string Context::get(const std::string& key) const { 14 | if (!this->impl_->has(key)) { 15 | return ""; 16 | } 17 | return this->impl_->get(key).dump(); 18 | } 19 | 20 | void Context::set(const std::string& key, const std::string& value, 21 | ContextType type) { 22 | Json value_json; 23 | try { 24 | value_json = Json::parse(value); 25 | } catch (const ::nlohmann::json::parse_error& e) { 26 | ERR(InvalidUsageError, "Failed to parse context value as JSON: `", 27 | value, "`"); 28 | } 29 | this->impl_->set(key, value_json, type); 30 | } 31 | 32 | } // namespace ark 33 | -------------------------------------------------------------------------------- /ark/api/context_test.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "ark/context.hpp" 5 | 6 | #include "model/model_node.hpp" 7 | #include "unittest/unittest_utils.h" 8 | 9 | ark::unittest::State test_context() { 10 | ark::Model model; 11 | ark::Tensor t0 = model.tensor({1}, ark::FP32); 12 | ark::Tensor t1 = model.tensor({1}, ark::FP32); 13 | 14 | // node 0 15 | ark::Tensor t2 = model.add(t0, t1); 16 | 17 | ark::Tensor t3; 18 | ark::Tensor t4; 19 | ark::Tensor t5; 20 | { 21 | // node 1 22 | ark::Context ctx(model); 23 | ctx.set("key0", ark::Json("val1").dump()); 24 | t3 = model.relu(t2); 25 | 26 | UNITTEST_EQ(ctx.get("key0"), ark::Json("val1").dump()); 27 | 28 | // node 2 29 | ctx.set("key1", ark::Json("val2").dump()); 30 | t4 = model.sqrt(t3); 31 | 32 | UNITTEST_EQ(ctx.get("key0"), ark::Json("val1").dump()); 33 | UNITTEST_EQ(ctx.get("key1"), ark::Json("val2").dump()); 34 | } 35 | { 36 | // node 3 37 | ark::Context ctx(model); 38 | ctx.set("key0", ark::Json("val3").dump()); 39 | t5 = model.exp(t2); 40 | 41 | UNITTEST_EQ(ctx.get("key0"), ark::Json("val3").dump()); 42 | UNITTEST_EQ(ctx.get("key1"), ""); 43 | } 44 | 45 | UNITTEST_TRUE(model.verify()); 46 | 47 | auto compressed = model.compress(); 48 | UNITTEST_TRUE(compressed.verify()); 49 | 50 | auto nodes = compressed.nodes(); 51 | UNITTEST_EQ(nodes.size(), 4); 52 | 53 | UNITTEST_EQ(nodes[0]->context.size(), 0); 54 | UNITTEST_EQ(nodes[1]->context.size(), 1); 55 | UNITTEST_EQ(nodes[1]->context.at("key0"), ark::Json("val1")); 56 | UNITTEST_EQ(nodes[2]->context.size(), 2); 57 | UNITTEST_EQ(nodes[2]->context.at("key0"), ark::Json("val1")); 58 | UNITTEST_EQ(nodes[2]->context.at("key1"), ark::Json("val2")); 59 | UNITTEST_EQ(nodes[3]->context.size(), 1); 60 | UNITTEST_EQ(nodes[3]->context.at("key0"), ark::Json("val3")); 61 | 62 | return ark::unittest::SUCCESS; 63 | } 64 | 65 | ark::unittest::State test_context_invalid() { 66 | ark::Model model; 67 | ark::Context ctx(model); 68 | ark::Tensor t0 = model.tensor({1}, ark::FP32); 69 | ark::Tensor t1 = model.tensor({1}, ark::FP32); 70 | ark::Tensor t2 = model.add(t0, t1); 71 | 72 | UNITTEST_THROW(ctx.set("key", "val"), ark::InvalidUsageError); 73 | 74 | return ark::unittest::SUCCESS; 75 | } 76 | 77 | int main() { 78 | UNITTEST(test_context); 79 | UNITTEST(test_context_invalid); 80 | return 0; 81 | } 82 | -------------------------------------------------------------------------------- /ark/api/data_type.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "ark/data_type.hpp" 5 | 6 | #include 7 | 8 | #include "bfloat16.h" 9 | #include "half.h" 10 | #include "logging.hpp" 11 | #include "model/model_data_type.hpp" 12 | 13 | namespace ark { 14 | 15 | /// 16 | /// NOTE: how to add a new data type 17 | /// 1. Add an instance using `DATA_TYPE_INSTANCE()` macro. 18 | /// 2. Add a registration using `DATA_TYPE_REGISTER()` macro. 19 | /// 3. Expose the symbol in `include/ark/data_type.hpp`. 20 | /// 21 | 22 | #define DATA_TYPE_INSTANCE(_name, _type) \ 23 | extern const DataType _name( \ 24 | std::make_shared(#_name, #_type, sizeof(_type))); 25 | 26 | #define DATA_TYPE_REGISTER(_name) instances[#_name] = &_name; 27 | 28 | extern const DataType NONE(std::make_shared("NONE", "void", 0)); 29 | DATA_TYPE_INSTANCE(FP32, float); 30 | DATA_TYPE_INSTANCE(FP16, fp16); 31 | DATA_TYPE_INSTANCE(BF16, bf16); 32 | DATA_TYPE_INSTANCE(INT32, int32_t); 33 | DATA_TYPE_INSTANCE(UINT32, uint32_t); 34 | DATA_TYPE_INSTANCE(INT8, int8_t); 35 | DATA_TYPE_INSTANCE(UINT8, uint8_t); 36 | DATA_TYPE_INSTANCE(BYTE, char); 37 | 38 | const DataType &DataType::from_name(const std::string &type_name) { 39 | static std::map instances; 40 | if (instances.empty()) { 41 | DATA_TYPE_REGISTER(NONE); 42 | DATA_TYPE_REGISTER(FP32); 43 | DATA_TYPE_REGISTER(FP16); 44 | DATA_TYPE_REGISTER(BF16); 45 | DATA_TYPE_REGISTER(INT32); 46 | DATA_TYPE_REGISTER(UINT32); 47 | DATA_TYPE_REGISTER(INT8); 48 | DATA_TYPE_REGISTER(UINT8); 49 | DATA_TYPE_REGISTER(BYTE); 50 | } 51 | auto it = instances.find(type_name); 52 | if (it == instances.end()) { 53 | ERR(UnsupportedError, "Unknown data type: ", type_name); 54 | } 55 | return *(it->second); 56 | } 57 | 58 | size_t DataType::bytes() const { return ref_->bytes(); } 59 | 60 | const std::string &DataType::name() const { return ref_->type_name(); } 61 | 62 | } // namespace ark 63 | -------------------------------------------------------------------------------- /ark/api/error_test.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "ark/error.hpp" 5 | 6 | #include "unittest/unittest_utils.h" 7 | 8 | ark::unittest::State test_error() { 9 | UNITTEST_THROW(throw ark::ModelError("test"), ark::ModelError); 10 | 11 | try { 12 | throw ark::ModelError("test"); 13 | } catch (const ark::ModelError &e) { 14 | UNITTEST_EQ(std::string(e.what()), "test"); 15 | } 16 | 17 | try { 18 | throw ark::ModelError("test"); 19 | } catch (const ark::BaseError &e) { 20 | UNITTEST_EQ(std::string(e.what()), "test"); 21 | } 22 | 23 | return ark::unittest::SUCCESS; 24 | } 25 | 26 | int main() { 27 | UNITTEST(test_error); 28 | return 0; 29 | } 30 | -------------------------------------------------------------------------------- /ark/api/init.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "ark/init.hpp" 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include "env.h" 11 | #include "file_io.h" 12 | #include "logging.hpp" 13 | 14 | namespace ark { 15 | 16 | void init() { 17 | LOG(DEBUG, "init ark"); 18 | 19 | // Get the environment variables. 20 | (void)get_env(true); 21 | 22 | // Create the temporary directory if it does not exist. 23 | const std::string &tmp_dir = get_env().path_tmp_dir; 24 | if (!is_exist(tmp_dir)) { 25 | if (create_dir(tmp_dir) != 0) { 26 | ERR(SystemError, 27 | "init failed: failed to create temporary directory ", tmp_dir, 28 | " (errno ", errno, ")"); 29 | } 30 | } else if (!get_env().keep_tmp) { 31 | // Clear the temporary directory if it exists and keep_tmp is false. 32 | if (clear_dir(tmp_dir) != 0) { 33 | ERR(SystemError, 34 | "init failed: failed to clear temporary directory ", tmp_dir, 35 | " (errno ", errno, ")"); 36 | } 37 | } 38 | } 39 | 40 | } // namespace ark 41 | -------------------------------------------------------------------------------- /ark/api/init_test.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "ark/init.hpp" 5 | 6 | #include "file_io.h" 7 | #include "unittest/unittest_utils.h" 8 | 9 | ark::unittest::State test_init() { 10 | // invalid tmp directory 11 | ::setenv("ARK_TMP", "", 1); 12 | UNITTEST_THROW(ark::init(), ark::SystemError); 13 | 14 | // create a tmp directory 15 | ::setenv("ARK_TMP", "/tmp/ark/.test_init", 1); 16 | ::setenv("ARK_KEEP_TMP", "1", 1); 17 | ark::init(); 18 | 19 | // create a tmp file 20 | ark::write_file("/tmp/ark/.test_init/test", "test"); 21 | 22 | // clear the tmp directory 23 | ::setenv("ARK_KEEP_TMP", "0", 1); 24 | ark::init(); 25 | UNITTEST_TRUE(!ark::is_exist("/tmp/ark/.test_init/test")); 26 | 27 | // given tmp directory is not a directory 28 | ::setenv("ARK_TMP", "/dev/null", 1); 29 | UNITTEST_THROW(ark::init(), ark::SystemError); 30 | 31 | return ark::unittest::SUCCESS; 32 | } 33 | 34 | int main() { 35 | UNITTEST(test_init); 36 | return 0; 37 | } 38 | -------------------------------------------------------------------------------- /ark/api/log.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "ark/log.hpp" 5 | 6 | #include "logging.hpp" 7 | 8 | namespace ark { 9 | 10 | void log(LogLevel level, const std::string &file, int line, 11 | const std::string &msg) { 12 | _log(level, file, line, msg); 13 | } 14 | 15 | } // namespace ark 16 | -------------------------------------------------------------------------------- /ark/api/model.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "ark/model.hpp" 5 | 6 | #include 7 | 8 | #include "logging.hpp" 9 | 10 | namespace ark { 11 | 12 | Model::Model(int rank, int world_size) : ModelGraph(rank, world_size) { 13 | static size_t next_id = 0; 14 | id_ = next_id++; 15 | } 16 | 17 | Model::Model(const Model &other) : ModelGraph(other), id_(other.id()) {} 18 | 19 | size_t Model::id() const { return id_; } 20 | 21 | Model Model::compress() const { 22 | Model model(*this); 23 | model.compress_nodes(); 24 | return model; 25 | } 26 | 27 | int Model::unique_tag() { 28 | size_t num_ints = size_t(std::numeric_limits::max()) * 2 + 2; 29 | if (tags_.size() == num_ints) { 30 | ERR(ModelError, "no more unique tags"); 31 | } 32 | int next_val; 33 | if (tags_.empty()) { 34 | next_val = std::numeric_limits::min(); 35 | } else if (*tags_.rbegin() < std::numeric_limits::max()) { 36 | next_val = *tags_.rbegin() + 1; 37 | } else { 38 | next_val = std::numeric_limits::min(); 39 | for (int tag : tags_) { 40 | if (tag == next_val) { 41 | next_val++; 42 | } else { 43 | break; 44 | } 45 | } 46 | } 47 | tags_.insert(next_val); 48 | return next_val; 49 | } 50 | 51 | } // namespace ark 52 | -------------------------------------------------------------------------------- /ark/api/model_graph.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "ark/model_graph.hpp" 5 | 6 | #include "logging.hpp" 7 | #include "model/model_graph_impl.hpp" 8 | #include "model/model_node.hpp" 9 | 10 | namespace ark { 11 | 12 | ModelGraph::ModelGraph(int rank, int world_size) 13 | : impl_(std::make_unique(rank, world_size)) {} 14 | 15 | ModelGraph::ModelGraph(const ModelGraph &other) 16 | : impl_(std::make_unique(*other.impl_)) {} 17 | 18 | ModelGraph::~ModelGraph() = default; 19 | 20 | ModelGraph &ModelGraph::operator=(const ModelGraph &other) { 21 | *impl_ = *other.impl_; 22 | return *this; 23 | } 24 | 25 | /// Get the list of @ref ModelNode in the graph. 26 | std::vector ModelGraph::nodes() const { return impl_->nodes(); } 27 | 28 | std::string ModelGraph::serialize(bool pretty) const { 29 | return impl_->serialize(pretty); 30 | } 31 | 32 | int ModelGraph::rank() const { return impl_->rank(); } 33 | 34 | int ModelGraph::world_size() const { return impl_->world_size(); } 35 | 36 | void ModelGraph::compress_nodes() { impl_->compress_nodes(); } 37 | 38 | bool ModelGraph::compressed() const { return impl_->compressed(); } 39 | 40 | bool ModelGraph::verify() const { return impl_->verify(); } 41 | 42 | } // namespace ark 43 | -------------------------------------------------------------------------------- /ark/api/random.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "ark/random.hpp" 5 | 6 | #include 7 | 8 | namespace ark { 9 | 10 | // Initialize the random number generator. 11 | void srand(int seed) { ::srand(seed); } 12 | 13 | // Generate a random integer. 14 | int rand() { return ::rand(); } 15 | 16 | } // namespace ark 17 | -------------------------------------------------------------------------------- /ark/api/tensor.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "ark/tensor.hpp" 5 | 6 | #include "model/model_data_type.hpp" 7 | #include "model/model_tensor.hpp" 8 | 9 | namespace ark { 10 | 11 | size_t Tensor::id() const { 12 | if (ref_) { 13 | return ref_->id(); 14 | } 15 | return 0; 16 | } 17 | 18 | Dims Tensor::shape() const { 19 | if (ref_) { 20 | return ref_->shape(); 21 | } 22 | return Dims(); 23 | } 24 | 25 | Dims Tensor::strides() const { 26 | if (ref_) { 27 | return ref_->strides(); 28 | } 29 | return Dims(); 30 | } 31 | 32 | Dims Tensor::offsets() const { 33 | if (ref_) { 34 | return ref_->offsets(); 35 | } 36 | return Dims(); 37 | } 38 | 39 | Dims Tensor::padded_shape() const { 40 | if (ref_) { 41 | return ref_->padded_shape(); 42 | } 43 | return Dims(); 44 | } 45 | 46 | const DataType &Tensor::data_type() const { 47 | if (ref_) { 48 | return DataType::from_name(ref_->data_type()->type_name()); 49 | } 50 | return NONE; 51 | } 52 | 53 | std::ostream &operator<<(std::ostream &os, const Tensor &tensor) { 54 | if (tensor.is_null()) { 55 | os << "null"; 56 | } else { 57 | os << tensor.ref()->serialize().dump(); 58 | } 59 | return os; 60 | } 61 | 62 | } // namespace ark 63 | -------------------------------------------------------------------------------- /ark/api/version.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "ark/version.hpp" 5 | 6 | #include 7 | #include 8 | 9 | namespace ark { 10 | 11 | std::string version() { 12 | std::stringstream ss; 13 | ss << ARK_MAJOR << "." << ARK_MINOR << "." << ARK_PATCH; 14 | return ss.str(); 15 | } 16 | 17 | } // namespace ark 18 | -------------------------------------------------------------------------------- /ark/api/version_test.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "ark/version.hpp" 5 | 6 | #include "unittest/unittest_utils.h" 7 | 8 | ark::unittest::State test_version() { 9 | auto version = ark::version(); 10 | 11 | // Check if the version string is in the correct format. 12 | auto dot1 = version.find('.'); 13 | auto dot2 = version.find('.', dot1 + 1); 14 | UNITTEST_NE(dot1, std::string::npos); 15 | UNITTEST_NE(dot2, std::string::npos); 16 | 17 | return ark::unittest::SUCCESS; 18 | } 19 | 20 | int main() { 21 | UNITTEST(test_version); 22 | return 0; 23 | } 24 | -------------------------------------------------------------------------------- /ark/arch.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_ARCH_HPP_ 5 | #define ARK_ARCH_HPP_ 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | namespace ark { 12 | 13 | class Arch; 14 | using ArchRef = std::shared_ptr; 15 | 16 | class Arch { 17 | private: 18 | std::vector category_; 19 | std::string name_; 20 | 21 | public: 22 | Arch(){}; 23 | 24 | Arch(const std::string &c0); 25 | 26 | Arch(const std::string &c0, const std::string &c1); 27 | 28 | Arch(const std::string &c0, const std::string &c1, const std::string &c2); 29 | 30 | Arch(const Arch &other) = default; 31 | 32 | Arch &operator=(const Arch &other); 33 | 34 | const std::string &name() const { return name_; } 35 | 36 | bool operator==(const Arch &other) const; 37 | 38 | bool operator!=(const Arch &other) const { return !(*this == other); } 39 | 40 | const std::vector &category() const { return category_; } 41 | 42 | bool belongs_to(const ArchRef arch) const; 43 | 44 | bool later_than(const ArchRef arch) const; 45 | 46 | static const ArchRef from_name(const std::string &name); 47 | }; 48 | 49 | extern const ArchRef ARCH_ANY; 50 | 51 | extern const ArchRef ARCH_CUDA; 52 | extern const ArchRef ARCH_CUDA_70; 53 | extern const ArchRef ARCH_CUDA_80; 54 | extern const ArchRef ARCH_CUDA_90; 55 | 56 | extern const ArchRef ARCH_ROCM; 57 | extern const ArchRef ARCH_ROCM_90A; 58 | extern const ArchRef ARCH_ROCM_942; 59 | 60 | } // namespace ark 61 | 62 | #endif // ARK_ARCH_HPP_ 63 | -------------------------------------------------------------------------------- /ark/arch_test.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "arch.hpp" 5 | 6 | #include "unittest/unittest_utils.h" 7 | 8 | ark::unittest::State test_arch() { 9 | UNITTEST_TRUE(ark::Arch("CUDA") == *ark::ARCH_CUDA); 10 | UNITTEST_TRUE(ark::Arch("CUDA", "80") == *ark::ARCH_CUDA_80); 11 | UNITTEST_TRUE(ark::Arch("CUDA", "80", "XYZ").name() == "CUDA_80_XYZ"); 12 | 13 | UNITTEST_TRUE(ark::ARCH_CUDA->belongs_to(ark::ARCH_ANY)); 14 | UNITTEST_TRUE(ark::ARCH_ROCM->belongs_to(ark::ARCH_ANY)); 15 | UNITTEST_TRUE(ark::ARCH_CUDA->belongs_to(ark::ARCH_CUDA)); 16 | UNITTEST_TRUE(ark::ARCH_CUDA_80->belongs_to(ark::ARCH_CUDA)); 17 | UNITTEST_FALSE(ark::ARCH_ROCM->belongs_to(ark::ARCH_CUDA)); 18 | 19 | UNITTEST_TRUE(ark::ARCH_CUDA_80->belongs_to(ark::ARCH_CUDA_80)); 20 | UNITTEST_TRUE(ark::ARCH_CUDA_90->later_than(ark::ARCH_CUDA_80)); 21 | UNITTEST_FALSE(ark::ARCH_CUDA_80->later_than(ark::ARCH_CUDA_90)); 22 | UNITTEST_FALSE(ark::ARCH_CUDA_80->later_than(ark::ARCH_CUDA_80)); 23 | return ark::unittest::SUCCESS; 24 | } 25 | 26 | int main() { 27 | UNITTEST(test_arch); 28 | return 0; 29 | } 30 | -------------------------------------------------------------------------------- /ark/codegen.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_CODEGEN_HPP_ 5 | #define ARK_CODEGEN_HPP_ 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | #include "model/model_json.hpp" 12 | 13 | namespace ark { 14 | 15 | class CodeGenerator { 16 | public: 17 | CodeGenerator(const PlanJson &plan, 18 | const std::map &buffer_id_to_offset, 19 | const std::string &name = "ark_kernel"); 20 | 21 | ~CodeGenerator() = default; 22 | 23 | std::string code() const; 24 | 25 | size_t num_procs() const; 26 | 27 | size_t num_warps_per_proc() const; 28 | 29 | private: 30 | class Impl; 31 | std::shared_ptr impl_; 32 | }; 33 | 34 | } // namespace ark 35 | 36 | #endif // ARK_CODEGEN_HPP_ 37 | -------------------------------------------------------------------------------- /ark/context_impl.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "context_impl.hpp" 5 | 6 | #include "logging.hpp" 7 | #include "model/model_context_manager.hpp" 8 | #include "model/model_graph_impl.hpp" 9 | 10 | namespace ark { 11 | 12 | Context::Impl::Impl(Model& model) 13 | : context_manager_(std::make_shared(model)) { 14 | static int next_id = 0; 15 | id_ = next_id++; 16 | } 17 | 18 | Json Context::Impl::get(const std::string& key) const { 19 | return context_manager_->get(key); 20 | } 21 | 22 | void Context::Impl::set(const std::string& key, const Json& value_json, 23 | ContextType type) { 24 | if (type == ContextType::Overwrite) { 25 | context_manager_->set(key, value_json); 26 | } else if (type == ContextType::Extend) { 27 | auto ctx = context_manager_->get(key); 28 | if (ctx.empty()) { 29 | context_manager_->set(key, value_json); 30 | } else if (!ctx.is_object() || !value_json.is_object()) { 31 | ERR(InvalidUsageError, 32 | "Context value must be a JSON object when type is " 33 | "ContextTypeExtend. Key: ", 34 | key, ", old value: ", ctx.dump(), 35 | ", new value: ", value_json.dump()); 36 | } else { 37 | for (const auto& [k, v] : value_json.items()) { 38 | ctx[k] = v; 39 | } 40 | context_manager_->set(key, ctx); 41 | } 42 | } else if (type == ContextType::Immutable) { 43 | if (!context_manager_->has(key)) { 44 | context_manager_->set(key, value_json); 45 | } 46 | } else { 47 | ERR(InvalidUsageError, "Unknown context type"); 48 | } 49 | } 50 | 51 | bool Context::Impl::has(const std::string& key) const { 52 | return context_manager_->has(key); 53 | } 54 | 55 | } // namespace ark 56 | -------------------------------------------------------------------------------- /ark/context_impl.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_CONTEXT_IMPL_HPP_ 5 | #define ARK_CONTEXT_IMPL_HPP_ 6 | 7 | #include "ark/context.hpp" 8 | #include "model/model_json.hpp" 9 | 10 | namespace ark { 11 | 12 | class ModelContextManager; 13 | 14 | class Context::Impl { 15 | public: 16 | Impl(Model& model); 17 | 18 | Json get(const std::string& key) const; 19 | 20 | void set(const std::string& key, const Json& value_json, ContextType type); 21 | 22 | bool has(const std::string& key) const; 23 | 24 | protected: 25 | friend class Context; 26 | 27 | std::shared_ptr context_manager_; 28 | int id_; 29 | }; 30 | 31 | } // namespace ark 32 | 33 | #endif // ARK_CONTEXT_IMPL_HPP_ 34 | -------------------------------------------------------------------------------- /ark/cpu_timer.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "cpu_timer.h" 5 | 6 | #include 7 | 8 | namespace ark { 9 | 10 | // Measure current time in second. 11 | double cpu_timer(void) { 12 | struct timespec tspec; 13 | if (clock_gettime(CLOCK_MONOTONIC, &tspec) == -1) { 14 | return -1; 15 | } 16 | return (tspec.tv_nsec / 1.0e9) + tspec.tv_sec; 17 | } 18 | 19 | // Sleep in second. 20 | int cpu_timer_sleep(double sec) { 21 | struct timespec tspec; 22 | tspec.tv_sec = (time_t)sec; 23 | tspec.tv_nsec = (long)((sec - tspec.tv_sec) * 1.0e9); 24 | return nanosleep(&tspec, 0); 25 | } 26 | 27 | // Sleep in nanosecond. 28 | int cpu_ntimer_sleep(long nsec) { 29 | struct timespec tspec; 30 | tspec.tv_sec = 0; 31 | tspec.tv_nsec = nsec; 32 | return nanosleep(&tspec, 0); 33 | } 34 | 35 | } // namespace ark 36 | -------------------------------------------------------------------------------- /ark/cpu_timer.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_CPU_TIMER_H_ 5 | #define ARK_CPU_TIMER_H_ 6 | 7 | namespace ark { 8 | 9 | // Measure current time in second. 10 | double cpu_timer(void); 11 | // Sleep in second. 12 | int cpu_timer_sleep(double sec); 13 | // Sleep in nanosecond. 14 | int cpu_ntimer_sleep(long nsec); 15 | 16 | } // namespace ark 17 | 18 | #endif // ARK_CPU_TIMER_H_ 19 | -------------------------------------------------------------------------------- /ark/env.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_ENV_H_ 5 | #define ARK_ENV_H_ 6 | 7 | #include 8 | 9 | namespace ark { 10 | 11 | // Environment variables. 12 | struct Env { 13 | Env(); 14 | // Log level. 15 | std::string log_level; 16 | // Root directory where ARK is installed. 17 | std::string path_root_dir; 18 | // Temporary directory. 19 | std::string path_tmp_dir; 20 | // If true, we do not remove temporal files in `path_tmp_dir`. 21 | bool keep_tmp; 22 | // Hostfile. 23 | std::string hostfile; 24 | // Number of ranks per host. 25 | int num_ranks_per_host; 26 | // Disable IB 27 | bool disable_ib; 28 | // Ignore compiled binary cache. 29 | bool ignore_binary_cache; 30 | // Enforce to compile a specific plan file. 31 | std::string enforce_plan_path; 32 | // MSCCL++ bootstrap port. 33 | int mscclpp_port; 34 | }; 35 | 36 | // Get the global Env. 37 | const Env &get_env(bool reset = false); 38 | 39 | } // namespace ark 40 | 41 | #endif // ARK_ENV_H_ 42 | -------------------------------------------------------------------------------- /ark/file_io.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "file_io.h" 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include "logging.hpp" 11 | 12 | namespace fs = std::filesystem; 13 | 14 | namespace ark { 15 | 16 | bool is_exist(const std::string &path) { 17 | return fs::directory_entry{path}.exists(); 18 | } 19 | 20 | bool is_dir(const std::string &path) { 21 | return fs::is_directory(fs::status(path)); 22 | } 23 | 24 | bool is_file(const std::string &path) { 25 | return fs::is_regular_file(fs::status(path)); 26 | } 27 | 28 | int create_dir(const std::string &path) { 29 | std::error_code ec; 30 | fs::create_directories(path, ec); 31 | return ec.value(); 32 | } 33 | 34 | int remove_dir(const std::string &path) { 35 | LOG(DEBUG, "remove dir: ", path); 36 | std::error_code ec; 37 | fs::remove_all(path, ec); 38 | return ec.value(); 39 | } 40 | 41 | // Remove all files in a directory. 42 | int clear_dir(const std::string &path) { 43 | LOG(DEBUG, "clear dir: ", path); 44 | std::error_code ec; 45 | for (const auto &entry : fs::directory_iterator(path, ec)) { 46 | if (ec.value() != 0) { 47 | return ec.value(); 48 | } 49 | fs::remove_all(entry.path(), ec); 50 | if (ec.value() != 0) { 51 | return ec.value(); 52 | } 53 | } 54 | return ec.value(); 55 | } 56 | 57 | std::vector list_dir(const std::string &path) { 58 | std::vector files; 59 | for (const auto &entry : fs::directory_iterator(path)) { 60 | files.push_back(entry.path().string()); 61 | } 62 | return files; 63 | } 64 | 65 | std::string read_file(const std::string &path) { 66 | std::ifstream file(path); 67 | std::stringstream ss; 68 | ss << file.rdbuf(); 69 | return ss.str(); 70 | } 71 | 72 | void write_file(const std::string &path, const std::string &data) { 73 | std::ofstream file(path, std::ios::out | std::ios::trunc); 74 | file << data; 75 | } 76 | 77 | int remove_file(const std::string &path) { 78 | LOG(DEBUG, "remove file: ", path); 79 | std::error_code ec; 80 | fs::remove(path, ec); 81 | return ec.value(); 82 | } 83 | 84 | std::string get_dir(const std::string &path) { 85 | return fs::path(path).parent_path().string(); 86 | } 87 | 88 | } // namespace ark 89 | -------------------------------------------------------------------------------- /ark/file_io.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_FILE_IO_H_ 5 | #define ARK_FILE_IO_H_ 6 | 7 | #include 8 | #include 9 | 10 | namespace ark { 11 | 12 | bool is_exist(const std::string &path); 13 | 14 | bool is_dir(const std::string &path); 15 | bool is_file(const std::string &path); 16 | int create_dir(const std::string &path); 17 | int remove_dir(const std::string &path); 18 | int clear_dir(const std::string &path); 19 | std::vector list_dir(const std::string &path); 20 | 21 | std::string read_file(const std::string &path); 22 | void write_file(const std::string &path, const std::string &data); 23 | int remove_file(const std::string &path); 24 | std::string get_dir(const std::string &path); 25 | 26 | } // namespace ark 27 | 28 | #endif // ARK_FILE_IO_H_ 29 | -------------------------------------------------------------------------------- /ark/gpu/gpu_compile.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_GPU_COMPILE_HPP_ 5 | #define ARK_GPU_COMPILE_HPP_ 6 | 7 | #include 8 | #include 9 | 10 | #include "arch.hpp" 11 | 12 | namespace ark { 13 | 14 | const std::string gpu_compile(const std::vector &codes, 15 | const ArchRef arch, unsigned int max_reg_cnt); 16 | 17 | } // namespace ark 18 | 19 | #endif // ARK_GPU_COMPILE_HPP_ 20 | -------------------------------------------------------------------------------- /ark/gpu/gpu_event.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "gpu/gpu_event.hpp" 5 | 6 | #include "gpu/gpu_logging.hpp" 7 | #include "gpu/gpu_manager.hpp" 8 | 9 | namespace ark { 10 | class GpuEvent::Impl { 11 | public: 12 | Impl(bool disable_timing); 13 | ~Impl(); 14 | Impl(const Impl&) = delete; 15 | Impl& operator=(const Impl&) = delete; 16 | 17 | void record(gpuStream stream); 18 | float elapsed_msec(const GpuEvent& other) const; 19 | 20 | private: 21 | gpuEvent event_; 22 | }; 23 | 24 | GpuEvent::Impl::Impl(bool disable_timing) { 25 | unsigned int flags = 0; 26 | if (disable_timing) { 27 | flags |= gpuEventDisableTiming; 28 | } 29 | GLOG(gpuEventCreateWithFlags(&event_, flags)); 30 | } 31 | 32 | GpuEvent::Impl::~Impl() { GLOG(gpuEventDestroy(event_)); } 33 | 34 | void GpuEvent::Impl::record(gpuStream stream) { 35 | GLOG(gpuEventRecord(event_, stream)); 36 | } 37 | 38 | float GpuEvent::Impl::elapsed_msec(const GpuEvent& other) const { 39 | float elapsed; 40 | GLOG(gpuEventElapsedTime(&elapsed, other.pimpl_->event_, event_)); 41 | return elapsed; 42 | } 43 | 44 | GpuEvent::GpuEvent(bool disable_timing) 45 | : pimpl_(std::make_shared(disable_timing)) {} 46 | 47 | void GpuEvent::record(gpuStream stream) { pimpl_->record(stream); } 48 | 49 | float GpuEvent::elapsed_msec(const GpuEvent& other) const { 50 | return pimpl_->elapsed_msec(other); 51 | } 52 | 53 | } // namespace ark 54 | -------------------------------------------------------------------------------- /ark/gpu/gpu_event.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_GPU_EVENT_HPP_ 5 | #define ARK_GPU_EVENT_HPP_ 6 | 7 | #include 8 | 9 | #include "gpu/gpu.hpp" 10 | 11 | namespace ark { 12 | 13 | class GpuStream; 14 | class GpuManager; 15 | 16 | class GpuEvent { 17 | public: 18 | ~GpuEvent() = default; 19 | GpuEvent(const GpuEvent &) = delete; 20 | GpuEvent &operator=(const GpuEvent &) = delete; 21 | 22 | void record(gpuStream stream); 23 | float elapsed_msec(const GpuEvent &other) const; 24 | 25 | protected: 26 | friend class GpuManager; 27 | 28 | GpuEvent(bool disable_timing = false); 29 | 30 | private: 31 | class Impl; 32 | std::shared_ptr pimpl_; 33 | }; 34 | } // namespace ark 35 | 36 | #endif // ARK_GPU_EVENT_HPP_ 37 | -------------------------------------------------------------------------------- /ark/gpu/gpu_kernel.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_GPU_KERNEL_HPP_ 5 | #define ARK_GPU_KERNEL_HPP_ 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | #include "gpu_stream.hpp" 12 | 13 | namespace ark { 14 | 15 | class GpuManager; 16 | 17 | class GpuKernel { 18 | public: 19 | GpuKernel(int gpu_id, const std::string& codes, 20 | const std::array& block_dim, 21 | const std::array& grid_dim, size_t smem_bytes, 22 | const std::string& kernel_name); 23 | 24 | void init(int gpu_id, const std::string& codes, 25 | const std::array& block_dim, 26 | const std::array& grid_dim, size_t smem_bytes, 27 | const std::string& kernel_name); 28 | void compile(); 29 | void launch(gpuStream stream, std::vector& args); 30 | 31 | gpuDeviceptr get_global(const std::string& name, 32 | bool ignore_not_found = false) const; 33 | bool is_compiled() const { return function_ != nullptr; } 34 | 35 | protected: 36 | std::shared_ptr gpu_manager_; 37 | std::string code_; 38 | std::array block_dim_; 39 | std::array grid_dim_; 40 | int smem_bytes_; 41 | std::string kernel_name_; 42 | std::string bin_; 43 | gpuModule module_; 44 | gpuFunction function_ = nullptr; 45 | }; 46 | 47 | } // namespace ark 48 | 49 | #endif // ARK_GPU_KERNEL_HPP_ 50 | -------------------------------------------------------------------------------- /ark/gpu/gpu_kernel_test.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "gpu/gpu_kernel.hpp" 5 | 6 | #include "unittest/unittest_utils.h" 7 | 8 | const std::string void_kernel = "extern \"C\" __global__ void kernel() {}"; 9 | 10 | ark::unittest::State test_gpu_kernel() { 11 | ark::GpuKernel kernel(0, void_kernel, {1, 1, 1}, {1, 1, 1}, 0, "kernel"); 12 | UNITTEST_TRUE(!kernel.is_compiled()); 13 | kernel.compile(); 14 | UNITTEST_TRUE(kernel.is_compiled()); 15 | std::vector args; 16 | for (int i = 0; i < 10; i++) { 17 | kernel.launch(nullptr, args); 18 | } 19 | return ark::unittest::SUCCESS; 20 | } 21 | 22 | int main() { 23 | UNITTEST(test_gpu_kernel); 24 | return 0; 25 | } 26 | -------------------------------------------------------------------------------- /ark/gpu/gpu_logging.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_GPU_LOGGING_HPP_ 5 | #define ARK_GPU_LOGGING_HPP_ 6 | 7 | #include "gpu/gpu.hpp" 8 | #include "logging.hpp" 9 | 10 | #define GLOG(cmd) \ 11 | do { \ 12 | ark::gpuError _e = cmd; \ 13 | if (_e != ark::gpuSuccess) { \ 14 | const char *_estr = ark::gpuGetErrorString(_e); \ 15 | ERR(ark::GpuError, _e, " '", _estr, "'"); \ 16 | } \ 17 | } while (0) 18 | 19 | #define GLOG_DRV(cmd) \ 20 | do { \ 21 | ark::gpuDrvError _e = cmd; \ 22 | if (_e != ark::gpuDrvSuccess) { \ 23 | const char *_estr; \ 24 | if (ark::gpuDrvGetErrorString(_e, &_estr) == ark::gpuDrvSuccess) { \ 25 | ERR(ark::GpuError, _e, " '", _estr, "'"); \ 26 | } else { \ 27 | ERR(ark::GpuError, _e); \ 28 | } \ 29 | } \ 30 | } while (0) 31 | 32 | #endif // ARK_GPU_LOGGING_HPP_ 33 | -------------------------------------------------------------------------------- /ark/gpu/gpu_manager.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_GPU_MANAGER_HPP_ 5 | #define ARK_GPU_MANAGER_HPP_ 6 | 7 | #include 8 | 9 | #include "arch.hpp" 10 | #include "gpu/gpu.hpp" 11 | #include "gpu/gpu_event.hpp" 12 | #include "gpu/gpu_memory.hpp" 13 | #include "gpu/gpu_stream.hpp" 14 | 15 | namespace ark { 16 | 17 | class GpuManager { 18 | public: 19 | static std::shared_ptr get_instance(int gpu_id); 20 | 21 | GpuManager(const GpuManager &) = delete; 22 | ~GpuManager() = default; 23 | GpuManager &operator=(const GpuManager &) = delete; 24 | 25 | void set_current() const; 26 | std::shared_ptr malloc(size_t bytes, size_t align = 1, 27 | bool expose = false); 28 | std::shared_ptr malloc_host(size_t bytes, 29 | unsigned int flags = 0); 30 | std::shared_ptr create_event(bool disable_timing = false) const; 31 | std::shared_ptr create_stream() const; 32 | 33 | void launch(gpuFunction function, const std::array &grid_dim, 34 | const std::array &block_dim, int smem_bytes, 35 | gpuStream stream, void **params, void **extra) const; 36 | 37 | struct Info; 38 | const Info &info() const; 39 | 40 | struct Info { 41 | int cc_major; 42 | int cc_minor; 43 | size_t gmem_total; 44 | int smem_total; 45 | int smem_block_total; 46 | int num_sm; 47 | int clk_rate; 48 | int threads_per_warp; 49 | int max_registers_per_block; 50 | int max_threads_per_block; 51 | int max_registers_per_thread = 256; // TODO: how to get this? 52 | int smem_align = 128; // TODO: how to get this? 53 | ArchRef arch; 54 | }; 55 | 56 | private: 57 | GpuManager(int gpu_id); 58 | 59 | class Impl; 60 | std::shared_ptr pimpl_; 61 | }; 62 | 63 | } // namespace ark 64 | 65 | #endif // ARK_GPU_MANAGER_HPP_ 66 | -------------------------------------------------------------------------------- /ark/gpu/gpu_memory.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_GPU_MEMORY_HPP_ 5 | #define ARK_GPU_MEMORY_HPP_ 6 | 7 | #include 8 | #include 9 | 10 | #include "gpu/gpu.hpp" 11 | 12 | namespace ark { 13 | 14 | class GpuManager; 15 | 16 | class GpuMemory { 17 | public: 18 | ~GpuMemory() = default; 19 | size_t bytes() const; 20 | 21 | template 22 | T* ref(size_t offset = 0) const { 23 | return reinterpret_cast((size_t)this->ref_impl(offset)); 24 | } 25 | 26 | private: 27 | friend class GpuManager; 28 | GpuMemory(const GpuManager& manager, size_t bytes, size_t align, 29 | bool expose = false); 30 | 31 | class Impl; 32 | std::shared_ptr pimpl_; 33 | 34 | void* ref_impl(size_t offset = 0) const; 35 | }; 36 | 37 | class GpuHostMemory { 38 | public: 39 | ~GpuHostMemory(); 40 | GpuHostMemory(const GpuHostMemory&) = delete; 41 | GpuHostMemory& operator=(const GpuHostMemory&) = delete; 42 | 43 | template 44 | T* ref() const { 45 | return reinterpret_cast(ptr_); 46 | } 47 | 48 | private: 49 | friend class GpuManager; 50 | GpuHostMemory(const GpuManager& manager, size_t bytes, unsigned int flags); 51 | 52 | void* ptr_; 53 | }; 54 | 55 | } // namespace ark 56 | 57 | #endif // ARK_GPU_MEMORY_HPP_ 58 | -------------------------------------------------------------------------------- /ark/gpu/gpu_stream.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "gpu/gpu_stream.hpp" 5 | 6 | #include "gpu/gpu_logging.hpp" 7 | #include "gpu/gpu_manager.hpp" 8 | 9 | namespace ark { 10 | class GpuStream::Impl { 11 | public: 12 | Impl(); 13 | ~Impl(); 14 | Impl(const Impl &) = delete; 15 | Impl &operator=(const Impl &) = delete; 16 | 17 | gpuStream get() const { return gpu_stream_; } 18 | gpuError query() const { return gpuStreamQuery(gpu_stream_); } 19 | void sync() const { GLOG(gpuStreamSynchronize(gpu_stream_)); } 20 | 21 | private: 22 | gpuStream gpu_stream_; 23 | }; 24 | 25 | GpuStream::GpuStream() : pimpl_(std::make_shared()) {} 26 | 27 | void GpuStream::sync() const { pimpl_->sync(); } 28 | 29 | gpuError GpuStream::query() const { return pimpl_->query(); } 30 | 31 | gpuStream GpuStream::get() const { return pimpl_->get(); } 32 | 33 | GpuStream::Impl::Impl() { 34 | GLOG(gpuStreamCreateWithFlags(&gpu_stream_, gpuStreamNonBlocking)); 35 | } 36 | 37 | GpuStream::Impl::~Impl() { GLOG(gpuStreamDestroy(gpu_stream_)); } 38 | 39 | } // namespace ark 40 | -------------------------------------------------------------------------------- /ark/gpu/gpu_stream.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_GPU_STREAM_HPP_ 5 | #define ARK_GPU_STREAM_HPP_ 6 | 7 | #include 8 | 9 | #include "gpu/gpu.hpp" 10 | 11 | namespace ark { 12 | 13 | class GpuManager; 14 | 15 | class GpuStream { 16 | public: 17 | ~GpuStream() = default; 18 | void sync() const; 19 | gpuError query() const; 20 | gpuStream get() const; 21 | 22 | protected: 23 | friend class GpuManager; 24 | 25 | GpuStream(); 26 | 27 | private: 28 | class Impl; 29 | std::shared_ptr pimpl_; 30 | }; 31 | } // namespace ark 32 | 33 | #endif // ARK_GPU_STREAM_HPP_ 34 | -------------------------------------------------------------------------------- /ark/include/ark.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_HPP 5 | #define ARK_HPP 6 | 7 | // clang-format off 8 | #include 9 | // clang-format on 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | #endif // ARK_HPP 25 | -------------------------------------------------------------------------------- /ark/include/ark/data_type.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_DATA_TYPE_HPP 5 | #define ARK_DATA_TYPE_HPP 6 | 7 | #include 8 | #include 9 | 10 | namespace ark { 11 | 12 | class DataType; 13 | 14 | extern const DataType NONE; 15 | extern const DataType FP32; 16 | extern const DataType FP16; 17 | extern const DataType BF16; 18 | extern const DataType INT32; 19 | extern const DataType UINT32; 20 | extern const DataType INT8; 21 | extern const DataType UINT8; 22 | extern const DataType BYTE; 23 | 24 | class ModelDataT; 25 | using ModelDataType = std::shared_ptr; 26 | 27 | class DataType { 28 | protected: 29 | friend class Model; 30 | ModelDataType ref_; 31 | 32 | public: 33 | DataType() = default; 34 | DataType(ModelDataType ref) : ref_(ref) {} 35 | DataType(const DataType &other) = default; 36 | DataType &operator=(const DataType &other) = default; 37 | 38 | bool operator==(const DataType &other) const { return ref_ == other.ref_; } 39 | bool operator!=(const DataType &other) const { return ref_ != other.ref_; } 40 | 41 | bool is_null() const { return !ref_; } 42 | 43 | ModelDataType ref() const { return ref_; } 44 | 45 | size_t bytes() const; 46 | 47 | const std::string &name() const; 48 | 49 | static const DataType &from_name(const std::string &type_name); 50 | }; 51 | 52 | } // namespace ark 53 | 54 | #endif // ARK_DATA_TYPE_HPP 55 | -------------------------------------------------------------------------------- /ark/include/ark/dims.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_DIMS_HPP 5 | #define ARK_DIMS_HPP 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | namespace ark { 12 | 13 | // Data type for dimension. 14 | typedef int64_t DimType; 15 | 16 | // DIMS_LEN is the maximum number of dimensions of a tensor. 17 | constexpr DimType DIMS_LEN = 4; 18 | 19 | // Up-to-`DIMS_LEN`-dimensional vector. 20 | class Dims { 21 | private: 22 | std::vector data_; 23 | 24 | public: 25 | Dims(); 26 | 27 | Dims(DimType d0); 28 | 29 | Dims(DimType d0, DimType d1); 30 | 31 | Dims(DimType d0, DimType d1, DimType d2); 32 | 33 | Dims(DimType d0, DimType d1, DimType d2, DimType d3); 34 | 35 | // Copy another Dims object. 36 | Dims(const Dims &dims_); 37 | // Construct from a vector. If the vector is shorter than DIMS_LEN, put 38 | // following NO_DIMs. Raise an error if the vector is longer than DIMS_LEN. 39 | Dims(const std::vector &vec); 40 | 41 | // Return the number of elements. If the dimensions are invalid, return 42 | // -1. 43 | DimType nelems() const; 44 | // Return the number of valid dimensions. 45 | int ndims() const; 46 | // Return a new Dims object with 4 valid dimensions by prepending 1s. 47 | Dims dims4() const; 48 | // Return true if all valid dimensions are zero. 49 | bool is_zeros() const; 50 | // Return true if the dimensions are empty. 51 | bool is_no_dim() const; 52 | // 53 | bool has_negative() const; 54 | // Return true if the dimensions are invalid. 55 | bool is_invalid() const; 56 | // Return a vector of valid dimensions. 57 | const std::vector &vector() const; 58 | // Insert a dimension at the given index. 59 | void insert(int idx, DimType dim); 60 | // Erase the dimension at the given index and return the erased dimension. 61 | DimType erase(int idx); 62 | 63 | DimType &operator[](int idx); 64 | 65 | const DimType &operator[](int idx) const; 66 | 67 | Dims &operator=(const Dims &) = default; 68 | 69 | friend bool operator==(const Dims &a, const Dims &b); 70 | friend bool operator!=(const Dims &a, const Dims &b); 71 | }; 72 | 73 | std::ostream &operator<<(std::ostream &os, const Dims &dims); 74 | 75 | } // namespace ark 76 | 77 | #endif // ARK_DIMS_HPP 78 | -------------------------------------------------------------------------------- /ark/include/ark/error.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_ERROR_HPP 5 | #define ARK_ERROR_HPP 6 | 7 | #include 8 | #include 9 | 10 | namespace ark { 11 | 12 | /// Base class for all ARK errors. 13 | class BaseError : public std::exception { 14 | private: 15 | std::string msg_; 16 | 17 | public: 18 | BaseError(const std::string &msg) : msg_(msg) {} 19 | const char *what() const noexcept override { return msg_.c_str(); } 20 | }; 21 | 22 | #define REGISTER_ERROR_TYPE(_name) \ 23 | class _name : public BaseError { \ 24 | public: \ 25 | _name(const std::string &msg) : BaseError(msg) {} \ 26 | }; 27 | 28 | /// Internal error in ARK, likely a bug. 29 | REGISTER_ERROR_TYPE(InternalError) 30 | /// Invalid usage of ARK API. 31 | REGISTER_ERROR_TYPE(InvalidUsageError) 32 | /// Invalid ARK model definition or usage. 33 | REGISTER_ERROR_TYPE(ModelError) 34 | /// Invalid ARK plan definition or usage. 35 | REGISTER_ERROR_TYPE(PlanError) 36 | /// Unsupported feature triggered. 37 | REGISTER_ERROR_TYPE(UnsupportedError) 38 | /// Error from invalid system state such as a system call failure. 39 | REGISTER_ERROR_TYPE(SystemError) 40 | /// Error from a CUDA/HIP API call. 41 | REGISTER_ERROR_TYPE(GpuError) 42 | /// Error from a unit test. 43 | REGISTER_ERROR_TYPE(UnitTestError) 44 | 45 | } // namespace ark 46 | 47 | #endif // ARK_ERROR_HPP 48 | -------------------------------------------------------------------------------- /ark/include/ark/init.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_INIT_HPP 5 | #define ARK_INIT_HPP 6 | 7 | namespace ark { 8 | 9 | /// Initialize the ARK runtime. 10 | /// 11 | /// This function should be called by the user before any other functions are 12 | /// called. It is safe to call this function multiple times. 13 | void init(); 14 | 15 | } // namespace ark 16 | 17 | #endif // ARK_INIT_HPP 18 | -------------------------------------------------------------------------------- /ark/include/ark/log.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_LOG_HPP 5 | #define ARK_LOG_HPP 6 | 7 | #include 8 | 9 | namespace ark { 10 | 11 | typedef enum { DEBUG, INFO, WARN, ERROR } LogLevel; 12 | 13 | void log(LogLevel level, const std::string &file, int line, 14 | const std::string &msg); 15 | 16 | } // namespace ark 17 | 18 | #endif // ARK_LOG_HPP 19 | -------------------------------------------------------------------------------- /ark/include/ark/model_graph.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_MODEL_GRAPH_HPP 5 | #define ARK_MODEL_GRAPH_HPP 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | namespace ark { 13 | 14 | class ModelGraph { 15 | public: 16 | ModelGraph(int rank, int world_size); 17 | 18 | ModelGraph(const ModelGraph &other); 19 | 20 | ~ModelGraph(); 21 | 22 | ModelGraph &operator=(const ModelGraph &other); 23 | 24 | int rank() const; 25 | 26 | int world_size() const; 27 | 28 | void compress_nodes(); 29 | 30 | bool compressed() const; 31 | 32 | bool verify() const; 33 | 34 | std::string serialize(bool pretty = true) const; 35 | 36 | /// Get the list of @ref ModelNode in the graph. 37 | std::vector nodes() const; 38 | 39 | protected: 40 | friend class Model; 41 | friend class ModelContextManager; 42 | friend class Context; 43 | 44 | class Impl; 45 | std::unique_ptr impl_; 46 | }; 47 | 48 | } // namespace ark 49 | 50 | #endif // ARK_MODEL_GRAPH_HPP 51 | -------------------------------------------------------------------------------- /ark/include/ark/model_ref.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_MODEL_REF_HPP 5 | #define ARK_MODEL_REF_HPP 6 | 7 | #include 8 | 9 | namespace ark { 10 | 11 | class ModelOp; 12 | using ModelOpRef = std::shared_ptr; 13 | 14 | class ModelBuffer; 15 | using ModelBufferRef = std::shared_ptr; 16 | 17 | class ModelTensor; 18 | using ModelTensorRef = std::shared_ptr; 19 | 20 | class ModelNode; 21 | using ModelNodeRef = std::shared_ptr; 22 | 23 | } // namespace ark 24 | 25 | #endif // ARK_MODEL_REF_HPP 26 | -------------------------------------------------------------------------------- /ark/include/ark/planner.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_PLANNER_HPP 5 | #define ARK_PLANNER_HPP 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | namespace ark { 13 | 14 | template 15 | class Range; 16 | 17 | class PlannerContext : public Context { 18 | public: 19 | PlannerContext(Model& model); 20 | 21 | void processor_range(int start, int end, int step = 1); 22 | 23 | void warp_range(int start, int end, int step = 1); 24 | 25 | void sram_range(int start, int end, int step = 1); 26 | 27 | void sync(bool sync); 28 | 29 | void config(const std::string& config); 30 | 31 | private: 32 | void check_range(const std::string& key, const Range& range); 33 | }; 34 | 35 | class Planner { 36 | public: 37 | Planner(const Model& model, int device_id); 38 | 39 | ~Planner(); 40 | 41 | using ConfigRule = std::function; 43 | 44 | void install_config_rule(ConfigRule rule); 45 | 46 | std::string plan(bool pretty = true) const; 47 | 48 | private: 49 | class Impl; 50 | std::shared_ptr impl_; 51 | }; 52 | 53 | } // namespace ark 54 | 55 | #endif // ARK_PLANNER_HPP 56 | -------------------------------------------------------------------------------- /ark/include/ark/random.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_RANDOM_HPP 5 | #define ARK_RANDOM_HPP 6 | 7 | #include 8 | 9 | namespace ark { 10 | 11 | // set random seed 12 | void srand(int seed = -1); 13 | 14 | // get random number 15 | int rand(); 16 | 17 | } // namespace ark 18 | 19 | #endif // ARK_RANDOM_HPP 20 | -------------------------------------------------------------------------------- /ark/include/ark/tensor.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_TENSOR_HPP 5 | #define ARK_TENSOR_HPP 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | namespace ark { 13 | 14 | /// 15 | /// Tensor is a view of a memory space. 16 | /// 17 | /// Illustration of a single axis of a tensor: 18 | /// 19 | /// 0 offset stride 20 | /// |------------|-------------shape-------------|---------------------------| 21 | /// <-----------------------------> 22 | /// data range of this tensor 23 | /// 24 | class Tensor { 25 | protected: 26 | friend class Model; 27 | ModelTensorRef ref_; 28 | 29 | public: 30 | Tensor() = default; 31 | Tensor(ModelTensorRef ref) : ref_(ref) {} 32 | Tensor(const Tensor &other) = default; 33 | Tensor &operator=(const Tensor &other) = default; 34 | 35 | bool operator==(const Tensor &other) const { return ref_ == other.ref_; } 36 | bool operator!=(const Tensor &other) const { return ref_ != other.ref_; } 37 | 38 | bool is_null() const { return !ref_; } 39 | 40 | ModelTensorRef ref() const { return ref_; } 41 | 42 | size_t id() const; 43 | 44 | Dims shape() const; 45 | 46 | Dims strides() const; 47 | 48 | Dims offsets() const; 49 | 50 | Dims padded_shape() const; 51 | 52 | const DataType &data_type() const; 53 | }; 54 | 55 | const Tensor NullTensor; 56 | 57 | std::ostream &operator<<(std::ostream &os, const Tensor &tensor); 58 | 59 | } // namespace ark 60 | 61 | #endif // ARK_TENSOR_HPP 62 | -------------------------------------------------------------------------------- /ark/include/ark/version.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_VERSION_HPP 5 | #define ARK_VERSION_HPP 6 | 7 | #include 8 | 9 | #define ARK_MAJOR 0 10 | #define ARK_MINOR 5 11 | #define ARK_PATCH 0 12 | #define ARK_VERSION (ARK_MAJOR * 10000 + ARK_MINOR * 100 + ARK_PATCH) 13 | 14 | namespace ark { 15 | 16 | /// Return a version string. 17 | std::string version(); 18 | 19 | } // namespace ark 20 | 21 | #endif // ARK_VERSION_HPP 22 | -------------------------------------------------------------------------------- /ark/include/kernels/ark_kernels.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #if defined(ARK_TARGET_CUDA_ARCH) || defined(ARK_TARGET_ROCM_ARCH) 5 | 6 | #ifndef ARK_KERNELS_H_ 7 | #define ARK_KERNELS_H_ 8 | 9 | #include 10 | #include 11 | 12 | #include "arithmetic.h" 13 | #include "cast.h" 14 | #include "comm.h" 15 | #include "copy.h" 16 | #include "embedding.h" 17 | #include "im2col.h" 18 | #include "layernorm.h" 19 | #include "math_functions.h" 20 | #include "matmul.h" 21 | #include "noop.h" 22 | #include "reduce.h" 23 | #include "scalar.h" 24 | #include "transpose.h" 25 | 26 | #endif // ARK_KERNELS_H_ 27 | 28 | #endif 29 | -------------------------------------------------------------------------------- /ark/include/kernels/cast.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_KERNELS_CAST_H_ 5 | #define ARK_KERNELS_CAST_H_ 6 | 7 | #include "common/broadcast.h" 8 | #include "common/type_intrinsics.h" 9 | #include "common/vector_type.h" 10 | 11 | namespace ark { 12 | 13 | template 15 | struct Cast; 16 | 17 | template 18 | struct Cast<_InShape, _FromType, _ToType, 2> { 19 | using InputType = _FromType; 20 | using OutputType = _ToType; 21 | static const int NelemPerThread = 2; 22 | 23 | static DEVICE void compute(_ToType *output, const _FromType *input) { 24 | if constexpr (_InShape::W == 1) { 25 | *output = type::Cast::compute<_ToType>(*input); 26 | } else if constexpr (type::VtypeExists<_FromType, 2>::value && 27 | type::VtypeExists<_ToType, 2>::value) { 28 | using ToType2 = typename type::Vtype<_ToType, 2>::type; 29 | using FromType2 = typename type::Vtype<_FromType, 2>::type; 30 | ToType2 *pout = reinterpret_cast(output); 31 | const FromType2 *pin = reinterpret_cast(input); 32 | *pout = type::Cast::compute(*pin); 33 | } else { 34 | output[0] = type::Cast::compute<_ToType>(input[0]); 35 | output[1] = type::Cast::compute<_ToType>(input[1]); 36 | } 37 | } 38 | }; 39 | 40 | template 43 | DEVICE void cast(ToType *out, FromType *in, int uop_idx, int) { 44 | Broadcast1>::run(out, in, 46 | uop_idx); 47 | } 48 | 49 | } // namespace ark 50 | 51 | #endif // ARK_KERNELS_CAST_H_ 52 | -------------------------------------------------------------------------------- /ark/include/kernels/common/arch.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_KERNELS_ARCH_H_ 5 | #define ARK_KERNELS_ARCH_H_ 6 | 7 | #include "device.h" 8 | #include "static_math.h" 9 | 10 | namespace ark { 11 | 12 | struct Arch { 13 | #if defined(ARK_TARGET_CUDA_ARCH) 14 | static const int ThreadsPerWarp = 32; 15 | #elif defined(ARK_TARGET_ROCM_ARCH) 16 | static const int ThreadsPerWarp = 64; 17 | #endif 18 | }; 19 | 20 | DEVICE int warp_id() { 21 | return threadIdx.x >> math::log2_up::value; 22 | } 23 | 24 | } // namespace ark 25 | 26 | #if defined(ARK_TARGET_CUDA_ARCH) 27 | #define ARCH_ALIAS_TYPE(alias, cuda_type, hip_type) typedef cuda_type alias; 28 | #elif defined(ARK_TARGET_ROCM_ARCH) 29 | #define ARCH_ALIAS_TYPE(alias, cuda_type, hip_type) typedef hip_type alias; 30 | #endif 31 | 32 | #if defined(ARK_TARGET_CUDA_ARCH) 33 | #define ARCH_ALIAS_FUNC(alias, cuda_func, hip_func) \ 34 | template \ 35 | inline auto alias(Args &&... args) { \ 36 | return cuda_func(std::forward(args)...); \ 37 | } 38 | #elif defined(ARK_TARGET_ROCM_ARCH) 39 | #define ARCH_ALIAS_FUNC(alias, cuda_func, hip_func) \ 40 | template \ 41 | inline auto alias(Args &&... args) { \ 42 | return hip_func(std::forward(args)...); \ 43 | } 44 | #endif 45 | 46 | #endif // ARK_KERNELS_ARCH_H_ 47 | -------------------------------------------------------------------------------- /ark/include/kernels/common/atomic.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_KERNELS_ATOMIC_H_ 5 | #define ARK_KERNELS_ATOMIC_H_ 6 | 7 | #include 8 | 9 | #include "device.h" 10 | 11 | namespace ark { 12 | 13 | template 14 | DEVICE T atomicLoadRelaxed(T *ptr) { 15 | return mscclpp::atomicLoad(ptr, mscclpp::memoryOrderRelaxed); 16 | } 17 | 18 | template 19 | DEVICE void atomicStoreRelaxed(T *ptr, const T &val) { 20 | mscclpp::atomicStore(ptr, val, mscclpp::memoryOrderRelaxed); 21 | } 22 | 23 | } // namespace ark 24 | 25 | #endif // ARK_KERNELS_ATOMIC_H_ 26 | -------------------------------------------------------------------------------- /ark/include/kernels/common/bf16.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_KERNELS_BF16_H_ 5 | #define ARK_KERNELS_BF16_H_ 6 | 7 | #include "arch.h" 8 | #include "device.h" 9 | #include "vector_type.h" 10 | 11 | #if defined(ARK_TARGET_CUDA_ARCH) 12 | #include 13 | #elif defined(ARK_TARGET_ROCM_ARCH) 14 | #include 15 | #endif 16 | 17 | namespace ark { 18 | 19 | ARCH_ALIAS_TYPE(bf16, __nv_bfloat16, __hip_bfloat16); 20 | ARCH_ALIAS_TYPE(bf16x2, __nv_bfloat162, __hip_bfloat162); 21 | ARCH_ALIAS_TYPE(bf16_raw, __nv_bfloat16_raw, __hip_bfloat16); 22 | ARCH_ALIAS_TYPE(bf16x2_raw, __nv_bfloat162_raw, __hip_bfloat162); 23 | 24 | namespace type { 25 | 26 | template 27 | struct Constant; 28 | 29 | template <> 30 | struct Constant { 31 | static DEVICE bf16 zero() { return bf16_raw{0x0}; } 32 | static DEVICE bf16 lowest() { return bf16_raw{0xff7f}; } 33 | }; 34 | 35 | template <> 36 | struct Constant { 37 | static DEVICE bf16x2 zero() { return bf16x2_raw{0x0, 0x0}; } 38 | static DEVICE bf16x2 lowest() { return bf16x2_raw{0xff7f, 0xff7f}; } 39 | }; 40 | 41 | template <> 42 | struct Vtype { 43 | using type = bf16x2; 44 | }; 45 | 46 | template <> 47 | struct Vtype { 48 | using type = const bf16x2; 49 | }; 50 | 51 | } // namespace type 52 | 53 | } // namespace ark 54 | 55 | #endif // ARK_KERNELS_BF16_H_ 56 | -------------------------------------------------------------------------------- /ark/include/kernels/common/checker.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_KERNELS_CHECKER_H_ 5 | #define ARK_KERNELS_CHECKER_H_ 6 | 7 | #include "device.h" 8 | 9 | namespace ark { 10 | 11 | /// Check if two values are equal at compile time. 12 | /// @tparam Value0_ First value. 13 | /// @tparam Value1_ Second value. 14 | template 15 | struct IsEq { 16 | static const int Value0 = Value0_; 17 | static const int Value1 = Value1_; 18 | static_assert(Value0 == Value1, "Size mismatch"); 19 | DEVICE void operator()() const { 20 | // Do nothing. 21 | } 22 | }; 23 | 24 | } // namespace ark 25 | 26 | #endif // ARK_KERNELS_CHECKER_H_ 27 | -------------------------------------------------------------------------------- /ark/include/kernels/common/device.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_KERNELS_DEVICE_H_ 5 | #define ARK_KERNELS_DEVICE_H_ 6 | 7 | #if defined(ARK_TARGET_CUDA_ARCH) && defined(ARK_TARGET_ROCM_ARCH) 8 | static_assert(false, "Multiple GPU architectures"); 9 | #endif // defined(ARK_TARGET_CUDA_ARCH) && defined(ARK_TARGET_ROCM_ARCH) 10 | 11 | #if defined(ARK_TARGET_ROCM_ARCH) 12 | #include 13 | #endif // !defined(ARK_TARGET_CUDA_ARCH) 14 | 15 | #if !defined(ARK_TARGET_CUDA_ARCH) && !defined(ARK_TARGET_ROCM_ARCH) 16 | static_assert(false, "Unknown GPU architecture"); 17 | #define ARK_TARGET_CUDA_ARCH 800 // Dummy define 18 | #include // Dummy include 19 | #endif // !defined(ARK_TARGET_CUDA_ARCH) && !defined(ARK_TARGET_ROCM_ARCH) 20 | 21 | #define DEVICE __forceinline__ __device__ 22 | 23 | #endif // ARK_KERNELS_DEVICE_H_ 24 | -------------------------------------------------------------------------------- /ark/include/kernels/common/ewise.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_KERNELS_EWISE_H_ 5 | #define ARK_KERNELS_EWISE_H_ 6 | 7 | #include "unit_op.h" 8 | 9 | namespace ark { 10 | 11 | /// Element-wise computation operator with a single input. 12 | template 14 | struct Ewise1 { 15 | using UnitOp = UnitOp; 16 | using DataType = typename CompType::DataType; 17 | static const int NelemPerThread = CompType::NelemPerThread; 18 | 19 | static_assert(NelemPerThread > 0, "NelemPerThread must be positive"); 20 | static_assert(UnitOutDims::W % NelemPerThread == 0, 21 | "UnitOutDims::W must be divisible by NelemPerThread"); 22 | 23 | /// Conduct element-wise computation on input and write the result on 24 | /// output. 25 | /// @param out Output data. 26 | /// @param in Input data. 27 | /// @param uop_idx Index of the unit operator. 28 | static DEVICE void run(DataType *out, DataType *in, int uop_idx) { 29 | int un = UnitOp::uop_idx_n(uop_idx); 30 | int uc = UnitOp::uop_idx_c(uop_idx); 31 | int uh = UnitOp::uop_idx_h(uop_idx); 32 | int uw = UnitOp::uop_idx_w(uop_idx); 33 | 34 | for (int tid = UnitOp::thread_id();; tid += UnitOp::NumThreads) { 35 | int tid_w = (tid * NelemPerThread) % UnitOutDims::W; 36 | int tid_h = 37 | ((tid * NelemPerThread) / UnitOutDims::W) % UnitOutDims::H; 38 | int tid_c = 39 | ((tid * NelemPerThread) / UnitOutDims::HW) % UnitOutDims::C; 40 | int tid_n = (tid * NelemPerThread) / UnitOutDims::CHW; 41 | 42 | if (tid_n >= UnitOutDims::N) { 43 | break; 44 | } 45 | 46 | int idx_n = tid_n + un * UnitOutDims::N; 47 | int idx_c = tid_c + uc * UnitOutDims::C; 48 | int idx_h = tid_h + uh * UnitOutDims::H; 49 | int idx_w = tid_w + uw * UnitOutDims::W; 50 | 51 | CompType::compute(out, in, idx_n, idx_c, idx_h, idx_w); 52 | } 53 | 54 | UnitOp::sync_threads(); 55 | } 56 | }; 57 | 58 | } // namespace ark 59 | 60 | #endif // ARK_KERNELS_EWISE_H_ 61 | -------------------------------------------------------------------------------- /ark/include/kernels/common/fp16.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_KERNELS_FP16_H_ 5 | #define ARK_KERNELS_FP16_H_ 6 | 7 | #include "device.h" 8 | #include "vector_type.h" 9 | 10 | #if defined(ARK_TARGET_CUDA_ARCH) 11 | #include 12 | #elif defined(ARK_TARGET_ROCM_ARCH) 13 | #include 14 | #endif 15 | 16 | namespace ark { 17 | 18 | using fp16 = __half; 19 | using fp16x2 = __half2; 20 | 21 | namespace type { 22 | 23 | template 24 | struct Constant; 25 | 26 | template <> 27 | struct Constant { 28 | static DEVICE fp16 zero() { 29 | #if defined(ARK_TARGET_CUDA_ARCH) 30 | return __half_raw{0}; 31 | #elif defined(ARK_TARGET_ROCM_ARCH) 32 | union BitCast { 33 | unsigned short u; 34 | fp16 f; 35 | }; 36 | return BitCast{0}.f; 37 | #endif 38 | } 39 | static DEVICE fp16 lowest() { 40 | #if defined(ARK_TARGET_CUDA_ARCH) 41 | return __half_raw{0xfbff}; 42 | #elif defined(ARK_TARGET_ROCM_ARCH) 43 | union BitCast { 44 | unsigned short u; 45 | fp16 f; 46 | }; 47 | return BitCast{0xfbff}.f; 48 | #endif 49 | } 50 | }; 51 | 52 | template <> 53 | struct Constant { 54 | static DEVICE fp16x2 zero() { 55 | #if defined(ARK_TARGET_CUDA_ARCH) 56 | return __half2_raw{0, 0}; 57 | #elif defined(ARK_TARGET_ROCM_ARCH) 58 | union BitCast { 59 | unsigned short u[2]; 60 | fp16x2 f; 61 | }; 62 | return BitCast{0, 0}.f; 63 | #endif 64 | } 65 | static DEVICE fp16x2 lowest() { 66 | #if defined(ARK_TARGET_CUDA_ARCH) 67 | return __half2_raw{0xfbff, 0xfbff}; 68 | #elif defined(ARK_TARGET_ROCM_ARCH) 69 | union BitCast { 70 | unsigned short u[2]; 71 | fp16x2 f; 72 | }; 73 | return BitCast{0xfbff, 0xfbff}.f; 74 | #endif 75 | } 76 | }; 77 | 78 | template <> 79 | struct Vtype { 80 | using type = fp16x2; 81 | }; 82 | 83 | template <> 84 | struct Vtype { 85 | using type = const fp16x2; 86 | }; 87 | 88 | } // namespace type 89 | 90 | } // namespace ark 91 | 92 | #endif // ARK_KERNELS_FP16_H_ 93 | -------------------------------------------------------------------------------- /ark/include/kernels/common/fp32.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_KERNELS_FP32_H_ 5 | #define ARK_KERNELS_FP32_H_ 6 | 7 | #include 8 | 9 | #include "device.h" 10 | #include "vector_type.h" 11 | 12 | namespace ark { 13 | 14 | using fp32 = float; 15 | using fp32x2 = float2; 16 | using fp32x4 = float4; 17 | 18 | namespace type { 19 | 20 | template 21 | struct Constant; 22 | 23 | template <> 24 | struct Constant { 25 | static DEVICE fp32 zero() { return 0; } 26 | static DEVICE fp32 lowest() { return -FLT_MAX; } 27 | }; 28 | 29 | template <> 30 | struct Constant { 31 | static DEVICE fp32x2 zero() { return make_float2(0, 0); } 32 | static DEVICE fp32x2 lowest() { return make_float2(-FLT_MAX, -FLT_MAX); } 33 | }; 34 | 35 | template <> 36 | struct Constant { 37 | static DEVICE fp32x4 zero() { return make_float4(0, 0, 0, 0); } 38 | static DEVICE fp32x4 lowest() { 39 | return make_float4(-FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX); 40 | } 41 | }; 42 | 43 | template <> 44 | struct Vtype { 45 | using type = fp32x2; 46 | }; 47 | 48 | template <> 49 | struct Vtype { 50 | using type = const fp32x2; 51 | }; 52 | 53 | template <> 54 | struct Vtype { 55 | using type = fp32x4; 56 | }; 57 | 58 | template <> 59 | struct Vtype { 60 | using type = const fp32x4; 61 | }; 62 | 63 | } // namespace type 64 | 65 | } // namespace ark 66 | 67 | #endif // ARK_KERNELS_FP32_H_ 68 | -------------------------------------------------------------------------------- /ark/include/kernels/common/integer.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_KERNELS_INTEGER_H_ 5 | #define ARK_KERNELS_INTEGER_H_ 6 | 7 | #include "device.h" 8 | #include "vector_type.h" 9 | 10 | namespace ark { 11 | 12 | using i32 = int; 13 | using i32x2 = int2; 14 | using i32x4 = int4; 15 | using ui32 = unsigned int; 16 | using ui32x2 = uint2; 17 | using ui32x4 = uint4; 18 | 19 | namespace type { 20 | 21 | template 22 | struct Constant; 23 | 24 | template <> 25 | struct Constant { 26 | static DEVICE i32 zero() { return 0; } 27 | static DEVICE i32 lowest() { return 0x80000000; } 28 | }; 29 | 30 | template <> 31 | struct Constant { 32 | static DEVICE i32x2 zero() { return make_int2(0, 0); } 33 | static DEVICE i32x2 lowest() { return make_int2(0x80000000, 0x80000000); } 34 | }; 35 | 36 | template <> 37 | struct Constant { 38 | static DEVICE i32x4 zero() { return make_int4(0, 0, 0, 0); } 39 | static DEVICE i32x4 lowest() { 40 | return make_int4(0x80000000, 0x80000000, 0x80000000, 0x80000000); 41 | } 42 | }; 43 | 44 | template <> 45 | struct Constant { 46 | static DEVICE ui32 zero() { return 0; } 47 | static DEVICE ui32 lowest() { return 0; } 48 | }; 49 | 50 | template <> 51 | struct Constant { 52 | static DEVICE ui32x2 zero() { return make_uint2(0, 0); } 53 | static DEVICE ui32x2 lowest() { return make_uint2(0, 0); } 54 | }; 55 | 56 | template <> 57 | struct Constant { 58 | static DEVICE ui32x4 zero() { return make_uint4(0, 0, 0, 0); } 59 | static DEVICE ui32x4 lowest() { return make_uint4(0, 0, 0, 0); } 60 | }; 61 | 62 | template <> 63 | struct Vtype { 64 | using type = i32x2; 65 | }; 66 | 67 | template <> 68 | struct Vtype { 69 | using type = const i32x2; 70 | }; 71 | 72 | template <> 73 | struct Vtype { 74 | using type = i32x4; 75 | }; 76 | 77 | template <> 78 | struct Vtype { 79 | using type = const i32x4; 80 | }; 81 | 82 | template <> 83 | struct Vtype { 84 | using type = ui32x2; 85 | }; 86 | 87 | template <> 88 | struct Vtype { 89 | using type = const ui32x2; 90 | }; 91 | 92 | template <> 93 | struct Vtype { 94 | using type = ui32x4; 95 | }; 96 | 97 | template <> 98 | struct Vtype { 99 | using type = const ui32x4; 100 | }; 101 | 102 | } // namespace type 103 | 104 | } // namespace ark 105 | 106 | #endif // ARK_KERNELS_INTEGER_H_ 107 | -------------------------------------------------------------------------------- /ark/include/kernels/common/shfl.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_KERNELS_SHFL_H_ 5 | #define ARK_KERNELS_SHFL_H_ 6 | 7 | #include "device.h" 8 | 9 | namespace ark { 10 | 11 | #if defined(ARK_TARGET_CUDA_ARCH) 12 | #define SHFL_XOR(var, lane_mask, width) \ 13 | __shfl_xor_sync(0xffffffff, var, lane_mask, width) 14 | #elif defined(ARK_TARGET_ROCM_ARCH) 15 | #define SHFL_XOR(var, lane_mask, width) __shfl_xor(var, lane_mask, width) 16 | #endif 17 | 18 | } // namespace ark 19 | 20 | #endif // ARK_KERNELS_SHFL_H_ 21 | -------------------------------------------------------------------------------- /ark/include/kernels/common/smem.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_KERNELS_SMEM_H_ 5 | #define ARK_KERNELS_SMEM_H_ 6 | 7 | #include "arch.h" 8 | #include "device.h" 9 | #include "static_math.h" 10 | 11 | extern __shared__ int _ARK_SMEM[]; 12 | 13 | // should be multiple of 128 and equal to or larger than sync::WarpGroupState 14 | #define ARK_SMEM_RESERVED_BYTES 128 15 | 16 | namespace ark { 17 | 18 | template 19 | struct SharedMemory { 20 | static DEVICE int smem_base_offset(int smem_per_warp) { 21 | // The smallest warp ID in the uop. 22 | int least_warp_id = math::gm(warp_id()); 23 | return math::div(least_warp_id * smem_per_warp + 24 | ARK_SMEM_RESERVED_BYTES); 25 | } 26 | 27 | static DEVICE T *get(int smem_per_warp) { 28 | return (T *)&_ARK_SMEM[smem_base_offset(smem_per_warp)]; 29 | } 30 | }; 31 | 32 | } // namespace ark 33 | 34 | #endif // ARK_KERNELS_SMEM_H_ 35 | -------------------------------------------------------------------------------- /ark/include/kernels/common/vec.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_KERNELS_VEC_H_ 5 | #define ARK_KERNELS_VEC_H_ 6 | 7 | #include "static_math.h" 8 | 9 | namespace ark { 10 | 11 | using DimType = long long int; 12 | 13 | template 14 | struct Vec { 15 | static_assert(_D0 >= 0, ""); 16 | static_assert(_D1 >= 0, ""); 17 | static_assert(_D2 >= 0, ""); 18 | static_assert(_D3 >= 0, ""); 19 | 20 | // 4D representation. 21 | static const DimType D0 = _D0; 22 | static const DimType D1 = _D1; 23 | static const DimType D2 = _D2; 24 | static const DimType D3 = _D3; 25 | static const DimType N = _D0; 26 | static const DimType C = _D1; 27 | static const DimType H = _D2; 28 | static const DimType W = _D3; 29 | 30 | // 3D representation. 31 | static const DimType X = _D0; 32 | static const DimType Y = _D1; 33 | static const DimType Z = _D2; 34 | 35 | // Multiplied values. 36 | static const DimType NCHW = 37 | math::mul::value>::value>::value; 38 | static const DimType NCH = math::mul::value>::value; 39 | static const DimType NCW = math::mul::value>::value; 40 | static const DimType NHW = math::mul::value>::value; 41 | static const DimType CHW = math::mul::value>::value; 42 | static const DimType NC = math::mul::value; 43 | static const DimType NH = math::mul::value; 44 | static const DimType NW = math::mul::value; 45 | static const DimType CH = math::mul::value; 46 | static const DimType CW = math::mul::value; 47 | static const DimType HW = math::mul::value; 48 | 49 | static_assert(NCHW >= 0, ""); 50 | static_assert(NCH >= 0, ""); 51 | static_assert(NCW >= 0, ""); 52 | static_assert(NHW >= 0, ""); 53 | static_assert(CHW >= 0, ""); 54 | static_assert(NC >= 0, ""); 55 | static_assert(NH >= 0, ""); 56 | static_assert(NW >= 0, ""); 57 | static_assert(CH >= 0, ""); 58 | static_assert(CW >= 0, ""); 59 | static_assert(HW >= 0, ""); 60 | }; 61 | 62 | template 63 | struct VecIsEq { 64 | enum { 65 | value = (Vec1::D0 == Vec2::D0 && Vec1::D1 == Vec2::D1 && 66 | Vec1::D2 == Vec2::D2 && Vec1::D3 == Vec2::D3) 67 | }; 68 | }; 69 | 70 | } // namespace ark 71 | 72 | #endif // ARK_KERNELS_VEC_H_ 73 | -------------------------------------------------------------------------------- /ark/include/kernels/copy.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_KERNELS_COPY_H_ 5 | #define ARK_KERNELS_COPY_H_ 6 | 7 | #include "common/broadcast.h" 8 | 9 | namespace ark { 10 | 11 | template 14 | DEVICE void copy(OutDataType *out, InDataType *in, int uop_idx, 15 | [[maybe_unused]] int smem_per_warp) { 16 | DefaultBroadcast1::run(out, in, uop_idx); 19 | } 20 | 21 | } // namespace ark 22 | 23 | #endif // ARK_KERNELS_COPY_H_ 24 | -------------------------------------------------------------------------------- /ark/include/kernels/noop.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_KERNELS_NOOP_H_ 5 | #define ARK_KERNELS_NOOP_H_ 6 | 7 | #include "common/device.h" 8 | 9 | namespace ark { 10 | 11 | DEVICE void noop(int, int) {} 12 | 13 | } // namespace ark 14 | 15 | #endif // ARK_KERNELS_NOOP_H_ 16 | -------------------------------------------------------------------------------- /ark/include/kernels/scalar.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_KERNELS_SCALAR_H_ 5 | #define ARK_KERNELS_SCALAR_H_ 6 | 7 | #include "common/broadcast.h" 8 | 9 | namespace ark { 10 | 11 | template 13 | DEVICE void scalar_assign(OutDataType *out, float val, int uop_idx, int) { 14 | OutDataType val_cast = type::Cast::compute(val); 15 | using ValDims = Vec<1, 1, 1, 1>; 16 | using ValShape = Vec<1, 1, 1, 1>; 17 | DefaultBroadcast1::run(out, &val_cast, uop_idx); 20 | } 21 | 22 | template 25 | DEVICE void scalar_add(OutDataType *y, InDataType *x, float val, int uop_idx, 26 | int) { 27 | InDataType val_cast = type::Cast::compute(val); 28 | using ValDims = Vec<1, 1, 1, 1>; 29 | using ValShape = Vec<1, 1, 1, 1>; 30 | DefaultBroadcast2::run(y, x, &val_cast, 33 | uop_idx); 34 | } 35 | 36 | template 39 | DEVICE void scalar_mul(OutDataType *y, InDataType *x, float val, int uop_idx, 40 | int) { 41 | InDataType val_cast = type::Cast::compute(val); 42 | using ValDims = Vec<1, 1, 1, 1>; 43 | using ValShape = Vec<1, 1, 1, 1>; 44 | DefaultBroadcast2::run(y, x, &val_cast, 47 | uop_idx); 48 | } 49 | 50 | } // namespace ark 51 | 52 | #endif // ARK_KERNELS_SCALAR_H_ 53 | -------------------------------------------------------------------------------- /ark/logging.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "logging.hpp" 5 | 6 | #include 7 | 8 | #include 9 | #include 10 | #include 11 | 12 | #include "cpu_timer.h" 13 | #include "env.h" 14 | 15 | namespace ark { 16 | 17 | Logging::Logging(const std::string &lv) : pid_{::getpid()} { 18 | if (lv.size() == 0) { 19 | level_ = INFO; 20 | } else if (lv == "DEBUG") { 21 | level_ = DEBUG; 22 | } else if (lv == "WARN") { 23 | level_ = WARN; 24 | } else if (lv == "ERROR") { 25 | level_ = ERROR; 26 | } else { 27 | level_ = INFO; 28 | } 29 | } 30 | 31 | const LogLevel &Logging::get_level() const { return level_; } 32 | 33 | void Logging::set_level(LogLevel lv) { level_ = lv; }; 34 | 35 | //////////////////////////////////////////////////////////////////////////////// 36 | 37 | // Get the global Logging. 38 | Logging &get_logging() { 39 | static std::unique_ptr ark_logging = nullptr; 40 | if (ark_logging.get() == nullptr) { 41 | ark_logging = std::make_unique(get_env().log_level); 42 | } 43 | return *ark_logging; 44 | } 45 | 46 | void _log_header(std::ostream &os, const LogLevel ll, const std::string &file, 47 | const int line) { 48 | os << "ARK " << std::setfill(' ') << std::setw(5) << ::getpid() << ' '; 49 | switch (ll) { 50 | case INFO: 51 | os << "INFO "; 52 | break; 53 | case DEBUG: 54 | os << "DEBUG "; 55 | break; 56 | case WARN: 57 | os << "WARN "; 58 | break; 59 | case ERROR: 60 | os << "ERROR "; 61 | break; 62 | } 63 | std::string file_name; 64 | size_t pos = file.rfind("ark/"); 65 | if (pos == std::string::npos) { 66 | file_name = file; 67 | } else { 68 | file_name = file.substr(pos + 4); 69 | } 70 | os << file_name << ':' << line << ' '; 71 | } 72 | 73 | void set_log_level(LogLevel lv) { get_logging().set_level(lv); } 74 | 75 | const LogLevel &get_log_level() { return get_logging().get_level(); } 76 | 77 | } // namespace ark 78 | -------------------------------------------------------------------------------- /ark/model/model_buffer.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "model_buffer.hpp" 5 | 6 | #include "logging.hpp" 7 | 8 | namespace ark { 9 | 10 | ModelBuffer::ModelBuffer(int rank) : rank_(rank) { 11 | static size_t id = 0; 12 | id_ = id++; 13 | } 14 | 15 | ModelBuffer::ModelBuffer(size_t id, int rank, 16 | const std::vector &send_tags, 17 | const std::vector &recv_tags) 18 | : id_(id), rank_(rank) { 19 | for (const auto &info : send_tags) { 20 | send_tags_.insert(info); 21 | } 22 | for (const auto &info : recv_tags) { 23 | recv_tags_.insert(info); 24 | } 25 | } 26 | 27 | void ModelBuffer::tag_send(int remote_rank, int tag) { 28 | send_tags_.insert(TagInfo{remote_rank, tag}); 29 | } 30 | 31 | void ModelBuffer::tag_recv(int remote_rank, int tag) { 32 | recv_tags_.insert(TagInfo{remote_rank, tag}); 33 | } 34 | 35 | Json ModelBuffer::serialize() const { 36 | Json j; 37 | j["Id"] = id_; 38 | j["Rank"] = rank_; 39 | Json send_tags = Json::array(); 40 | Json recv_tags = Json::array(); 41 | for (const auto &info : send_tags_) { 42 | send_tags.push_back({info.first, info.second}); 43 | } 44 | for (const auto &info : recv_tags_) { 45 | recv_tags.push_back({info.first, info.second}); 46 | } 47 | j["SendTags"] = send_tags; 48 | j["RecvTags"] = recv_tags; 49 | return j; 50 | } 51 | 52 | std::shared_ptr ModelBuffer::deserialize(const Json &serialized) { 53 | if (!serialized.contains("Id")) { 54 | ERR(ModelError, "ModelBuffer deserialization failed: missing Id"); 55 | } else if (!serialized.contains("Rank")) { 56 | ERR(ModelError, "ModelBuffer deserialization failed: missing Rank"); 57 | } else if (!serialized.contains("SendTags")) { 58 | ERR(ModelError, "ModelBuffer deserialization failed: missing SendTags"); 59 | } else if (!serialized.contains("RecvTags")) { 60 | ERR(ModelError, "ModelBuffer deserialization failed: missing RecvTags"); 61 | } 62 | return std::make_shared(serialized["Id"], serialized["Rank"], 63 | serialized["SendTags"], 64 | serialized["RecvTags"]); 65 | } 66 | 67 | } // namespace ark 68 | -------------------------------------------------------------------------------- /ark/model/model_buffer.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_MODEL_BUFFER_HPP_ 5 | #define ARK_MODEL_BUFFER_HPP_ 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | #include "model_json.hpp" 12 | 13 | namespace ark { 14 | 15 | class ModelBuffer { 16 | public: 17 | // (remote_rank, tag) 18 | using TagInfo = std::pair; 19 | 20 | ModelBuffer(int rank = -1); 21 | 22 | ModelBuffer(size_t id, int rank, const std::vector &send_tags, 23 | const std::vector &recv_tags); 24 | 25 | size_t id() const { return id_; } 26 | 27 | int rank() const { return rank_; } 28 | 29 | const std::set &send_tags() const { return send_tags_; } 30 | 31 | const std::set &recv_tags() const { return recv_tags_; } 32 | 33 | // Identify this buffer as `tag` when sending data to `remote_rank`. 34 | // The same buffer can be tagged multiple times with different tags, 35 | // but the same tag can only be used for one sending buffer. 36 | void tag_send(int remote_rank, int tag); 37 | 38 | // Identify this buffer as `tag` when receiving from `remote_rank`. 39 | // The same buffer can be tagged multiple times with different tags, 40 | // but the same tag can only be used for one receiving buffer. 41 | void tag_recv(int remote_rank, int tag); 42 | 43 | Json serialize() const; 44 | 45 | static std::shared_ptr deserialize(const Json &serialized); 46 | 47 | private: 48 | size_t id_; 49 | int rank_; 50 | std::set send_tags_; 51 | std::set recv_tags_; 52 | }; 53 | 54 | } // namespace ark 55 | 56 | #endif // ARK_MODEL_BUFFER_HPP_ 57 | -------------------------------------------------------------------------------- /ark/model/model_context_manager.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "model_context_manager.hpp" 5 | 6 | namespace ark { 7 | 8 | ModelContextManager::ModelContextManager(Model& model) 9 | : context_stack_(model.impl_->context_stack_) {} 10 | 11 | ModelContextManager::~ModelContextManager() { 12 | for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) { 13 | context_stack_->pop(*it); 14 | } 15 | } 16 | 17 | void ModelContextManager::set(const std::string& key, const Json& value) { 18 | context_stack_->push(key, value); 19 | keys_.push_back(key); 20 | } 21 | 22 | bool ModelContextManager::has(const std::string& key) const { 23 | return context_stack_->has(key); 24 | } 25 | 26 | Json ModelContextManager::get(const std::string& key) const { 27 | return context_stack_->get(key); 28 | } 29 | 30 | } // namespace ark 31 | -------------------------------------------------------------------------------- /ark/model/model_context_manager.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_MODEL_CONTEXT_MANAGER_HPP_ 5 | #define ARK_MODEL_CONTEXT_MANAGER_HPP_ 6 | 7 | #include 8 | 9 | #include "ark/model.hpp" 10 | #include "model_graph_impl.hpp" 11 | #include "model_json.hpp" 12 | 13 | namespace ark { 14 | 15 | class ModelContextManager { 16 | public: 17 | ModelContextManager(Model& model); 18 | 19 | ~ModelContextManager(); 20 | 21 | void set(const std::string& key, const Json& value); 22 | 23 | bool has(const std::string& key) const; 24 | 25 | Json get(const std::string& key) const; 26 | 27 | private: 28 | std::shared_ptr context_stack_; 29 | std::vector keys_; 30 | }; 31 | 32 | } // namespace ark 33 | 34 | #endif // ARK_MODEL_CONTEXT_MANAGER_HPP_ 35 | -------------------------------------------------------------------------------- /ark/model/model_context_manager_test.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "model_context_manager.hpp" 5 | 6 | #include "model_node.hpp" 7 | #include "unittest/unittest_utils.h" 8 | 9 | ark::unittest::State test_model_context_manager() { 10 | ark::Model model; 11 | ark::Tensor t0 = model.tensor({1}, ark::FP32); 12 | ark::Tensor t1 = model.tensor({1}, ark::FP32); 13 | 14 | // node 0 15 | ark::Tensor t2 = model.add(t0, t1); 16 | 17 | ark::Tensor t3; 18 | ark::Tensor t4; 19 | ark::Tensor t5; 20 | { 21 | // node 1 22 | ark::ModelContextManager cm(model); 23 | cm.set("key0", ark::Json("val1")); 24 | t3 = model.relu(t2); 25 | 26 | // node 2 27 | cm.set("key1", ark::Json("val2")); 28 | t4 = model.sqrt(t3); 29 | } 30 | { 31 | // node 3 32 | ark::ModelContextManager cm(model); 33 | cm.set("key0", ark::Json("val3")); 34 | t5 = model.exp(t2); 35 | } 36 | 37 | UNITTEST_TRUE(model.verify()); 38 | 39 | auto compressed = model.compress(); 40 | UNITTEST_TRUE(compressed.verify()); 41 | 42 | auto nodes = compressed.nodes(); 43 | UNITTEST_EQ(nodes.size(), 4); 44 | 45 | UNITTEST_EQ(nodes[0]->context.size(), 0); 46 | UNITTEST_EQ(nodes[1]->context.size(), 1); 47 | UNITTEST_EQ(nodes[1]->context.at("key0"), ark::Json("val1")); 48 | UNITTEST_EQ(nodes[2]->context.size(), 2); 49 | UNITTEST_EQ(nodes[2]->context.at("key0"), ark::Json("val1")); 50 | UNITTEST_EQ(nodes[2]->context.at("key1"), ark::Json("val2")); 51 | UNITTEST_EQ(nodes[3]->context.size(), 1); 52 | UNITTEST_EQ(nodes[3]->context.at("key0"), ark::Json("val3")); 53 | 54 | return ark::unittest::SUCCESS; 55 | } 56 | 57 | int main() { 58 | UNITTEST(test_model_context_manager); 59 | return 0; 60 | } 61 | -------------------------------------------------------------------------------- /ark/model/model_data_type.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "model_data_type.hpp" 5 | 6 | namespace ark { 7 | 8 | const std::string &ModelDataT::type_str() const { return type_str_; } 9 | 10 | size_t ModelDataT::bytes() const { return bytes_; } 11 | 12 | } // namespace ark 13 | -------------------------------------------------------------------------------- /ark/model/model_data_type.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_MODEL_DATA_TYPE_HPP_ 5 | #define ARK_MODEL_DATA_TYPE_HPP_ 6 | 7 | #include 8 | #include 9 | 10 | #include "model_named_type.hpp" 11 | 12 | namespace ark { 13 | 14 | class ModelDataT; 15 | using ModelDataType = std::shared_ptr; 16 | 17 | class ModelDataT : public ModelNamedT { 18 | public: 19 | ModelDataT(const std::string &type_name, const std::string &type_str, 20 | size_t bytes) 21 | : ModelNamedT(type_name), type_str_(type_str), bytes_(bytes) {} 22 | 23 | ModelDataT(const ModelDataT &) = default; 24 | 25 | const std::string &type_str() const; 26 | 27 | size_t bytes() const; 28 | 29 | private: 30 | std::string type_str_; 31 | size_t bytes_; 32 | }; 33 | 34 | using ModelDataType = std::shared_ptr; 35 | 36 | } // namespace ark 37 | 38 | #endif // ARK_MODEL_DATA_TYPE_HPP_ 39 | -------------------------------------------------------------------------------- /ark/model/model_json.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_MODEL_JSON_HPP_ 5 | #define ARK_MODEL_JSON_HPP_ 6 | 7 | #include 8 | 9 | namespace ark { 10 | 11 | using Json = ::nlohmann::ordered_json; 12 | 13 | class ModelJson : public Json { 14 | public: 15 | ModelJson(const Json &json); 16 | std::string dump_pretty(int indent = 0, int indent_step = 2) const; 17 | }; 18 | 19 | class PlanJson : public Json { 20 | public: 21 | PlanJson(const Json &json = nullptr); 22 | std::string dump_pretty(int indent = 0, int indent_step = 2) const; 23 | }; 24 | 25 | } // namespace ark 26 | 27 | #endif // ARK_MODEL_JSON_HPP_ 28 | -------------------------------------------------------------------------------- /ark/model/model_named_type.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_MODEL_NAMED_TYPE_HPP_ 5 | #define ARK_MODEL_NAMED_TYPE_HPP_ 6 | 7 | #include 8 | 9 | namespace ark { 10 | 11 | class ModelNamedT { 12 | public: 13 | ModelNamedT(const std::string &type_name) : type_name_(type_name) {} 14 | 15 | const std::string &type_name() const { return type_name_; } 16 | 17 | private: 18 | std::string type_name_; 19 | }; 20 | 21 | } // namespace ark 22 | 23 | #endif // ARK_MODEL_NAMED_TYPE_HPP_ 24 | -------------------------------------------------------------------------------- /ark/model/model_node.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_MODEL_NODE_HPP_ 5 | #define ARK_MODEL_NODE_HPP_ 6 | 7 | #include 8 | #include 9 | 10 | #include "ark/model_ref.hpp" 11 | #include "model_json.hpp" 12 | #include "unique_list.hpp" 13 | 14 | namespace ark { 15 | 16 | /// A node of @ref Model. 17 | class ModelNode { 18 | public: 19 | ModelNode() = default; 20 | 21 | /// @ref Op that this @ref ModelNode represents. 22 | ModelOpRef op; 23 | 24 | /// The list of @ref ModelNode that depends on this @ref ModelNode. 25 | UniqueList consumers; 26 | 27 | /// The list of @ref ModelNode that this @ref ModelNode depends on. 28 | UniqueList producers; 29 | 30 | /// Graph context of this node. 31 | std::map context; 32 | }; 33 | 34 | } // namespace ark 35 | 36 | #endif // ARK_MODEL_NODE_HPP_ 37 | -------------------------------------------------------------------------------- /ark/model/model_offset.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "model_offset.hpp" 5 | 6 | #include "logging.hpp" 7 | #include "model_buffer.hpp" 8 | #include "model_data_type.hpp" 9 | #include "model_tensor.hpp" 10 | 11 | namespace ark { 12 | 13 | ModelOffset::ModelOffset(ModelTensorRef tensor) { 14 | auto st = tensor->strides(); 15 | auto of = tensor->offsets(); 16 | int ndims = st.ndims(); 17 | size_t offset = 0; 18 | for (int idx = ndims - 1; idx >= 0; --idx) { 19 | size_t inc = of[idx]; 20 | for (int j = idx + 1; j < ndims; ++j) { 21 | inc *= st[j]; 22 | } 23 | offset += inc * tensor->data_type()->bytes(); 24 | } 25 | buffer_id_ = tensor->buffer()->id(); 26 | value_ = offset; 27 | } 28 | 29 | Json ModelOffset::serialize() const { 30 | Json j; 31 | j["BufferId"] = buffer_id_; 32 | j["Value"] = value_; 33 | return j; 34 | } 35 | 36 | std::shared_ptr ModelOffset::deserialize(const Json &serialized) { 37 | if (!serialized.contains("BufferId")) { 38 | ERR(ModelError, "ModelOffset deserialization failed: missing BufferId"); 39 | } else if (!serialized.contains("Value")) { 40 | ERR(ModelError, "ModelOffset deserialization failed: missing Value"); 41 | } 42 | return std::make_shared(serialized["BufferId"], 43 | serialized["Value"]); 44 | } 45 | 46 | } // namespace ark 47 | -------------------------------------------------------------------------------- /ark/model/model_offset.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_MODEL_OFFSET_HPP_ 5 | #define ARK_MODEL_OFFSET_HPP_ 6 | 7 | #include "ark/model_ref.hpp" 8 | #include "model_json.hpp" 9 | 10 | namespace ark { 11 | 12 | class ModelOffset { 13 | private: 14 | size_t buffer_id_; 15 | size_t value_; 16 | 17 | public: 18 | ModelOffset(size_t buffer_id, size_t value) 19 | : buffer_id_(buffer_id), value_(value) {} 20 | 21 | ModelOffset(ModelTensorRef tensor); 22 | 23 | size_t buffer_id() const { return buffer_id_; } 24 | 25 | size_t value() const { return value_; } 26 | 27 | Json serialize() const; 28 | 29 | static std::shared_ptr deserialize(const Json &serialized); 30 | }; 31 | 32 | } // namespace ark 33 | 34 | #endif // ARK_MODEL_OFFSET_HPP_ 35 | -------------------------------------------------------------------------------- /ark/model/model_op_arg.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_MODEL_OP_ARG_HPP_ 5 | #define ARK_MODEL_OP_ARG_HPP_ 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | #include "ark/dims.hpp" 12 | #include "ark/model_ref.hpp" 13 | #include "model_json.hpp" 14 | #include "model_named_type.hpp" 15 | #include "model_offset.hpp" 16 | 17 | namespace ark { 18 | 19 | template 20 | class ModelOpArgTName; 21 | 22 | #define REGISTER_MODEL_OP_ARG_TYPE(_name, _type) \ 23 | template <> \ 24 | class ModelOpArgTName<_type> { \ 25 | public: \ 26 | ModelOpArgTName() : name(#_name), type_str(#_type){}; \ 27 | const std::string name; \ 28 | const std::string type_str; \ 29 | }; 30 | 31 | class ModelOpArg : public ModelNamedT { 32 | public: 33 | ModelOpArg(); 34 | 35 | template 36 | ModelOpArg(T val) 37 | : ModelNamedT(ModelOpArgTName().name), 38 | type_str_(ModelOpArgTName().type_str), 39 | val_(val) {} 40 | 41 | template 42 | T value() const { 43 | return std::any_cast(val_); 44 | } 45 | 46 | const std::string &type_str() const { return type_str_; } 47 | 48 | Json serialize() const; 49 | 50 | static ModelOpArg deserialize(const Json &serialized); 51 | 52 | private: 53 | std::string type_str_; 54 | std::any val_; 55 | }; 56 | 57 | REGISTER_MODEL_OP_ARG_TYPE(INT, int) 58 | REGISTER_MODEL_OP_ARG_TYPE(UINT32, uint32_t) 59 | REGISTER_MODEL_OP_ARG_TYPE(INT64, int64_t) 60 | REGISTER_MODEL_OP_ARG_TYPE(UINT64, uint64_t) 61 | REGISTER_MODEL_OP_ARG_TYPE(BOOL, bool) 62 | REGISTER_MODEL_OP_ARG_TYPE(FLOAT, float) 63 | REGISTER_MODEL_OP_ARG_TYPE(DIMS, Dims) 64 | REGISTER_MODEL_OP_ARG_TYPE(TENSOR, ModelTensorRef) 65 | REGISTER_MODEL_OP_ARG_TYPE(OFFSET, ModelOffset) 66 | 67 | } // namespace ark 68 | 69 | #endif // ARK_MODEL_OP_ARG_HPP_ 70 | -------------------------------------------------------------------------------- /ark/model/model_tensor.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_MODEL_TENSOR_HPP_ 5 | #define ARK_MODEL_TENSOR_HPP_ 6 | 7 | #include "ark/dims.hpp" 8 | #include "ark/model_ref.hpp" 9 | #include "model_json.hpp" 10 | 11 | namespace ark { 12 | 13 | class ModelDataT; 14 | using ModelDataType = std::shared_ptr; 15 | 16 | class ModelTensor { 17 | public: 18 | ModelTensor(ModelDataType data_type, ModelBufferRef buffer, 19 | const Dims &shape, const Dims &strides = {}, 20 | const Dims &offsets = {}, const Dims &padded_shape = {}); 21 | 22 | ModelTensor(const ModelTensor &other); 23 | 24 | size_t id() const { return id_; } 25 | 26 | ModelDataType data_type() const { return data_type_; } 27 | 28 | ModelBufferRef buffer() const { return buffer_; } 29 | 30 | const Dims &shape() const { return shape_; } 31 | 32 | const Dims &strides() const { return strides_; } 33 | 34 | const Dims &offsets() const { return offsets_; } 35 | 36 | const Dims &padded_shape() const { return padded_shape_; } 37 | 38 | size_t shape_bytes() const; 39 | 40 | Json serialize() const; 41 | 42 | static std::shared_ptr deserialize(const Json &serialized); 43 | 44 | private: 45 | static size_t next_id(); 46 | 47 | size_t id_; 48 | ModelDataType data_type_; 49 | ModelBufferRef buffer_; 50 | Dims shape_; 51 | Dims strides_; 52 | Dims offsets_; 53 | Dims padded_shape_; 54 | }; 55 | 56 | } // namespace ark 57 | 58 | #endif // ARK_MODEL_TENSOR_HPP_ 59 | -------------------------------------------------------------------------------- /ark/ops/ops_all_reduce.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "ops_common.hpp" 5 | 6 | namespace ark { 7 | 8 | Tensor Model::all_reduce(Tensor input, int gpu_id, int gpu_num, Tensor output, 9 | const std::string &) { 10 | std::vector tags(gpu_num); 11 | for (int i = 0; i < gpu_num; i++) { 12 | tags[i] = this->unique_tag(); 13 | } 14 | if (output.is_null()) { 15 | output = this->copy(input); 16 | } 17 | Tensor prev_recv = NullTensor; 18 | Tensor cumulate = output; 19 | for (int i = 1; i < gpu_num; i++) { 20 | int gpu_dst = (gpu_id + i) % gpu_num; 21 | int gpu_src = (gpu_id + gpu_num - i) % gpu_num; 22 | Tensor send_data; 23 | if (prev_recv.is_null()) { 24 | send_data = input; 25 | } else { 26 | send_data = this->identity(input, {prev_recv}); 27 | } 28 | send_data = this->send(send_data, gpu_dst, tags[gpu_id]); 29 | Tensor send_done_tensor = this->send_done(send_data); 30 | Tensor recv_buf = this->tensor(output.shape(), output.data_type()); 31 | Tensor recv = this->identity(recv_buf, {send_done_tensor}); 32 | recv = this->recv(recv_buf, gpu_src, tags[gpu_src]); 33 | prev_recv = recv; 34 | cumulate = this->add(cumulate, recv, cumulate); 35 | } 36 | return cumulate; 37 | } 38 | 39 | } // namespace ark 40 | -------------------------------------------------------------------------------- /ark/ops/ops_arithmetic.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "ops_arithmetic.hpp" 5 | 6 | #include "ops_common.hpp" 7 | 8 | namespace ark { 9 | 10 | ModelOpAdd::ModelOpAdd(ModelTensorRef input, ModelTensorRef other, 11 | ModelTensorRef output) 12 | : ModelOpBroadcast2("Add", input, other, output) {} 13 | 14 | Tensor Model::add(Tensor input, Tensor other, Tensor output, 15 | const std::string &name) { 16 | return impl_ 17 | ->create_op(name, input.ref_, other.ref_, output.ref_) 18 | ->result_tensors()[0]; 19 | } 20 | 21 | ModelOpMul::ModelOpMul(ModelTensorRef input, ModelTensorRef other, 22 | ModelTensorRef output) 23 | : ModelOpBroadcast2("Mul", input, other, output) {} 24 | 25 | Tensor Model::mul(Tensor input, Tensor other, Tensor output, 26 | const std::string &name) { 27 | return impl_ 28 | ->create_op(name, input.ref_, other.ref_, output.ref_) 29 | ->result_tensors()[0]; 30 | } 31 | 32 | ModelOpSub::ModelOpSub(ModelTensorRef input, ModelTensorRef other, 33 | ModelTensorRef output) 34 | : ModelOpBroadcast2("Sub", input, other, output) {} 35 | 36 | Tensor Model::sub(Tensor input, Tensor other, Tensor output, 37 | const std::string &name) { 38 | return impl_ 39 | ->create_op(name, input.ref_, other.ref_, output.ref_) 40 | ->result_tensors()[0]; 41 | } 42 | 43 | ModelOpDiv::ModelOpDiv(ModelTensorRef input, ModelTensorRef other, 44 | ModelTensorRef output) 45 | : ModelOpBroadcast2("Div", input, other, output) {} 46 | 47 | Tensor Model::div(Tensor input, Tensor other, Tensor output, 48 | const std::string &name) { 49 | return impl_ 50 | ->create_op(name, input.ref_, other.ref_, output.ref_) 51 | ->result_tensors()[0]; 52 | } 53 | 54 | } // namespace ark 55 | -------------------------------------------------------------------------------- /ark/ops/ops_arithmetic.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_OPS_ARITHMETIC_HPP_ 5 | #define ARK_OPS_ARITHMETIC_HPP_ 6 | 7 | #include "ops_broadcast.hpp" 8 | 9 | namespace ark { 10 | 11 | class ModelOpAdd : public ModelOpBroadcast2 { 12 | public: 13 | ModelOpAdd() = default; 14 | ModelOpAdd(ModelTensorRef input, ModelTensorRef other, 15 | ModelTensorRef output); 16 | }; 17 | 18 | class ModelOpMul : public ModelOpBroadcast2 { 19 | public: 20 | ModelOpMul() = default; 21 | ModelOpMul(ModelTensorRef input, ModelTensorRef other, 22 | ModelTensorRef output); 23 | }; 24 | 25 | class ModelOpSub : public ModelOpBroadcast2 { 26 | public: 27 | ModelOpSub() = default; 28 | ModelOpSub(ModelTensorRef input, ModelTensorRef other, 29 | ModelTensorRef output); 30 | }; 31 | 32 | class ModelOpDiv : public ModelOpBroadcast2 { 33 | public: 34 | ModelOpDiv() = default; 35 | ModelOpDiv(ModelTensorRef input, ModelTensorRef other, 36 | ModelTensorRef output); 37 | }; 38 | 39 | } // namespace ark 40 | 41 | #endif // ARK_OPS_ARITHMETIC_HPP_ 42 | -------------------------------------------------------------------------------- /ark/ops/ops_broadcast.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_OPS_BROADCAST_HPP_ 5 | #define ARK_OPS_BROADCAST_HPP_ 6 | 7 | #include "model/model_op.hpp" 8 | 9 | namespace ark { 10 | 11 | class ModelOpBroadcast1 : public ModelOp { 12 | public: 13 | ModelOpBroadcast1() = default; 14 | ModelOpBroadcast1(const std::string &type_name, ModelTensorRef input, 15 | ModelTensorRef output); 16 | 17 | std::string impl_name(const Json &config) const override; 18 | 19 | std::vector impl_args(const Json &config) const override; 20 | 21 | Json default_config(const ArchRef arch = ARCH_ANY) const override; 22 | }; 23 | 24 | class ModelOpBroadcast2 : public ModelOp { 25 | public: 26 | ModelOpBroadcast2() = default; 27 | ModelOpBroadcast2(const std::string &type_name, ModelTensorRef input, 28 | ModelTensorRef other, ModelTensorRef output); 29 | 30 | std::string impl_name(const Json &config) const override; 31 | 32 | std::vector impl_args(const Json &config) const override; 33 | 34 | Json default_config(const ArchRef arch = ARCH_ANY) const override; 35 | }; 36 | 37 | } // namespace ark 38 | 39 | #endif // ARK_OPS_BROADCAST_HPP_ 40 | -------------------------------------------------------------------------------- /ark/ops/ops_cast.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_OPS_CAST_HPP_ 5 | #define ARK_OPS_CAST_HPP_ 6 | 7 | #include "ops_broadcast.hpp" 8 | #include "ops_tensor.hpp" 9 | 10 | namespace ark { 11 | 12 | class ModelOpCast : public ModelOpBroadcast1 { 13 | public: 14 | ModelOpCast() = default; 15 | ModelOpCast(ModelTensorRef input, ModelDataType data_type, 16 | ModelTensorRef output); 17 | }; 18 | 19 | class ModelOpByteCast : public ModelOpTensor { 20 | public: 21 | ModelOpByteCast() = default; 22 | ModelOpByteCast(ModelTensorRef input, ModelDataType data_type, 23 | const Dims &shape, const Dims &strides, const Dims &offsets, 24 | const Dims &padded_shape); 25 | }; 26 | 27 | } // namespace ark 28 | 29 | #endif // ARK_OPS_CAST_HPP_ 30 | -------------------------------------------------------------------------------- /ark/ops/ops_common.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_OPS_COMMON_HPP_ 5 | #define ARK_OPS_COMMON_HPP_ 6 | 7 | #include 8 | 9 | #include "ark/model.hpp" 10 | #include "logging.hpp" 11 | #include "model/model_buffer.hpp" 12 | #include "model/model_data_type.hpp" 13 | #include "model/model_graph_impl.hpp" 14 | #include "model/model_offset.hpp" 15 | #include "model/model_op.hpp" 16 | #include "model/model_tensor.hpp" 17 | 18 | namespace ark { 19 | 20 | void check_null(ModelTensorRef tensor); 21 | 22 | void check_match_data_type(ModelTensorRef t, ModelDataType dt); 23 | 24 | void check_match_data_type(ModelTensorRef a, ModelTensorRef b); 25 | 26 | void check_match_shape(ModelTensorRef tensor, const Dims &shape); 27 | 28 | void check_match_padded_shape(ModelTensorRef tensor, const Dims &padded_shape); 29 | 30 | /// Return the output shape of broadcasting between two shapes. 31 | /// Follow NumPy rules. 32 | /// https://numpy.org/doc/stable/user/basics.broadcasting.html 33 | /// @param dims1 The first shape. 34 | /// @param dims2 The second shape. 35 | Dims broadcast_shape(const Dims &dims1, const Dims &dims2); 36 | 37 | void check_fields_config(const Json &config, 38 | const std::vector &fields); 39 | 40 | void check_fields_args(const std::map &args, 41 | const std::vector &fields); 42 | 43 | } // namespace ark 44 | 45 | #endif // ARK_OPS_COMMON_HPP_ 46 | -------------------------------------------------------------------------------- /ark/ops/ops_copy.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "ops_copy.hpp" 5 | 6 | #include "ops_common.hpp" 7 | 8 | namespace ark { 9 | 10 | ModelOpCopy::ModelOpCopy(ModelTensorRef input, ModelTensorRef output) 11 | : ModelOpBroadcast1( 12 | "Copy", input, 13 | output ? output 14 | : std::make_shared( 15 | input->data_type(), std::make_shared(), 16 | input->shape())) { 17 | if (output) { 18 | check_match_data_type(input, output); 19 | } 20 | verify(); 21 | } 22 | 23 | Tensor Model::copy(Tensor input, Tensor output, const std::string &name) { 24 | return impl_->create_op(name, input.ref_, output.ref_) 25 | ->result_tensors()[0]; 26 | } 27 | 28 | } // namespace ark 29 | -------------------------------------------------------------------------------- /ark/ops/ops_copy.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_OPS_COPY_HPP_ 5 | #define ARK_OPS_COPY_HPP_ 6 | 7 | #include "ops_broadcast.hpp" 8 | 9 | namespace ark { 10 | 11 | class ModelOpCopy : public ModelOpBroadcast1 { 12 | public: 13 | ModelOpCopy() = default; 14 | ModelOpCopy(ModelTensorRef input, ModelTensorRef output); 15 | }; 16 | 17 | } // namespace ark 18 | 19 | #endif // ARK_OPS_COPY_HPP_ 20 | -------------------------------------------------------------------------------- /ark/ops/ops_embedding.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_OPS_EMBEDDING_HPP_ 5 | #define ARK_OPS_EMBEDDING_HPP_ 6 | 7 | #include "model/model_op.hpp" 8 | 9 | namespace ark { 10 | 11 | class ModelOpEmbedding : public ModelOp { 12 | public: 13 | ModelOpEmbedding() = default; 14 | ModelOpEmbedding(ModelTensorRef input, ModelTensorRef weight, 15 | ModelTensorRef output); 16 | 17 | std::string impl_name(const Json &config) const override; 18 | 19 | std::vector impl_args(const Json &config) const override; 20 | 21 | Json default_config(const ArchRef arch = ARCH_ANY) const override; 22 | }; 23 | 24 | } // namespace ark 25 | 26 | #endif // ARK_OPS_EMBEDDING_HPP_ 27 | -------------------------------------------------------------------------------- /ark/ops/ops_identity.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "ops_identity.hpp" 5 | 6 | #include 7 | 8 | #include "ops_common.hpp" 9 | 10 | namespace ark { 11 | 12 | ModelOpIdentity::ModelOpIdentity(ModelTensorRef input, 13 | const std::vector &deps) 14 | : ModelOpTensor(input->buffer(), input->shape(), input->data_type(), 15 | input->strides(), input->offsets(), input->padded_shape()) { 16 | std::set dep_set; 17 | dep_set.emplace(input); 18 | read_tensors_.emplace_back(input); 19 | for (auto &dep : deps) { 20 | if (dep_set.emplace(dep).second) { 21 | read_tensors_.emplace_back(dep); 22 | } 23 | } 24 | 25 | verify(); 26 | } 27 | 28 | Tensor Model::identity(Tensor input, const std::vector &deps, 29 | const std::string &name) { 30 | std::vector deps_ref; 31 | for (auto &dep : deps) { 32 | deps_ref.emplace_back(dep.ref_); 33 | } 34 | return impl_->create_op(name, input.ref_, deps_ref) 35 | ->result_tensors()[0]; 36 | } 37 | 38 | } // namespace ark 39 | -------------------------------------------------------------------------------- /ark/ops/ops_identity.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_OPS_IDENTITY_HPP_ 5 | #define ARK_OPS_IDENTITY_HPP_ 6 | 7 | #include "ops_tensor.hpp" 8 | 9 | namespace ark { 10 | 11 | class ModelOpIdentity : public ModelOpTensor { 12 | public: 13 | ModelOpIdentity() = default; 14 | ModelOpIdentity(ModelTensorRef input, 15 | const std::vector &deps); 16 | }; 17 | 18 | } // namespace ark 19 | 20 | #endif // ARK_OPS_IDENTITY_HPP_ 21 | -------------------------------------------------------------------------------- /ark/ops/ops_identity_test.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "ark/executor.hpp" 5 | #include "model/model_node.hpp" 6 | #include "model/model_op.hpp" 7 | #include "ops_test_common.hpp" 8 | 9 | ark::unittest::State test_ops_identity_model() { 10 | // OpNode graph: 11 | // 12 | // ReluOp --+ 13 | // | 14 | // ReluOp --+--> ReluOp 15 | // 16 | 17 | ark::Model model; 18 | ark::Tensor t0 = model.tensor({1}, ark::FP32); 19 | ark::Tensor t1 = model.tensor({1}, ark::FP32); 20 | ark::Tensor t2 = model.tensor({1}, ark::FP32); 21 | 22 | ark::Tensor r0 = model.relu(t0); 23 | ark::Tensor r1 = model.relu(t1); 24 | ark::Tensor t3 = model.identity(t2, {r0, r1}); 25 | 26 | ark::Tensor t4 = model.relu(t3); 27 | UNITTEST_TRUE(model.verify()); 28 | 29 | auto compressed = model.compress(); 30 | UNITTEST_TRUE(compressed.verify()); 31 | auto nodes = compressed.nodes(); 32 | UNITTEST_EQ(nodes.size(), 3); 33 | 34 | UNITTEST_EQ(nodes[0]->op->result_tensors()[0], r0.ref()); 35 | UNITTEST_EQ(nodes[0]->producers.size(), 0); 36 | UNITTEST_EQ(nodes[0]->consumers.size(), 1); 37 | 38 | UNITTEST_EQ(nodes[1]->op->result_tensors()[0], r1.ref()); 39 | UNITTEST_EQ(nodes[1]->producers.size(), 0); 40 | UNITTEST_EQ(nodes[1]->consumers.size(), 1); 41 | 42 | UNITTEST_EQ(nodes[2]->op->result_tensors()[0], t4.ref()); 43 | UNITTEST_EQ(nodes[2]->producers.size(), 2); 44 | UNITTEST_EQ(nodes[2]->consumers.size(), 0); 45 | 46 | return ark::unittest::SUCCESS; 47 | } 48 | 49 | ark::unittest::State test_ops_identity() { 50 | ark::Model model; 51 | // float buf[2][3][4][5]; 52 | ark::Tensor tns0 = model.tensor({2, 3, 4, 5}, ark::FP32); 53 | ark::Tensor tns1 = model.identity(tns0); 54 | 55 | // For preventing optimize-out 56 | model.noop(tns0); 57 | model.noop(tns1); 58 | 59 | // Create an executor 60 | ark::DefaultExecutor exe(model); 61 | exe.compile(); 62 | 63 | int num_elem = 2 * 3 * 4 * 5; 64 | 65 | // Fill tensor data: {1.0, 2.0, 3.0, ..., 120.0} 66 | std::vector data_vec(num_elem); 67 | std::iota(data_vec.begin(), data_vec.end(), 1.0f); 68 | exe.tensor_write(tns0, data_vec); 69 | 70 | // Check identity values 71 | std::vector ref_val(num_elem); 72 | exe.tensor_read(tns1, ref_val); 73 | for (int i = 0; i < num_elem; ++i) { 74 | UNITTEST_EQ(ref_val[i], (float)(i + 1)); 75 | } 76 | 77 | return ark::unittest::SUCCESS; 78 | } 79 | 80 | int main() { 81 | UNITTEST(test_ops_identity_model); 82 | UNITTEST(test_ops_identity); 83 | return 0; 84 | } 85 | -------------------------------------------------------------------------------- /ark/ops/ops_math.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_OPS_MATH_HPP_ 5 | #define ARK_OPS_MATH_HPP_ 6 | 7 | #include "ops_broadcast.hpp" 8 | 9 | namespace ark { 10 | 11 | class ModelOpMath : public ModelOpBroadcast1 { 12 | public: 13 | ModelOpMath() = default; 14 | ModelOpMath(const std::string &type_name, ModelTensorRef input, 15 | ModelTensorRef output); 16 | }; 17 | 18 | class ModelOpExp : public ModelOpMath { 19 | public: 20 | ModelOpExp() = default; 21 | ModelOpExp(ModelTensorRef input, ModelTensorRef output); 22 | }; 23 | 24 | class ModelOpGelu : public ModelOpMath { 25 | public: 26 | ModelOpGelu() = default; 27 | ModelOpGelu(ModelTensorRef input, ModelTensorRef output); 28 | }; 29 | 30 | class ModelOpRelu : public ModelOpMath { 31 | public: 32 | ModelOpRelu() = default; 33 | ModelOpRelu(ModelTensorRef input, ModelTensorRef output); 34 | }; 35 | 36 | class ModelOpRsqrt : public ModelOpMath { 37 | public: 38 | ModelOpRsqrt() = default; 39 | ModelOpRsqrt(ModelTensorRef input, ModelTensorRef output); 40 | }; 41 | 42 | class ModelOpSigmoid : public ModelOpMath { 43 | public: 44 | ModelOpSigmoid() = default; 45 | ModelOpSigmoid(ModelTensorRef input, ModelTensorRef output); 46 | }; 47 | 48 | class ModelOpSqrt : public ModelOpMath { 49 | public: 50 | ModelOpSqrt() = default; 51 | ModelOpSqrt(ModelTensorRef input, ModelTensorRef output); 52 | }; 53 | 54 | } // namespace ark 55 | 56 | #endif // ARK_OPS_MATH_HPP_ 57 | -------------------------------------------------------------------------------- /ark/ops/ops_matmul.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_OPS_MATMUL_HPP_ 5 | #define ARK_OPS_MATMUL_HPP_ 6 | 7 | #include "model/model_op.hpp" 8 | 9 | namespace ark { 10 | 11 | class ModelOpMatmul : public ModelOp { 12 | public: 13 | ModelOpMatmul() = default; 14 | ModelOpMatmul(ModelTensorRef input, ModelTensorRef other, 15 | ModelTensorRef output, bool trans_input, bool trans_other); 16 | 17 | std::string impl_name(const Json &config) const override; 18 | 19 | std::vector impl_args(const Json &config) const override; 20 | 21 | Json default_config(const ArchRef arch = ARCH_ANY) const override; 22 | }; 23 | 24 | } // namespace ark 25 | 26 | #endif // ARK_OPS_MATMUL_HPP_ 27 | -------------------------------------------------------------------------------- /ark/ops/ops_noop.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "ops_noop.hpp" 5 | 6 | #include "ops_common.hpp" 7 | 8 | namespace ark { 9 | 10 | ModelOpNoop::ModelOpNoop(ModelTensorRef input) : ModelOp("Noop") { 11 | read_tensors_ = {input}; 12 | verify(); 13 | } 14 | 15 | std::string ModelOpNoop::impl_name([[maybe_unused]] const Json &config) const { 16 | return function_name_string("noop"); 17 | } 18 | 19 | std::vector ModelOpNoop::impl_args([ 20 | [maybe_unused]] const Json &config) const { 21 | return {}; 22 | } 23 | 24 | Json ModelOpNoop::default_config([[maybe_unused]] const ArchRef arch) const { 25 | Json config; 26 | config["NumWarps"] = 1; 27 | config["SramBytes"] = 0; 28 | config["NumTasks"] = 0; 29 | return config; 30 | } 31 | 32 | void Model::noop(Tensor input, const std::string &name) { 33 | impl_->create_op(name, input.ref_); 34 | } 35 | 36 | } // namespace ark 37 | -------------------------------------------------------------------------------- /ark/ops/ops_noop.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_OPS_NOOP_HPP_ 5 | #define ARK_OPS_NOOP_HPP_ 6 | 7 | #include "model/model_op.hpp" 8 | 9 | namespace ark { 10 | 11 | class ModelOpNoop : public ModelOp { 12 | public: 13 | ModelOpNoop() = default; 14 | ModelOpNoop(ModelTensorRef input); 15 | 16 | std::string impl_name(const Json &config) const override; 17 | 18 | std::vector impl_args(const Json &config) const override; 19 | 20 | Json default_config(const ArchRef arch = ARCH_ANY) const override; 21 | }; 22 | 23 | } // namespace ark 24 | 25 | #endif // ARK_OPS_NOOP_HPP_ 26 | -------------------------------------------------------------------------------- /ark/ops/ops_reduce.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_OPS_REDUCE_HPP_ 5 | #define ARK_OPS_REDUCE_HPP_ 6 | 7 | #include "model/model_op.hpp" 8 | 9 | namespace ark { 10 | 11 | class ModelOpReduce : public ModelOp { 12 | public: 13 | ModelOpReduce() = default; 14 | ModelOpReduce(const std::string &type_name, ModelTensorRef input, int axis, 15 | bool keepdims, ModelTensorRef output); 16 | 17 | std::string impl_name(const Json &config) const override; 18 | 19 | std::vector impl_args(const Json &config) const override; 20 | 21 | Json default_config(const ArchRef arch = ARCH_ANY) const override; 22 | }; 23 | 24 | class ModelOpReduceMax : public ModelOpReduce { 25 | public: 26 | ModelOpReduceMax() = default; 27 | ModelOpReduceMax(ModelTensorRef input, int axis, bool keepdims, 28 | ModelTensorRef output) 29 | : ModelOpReduce("ReduceMax", input, axis, keepdims, output) {} 30 | }; 31 | 32 | class ModelOpReduceMean : public ModelOpReduce { 33 | public: 34 | ModelOpReduceMean() = default; 35 | ModelOpReduceMean(ModelTensorRef input, int axis, bool keepdims, 36 | ModelTensorRef output) 37 | : ModelOpReduce("ReduceMean", input, axis, keepdims, output) {} 38 | }; 39 | 40 | class ModelOpReduceSum : public ModelOpReduce { 41 | public: 42 | ModelOpReduceSum() = default; 43 | ModelOpReduceSum(ModelTensorRef input, int axis, bool keepdims, 44 | ModelTensorRef output) 45 | : ModelOpReduce("ReduceSum", input, axis, keepdims, output) {} 46 | }; 47 | 48 | } // namespace ark 49 | 50 | #endif // ARK_OPS_REDUCE_HPP_ 51 | -------------------------------------------------------------------------------- /ark/ops/ops_refer.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "ops_refer.hpp" 5 | 6 | #include "ops_common.hpp" 7 | 8 | namespace ark { 9 | 10 | ModelOpRefer::ModelOpRefer(ModelTensorRef input, const Dims &shape, 11 | const Dims &strides, const Dims &offsets, 12 | const Dims &padded_shape) 13 | : ModelOpTensor(input->buffer(), shape, input->data_type(), strides, 14 | offsets, padded_shape) { 15 | read_tensors_ = {input}; 16 | verify(); 17 | } 18 | 19 | Tensor Model::refer(Tensor input, const Dims &shape, const Dims &strides, 20 | const Dims &offsets, const Dims &padded_shape, 21 | const std::string &name) { 22 | return impl_ 23 | ->create_op(name, input.ref_, shape, strides, offsets, 24 | padded_shape) 25 | ->result_tensors()[0]; 26 | } 27 | 28 | } // namespace ark 29 | -------------------------------------------------------------------------------- /ark/ops/ops_refer.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_OPS_REFER_HPP_ 5 | #define ARK_OPS_REFER_HPP_ 6 | 7 | #include "ops_tensor.hpp" 8 | 9 | namespace ark { 10 | 11 | class ModelOpRefer : public ModelOpTensor { 12 | public: 13 | ModelOpRefer() = default; 14 | ModelOpRefer(ModelTensorRef input, const Dims &shape, const Dims &strides, 15 | const Dims &offsets, const Dims &padded_shape); 16 | }; 17 | 18 | } // namespace ark 19 | 20 | #endif // ARK_OPS_REFER_HPP_ 21 | -------------------------------------------------------------------------------- /ark/ops/ops_reshape.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_OPS_RESHAPE_HPP_ 5 | #define ARK_OPS_RESHAPE_HPP_ 6 | 7 | #include "ark/dims.hpp" 8 | #include "ark/model.hpp" 9 | #include "model/model_op.hpp" 10 | #include "ops_tensor.hpp" 11 | 12 | namespace ark { 13 | 14 | class ModelOpReshape : public ModelOpTensor { 15 | public: 16 | ModelOpReshape() = default; 17 | ModelOpReshape(ModelTensorRef input, const Dims &shape, const Dims &strides, 18 | const Dims &offsets); 19 | }; 20 | 21 | } // namespace ark 22 | 23 | #endif // ARK_OPS_RESHAPE_HPP_ 24 | -------------------------------------------------------------------------------- /ark/ops/ops_rope.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "ops_rope.hpp" 5 | 6 | #include "ops_common.hpp" 7 | 8 | namespace ark { 9 | 10 | ModelOpRope::ModelOpRope(ModelTensorRef input, ModelTensorRef other, 11 | ModelTensorRef output) 12 | : ModelOpBroadcast2("Rope", input, other, output) {} 13 | 14 | Tensor Model::rope(Tensor input, Tensor other, Tensor output, 15 | const std::string &name) { 16 | return impl_ 17 | ->create_op(name, input.ref_, other.ref_, output.ref_) 18 | ->result_tensors()[0]; 19 | } 20 | 21 | } // namespace ark 22 | -------------------------------------------------------------------------------- /ark/ops/ops_rope.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_OPS_ROPE_HPP_ 5 | #define ARK_OPS_ROPE_HPP_ 6 | 7 | #include "ops_broadcast.hpp" 8 | 9 | namespace ark { 10 | 11 | class ModelOpRope : public ModelOpBroadcast2 { 12 | public: 13 | ModelOpRope() = default; 14 | ModelOpRope(ModelTensorRef input, ModelTensorRef weight, 15 | ModelTensorRef output); 16 | }; 17 | 18 | } // namespace ark 19 | 20 | #endif // ARK_OPS_ROPE_HPP_ 21 | -------------------------------------------------------------------------------- /ark/ops/ops_scalar.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_OPS_SCALAR_HPP_ 5 | #define ARK_OPS_SCALAR_HPP_ 6 | 7 | #include "ark/data_type.hpp" 8 | #include "ops_broadcast.hpp" 9 | 10 | namespace ark { 11 | 12 | class ModelOpScalarAssign : public ModelOp { 13 | public: 14 | ModelOpScalarAssign() = default; 15 | ModelOpScalarAssign(float val, const Dims &shape, ModelDataType data_type, 16 | ModelTensorRef output); 17 | 18 | std::string impl_name(const Json &config) const override; 19 | 20 | std::vector impl_args(const Json &config) const override; 21 | 22 | Json default_config(const ArchRef arch = ARCH_ANY) const override; 23 | }; 24 | 25 | class ModelOpScalarAdd : public ModelOpBroadcast1 { 26 | public: 27 | ModelOpScalarAdd() = default; 28 | ModelOpScalarAdd(ModelTensorRef input, float val, ModelTensorRef output); 29 | 30 | std::vector impl_args(const Json &config) const override; 31 | }; 32 | 33 | class ModelOpScalarMul : public ModelOpBroadcast1 { 34 | public: 35 | ModelOpScalarMul() = default; 36 | ModelOpScalarMul(ModelTensorRef input, float val, ModelTensorRef output); 37 | 38 | std::vector impl_args(const Json &config) const override; 39 | }; 40 | 41 | } // namespace ark 42 | 43 | #endif // ARK_OPS_SCALAR_HPP_ 44 | -------------------------------------------------------------------------------- /ark/ops/ops_sharding.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "ops_common.hpp" 5 | 6 | namespace ark { 7 | 8 | // Shard `input` along `axis` into `dim_per_shard`-dimensional shards. 9 | std::vector Model::sharding(Tensor input, DimType axis, 10 | DimType dim_per_shard, 11 | const std::string &name) { 12 | if (axis >= DIMS_LEN) { 13 | ERR(ModelError, "invlaid axis value: ", axis); 14 | } 15 | if ((input.shape()[axis] % dim_per_shard) != 0) { 16 | ERR(ModelError, "dimension length of axis ", axis, " (", 17 | input.shape()[axis], 18 | ") is not divided by the dimension per shard (", dim_per_shard, 19 | ")."); 20 | } 21 | std::vector shards; 22 | DimType num_shard = input.shape()[axis] / dim_per_shard; 23 | Dims shard_shape = input.shape(); 24 | Dims shard_offs = input.offsets(); 25 | for (DimType i = 0; i < num_shard; ++i) { 26 | shard_shape[axis] = dim_per_shard; 27 | Tensor shard = 28 | this->refer(input, shard_shape, input.strides(), shard_offs, {}, 29 | name + "/shard_" + std::to_string(i)); 30 | shards.emplace_back(shard); 31 | shard_offs[axis] += dim_per_shard; 32 | } 33 | return shards; 34 | } 35 | 36 | } // namespace ark 37 | -------------------------------------------------------------------------------- /ark/ops/ops_sharding_test.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "ark/model.hpp" 5 | #include "logging.hpp" 6 | #include "model/model_node.hpp" 7 | #include "model/model_op.hpp" 8 | #include "unittest/unittest_utils.h" 9 | 10 | ark::unittest::State test_ops_sharding_model() { 11 | // OpNode graph: 12 | // 13 | // ReluOp --+ 14 | // | 15 | // ReluOp --+ 16 | // | 17 | // ReluOp --+--> ReluOp 18 | // 19 | 20 | ark::Model model; 21 | ark::Tensor t0 = model.tensor({3}, ark::FP32); 22 | 23 | std::vector vec = model.sharding(t0, 0, 1); 24 | UNITTEST_EQ(vec.size(), 3); 25 | 26 | ark::Tensor t1 = vec[0]; 27 | ark::Tensor t2 = vec[1]; 28 | ark::Tensor t3 = vec[2]; 29 | 30 | ark::Tensor r0 = model.relu(t1); 31 | ark::Tensor r1 = model.relu(t2); 32 | ark::Tensor r2 = model.relu(t3); 33 | 34 | ark::Tensor t4 = model.identity(t0, {r0, r1, r2}); 35 | 36 | ark::Tensor t5 = model.relu(t4); 37 | UNITTEST_TRUE(model.verify()); 38 | 39 | auto compressed = model.compress(); 40 | UNITTEST_TRUE(compressed.verify()); 41 | auto nodes = compressed.nodes(); 42 | UNITTEST_EQ(nodes.size(), 4); 43 | 44 | UNITTEST_EQ(nodes[0]->op->result_tensors()[0], r0.ref()); 45 | UNITTEST_EQ(nodes[0]->producers.size(), 0); 46 | UNITTEST_EQ(nodes[0]->consumers.size(), 1); 47 | 48 | UNITTEST_EQ(nodes[1]->op->result_tensors()[0], r1.ref()); 49 | UNITTEST_EQ(nodes[1]->producers.size(), 0); 50 | UNITTEST_EQ(nodes[1]->consumers.size(), 1); 51 | 52 | UNITTEST_EQ(nodes[2]->op->result_tensors()[0], r2.ref()); 53 | UNITTEST_EQ(nodes[2]->producers.size(), 0); 54 | UNITTEST_EQ(nodes[2]->consumers.size(), 1); 55 | 56 | UNITTEST_EQ(nodes[3]->op->result_tensors()[0], t5.ref()); 57 | UNITTEST_EQ(nodes[3]->producers.size(), 3); 58 | UNITTEST_EQ(nodes[3]->consumers.size(), 0); 59 | 60 | return ark::unittest::SUCCESS; 61 | } 62 | 63 | int main() { 64 | UNITTEST(test_ops_sharding_model); 65 | return 0; 66 | } 67 | -------------------------------------------------------------------------------- /ark/ops/ops_tensor.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "ops_tensor.hpp" 5 | 6 | #include "logging.hpp" 7 | #include "ops_common.hpp" 8 | 9 | namespace ark { 10 | 11 | ModelOpTensor::ModelOpTensor(ModelBufferRef buffer, const Dims &shape, 12 | ModelDataType data_type, const Dims &strides, 13 | const Dims &offsets, const Dims &padded_shape) 14 | : ModelOp("Tensor", true) { 15 | if (!buffer) { 16 | buffer = std::make_shared(); 17 | } 18 | 19 | ModelTensorRef tensor = std::make_shared( 20 | data_type, buffer, shape, strides, offsets, padded_shape); 21 | 22 | result_tensors_.emplace_back(tensor); 23 | 24 | verify(); 25 | } 26 | 27 | Tensor Model::tensor(const Dims &shape, const DataType &data_type, 28 | const Dims &strides, const Dims &offsets, 29 | const Dims &padded_shape, int rank, 30 | const std::string &name) { 31 | if (rank != -1) { 32 | if (rank == this->rank()) { 33 | rank = -1; 34 | } else if (rank < 0 || rank >= this->world_size()) { 35 | ERR(ModelError, "Invalid rank %d", rank); 36 | } 37 | } 38 | return impl_ 39 | ->create_op(name, std::make_shared(rank), 40 | shape, data_type.ref(), strides, offsets, 41 | padded_shape) 42 | ->result_tensors()[0]; 43 | } 44 | 45 | } // namespace ark 46 | -------------------------------------------------------------------------------- /ark/ops/ops_tensor.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_OPS_TENSOR_HPP_ 5 | #define ARK_OPS_TENSOR_HPP_ 6 | 7 | #include "ark/model.hpp" 8 | #include "model/model_op.hpp" 9 | 10 | namespace ark { 11 | 12 | class ModelOpTensor : public ModelOp { 13 | public: 14 | ModelOpTensor() = default; 15 | ModelOpTensor(ModelBufferRef buffer, const Dims &shape, 16 | ModelDataType data_type, const Dims &strides, 17 | const Dims &offsets, const Dims &padded_shape); 18 | }; 19 | 20 | } // namespace ark 21 | 22 | #endif // ARK_OPS_TENSOR_HPP_ 23 | -------------------------------------------------------------------------------- /ark/ops/ops_transpose.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_OPS_TRANSPOSE_HPP_ 5 | #define ARK_OPS_TRANSPOSE_HPP_ 6 | 7 | #include "model/model_op.hpp" 8 | 9 | namespace ark { 10 | 11 | class ModelOpTranspose : public ModelOp { 12 | public: 13 | ModelOpTranspose() = default; 14 | ModelOpTranspose(ModelTensorRef input, 15 | const std::vector &permutation, 16 | ModelTensorRef output); 17 | 18 | std::string impl_name(const Json &config) const override; 19 | 20 | std::vector impl_args(const Json &config) const override; 21 | 22 | Json default_config(const ArchRef arch = ARCH_ANY) const override; 23 | }; 24 | 25 | } // namespace ark 26 | 27 | #endif // ARK_OPS_TRANSPOSE_HPP_ 28 | -------------------------------------------------------------------------------- /ark/range.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "range.hpp" 5 | 6 | namespace ark { 7 | 8 | std::ostream& operator<<(std::ostream& os, const Range& range) { 9 | if (range.step() == 1) { 10 | os << "(" << *range.begin() << ", " << *range.end() << ")"; 11 | } else { 12 | os << "(" << *range.begin() << ", " << *range.end() << ", " 13 | << range.step() << ")"; 14 | } 15 | return os; 16 | } 17 | 18 | } // namespace ark 19 | -------------------------------------------------------------------------------- /ark/range_test.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "range.hpp" 5 | 6 | #include "unittest/unittest_utils.h" 7 | 8 | ark::unittest::State test_range() { 9 | ark::Range r1(0, 10); 10 | UNITTEST_EQ(r1.step(), 1); 11 | int v = 0; 12 | int cnt = 0; 13 | for (auto i : r1) { 14 | UNITTEST_EQ(i, v++); 15 | cnt++; 16 | } 17 | UNITTEST_EQ(cnt, 10); 18 | 19 | ark::Range r2(0, 10, 3); 20 | UNITTEST_EQ(r2.step(), 3); 21 | v = 0; 22 | cnt = 0; 23 | for (auto i : r2) { 24 | UNITTEST_EQ(i, v); 25 | v += 3; 26 | cnt++; 27 | } 28 | UNITTEST_EQ(cnt, 4); 29 | 30 | ark::Range r3(13, 1, -3); 31 | UNITTEST_EQ(r3.step(), -3); 32 | v = 13; 33 | cnt = 0; 34 | for (auto i : r3) { 35 | UNITTEST_EQ(i, v); 36 | v -= 3; 37 | cnt++; 38 | } 39 | UNITTEST_EQ(cnt, 4); 40 | 41 | ark::Range r4(0, 0); 42 | UNITTEST_EQ(r4.step(), 1); 43 | cnt = 0; 44 | for ([[maybe_unused]] auto i : r4) { 45 | cnt++; 46 | } 47 | UNITTEST_EQ(cnt, 0); 48 | 49 | ark::Range r5(2, 19, 3); // 2, 5, 8, 11, 14, 17 50 | ark::Range r6(3, 17, 2); // 3, 5, 7, 9, 11, 13, 15 51 | auto intersec = r5.intersection(r6); 52 | UNITTEST_EQ(intersec[0], 5); 53 | UNITTEST_EQ(intersec[1], 11); 54 | 55 | return ark::unittest::State::SUCCESS; 56 | } 57 | 58 | int main() { 59 | UNITTEST(test_range); 60 | return 0; 61 | } 62 | -------------------------------------------------------------------------------- /ark/unique_list_test.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "unique_list.hpp" 5 | 6 | #include "unittest/unittest_utils.h" 7 | 8 | ark::unittest::State test_unique_list() { 9 | ark::UniqueList list; 10 | list.push_back(1); 11 | list.push_back(2); 12 | list.push_back(3); 13 | list.push_back(1); 14 | list.push_back(2); 15 | list.push_back(3); 16 | UNITTEST_EQ(list.size(), 3); 17 | UNITTEST_EQ(list[0], 1); 18 | UNITTEST_EQ(list[1], 2); 19 | UNITTEST_EQ(list[2], 3); 20 | 21 | list.clear(); 22 | UNITTEST_EQ(list.size(), 0); 23 | 24 | list.push_back(1); 25 | list.push_back(2); 26 | list.push_back(3); 27 | list.push_back(1); 28 | list.push_back(2); 29 | list.push_back(3); 30 | list.push_back(4); 31 | UNITTEST_EQ(list.size(), 4); 32 | UNITTEST_EQ(list[0], 1); 33 | UNITTEST_EQ(list[1], 2); 34 | UNITTEST_EQ(list[2], 3); 35 | UNITTEST_EQ(list[3], 4); 36 | 37 | list.clear(); 38 | UNITTEST_EQ(list.size(), 0); 39 | 40 | list.push_back(1); 41 | list.push_back(2); 42 | list.push_back(3); 43 | 44 | list.erase(1); 45 | UNITTEST_EQ(list.size(), 2); 46 | UNITTEST_EQ(list[0], 2); 47 | UNITTEST_EQ(list[1], 3); 48 | 49 | list.clear(); 50 | UNITTEST_EQ(list.size(), 0); 51 | 52 | list.push_back(1); 53 | list.push_back(2); 54 | list.push_back(3); 55 | 56 | list.erase(0); 57 | UNITTEST_EQ(list.size(), 3); 58 | UNITTEST_EQ(list[0], 1); 59 | UNITTEST_EQ(list[1], 2); 60 | UNITTEST_EQ(list[2], 3); 61 | 62 | list.clear(); 63 | UNITTEST_EQ(list.size(), 0); 64 | 65 | list.push_back(1); 66 | list.push_back(2); 67 | list.push_back(3); 68 | 69 | list.erase(2); 70 | UNITTEST_EQ(list.size(), 2); 71 | UNITTEST_EQ(list[0], 1); 72 | UNITTEST_EQ(list[1], 3); 73 | 74 | return ark::unittest::SUCCESS; 75 | } 76 | 77 | int main() { 78 | UNITTEST(test_unique_list); 79 | return 0; 80 | } 81 | -------------------------------------------------------------------------------- /ark/utils/utils_math.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "utils/utils_math.hpp" 5 | 6 | #include "logging.hpp" 7 | 8 | namespace ark { 9 | namespace math { 10 | 11 | // Calculate the ceiling of x / div. 12 | size_t div_up(size_t x, size_t div) { 13 | if (div == 0) { 14 | ERR(InvalidUsageError, "division by zero"); 15 | } 16 | if (x == 0) { 17 | return 0; 18 | } 19 | return 1 + ((x - 1) / div); 20 | } 21 | 22 | // Calculate the minimum multiple of u that is greater than or equal to x. 23 | size_t pad(size_t x, size_t u) { return div_up(x, u) * u; } 24 | 25 | // Return true if x is a power of 2. 26 | bool is_pow2(size_t x) { 27 | if (x == 0) { 28 | return false; 29 | } 30 | return (x & (x - 1)) == 0; 31 | } 32 | 33 | // Return the log base 2 of x. x must be a power of 2. 34 | unsigned int ilog2(unsigned int x) { 35 | if (x == 0) { 36 | ERR(InvalidUsageError, "log of zero is undefined"); 37 | } 38 | return (sizeof(unsigned int) * 8) - __builtin_clz(x) - 1; 39 | } 40 | 41 | // Greatest Common Divisor. 42 | size_t gcd(size_t a, size_t b) { 43 | if (a == 0) { 44 | return b; 45 | } 46 | if (b == 0) { 47 | return a; 48 | } 49 | while (b != 0) { 50 | size_t t = b; 51 | b = a % b; 52 | a = t; 53 | } 54 | return a; 55 | } 56 | 57 | // Least Common Multiple. 58 | size_t lcm(size_t a, size_t b) { return a / gcd(a, b) * b; } 59 | 60 | } // namespace math 61 | } // namespace ark 62 | -------------------------------------------------------------------------------- /ark/utils/utils_math.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_UTILS_MATH_HPP_ 5 | #define ARK_UTILS_MATH_HPP_ 6 | 7 | #include 8 | 9 | namespace ark { 10 | namespace math { 11 | 12 | size_t div_up(size_t x, size_t div); 13 | size_t pad(size_t x, size_t u); 14 | bool is_pow2(size_t x); 15 | unsigned int ilog2(unsigned int x); 16 | size_t gcd(size_t a, size_t b); 17 | size_t lcm(size_t a, size_t b); 18 | 19 | } // namespace math 20 | } // namespace ark 21 | 22 | #endif // ARK_UTILS_MATH_HPP_ 23 | -------------------------------------------------------------------------------- /ark/utils/utils_net.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "utils_net.hpp" 5 | 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | 12 | #include "env.h" 13 | #include "file_io.h" 14 | #include "logging.hpp" 15 | 16 | namespace ark { 17 | 18 | static std::vector hosts; 19 | 20 | const std::string get_host(int idx, bool reset) { 21 | if (reset) { 22 | hosts.clear(); 23 | } 24 | if (hosts.size() == 0) { 25 | const auto &hostfile = get_env().hostfile; 26 | if (!is_file(hostfile)) { 27 | LOG(WARN, "cannot open hostfile: ", hostfile, ", assume localhost"); 28 | hosts.push_back("127.0.0.1"); 29 | } else { 30 | std::ifstream ifs(hostfile); 31 | std::string line; 32 | int host_idx = 0; 33 | while (std::getline(ifs, line)) { 34 | // Hostname to IP 35 | struct hostent *ent = ::gethostbyname(line.c_str()); 36 | if (ent == nullptr) { 37 | ERR(InvalidUsageError, "cannot resolve hostname: ", line); 38 | } 39 | char *host = ::inet_ntoa(*(struct in_addr *)ent->h_addr); 40 | LOG(INFO, "HOST ", host_idx, ": ", host); 41 | hosts.emplace_back(host); 42 | host_idx++; 43 | } 44 | } 45 | } 46 | if ((idx < 0) || (idx >= (int)hosts.size())) { 47 | ERR(InvalidUsageError, "invalid host index: ", idx); 48 | } 49 | return hosts[idx]; 50 | } 51 | 52 | } // namespace ark 53 | -------------------------------------------------------------------------------- /ark/utils/utils_net.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_UTILS_NET_HPP_ 5 | #define ARK_UTILS_NET_HPP_ 6 | 7 | #include 8 | 9 | namespace ark { 10 | 11 | /// Return a hostname from the hostfile. 12 | /// @param idx Index of the hostname to return. 13 | /// @param reset Whether to reread the hostfile. 14 | /// @return The hostname. 15 | const std::string get_host(int idx, bool reset = false); 16 | 17 | } // namespace ark 18 | 19 | #endif // ARK_UTILS_NET_HPP_ 20 | -------------------------------------------------------------------------------- /ark/utils/utils_net_test.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "utils_net.hpp" 5 | 6 | #include "env.h" 7 | #include "file_io.h" 8 | #include "unittest/unittest_utils.h" 9 | 10 | ark::unittest::State test_ipc_hosts() { 11 | auto tmp_dir = ark::get_env().path_tmp_dir; 12 | auto tmp_hostfile = tmp_dir + "/.test_ipc_hostfile"; 13 | ark::write_file(tmp_hostfile, "127.0.0.1\n127.0.0.1\n127.0.0.1\n"); 14 | ::setenv("ARK_HOSTFILE", tmp_hostfile.c_str(), 1); 15 | ark::init(); 16 | 17 | UNITTEST_EQ(ark::get_host(0, true), "127.0.0.1"); 18 | UNITTEST_EQ(ark::get_host(1), "127.0.0.1"); 19 | UNITTEST_EQ(ark::get_host(2), "127.0.0.1"); 20 | 21 | UNITTEST_THROW(ark::get_host(-1), ark::InvalidUsageError); 22 | UNITTEST_THROW(ark::get_host(3), ark::InvalidUsageError); 23 | 24 | ark::remove_file(tmp_hostfile); 25 | 26 | return ark::unittest::SUCCESS; 27 | } 28 | 29 | ark::unittest::State test_ipc_hosts_unknown_host() { 30 | auto tmp_dir = ark::get_env().path_tmp_dir; 31 | auto tmp_hostfile = tmp_dir + "/.test_ipc_hostfile"; 32 | ark::write_file(tmp_hostfile, "unknown\nunknown\nunknown\n"); 33 | ::setenv("ARK_HOSTFILE", tmp_hostfile.c_str(), 1); 34 | ark::init(); 35 | 36 | UNITTEST_THROW(ark::get_host(0, true), ark::InvalidUsageError); 37 | 38 | ark::remove_file(tmp_hostfile); 39 | 40 | return ark::unittest::SUCCESS; 41 | } 42 | 43 | int main() { 44 | UNITTEST(test_ipc_hosts); 45 | UNITTEST(test_ipc_hosts_unknown_host); 46 | return 0; 47 | } 48 | -------------------------------------------------------------------------------- /ark/utils/utils_string.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "utils_string.hpp" 5 | 6 | #include "logging.hpp" 7 | 8 | namespace ark { 9 | 10 | bool is_pascal(const std::string &str) { 11 | if (str.empty()) { 12 | return false; 13 | } 14 | if (!std::isupper(str[0])) { 15 | return false; 16 | } 17 | for (size_t i = 1; i < str.size(); ++i) { 18 | if (!std::isalnum(str[i])) { 19 | return false; 20 | } 21 | } 22 | return true; 23 | } 24 | 25 | std::string pascal_to_snake(const std::string &str) { 26 | if (!is_pascal(str)) { 27 | ERR(InvalidUsageError, "given string (", str, 28 | ") is not in Pascal case"); 29 | } 30 | std::string ret; 31 | for (size_t i = 0; i < str.size(); ++i) { 32 | if (i > 0 && std::isupper(str[i])) { 33 | ret.push_back('_'); 34 | } 35 | ret.push_back(std::tolower(str[i])); 36 | } 37 | return ret; 38 | } 39 | 40 | std::string to_upper(const std::string &str) { 41 | std::string ret; 42 | for (size_t i = 0; i < str.size(); ++i) { 43 | ret.push_back(std::toupper(str[i])); 44 | } 45 | return ret; 46 | } 47 | 48 | std::string to_lower(const std::string &str) { 49 | std::string ret; 50 | for (size_t i = 0; i < str.size(); ++i) { 51 | ret.push_back(std::tolower(str[i])); 52 | } 53 | return ret; 54 | } 55 | 56 | } // namespace ark 57 | -------------------------------------------------------------------------------- /ark/utils/utils_string.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ARK_UTILS_STRING_HPP_ 5 | #define ARK_UTILS_STRING_HPP_ 6 | 7 | #include 8 | 9 | namespace ark { 10 | 11 | bool is_pascal(const std::string &str); 12 | 13 | std::string pascal_to_snake(const std::string &str); 14 | 15 | std::string to_upper(const std::string &str); 16 | 17 | std::string to_lower(const std::string &str); 18 | 19 | } // namespace ark 20 | 21 | #endif // ARK_UTILS_STRING_HPP_ 22 | -------------------------------------------------------------------------------- /ark/utils/utils_string_test.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "utils/utils_string.hpp" 5 | 6 | #include "unittest/unittest_utils.h" 7 | 8 | ark::unittest::State test_utils_string() { 9 | UNITTEST_TRUE(ark::is_pascal("PascalCase")); 10 | UNITTEST_FALSE(ark::is_pascal("")); 11 | UNITTEST_FALSE(ark::is_pascal("notPascalCase")); 12 | UNITTEST_FALSE(ark::is_pascal("Not_PascalCase")); 13 | 14 | UNITTEST_EQ(ark::pascal_to_snake("PascalCase"), "pascal_case"); 15 | 16 | UNITTEST_EQ(ark::to_upper("upper"), "UPPER"); 17 | UNITTEST_EQ(ark::to_lower("UPPER"), "upper"); 18 | 19 | return ark::unittest::SUCCESS; 20 | } 21 | 22 | int main() { 23 | UNITTEST(test_utils_string); 24 | return 0; 25 | } 26 | -------------------------------------------------------------------------------- /cmake/CheckAmdGpu.cmake: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | set(AMD_FOUND "FALSE") 5 | 6 | set(CMAKE_PREFIX_PATH "/opt/rocm;${CMAKE_PREFIX_PATH}") 7 | # Temporal fix for rocm5.6 8 | set(ENV{amd_comgr_DIR} "/opt/rocm/lib/cmake/amd_comgr") 9 | set(ENV{AMDDeviceLibs_DIR} "/opt/rocm/lib/cmake/AMDDeviceLibs") 10 | 11 | find_package(hip QUIET) 12 | 13 | if(NOT hip_FOUND) 14 | return() 15 | endif() 16 | 17 | enable_language(HIP) 18 | 19 | set(CHECK_SRC "${CMAKE_CURRENT_SOURCE_DIR}/cmake/check_amd_gpu.hip") 20 | 21 | try_run(RUN_RESULT COMPILE_SUCCESS SOURCES ${CHECK_SRC}) 22 | 23 | if(COMPILE_SUCCESS AND RUN_RESULT EQUAL 0) 24 | set(AMD_FOUND "TRUE") 25 | endif() 26 | -------------------------------------------------------------------------------- /cmake/CheckNvidiaGpu.cmake: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | set(NVIDIA_FOUND "FALSE") 5 | 6 | find_package(CUDAToolkit) 7 | 8 | if(NOT CUDAToolkit_FOUND) 9 | return() 10 | endif() 11 | 12 | set(CMAKE_CUDA_ARCHITECTURES "60") 13 | if(NOT CMAKE_CUDA_COMPILER) 14 | # In case the CUDA Toolkit directory is not in the PATH 15 | find_program(CUDA_COMPILER 16 | NAMES nvcc 17 | PATHS ${CUDAToolkit_BIN_DIR}) 18 | if(NOT CUDA_COMPILER) 19 | message(WARNING "Could not find nvcc in ${CUDAToolkit_BIN_DIR}") 20 | unset(CMAKE_CUDA_ARCHITECTURES) 21 | return() 22 | endif() 23 | set(CMAKE_CUDA_COMPILER "${CUDA_COMPILER}") 24 | endif() 25 | enable_language(CUDA) 26 | 27 | set(CHECK_SRC "${CMAKE_CURRENT_SOURCE_DIR}/cmake/check_nvidia_gpu.cu") 28 | 29 | try_run(RUN_RESULT COMPILE_SUCCESS SOURCES ${CHECK_SRC}) 30 | 31 | if(COMPILE_SUCCESS AND RUN_RESULT EQUAL 0) 32 | set(NVIDIA_FOUND "TRUE") 33 | elseif(COMPILE_SUCCESS) 34 | message(WARNING "CUDA compiler found but no NVIDIA GPU detected") 35 | else() 36 | message(WARNING "CUDA compiler found but failed to compile a CUDA program") 37 | unset(CMAKE_CUDA_ARCHITECTURES) 38 | unset(CMAKE_CUDA_COMPILER) 39 | endif() 40 | -------------------------------------------------------------------------------- /cmake/FindIBVerbs.cmake: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | # Find the IB Verbs libraries 5 | # 6 | # The following variables are optionally searched for defaults 7 | # IBVERBS_ROOT_DIR: Base directory where all ibverbs components are found 8 | # IBVERBS_INCLUDE_DIR: Directory where ibverbs headers are found 9 | # IBVERBS_LIB_DIR: Directory where ibverbs libraries are found 10 | 11 | # The following are set after configuration is done: 12 | # IBVERBS_FOUND 13 | # IBVERBS_INCLUDE_DIRS 14 | # IBVERBS_LIBRARIES 15 | 16 | # An imported target ARK::ibverbs is created if the library is found. 17 | 18 | find_path(IBVERBS_INCLUDE_DIRS 19 | NAMES infiniband/verbs.h 20 | HINTS 21 | ${IBVERBS_INCLUDE_DIR} 22 | ${IBVERBS_ROOT_DIR} 23 | ${IBVERBS_ROOT_DIR}/include 24 | ) 25 | 26 | find_library(IBVERBS_LIBRARIES 27 | NAMES ibverbs 28 | HINTS 29 | ${IBVERBS_LIB_DIR} 30 | ${IBVERBS_ROOT_DIR} 31 | ${IBVERBS_ROOT_DIR}/lib 32 | ) 33 | 34 | include(FindPackageHandleStandardArgs) 35 | find_package_handle_standard_args(IBVerbs DEFAULT_MSG IBVERBS_INCLUDE_DIRS IBVERBS_LIBRARIES) 36 | mark_as_advanced(IBVERBS_INCLUDE_DIR IBVERBS_LIBRARIES) 37 | 38 | if(IBVERBS_FOUND) 39 | if(NOT TARGET ARK::ibverbs) 40 | add_library(ARK::ibverbs UNKNOWN IMPORTED) 41 | endif() 42 | set_target_properties(ARK::ibverbs PROPERTIES 43 | INTERFACE_INCLUDE_DIRECTORIES "${IBVERBS_INCLUDE_DIR}" 44 | IMPORTED_LINK_INTERFACE_LANGUAGES "C" 45 | IMPORTED_LOCATION "${IBVERBS_LIBRARIES}" 46 | ) 47 | endif() 48 | -------------------------------------------------------------------------------- /cmake/FindNUMA.cmake: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | # Find the numa libraries 5 | # 6 | # The following variables are optionally searched for defaults 7 | # NUMA_ROOT_DIR: Base directory where all numa components are found 8 | # NUMA_INCLUDE_DIR: Directory where numa headers are found 9 | # NUMA_LIB_DIR: Directory where numa libraries are found 10 | 11 | # The following are set after configuration is done: 12 | # NUMA_FOUND 13 | # NUMA_INCLUDE_DIRS 14 | # NUMA_LIBRARIES 15 | 16 | # An imported target ARK::numa is created if the library is found. 17 | 18 | find_path(NUMA_INCLUDE_DIRS 19 | NAMES numa.h 20 | HINTS 21 | ${NUMA_INCLUDE_DIR} 22 | ${NUMA_ROOT_DIR} 23 | ${NUMA_ROOT_DIR}/include 24 | ) 25 | 26 | find_library(NUMA_LIBRARIES 27 | NAMES numa 28 | HINTS 29 | ${NUMA_LIB_DIR} 30 | ${NUMA_ROOT_DIR} 31 | ${NUMA_ROOT_DIR}/lib 32 | ) 33 | 34 | include(FindPackageHandleStandardArgs) 35 | find_package_handle_standard_args(NUMA DEFAULT_MSG NUMA_INCLUDE_DIRS NUMA_LIBRARIES) 36 | mark_as_advanced(NUMA_INCLUDE_DIR NUMA_LIBRARIES) 37 | 38 | if(NUMA_FOUND) 39 | if(NOT TARGET ARK::numa) 40 | add_library(ARK::numa UNKNOWN IMPORTED) 41 | endif() 42 | set_target_properties(ARK::numa PROPERTIES 43 | INTERFACE_INCLUDE_DIRECTORIES "${NUMA_INCLUDE_DIR}" 44 | IMPORTED_LINK_INTERFACE_LANGUAGES "C" 45 | IMPORTED_LOCATION "${NUMA_LIBRARIES}" 46 | ) 47 | endif() 48 | -------------------------------------------------------------------------------- /cmake/Utils.cmake: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | # git-clang-format 5 | find_program(GIT_CLANG_FORMAT git-clang-format) 6 | if(GIT_CLANG_FORMAT) 7 | message(STATUS "Found git-clang-format: ${GIT_CLANG_FORMAT}") 8 | set(FIND_DIRS 9 | ${PROJECT_SOURCE_DIR}/ark 10 | ${PROJECT_SOURCE_DIR}/python 11 | ${PROJECT_SOURCE_DIR}/examples 12 | ) 13 | add_custom_target(cpplint 14 | COMMAND ${GIT_CLANG_FORMAT} --style=file --diff || true 15 | ) 16 | add_custom_target(cpplint-autofix 17 | COMMAND ${GIT_CLANG_FORMAT} --style=file --force --extensions cc,cpp,h,hpp,cu,in,hip || true 18 | ) 19 | else() 20 | message(STATUS "git-clang-format not found.") 21 | endif() 22 | 23 | # black 24 | find_program(BLACK black) 25 | if(BLACK) 26 | add_custom_target(pylint 27 | COMMAND python3 -m black --check --config ${PROJECT_SOURCE_DIR}/pyproject.toml ${PROJECT_SOURCE_DIR} 28 | ) 29 | add_custom_target(pylint-autofix 30 | COMMAND python3 -m black --config ${PROJECT_SOURCE_DIR}/pyproject.toml ${PROJECT_SOURCE_DIR} 31 | ) 32 | else() 33 | message(STATUS "black not found.") 34 | endif() 35 | 36 | # lcov 37 | find_program(LCOV lcov) 38 | if(LCOV) 39 | message(STATUS "Found lcov: ${LCOV}") 40 | add_custom_target(lcov 41 | COMMAND ${LCOV} --directory . --capture --output-file coverage.info 42 | COMMAND ${LCOV} --remove coverage.info 43 | '/usr/*' 44 | '/tmp/*' 45 | '*/third_party/*' 46 | '*/ark/*_test.*' 47 | '*/examples/*' 48 | '*/python/*' 49 | '*/ark/unittest/unittest_utils.cc' 50 | --output-file coverage.info 51 | COMMAND ${LCOV} --list coverage.info 52 | ) 53 | else() 54 | message(STATUS "lcov not found.") 55 | endif() 56 | -------------------------------------------------------------------------------- /cmake/check_amd_gpu.hip: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include 5 | 6 | __global__ void kernel() {} 7 | 8 | int main() { 9 | int cnt; 10 | hipError_t err = hipGetDeviceCount(&cnt); 11 | if (err != hipSuccess || cnt == 0) { 12 | return 1; 13 | } 14 | return 0; 15 | } 16 | -------------------------------------------------------------------------------- /cmake/check_nvidia_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include 5 | 6 | __global__ void kernel() {} 7 | 8 | int main() { 9 | int cnt; 10 | cudaError_t err = cudaGetDeviceCount(&cnt); 11 | if (err != cudaSuccess || cnt == 0) { 12 | return 1; 13 | } 14 | return 0; 15 | } 16 | -------------------------------------------------------------------------------- /docker/base-dev-x.dockerfile: -------------------------------------------------------------------------------- 1 | ARG BASE_IMAGE=ghcr.io/microsoft/ark/ark:base-cuda12.1 2 | FROM ${BASE_IMAGE} 3 | 4 | LABEL maintainer="ARK" 5 | LABEL org.opencontainers.image.source https://github.com/microsoft/ark 6 | 7 | ENV ARK_SRC_DIR="/tmp/ark" \ 8 | CMAKE_VERSION="3.26.4" 9 | 10 | ADD . ${ARK_SRC_DIR} 11 | WORKDIR ${ARK_SRC_DIR} 12 | 13 | # Install Lcov 14 | RUN apt-get update && \ 15 | apt-get install -y --no-install-recommends \ 16 | lcov \ 17 | && \ 18 | apt-get autoremove && \ 19 | apt-get clean && \ 20 | rm -rf /var/lib/apt/lists/* /tmp/* 21 | 22 | # Install cmake 3.26.4 23 | ENV CMAKE_HOME="/tmp/cmake-${CMAKE_VERSION}-linux-x86_64" \ 24 | CMAKE_URL="https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}-linux-x86_64.tar.gz" 25 | RUN curl -L ${CMAKE_URL} -o ${CMAKE_HOME}.tar.gz && \ 26 | tar xzf ${CMAKE_HOME}.tar.gz -C /usr/local && \ 27 | rm -rf ${CMAKE_HOME}.tar.gz 28 | ENV PATH="/usr/local/cmake-${CMAKE_VERSION}-linux-x86_64/bin:${PATH}" 29 | 30 | # Set PATH 31 | RUN echo PATH="${PATH}" > /etc/environment 32 | 33 | # Cleanup 34 | WORKDIR / 35 | RUN rm -rf ${ARK_SRC_DIR} 36 | -------------------------------------------------------------------------------- /docker/base-rocm5.6.dockerfile: -------------------------------------------------------------------------------- 1 | # Temporal Dockerfile for building ARK base image for ROCm 5.6 2 | 3 | ARG BASE_IMAGE=rocm/dev-ubuntu-20.04:5.6.1-complete 4 | FROM ${BASE_IMAGE} 5 | 6 | LABEL maintainer="ARK" 7 | LABEL org.opencontainers.image.source https://github.com/microsoft/ark 8 | 9 | ENV DEBIAN_FRONTEND=noninteractive 10 | 11 | RUN rm -rf /opt/nvidia 12 | 13 | RUN apt-get update && \ 14 | apt-get install -y --no-install-recommends \ 15 | build-essential \ 16 | ca-certificates \ 17 | curl \ 18 | git \ 19 | libcap2 \ 20 | libnuma-dev \ 21 | openssh-client \ 22 | openssh-server \ 23 | python3-dev \ 24 | python3-pip \ 25 | python3-setuptools \ 26 | python3-wheel \ 27 | sudo \ 28 | wget \ 29 | && \ 30 | apt-get autoremove && \ 31 | apt-get clean && \ 32 | rm -rf /var/lib/apt/lists/* /tmp/* 33 | 34 | # Install OFED 35 | ENV OFED_VERSION=5.2-2.2.3.0 36 | RUN cd /tmp && \ 37 | wget -q https://content.mellanox.com/ofed/MLNX_OFED-${OFED_VERSION}/MLNX_OFED_LINUX-${OFED_VERSION}-ubuntu20.04-x86_64.tgz && \ 38 | tar xzf MLNX_OFED_LINUX-${OFED_VERSION}-ubuntu20.04-x86_64.tgz && \ 39 | MLNX_OFED_LINUX-${OFED_VERSION}-ubuntu20.04-x86_64/mlnxofedinstall --user-space-only --without-fw-update --force --all && \ 40 | rm -rf /tmp/MLNX_OFED_LINUX-${OFED_VERSION}* 41 | 42 | # Install OpenMPI 43 | ENV OPENMPI_VERSION=4.1.5 44 | RUN cd /tmp && \ 45 | export ompi_v_parsed="$(echo ${OPENMPI_VERSION} | sed -E 's/^([0-9]+)\.([0-9]+)\..*/\1.\2/')" && \ 46 | wget -q https://download.open-mpi.org/release/open-mpi/v${ompi_v_parsed}/openmpi-${OPENMPI_VERSION}.tar.gz && \ 47 | tar xzf openmpi-${OPENMPI_VERSION}.tar.gz && \ 48 | cd openmpi-${OPENMPI_VERSION} && \ 49 | ./configure --prefix=/usr/local/mpi && \ 50 | make -j && \ 51 | make install && \ 52 | cd .. && \ 53 | rm -rf /tmp/openmpi-${OPENMPI_VERSION}* 54 | 55 | ARG EXTRA_LD_PATH=/opt/rocm/lib 56 | ENV PATH="/usr/local/mpi/bin:${PATH}" \ 57 | LD_LIBRARY_PATH="/usr/local/mpi/lib:${EXTRA_LD_PATH}:${LD_LIBRARY_PATH}" 58 | 59 | RUN echo PATH="${PATH}" > /etc/environment && \ 60 | echo LD_LIBRARY_PATH="${LD_LIBRARY_PATH}" >> /etc/environment 61 | 62 | # Copy amd_hip_bf16.h from ROCm 5.7 63 | ADD amd_hip_bf16.h /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h 64 | 65 | ENTRYPOINT [] 66 | -------------------------------------------------------------------------------- /docker/base-x.dockerfile: -------------------------------------------------------------------------------- 1 | ARG BASE_IMAGE=nvidia/cuda:12.1.1-devel-ubuntu20.04 2 | FROM ${BASE_IMAGE} 3 | 4 | LABEL maintainer="ARK" 5 | LABEL org.opencontainers.image.source https://github.com/microsoft/ark 6 | 7 | ENV DEBIAN_FRONTEND=noninteractive 8 | 9 | RUN rm -rf /opt/nvidia 10 | 11 | RUN apt-get update && \ 12 | apt-get install -y --no-install-recommends \ 13 | build-essential \ 14 | ca-certificates \ 15 | curl \ 16 | git \ 17 | libcap2 \ 18 | libnuma-dev \ 19 | openssh-client \ 20 | openssh-server \ 21 | python3-dev \ 22 | python3-pip \ 23 | python3-setuptools \ 24 | python3-wheel \ 25 | sudo \ 26 | wget \ 27 | && \ 28 | apt-get autoremove && \ 29 | apt-get clean && \ 30 | rm -rf /var/lib/apt/lists/* /tmp/* 31 | 32 | # Install OFED 33 | ENV OFED_VERSION=5.2-2.2.3.0 34 | RUN cd /tmp && \ 35 | wget -q https://content.mellanox.com/ofed/MLNX_OFED-${OFED_VERSION}/MLNX_OFED_LINUX-${OFED_VERSION}-ubuntu20.04-x86_64.tgz && \ 36 | tar xzf MLNX_OFED_LINUX-${OFED_VERSION}-ubuntu20.04-x86_64.tgz && \ 37 | MLNX_OFED_LINUX-${OFED_VERSION}-ubuntu20.04-x86_64/mlnxofedinstall --user-space-only --without-fw-update --force --all && \ 38 | rm -rf /tmp/MLNX_OFED_LINUX-${OFED_VERSION}* 39 | 40 | # Install OpenMPI 41 | ENV OPENMPI_VERSION=4.1.5 42 | RUN cd /tmp && \ 43 | export ompi_v_parsed="$(echo ${OPENMPI_VERSION} | sed -E 's/^([0-9]+)\.([0-9]+)\..*/\1.\2/')" && \ 44 | wget -q https://download.open-mpi.org/release/open-mpi/v${ompi_v_parsed}/openmpi-${OPENMPI_VERSION}.tar.gz && \ 45 | tar xzf openmpi-${OPENMPI_VERSION}.tar.gz && \ 46 | cd openmpi-${OPENMPI_VERSION} && \ 47 | ./configure --prefix=/usr/local/mpi && \ 48 | make -j && \ 49 | make install && \ 50 | cd .. && \ 51 | rm -rf /tmp/openmpi-${OPENMPI_VERSION}* 52 | 53 | ARG EXTRA_LD_PATH=/usr/local/cuda-12.1/compat:/usr/local/cuda-12.1/lib64 54 | ENV PATH="/usr/local/mpi/bin:${PATH}" \ 55 | LD_LIBRARY_PATH="/usr/local/mpi/lib:${EXTRA_LD_PATH}:${LD_LIBRARY_PATH}" 56 | 57 | RUN echo PATH="${PATH}" > /etc/environment && \ 58 | echo LD_LIBRARY_PATH="${LD_LIBRARY_PATH}" >> /etc/environment 59 | 60 | ENTRYPOINT [] 61 | -------------------------------------------------------------------------------- /docker/build-x.dockerfile: -------------------------------------------------------------------------------- 1 | ARG BASE_IMAGE 2 | FROM ${BASE_IMAGE} 3 | 4 | LABEL maintainer="ARK" 5 | LABEL org.opencontainers.image.source https://github.com/microsoft/ark 6 | 7 | ENV DEBIAN_FRONTEND=noninteractive \ 8 | CMAKE_VERSION="3.26.4" 9 | 10 | RUN rm -rf /opt/nvidia 11 | 12 | RUN apt-get update && \ 13 | apt-get install -y --no-install-recommends \ 14 | build-essential \ 15 | ca-certificates \ 16 | curl \ 17 | git \ 18 | libcap2 \ 19 | libnuma-dev \ 20 | openssh-client \ 21 | openssh-server \ 22 | python3-dev \ 23 | python3-pip \ 24 | python3-setuptools \ 25 | python3-wheel \ 26 | sudo \ 27 | wget \ 28 | && \ 29 | apt-get autoremove && \ 30 | apt-get clean && \ 31 | rm -rf /var/lib/apt/lists/* /tmp/* 32 | 33 | # Install OFED 34 | ENV OFED_VERSION=5.2-2.2.3.0 35 | RUN cd /tmp && \ 36 | wget -q https://content.mellanox.com/ofed/MLNX_OFED-${OFED_VERSION}/MLNX_OFED_LINUX-${OFED_VERSION}-ubuntu20.04-x86_64.tgz && \ 37 | tar xzf MLNX_OFED_LINUX-${OFED_VERSION}-ubuntu20.04-x86_64.tgz && \ 38 | MLNX_OFED_LINUX-${OFED_VERSION}-ubuntu20.04-x86_64/mlnxofedinstall --user-space-only --without-fw-update --force --all && \ 39 | rm -rf /tmp/MLNX_OFED_LINUX-${OFED_VERSION}* 40 | 41 | # Install cmake 3.26.4 42 | ENV CMAKE_HOME="/tmp/cmake-${CMAKE_VERSION}-linux-x86_64" \ 43 | CMAKE_URL="https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}-linux-x86_64.tar.gz" 44 | RUN curl -L ${CMAKE_URL} -o ${CMAKE_HOME}.tar.gz && \ 45 | tar xzf ${CMAKE_HOME}.tar.gz -C /usr/local && \ 46 | rm -rf ${CMAKE_HOME}.tar.gz 47 | ENV PATH="/usr/local/cmake-${CMAKE_VERSION}-linux-x86_64/bin:${PATH}" 48 | 49 | ARG EXTRA_LD_PATH 50 | ENV LD_LIBRARY_PATH="${EXTRA_LD_PATH}:${LD_LIBRARY_PATH}" 51 | 52 | RUN echo PATH="${PATH}" > /etc/environment \ 53 | echo LD_LIBRARY_PATH="${LD_LIBRARY_PATH}" >> /etc/environment 54 | 55 | ENTRYPOINT [] 56 | -------------------------------------------------------------------------------- /docker/build.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | declare -A buildImageTable 6 | buildImageTable=( 7 | ["cuda11.8"]="nvidia/cuda:11.8.0-devel-ubuntu20.04" 8 | ["cuda12.1"]="nvidia/cuda:12.1.1-devel-ubuntu20.04" 9 | ["cuda12.2"]="nvidia/cuda:12.2.2-devel-ubuntu20.04" 10 | ["cuda12.4"]="nvidia/cuda:12.4.1-devel-ubuntu20.04" 11 | ["rocm5.7"]="rocm/dev-ubuntu-20.04:5.7" 12 | ["rocm6.0"]="rocm/dev-ubuntu-20.04:6.0" 13 | ["rocm6.1"]="rocm/dev-ubuntu-20.04:6.1" 14 | ) 15 | 16 | declare -A baseImageTable 17 | baseImageTable=( 18 | ["cuda11.8"]="nvidia/cuda:11.8.0-devel-ubuntu20.04" 19 | ["cuda12.1"]="nvidia/cuda:12.1.1-devel-ubuntu20.04" 20 | ["cuda12.2"]="nvidia/cuda:12.2.2-devel-ubuntu20.04" 21 | ["cuda12.4"]="nvidia/cuda:12.4.1-devel-ubuntu20.04" 22 | ["rocm5.7"]="rocm/dev-ubuntu-20.04:5.7-complete" 23 | ["rocm6.0"]="rocm/dev-ubuntu-20.04:6.0-complete" 24 | ["rocm6.1"]="rocm/dev-ubuntu-20.04:6.1-complete" 25 | ) 26 | 27 | declare -A extraLdPathTable 28 | extraLdPathTable=( 29 | ["cuda11.8"]="/usr/local/cuda-11.8/lib64" 30 | ["cuda12.1"]="/usr/local/cuda-12.1/compat:/usr/local/cuda-12.1/lib64" 31 | ["cuda12.2"]="/usr/local/cuda-12.2/compat:/usr/local/cuda-12.2/lib64" 32 | ["cuda12.4"]="/usr/local/cuda-12.4/compat:/usr/local/cuda-12.4/lib64" 33 | ["rocm5.7"]="/opt/rocm/lib" 34 | ["rocm6.0"]="/opt/rocm/lib" 35 | ["rocm6.1"]="/opt/rocm/lib" 36 | ) 37 | 38 | GHCR="ghcr.io/microsoft/ark/ark" 39 | TARGET=${1} 40 | 41 | print_usage() { 42 | echo "Usage: $0 [cuda11.8|cuda12.1|cuda12.2|cuda12.4|rocm5.7|rocm6.0|rocm6.1]" 43 | } 44 | 45 | if [[ ! -v "baseImageTable[${TARGET}]" ]]; then 46 | echo "Invalid target: ${TARGET}" 47 | print_usage 48 | exit 1 49 | fi 50 | echo "Target: ${TARGET}" 51 | 52 | SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" 53 | 54 | cd ${SCRIPT_DIR}/.. 55 | 56 | docker build -t ${GHCR}:build-${TARGET} \ 57 | -f docker/build-x.dockerfile \ 58 | --build-arg BASE_IMAGE=${buildImageTable[${TARGET}]} \ 59 | --build-arg EXTRA_LD_PATH=${extraLdPathTable[${TARGET}]} . 60 | 61 | docker build -t ${GHCR}:base-${TARGET} \ 62 | -f docker/base-x.dockerfile \ 63 | --build-arg BASE_IMAGE=${baseImageTable[${TARGET}]} \ 64 | --build-arg EXTRA_LD_PATH=${extraLdPathTable[${TARGET}]} . 65 | 66 | docker build -t ${GHCR}:base-dev-${TARGET} \ 67 | -f docker/base-dev-x.dockerfile \ 68 | --build-arg BASE_IMAGE=${GHCR}:base-${TARGET} . 69 | -------------------------------------------------------------------------------- /docs/doxygen/.gitignore: -------------------------------------------------------------------------------- 1 | doxygen/ 2 | -------------------------------------------------------------------------------- /docs/env.md: -------------------------------------------------------------------------------- 1 | # Environment Variables 2 | 3 | - `ARK_ROOT` (Default: `/usr/local/ark`) 4 | 5 | The installation directory of ARK. For C++, defaults to `/usr/local/ark` when unset. For Python, defaults to the ARK Python module's path. 6 | 7 | - `ARK_LOG_LEVEL` (Default: `INFO`; Options: `DEBUG`, `INFO`, `WARN`, `ERROR`) 8 | 9 | The log level of ARK. Use `DEBUG` for verbose debugging information and use `ERROR` for quiet execution. 10 | 11 | - `ARK_TMP` (Default: `/tmp/ark`) 12 | 13 | A directory to store temporal files that ARK generates. 14 | 15 | - `ARK_KEEP_TMP` (Default: `1`; Options: `0`, `1`) 16 | 17 | If set to `1`, do not remove temporal files in the `ARK_TMP` directory, vice versa. 18 | 19 | - `ARK_HOSTFILE` (Default: `${ARK_ROOT}/hostfile`) 20 | 21 | Path to a hostfile. Need to set for multi-node execution. Ranks will be assigned in the order that hosts appear in the hostfile (`ARK_NUM_RANKS_PER_HOST` ranks per host). 22 | 23 | - `ARK_NUM_RANKS_PER_HOST` (Default: `8`) 24 | 25 | The number of ranks that each host runs. The behavior is undefined if the total number of ranks is not a multiple of `ARK_NUM_RANKS_PER_HOST`. 26 | 27 | - `ARK_DISABLE_IB` (Default: `0`; Options: `0`, `1`) 28 | 29 | If set to `1`, disable ibverbs networking (i.e., disable multi-node execution). 30 | -------------------------------------------------------------------------------- /docs/sphinx/.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | -------------------------------------------------------------------------------- /docs/sphinx/Makefile: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | SPHINXOPTS ?= 5 | SPHINXBUILD ?= sphinx-build 6 | SOURCEDIR = source 7 | BUILDDIR = build 8 | 9 | # Put it first so that "make" without argument is like "make help". 10 | help: 11 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 12 | 13 | .PHONY: help Makefile 14 | 15 | # Catch-all target: route all unknown targets to Sphinx using the new 16 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 17 | %: Makefile 18 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 19 | -------------------------------------------------------------------------------- /docs/sphinx/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | sphinx-book-theme 3 | sphinx-prompt 4 | rinohtype 5 | myst-parser 6 | -------------------------------------------------------------------------------- /docs/sphinx/source/api.rst: -------------------------------------------------------------------------------- 1 | ****************** 2 | API Document 3 | ****************** 4 | 5 | ``ark`` 6 | ===================== 7 | 8 | .. automodule:: ark 9 | :members: init, srand, rand, NO_DIM, DIMS_LEN 10 | 11 | 12 | ``ark.Dims`` 13 | ===================== 14 | 15 | .. automodule:: ark.Dims 16 | :members: size, ndims, __getitem__, __setitem__, __repr__ 17 | 18 | ``ark.Model`` 19 | ===================== 20 | 21 | .. automodule:: ark.Model 22 | :members: tensor, reshape, identity, sharding, reduce, layernorm, softmax, 23 | transpose, linear, im2col, conv2d, max_pool, scale, relu, gelu, add, mul, send, 24 | send_done, recv, all_reduce 25 | 26 | ``ark.Executor`` 27 | ===================== 28 | 29 | .. automodule:: ark.Executor 30 | :members: compile, launch, run, wait, stop 31 | 32 | -------------------------------------------------------------------------------- /docs/sphinx/source/conf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | # Configuration file for the Sphinx documentation builder. 5 | # 6 | # This file only contains a selection of the most common options. For a full 7 | # list see the documentation: 8 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 9 | 10 | # -- Path setup -------------------------------------------------------------- 11 | 12 | # If extensions (or modules to document with autodoc) are in another directory, 13 | # add these directories to sys.path here. If the directory is relative to the 14 | # documentation root, use os.path.abspath to make it absolute, like shown here. 15 | # 16 | import ark 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = "ARK" 21 | copyright = "2023, ARK Team" 22 | author = "ARK Team" 23 | version = "0.5.0" 24 | release = "0.5.0" 25 | 26 | # -- General configuration --------------------------------------------------- 27 | 28 | # Add any Sphinx extension module names here, as strings. They can be 29 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 30 | # ones. 31 | extensions = [ 32 | "myst_parser", 33 | "rinoh.frontend.sphinx", 34 | "sphinx.ext.todo", 35 | "sphinx.ext.viewcode", 36 | "sphinx.ext.autodoc", 37 | "sphinx.ext.napoleon", 38 | "sphinx-prompt", 39 | ] 40 | 41 | # Add any paths that contain templates here, relative to this directory. 42 | templates_path = ["_templates"] 43 | 44 | # List of patterns, relative to source directory, that match files and 45 | # directories to ignore when looking for source files. 46 | # This pattern also affects html_static_path and html_extra_path. 47 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 48 | 49 | 50 | # -- Options for HTML output ------------------------------------------------- 51 | 52 | # The theme to use for HTML and HTML Help pages. See the documentation for 53 | # a list of builtin themes. 54 | # 55 | html_theme = "sphinx_book_theme" 56 | 57 | # Add any paths that contain custom static files (such as style sheets) here, 58 | # relative to this directory. They are copied after the builtin static files, 59 | # so a file named "default.css" will overwrite the builtin "default.css". 60 | html_static_path = ["_static"] 61 | 62 | myst_enable_extensions = ["deflist"] 63 | 64 | # -- Options for PDF generation with Rinohtype ------------------------------- 65 | 66 | latex_elements = { 67 | "papersize": "letterpaper", 68 | "pointsize": "10pt", 69 | "preamble": "", 70 | "figure_align": "htbp", 71 | } 72 | -------------------------------------------------------------------------------- /docs/sphinx/source/index.rst: -------------------------------------------------------------------------------- 1 | .. Copyright (c) Microsoft Corporation. 2 | Licensed under the MIT license. 3 | 4 | Welcome to ARK's documentation! 5 | ================================= 6 | 7 | This document explains the usage of the ARK system. 8 | 9 | .. toctree:: 10 | :caption: ARK Guides 11 | 12 | api 13 | 14 | 15 | .. toctree:: 16 | 17 | ARK GitHub Repo 18 | -------------------------------------------------------------------------------- /examples/ffn/Makefile: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | ARK_ROOT ?= /usr/local/ark 5 | CUDIR ?= /usr/local/cuda 6 | 7 | CXX := g++ 8 | CXXFLAGS := -std=c++17 -Wall -Wextra 9 | INCLUDE := -I$(ARK_ROOT)/include -I $(CUDIR)/include -I$(ARK_ROOT)/include/kernels 10 | LDFLAGS := -L$(CUDIR)/lib64/stubs -Wl,-rpath,$(CUDIR)/lib64 11 | LDLIBS := -lcuda -lnvidia-ml -lnvrtc -lpthread -lrt -libverbs -lnuma 12 | 13 | all: build/ffn 14 | 15 | build/ffn: build/ffn.o 16 | $(CXX) -o $@ $< -L$(ARK_ROOT)/lib -lark $(LDFLAGS) $(LDLIBS) 17 | 18 | build/ffn.o: ffn.cc 19 | mkdir -p $(@D) 20 | $(CXX) -o $@ $(CXXFLAGS) $(INCLUDE) -c $< 21 | 22 | clean: 23 | rm -r build/ 24 | -------------------------------------------------------------------------------- /examples/llama/README.md: -------------------------------------------------------------------------------- 1 | # Llama2 over ARK 2 | 3 | Llama2 examples over ARK. 4 | 5 | ## Quick Start 6 | 7 | 0. Install ARK Python following the [ARK Install Instructions](../../docs/install.md). 8 | 9 | 1. Install Llama2 requirements. 10 | 11 | ```bash 12 | python3 -m pip install -r requirements.txt 13 | ``` 14 | 15 | 2. Update submodules. 16 | 17 | ```bash 18 | git submodule update --init --recursive 19 | ``` 20 | 21 | 3. Install `llama` submodule. 22 | 23 | ```bash 24 | cd llama 25 | python3 -m pip install -e . 26 | cd .. 27 | ``` 28 | 29 | 4. Download Llama2 model weights and tokenizer weights. 30 | * The model and tokenizer should be compatible with the [official PyTorch implementation](https://github.com/facebookresearch/llama/blob/main/llama). 31 | 32 | 5. Run the model accuracy test. `--pth_path` is the path to the model weights file (`consolidated.00.pth`). 33 | 34 | ```bash 35 | python3 model_test.py --pth_path=/path/to/model/weights.pth 36 | ``` 37 | 38 | 6. Test text generation. `--pth_path` is the path to the model weights file (`consolidated.00.pth`), `--tok_path` is the path to the tokenizer weights file (`tokenizer.model`), and `--params_path` is the path to the model parameters (`params.json`). 39 | 40 | ```bash 41 | python3 generator.py --pth_path=consolidated.00.pth --tok_path=tokenizer.model --params_path=params.json 42 | ``` 43 | 44 | ## Multi-GPU Inference 45 | 46 | Multi-GPU version will be added in a future release. 47 | -------------------------------------------------------------------------------- /examples/llama/generator_torch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import multiprocessing as mp 3 | import os 4 | 5 | from llama import Llama 6 | 7 | 8 | def worker(args: argparse.Namespace, rank: int): 9 | os.environ["RANK"] = str(rank) 10 | os.environ["LOCAL_RANK"] = str(rank) 11 | 12 | def log(msg): 13 | print(f"[Rank {rank}] {msg}") 14 | 15 | prompt_list = ["Where is the capital of France?"] 16 | generator = Llama.build( 17 | ckpt_dir=args.ckpt_dir, 18 | tokenizer_path=args.tok_path, 19 | max_seq_len=512, 20 | max_batch_size=1, 21 | ) 22 | generation_tokens = generator.generate( 23 | prompt_tokens=[ 24 | generator.tokenizer.encode( 25 | f"[INST] {prompt} [/INST]", bos=True, eos=False 26 | ) 27 | for prompt in prompt_list 28 | ], 29 | max_gen_len=512, 30 | temperature=0, 31 | top_p=0.9, 32 | logprobs=False, 33 | echo=False, 34 | ) 35 | output_text = [ 36 | {"generation": generator.tokenizer.decode(t)} for t in generation_tokens 37 | ] 38 | if rank == 0: 39 | log(f"{output_text}") 40 | return 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument("--ckpt_dir", type=str, required=True) 46 | parser.add_argument("--params_path", type=str, required=True) 47 | parser.add_argument("--tok_path", type=str, required=True) 48 | parser.add_argument("--ngpus", type=int, default=1) 49 | 50 | args = parser.parse_args() 51 | 52 | os.environ["MASTER_ADDR"] = "localhost" 53 | os.environ["MASTER_PORT"] = "29500" 54 | os.environ["WORLD_SIZE"] = str(args.ngpus) 55 | 56 | procs = [] 57 | for i in range(args.ngpus): 58 | p = mp.Process(target=worker, args=(args, i)) 59 | p.start() 60 | procs.append(p) 61 | 62 | for p in procs: 63 | p.join() 64 | -------------------------------------------------------------------------------- /examples/llama/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | fairscale 3 | sentencepiece 4 | -------------------------------------------------------------------------------- /examples/transformer/transformer_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import torch 5 | import numpy as np 6 | import torch.nn as nn 7 | 8 | import ark 9 | 10 | d_model = 512 # Dimension of word embeddings 11 | d_ff = 2048 # Dimension of the hidden layer in the feed-forward network 12 | d_k = d_v = 64 # Dimensions of K(=Q) and V in the attention mechanism 13 | n_layers = 2 # Number of encoder and decoder layers 14 | n_heads = 8 # Number of heads in Multi-Head Attention set to 8 15 | 16 | batch_size = 1 17 | seq_len = 64 18 | src_vocab_size = 128 19 | 20 | # The number of input tokens is 10 21 | # Used for constructing the masks 22 | input_seq_len = 10 23 | 24 | # Megatron-LM on 2 GPU 25 | num_gpu = 2 26 | n_heads_per_gpu = n_heads // num_gpu 27 | -------------------------------------------------------------------------------- /examples/tutorial/planner_tutorial.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import ark 5 | import time 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | 10 | class VanillaSoftmax(ark.Module): 11 | def __init__(self): 12 | super().__init__() 13 | 14 | def forward(self, input): 15 | max = ark.reduce_max(input, axis=-1) 16 | output = ark.sub(input, max) 17 | output = ark.exp(output) 18 | sum = ark.reduce_sum(output, axis=-1) 19 | output = ark.div(output, sum) 20 | return output 21 | 22 | 23 | class Softmax(ark.Module): 24 | def __init__(self): 25 | super().__init__() 26 | 27 | def forward(self, input): 28 | with ark.PlannerContext( 29 | warp_range=[0, 8], 30 | sram_range=[0, 0], 31 | sync=False, 32 | config={ 33 | "NumWarps": 1, 34 | "SramBytes": 0, 35 | "NumTasks": 65536, 36 | }, 37 | ): 38 | with ark.PlannerContext(config={"ImplType": "WarpWise"}): 39 | max = ark.reduce_max(input, axis=-1) 40 | with ark.PlannerContext(config={"Tile": [1, 2048]}): 41 | output = ark.sub(input, max) 42 | output = ark.exp(output) 43 | with ark.PlannerContext(config={"ImplType": "WarpWise"}): 44 | sum = ark.reduce_sum(output, axis=-1) 45 | with ark.PlannerContext(config={"Tile": [1, 2048]}): 46 | output = ark.div(output, sum) 47 | return output 48 | 49 | 50 | def eval(tensor: ark.Tensor): 51 | with ark.Runtime() as rt: 52 | rt.launch() 53 | rt.run() 54 | return tensor.to_torch() 55 | 56 | 57 | def perf(): 58 | with ark.Runtime() as rt: 59 | rt.launch() 60 | 61 | start = time.time() 62 | rt.run(iter=1000) 63 | end = time.time() 64 | return (end - start) / 1000 65 | 66 | 67 | if __name__ == "__main__": 68 | ark.init() 69 | 70 | shape = (32, 2048, 2048) 71 | 72 | # input = torch.randn(*shape).to("cuda:0") 73 | input = ark.tensor(shape) 74 | 75 | output = Softmax()(input) 76 | 77 | # if torch.allclose(eval(output), F.softmax(input, dim=-1), atol=1e-5): 78 | # print("Correct result") 79 | # else: 80 | # print("Incorrect result") 81 | 82 | print(f"Performance: {(perf() * 1e3):.3f} ms/iter") 83 | -------------------------------------------------------------------------------- /examples/tutorial/quickstart_tutorial.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import numpy as np 5 | import ark 6 | 7 | 8 | def quickstart_tutorial(): 9 | # Initialize the ARK environments 10 | ark.init() 11 | 12 | M, N = 64, 64 13 | # Create an input tensor 14 | input_tensor = ark.tensor([M, N], ark.fp16) 15 | # Create another tensor 16 | other_tensor = ark.tensor([M, N], ark.fp16) 17 | 18 | # Add the two tensors 19 | output_tensor = ark.add(input_tensor, other_tensor) 20 | 21 | # Initialize the ARK runtime 22 | runtime = ark.Runtime() 23 | 24 | # Launch the ARK runtime 25 | runtime.launch() 26 | 27 | # Initialize the input and other tensor with random values 28 | input_tensor_host = np.random.rand(M, N).astype(np.float16) 29 | input_tensor.from_numpy(input_tensor_host) 30 | other_tensor_host = np.random.rand(M, N).astype(np.float16) 31 | other_tensor.from_numpy(other_tensor_host) 32 | 33 | # Run the ARK program 34 | runtime.run() 35 | 36 | # Copy the output tensor from device memory to host memory, if dst is 37 | # None, a new numpy array of the same shape as the src tensor will be returned 38 | output_tensor_host = output_tensor.to_numpy() 39 | # Check if the output tensor is equal to the sum of the input and other tensor 40 | np.testing.assert_allclose( 41 | output_tensor_host, input_tensor_host + other_tensor_host 42 | ) 43 | 44 | # Stop the ARK runtime (undo Runtime.launch()) 45 | runtime.stop() 46 | 47 | # Reset the ARK runtime (free all resources) 48 | runtime.reset() 49 | 50 | print("Quickstart tutorial is successful!") 51 | 52 | 53 | if __name__ == "__main__": 54 | quickstart_tutorial() 55 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["scikit-build-core"] 3 | build-backend = "scikit_build_core.build" 4 | 5 | [project] 6 | name = "ark" 7 | version = "0.5.0" 8 | 9 | [tool.scikit-build] 10 | cmake.version = ">=3.25" 11 | cmake.args = [] 12 | cmake.verbose = false 13 | cmake.build-type = "Release" 14 | cmake.targets = ["ark_py"] 15 | wheel.packages = ["python/ark"] 16 | wheel.license-files = ["LICENSE", "CITATION.cff", "CODE_OF_CONDUCT.md", "README.md", "SECURITY.md", "SUPPORT.md"] 17 | install.strip = true 18 | build-dir = "build/{wheel_tag}" 19 | 20 | [tool.scikit-build.cmake.define] 21 | ARK_BUILD_PYTHON = "ON" 22 | 23 | [tool.black] 24 | line-length = 80 25 | target-version = ['py38'] 26 | include = '\.pyi?$' 27 | exclude = '/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist|third_party|docs|examples/llama/llama)/' 28 | -------------------------------------------------------------------------------- /python/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | find_package(Python3 REQUIRED COMPONENTS Interpreter Development.Module) 5 | if(Python3_FOUND) 6 | if(${Python3_VERSION} VERSION_LESS 3.8) 7 | message(FATAL_ERROR "Python version must be at least 3.8") 8 | endif() 9 | endif() 10 | 11 | include(FetchContent) 12 | FetchContent_Declare( 13 | pybind11 14 | GIT_REPOSITORY https://github.com/pybind/pybind11.git 15 | GIT_TAG v2.11.1 16 | ) 17 | FetchContent_MakeAvailable(pybind11) 18 | 19 | file(GLOB_RECURSE BIND_SOURCES CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) 20 | pybind11_add_module(ark_py ${BIND_SOURCES}) 21 | set_target_properties(ark_py PROPERTIES OUTPUT_NAME core LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ark) 22 | target_link_libraries(ark_py PRIVATE ark_static) 23 | target_include_directories(ark_py SYSTEM PRIVATE ${DLPACK_INCLUDE_DIRS}) 24 | target_include_directories(ark_py PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../ark) 25 | add_custom_target(py_copy 26 | COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/ark ${CMAKE_CURRENT_BINARY_DIR}/ark 27 | ) 28 | add_dependencies(ark_py py_copy) 29 | 30 | if(ARK_USE_CUDA) 31 | target_include_directories(ark_py SYSTEM PRIVATE 32 | ${CUDAToolkit_INCLUDE_DIRS} 33 | ) 34 | endif() 35 | 36 | if(ARK_USE_ROCM) 37 | target_include_directories(ark_py SYSTEM PRIVATE 38 | /opt/rocm/include 39 | ) 40 | endif() 41 | -------------------------------------------------------------------------------- /python/ark/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import os 5 | 6 | if os.environ.get("ARK_ROOT", None) is None: 7 | os.environ["ARK_ROOT"] = os.path.abspath(os.path.dirname(__file__)) 8 | 9 | from .core import version 10 | from .model import Model 11 | 12 | 13 | __version__ = version() 14 | 15 | 16 | def version(): 17 | """Returns the version of ARK.""" 18 | return __version__ 19 | 20 | 21 | def set_rank(rank): 22 | """Sets the rank of the current process.""" 23 | Model.set_rank(rank) 24 | 25 | 26 | def set_world_size(world_size): 27 | """Sets the world size of the current process.""" 28 | Model.set_world_size(world_size) 29 | 30 | 31 | from .init import init 32 | from .tensor import Dims, Tensor, Parameter 33 | from .module import Module 34 | from .runtime import Runtime 35 | from .serialize import save, load 36 | from .data_type import * 37 | from .ops import * 38 | from .planner import * 39 | from .error import * 40 | -------------------------------------------------------------------------------- /python/ark/error.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .core import BaseError 5 | from .core import InternalError 6 | from .core import InvalidUsageError 7 | from .core import ModelError 8 | from .core import PlanError 9 | from .core import UnsupportedError 10 | from .core import SystemError 11 | from .core import GpuError 12 | 13 | __all__ = [ 14 | "BaseError", 15 | "InternalError", 16 | "InvalidUsageError", 17 | "ModelError", 18 | "PlanError", 19 | "UnsupportedError", 20 | "SystemError", 21 | "GpuError", 22 | ] 23 | -------------------------------------------------------------------------------- /python/ark/init.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from . import core 5 | from .model import Model 6 | from .runtime import RuntimeState 7 | 8 | __all__ = ["init"] 9 | 10 | 11 | def init(): 12 | """Initializes ARK.""" 13 | Model.reset() 14 | if RuntimeState.executor is not None: 15 | if not RuntimeState.executor.destroyed(): 16 | RuntimeState.executor.destroy() 17 | core.init() 18 | -------------------------------------------------------------------------------- /python/ark/log.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import inspect 5 | from .core import LogLevel, log 6 | from .error import * 7 | from .error import __all__ as error_all 8 | 9 | __all__ = [*error_all, "DEBUG", "INFO", "WARN"] 10 | 11 | 12 | def DEBUG(msg: str) -> None: 13 | frame = inspect.currentframe().f_back 14 | info = inspect.getframeinfo(frame) 15 | log(LogLevel.DEBUG, info.filename, info.lineno, msg) 16 | 17 | 18 | def INFO(msg: str) -> None: 19 | frame = inspect.currentframe().f_back 20 | info = inspect.getframeinfo(frame) 21 | log(LogLevel.INFO, info.filename, info.lineno, msg) 22 | 23 | 24 | def WARN(msg: str) -> None: 25 | frame = inspect.currentframe().f_back 26 | info = inspect.getframeinfo(frame) 27 | log(LogLevel.WARN, info.filename, info.lineno, msg) 28 | -------------------------------------------------------------------------------- /python/ark/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from typing import NewType 5 | from .core import CoreModel 6 | 7 | 8 | __all__ = ["Model"] 9 | 10 | ModelState = NewType("ModelState", None) 11 | 12 | 13 | class Model(CoreModel): 14 | @staticmethod 15 | def get_model(): 16 | """ 17 | Get the underlying model. 18 | """ 19 | if ModelState.model is None: 20 | ModelState.model = Model(ModelState.rank, ModelState.world_size) 21 | return ModelState.model 22 | 23 | @staticmethod 24 | def get_rank(): 25 | """ 26 | Get the rank of the model. 27 | """ 28 | return ModelState.rank 29 | 30 | @staticmethod 31 | def get_world_size(): 32 | """ 33 | Get the world size of the model. 34 | """ 35 | return ModelState.world_size 36 | 37 | @staticmethod 38 | def set_rank(rank: int): 39 | """ 40 | Set the rank of the model. 41 | """ 42 | ModelState.rank = rank 43 | 44 | @staticmethod 45 | def set_world_size(world_size: int): 46 | """ 47 | Set the world size of the model. 48 | """ 49 | ModelState.world_size = world_size 50 | 51 | @staticmethod 52 | def reset(): 53 | """ 54 | Reset the model state. 55 | """ 56 | ModelState.model = None 57 | ModelState.rank = 0 58 | ModelState.world_size = 1 59 | 60 | def compress(self) -> "Model": 61 | """ 62 | Compress the model. 63 | """ 64 | return super().compress() 65 | 66 | def serialize(self, pretty: bool = True) -> str: 67 | """ 68 | Serialize the model. 69 | 70 | Args: 71 | pretty: Whether to pretty print the model. 72 | 73 | Returns: 74 | The serialized model. 75 | """ 76 | return super().serialize(pretty) 77 | 78 | 79 | class ModelState: 80 | """ 81 | The ModelState class is used to store the state of the model. 82 | """ 83 | 84 | model: Model = None 85 | rank: int = 0 86 | world_size: int = 1 87 | -------------------------------------------------------------------------------- /python/ark/serialize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import pickle 5 | import logging 6 | 7 | 8 | def save(state_dict, state_dict_file_path: str): 9 | """ 10 | Save the state_dict of a module to a file 11 | """ 12 | if not isinstance(state_dict, dict): 13 | logging.warn( 14 | "Warning: Invalid state_dict saved to", state_dict_file_path 15 | ) 16 | with open(state_dict_file_path, "wb") as f: 17 | pickle.dump(state_dict, f) 18 | 19 | 20 | def load(state_dict_file_path: str): 21 | """ 22 | Load the state_dict of a module from a file 23 | """ 24 | with open(state_dict_file_path, "rb") as f: 25 | state_dict = pickle.load(f) 26 | if not isinstance(state_dict, dict): 27 | logging.warn("Warning: Invalid state_dict file") 28 | return state_dict 29 | -------------------------------------------------------------------------------- /python/ark_py.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | namespace py = pybind11; 9 | 10 | extern void register_data_type(py::module &m); 11 | extern void register_dims(py::module &m); 12 | extern void register_error(py::module &m); 13 | extern void register_executor(py::module &m); 14 | extern void register_init(py::module &m); 15 | extern void register_log(py::module &m); 16 | extern void register_model_graph(py::module &m); 17 | extern void register_model(py::module &m); 18 | extern void register_planner(py::module &m); 19 | extern void register_random(py::module &m); 20 | extern void register_tensor(py::module &m); 21 | extern void register_version(py::module &m); 22 | 23 | PYBIND11_MODULE(core, m) { 24 | m.doc() = "Bind ARK C++ APIs to Python"; 25 | 26 | register_data_type(m); 27 | register_dims(m); 28 | register_error(m); 29 | register_executor(m); 30 | register_init(m); 31 | register_log(m); 32 | register_model_graph(m); 33 | register_model(m); 34 | register_planner(m); 35 | register_random(m); 36 | register_tensor(m); 37 | register_version(m); 38 | } 39 | -------------------------------------------------------------------------------- /python/data_type_py.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | namespace py = pybind11; 11 | 12 | void register_data_type(py::module &m) { 13 | py::class_(m, "CoreDataType") 14 | .def("__eq__", &ark::DataType::operator==) 15 | .def("__ne__", &ark::DataType::operator!=) 16 | .def("is_null", &ark::DataType::is_null) 17 | .def("bytes", &ark::DataType::bytes) 18 | .def("name", &ark::DataType::name, py::return_value_policy::reference) 19 | .def_static("from_name", &ark::DataType::from_name); 20 | 21 | m.attr("NONE") = &ark::NONE; 22 | m.attr("FP32") = &ark::FP32; 23 | m.attr("FP16") = &ark::FP16; 24 | m.attr("BF16") = &ark::BF16; 25 | m.attr("INT32") = &ark::INT32; 26 | m.attr("UINT32") = &ark::UINT32; 27 | m.attr("INT8") = &ark::INT8; 28 | m.attr("UINT8") = &ark::UINT8; 29 | m.attr("BYTE") = &ark::BYTE; 30 | } 31 | -------------------------------------------------------------------------------- /python/dims_py.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | namespace py = pybind11; 12 | 13 | void register_dims(py::module &m) { 14 | m.attr("DIMS_LEN") = py::int_(ark::DIMS_LEN); 15 | 16 | py::class_(m, "CoreDims") 17 | .def(py::init<>()) 18 | .def(py::init()) 19 | .def(py::init()) 20 | .def(py::init()) 21 | .def(py::init()) 22 | .def(py::init()) 23 | .def(py::init &>()) 24 | .def("nelems", &ark::Dims::nelems) 25 | .def("ndims", &ark::Dims::ndims) 26 | .def("vector", &ark::Dims::vector, py::return_value_policy::reference) 27 | .def("__getitem__", 28 | [](const ark::Dims &d, ark::DimType idx) { return d[idx]; }) 29 | .def("__setitem__", [](ark::Dims &d, ark::DimType idx, 30 | ark::DimType value) { d[idx] = value; }) 31 | .def("__repr__", [](const ark::Dims &d) { 32 | std::ostringstream os; 33 | os << d; 34 | return os.str(); 35 | }); 36 | } 37 | -------------------------------------------------------------------------------- /python/error_py.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | namespace py = pybind11; 11 | 12 | #define REGISTER_ERROR_PY(_name) \ 13 | py::register_exception(m, #_name, m.attr("BaseError").ptr()) 14 | 15 | void register_error(py::module &m) { 16 | py::register_exception(m, "BaseError"); 17 | 18 | REGISTER_ERROR_PY(InternalError); 19 | REGISTER_ERROR_PY(InvalidUsageError); 20 | REGISTER_ERROR_PY(ModelError); 21 | REGISTER_ERROR_PY(PlanError); 22 | REGISTER_ERROR_PY(UnsupportedError); 23 | REGISTER_ERROR_PY(SystemError); 24 | REGISTER_ERROR_PY(GpuError); 25 | } 26 | -------------------------------------------------------------------------------- /python/init_py.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | namespace py = pybind11; 11 | 12 | void register_init(py::module &m) { m.def("init", &ark::init); } 13 | -------------------------------------------------------------------------------- /python/log_py.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | namespace py = pybind11; 11 | 12 | void register_log(py::module &m) { 13 | py::enum_(m, "LogLevel") 14 | .value("DEBUG", ark::LogLevel::DEBUG) 15 | .value("INFO", ark::LogLevel::INFO) 16 | .value("WARN", ark::LogLevel::WARN) 17 | .value("ERROR", ark::LogLevel::ERROR) 18 | .export_values(); 19 | m.def("log", &ark::log); 20 | } 21 | -------------------------------------------------------------------------------- /python/model_graph_py.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | namespace py = pybind11; 11 | 12 | void register_model_graph(py::module &m) { 13 | py::class_(m, "CoreModelGraph") 14 | .def("serialize", &ark::ModelGraph::serialize, 15 | py::arg("pretty") = true); 16 | } 17 | -------------------------------------------------------------------------------- /python/planner_py.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | namespace py = pybind11; 12 | 13 | void register_planner(py::module &m) { 14 | py::class_(m, "CorePlannerContext") 15 | .def(py::init()) 16 | .def("processor_range", &ark::PlannerContext::processor_range, 17 | py::arg("start"), py::arg("end"), py::arg("step") = 1) 18 | .def("warp_range", &ark::PlannerContext::warp_range, py::arg("start"), 19 | py::arg("end"), py::arg("step") = 1) 20 | .def("sram_range", &ark::PlannerContext::sram_range, py::arg("start"), 21 | py::arg("end"), py::arg("step") = 1) 22 | .def("sync", &ark::PlannerContext::sync, py::arg("sync")) 23 | .def("config", &ark::PlannerContext::config, py::arg("config")); 24 | 25 | py::class_(m, "CorePlanner") 26 | .def(py::init()) 27 | .def("install_config_rule", 28 | [](ark::Planner *self, const py::function &rule) { 29 | self->install_config_rule( 30 | [rule](const std::string &op, const std::string &arch) { 31 | return rule(op, arch).cast(); 32 | }); 33 | }) 34 | .def("plan", &ark::Planner::plan, py::arg("pretty") = true); 35 | } 36 | -------------------------------------------------------------------------------- /python/random_py.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | namespace py = pybind11; 11 | 12 | void register_random(py::module &m) { 13 | m.def("srand", &ark::srand, py::arg("seed")); 14 | m.def("rand", &ark::rand); 15 | } 16 | -------------------------------------------------------------------------------- /python/tensor_py.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | namespace py = pybind11; 11 | 12 | void register_tensor(py::module &m) { 13 | py::class_(m, "CoreTensor") 14 | .def("id", &ark::Tensor::id) 15 | .def("shape", &ark::Tensor::shape, py::return_value_policy::reference) 16 | .def("strides", &ark::Tensor::strides, 17 | py::return_value_policy::reference) 18 | .def("offsets", &ark::Tensor::offsets, 19 | py::return_value_policy::reference) 20 | .def("padded_shape", &ark::Tensor::padded_shape, 21 | py::return_value_policy::reference) 22 | .def("data_type", &ark::Tensor::data_type, 23 | py::return_value_policy::reference); 24 | 25 | m.attr("NullTensor") = &ark::NullTensor; 26 | } 27 | -------------------------------------------------------------------------------- /python/unittest/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import pytest 5 | import ark 6 | 7 | 8 | def pytest_ark(need_torch: bool = False): 9 | """ 10 | Decorator for ARK unit tests. 11 | """ 12 | 13 | def decorator(test_func): 14 | if need_torch: 15 | try: 16 | import torch 17 | except ImportError: 18 | return pytest.mark.skip(reason="torch is not installed")( 19 | test_func 20 | ) 21 | 22 | def wrapper(*args, **kwargs): 23 | ark.init() 24 | test_func(*args, **kwargs) 25 | 26 | return wrapper 27 | 28 | return decorator 29 | -------------------------------------------------------------------------------- /python/unittest/test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from test_data_type import * 5 | from test_error import * 6 | from test_model import * 7 | from test_ops import * 8 | from test_runtime import * 9 | -------------------------------------------------------------------------------- /python/unittest/test_error.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from common import ark, pytest_ark 5 | 6 | 7 | @pytest_ark() 8 | def test_error(): 9 | try: 10 | raise ark.InternalError("test") 11 | except ark.BaseError as e: 12 | assert isinstance(e, ark.InternalError) 13 | assert str(e) == "test" 14 | -------------------------------------------------------------------------------- /python/unittest/test_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from common import ark, pytest_ark 5 | import json 6 | 7 | 8 | @pytest_ark() 9 | def test_model(): 10 | input_tensor = ark.tensor([64, 64], ark.fp16) 11 | other_tensor = ark.tensor([64, 64], ark.fp16) 12 | ark.add(input_tensor, other_tensor) 13 | 14 | m = ark.Model.get_model().compress() 15 | m_json = json.loads(m.serialize()) 16 | 17 | assert m_json.get("Nodes", None) is not None 18 | assert len(m_json["Nodes"]) == 1 19 | assert m_json["Nodes"][0].get("Op", None) is not None 20 | assert m_json["Nodes"][0]["Op"].get("Type", None) == "Add" 21 | 22 | ark.Model.reset() 23 | 24 | m = ark.Model.get_model().compress() 25 | m_json = json.loads(m.serialize()) 26 | 27 | assert m_json.get("Nodes", None) is not None 28 | assert len(m_json["Nodes"]) == 0 29 | -------------------------------------------------------------------------------- /python/unittest/test_runtime.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from common import ark, pytest_ark 5 | 6 | 7 | @pytest_ark() 8 | def test_runtime_empty(): 9 | with ark.Runtime.get_runtime() as rt: 10 | rt.launch() 11 | rt.run() 12 | rt.stop() 13 | -------------------------------------------------------------------------------- /python/version_py.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | namespace py = pybind11; 11 | 12 | void register_version(py::module &m) { m.def("version", &ark::version); } 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-build-core 2 | pyproject_metadata 3 | pytest-cov 4 | numpy 5 | --------------------------------------------------------------------------------