├── .clang-format ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── CMakeLists.txt ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── benchmarks ├── CMakeLists.txt ├── manual_benchmark.cu └── sweep_parameters.sh ├── cuembed ├── README.md └── include │ ├── embedding_lookup.cuh │ ├── embedding_lookup_kernels.cuh │ ├── embedding_lookup_ops.cuh │ ├── embedding_lookup_types.cuh │ ├── index_transforms.cuh │ └── index_transforms_kernels.cuh ├── docs └── Doxyfile ├── examples └── pytorch │ ├── CMakeLists.txt │ ├── cuembed_embedding.cu │ ├── cuembed_pyt.py │ └── cuembed_test.py ├── tests ├── CMakeLists.txt ├── test_datagen.cpp ├── test_embedding_against_cpu.cu ├── test_embedding_allocation.cu ├── test_embedding_backward.cu ├── test_embedding_forward.cu ├── test_embedding_ops.cu ├── test_embedding_transpose.cu └── test_third_party_utils.cu └── utils ├── CMakeLists.txt ├── include ├── datagen.h ├── embedding_allocation.h ├── embedding_lookup_cpu.hpp ├── embedding_utils.h └── index_transforms_cpu.hpp └── src ├── datagen.cpp ├── embedding_allocation.cu ├── embedding_cpu.cu ├── embedding_gpu_backward.cu ├── embedding_gpu_forward.cu └── embedding_gpu_transpose.cu /.clang-format: -------------------------------------------------------------------------------- 1 | # Run manually to reformat a file: 2 | # clang-format -i --style=file 3 | BasedOnStyle: Google 4 | DerivePointerAlignment: false 5 | BinPackArguments: false 6 | BinPackParameters: false 7 | StatementMacros: 8 | - _Pragma 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.so 3 | *.o 4 | *cache* 5 | .build* 6 | *.npy 7 | bin/* 8 | *.log 9 | build* 10 | docs/html 11 | .vscode/* 12 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/gtest"] 2 | path = third_party/gtest 3 | url = https://github.com/google/googletest.git 4 | [submodule "third_party/abseil-cpp"] 5 | path = third_party/abseil-cpp 6 | url = https://github.com/abseil/abseil-cpp.git 7 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # File introduces automated checks triggered on git events 2 | # to enable run `pip install pre-commit && pre-commit install` 3 | 4 | repos: 5 | - repo: local 6 | hooks: 7 | - id: clang-format 8 | name: clang-format 9 | language: system 10 | entry: clang-format 11 | args: [-i] 12 | files: \.(c|cu|cc|cxx|cpp|h|hpp|hxx|cuh)$ 13 | - id: cpplint 14 | name: cpplint 15 | language: system 16 | entry: cpplint 17 | files: \.(c|cu|cc|cxx|cpp|h|hpp|hxx|cuh)$ 18 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | cmake_minimum_required(VERSION 3.23) 17 | set(CMAKE_CXX_STANDARD 17) 18 | project(cuembed CXX CUDA) 19 | enable_language(CUDA) 20 | 21 | set(CMAKE_CUDA_ARCHITECTURES 70 75 80 90) 22 | set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --ptxas-options=-v") 23 | 24 | set(ABSL_PROPAGATE_CXX_STD ON) 25 | 26 | option(BUILD_TESTS "Build the tests" ON) 27 | option(BUILD_BENCHMARKS "Build the benchmarks" ON) 28 | option(BUILD_EXAMPLES "Build examples" OFF) 29 | 30 | find_package(CUDAToolkit) 31 | 32 | set(CUEMBED_PROJECT_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}) 33 | 34 | add_library(cuembed_hdrs INTERFACE ${cuembed_source_files}) 35 | target_include_directories(cuembed_hdrs INTERFACE ${CUEMBED_PROJECT_SOURCE_DIR}) 36 | add_library(cuembed::hdrs ALIAS cuembed_hdrs) 37 | 38 | if (BUILD_TESTS OR BUILD_BENCHMARKS) 39 | # TODO(zejiaz): move to CPM instead of submodule 40 | add_subdirectory(third_party/abseil-cpp) 41 | 42 | # Utility library for benchmarking and testing. 43 | add_subdirectory(utils) 44 | endif() 45 | 46 | # Setup tests 47 | if(BUILD_TESTS) 48 | add_subdirectory(third_party/gtest) 49 | add_subdirectory(tests) 50 | endif() 51 | 52 | # Benchmarks. 53 | if (BUILD_BENCHMARKS) 54 | add_subdirectory(benchmarks) 55 | endif() 56 | 57 | # Examples 58 | if (BUILD_EXAMPLES) 59 | add_subdirectory(examples/pytorch) 60 | endif() -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | Thank you for contributing to cuEmbed! We greatly appreciate your help and engagement. Please follow these guidelines. 3 | 4 | ## Pull Requests 5 | - Create pull requests targeting the main branch. 6 | 7 | - Individual pull requests should be limited in scope and concise. Commit messages should be clear to facilitate code review. Please do participate in code reviews to help improve the quality of the codebase. 8 | 9 | - Correctness tests are located in the tests/ folder and can be built by setting the BUILD\_TESTS option to ON in `CMakeLists.txt`. Please add tests, and/or modify existing tests, to cover the feature you are adding. All tests should pass or the pull request will not be accepted. 10 | 11 | - Test should also be pass with [NVIDA Compute Sanitizer](https://developer.nvidia.com/compute-sanitizer), which is included with the CUDA Toolkit. For more information about NVIDIA Compute Sanitizer, please refer to the [documentation](https://docs.nvidia.com/cuda/compute-sanitizer/index.html). 12 | 13 | - Please add or update documentation for any new code or features both inline with the code as well as in the `README.md` files where applicable. Examples of high-quality docstrings can be found throughout the code. 14 | 15 | - Please adhere to formatting and linting checks (e.g., clang-format, cpplint). Verify this by first installing via `pip install pre-commit && pre-commit install`, then running `pre-commit --files all` from the source directory. Code that does not pass these checks will not be accepted. 16 | 17 | - If your change is likely to impact performance, please run performance benchmarks (i.e. `cd benchmarks ; ./sweep_parameters.sh`) both with and without your change to check for performance regressions. If you have any questions about what constitutes a performance regression, please bring this up during the code review. 18 | 19 | - Please sign your commit using `git commit -s` or `--signoff` to certify that your work can be contributed to open source. 20 | 21 | - Please don't hesitate to open an issue and we will do our best to reply promptly. 22 | 23 | ## Developer Certificate of Origin 24 | 25 | By signing your commit using `-s` or `--signoff`, you are certifying the following: 26 | 27 | ``` 28 | Developer Certificate of Origin 29 | Version 1.1 30 | 31 | Copyright (C) 2004, 2006 The Linux Foundation and its contributors. 32 | 33 | Everyone is permitted to copy and distribute verbatim copies of this 34 | license document, but changing it is not allowed. 35 | 36 | 37 | Developer's Certificate of Origin 1.1 38 | 39 | By making a contribution to this project, I certify that: 40 | 41 | (a) The contribution was created in whole or in part by me and I 42 | have the right to submit it under the open source license 43 | indicated in the file; or 44 | 45 | (b) The contribution is based upon previous work that, to the best 46 | of my knowledge, is covered under an appropriate open source 47 | license and I have the right under that license to submit that 48 | work with modifications, whether created in whole or in part 49 | by me, under the same open source license (unless I am 50 | permitted to submit under a different license), as indicated 51 | in the file; or 52 | 53 | (c) The contribution was provided directly to me by some other 54 | person who certified (a), (b) or (c) and I have not modified 55 | it. 56 | 57 | (d) I understand and agree that this project and the contribution 58 | are public and that a record of the contribution (including all 59 | personal information I submit with it, including my sign-off) is 60 | maintained indefinitely and may be redistributed consistent with 61 | this project or the open source license(s) involved. 62 | ``` 63 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # cuEmbed: embedding lookup kernel library 2 | 3 | ## Overview 4 | cuEmbed is an open-source, header-only CUDA kernel library that accelerates embedding lookup. It aims to achieve high memory bandwidth utilization by maximizing loads in flight when accessing embedding rows. It makes extensive use of C++ templates and compile-time specialization to support a variety of embedding lookup configurations using only a small number of kernels optimized for memory-level parallelism. All of this is intended to make it easy for developers to achieve high performance on embedding lookups in their CUDA programs. 5 | 6 | Supported Operations: 7 | - Forward propagation (fixed-hotness or CSR index formats). 8 | - Backward propagation (COO index format, full or compressed gradients). 9 | - Index transformations (e.g., transpose). 10 | 11 | ## Development Status 12 | 13 | `cuEmbed` is still under development. We aim to keep the host API stable. Users should expect changes in the kernel API and corresponding abstractions of operations. 14 | 15 | ## How to use 16 | Core components of cuEmbed are the kernel headers in the `cuembed/include` directory. These files have minimal dependency on third-party libraries and are safe to be copied into separate libraries. 17 | 18 | ### Adding `cuEmbed` to a CMake Project 19 | We recommend using [CMake Package Manager (CPM)](https://github.com/cpm-cmake/CPM.cmake) to fetch cuEmbed into your project. With CPM, getting cuEmbed is easy: 20 | ``` 21 | CPMAddPackage( 22 | NAME cuembed 23 | GIT_REPOSITORY https://rep_ro:${GITLAB_TOKEN}@gitlab-master.nvidia.com/compute/psx/recommender/cuembed.git 24 | GIT_TAG main 25 | OPTIONS 26 | "BUILD_TESTS OFF" 27 | "BUILD_BENCHMARKS OFF" 28 | ) 29 | 30 | target_link_libraries(my_library ${cuembed_SOURCE_DIR}) 31 | ``` 32 | 33 | ### Example usage: Forward Propagation 34 | The following example from `utils/src/embedding_allocation.cu` covers the basic usage of the host API for running forward propagation: 35 | ```cpp 36 | template 37 | void RunForward(const utils::AllocationOptions& options, 38 | const thrust::device_vector& embedding, 39 | const thrust::device_vector& indices, 40 | const thrust::device_vector& offsets, 41 | const thrust::device_vector& weights, 42 | thrust::device_vector* result) { 43 | const int* offsets_ptr = nullptr; 44 | int hotness = options.hotness(); 45 | if (options.is_csr()) { 46 | offsets_ptr = offsets.data().get(); 47 | hotness = 0; 48 | } 49 | const ElemT* weight_ptr = nullptr; 50 | if (options.is_weighted()) { 51 | weight_ptr = weights.data().get(); 52 | } 53 | using InputT = ElemT; 54 | using OutputT = ElemT; 55 | EmbeddingForward( 56 | embedding.data().get(), 57 | options.embed_width(), 58 | indices.data().get(), 59 | offsets_ptr, 60 | weight_ptr, 61 | options.batch_size(), 62 | hotness, 63 | options.combine_mode(), 64 | result->data().get()); 65 | } 66 | ``` 67 | In the above example, we call `EmbeddingForward` with the corresponding data pointers from the embedding table (i.e., `embedding`), the embedding row indices (i.e., `indices`) & offsets indicating the starting position of each set of indices (i.e., `offsets`) & per sample weights (i.e., `weights`), the output of embedding lookup (i.e., `result`), and workload descriptions (i.e., `embedding_width`, `hotness`, `batch_size`, `combine_mode` unwrapped from `options`). The end result of embedding lookup is written into `result`. 68 | 69 | Please refer to `utils/src/embedding_allocation.cu` for more examples, including index transposition and backward propagation. 70 | 71 | Detailed descriptions of the full API and parameters can be found in [cuembed/README.md](https://gitlab-master.nvidia.com/compute/psx/recommender/cuembed/-/blob/main/cuembed/README.md?ref_type=heads). 72 | 73 | ## Building cuEmbed tests and benchmarks 74 | Since cuEmbed is header-only, there is nothing to build to use it. 75 | To build the tests and benchmarks: 76 | 77 | ### Build From Source 78 | ```bash 79 | git clone --recursive https://gitlab-master.nvidia.com/compute/psx/recommender/cuembed 80 | cd cuembed 81 | mkdir build 82 | cd build 83 | cmake -DCMAKE_BUILD_TYPE=Release .. 84 | make 85 | ``` 86 | Binaries will be built into: 87 | - `build/tests` 88 | - `build/benchmarks` 89 | 90 | ## Benchmarks 91 | ### Full Suite Benchmarks 92 | To run benchmarks locally: 93 | ```bash 94 | cd benchmarks/ 95 | ./sweep_parameters.sh 96 | ``` 97 | 98 | ### Manual Benchmark Single Test Case 99 | 100 | Manual benchmarking can be done with the `manual_benchmark` binary in the `benchmarks` folder. This will run the forward, transpose, and backward stages. 101 | 102 | Example: 103 | ```bash 104 | ./bin/benchmarks/manual_benchmark --num_categories 10000000 --embed_width 256 --batch_size 65536 --alpha=1.15 --hotness=64 --csr_input=false --half_embedding_type=true --weighted_sum=false --compressed_grad=true 105 | ``` 106 | 107 | ## Detailed Support Matrix 108 | | | Supported In Current Release | Future Release | 109 | |-----------------------------|-------------------------|--------------------------------------| 110 | | Embedding table size | single table single GPU | multiple tables and multiple devices | 111 | | Embedding cache integration | no | yes | 112 | | Embedding & Output types | fp32, fp16 | bf16 | 113 | | Lookup Index types | int32_t, int64_t | | 114 | | Lookup Index Layout (fwd) | fixed hotness, CSR | COO | 115 | | Lookup Index Layout (bwd) | COO | | 116 | | Reduction type (fwd) | weighted sum, concat, mean | | 117 | | Reduction type (bwd) | weighted sum, concat | mean | 118 | | Reduction precision | fp32, fp16 | bf16 | 119 | | Kernel type | fwd, bwd, transpose | optimizer | 120 | 121 | ## Requirements 122 | - nvcc 12.0+ 123 | - C++ 17 124 | - Volta+ 125 | -------------------------------------------------------------------------------- /benchmarks/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | find_package(CUDAToolkit REQUIRED) 17 | 18 | add_executable(manual_benchmark manual_benchmark.cu) 19 | target_include_directories(manual_benchmark PRIVATE ${CUDAToolkit_INCLUDE_DIRS}) 20 | target_link_libraries( 21 | manual_benchmark PRIVATE 22 | cuembed_hdrs 23 | cuda 24 | utils 25 | absl::log 26 | absl::log_initialize 27 | absl::check 28 | absl::flags 29 | absl::flags_parse) 30 | set_target_properties(manual_benchmark 31 | PROPERTIES 32 | RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin/benchmarks") 33 | -------------------------------------------------------------------------------- /benchmarks/sweep_parameters.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # SPDX-License-Identifier: Apache-2.0 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | rm -f manual_benchmark_out.csv 19 | 20 | benchmark=${1:-../build/bin/benchmarks/manual_benchmark} 21 | for alpha in 0.0 1.05 1.15 22 | do 23 | for num_categories in 1000000 10000000 24 | do 25 | for embed_width in 32 128 26 | do 27 | for batch in 1024 32768 131072 28 | do 29 | for hotness in 1 16 64 30 | do 31 | ${benchmark} --num_categories "${num_categories}" --embed_width "${embed_width}" --batch_size "${batch}" --alpha=${alpha} --hotness="${hotness}" --iterations=1000 --enable_csv 32 | done 33 | done 34 | done 35 | done 36 | done 37 | -------------------------------------------------------------------------------- /cuembed/README.md: -------------------------------------------------------------------------------- 1 | # cuEmbed 2 | This directory contains the core implementation of cuembed. 3 | 4 | The library is divided into *embedding lookup* operations (found in `embedding_lookup.cuh`) and *index transformations* (found in `index_transformations.cuh`). 5 | 6 | ## Embedding Lookup Operations 7 | Embedding lookup operations (forward and backward) follow a similar computational pattern, primary consisting of three kinds of operations: 8 | - Read the lookup indices. 9 | - Read the corresponding input rows. 10 | - Accumulate & write output. 11 | 12 | This embedding forward and backward implementation aims to achieve SOL memory bandwidth utilization by maximizing loads in flight when reading embedding rows. Launch parameters of the provided forward and backward kernels are determined by heuristics and the workload description. 13 | 14 | `embedding_lookup.cuh` contains the following primary functions for forward and backward propagation. 15 | 16 | ### Forward Propagation 17 | 18 | The embedding forward operation accepts an embedding table in a dense row-major format, and lookup indicies in either a fixed-hotness format or compressed-sparse-row (CSR) format and processes them using one of three currently supported reduction modes: `CombineMode::kSum`, `CombineMode::kMean`, and `CombineMode::kConcat`. 19 | 20 | ```cpp 21 | template 26 | void EmbeddingForward(const InputT* params, 27 | const int embed_width, 28 | const IndexT* indices, 29 | const OffsetT* offsets, 30 | const typename GetElemT* weights, 31 | const int batch_size, 32 | const int num_hots, 33 | const CombineMode mode, 34 | OutputT* ret, 35 | const cudaStream_t stream = 0) 36 | ``` 37 | - For fixed hotness indices, `num_hots` indicates the hotness value, which is the same for every sample in the batch. `offsets` must be nullptr since there is no explicit offset array to be passed to the kernel. 38 | 39 | - For CSR indices, `num_hots` must be 0. `offsets` points to the data of the explicit offset array indicating the starting point of indices for each sample in the batch. 40 | 41 | - For reduction type sum, `weights` can be nullptr to indicate plain reduction. Otherwise the weight of a specific lookup indice would be apply to the loaded rows before reduction. 42 | 43 | - When `fp16_math` is true, math operations (multiplication and summation) on fp16 embedding rows are performed in fp16. 44 | 45 | __Parameters__ 46 | - __params__ Pointer to the embedding table data. 47 | - __embed_width__ Number of elements in each embedding row. 48 | - __indices__ Pointer to the lookup indices. 49 | - __offsets__ Pointer to the offsets (CSR format). Must be nullptr when launching for fixed hotness. 50 | - __weights__ Pointer to the weight array. Weight for a specific lookup index is applied to the loaded embedding row before reduction. If nullptr, will use just the embedding row for reduction. The type for the weights must the be the same as the input type. If the input type is structured, then the user need to define their own `GetElemT` specialization. 51 | - __batch_size__ Batch size of the embedding lookup workload. 52 | - __num_hots__ Number of rows to lookup for each sample in batch. Must be 0 when launching for CSR indices layout. 53 | - __mode__ `ReductionType::kSum` (computes the summation of the looked up rows for each sample) or `ReductionType::kConcat` (concatenates all looked up rows). 54 | - __ret__ Pointer to the output location. 55 | - __stream__ Optional. The cudaStream to launch the kernel asynchronously. If not specified, will launch the kernel on default stream. 56 | 57 | ### Backward Propagation 58 | 59 | The backward embedding operation accepts the incoming gradients and uses these along with the transposed indices to generate a gradient with respect to the embedding table. In the event of multiple indices pointing to the same embedding row, the gradients are summed, potentially using atomic operations. 60 | 61 | EmbeddingBackward can produce either full embedding gradients or compressed embedding gradients. For dense embedding gradients, the output embedding gradient is the same size as the original embedding table, but will only will be modified in rows specified by the indices. For compressed embedding gradients, the output embedding gradient will have a number of rows which is equal to the number of unique lookup indices, and it will also produce an inverse mapping between the rows of the compressed gradient and the original row IDs in the embedding table. 62 | 63 | ```cpp 64 | template 65 | void EmbeddingBackward(const GradT* grad_y, 66 | const int embed_width, 67 | const int num_grad_embedding_rows, 68 | const int nnz, 69 | const IndexT* transpose_indices, 70 | const IndexT* transpose_sample_ids, 71 | const IndexT* transpose_remapped_indices, 72 | const GradT* transpose_weights, 73 | const bool skip_grad_init, 74 | GradT* grad_embedding, 75 | IndexT* inverse_mapping, 76 | const cudaStream_t stream = 0) 77 | ``` 78 | 79 | - The inputs `transpose_indices`, `transpose_sample_ids`, and `transpose_weights` are indices in coordinate (COO) format, produced by *Transpose*. All repeating indices in `transpose_indices` must be grouped contiguously. 80 | 81 | - If `transpose_weights` is provided then these weights will be multiplied with the `y_grad` rows prior to accumulation. `transpose_weights` may be set to nullptr for the unweighted case. 82 | 83 | - The input `transpose_remapped_indices` holds indices needed for compressed gradients, described below, and is produced by *ComputeCompressedGradIndices*. 84 | 85 | - If full gradient is desired, then `transpose_remapped_indices` should be set to nullptr, `grad_embedding` should be allocated with `num_grad_embedding_rows` rows, which is equal to the total number of categories in this case, and `inverse_mapping` will not be written. 86 | 87 | - If compressed gradient is desired, then `transpose_remapped_indices` should be provided (i.e. by *ComputeCompressedGradIndices*), `grad_embedding` should be allocated with `num_grad_embedding_rows` rows, which is equal to the number of unique lookup indices (i.e. transpose_remapped_indices.back()+1) in this case, and `inverse_mapping` should be allocated with `num_grad_embedding_rows` elements. 88 | 89 | - The output gradient `grad_embedding` will be initialized to zero prior to the backward lookup operation, unless `skip_grad_init` is set. 90 | 91 | __Parameters__ 92 | - __grad_y__ Pointer to the incoming gradient. 93 | - __embed_width__ Number of elements in each embedding row. 94 | - __num_grad_embedding_rows__ Number of rows in grad_embedding. 95 | - __nnz__ Total number of indices in COO input. 96 | - __transpose_indices__ Pointer to the transposed lookup indices. 97 | - __transpose_sample_ids__ Pointer to the transposed sample IDs. 98 | - __transpose_remapped_indices__ Pointer to the remapped lookup indices (i.e. from *ComputeCompressedGradIndices*), required only if computing compressed gradient. 99 | - __transpose_weights__ Pointer to the weight array. Set to nullptr for unweighted. 100 | - __skip_grad_init__ If true, skip zero-initializion of grad_embedding. 101 | - __grad_embedding__ Pointer to the gradient wrt embedding table. 102 | - __inverse_mapping__ Pointer to the table indices corresponding to each row in `grad_embedding`, produced only for compressed gradients. 103 | - __stream__ Optional. The cudaStream to launch the kernel asynchronously. If not specified, will launch the kernel on default stream. 104 | 105 | 106 | ## Index Transformations 107 | 108 | Index transformations are required to support common use cases of the forward and backward lookup kernels. For example, indices must be converted from the fixed-hotness or compressed-sparse-row (CSR) format used in the forward pass, into a transposed coordinate (COO) format for the backward pass. Additionally, computing compressed gradients during the backward pass requires one to generate a mapping between row ids in the compressed gradient and the row ids in the embedding table. 109 | 110 | Consider the common use-case of calling forward backward propagation, beginning with indices in CSR format. The lookup indices are stored in an array named `indices`, and ordered according to the sample IDs within the batch. (i.e. indices for sample ID 0, followed by sample IDs 1, 2, ..`batch_size-1`.) The `offsets` array contains a pointer to the beginning of each sample in the batch. The `offsets` array has size `batch_size+1`, with the last element containing the length of the `indices` array. When viewed as a sparse matrix, the forward CSR indices would look like this: 111 | 112 | ``` 113 | <-- Embedding Categories--> 114 | Sample 0: X X X 115 | Sample 1: X 116 | Sample 2: X X 117 | ... 118 | Sample bs-1: X X 119 | 120 | ``` 121 | 122 | Note that the rows are samples, the columns are embedding categories, and the data is stored in row-major order. This is helpful to facilitate efficient row-wise reductions within a batch sample during the forward pass. 123 | 124 | However, for the backward pass we need to accumulate independenly all the gradients corresponding to a single embedding category (i.e. summing vertically in the above picture). In order to do this efficiently, we want the indices stored instead in a column-major or transposed ordering. A simple and performant way to transpose sparse matrices is to convert to coordinate format (i.e. three nnz-length arrays: rows, columns, and weights), and sort the arrays by either rows or columns 125 | 126 | We provide the helper functions *ExtractRowIdsFromFixed*, *ExtractRowIdsFromCSR*, and *ExtractRowIdsForConcat* to help convert from the forward Fixed or CSR format to the explicit COO format. We also provide the *Transpose* function to reorder the COO indices by columns instead of rows. 127 | 128 | 129 | `index_transformations.cuh` contains the following functions: 130 | 131 | ### Conversion from Fixed-Hotness to COO 132 | Produce a nnz-length `row_ids` array which has fixed offsets from 0 to batch_size, e.g.: `num_hots = 3 -> row_ids = [0, 0, 0, 1, 1, 1, 2, 2, 2, ...]` 133 | ```cpp 134 | template 135 | void ExtractRowIdsFromFixed(const int batch_size, 136 | const int num_hots, 137 | IndexT* row_ids, 138 | const cudaStream_t stream = 0); 139 | ``` 140 | ### Conversion from CSR to COO 141 | Produce a nnz-length `row_ids` array which has explicit offsets read from CSR, e.g.: `offsets : [0, 2, 3, 5] -> row_ids = [0, 0, 1, 2, 2]` 142 | ```cpp 143 | template 144 | void ExtractRowIdsFromCSR(const OffsetT* offsets, 145 | const int batch_size, 146 | IndexT* row_ids, 147 | const cudaStream_t stream = 0); 148 | ``` 149 | ### Conversion to COO for Concat 150 | Produce a nnz-length `row_ids` array which has the sequence 0 .. nnz. e.g. `row_ids = [0, 1, 2, 3, ...]` 151 | ```cpp 152 | template 153 | void ExtractRowIdsForConcat(const int nnz, 154 | IndexT* row_ids, 155 | const cudaStream_t stream = 0); 156 | ``` 157 | ### Transpose 158 | 159 | Reorders indices from sample-id-first ordering as is needed during forward to table-index-first ordering needed for backward. Output indices are produced in coordinate (COO) format. 160 | 161 | ```cpp 162 | template 163 | void Transpose(const IndexT* rows, 164 | const IndexT* cols, 165 | const WeightT* weights, 166 | const int nnz, 167 | IndexT* transpose_rows, 168 | IndexT* transpose_cols, 169 | WeightT* transpose_weights, 170 | char* work, 171 | size_t* lwork, 172 | const cudaStream_t stream = 0); 173 | ``` 174 | 175 | 176 | - Input `rows`, `cols`, `weights` are the indices in COO format. For the embedding use case, `rows` contains the sample IDs during forward pass and `cols` contains the embedding lookup indices. 177 | 178 | - Output is stored in output arrays `transpose_rows`, `transpose_cols` and `transpose_weights` should be allocated with `nnz` elements. 179 | 180 | - If input `weights` are set to nullptr, then output `transpose_weights` will not be set. 181 | 182 | - For the embedding use case, `transpose_rows` contains the embedding lookup indices which are now in sorted order. `transpose_cols` and `transpose_weights` contain the sample IDs and optionally the weights corresponding to the reordered rows. 183 | 184 | - The function should first be called with `work` set to nullptr to perform a workspace query. The required size of `work` array in bytes will be returned in `lwork`. Then the function should be called a second time with `work` pointing to allocated workspace of size `lwork`. 185 | 186 | __Parameters__ 187 | - __rows__ Pointer to the lookup indices. 188 | - __cols__ Pointer to the offsets (CSR format) used during forward. Must be nullptr when launching for fixed hotness. 189 | - __weights__ Pointer to the weight array used during forward. If nullptr, will not produce transposed weights. 190 | - __nnz__ Number of nonzeros. 191 | - __transpose_rows__ Pointer to the output transposed table indices. 192 | - __transpose_cols__ Pointer to the output transposed sparse indices. 193 | - __transpose_weights__ Pointer to the transposed weight array. If input weights is nullptr, then will not produce transposed weights. 194 | - __work__ Pointer to scratch workspace. Set to nullptr for workspace query. 195 | - __lwork__ Pointer to size of scratch workspace. 196 | - __stream__ Optional. The cudaStream to launch the kernel asynchronously. If not specified, will launch the kernel on default stream. 197 | 198 | 199 | 200 | ### Compressed Gradient Index Conversion 201 | 202 | In some cases, the number of embedding rows actually referenced by the indices is much smaller than the total number of rows. In these cases, it may be advantageous to produce a compressed gradient which stores only the nonzero rows of the embedding gradient. However, this requires an additional step of remapping the indices from dense embedding row ids to compressed embedding row ids. This process is pictured below. The indices which are initially distributed between 0 and `num_categories` values, are remapped to the range of 0 and `num_unique`, e.g. `indices = [4, 4, 7, 8, 8, 8, 18] -> remapped_indices = [0, 0, 1, 2, 2, 2, 3]` 203 | 204 | We provide the helper function *ComputeCompressedGradIndices* to do the above transformation. Note that the value `num_unique` can be attained from remapped_indices.back() + 1 after calling this function. 205 | 206 | ```cpp 207 | template 208 | void ComputeCompressedGradIndices(const IndexT* indices, 209 | const int nnz, 210 | IndexT* remapped_indices, 211 | char* work, 212 | size_t* lwork, 213 | const cudaStream_t stream = 0) 214 | ``` 215 | 216 | - The function should first be called with `work` set to nullptr to perform a workspace query. The required size of `work` array in bytes will be returned in `lwork`. Then the function should be called a second time with `work` pointing to allocated workspace of size `lwork`. 217 | 218 | __Parameters__ 219 | - __indices__ Pointer to the lookup indices, grouped by index. 220 | - __nnz__ Length of the indices array. 221 | - __remapped_indices__ Pointer to the remapped lookup indices (output) 222 | - __work__ Temporary workspace 223 | - __lwork__ Size of workspace in bytes (input/output) 224 | - __stream__ Optional. The cudaStream to launch the kernel asynchronously. If not specified, will launch the kernel on default stream. 225 | 226 | 227 | -------------------------------------------------------------------------------- /cuembed/include/embedding_lookup_kernels.cuh: -------------------------------------------------------------------------------- 1 | // clang-format off 2 | /* 3 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | * SPDX-License-Identifier: Apache-2.0 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | // clang-format on 19 | 20 | //! \file 21 | #ifndef CUEMBED_INCLUDE_EMBEDDING_LOOKUP_KERNELS_CUH_ 22 | #define CUEMBED_INCLUDE_EMBEDDING_LOOKUP_KERNELS_CUH_ 23 | 24 | #include "cuembed/include/embedding_lookup_ops.cuh" 25 | 26 | // Parentheses after __launch_bounds__ confuses Doxygen. Make it macro 27 | #define LAUNCH_BOUNDS_1024_1 __launch_bounds__(1024, 1) 28 | 29 | //! cuEmbed main namespace 30 | namespace cuembed { 31 | /*! 32 | * \brief The actual implementation of the embedding lookup kernel. 33 | */ 34 | template 39 | __device__ __forceinline__ void EmbeddingLookupImpl( 40 | const typename AddresserT::InputType* __restrict__ params, 41 | const int embed_width, 42 | const int batch_size, 43 | const typename IndexLoaderT::IndexType* __restrict__ indices, 44 | const typename IndexLoaderT::OffsetType* __restrict__ offsets, 45 | const int num_hots, 46 | const typename IndexLoaderT::WeightType* __restrict__ weights, 47 | typename AddresserT::OutputType* __restrict__ ret) { 48 | const int sample_id = blockIdx.x * blockDim.y + threadIdx.y; 49 | IndexLoaderT index_loader( 50 | batch_size, sample_id, indices, weights, offsets, num_hots); 51 | 52 | if (sample_id >= batch_size) { 53 | return; 54 | } 55 | CombinerT combiner; 56 | 57 | AddresserT addresser(params, ret, sample_id, num_hots, embed_width); 58 | 59 | int64_t embed_row_offset = threadIdx.x; 60 | int64_t output_row_offset = threadIdx.x; 61 | 62 | #pragma unroll UnrollFactor 63 | for (int i = 0; i < index_loader.GetHotness(); ++i) { 64 | auto index = index_loader.GetLookUpIndex(i); 65 | if constexpr (IsWeighted) { 66 | auto weight = index_loader.GetWeight(i); 67 | combiner.Gather(addresser.GetEmbeddingAddress(index) + embed_row_offset, 68 | weight); 69 | } else { 70 | combiner.Gather(addresser.GetEmbeddingAddress(index) + embed_row_offset); 71 | } 72 | combiner.OutputForConcatIfNeeded(addresser.GetConcatOutputAddress(i) + 73 | output_row_offset); 74 | } 75 | combiner.OutputForReductionIfNeeded(addresser.GetOutputAddress() + 76 | output_row_offset); 77 | } 78 | 79 | /*! 80 | * \brief Embedding lookup kernel for fixed hotness and CSR lookup index layout. 81 | * 82 | * Templatization for this kernel is only for the basic operations: 83 | * IndexLoaderT, AddresserT and CombinerT. 84 | * 85 | * Organization for Grid and Blocks: 86 | * 1. Each CTA is a 2D block of threads. 87 | * Each CTA is responsible for processing multiple lookup samples. 88 | * Threads with the same blockDim.y in the CTA are processing the same 89 | * sample. To maximize loads in flight, each thread loads multiple samples 90 | * per load. 91 | * BlockDim.x = embed_width / elements_per_load. 92 | * BlockDim.y = samples_per_cta. 93 | * 2. The grid is a 1D array of CTAs. 94 | * GridDim.x = batch_size / samples_per_cta. 95 | * 96 | * The kernel contains the following operations: 97 | * 1. IndexLoader calculates the hotness and the index offset loads for a 98 | * specific sample. 99 | * For fixed hotness, IndexLoader lookup indices needed by the CTA into 100 | * shared memory and calculates the offset based on the loaded indices. 101 | * For CSR layout, IndexLoader just does the offset and hotness 102 | * calculation, then loads the lookup index from global memory when the 103 | * index is needed. 104 | * 2. Addresser calculates the lookup addresses based on the lookup indices. 105 | * 3. Combiner issues out parallel loads and writes output according to the 106 | * calculated addresses. 107 | * 108 | * Different modes of reduction/concatenation are realized by template 109 | * specialization of the combiner. 110 | * 111 | * Different layouts of the lookup indices are separated by template 112 | * specialization of offsets_or_hotness parameter. 113 | * 114 | * For future integration of Embed Cache, we can templatize the addresser with 115 | * the cache to get the indirect mapping of the actual address. 116 | */ 117 | template 121 | __global__ void LAUNCH_BOUNDS_1024_1 EmbeddingLookUpKernel( 122 | const typename AddresserT::InputType* __restrict__ params, 123 | const int embed_width, 124 | const int batch_size, 125 | const typename IndexLoaderT::IndexType* __restrict__ indices, 126 | const typename IndexLoaderT::OffsetType* __restrict__ offsets, 127 | const int num_hots, 128 | const typename IndexLoaderT::WeightType* __restrict__ weights, 129 | typename AddresserT::OutputType* __restrict__ ret) { 130 | EmbeddingLookupImpl( 131 | params, 132 | embed_width, 133 | batch_size, 134 | indices, 135 | offsets, 136 | num_hots, 137 | weights, 138 | ret); 139 | } 140 | 141 | /*! 142 | * \brief Explicit specialization of the embedding lookup kernel for the not 143 | * weighted use case. 144 | * 145 | * This specialization cuts down the number of registers needed for this kernel 146 | * and achieves higher occupancy. 147 | */ 148 | template 152 | __global__ void LAUNCH_BOUNDS_1024_1 EmbeddingLookUpKernel( 153 | const typename AddresserT::InputType* __restrict__ params, 154 | const int embed_width, 155 | const int batch_size, 156 | const typename IndexLoaderT::IndexType* __restrict__ indices, 157 | const typename IndexLoaderT::OffsetType* __restrict__ offsets, 158 | const int num_hots, 159 | std::nullptr_t, 160 | typename AddresserT::OutputType* __restrict__ ret) { 161 | EmbeddingLookupImpl( 162 | params, 163 | embed_width, 164 | batch_size, 165 | indices, 166 | offsets, 167 | num_hots, 168 | nullptr, 169 | ret); 170 | } 171 | 172 | /*! 173 | * \brief The actual implementation of the embedding backward kernel. 174 | */ 175 | template 179 | __device__ __forceinline__ void EmbeddingBackwardImpl( 180 | const typename GradAddresserT::GradType* __restrict__ grad_y, 181 | const int embed_width, 182 | const typename GradIndexLoaderT::IndexType* __restrict__ transpose_indices, 183 | const typename GradIndexLoaderT:: 184 | IndexType* __restrict__ transpose_sample_ids, 185 | const typename GradIndexLoaderT::WeightType* __restrict__ transpose_weights, 186 | const int nnz, 187 | const int nz_block_size, 188 | typename GradAddresserT::GradType* __restrict__ grad_embedding) { 189 | GradIndexLoaderT index_loader(transpose_indices, 190 | transpose_sample_ids, 191 | transpose_weights, 192 | nnz, 193 | nz_block_size); 194 | 195 | GradAddresserT addresser(grad_y, grad_embedding, embed_width); 196 | 197 | GradCombinerT combiner; 198 | 199 | const int embed_offset = threadIdx.x; 200 | 201 | // For each nonzero assigned to this block 202 | for (int i = 0; i < index_loader.GetBlockNnz(); i++) { 203 | int row = index_loader.GetIndex(i); 204 | int col = index_loader.GetSampleId(i); 205 | bool write_flag = index_loader.ShouldWrite(i); 206 | bool atomic_flag = index_loader.ShouldAtomic(i); 207 | 208 | if constexpr (IsWeighted) { 209 | auto weight = index_loader.GetWeight(i); 210 | combiner.Gather(addresser.GetGradResultAddress(col) + embed_offset, 211 | weight); 212 | } else { 213 | combiner.Gather(addresser.GetGradResultAddress(col) + embed_offset); 214 | } 215 | combiner.WriteOrAtomic( 216 | addresser.GetGradEmbeddingAddress(row) + embed_offset, 217 | write_flag, 218 | atomic_flag); 219 | } 220 | } 221 | 222 | /*! 223 | * \brief Embedding backward kernel for compressed and full gradients. 224 | * 225 | * Organization for Grid and Blocks: 226 | * 1. Each CTA is a 2D block of threads. 227 | * Each CTA is responsible for processing multiple indices. 228 | * Threads with the same blockDim.y in the CTA are processing the same 229 | * indices. 230 | * BlockDim.x = embed_width / elements_per_load. 231 | * BlockDim.y = nonzero_blocks_per_cta 232 | * 2. The grid is a 1D array of CTAs. 233 | * GridDim.x = nnz / nz_block_size / nonzero_blocks_per_cta. 234 | * 235 | */ 236 | template 239 | __global__ void EmbeddingBackwardKernel( 240 | const typename GradAddresserT::GradType* __restrict__ grad_y, 241 | const int embed_width, 242 | const typename GradIndexLoaderT::IndexType* __restrict__ transpose_indices, 243 | const typename GradIndexLoaderT:: 244 | IndexType* __restrict__ transpose_sample_ids, 245 | const typename GradIndexLoaderT::WeightType* __restrict__ transpose_weights, 246 | const int nnz, 247 | const int nz_block_size, 248 | typename GradAddresserT::GradType* __restrict__ grad_embedding) { 249 | EmbeddingBackwardImpl( 250 | grad_y, 251 | embed_width, 252 | transpose_indices, 253 | transpose_sample_ids, 254 | transpose_weights, 255 | nnz, 256 | nz_block_size, 257 | grad_embedding); 258 | } 259 | 260 | /*! 261 | * \brief Specialization of mbedding backward kernel for unweighted case 262 | * 263 | */ 264 | template 267 | __global__ void EmbeddingBackwardKernel( 268 | const typename GradAddresserT::GradType* __restrict__ grad_y, 269 | const int embed_width, 270 | const typename GradIndexLoaderT::IndexType* __restrict__ transpose_indices, 271 | const typename GradIndexLoaderT:: 272 | IndexType* __restrict__ transpose_sample_ids, 273 | std::nullptr_t, 274 | const int nnz, 275 | const int nz_block_size, 276 | typename GradAddresserT::GradType* __restrict__ grad_embedding) { 277 | EmbeddingBackwardImpl( 278 | grad_y, 279 | embed_width, 280 | transpose_indices, 281 | transpose_sample_ids, 282 | nullptr, 283 | nnz, 284 | nz_block_size, 285 | grad_embedding); 286 | } 287 | 288 | // Generate sparse indices for sparse output gradients 289 | template 290 | __global__ void CompactSparseIndicesKernel(const IndexT* indices, 291 | const IndexT* remapped_indices, 292 | IndexT* compacted_indices, 293 | const int nnz) { 294 | int tid = threadIdx.x + blockIdx.x * blockDim.x; 295 | if (tid >= nnz) return; 296 | 297 | // If you're the first occurance, then write 298 | if ((tid == 0) || (indices[tid - 1] != indices[tid])) { 299 | IndexT cidx = remapped_indices[tid]; 300 | compacted_indices[cidx] = indices[tid]; 301 | } 302 | } 303 | 304 | } // namespace cuembed 305 | 306 | #endif // CUEMBED_INCLUDE_EMBEDDING_LOOKUP_KERNELS_CUH_ 307 | -------------------------------------------------------------------------------- /cuembed/include/index_transforms.cuh: -------------------------------------------------------------------------------- 1 | // clang-format off 2 | /* 3 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | * SPDX-License-Identifier: Apache-2.0 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | // clang-format on 19 | 20 | //! \file 21 | #ifndef CUEMBED_INCLUDE_INDEX_TRANSFORMS_CUH_ 22 | #define CUEMBED_INCLUDE_INDEX_TRANSFORMS_CUH_ 23 | 24 | // clang-format off 25 | #include 26 | #include 27 | #include 28 | #include 29 | 30 | #include "cuembed/include/embedding_lookup.cuh" 31 | #include "cuembed/include/index_transforms_kernels.cuh" 32 | // clang-format on 33 | 34 | namespace cuembed { 35 | 36 | /*! 37 | * \brief Produce a nnz-length row_ids array which has fixed offsets from 0 to 38 | * batch_size, e.g.: num_hots = 3 -> row_ids = [0, 0, 0, 1, 1, 1, 2, 2, 2, ...] 39 | * 40 | * Requires a workspace query in which the function is called first with the 41 | * "work" parameter set to NULL. The required temporary memory size in bytes 42 | * will be returned in the lwork parameter by reference. 43 | * 44 | */ 45 | template 46 | void ExtractRowIdsFromFixed(const int batch_size, 47 | const int num_hots, 48 | IndexT* row_ids, 49 | const cudaStream_t stream = 0) { 50 | const int nnz = batch_size * num_hots; 51 | const int nthreads = DEFAULT_THREADS_PER_CTA; 52 | ExtractSequenceKernel 53 | <<<(nnz + nthreads - 1) / nthreads, nthreads, 0, stream>>>( 54 | nnz, num_hots, row_ids); 55 | } 56 | 57 | /*! 58 | * \brief Produce a nnz-length row_ids array which has explicit offsets read 59 | * from CSR, e.g.: offsets : [0, 2, 3, 5] -> row_ids = [0, 0, 1, 2, 2] 60 | * 61 | * Requires a workspace query in which the function is called first with the 62 | * "work" parameter set to NULL. The required temporary memory size in bytes 63 | * will be returned in the lwork parameter by reference. 64 | * 65 | */ 66 | template 67 | void ExtractRowIdsFromCSR(const OffsetT* offsets, 68 | const int batch_size, 69 | IndexT* row_ids, 70 | const cudaStream_t stream = 0) { 71 | const int nthreads = DEFAULT_THREADS_PER_CTA; 72 | ExtractRowIdsFromCSRKernel 73 | <<>>(offsets, row_ids); 74 | } 75 | 76 | /*! 77 | * \brief Produce a nnz-length row_ids array which has the sequence 0 .. nnz. 78 | * e.g. row_ids = [0, 1, 2, 3, ...] 79 | * 80 | * Requires a workspace query in which the function is called first with the 81 | * "work" parameter set to NULL. The required temporary memory size in bytes 82 | * will be returned in the lwork parameter by reference. 83 | * 84 | */ 85 | template 86 | void ExtractRowIdsForConcat(const int nnz, 87 | IndexT* row_ids, 88 | const cudaStream_t stream = 0) { 89 | const int nthreads = DEFAULT_THREADS_PER_CTA; 90 | ExtractSequenceKernel 91 | <<<(nnz + nthreads - 1) / nthreads, nthreads, 0, stream>>>( 92 | nnz, 1, row_ids); 93 | } 94 | 95 | template 96 | void TransposeUnweighted(const IndexT* rows, 97 | const IndexT* cols, 98 | const int nnz, 99 | IndexT* transpose_rows, 100 | IndexT* transpose_cols, 101 | char* work, 102 | size_t* lwork, 103 | const cudaStream_t stream = 0) { 104 | void* nullwork = nullptr; 105 | size_t sort_bytes = 0; 106 | const int begin_bit = 0; 107 | const int end_bit = sizeof(IndexT) * 8; 108 | cub::DeviceRadixSort::SortPairs(nullwork, 109 | sort_bytes, 110 | cols, 111 | transpose_rows, 112 | rows, 113 | transpose_cols, 114 | nnz, 115 | begin_bit, 116 | end_bit, 117 | stream); 118 | 119 | size_t required_workspace = sort_bytes; 120 | 121 | if (work == nullptr) { 122 | *lwork = required_workspace; 123 | return; 124 | } 125 | 126 | assert(*lwork >= required_workspace); 127 | cub::DeviceRadixSort::SortPairs(static_cast(work), 128 | *lwork, 129 | cols, 130 | transpose_rows, 131 | rows, 132 | transpose_cols, 133 | nnz, 134 | begin_bit, 135 | end_bit, 136 | stream); 137 | } 138 | 139 | template 140 | void TransposeWeighted(const IndexT* rows, 141 | const IndexT* cols, 142 | const WeightT* weights, 143 | const int nnz, 144 | IndexT* transpose_rows, 145 | IndexT* transpose_cols, 146 | WeightT* transpose_weights, 147 | char* work, 148 | size_t* lwork, 149 | const cudaStream_t stream = 0) { 150 | size_t buffer_bytes = 2 * nnz * sizeof(WeightTuple); 151 | 152 | WeightTuple* vals_in = 153 | reinterpret_cast*>(work); 154 | WeightTuple* vals_out = vals_in + nnz; 155 | 156 | void* nullwork = nullptr; 157 | size_t sort_bytes = 0; 158 | const int begin_bit = 0; 159 | const int end_bit = sizeof(IndexT) * 8; 160 | cub::DeviceRadixSort::SortPairs(nullwork, 161 | sort_bytes, 162 | cols, 163 | transpose_rows, 164 | vals_in, 165 | vals_out, 166 | nnz, 167 | begin_bit, 168 | end_bit, 169 | stream); 170 | 171 | size_t required_workspace = sort_bytes + buffer_bytes; 172 | 173 | if (work == nullptr) { 174 | *lwork = required_workspace; 175 | return; 176 | } 177 | 178 | assert(*lwork >= required_workspace); 179 | 180 | const int nthreads = DEFAULT_THREADS_PER_CTA; 181 | PackToTuple 182 | <<<(nnz + nthreads - 1) / nthreads, nthreads, 0, stream>>>( 183 | rows, weights, nnz, vals_in); 184 | 185 | void* sort_workspace = reinterpret_cast(vals_out + nnz); 186 | cub::DeviceRadixSort::SortPairs(static_cast(sort_workspace), 187 | *lwork, 188 | cols, 189 | transpose_rows, 190 | vals_in, 191 | vals_out, 192 | nnz, 193 | begin_bit, 194 | end_bit, 195 | stream); 196 | 197 | ExtractFromTuple 198 | <<<(nnz + nthreads - 1) / nthreads, nthreads, 0, stream>>>( 199 | vals_out, nnz, transpose_cols, transpose_weights); 200 | } 201 | 202 | /** 203 | * @brief Reorders indices from sample-id-first ordering as is needed during 204 | * forward to table-index-first ordering needed for backward. Output indices are 205 | * produced in coordinate (COO) format. 206 | * 207 | * @tparam IndexT Index datatype 208 | * 209 | * @param rows Pointer to the lookup indices. 210 | * @param cols Pointer to the offsets (CSR format) used during forward. Must be 211 | * nullptr when launching for fixed hotness. 212 | * @param weights Pointer to the weight array used during forward. If nullptr, 213 | * will not produce transposed weights. 214 | * @param nnz Number of nonzeros. 215 | * @param transpose_rows Pointer to the output transposed table indices. 216 | * @param transpose_cols Pointer to the output transposed sparse indices. 217 | * @param transpose_weights Pointer to the transposed weight array. If input 218 | * weights is nullptr, then will not produce transposed weights. 219 | * @param work Pointer to scratch workspace. Set to nullptr for workspace query. 220 | * @param lwork Pointer to size of scratch workspace. 221 | * @param stream Optional. The cudaStream to launch the kernel asynchronously. 222 | * If not specified, will launch the kernel on default stream. 223 | */ 224 | template 225 | void Transpose(const IndexT* rows, 226 | const IndexT* cols, 227 | const WeightT* weights, 228 | const int nnz, 229 | IndexT* transpose_rows, 230 | IndexT* transpose_cols, 231 | WeightT* transpose_weights, 232 | char* work, 233 | size_t* lwork, 234 | const cudaStream_t stream = 0) { 235 | if (weights == nullptr) { 236 | TransposeUnweighted( 237 | rows, cols, nnz, transpose_rows, transpose_cols, work, lwork, stream); 238 | } else { 239 | TransposeWeighted(rows, 240 | cols, 241 | weights, 242 | nnz, 243 | transpose_rows, 244 | transpose_cols, 245 | transpose_weights, 246 | work, 247 | lwork, 248 | stream); 249 | } 250 | } 251 | 252 | struct FlagNonzero { 253 | template 254 | __host__ __device__ __forceinline__ T operator()(const T lhs, const T rhs) { 255 | return (lhs == rhs) ? 0 : 1; 256 | } 257 | }; 258 | 259 | /** 260 | * @brief The indices which are initially distributed between 0 and 261 | * num_categories values, are remapped to the range of 0 and num_unique, e.g. 262 | * indices = [4, 4, 7, 8, 8, 8, 18] -> remapped_indices = [0, 0, 1, 2, 2, 2, 3] 263 | * 264 | * Requires a workspace query in which the function is called first with the 265 | * "work" parameter set to NULL. The required temporary memory size in bytes 266 | * will be returned in the lwork parameter by reference. 267 | * 268 | * @tparam IndexT Index datatype 269 | * 270 | * @param indices Pointer to the lookup indices, grouped by index. 271 | * @param nnz Length of the indices array. 272 | * @param remapped_indices Pointer to the remapped lookup indices (output) 273 | * @param work Temporary workspace 274 | * @param lwork Size of workspace in bytes (input/output) 275 | * @param stream Optional. The cudaStream to launch the kernel asynchronously. 276 | * If not specified, will launch the kernel on default stream. 277 | */ 278 | template 279 | void ComputeCompressedGradIndices(const IndexT* indices, 280 | const int nnz, 281 | IndexT* remapped_indices, 282 | char* work, 283 | size_t* lwork, 284 | const cudaStream_t stream = 0) { 285 | void* nullwork = nullptr; 286 | size_t scan_storage_bytes = 0; 287 | cub::DeviceScan::InclusiveSum( 288 | nullwork, scan_storage_bytes, indices, remapped_indices, nnz, stream); 289 | size_t ad_storage_bytes = 0; 290 | cub::DeviceAdjacentDifference::SubtractLeftCopy(nullwork, 291 | ad_storage_bytes, 292 | remapped_indices, 293 | remapped_indices, 294 | nnz, 295 | FlagNonzero(), 296 | stream); 297 | size_t required_workspace = std::max(scan_storage_bytes, ad_storage_bytes); 298 | 299 | // Workspace query 300 | if (work == nullptr) { 301 | *lwork = required_workspace; 302 | return; 303 | } 304 | 305 | assert(*lwork >= required_workspace); 306 | 307 | cub::DeviceAdjacentDifference::SubtractLeftCopy(reinterpret_cast(work), 308 | ad_storage_bytes, 309 | indices, 310 | remapped_indices, 311 | nnz, 312 | FlagNonzero(), 313 | stream); 314 | 315 | cudaMemsetAsync(remapped_indices, 0, sizeof(IndexT), stream); 316 | 317 | cub::DeviceScan::InclusiveSum(reinterpret_cast(work), 318 | *lwork, 319 | remapped_indices, 320 | remapped_indices, 321 | nnz, 322 | stream); 323 | } 324 | 325 | } // namespace cuembed 326 | 327 | #endif // CUEMBED_INCLUDE_INDEX_TRANSFORMS_CUH_ 328 | -------------------------------------------------------------------------------- /cuembed/include/index_transforms_kernels.cuh: -------------------------------------------------------------------------------- 1 | // clang-format off 2 | /* 3 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | * SPDX-License-Identifier: Apache-2.0 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | // clang-format on 19 | 20 | //! \file 21 | #ifndef CUEMBED_INCLUDE_INDEX_TRANSFORMS_KERNELS_CUH_ 22 | #define CUEMBED_INCLUDE_INDEX_TRANSFORMS_KERNELS_CUH_ 23 | 24 | //! cuEmbed main namespace 25 | namespace cuembed { 26 | 27 | // Create expanded COO Offsets from CSR Offsets 28 | template 29 | __global__ void ExtractRowIdsFromCSRKernel(const OffsetT* offsets, 30 | IndexT* row_ids) { 31 | const int b = blockIdx.x; 32 | OffsetT start = offsets[b]; 33 | OffsetT end = offsets[b + 1]; 34 | for (OffsetT i = start + threadIdx.x; i < end; i += blockDim.x) { 35 | row_ids[i] = static_cast(b); 36 | } 37 | } 38 | 39 | // Create offsets from sequence 40 | template 41 | __global__ void ExtractSequenceKernel(const int nnz, 42 | const int int_div, 43 | IndexT* row_ids) { 44 | const int tid = threadIdx.x + blockIdx.x * blockDim.x; 45 | if (tid < nnz) { 46 | row_ids[tid] = static_cast(tid / int_div); 47 | } 48 | } 49 | 50 | template 51 | struct WeightTuple { 52 | IndexT idx; 53 | WeightT weight; 54 | }; 55 | 56 | template 57 | __global__ void PackToTuple(const IndexT* __restrict__ indices, 58 | const WeightT* __restrict__ weights, 59 | const int nnz, 60 | WeightTuple* __restrict__ vals) { 61 | int tid = threadIdx.x + blockIdx.x * blockDim.x; 62 | if (tid < nnz) { 63 | WeightTuple t; 64 | t.idx = indices[tid]; 65 | t.weight = weights[tid]; 66 | vals[tid] = t; 67 | } 68 | } 69 | 70 | template 71 | __global__ void ExtractFromTuple( 72 | const WeightTuple* __restrict__ vals, 73 | const int nnz, 74 | IndexT* __restrict__ indices, 75 | WeightT* __restrict__ weights) { 76 | int tid = threadIdx.x + blockIdx.x * blockDim.x; 77 | if (tid < nnz) { 78 | indices[tid] = vals[tid].idx; 79 | weights[tid] = vals[tid].weight; 80 | } 81 | } 82 | 83 | } // namespace cuembed 84 | 85 | #endif // CUEMBED_INCLUDE_INDEX_TRANSFORMS_KERNELS_CUH_ 86 | -------------------------------------------------------------------------------- /examples/pytorch/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | find_package(Python REQUIRED COMPONENTS Development) 2 | 3 | set(TORCH_CUDA_ARCH_LIST "7.0;7.5;8.0;9.0" CACHE STRING "List of target GPU architectures") 4 | 5 | message(STATUS "add `python -c 'import torch;print(torch.utils.cmake_prefix_path)'` to CMAKE_PREFIX_PATH") 6 | find_package(Torch REQUIRED) 7 | 8 | add_library(cuembed_pyt SHARED cuembed_embedding.cu) 9 | target_link_libraries(cuembed_pyt PRIVATE ${TORCH_LIBRARIES} Python::Python cuembed_hdrs) 10 | -------------------------------------------------------------------------------- /examples/pytorch/cuembed_embedding.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "cuembed/include/embedding_lookup.cuh" 8 | #include "cuembed/include/index_transforms.cuh" 9 | 10 | torch::Tensor cuembed_embedding_forward(const torch::Tensor params, 11 | const torch::Tensor indices, 12 | const torch::Tensor offsets, 13 | const torch::Tensor weights, 14 | const std::string mode) { 15 | AT_ASSERT(indices.is_cuda()); 16 | AT_ASSERT(offsets.is_cuda()); 17 | AT_ASSERT(params.is_cuda()); 18 | AT_ASSERT(params.is_contiguous()); 19 | AT_ASSERT(params.scalar_type() == torch::ScalarType::Float); 20 | if (weights.defined()) { 21 | AT_ASSERT(weights.scalar_type() == torch::ScalarType::Float); 22 | } 23 | AT_ASSERT(indices.scalar_type() == torch::ScalarType::Long); 24 | AT_ASSERT(offsets.scalar_type() == torch::ScalarType::Long); 25 | using IndexType = int64_t; 26 | 27 | int num_features = params.size(0); 28 | int embed_width = params.size(1); 29 | 30 | int batch_size = offsets.numel() - 1; 31 | auto outputs = torch::empty(batch_size * embed_width, params.options()); 32 | 33 | AT_ASSERT(mode == "sum"); 34 | auto combine_mode = cuembed::CombineMode::kSum; 35 | cuembed::EmbeddingForward( 36 | params.data_ptr(), 37 | embed_width, 38 | indices.contiguous().data_ptr(), 39 | offsets.contiguous().data_ptr(), 40 | weights.defined() ? weights.contiguous().data_ptr() : nullptr, 41 | batch_size, 42 | 0, 43 | combine_mode, 44 | outputs.mutable_data_ptr(), 45 | at::cuda::getCurrentCUDAStream()); 46 | 47 | return outputs.reshape({batch_size, embed_width}); 48 | } 49 | 50 | torch::Tensor cuembed_extract_row_ids_from_csr(const torch::Tensor offsets, 51 | const int64_t nnz) { 52 | AT_ASSERT(offsets.is_cuda()); 53 | AT_ASSERT(offsets.scalar_type() == torch::ScalarType::Long); 54 | using IndexType = int64_t; 55 | 56 | int batch_size = offsets.size(0); 57 | auto row_ids = torch::empty(nnz, offsets.options()); 58 | cuembed::ExtractRowIdsFromCSR(offsets.data_ptr(), 59 | batch_size, 60 | row_ids.mutable_data_ptr(), 61 | at::cuda::getCurrentCUDAStream()); 62 | return row_ids; 63 | } 64 | 65 | std::tuple cuembed_transpose( 66 | const torch::Tensor rows, 67 | const torch::Tensor cols, 68 | const torch::Tensor weights) { 69 | AT_ASSERT(rows.is_cuda()); 70 | AT_ASSERT(cols.is_cuda()); 71 | AT_ASSERT(rows.scalar_type() == torch::ScalarType::Long); 72 | AT_ASSERT(cols.scalar_type() == torch::ScalarType::Long); 73 | if (weights.defined()) { 74 | AT_ASSERT(weights.scalar_type() == torch::ScalarType::Float); 75 | } 76 | 77 | int nnz = rows.size(0); 78 | using IndexType = int64_t; 79 | auto transpose_rows = torch::empty(nnz, rows.options()); 80 | auto transpose_cols = torch::empty(nnz, cols.options()); 81 | 82 | // TODO(niskos): Propely return a None tensor if there are no weights 83 | auto transpose_weights_size = weights.defined() ? nnz : 0; 84 | auto transpose_weights = torch::empty( 85 | transpose_weights_size, at::device(rows.device()).dtype(at::kFloat)); 86 | 87 | size_t lwork = 0; 88 | cuembed::Transpose( 89 | rows.data_ptr(), 90 | cols.data_ptr(), 91 | weights.defined() ? weights.contiguous().data_ptr() : nullptr, 92 | nnz, 93 | transpose_rows.mutable_data_ptr(), 94 | transpose_cols.mutable_data_ptr(), 95 | transpose_weights.mutable_data_ptr(), 96 | nullptr, 97 | &lwork, 98 | at::cuda::getCurrentCUDAStream()); 99 | auto work = torch::empty(lwork, at::device(rows.device()).dtype(at::kByte)); 100 | cuembed::Transpose( 101 | rows.data_ptr(), 102 | cols.data_ptr(), 103 | weights.defined() ? weights.contiguous().data_ptr() : nullptr, 104 | nnz, 105 | transpose_rows.mutable_data_ptr(), 106 | transpose_cols.mutable_data_ptr(), 107 | transpose_weights.mutable_data_ptr(), 108 | reinterpret_cast(work.mutable_data_ptr()), 109 | &lwork, 110 | at::cuda::getCurrentCUDAStream()); 111 | return {transpose_rows, transpose_cols, transpose_weights}; 112 | } 113 | 114 | torch::Tensor cuembed_embedding_backward( 115 | const torch::Tensor y_grad, 116 | const int64_t num_categories, 117 | const torch::Tensor transpose_indices, 118 | const torch::Tensor transpose_sample_ids, 119 | const torch::Tensor transpose_weights) { 120 | AT_ASSERT(transpose_indices.is_cuda()); 121 | AT_ASSERT(transpose_sample_ids.is_cuda()); 122 | AT_ASSERT(y_grad.is_cuda()); 123 | AT_ASSERT(y_grad.is_contiguous()); 124 | 125 | AT_ASSERT(transpose_indices.scalar_type() == torch::ScalarType::Long); 126 | AT_ASSERT(transpose_sample_ids.scalar_type() == torch::ScalarType::Long); 127 | AT_ASSERT(y_grad.scalar_type() == torch::ScalarType::Float); 128 | using IndexType = int64_t; 129 | 130 | // Allocate grad_embedding 131 | int embed_width = y_grad.size(1); 132 | int nnz = transpose_indices.size(0); 133 | auto grad_embedding = 134 | torch::zeros(num_categories * embed_width, y_grad.options()); 135 | 136 | // Call backward 137 | cuembed::EmbeddingBackward( 138 | y_grad.data_ptr(), 139 | embed_width, 140 | num_categories, 141 | nnz, 142 | transpose_indices.contiguous().data_ptr(), 143 | transpose_sample_ids.contiguous().data_ptr(), 144 | nullptr, /*transpose_remapped_indices*/ 145 | transpose_weights.defined() 146 | ? transpose_weights.contiguous().data_ptr() 147 | : nullptr, 148 | true, /*skip_grad_init*/ 149 | grad_embedding.mutable_data_ptr(), 150 | nullptr, /*inverse_mapping*/ 151 | at::cuda::getCurrentCUDAStream()); 152 | 153 | return grad_embedding.reshape({num_categories, embed_width}); 154 | } 155 | 156 | TORCH_LIBRARY(cuembed_pyt, m) { 157 | m.def( 158 | "cuembed_extract_row_ids_from_csr(Tensor offsets, int nnz)" 159 | " ->Tensor", 160 | &cuembed_extract_row_ids_from_csr); 161 | m.def( 162 | "cuembed_transpose(Tensor rows, Tensor cols, Tensor weights) ->" 163 | " (Tensor, Tensor, Tensor)", 164 | &cuembed_transpose); 165 | m.def( 166 | "cuembed_embedding_forward(Tensor params, Tensor indices," 167 | " Tensor offsets, Tensor weights, str mode) -> Tensor", 168 | &cuembed_embedding_forward); 169 | m.def( 170 | "cuembed_embedding_backward(Tensor y_grad, int num_categories," 171 | " Tensor transpose_indices, Tensor transpose_sample_ids, Tensor " 172 | "transpose_weights) -> Tensor", 173 | &cuembed_embedding_backward); 174 | } 175 | -------------------------------------------------------------------------------- /examples/pytorch/cuembed_pyt.py: -------------------------------------------------------------------------------- 1 | from absl import app 2 | from absl import flags 3 | 4 | import torch 5 | from torch.nn import functional as F 6 | 7 | torch.ops.load_library("../../build/examples/pytorch/libcuembed_pyt.so") 8 | 9 | cuembed_extract_row_ids_from_csr = ( 10 | torch.ops.cuembed_pyt.cuembed_extract_row_ids_from_csr 11 | ) 12 | cuembed_transpose = torch.ops.cuembed_pyt.cuembed_transpose 13 | cuembed_embedding_forward = torch.ops.cuembed_pyt.cuembed_embedding_forward 14 | cuembed_embedding_backward = torch.ops.cuembed_pyt.cuembed_embedding_backward 15 | 16 | def cuembed_forward(params, idx, offsets, weights): 17 | return cuembed_embedding_forward(params, idx, offsets, weights, mode="sum") 18 | 19 | def cuembed_backward(ctx, out_grad): 20 | idx = ctx.saved_tensors[0] 21 | offsets = ctx.saved_tensors[1] 22 | weights = ctx.saved_tensors[2] 23 | num_categories = ctx.num_categories 24 | nnz = idx.size(0) 25 | 26 | # Assuming equivalent of nn.EmbeddingBag's `include_last_offset=True` 27 | sample_ids = cuembed_extract_row_ids_from_csr(offsets[:-1], nnz) 28 | transpose_indices, transpose_sample_ids, transpose_weights = \ 29 | cuembed_transpose(sample_ids, idx, weights) 30 | 31 | # This means weights = None during forward 32 | if(transpose_weights.numel() == 0): 33 | transpose_weights = None 34 | 35 | grad_embedding = cuembed_embedding_backward( 36 | out_grad, num_categories, transpose_indices, transpose_sample_ids, transpose_weights) 37 | 38 | # no grad for indices, offsets, or weights 39 | return grad_embedding, None, None, None 40 | 41 | def setup_context(ctx, inputs, output): 42 | params, idx, offsets, weights = inputs 43 | ctx.save_for_backward(idx, offsets, weights) 44 | ctx.num_categories = params.size(0) 45 | 46 | # Need to register this as a custom op to allow torch.compile to work 47 | @torch.library.custom_op("cuemb::cuemb_embedding", mutates_args=()) 48 | def cuemb_embedding( 49 | params : torch.Tensor, idx : torch.Tensor, offsets : torch.Tensor, weights : torch.Tensor = None) -> torch.Tensor: 50 | return cuembed_forward(params, idx, offsets, weights) 51 | 52 | @cuemb_embedding.register_fake 53 | def _(params : torch.Tensor, idx : torch.Tensor, offsets : torch.Tensor, weights : torch.Tensor = None): 54 | return torch.empty_like(params).reshape([offsets.shape[0]-1, params.shape[1]]) 55 | 56 | cuemb_embedding.register_autograd(cuembed_backward, setup_context=setup_context) -------------------------------------------------------------------------------- /examples/pytorch/cuembed_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from cuembed_pyt import cuemb_embedding 4 | 5 | torch.manual_seed(0) 6 | 7 | def test_cuembed(embedding_bag, indices, offsets, weights): 8 | if(weights != None): 9 | res = cuemb_embedding(embedding_bag.weight, indices, offsets, weights) 10 | ref = embedding_bag(indices, offsets, weights) 11 | else: 12 | res = cuemb_embedding(embedding_bag.weight, indices, offsets) 13 | ref = embedding_bag(indices, offsets) 14 | 15 | print('fprop test pass = ', (res == ref).all()) 16 | 17 | embedding_bag.weight.grad = None 18 | torch.mean(res).backward() 19 | grad_res = embedding_bag.weight.grad.clone() 20 | 21 | embedding_bag.weight.grad = None 22 | torch.mean(ref).backward() 23 | grad_ref = embedding_bag.weight.grad.clone() 24 | 25 | # might not be exactly equal because cuEmbed uses atomics in back pass 26 | print('bprop test pass = ', torch.allclose(grad_res, grad_ref), '\n') 27 | 28 | # test cases 29 | k = 958 30 | embedding_bag = nn.EmbeddingBag( 31 | num_embeddings=k, 32 | embedding_dim=128, 33 | mode='sum', 34 | include_last_offset=True, # type: ignore Argument of type "bool | None" cannot be assigned ... 35 | padding_idx=None, 36 | dtype=torch.float32 37 | ).to(device='cuda') 38 | 39 | n = 2880000 40 | indices = k * torch.rand([n], device='cuda') 41 | indices = indices.to(dtype=torch.long) 42 | 43 | offsets = torch.tensor([ i for i in range(n)]+[n],device='cuda') 44 | weights = torch.rand([n], device='cuda', dtype=torch.float32) 45 | 46 | test_cuembed(embedding_bag, indices, offsets, weights) 47 | test_cuembed(embedding_bag, indices, offsets, None) 48 | 49 | k = 2048 50 | embedding_bag = nn.EmbeddingBag( 51 | num_embeddings=k, 52 | embedding_dim=64, 53 | mode='sum', 54 | include_last_offset=True, # type: ignore Argument of type "bool | None" cannot be assigned ... 55 | padding_idx=None, 56 | dtype=torch.float32 57 | ).to(device='cuda') 58 | 59 | n = 104217 60 | indices = k * torch.rand([n], device='cuda') 61 | indices = indices.to(dtype=torch.long) 62 | 63 | offsets = torch.tensor([ i for i in range(n)]+[n],device='cuda', dtype=torch.long) 64 | weights = torch.rand([n], device='cuda', dtype=torch.float32) 65 | 66 | test_cuembed(embedding_bag, indices, offsets, weights) 67 | test_cuembed(embedding_bag, indices, offsets, None) -------------------------------------------------------------------------------- /tests/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | file(GLOB test_source_files_cpp "${CMAKE_SOURCE_DIR}/tests/*.cpp") 17 | file(GLOB test_source_files_cu "${CMAKE_SOURCE_DIR}/tests/*.cu") 18 | set(test_source_files ${test_source_files_cpp} ${test_source_files_cu}) 19 | 20 | foreach(test_file ${test_source_files}) 21 | get_filename_component(test_name ${test_file} NAME_WE) 22 | add_executable(${test_name} ${test_file}) 23 | target_include_directories(${test_name} PRIVATE googletest ${CUDAToolkit_INCLUDE_DIRS}) 24 | target_link_libraries( 25 | ${test_name} PRIVATE 26 | cuembed_hdrs 27 | gtest 28 | gtest_main 29 | cuda 30 | utils 31 | absl::log 32 | absl::check) 33 | set_target_properties(${test_name} 34 | PROPERTIES 35 | RUNTIME_OUTPUT_DIRECTORY "${PROJECT_BINARY_DIR}/bin/test" 36 | ) 37 | endforeach(test_file) 38 | -------------------------------------------------------------------------------- /tests/test_datagen.cpp: -------------------------------------------------------------------------------- 1 | // clang-format off 2 | /* 3 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | * SPDX-License-Identifier: Apache-2.0 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | // clang-format on 19 | 20 | #include 21 | 22 | #include "utils/include/datagen.h" 23 | 24 | namespace cuembed { 25 | namespace index_generators { 26 | 27 | template 28 | class DataGeneratorTest : public ::testing::Test { 29 | protected: 30 | using IndexType = T; 31 | 32 | void setUpParams(const int num_categories, const int num_hot) { 33 | num_categories_ = num_categories; 34 | num_hot_ = num_hot; 35 | } 36 | 37 | std::vector> generateIndices(const int batch_size) { 38 | std::vector> result; 39 | for (int i = 0; i < batch_size; i++) { 40 | std::vector generated_idx = generator_->getCategoryIndices(); 41 | EXPECT_EQ(generated_idx.size(), num_hot_); 42 | result.push_back(generated_idx); 43 | } 44 | return result; 45 | } 46 | 47 | void sanityCheckIndices(const std::vector>& indices) { 48 | for (const auto batch : indices) { 49 | std::set used_indices; 50 | for (const auto index : batch) { 51 | EXPECT_GT(index, 0); 52 | EXPECT_LE(index, num_categories_); 53 | // Checks that the indices are generated with no repetitions. 54 | EXPECT_EQ(used_indices.count(index), 0); 55 | used_indices.insert(index); 56 | } 57 | } 58 | } 59 | 60 | std::vector computeHistogram( 61 | const std::vector>& generated_indices) { 62 | std::vector result(num_categories_ + 1, 0); 63 | for (const auto& sample : generated_indices) { 64 | for (const auto& index : sample) { 65 | result[index]++; 66 | } 67 | } 68 | return result; 69 | } 70 | 71 | void resetGenerator(FeatureGenerator* new_generator) { 72 | generator_.reset(new_generator); 73 | } 74 | 75 | // Generates indices according to the given batchsize. 76 | // Computes histogram of the generated indices. 77 | // Checks that the computed histogram normalized to probability distribution 78 | // is within delta compared to the provided expected probability. 79 | void checkGeneratorMatchesExpectation( 80 | const int batch_size, 81 | const std::vector& expected_probability, 82 | const double delta = 1e-4) { 83 | const auto histogram = 84 | this->computeHistogram(this->generateIndices(batch_size)); 85 | EXPECT_EQ(histogram.size(), expected_probability.size()); 86 | 87 | for (size_t i = 0; i < histogram.size(); i++) { 88 | if (i == 0) { 89 | // Category feature 0 is reserved. 90 | EXPECT_EQ(histogram[i], 0); 91 | } else { 92 | EXPECT_NEAR( 93 | static_cast(histogram[i]) / static_cast(batch_size), 94 | expected_probability[i], 95 | delta); 96 | } 97 | } 98 | } 99 | 100 | private: 101 | std::unique_ptr> generator_; 102 | int num_categories_; 103 | int num_hot_; 104 | }; 105 | 106 | TYPED_TEST_SUITE_P(DataGeneratorTest); 107 | 108 | // Test that psx power law generator follows the analytical distribution. 109 | TYPED_TEST_P(DataGeneratorTest, OneHotPsxPowerLawGenerationWorks) { 110 | using IndexType = typename TestFixture::IndexType; 111 | const int kNumCategories = 9; 112 | const int kNumHot = 1; 113 | const int kBatchSize = 4000000; 114 | const double kAlpha = 1.15; 115 | std::vector expected_probability(kNumCategories + 1, 0); 116 | 117 | double sum = 0; 118 | for (size_t i = 1; i < expected_probability.size(); i++) { 119 | // f(i) = i ^(-alpha). 120 | // expected_probability(i) = \integral_(i)^(i+1){f(x)dx}. 121 | expected_probability[i] = 122 | (-kAlpha) * pow(static_cast(i), (1 - kAlpha)) - 123 | (-kAlpha) * pow(static_cast(i + 1), (1 - kAlpha)); 124 | sum += expected_probability[i]; 125 | } 126 | // Normalize the probability distribution. 127 | for (auto& prob : expected_probability) { 128 | prob /= sum; 129 | } 130 | 131 | this->setUpParams(kNumCategories, kNumHot); 132 | this->resetGenerator( 133 | new index_generators::PowerLawFeatureGenerator( 134 | kNumCategories, kNumHot, kAlpha)); 135 | 136 | const double delta = 1e-3; // allow 0.1% off from expected. 137 | this->checkGeneratorMatchesExpectation( 138 | kBatchSize, expected_probability, delta); 139 | } 140 | 141 | // Check that multi-hot generators do not have repetitions. 142 | // Check that all indices generated are within range [1, num_category - 1]. 143 | TYPED_TEST_P(DataGeneratorTest, MultiHotGenerationWorks) { 144 | using IndexType = typename TestFixture::IndexType; 145 | const int kNumCategories = 1000; 146 | const int kNumHot = 64; 147 | const int kBatchSize = 40000; 148 | const double kAlpha = 1.15; 149 | 150 | std::vector*> generators({ 151 | new index_generators::PowerLawFeatureGenerator( 152 | kNumCategories, kNumHot, kAlpha, false, false, PowerLawType::kPsx), 153 | }); 154 | 155 | for (auto& generator : generators) { 156 | this->setUpParams(kNumCategories, kNumHot); 157 | this->resetGenerator(generator); 158 | this->sanityCheckIndices(this->generateIndices(kBatchSize)); 159 | } 160 | } 161 | 162 | REGISTER_TYPED_TEST_SUITE_P(DataGeneratorTest, 163 | OneHotPsxPowerLawGenerationWorks, 164 | MultiHotGenerationWorks); 165 | 166 | using Types = ::testing::Types; 167 | INSTANTIATE_TYPED_TEST_SUITE_P(DataGen, DataGeneratorTest, Types); 168 | 169 | } // namespace index_generators 170 | } // namespace cuembed 171 | -------------------------------------------------------------------------------- /tests/test_embedding_against_cpu.cu: -------------------------------------------------------------------------------- 1 | // clang-format off 2 | /* 3 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | * SPDX-License-Identifier: Apache-2.0 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | // clang-format on 19 | 20 | #include 21 | #include 22 | 23 | #include "absl/log/check.h" 24 | #include "absl/strings/str_format.h" 25 | #include "cuembed/include/embedding_lookup.cuh" 26 | #include "gtest/gtest.h" 27 | #include "utils/include/embedding_allocation.h" 28 | #include "utils/include/embedding_utils.h" 29 | 30 | // CPU reference implementations 31 | #include "utils/include/embedding_lookup_cpu.hpp" 32 | 33 | namespace cuembed { 34 | 35 | template 36 | class TestAgainstCpuRef 37 | : public ::testing::TestWithParam { 38 | protected: 39 | TestAgainstCpuRef() { 40 | options_ = GetParam(); 41 | AllocateHost(options_, &u_a); 42 | AllocateDevice(options_, u_a, &d_a); 43 | cuembed::utils::RunForwardReference( 44 | options_, 45 | u_a.embedding, 46 | u_a.indices, 47 | u_a.offsets, 48 | u_a.weights, 49 | &u_a.result); 50 | 51 | const int nnz = u_a.indices.size(); 52 | cuembed::utils::RunTransposeReference( 53 | options_, 54 | u_a.indices, 55 | u_a.offsets, 56 | u_a.weights, 57 | nnz, 58 | &u_a.transpose_indices, 59 | &u_a.transpose_remapped_indices, 60 | &u_a.transpose_sample_ids, 61 | &u_a.transpose_weights); 62 | 63 | cuembed::utils::RunBackwardReference( 64 | options_, 65 | u_a.grad_y, 66 | u_a.transpose_indices, 67 | u_a.transpose_remapped_indices, 68 | u_a.transpose_sample_ids, 69 | u_a.transpose_weights, 70 | u_a.offsets, 71 | nnz, 72 | &u_a.grad_embedding, 73 | &u_a.inverse_mapping); 74 | } 75 | 76 | void RunTestForward() { 77 | cuembed::utils::RunForward(options_, 78 | d_a.embedding, 79 | d_a.indices, 80 | d_a.offsets, 81 | d_a.weights, 82 | &d_a.result); 83 | CHECK_CUDA(cudaDeviceSynchronize()); 84 | CheckResultForward(); 85 | } 86 | 87 | void RunTestTranspose() { 88 | const int nnz = d_a.indices.size(); 89 | cuembed::utils::RunTranspose( 90 | options_, 91 | d_a.indices, 92 | d_a.offsets, 93 | d_a.weights, 94 | nnz, 95 | &d_a.transpose_indices, 96 | &d_a.transpose_remapped_indices, 97 | &d_a.transpose_sample_ids, 98 | &d_a.transpose_weights, 99 | &d_a.sample_ids, 100 | &d_a.transpose_workspace); 101 | CheckResultTranspose(); 102 | } 103 | 104 | void RunTestBackward() { 105 | const int nnz = d_a.indices.size(); 106 | cuembed::utils::RunTranspose( 107 | options_, 108 | d_a.indices, 109 | d_a.offsets, 110 | d_a.weights, 111 | nnz, 112 | &d_a.transpose_indices, 113 | &d_a.transpose_remapped_indices, 114 | &d_a.transpose_sample_ids, 115 | &d_a.transpose_weights, 116 | &d_a.sample_ids, 117 | &d_a.transpose_workspace); 118 | 119 | const int num_unique = options_.compressed_grad() 120 | ? d_a.transpose_remapped_indices.back() + 1 121 | : 0; 122 | cuembed::utils::RunBackward( 123 | options_, 124 | d_a.grad_y, 125 | d_a.transpose_indices, 126 | d_a.transpose_remapped_indices, 127 | d_a.transpose_sample_ids, 128 | d_a.transpose_weights, 129 | d_a.offsets, 130 | nnz, 131 | num_unique, 132 | &d_a.grad_embedding, 133 | &d_a.inverse_mapping); 134 | CheckResultBackward(); 135 | } 136 | 137 | private: 138 | void CheckResultForward() { 139 | if (options_.combine_mode() == CombineMode::kSum) { 140 | EXPECT_EQ(d_a.result.size(), 141 | options_.batch_size() * options_.embed_width()); 142 | } else if (options_.combine_mode() == CombineMode::kConcat) { 143 | EXPECT_EQ( 144 | d_a.result.size(), 145 | options_.batch_size() * options_.hotness() * options_.embed_width()); 146 | } else if (options_.combine_mode() == CombineMode::kMean) { 147 | EXPECT_EQ(d_a.result.size(), 148 | options_.batch_size() * options_.embed_width()); 149 | } else { 150 | EXPECT_TRUE(false) << "Reduce type not supported"; 151 | } 152 | 153 | // Weighted summation on host vs. device may not be exact. 154 | if (options_.is_weighted()) { 155 | const float tolerance = 1e-4f; 156 | EXPECT_TRUE(thrust::equal(d_a.result.begin(), 157 | d_a.result.end(), 158 | u_a.result.begin(), 159 | cuembed::utils::Near(tolerance))); 160 | } else { 161 | EXPECT_TRUE(thrust::equal( 162 | d_a.result.begin(), d_a.result.end(), u_a.result.begin())); 163 | } 164 | } 165 | 166 | void CheckResultTranspose() { 167 | EXPECT_TRUE(thrust::equal(d_a.transpose_indices.begin(), 168 | d_a.transpose_indices.end(), 169 | u_a.transpose_indices.begin())); 170 | EXPECT_TRUE(thrust::equal(d_a.transpose_remapped_indices.begin(), 171 | d_a.transpose_remapped_indices.end(), 172 | u_a.transpose_remapped_indices.begin())); 173 | 174 | // Check that sample_ids and weights sum to the same integer values 175 | // This allows transpose to order sample ids differently within an index 176 | IndexT d_sum = 0; 177 | IndexT ref_sum = 0; 178 | int64_t wt_sum = 0; 179 | int64_t ref_wt_sum = 0; 180 | for (int i = 0; i < u_a.transpose_sample_ids.size(); i++) { 181 | if (i > 0 && (u_a.transpose_indices[i - 1] != u_a.transpose_indices[i])) { 182 | EXPECT_TRUE(d_sum == ref_sum); 183 | d_sum = 0; 184 | ref_sum = 0; 185 | 186 | if (options_.is_weighted()) { 187 | EXPECT_TRUE(wt_sum == ref_wt_sum); 188 | wt_sum = 0; 189 | ref_wt_sum = 0; 190 | } 191 | } 192 | 193 | d_sum += d_a.transpose_sample_ids[i]; 194 | ref_sum += u_a.transpose_sample_ids[i]; 195 | 196 | if (options_.is_weighted()) { 197 | ElemT wt = d_a.transpose_weights[i]; 198 | ElemT ref_wt = u_a.transpose_weights[i]; 199 | int64_t wt_int = 0; 200 | int64_t ref_wt_int = 0; 201 | memcpy(&wt_int, &wt, std::min(sizeof(int64_t), sizeof(ElemT))); 202 | memcpy(&ref_wt_int, &ref_wt, std::min(sizeof(int64_t), sizeof(ElemT))); 203 | wt_sum += wt_int; 204 | ref_wt_sum += ref_wt_int; 205 | } 206 | } 207 | } 208 | 209 | void CheckResultBackward() { 210 | EXPECT_TRUE(thrust::equal(d_a.grad_embedding.begin(), 211 | d_a.grad_embedding.end(), 212 | u_a.grad_embedding.begin())); 213 | if (options_.compressed_grad()) { 214 | EXPECT_TRUE(thrust::equal(d_a.inverse_mapping.begin(), 215 | d_a.inverse_mapping.end(), 216 | u_a.inverse_mapping.begin())); 217 | } 218 | } 219 | 220 | utils::UniversalEmbeddingAllocation 221 | u_a; 222 | utils::DeviceEmbeddingAllocation d_a; 223 | utils::AllocationOptions options_; 224 | }; 225 | 226 | // Some macros to make test allocation statement more concise. 227 | #define ALLOC \ 228 | (utils::AllocationOptions().num_categories(20_K).skip_grad_init(false)) 229 | #define SUM CombineMode::kSum 230 | #define CONCAT CombineMode::kConcat 231 | #define AVG CombineMode::kMean 232 | #define CSR is_csr(true) 233 | #define WEIGHTED is_weighted(true) 234 | #define CMPGRAD compressed_grad(true) 235 | #define BS batch_size 236 | auto lookup_test_values = testing::Values( 237 | ALLOC.BS(3).embed_width(2).hotness(4).combine_mode(SUM), 238 | ALLOC.BS(3).embed_width(2).hotness(4).combine_mode(SUM).CSR, 239 | ALLOC.BS(3).embed_width(2).hotness(4).combine_mode(SUM).WEIGHTED, 240 | ALLOC.BS(3).embed_width(2).hotness(4).combine_mode(SUM).CSR.WEIGHTED, 241 | ALLOC.BS(3).embed_width(2).hotness(4).combine_mode(AVG), 242 | ALLOC.BS(3).embed_width(2).hotness(4).combine_mode(AVG).CSR, 243 | ALLOC.BS(3).embed_width(2).hotness(4).combine_mode(CONCAT), 244 | ALLOC.BS(3).embed_width(2).hotness(4).combine_mode(SUM).CMPGRAD, 245 | ALLOC.BS(3).embed_width(2).hotness(4).combine_mode(CONCAT).CMPGRAD, 246 | ALLOC.BS(3).embed_width(4).hotness(4).combine_mode(SUM), 247 | ALLOC.BS(3).embed_width(4).hotness(4).combine_mode(SUM).CSR, 248 | ALLOC.BS(3).embed_width(4).hotness(4).combine_mode(AVG), 249 | ALLOC.BS(3).embed_width(4).hotness(4).combine_mode(AVG).CSR, 250 | ALLOC.BS(3).embed_width(4).hotness(4).combine_mode(SUM).WEIGHTED, 251 | ALLOC.BS(3).embed_width(4).hotness(4).combine_mode(SUM).CSR.WEIGHTED, 252 | ALLOC.BS(3).embed_width(4).hotness(4).combine_mode(CONCAT), 253 | ALLOC.BS(3).embed_width(4).hotness(4).combine_mode(SUM).CMPGRAD, 254 | ALLOC.BS(3).embed_width(4).hotness(4).combine_mode(CONCAT).CMPGRAD, 255 | ALLOC.BS(1023).embed_width(32).hotness(26).combine_mode(SUM), 256 | ALLOC.BS(1023).embed_width(32).hotness(26).combine_mode(SUM).CSR, 257 | ALLOC.BS(1023).embed_width(32).hotness(26).combine_mode(SUM).WEIGHTED, 258 | ALLOC.BS(1023).embed_width(32).hotness(26).combine_mode(SUM).CSR.WEIGHTED, 259 | ALLOC.BS(1023).embed_width(32).hotness(26).combine_mode(AVG), 260 | ALLOC.BS(1023).embed_width(32).hotness(26).combine_mode(AVG).CSR, 261 | ALLOC.BS(1023).embed_width(32).hotness(26).combine_mode(CONCAT), 262 | ALLOC.BS(1023).embed_width(32).hotness(26).combine_mode(SUM).CMPGRAD, 263 | ALLOC.BS(1023).embed_width(32).hotness(26).combine_mode(CONCAT).CMPGRAD, 264 | ALLOC.BS(1023).embed_width(36).hotness(26).combine_mode(SUM), 265 | ALLOC.BS(1023).embed_width(36).hotness(26).combine_mode(SUM).CSR, 266 | ALLOC.BS(1023).embed_width(36).hotness(26).combine_mode(SUM).WEIGHTED, 267 | ALLOC.BS(1023).embed_width(36).hotness(26).combine_mode(SUM).CSR.WEIGHTED, 268 | ALLOC.BS(1023).embed_width(36).hotness(26).combine_mode(AVG), 269 | ALLOC.BS(1023).embed_width(36).hotness(26).combine_mode(AVG).CSR, 270 | ALLOC.BS(1023).embed_width(36).hotness(26).combine_mode(CONCAT), 271 | ALLOC.BS(1023).embed_width(36).hotness(26).combine_mode(SUM).CMPGRAD, 272 | ALLOC.BS(1023).embed_width(36).hotness(26).combine_mode(CONCAT).CMPGRAD, 273 | ALLOC.BS(3).embed_width(512).hotness(63).combine_mode(SUM), 274 | ALLOC.BS(3).embed_width(512).hotness(63).combine_mode(SUM).CSR, 275 | ALLOC.BS(3).embed_width(512).hotness(63).combine_mode(SUM).WEIGHTED, 276 | ALLOC.BS(3).embed_width(512).hotness(63).combine_mode(SUM).CSR.WEIGHTED, 277 | ALLOC.BS(3).embed_width(512).hotness(63).combine_mode(AVG), 278 | ALLOC.BS(3).embed_width(512).hotness(63).combine_mode(AVG).CSR, 279 | ALLOC.BS(3).embed_width(512).hotness(63).combine_mode(CONCAT), 280 | ALLOC.BS(3).embed_width(512).hotness(63).combine_mode(SUM).CMPGRAD, 281 | ALLOC.BS(3).embed_width(512).hotness(63).combine_mode(CONCAT).CMPGRAD, 282 | ALLOC.BS(1023).embed_width(512).hotness(63).combine_mode(SUM), 283 | ALLOC.BS(1023).embed_width(512).hotness(63).combine_mode(SUM).CSR, 284 | ALLOC.BS(1023).embed_width(512).hotness(63).combine_mode(SUM).WEIGHTED, 285 | ALLOC.BS(1023).embed_width(512).hotness(63).combine_mode(SUM).CSR.WEIGHTED, 286 | ALLOC.BS(1023).embed_width(512).hotness(63).combine_mode(CONCAT), 287 | ALLOC.BS(1023).embed_width(514).hotness(63).combine_mode(SUM), 288 | ALLOC.BS(1023).embed_width(514).hotness(63).combine_mode(SUM).CSR, 289 | ALLOC.BS(1023).embed_width(514).hotness(63).combine_mode(SUM).WEIGHTED, 290 | ALLOC.BS(1023).embed_width(514).hotness(63).combine_mode(SUM).CSR.WEIGHTED, 291 | ALLOC.BS(1023).embed_width(514).hotness(63).combine_mode(CONCAT), 292 | ALLOC.BS(1023).embed_width(514).hotness(63).combine_mode(SUM).CMPGRAD, 293 | ALLOC.BS(1023).embed_width(514).hotness(63).combine_mode(CONCAT).CMPGRAD); 294 | #undef ALLOC 295 | #undef SUM 296 | #undef CONCAT 297 | #undef AVG 298 | #undef CSR 299 | #undef WEIGHTED 300 | #undef CMPGRAD 301 | #undef BS 302 | 303 | class TestAgainstCpuEmbed32Idx32 304 | : public TestAgainstCpuRef {}; 305 | class TestAgainstCpuEmbed32Idx64 306 | : public TestAgainstCpuRef {}; 307 | class TestAgainstCpuEmbed16Idx32Reduce32 308 | : public TestAgainstCpuRef<__half, int32_t, true> {}; 309 | class TestAgainstCpuEmbed16Idx64Reduce32 310 | : public TestAgainstCpuRef<__half, int64_t, true> {}; 311 | class TestAgainstCpuEmbed16Idx32Reduce16 312 | : public TestAgainstCpuRef<__half, int32_t, false> {}; 313 | class TestAgainstCpuEmbed16Idx64Reduce16 314 | : public TestAgainstCpuRef<__half, int64_t, false> {}; 315 | 316 | std::string ToString(CombineMode value) { 317 | switch (value) { 318 | case CombineMode::kSum: 319 | return "Sum"; 320 | case CombineMode::kConcat: 321 | return "Concat"; 322 | case CombineMode::kMean: 323 | return "Mean"; 324 | default: 325 | return "Unknown"; 326 | } 327 | } 328 | 329 | std::string GenerateTestName(const utils::AllocationOptions& options) { 330 | std::string result = absl::StrFormat( 331 | "Width%dBatch%dHot%d%s%s%s%s", 332 | options.embed_width(), 333 | options.batch_size(), 334 | options.hotness(), 335 | ToString(options.combine_mode()), 336 | options.is_csr() ? "CSR" : "FixedHot", 337 | options.is_weighted() ? "Weight" : "NoWeight", 338 | options.compressed_grad() ? "CompressedGrad" : "FullGrad"); 339 | return result; 340 | } 341 | 342 | INSTANTIATE_TEST_SUITE_P( 343 | EmbeddingLookup, 344 | TestAgainstCpuEmbed32Idx32, 345 | lookup_test_values, 346 | [](const testing::TestParamInfo& 347 | info) { return GenerateTestName(info.param); }); 348 | 349 | INSTANTIATE_TEST_SUITE_P( 350 | EmbeddingLookup, 351 | TestAgainstCpuEmbed32Idx64, 352 | lookup_test_values, 353 | [](const testing::TestParamInfo& 354 | info) { return GenerateTestName(info.param); }); 355 | 356 | INSTANTIATE_TEST_SUITE_P( 357 | EmbeddingLookup, 358 | TestAgainstCpuEmbed16Idx32Reduce16, 359 | lookup_test_values, 360 | [](const testing::TestParamInfo< 361 | TestAgainstCpuEmbed16Idx32Reduce16::ParamType>& info) { 362 | return GenerateTestName(info.param); 363 | }); 364 | 365 | INSTANTIATE_TEST_SUITE_P( 366 | EmbeddingLookup, 367 | TestAgainstCpuEmbed16Idx64Reduce16, 368 | lookup_test_values, 369 | [](const testing::TestParamInfo< 370 | TestAgainstCpuEmbed16Idx64Reduce16::ParamType>& info) { 371 | return GenerateTestName(info.param); 372 | }); 373 | 374 | INSTANTIATE_TEST_SUITE_P( 375 | EmbeddingLookup, 376 | TestAgainstCpuEmbed16Idx32Reduce32, 377 | lookup_test_values, 378 | [](const testing::TestParamInfo< 379 | TestAgainstCpuEmbed16Idx32Reduce32::ParamType>& info) { 380 | return GenerateTestName(info.param); 381 | }); 382 | 383 | INSTANTIATE_TEST_SUITE_P( 384 | EmbeddingLookup, 385 | TestAgainstCpuEmbed16Idx64Reduce32, 386 | lookup_test_values, 387 | [](const testing::TestParamInfo< 388 | TestAgainstCpuEmbed16Idx64Reduce32::ParamType>& info) { 389 | return GenerateTestName(info.param); 390 | }); 391 | 392 | TEST_P(TestAgainstCpuEmbed32Idx32, TestForward) { RunTestForward(); } 393 | TEST_P(TestAgainstCpuEmbed32Idx64, TestForward) { RunTestForward(); } 394 | TEST_P(TestAgainstCpuEmbed16Idx32Reduce16, TestForward) { RunTestForward(); } 395 | TEST_P(TestAgainstCpuEmbed16Idx64Reduce16, TestForward) { RunTestForward(); } 396 | TEST_P(TestAgainstCpuEmbed16Idx32Reduce32, TestForward) { RunTestForward(); } 397 | TEST_P(TestAgainstCpuEmbed16Idx64Reduce32, TestForward) { RunTestForward(); } 398 | 399 | TEST_P(TestAgainstCpuEmbed32Idx32, TestTranspose) { RunTestTranspose(); } 400 | TEST_P(TestAgainstCpuEmbed32Idx64, TestTranspose) { RunTestTranspose(); } 401 | TEST_P(TestAgainstCpuEmbed16Idx32Reduce16, TestTranspose) { 402 | RunTestTranspose(); 403 | } 404 | TEST_P(TestAgainstCpuEmbed16Idx64Reduce16, TestTranspose) { 405 | RunTestTranspose(); 406 | } 407 | TEST_P(TestAgainstCpuEmbed16Idx32Reduce32, TestTranspose) { 408 | RunTestTranspose(); 409 | } 410 | TEST_P(TestAgainstCpuEmbed16Idx64Reduce32, TestTranspose) { 411 | RunTestTranspose(); 412 | } 413 | 414 | TEST_P(TestAgainstCpuEmbed32Idx32, TestBackward) { RunTestBackward(); } 415 | TEST_P(TestAgainstCpuEmbed32Idx64, TestBackward) { RunTestBackward(); } 416 | TEST_P(TestAgainstCpuEmbed16Idx32Reduce16, TestBackward) { RunTestBackward(); } 417 | TEST_P(TestAgainstCpuEmbed16Idx64Reduce16, TestBackward) { RunTestBackward(); } 418 | TEST_P(TestAgainstCpuEmbed16Idx32Reduce32, TestBackward) { RunTestBackward(); } 419 | TEST_P(TestAgainstCpuEmbed16Idx64Reduce32, TestBackward) { RunTestBackward(); } 420 | } // namespace cuembed 421 | -------------------------------------------------------------------------------- /tests/test_embedding_allocation.cu: -------------------------------------------------------------------------------- 1 | // clang-format off 2 | /* 3 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | * SPDX-License-Identifier: Apache-2.0 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | // clang-format on 19 | 20 | #include "gtest/gtest.h" 21 | #include "utils/include/embedding_allocation.h" 22 | #include "utils/include/embedding_utils.h" 23 | 24 | namespace cuembed { 25 | 26 | namespace utils { 27 | 28 | class EmbeddingAllocationTest : public ::testing::TestWithParam { 29 | public: 30 | EmbeddingAllocationTest() { 31 | options_.num_categories(kNumCategories) 32 | .batch_size(kBatchSize) 33 | .hotness(kHotness) 34 | .alpha(kAlpha) 35 | .embed_width(kWidth); 36 | } 37 | void FinishSetup(const CombineMode combine_mode, 38 | const bool is_csr, 39 | const bool compressed_grad) { 40 | options_.combine_mode(combine_mode); 41 | options_.is_csr(is_csr); 42 | options_.compressed_grad(compressed_grad); 43 | AllocateHost(options_, &u_a); 44 | } 45 | 46 | void RunTest(const CombineMode combine_mode, 47 | const bool is_csr, 48 | const bool compressed_grad) { 49 | FinishSetup(combine_mode, is_csr, compressed_grad); 50 | ValidateOptions(combine_mode, is_csr, compressed_grad); 51 | ValidateAllocations(combine_mode, is_csr, compressed_grad); 52 | ValidateIndices(is_csr); 53 | ValidateWeights(); 54 | } 55 | 56 | private: 57 | void ValidateOptions(const CombineMode combine_mode, 58 | const bool is_csr, 59 | const bool compressed_grad) { 60 | EXPECT_EQ(options_.num_categories(), kNumCategories); 61 | EXPECT_EQ(options_.batch_size(), kBatchSize); 62 | EXPECT_EQ(options_.alpha(), kAlpha); 63 | EXPECT_EQ(options_.embed_width(), kWidth); 64 | EXPECT_EQ(options_.combine_mode(), combine_mode); 65 | EXPECT_EQ(options_.is_csr(), is_csr); 66 | EXPECT_EQ(options_.compressed_grad(), compressed_grad); 67 | } 68 | 69 | void ValidateAllocations(const CombineMode combine_mode, 70 | const bool is_csr, 71 | const bool compressed_grad) { 72 | EXPECT_EQ(u_a.embedding.size(), kNumCategories * kWidth); 73 | if (compressed_grad) { 74 | EXPECT_LE(u_a.grad_embedding.size(), u_a.indices.size() * kWidth); 75 | EXPECT_LE(u_a.inverse_mapping.size(), u_a.indices.size()); 76 | } else { 77 | EXPECT_EQ(u_a.grad_embedding.size(), kNumCategories * kWidth); 78 | EXPECT_EQ(u_a.inverse_mapping.size(), 0); 79 | } 80 | if (is_csr) { 81 | EXPECT_LE(u_a.indices.size(), kBatchSize * kHotness); 82 | EXPECT_EQ(u_a.indices.size(), u_a.offsets.back()); 83 | EXPECT_EQ(u_a.transpose_remapped_indices.size(), u_a.offsets.back()); 84 | EXPECT_EQ(u_a.transpose_sample_ids.size(), u_a.offsets.back()); 85 | EXPECT_EQ(u_a.result.size(), kBatchSize * kWidth); 86 | } else { 87 | EXPECT_EQ(u_a.indices.size(), kBatchSize * kHotness); 88 | EXPECT_EQ(u_a.transpose_indices.size(), kBatchSize * kHotness); 89 | EXPECT_EQ(u_a.transpose_remapped_indices.size(), kBatchSize * kHotness); 90 | EXPECT_EQ(u_a.transpose_sample_ids.size(), kBatchSize * kHotness); 91 | if (combine_mode == CombineMode::kConcat) { 92 | EXPECT_EQ(u_a.result.size(), kBatchSize * kWidth * kHotness); 93 | EXPECT_EQ(u_a.grad_y.size(), kBatchSize * kWidth * kHotness); 94 | } else { 95 | EXPECT_EQ(u_a.result.size(), kBatchSize * kWidth); 96 | EXPECT_EQ(u_a.grad_y.size(), kBatchSize * kWidth); 97 | } 98 | } 99 | } 100 | 101 | void ValidateIndices(const bool is_csr) { 102 | for (int sample_id = 0; sample_id < kBatchSize; sample_id++) { 103 | const int hotness_for_sample = 104 | is_csr ? (u_a.offsets[sample_id + 1] - u_a.offsets[sample_id]) 105 | : options_.hotness(); 106 | const int index_start = 107 | is_csr ? u_a.offsets[sample_id] : sample_id * options_.hotness(); 108 | EXPECT_GE(hotness_for_sample, 0); 109 | 110 | // Check for no repetitions for a sample. 111 | std::set used_indices; 112 | for (int i_hotness = 0; i_hotness < hotness_for_sample; i_hotness++) { 113 | auto index = u_a.indices[index_start + i_hotness]; 114 | EXPECT_TRUE(used_indices.find(index) == used_indices.end()); 115 | used_indices.insert(index); 116 | } 117 | } 118 | } 119 | 120 | void ValidateWeights() { 121 | EXPECT_EQ(u_a.weights.size(), u_a.indices.size()); 122 | for (auto weight : u_a.weights) { 123 | EXPECT_GE(weight, 0.0f); 124 | EXPECT_LE(weight, 1.0f); 125 | } 126 | } 127 | 128 | AllocationOptions options_; 129 | UniversalEmbeddingAllocation u_a; 130 | 131 | const int32_t kNumCategories = 7_M; 132 | const int32_t kBatchSize = 65_K; 133 | const int32_t kHotness = 32; 134 | const float kAlpha = 1.5f; 135 | const int32_t kWidth = 16; 136 | }; 137 | 138 | TEST_P(EmbeddingAllocationTest, FixedHotnessWorks) { 139 | auto combine_mode = GetParam(); 140 | const bool is_csr = false; 141 | const bool compressed_grad = false; 142 | RunTest(combine_mode, is_csr, compressed_grad); 143 | } 144 | 145 | TEST_P(EmbeddingAllocationTest, CSRWorks) { 146 | auto combine_mode = GetParam(); 147 | if (combine_mode == CombineMode::kConcat) { 148 | return; 149 | } 150 | const bool is_csr = true; 151 | const bool compressed_grad = false; 152 | RunTest(combine_mode, is_csr, compressed_grad); 153 | } 154 | 155 | TEST_P(EmbeddingAllocationTest, SparseGradientWorks) { 156 | auto combine_mode = GetParam(); 157 | const bool is_csr = false; 158 | const bool compressed_grad = false; 159 | RunTest(combine_mode, is_csr, compressed_grad); 160 | } 161 | 162 | INSTANTIATE_TEST_SUITE_P( 163 | EmbeddingAllocationTest, 164 | EmbeddingAllocationTest, 165 | ::testing::Values( 166 | CombineMode::kConcat, 167 | CombineMode::kSum) // This runs the test for each of these values 168 | ); 169 | 170 | } // namespace utils 171 | } // namespace cuembed 172 | -------------------------------------------------------------------------------- /tests/test_embedding_backward.cu: -------------------------------------------------------------------------------- 1 | // clang-format off 2 | /* 3 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | * SPDX-License-Identifier: Apache-2.0 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | // clang-format on 19 | 20 | #include 21 | #include 22 | 23 | #include "absl/log/check.h" 24 | #include "absl/strings/str_format.h" 25 | #include "cuembed/include/embedding_lookup.cuh" 26 | #include "cuembed/include/index_transforms.cuh" 27 | #include "gtest/gtest.h" 28 | #include "utils/include/embedding_allocation.h" 29 | #include "utils/include/embedding_utils.h" 30 | 31 | // CPU reference implementations 32 | #include "utils/include/embedding_lookup_cpu.hpp" 33 | #include "utils/include/index_transforms_cpu.hpp" 34 | 35 | namespace cuembed { 36 | 37 | enum class DeviceType { kGPU, kCPU }; 38 | 39 | template 40 | class EmbeddingBackwardRefTest : public ::testing::Test { 41 | public: 42 | void LaunchTest(const CombineMode mode, 43 | const DeviceType device, 44 | const bool is_weighted, 45 | const bool compressed_grad, 46 | const bool skip_grad_init = false) { 47 | this->CreateResult(compressed_grad, skip_grad_init); 48 | this->LaunchKernel( 49 | mode, device, is_weighted, compressed_grad, skip_grad_init); 50 | this->CheckResult(mode, is_weighted, compressed_grad, skip_grad_init); 51 | } 52 | 53 | typedef typename T::EmbedType EmbedType; 54 | typedef typename T::IndexType IndexType; 55 | typedef int OffsetType; 56 | 57 | private: 58 | void CreateResult(const bool compressed_grad, const bool skip_grad_init) { 59 | grad_embedding_.reset(new thrust::universal_vector( 60 | num_categories_ * embed_width_)); 61 | // If kernel doesn't initialize the gradient then do it here 62 | if (skip_grad_init) { 63 | thrust::fill(grad_embedding_->begin(), grad_embedding_->end(), 0); 64 | } 65 | if (compressed_grad) { 66 | inverse_mapping_.reset( 67 | new thrust::universal_vector(num_categories_)); 68 | } 69 | } 70 | void LaunchKernel(const CombineMode mode, 71 | const DeviceType device, 72 | const bool is_weighted, 73 | const bool compressed_grad, 74 | const bool skip_grad_init) { 75 | const EmbedType* transpose_weights = 76 | (!is_weighted) ? nullptr : transpose_weights_.data().get(); 77 | const thrust::universal_vector* grad_y = nullptr; 78 | const thrust::universal_vector* transpose_sample_ids = nullptr; 79 | if (mode == CombineMode::kSum || mode == CombineMode::kMean) { 80 | grad_y = &grad_y_sum_; 81 | transpose_sample_ids = &transpose_sample_ids_; 82 | } else if (mode == CombineMode::kConcat) { 83 | grad_y = &grad_y_concat_; 84 | transpose_sample_ids = &transpose_sample_ids_concat_; 85 | } 86 | const IndexType* transpose_remapped_indices = 87 | (compressed_grad) ? transpose_remapped_indices_.data().get() : nullptr; 88 | const int num_grad_embedding_rows = 89 | (compressed_grad) ? num_unique_ : num_categories_; 90 | IndexType* inverse_mapping = 91 | (compressed_grad) ? inverse_mapping_->data().get() : nullptr; 92 | if (device == DeviceType::kCPU) { 93 | EmbeddingBackwardCpu( 94 | grad_y->data().get(), 95 | embed_width_, 96 | num_grad_embedding_rows, 97 | nnz_, 98 | transpose_indices_.data().get(), 99 | transpose_sample_ids->data().get(), 100 | transpose_remapped_indices, 101 | transpose_weights, 102 | skip_grad_init, 103 | grad_embedding_->data().get(), 104 | inverse_mapping); 105 | } else if (device == DeviceType::kGPU) { 106 | EmbeddingBackward( 107 | grad_y->data().get(), 108 | embed_width_, 109 | num_grad_embedding_rows, 110 | nnz_, 111 | transpose_indices_.data().get(), 112 | transpose_sample_ids->data().get(), 113 | transpose_remapped_indices, 114 | transpose_weights, 115 | skip_grad_init, 116 | grad_embedding_->data().get(), 117 | inverse_mapping); 118 | CHECK_CUDA(cudaDeviceSynchronize()); 119 | } else { 120 | CHECK(false) << "not supported device"; 121 | } 122 | } 123 | void CheckResult(const CombineMode mode, 124 | const bool is_weighted, 125 | const bool compressed_grad, 126 | const bool skip_grad_init) { 127 | const thrust::universal_vector* reference = nullptr; 128 | if (compressed_grad) { 129 | if (mode == CombineMode::kSum && is_weighted) { 130 | reference = &ref_compressed_grad_embedding_sum_weighted_; 131 | } else if (mode == CombineMode::kSum && (!is_weighted)) { 132 | reference = &ref_compressed_grad_embedding_sum_; 133 | } else if (mode == CombineMode::kConcat && (!is_weighted)) { 134 | reference = &ref_compressed_grad_embedding_concat_; 135 | } else if (mode == CombineMode::kConcat && is_weighted) { 136 | reference = &ref_compressed_grad_embedding_concat_weighted_; 137 | } 138 | } else { 139 | if (mode == CombineMode::kSum && is_weighted) { 140 | reference = &ref_grad_embedding_sum_weighted_; 141 | } else if (mode == CombineMode::kSum && (!is_weighted)) { 142 | reference = &ref_grad_embedding_sum_; 143 | } else if (mode == CombineMode::kConcat && (!is_weighted)) { 144 | reference = &ref_grad_embedding_concat_; 145 | } else if (mode == CombineMode::kConcat && is_weighted) { 146 | reference = &ref_grad_embedding_concat_weighted_; 147 | } 148 | } 149 | if (reference != nullptr) { 150 | for (size_t i = 0; i < grad_embedding_->size(); ++i) { 151 | EXPECT_EQ(grad_embedding_->data()[i], reference->data()[i]); 152 | } 153 | } 154 | if (compressed_grad) { 155 | for (size_t i = 0; i < num_unique_; ++i) { 156 | EXPECT_EQ(inverse_mapping_->data()[i], ref_inverse_mapping_.data()[i]); 157 | } 158 | } 159 | } 160 | 161 | private: 162 | const int embed_width_{4}; 163 | const int hotness_{2}; 164 | const int num_categories_{5}; 165 | const int batch_size_{2}; 166 | const int nnz_{4}; 167 | const int num_unique_{3}; 168 | const thrust::universal_vector transpose_indices_{0, 1, 3, 3}; 169 | const thrust::universal_vector transpose_remapped_indices_{ 170 | 0, 1, 2, 2}; 171 | const thrust::universal_vector transpose_sample_ids_{1, 0, 0, 1}; 172 | const thrust::universal_vector transpose_sample_ids_concat_{ 173 | 2, 0, 1, 3}; 174 | const thrust::universal_vector transpose_weights_{ 175 | 3.f, 1.f, 0.5f, 3.f}; 176 | const thrust::universal_vector grad_y_sum_{ 177 | 1., 2., 3., 4., 5., 6., 7., 8.}; 178 | const thrust::universal_vector grad_y_concat_{ 179 | 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16.}; 180 | thrust::universal_vector ref_grad_embedding_sum_{ 181 | 5., 6., 7., 8., 1., 2., 3., 4., 0., 0., 182 | 0., 0., 6., 8., 10., 12., 0., 0., 0., 0.}; 183 | thrust::universal_vector ref_grad_embedding_sum_weighted_{ 184 | 15., 18., 21., 24., 1., 2., 3., 4., 0., 0., 185 | 0., 0., 15.5, 19., 22.5, 26., 0., 0., 0., 0.}; 186 | thrust::universal_vector ref_grad_embedding_concat_{ 187 | 9., 10., 11., 12., 1., 2., 3., 4., 0., 0., 188 | 0., 0., 18., 20., 22., 24., 0., 0., 0., 0.}; 189 | thrust::universal_vector ref_grad_embedding_concat_weighted_{ 190 | 27., 30., 33., 36., 1., 2., 3., 4., 0., 0., 191 | 0., 0., 41.5, 45., 48.5, 52., 0., 0., 0., 0.}; 192 | thrust::universal_vector ref_inverse_mapping_{0, 1, 3}; 193 | thrust::universal_vector ref_compressed_grad_embedding_sum_{ 194 | 5., 6., 7., 8., 1., 2., 3., 4., 6., 8., 10., 12.}; 195 | thrust::universal_vector 196 | ref_compressed_grad_embedding_sum_weighted_{ 197 | 15., 18., 21., 24., 1., 2., 3., 4., 15.5, 19., 22.5, 26.}; 198 | thrust::universal_vector ref_compressed_grad_embedding_concat_{ 199 | 9., 10., 11., 12., 1., 2., 3., 4., 18., 20., 22., 24.}; 200 | thrust::universal_vector 201 | ref_compressed_grad_embedding_concat_weighted_{ 202 | 27., 30., 33., 36., 1., 2., 3., 4., 41.5, 45., 48.5, 52.}; 203 | 204 | std::unique_ptr> grad_embedding_; 205 | std::unique_ptr> inverse_mapping_; 206 | }; 207 | 208 | TYPED_TEST_SUITE_P(EmbeddingBackwardRefTest); 209 | 210 | TYPED_TEST_P(EmbeddingBackwardRefTest, TestFixedHotnessAgainstRefCpu) { 211 | for (const auto compressed_grad : {false}) { 212 | for (const auto weighted : {false, true}) { 213 | for (const auto mode : {CombineMode::kSum, CombineMode::kConcat}) { 214 | this->LaunchTest(mode, DeviceType::kCPU, weighted, compressed_grad); 215 | } 216 | } 217 | } 218 | } 219 | 220 | TYPED_TEST_P(EmbeddingBackwardRefTest, TestFixedHotnessAgainstRefGpu) { 221 | for (const auto compressed_grad : {false}) { 222 | for (const auto weighted : {false, true}) { 223 | for (const auto mode : {CombineMode::kSum, CombineMode::kConcat}) { 224 | this->LaunchTest(mode, DeviceType::kGPU, weighted, compressed_grad); 225 | } 226 | } 227 | } 228 | } 229 | 230 | TYPED_TEST_P(EmbeddingBackwardRefTest, TestCSRAgainstRefCpu) { 231 | for (const auto compressed_grad : {false}) { 232 | for (const auto weighted : {false, true}) { 233 | for (const auto mode : {CombineMode::kSum, CombineMode::kConcat}) { 234 | this->LaunchTest(mode, DeviceType::kCPU, weighted, compressed_grad); 235 | } 236 | } 237 | } 238 | } 239 | 240 | TYPED_TEST_P(EmbeddingBackwardRefTest, TestCSRAgainstRefGpu) { 241 | for (const auto compressed_grad : {false}) { 242 | for (const auto weighted : {false, true}) { 243 | for (const auto mode : {CombineMode::kSum, CombineMode::kConcat}) { 244 | this->LaunchTest(mode, DeviceType::kGPU, weighted, compressed_grad); 245 | } 246 | } 247 | } 248 | } 249 | 250 | TYPED_TEST_P(EmbeddingBackwardRefTest, TestSkipInitGradAgainstRefCpu) { 251 | const bool weighted = false; 252 | const auto mode = CombineMode::kSum; 253 | for (const auto skip_grad_init : {false, true}) { 254 | for (const auto compressed_grad : {false, true}) { 255 | this->LaunchTest( 256 | mode, DeviceType::kCPU, weighted, compressed_grad, skip_grad_init); 257 | } 258 | } 259 | } 260 | 261 | TYPED_TEST_P(EmbeddingBackwardRefTest, TestSkipInitGradAgainstRefGpu) { 262 | const bool weighted = false; 263 | const auto mode = CombineMode::kSum; 264 | for (const auto skip_grad_init : {false, true}) { 265 | for (const auto compressed_grad : {false, true}) { 266 | this->LaunchTest( 267 | mode, DeviceType::kGPU, weighted, compressed_grad, skip_grad_init); 268 | } 269 | } 270 | } 271 | 272 | REGISTER_TYPED_TEST_SUITE_P(EmbeddingBackwardRefTest, 273 | TestFixedHotnessAgainstRefCpu, 274 | TestFixedHotnessAgainstRefGpu, 275 | TestCSRAgainstRefCpu, 276 | TestCSRAgainstRefGpu, 277 | TestSkipInitGradAgainstRefCpu, 278 | TestSkipInitGradAgainstRefGpu); 279 | 280 | template 281 | struct EmbedTestTypeCombo { 282 | typedef EmbedT EmbedType; 283 | typedef IndexT IndexType; 284 | }; 285 | 286 | typedef ::testing::Types, 287 | EmbedTestTypeCombo, 288 | EmbedTestTypeCombo<__half, int32_t>, 289 | EmbedTestTypeCombo<__half, int64_t>> 290 | EmbedTestTypes; 291 | 292 | char EmbeddingBackwardRefTestNameGlobal[] = "EmbeddingBackwardRefTest_"; 293 | INSTANTIATE_TYPED_TEST_SUITE_P( 294 | EmbeddingBackward, 295 | EmbeddingBackwardRefTest, 296 | EmbedTestTypes, 297 | utils::EmbeddingRefTestNames); 298 | 299 | } // namespace cuembed 300 | -------------------------------------------------------------------------------- /tests/test_embedding_forward.cu: -------------------------------------------------------------------------------- 1 | // clang-format off 2 | /* 3 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | * SPDX-License-Identifier: Apache-2.0 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | // clang-format on 19 | 20 | #include 21 | #include 22 | 23 | #include "absl/log/check.h" 24 | #include "absl/strings/str_format.h" 25 | #include "cuembed/include/embedding_lookup.cuh" 26 | #include "gtest/gtest.h" 27 | #include "utils/include/embedding_allocation.h" 28 | #include "utils/include/embedding_utils.h" 29 | 30 | // CPU reference implementations 31 | #include "utils/include/embedding_lookup_cpu.hpp" 32 | 33 | namespace cuembed { 34 | 35 | enum class DeviceType { kGPU, kCPU }; 36 | 37 | template 38 | class EmbeddingRefTest : public ::testing::Test { 39 | public: 40 | void LaunchTest(const CombineMode mode, 41 | const DeviceType device, 42 | const bool is_csr, 43 | const bool is_weighted) { 44 | this->CreateResult(mode); 45 | this->LaunchKernel(mode, device, is_csr, is_weighted); 46 | this->CheckResult(mode, is_weighted); 47 | } 48 | 49 | typedef typename T::EmbedType EmbedType; 50 | typedef typename T::IndexType IndexType; 51 | typedef int OffsetType; 52 | 53 | private: 54 | void CreateResult(const CombineMode mode) { 55 | if (mode != CombineMode::kConcat) { 56 | result_.reset( 57 | new thrust::universal_vector(batch_size_ * embed_width_)); 58 | } else { 59 | result_.reset(new thrust::universal_vector( 60 | batch_size_ * embed_width_ * hotness_)); 61 | } 62 | } 63 | void LaunchKernel(const CombineMode mode, 64 | const DeviceType device, 65 | const bool is_csr, 66 | const bool is_weighted) { 67 | int* offsets = is_csr ? offsets_.data().get() : nullptr; 68 | int hotness = is_csr ? 0 : hotness_; 69 | EmbedType* weights = (mode == CombineMode::kConcat || (!is_weighted)) 70 | ? nullptr 71 | : weights_.data().get(); 72 | const bool kFp16Math = false; 73 | if (device == DeviceType::kCPU) { 74 | EmbeddingForwardCpu(embedding_.data().get(), 79 | embed_width_, 80 | batch_size_, 81 | hotness, 82 | indices_.data().get(), 83 | offsets, 84 | weights, 85 | result_->data().get(), 86 | mode); 87 | } else if (device == DeviceType::kGPU) { 88 | EmbeddingForward( 89 | embedding_.data().get(), 90 | embed_width_, 91 | indices_.data().get(), 92 | offsets, 93 | weights, 94 | batch_size_, 95 | hotness, 96 | mode, 97 | result_->data().get()); 98 | CHECK_CUDA(cudaDeviceSynchronize()); 99 | } else { 100 | CHECK(false) << "not supported device"; 101 | } 102 | } 103 | void CheckResult(const CombineMode mode, const bool is_weighted) { 104 | const thrust::universal_vector* reference = nullptr; 105 | if (mode == CombineMode::kSum && is_weighted) { 106 | reference = &ref_result_sum_weighted_; 107 | } else if (mode == CombineMode::kSum && (!is_weighted)) { 108 | reference = &ref_result_sum_; 109 | } else if (mode == CombineMode::kConcat) { 110 | reference = &ref_result_concat_; 111 | } else if (mode == CombineMode::kMean) { 112 | reference = &ref_result_avg_; 113 | } 114 | for (size_t i = 0; i < result_->size(); ++i) { 115 | EXPECT_EQ(result_->data()[i], reference->data()[i]); 116 | } 117 | } 118 | 119 | private: 120 | const int embed_width_{4}; 121 | const int hotness_{2}; 122 | thrust::universal_vector embedding_{ 123 | 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 124 | 11., 12., 13., 14., 15., 16., 17., 18., 19., 20.}; 125 | const int batch_size_{2}; 126 | thrust::universal_vector indices_{1, 3, 0, 4}; 127 | thrust::universal_vector offsets_{0, 2, 4}; 128 | thrust::universal_vector weights_{1.f, 0.5f, 1.f, 0.5f}; 129 | const thrust::universal_vector ref_result_concat_{ 130 | 5., 6., 7., 8., 13., 14., 15., 16., 1., 2., 3., 4., 17., 18., 19., 20.}; 131 | const thrust::universal_vector ref_result_sum_{ 132 | 18., 133 | 20., 134 | 22., 135 | 24., 136 | 18., 137 | 20., 138 | 22., 139 | 24., 140 | }; 141 | const thrust::universal_vector ref_result_avg_{ 142 | 9., 143 | 10., 144 | 11., 145 | 12., 146 | 9., 147 | 10., 148 | 11., 149 | 12., 150 | }; 151 | const thrust::universal_vector ref_result_sum_weighted_{ 152 | 11.5, 153 | 13., 154 | 14.5, 155 | 16., 156 | 9.5, 157 | 11., 158 | 12.5, 159 | 14., 160 | }; 161 | 162 | std::unique_ptr> result_; 163 | }; 164 | 165 | TYPED_TEST_SUITE_P(EmbeddingRefTest); 166 | 167 | TYPED_TEST_P(EmbeddingRefTest, TestFixedHotnessAgainstRefCpu) { 168 | const bool kIsCSR = false; 169 | for (const auto weighted : {true, false}) { 170 | for (const auto mode : 171 | {CombineMode::kSum, CombineMode::kConcat, CombineMode::kMean}) { 172 | if (mode == CombineMode::kMean && weighted) { 173 | continue; 174 | } 175 | this->LaunchTest(mode, DeviceType::kCPU, kIsCSR, weighted); 176 | } 177 | } 178 | } 179 | 180 | TYPED_TEST_P(EmbeddingRefTest, TestCSRAgainstRefCpu) { 181 | const bool kIsCSR = true; 182 | for (const auto weighted : {true, false}) { 183 | for (const auto mode : {CombineMode::kSum, CombineMode::kMean}) { 184 | if (mode == CombineMode::kMean && weighted) { 185 | continue; 186 | } 187 | this->LaunchTest(mode, DeviceType::kCPU, kIsCSR, weighted); 188 | } 189 | } 190 | } 191 | 192 | TYPED_TEST_P(EmbeddingRefTest, TestFixedHotnessAgainstRefGpu) { 193 | const bool kIsCSR = false; 194 | for (const auto weighted : {true, false}) { 195 | for (const auto mode : 196 | {CombineMode::kSum, CombineMode::kConcat, CombineMode::kMean}) { 197 | if (mode == CombineMode::kMean && weighted) { 198 | continue; 199 | } 200 | this->LaunchTest(mode, DeviceType::kGPU, kIsCSR, weighted); 201 | } 202 | } 203 | } 204 | 205 | TYPED_TEST_P(EmbeddingRefTest, TestCSRAgainstRefGpu) { 206 | const bool kIsCSR = true; 207 | for (const auto weighted : {true, false}) { 208 | for (const auto mode : {CombineMode::kSum, CombineMode::kMean}) { 209 | if (mode == CombineMode::kMean && weighted) { 210 | continue; 211 | } 212 | this->LaunchTest(mode, DeviceType::kGPU, kIsCSR, weighted); 213 | } 214 | } 215 | } 216 | 217 | REGISTER_TYPED_TEST_SUITE_P(EmbeddingRefTest, 218 | TestFixedHotnessAgainstRefCpu, 219 | TestCSRAgainstRefCpu, 220 | TestFixedHotnessAgainstRefGpu, 221 | TestCSRAgainstRefGpu); 222 | 223 | template 224 | struct EmbedTestTypeCombo { 225 | typedef EmbedT EmbedType; 226 | typedef IndexT IndexType; 227 | }; 228 | 229 | typedef ::testing::Types, 230 | EmbedTestTypeCombo, 231 | EmbedTestTypeCombo<__half, int32_t>, 232 | EmbedTestTypeCombo<__half, int64_t>> 233 | EmbedTestTypes; 234 | 235 | char ForwardRefTestNameGlobal[] = "ForwardRefTest_"; 236 | INSTANTIATE_TYPED_TEST_SUITE_P( 237 | Embedding, 238 | EmbeddingRefTest, 239 | EmbedTestTypes, 240 | utils::EmbeddingRefTestNames); 241 | 242 | } // namespace cuembed 243 | -------------------------------------------------------------------------------- /tests/test_embedding_transpose.cu: -------------------------------------------------------------------------------- 1 | // clang-format off 2 | /* 3 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | * SPDX-License-Identifier: Apache-2.0 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | // clang-format on 19 | 20 | #include 21 | #include 22 | 23 | #include "absl/log/check.h" 24 | #include "absl/strings/str_format.h" 25 | #include "cuembed/include/index_transforms.cuh" 26 | #include "gtest/gtest.h" 27 | #include "utils/include/embedding_allocation.h" 28 | #include "utils/include/embedding_utils.h" 29 | 30 | // CPU reference implementations 31 | #include "utils/include/index_transforms_cpu.hpp" 32 | 33 | namespace cuembed { 34 | 35 | enum class DeviceType { kGPU, kCPU }; 36 | 37 | template 38 | class TransposeRefTest : public ::testing::Test { 39 | public: 40 | void LaunchTest(const DeviceType device, const bool is_weighted) { 41 | this->CreateResult(); 42 | this->LaunchKernel(device, is_weighted); 43 | this->CheckResult(is_weighted); 44 | } 45 | 46 | typedef typename T::EmbedType EmbedType; 47 | typedef typename T::IndexType IndexType; 48 | 49 | private: 50 | void CreateResult() { 51 | transpose_indices_.reset(new thrust::universal_vector(nnz_)); 52 | transpose_sample_ids_.reset(new thrust::universal_vector(nnz_)); 53 | transpose_weights_.reset(new thrust::universal_vector(nnz_)); 54 | } 55 | void LaunchKernel(const DeviceType device, const bool is_weighted) { 56 | EmbedType* weights = (!is_weighted) ? nullptr : weights_.data().get(); 57 | EmbedType* transpose_weights = 58 | (!is_weighted) ? nullptr : transpose_weights_->data().get(); 59 | if (device == DeviceType::kCPU) { 60 | TransposeCpu(sample_ids_.data().get(), 61 | indices_.data().get(), 62 | weights, 63 | nnz_, 64 | transpose_indices_->data().get(), 65 | transpose_sample_ids_->data().get(), 66 | transpose_weights); 67 | } else if (device == DeviceType::kGPU) { 68 | size_t lwork; 69 | Transpose(sample_ids_.data().get(), 70 | indices_.data().get(), 71 | weights, 72 | nnz_, 73 | transpose_indices_->data().get(), 74 | transpose_sample_ids_->data().get(), 75 | transpose_weights, 76 | nullptr, 77 | &lwork); 78 | 79 | thrust::device_vector t_work(lwork); 80 | 81 | Transpose(sample_ids_.data().get(), 82 | indices_.data().get(), 83 | weights, 84 | nnz_, 85 | transpose_indices_->data().get(), 86 | transpose_sample_ids_->data().get(), 87 | transpose_weights, 88 | t_work.data().get(), 89 | &lwork); 90 | CHECK_CUDA(cudaDeviceSynchronize()); 91 | } else { 92 | CHECK(false) << "not supported device"; 93 | } 94 | } 95 | void CheckResult(const bool is_weighted) { 96 | for (int64_t i = 0; i < nnz_; ++i) { 97 | EXPECT_EQ(transpose_indices_->data()[i], 98 | ref_transpose_indices_.data()[i]); 99 | 100 | // No repeated indices so exact match is ensured for weights and 101 | // sample_ids 102 | EXPECT_EQ(transpose_sample_ids_->data()[i], 103 | ref_transpose_sample_ids_.data()[i]); 104 | if (is_weighted) { 105 | EXPECT_EQ(transpose_weights_->data()[i], 106 | ref_transpose_weights_.data()[i]); 107 | } 108 | } 109 | } 110 | 111 | private: 112 | const int nnz_{4}; 113 | thrust::universal_vector indices_{1, 3, 0, 4}; 114 | thrust::universal_vector sample_ids_{0, 0, 1, 1}; 115 | thrust::universal_vector weights_{1.f, 0.5f, 1.f, 0.5f}; 116 | const thrust::universal_vector ref_transpose_indices_{0, 1, 3, 4}; 117 | const thrust::universal_vector ref_transpose_sample_ids_{ 118 | 1, 0, 0, 1}; 119 | const thrust::universal_vector ref_transpose_sample_ids_concat_{ 120 | 2, 0, 1, 3}; 121 | const thrust::universal_vector ref_transpose_weights_{ 122 | 1.f, 1.f, 0.5f, 0.5f}; 123 | 124 | std::unique_ptr> transpose_indices_; 125 | std::unique_ptr> transpose_sample_ids_; 126 | std::unique_ptr> transpose_weights_; 127 | }; 128 | 129 | TYPED_TEST_SUITE_P(TransposeRefTest); 130 | 131 | TYPED_TEST_P(TransposeRefTest, TestAgainstRefCpu) { 132 | for (const auto weighted : {true, false}) { 133 | this->LaunchTest(DeviceType::kCPU, weighted); 134 | } 135 | } 136 | 137 | TYPED_TEST_P(TransposeRefTest, TestAgainstRefGpu) { 138 | for (const auto weighted : {true, false}) { 139 | this->LaunchTest(DeviceType::kGPU, weighted); 140 | } 141 | } 142 | 143 | REGISTER_TYPED_TEST_SUITE_P(TransposeRefTest, 144 | TestAgainstRefCpu, 145 | TestAgainstRefGpu); 146 | 147 | template 148 | struct EmbedTestTypeCombo { 149 | typedef EmbedT EmbedType; 150 | typedef IndexT IndexType; 151 | }; 152 | 153 | typedef ::testing::Types, 154 | EmbedTestTypeCombo, 155 | EmbedTestTypeCombo<__half, int32_t>, 156 | EmbedTestTypeCombo<__half, int64_t>> 157 | EmbedTestTypes; 158 | 159 | char TransposeRefTestNameGlobal[] = "TransposeRefTest_"; 160 | INSTANTIATE_TYPED_TEST_SUITE_P( 161 | Transpose, 162 | TransposeRefTest, 163 | EmbedTestTypes, 164 | utils::EmbeddingRefTestNames); 165 | 166 | } // namespace cuembed 167 | -------------------------------------------------------------------------------- /tests/test_third_party_utils.cu: -------------------------------------------------------------------------------- 1 | // clang-format off 2 | /* 3 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | * SPDX-License-Identifier: Apache-2.0 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | // clang-format on 19 | 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | #include "absl/log/check.h" 26 | #include "absl/log/log.h" 27 | #include "gtest/gtest.h" 28 | #include "utils/include/datagen.h" 29 | 30 | TEST(GtestTest, BasicTest) { EXPECT_EQ(1, 1); } 31 | 32 | TEST(LoggingTest, BasicLogging) { 33 | // Redirect glog output to prevent it from polluting test output 34 | EXPECT_NO_FATAL_FAILURE(LOG(INFO) << "This is an info log for testing."); 35 | 36 | // Test other logging levels. 37 | EXPECT_NO_FATAL_FAILURE(LOG(ERROR) << "This is an error log for testing."); 38 | } 39 | 40 | TEST(CheckTest, BasicCheck) { 41 | EXPECT_NO_FATAL_FAILURE(CHECK_EQ(1, 1)); 42 | EXPECT_DEATH(CHECK_EQ(1, 2), ""); 43 | } 44 | 45 | TEST(Thrust, BasicTest) { 46 | thrust::host_vector h_vec{1, 3, 2, 4}; 47 | thrust::device_vector d_vec = h_vec; 48 | thrust::sort(d_vec.begin(), d_vec.end()); 49 | thrust::copy(d_vec.begin(), d_vec.end(), h_vec.begin()); 50 | for (size_t i = 0; i < h_vec.size(); i++) { 51 | EXPECT_EQ(h_vec[i], i + 1); 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /utils/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | include_directories(${CUEMBED_PROJECT_SOURCE_DIR}) 17 | 18 | find_package(CUDAToolkit) 19 | add_library(utils OBJECT src/embedding_allocation.cu src/embedding_gpu_forward.cu src/embedding_gpu_transpose.cu src/embedding_gpu_backward.cu src/embedding_cpu.cu src/datagen.cpp) 20 | target_include_directories(utils PRIVATE ${CUDAToolkit_INCLUDE_DIRS} absl::log absl::check gtest) 21 | target_link_libraries(utils PRIVATE CUDA::cudart absl::log absl::check gtest) 22 | -------------------------------------------------------------------------------- /utils/include/datagen.h: -------------------------------------------------------------------------------- 1 | // clang-format off 2 | /* 3 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | * SPDX-License-Identifier: Apache-2.0 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | // clang-format on 19 | 20 | #ifndef UTILS_INCLUDE_DATAGEN_H_ 21 | #define UTILS_INCLUDE_DATAGEN_H_ 22 | 23 | #include 24 | #include 25 | #include 26 | #include 27 | 28 | namespace cuembed { 29 | namespace index_generators { 30 | 31 | // Abstract base class for generators of feature categories. Derived classes 32 | // will be specialized to draw random values from particular distributions. Any 33 | // generator requires two parameters for the features: number of categories and 34 | // the number of indices per sample (aka hotness). Each call to the 35 | // getCategoryIndices() method return a C++ vector with the randomly generated 36 | // category indices. There will be no index repetitions in this returned vector. 37 | template 38 | class FeatureGenerator { 39 | public: 40 | // Deleted to ensure that parameters are always initialized upon construction. 41 | FeatureGenerator() = delete; 42 | 43 | // Constructs an object given the number of categories and hotness for a 44 | // feature. 45 | FeatureGenerator(const IndexType num_categories, 46 | const int num_hot, 47 | const bool shuffle = false, 48 | const bool permute = false); 49 | 50 | virtual ~FeatureGenerator() {} 51 | 52 | // Each derived feature generator should implement this method. 53 | virtual IndexType generateIndex() = 0; 54 | 55 | // Returns a vector of random category indices. 56 | std::vector getCategoryIndices(); 57 | 58 | // Returns the number of categories for the feature. 59 | int getNumCategories() const { return num_categories_; } 60 | 61 | // Returns the hotness for teh feature (number of looks ups per sample). 62 | size_t getNumHot() const { return static_cast(num_hot_); } 63 | 64 | IndexType getPermutedIndex(int index) const; 65 | 66 | const std::vector& getInversePermutation() const { 67 | return this->inverse_permutation_; 68 | } 69 | 70 | protected: 71 | IndexType num_categories_ = 0; 72 | int num_hot_ = 0; 73 | bool shuffle_ = false; 74 | bool permute_ = false; 75 | std::vector permutation_; 76 | std::vector inverse_permutation_; 77 | }; 78 | 79 | // PowerLaw type specifies the type of power law distribution used in the 80 | // generator. 81 | enum class PowerLawType { 82 | // Generate random indices in [1, num_categories] range according to power 83 | // law. 84 | kPsx = 0 85 | }; 86 | 87 | // A class for generating category indices drawn from a power law distribution. 88 | // Category index 0 is not generated as it is assumed to be reserved for a 89 | // "missing" category. Thus, given num_categories, returned indices are drawn 90 | // from [1, num_categories] range. Each returned set of indices contains exactly 91 | // num_hots indices, with no repetitions. Power law distribution is specified 92 | // via its exponent, alpha > 0. Smaller indices correspond to more frequent 93 | // categories (i.e. 1 will be the most frequent category, 2 - the second most 94 | // frequent one, etc.). If math_numpy is true, generate distribution that 95 | // matches numpy. 96 | template 97 | class PowerLawFeatureGenerator : public FeatureGenerator { 98 | public: 99 | PowerLawFeatureGenerator() = delete; 100 | 101 | PowerLawFeatureGenerator(const IndexType num_categories, 102 | const int num_hot, 103 | const double alpha, 104 | const bool shuffle = false, 105 | const bool permute = false, 106 | const PowerLawType type = PowerLawType::kPsx); 107 | 108 | IndexType generateIndex() override; 109 | 110 | protected: 111 | double alpha_ = 0.; // Exponent for the power-law distribution. 112 | PowerLawType type_ = PowerLawType::kPsx; 113 | std::default_random_engine generator_; 114 | std::unique_ptr> distribution_; 115 | }; 116 | } // namespace index_generators 117 | } // namespace cuembed 118 | 119 | #endif // UTILS_INCLUDE_DATAGEN_H_ 120 | -------------------------------------------------------------------------------- /utils/include/embedding_allocation.h: -------------------------------------------------------------------------------- 1 | // clang-format off 2 | /* 3 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | * SPDX-License-Identifier: Apache-2.0 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | // clang-format on 19 | 20 | #ifndef UTILS_INCLUDE_EMBEDDING_ALLOCATION_H_ 21 | #define UTILS_INCLUDE_EMBEDDING_ALLOCATION_H_ 22 | 23 | #include 24 | #include 25 | #include 26 | 27 | #include "cuembed/include/embedding_lookup_types.cuh" 28 | 29 | namespace cuembed { 30 | 31 | // String literals live in cuembed namespace for ease of use. 32 | // A string literal for million. E.g., num_categories = 10_M. 33 | constexpr unsigned long long operator"" _M( // NOLINT 34 | unsigned long long int val) { // NOLINT 35 | return val * 1024 * 1024; 36 | } 37 | 38 | // A string literal for thousand. E.g., batch_size = 64_K. 39 | constexpr unsigned long long operator"" _K( // NOLINT 40 | unsigned long long int val) { // NOLINT 41 | return val * 1024; 42 | } 43 | 44 | namespace utils { 45 | 46 | // A wrapper class for different options. 47 | class AllocationOptions { 48 | public: 49 | AllocationOptions() {} 50 | 51 | // Setters and getters of each allocation option. 52 | AllocationOptions& num_categories(int num_categories); 53 | int32_t num_categories() const { return num_categories_; } 54 | 55 | AllocationOptions& batch_size(int32_t batch_size); 56 | int32_t batch_size() const { return batch_size_; } 57 | 58 | AllocationOptions& hotness(int32_t hotness); 59 | int32_t hotness() const { return hotness_; } 60 | 61 | AllocationOptions& alpha(float alpha); 62 | float alpha() const { return alpha_; } 63 | 64 | AllocationOptions& combine_mode(CombineMode type); 65 | CombineMode combine_mode() const { return combine_mode_; } 66 | 67 | AllocationOptions& embed_width(int32_t embed_width); 68 | int32_t embed_width() const { return embed_width_; } 69 | 70 | AllocationOptions& permute_indices(bool permute_indices); 71 | bool permute_indices() const { return permute_indices_; } 72 | 73 | AllocationOptions& shuffle_indices(int32_t shuffle_indices); 74 | bool shuffle_indices() const { return shuffle_indices_; } 75 | 76 | AllocationOptions& is_csr(bool is_csr); 77 | bool is_csr() const { return is_csr_; } 78 | 79 | AllocationOptions& is_weighted(bool is_weighted); 80 | bool is_weighted() const { return is_weighted_; } 81 | 82 | AllocationOptions& compressed_grad(bool compressed_grad); 83 | bool compressed_grad() const { return compressed_grad_; } 84 | 85 | AllocationOptions& skip_grad_init(bool skip_grad_init); 86 | bool skip_grad_init() const { return skip_grad_init_; } 87 | 88 | private: 89 | int32_t num_categories_{0}; 90 | int32_t batch_size_{0}; 91 | int32_t hotness_{0}; 92 | float alpha_{0}; 93 | int32_t embed_width_{0}; 94 | bool permute_indices_{true}; 95 | bool shuffle_indices_{true}; 96 | bool is_csr_{false}; 97 | bool is_weighted_{false}; 98 | bool compressed_grad_{false}; 99 | bool skip_grad_init_{false}; 100 | CombineMode combine_mode_{CombineMode::kSum}; 101 | }; 102 | 103 | template 109 | struct UniversalEmbeddingAllocation { 110 | thrust::universal_vector embedding; 111 | thrust::universal_vector indices; 112 | thrust::universal_vector offsets; 113 | thrust::universal_vector weights; 114 | thrust::universal_vector result; 115 | thrust::universal_vector transpose_indices; 116 | thrust::universal_vector transpose_remapped_indices; 117 | thrust::universal_vector transpose_sample_ids; 118 | thrust::universal_vector transpose_weights; 119 | thrust::universal_vector sample_ids; 120 | thrust::universal_vector transpose_workspace; 121 | thrust::universal_vector grad_y; 122 | thrust::universal_vector grad_embedding; 123 | thrust::universal_vector inverse_mapping; 124 | }; 125 | 126 | template 132 | struct DeviceEmbeddingAllocation { 133 | thrust::device_vector embedding; 134 | thrust::device_vector indices; 135 | thrust::device_vector offsets; 136 | thrust::device_vector weights; 137 | thrust::device_vector result; 138 | thrust::device_vector transpose_indices; 139 | thrust::device_vector transpose_remapped_indices; 140 | thrust::device_vector transpose_sample_ids; 141 | thrust::device_vector sample_ids; 142 | thrust::device_vector transpose_weights; 143 | thrust::device_vector transpose_workspace; 144 | thrust::device_vector grad_y; 145 | thrust::device_vector grad_embedding; 146 | thrust::device_vector inverse_mapping; 147 | }; 148 | 149 | template 155 | void AllocateHost(const AllocationOptions& options, 156 | UniversalEmbeddingAllocation* allocation, 162 | bool forward_only = false); 163 | 164 | template 170 | void AllocateDevice( 171 | const AllocationOptions& options, 172 | const UniversalEmbeddingAllocation& universal_allocation, 178 | DeviceEmbeddingAllocation* 179 | device_allocation, 180 | bool forward_only = false); 181 | 182 | } // namespace utils 183 | } // namespace cuembed 184 | 185 | #endif // UTILS_INCLUDE_EMBEDDING_ALLOCATION_H_ 186 | -------------------------------------------------------------------------------- /utils/include/embedding_lookup_cpu.hpp: -------------------------------------------------------------------------------- 1 | // clang-format off 2 | /* 3 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | * SPDX-License-Identifier: Apache-2.0 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | // clang-format on 19 | 20 | #ifndef UTILS_INCLUDE_EMBEDDING_LOOKUP_CPU_HPP_ 21 | #define UTILS_INCLUDE_EMBEDDING_LOOKUP_CPU_HPP_ 22 | 23 | #include 24 | 25 | #include 26 | #include 27 | #include 28 | 29 | #include "absl/log/check.h" 30 | #include "absl/log/log.h" 31 | #include "cuembed/include/embedding_lookup_types.cuh" 32 | 33 | namespace cuembed { 34 | 35 | template 40 | void EmbeddingForwardCpu(const InputT* params, 41 | const int embed_width, 42 | const int batch_size, 43 | const int num_hots, 44 | const IndexT* indices, 45 | const OffsetT* offsets, 46 | const GetElemT* weights, 47 | OutputT* ret, 48 | const CombineMode reduce) { 49 | using ElemT = GetElemT; 50 | // Weights can only be sum. 51 | CHECK(weights == nullptr || reduce == CombineMode::kSum); 52 | // CSR or fixed hotness. 53 | CHECK((offsets != nullptr && num_hots == 0) || 54 | (offsets == nullptr && num_hots > 0)); 55 | // CSR does not support concat. 56 | CHECK(offsets == nullptr || reduce != CombineMode::kConcat); 57 | for (int i = 0; i < batch_size; ++i) { 58 | for (int k = 0; k < embed_width; ++k) { 59 | using SumT = typename std::conditional::type; 60 | SumT sum = 0.0f; 61 | int hotness = 62 | (offsets == nullptr) ? num_hots : (offsets[i + 1] - offsets[i]); 63 | int index_start = (offsets == nullptr) ? i * num_hots : offsets[i]; 64 | IndexT write_idx = i * embed_width + k; 65 | for (int j = 0; j < hotness; ++j) { 66 | int64_t read_idx = 67 | static_cast(indices[index_start + j]) * embed_width + k; 68 | if (reduce == CombineMode::kConcat) { 69 | write_idx = index_start * embed_width + j * embed_width + k; 70 | ret[write_idx] = params[read_idx]; 71 | } else if (reduce == CombineMode::kSum || 72 | reduce == CombineMode::kMean) { 73 | ElemT weight = (weights == nullptr) ? static_cast(1.0f) 74 | : weights[index_start + j]; 75 | sum += VecCast(params[read_idx]) * weight; 76 | } else { 77 | CHECK(false) << "reduce type not supported."; 78 | } 79 | } 80 | if (reduce == CombineMode::kSum) { 81 | ret[write_idx] = VecCast(sum); 82 | } else if (reduce == CombineMode::kMean) { 83 | if (hotness == 0) { 84 | // Preserve sign. 85 | ret[write_idx] = 86 | VecCast(sum * static_cast(0.0f)); 87 | } else { 88 | ret[write_idx] = 89 | VecCast(sum * static_cast(1.0f / hotness)); 90 | } 91 | } 92 | } 93 | } 94 | } 95 | 96 | template 97 | void EmbeddingBackwardCpu(const GradT* result_grad, 98 | const int embed_width, 99 | const int num_grad_embedding_rows, 100 | const int nnz, 101 | const IndexT* transpose_indices, 102 | const IndexT* transpose_sample_ids, 103 | const IndexT* transpose_remapped_indices, 104 | const GradT* transpose_weights, 105 | const bool skip_grad_init, 106 | GradT* grad_embedding, 107 | IndexT* inverse_mapping) { 108 | using WeightT = GradT; 109 | if (nnz == 0) return; 110 | if (transpose_remapped_indices != nullptr) { 111 | CHECK(grad_embedding != nullptr); 112 | CHECK(inverse_mapping != nullptr); 113 | 114 | // Set grad embedding indices 115 | inverse_mapping[0] = transpose_indices[0]; 116 | int cnt = 1; 117 | for (int i = 1; i < nnz; i++) { 118 | if (transpose_remapped_indices[i - 1] != transpose_remapped_indices[i]) { 119 | inverse_mapping[cnt] = transpose_indices[i]; 120 | cnt++; 121 | } 122 | } 123 | } 124 | if (!skip_grad_init) { 125 | memset(grad_embedding, 126 | 0, 127 | (int64_t)num_grad_embedding_rows * (int64_t)embed_width * 128 | sizeof(GradT)); 129 | } 130 | // Loop over nnz, load index and offset, then loop over embed dim 131 | for (int nz = 0; nz < nnz; nz++) { 132 | IndexT index = (transpose_remapped_indices != nullptr) 133 | ? transpose_remapped_indices[nz] 134 | : transpose_indices[nz]; 135 | IndexT sample_id = transpose_sample_ids[nz]; 136 | WeightT weight = (transpose_weights != nullptr) 137 | ? transpose_weights[nz] 138 | : static_cast(1.0f); 139 | for (int e = 0; e < embed_width; e++) { 140 | grad_embedding[e + index * embed_width] += 141 | result_grad[e + sample_id * embed_width] * weight; 142 | } 143 | } 144 | } 145 | 146 | } // namespace cuembed 147 | 148 | #endif // UTILS_INCLUDE_EMBEDDING_LOOKUP_CPU_HPP_ 149 | -------------------------------------------------------------------------------- /utils/include/embedding_utils.h: -------------------------------------------------------------------------------- 1 | // clang-format off 2 | /* 3 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | * SPDX-License-Identifier: Apache-2.0 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | // clang-format on 19 | 20 | #ifndef UTILS_INCLUDE_EMBEDDING_UTILS_H_ 21 | #define UTILS_INCLUDE_EMBEDDING_UTILS_H_ 22 | 23 | #include 24 | #include 25 | #include 26 | #include 27 | 28 | #include 29 | 30 | #include "absl/log/log.h" 31 | #include "absl/strings/str_format.h" 32 | #include "cuembed/include/embedding_lookup_types.cuh" 33 | #include "utils/include/embedding_allocation.h" 34 | 35 | #define CHECK_CUDA(cmd) \ 36 | do { \ 37 | cudaError_t e = cmd; \ 38 | if (e != cudaSuccess) { \ 39 | LOG(FATAL) << absl::StrFormat("Failed: Cuda error %s:%d '%s'\n", \ 40 | __FILE__, \ 41 | __LINE__, \ 42 | cudaGetErrorString(e)); \ 43 | } \ 44 | } while (0) 45 | 46 | namespace cuembed { 47 | namespace utils { 48 | struct Near { 49 | float tolerance_; 50 | 51 | explicit Near(const float tol) : tolerance_(tol) {} 52 | 53 | __host__ __device__ bool operator()(const float& a, const float& b) const { 54 | return fabsf(a - b) <= tolerance_; 55 | } 56 | 57 | __host__ __device__ bool operator()(const __half& a, const __half& b) const { 58 | return fabsf(__half2float(a) - __half2float(b)) <= tolerance_; 59 | } 60 | }; 61 | 62 | template 63 | class EmbeddingRefTestNames { 64 | public: 65 | template 66 | static std::string GetName(int i) { 67 | typedef typename T::EmbedType EmbedType; 68 | typedef typename T::IndexType IndexType; 69 | 70 | std::string test_name = std::string(test_name_str); 71 | if (std::is_same_v) { 72 | test_name += std::string("Embed[float]"); 73 | } 74 | if (std::is_same_v) { 75 | test_name += std::string("Embed[half]"); 76 | } 77 | if (std::is_same_v) { 78 | test_name += std::string("Index[int32]"); 79 | } 80 | if (std::is_same_v) { 81 | test_name += std::string("Index[int64]"); 82 | } 83 | test_name += std::to_string(i); 84 | return test_name; 85 | } 86 | }; 87 | 88 | template 89 | void RunForward(const utils::AllocationOptions& options, 90 | const thrust::device_vector& embedding, 91 | const thrust::device_vector& indices, 92 | const thrust::device_vector& offsets, 93 | const thrust::device_vector& weights, 94 | thrust::device_vector* result); 95 | 96 | template 97 | void RunForwardReference(const utils::AllocationOptions& options, 98 | const thrust::universal_vector& embedding, 99 | const thrust::universal_vector& indices, 100 | const thrust::universal_vector& offsets, 101 | const thrust::universal_vector& weights, 102 | thrust::universal_vector* result); 103 | 104 | template 105 | void RunTranspose(const utils::AllocationOptions& options, 106 | const thrust::device_vector& indices, 107 | const thrust::device_vector& offsets, 108 | const thrust::device_vector& weights, 109 | const OffsetT nnz, 110 | thrust::device_vector* transpose_indices, 111 | thrust::device_vector* transpose_remapped_indices, 112 | thrust::device_vector* transpose_sample_ids, 113 | thrust::device_vector* transpose_weights, 114 | thrust::device_vector* sample_ids, 115 | thrust::device_vector* transpose_workspace); 116 | 117 | template 118 | void RunTransposeReference( 119 | const utils::AllocationOptions& options, 120 | const thrust::universal_vector& indices, 121 | const thrust::universal_vector& offsets, 122 | const thrust::universal_vector& weights, 123 | const int nnz, 124 | thrust::universal_vector* transpose_indices, 125 | thrust::universal_vector* transpose_remapped_indices, 126 | thrust::universal_vector* transpose_sample_ids, 127 | thrust::universal_vector* transpose_weights); 128 | 129 | template 130 | void RunBackward( 131 | const utils::AllocationOptions& options, 132 | const thrust::device_vector& grad_y, 133 | const thrust::device_vector& transpose_indices, 134 | const thrust::device_vector& transpose_remapped_indices, 135 | const thrust::device_vector& transpose_sample_ids, 136 | const thrust::device_vector& transpose_weights, 137 | const thrust::device_vector& offsets, 138 | const OffsetT nnz, 139 | const OffsetT num_unique, 140 | thrust::device_vector* grad_embedding, 141 | thrust::device_vector* inverse_mapping); 142 | 143 | template 144 | void RunBackwardReference( 145 | const utils::AllocationOptions& options, 146 | const thrust::universal_vector& grad_y, 147 | const thrust::universal_vector& transpose_indices, 148 | const thrust::universal_vector& transpose_remapped_indices, 149 | const thrust::universal_vector& transpose_sample_ids, 150 | const thrust::universal_vector& transpose_weights, 151 | const thrust::universal_vector& offsets, 152 | const int nnz, 153 | thrust::universal_vector* grad_embedding, 154 | thrust::universal_vector* inverse_mapping); 155 | 156 | } // namespace utils 157 | } // namespace cuembed 158 | 159 | #endif // UTILS_INCLUDE_EMBEDDING_UTILS_H_ 160 | -------------------------------------------------------------------------------- /utils/include/index_transforms_cpu.hpp: -------------------------------------------------------------------------------- 1 | // clang-format off 2 | /* 3 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | * SPDX-License-Identifier: Apache-2.0 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | // clang-format on 19 | 20 | #ifndef UTILS_INCLUDE_INDEX_TRANSFORMS_CPU_HPP_ 21 | #define UTILS_INCLUDE_INDEX_TRANSFORMS_CPU_HPP_ 22 | 23 | #include 24 | 25 | #include 26 | #include 27 | #include 28 | 29 | #include "absl/log/check.h" 30 | #include "absl/log/log.h" 31 | #include "cuembed/include/embedding_lookup_types.cuh" 32 | 33 | namespace cuembed { 34 | 35 | template 36 | void ExtractRowIdsFromFixedCpu(const int batch_size, 37 | const int num_hots, 38 | IndexT* row_ids) { 39 | for (int b = 0; b < batch_size; b++) { 40 | for (int h = 0; h < num_hots; h++) { 41 | row_ids[b * num_hots + h] = static_cast(b); 42 | } 43 | } 44 | } 45 | 46 | template 47 | void ExtractRowIdsFromCSRCpu(const OffsetT* offsets, 48 | const int batch_size, 49 | IndexT* row_ids) { 50 | int cnt = 0; 51 | for (int b = 0; b < batch_size; b++) { 52 | for (OffsetT o = offsets[b]; o < offsets[b + 1]; o++) { 53 | row_ids[cnt] = static_cast(b); 54 | cnt++; 55 | } 56 | } 57 | } 58 | 59 | template 60 | void ExtractRowIdsForConcatCpu(const int nnz, IndexT* row_ids) { 61 | for (int i = 0; i < nnz; i++) { 62 | row_ids[i] = static_cast(i); 63 | } 64 | } 65 | 66 | template 67 | void ComputeCompressedGradIndicesCpu(const IndexT* indices, 68 | const int nnz, 69 | IndexT* remapped_indices) { 70 | IndexT unique_cnt = 0; 71 | for (int64_t cnt = 0; cnt < nnz; cnt++) { 72 | if ((cnt > 0) && (indices[cnt] != indices[cnt - 1])) { 73 | unique_cnt++; 74 | } 75 | remapped_indices[cnt] = unique_cnt; 76 | } 77 | } 78 | 79 | template 80 | struct index_tuple { 81 | IndexT idx; 82 | IndexT sid; 83 | WeightT wt; 84 | }; 85 | 86 | template 87 | void TransposeCpu(const IndexT* rows, 88 | const IndexT* cols, 89 | const WeightT* weights, 90 | const int nnz, 91 | IndexT* transpose_rows, 92 | IndexT* transpose_cols, 93 | WeightT* transpose_weights) { 94 | // Fill indices and weights into vector 95 | std::vector > tuples; 96 | for (int cnt = 0; cnt < nnz; cnt++) { 97 | index_tuple tuple; 98 | tuple.idx = cols[cnt]; 99 | tuple.sid = rows[cnt]; 100 | tuple.wt = (weights != nullptr) ? weights[cnt] : static_cast(0); 101 | tuples.push_back(tuple); 102 | } 103 | 104 | // Sort (offsets, indices, weights) 105 | std::sort(tuples.begin(), 106 | tuples.end(), 107 | [](const index_tuple& a, 108 | const index_tuple& b) { 109 | if (a.idx < b.idx) return true; 110 | if (a.idx > b.idx) return false; 111 | if (a.sid < b.sid) return true; 112 | if (a.sid > b.sid) return false; 113 | if (a.wt < b.wt) return true; 114 | return false; 115 | }); 116 | 117 | // Copy to output 118 | for (int64_t cnt = 0; cnt < nnz; cnt++) { 119 | transpose_rows[cnt] = tuples[cnt].idx; 120 | transpose_cols[cnt] = tuples[cnt].sid; 121 | if (transpose_weights != nullptr) { 122 | transpose_weights[cnt] = tuples[cnt].wt; 123 | } 124 | } 125 | } 126 | 127 | } // namespace cuembed 128 | 129 | #endif // UTILS_INCLUDE_INDEX_TRANSFORMS_CPU_HPP_ 130 | -------------------------------------------------------------------------------- /utils/src/datagen.cpp: -------------------------------------------------------------------------------- 1 | // clang-format off 2 | /* 3 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | * SPDX-License-Identifier: Apache-2.0 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | // clang-format on 19 | 20 | #include "utils/include/datagen.h" 21 | 22 | #include 23 | #include 24 | #include 25 | #include 26 | 27 | namespace cuembed { 28 | namespace index_generators { 29 | 30 | // Function that "translates" a value drawn uniformly from [0, 1) range into a 31 | // value drawn from a power-law distribution. Power-law distribution is 32 | // characterized by the range of values [min_value, max_value) and the exponent 33 | // value alpha. Assumptions: 34 | // * k_min >= 1 35 | // * alpha > 0 && alpha != 1 36 | // See the accompanying derivation.jpg file for a derivation of the equation 37 | // used in this function. TODO: check if max_value can really be generated or if 38 | // returned values are in [min_value, max_value) range. 39 | template 40 | float translateToPowerLaw(const Type min_value, 41 | const Type max_value, 42 | const Type alpha, 43 | const Type random_uniform_value) { 44 | const Type gamma = 1 - alpha; 45 | Type y = pow( 46 | random_uniform_value * (pow(max_value, gamma) - pow(min_value, gamma)) + 47 | pow(min_value, gamma), 48 | 1.0 / gamma); 49 | return y; 50 | } 51 | 52 | template 53 | FeatureGenerator::FeatureGenerator(const IndexType num_categories, 54 | const int num_hot, 55 | const bool shuffle, 56 | const bool permute) 57 | : num_categories_(num_categories), 58 | num_hot_(num_hot), 59 | shuffle_(shuffle), 60 | permute_(permute) { 61 | // 0 is reserved. Need at least one additional category to generate indices 62 | // from. 63 | assert(this->num_categories_ > 1); 64 | if (permute) { 65 | this->permutation_.resize(num_categories + 1); 66 | this->inverse_permutation_.resize(num_categories + 1); 67 | std::iota(this->permutation_.begin(), this->permutation_.end(), 0); 68 | std::random_shuffle(this->permutation_.begin(), this->permutation_.end()); 69 | 70 | for (IndexType i = 0; i < num_categories + 1; ++i) { 71 | this->inverse_permutation_[this->permutation_[i]] = i; 72 | } 73 | } 74 | } 75 | 76 | template 77 | IndexType FeatureGenerator::getPermutedIndex(int index) const { 78 | if (this->permute_) { 79 | return this->permutation_[index]; 80 | } else { 81 | return index; 82 | } 83 | } 84 | 85 | template 86 | std::vector FeatureGenerator::getCategoryIndices() { 87 | // A set created to track already used indices. 88 | std::set used_indices; 89 | while (used_indices.size() < this->getNumHot()) { 90 | used_indices.insert(this->getPermutedIndex(this->generateIndex())); 91 | } 92 | 93 | std::vector indices; 94 | for (const auto& x : used_indices) { 95 | indices.push_back(x); 96 | } 97 | 98 | if (this->shuffle_) { 99 | std::random_shuffle(indices.begin(), indices.end()); 100 | } 101 | 102 | return indices; 103 | } 104 | 105 | template 106 | PowerLawFeatureGenerator::PowerLawFeatureGenerator( 107 | const IndexType num_categories, 108 | const int num_hot, 109 | const double alpha, 110 | const bool shuffle, 111 | const bool permute, 112 | const PowerLawType type) 113 | : FeatureGenerator(num_categories, num_hot, shuffle, permute), 114 | alpha_(alpha), 115 | type_(type) { 116 | distribution_.reset(new std::uniform_real_distribution(0., 1.)); 117 | } 118 | 119 | template 120 | IndexType PowerLawFeatureGenerator::generateIndex() { 121 | const double x = (*distribution_)(generator_); 122 | IndexType y = -1; 123 | 124 | // translateToPowerLaw(1., num_categories + 1, *, *) generates an index 125 | // within range [1, num_categories + 1). Then cast to IndexType to range [1, 126 | // num_categories]. 127 | y = IndexType(translateToPowerLaw( 128 | 1., static_cast(this->num_categories_ + 1), alpha_, x)); 129 | 130 | return y; 131 | } // namespace index_generators 132 | 133 | template class FeatureGenerator; 134 | template class FeatureGenerator; 135 | template class PowerLawFeatureGenerator; 136 | template class PowerLawFeatureGenerator; 137 | 138 | } // namespace index_generators 139 | } // namespace cuembed 140 | -------------------------------------------------------------------------------- /utils/src/embedding_allocation.cu: -------------------------------------------------------------------------------- 1 | // clang-format off 2 | /* 3 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | * SPDX-License-Identifier: Apache-2.0 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | // clang-format on 19 | 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | 27 | #include "absl/log/check.h" 28 | #include "cuembed/include/index_transforms.cuh" 29 | #include "utils/include/datagen.h" 30 | #include "utils/include/embedding_allocation.h" 31 | #include "utils/include/embedding_utils.h" 32 | 33 | namespace cuembed { 34 | 35 | namespace utils { 36 | AllocationOptions& AllocationOptions::num_categories(int32_t num_categories) { 37 | num_categories_ = num_categories; 38 | return *this; 39 | } 40 | 41 | AllocationOptions& AllocationOptions::batch_size(int32_t batch_size) { 42 | batch_size_ = batch_size; 43 | return *this; 44 | } 45 | 46 | AllocationOptions& AllocationOptions::hotness(int32_t hotness) { 47 | hotness_ = hotness; 48 | return *this; 49 | } 50 | 51 | AllocationOptions& AllocationOptions::alpha(float alpha) { 52 | alpha_ = alpha; 53 | return *this; 54 | } 55 | 56 | AllocationOptions& AllocationOptions::combine_mode(CombineMode type) { 57 | combine_mode_ = type; 58 | return *this; 59 | } 60 | 61 | AllocationOptions& AllocationOptions::embed_width(int32_t embed_width) { 62 | embed_width_ = embed_width; 63 | return *this; 64 | } 65 | 66 | AllocationOptions& AllocationOptions::permute_indices(bool permute_indices) { 67 | permute_indices_ = permute_indices; 68 | return *this; 69 | } 70 | 71 | AllocationOptions& AllocationOptions::shuffle_indices(int32_t shuffle_indices) { 72 | shuffle_indices_ = shuffle_indices; 73 | return *this; 74 | } 75 | 76 | AllocationOptions& AllocationOptions::is_csr(bool is_csr) { 77 | is_csr_ = is_csr; 78 | return *this; 79 | } 80 | 81 | AllocationOptions& AllocationOptions::is_weighted(bool is_weighted) { 82 | is_weighted_ = is_weighted; 83 | return *this; 84 | } 85 | 86 | AllocationOptions& AllocationOptions::compressed_grad(bool compressed_grad) { 87 | compressed_grad_ = compressed_grad; 88 | return *this; 89 | } 90 | 91 | AllocationOptions& AllocationOptions::skip_grad_init(bool skip_grad_init) { 92 | skip_grad_init_ = skip_grad_init; 93 | return *this; 94 | } 95 | 96 | template 101 | void AllocateForward(const AllocationOptions& options, 102 | thrust::universal_vector* embedding, 103 | thrust::universal_vector* indices, 104 | thrust::universal_vector* offsets, 105 | thrust::universal_vector* weights, 106 | thrust::universal_vector* result) { 107 | CHECK(options.num_categories() > 0 && options.batch_size() > 0 && 108 | options.hotness() > 0 && options.embed_width() > 0); 109 | embedding->resize(((int64_t)options.num_categories()) * options.embed_width(), 110 | 0); 111 | 112 | // Fill embedding with random values. 113 | std::default_random_engine rng(123456); 114 | std::uniform_real_distribution dist(-1, 1); 115 | std::generate( 116 | embedding->begin(), embedding->end(), [&] { return (InputT)dist(rng); }); 117 | 118 | if (options.combine_mode() != CombineMode::kConcat) { 119 | result->resize(((int64_t)options.batch_size()) * options.embed_width(), 0); 120 | } else { 121 | result->resize(((int64_t)options.batch_size()) * options.embed_width() * 122 | options.hotness(), 123 | 0); 124 | } 125 | 126 | // Generate offsets for CSR representation. Each batch may lookup a random 127 | // number of values (maximum num_hotness). 128 | offsets->resize(options.batch_size() + 1); 129 | CHECK_GT(options.batch_size(), 0); 130 | (*offsets)[0] = 0; // Starting value 131 | { 132 | std::uniform_int_distribution<> distrib(0, options.hotness()); 133 | for (int i = 0; i < options.batch_size(); ++i) { 134 | (*offsets)[i + 1] = (*offsets)[i] + distrib(rng); 135 | } 136 | } 137 | // Generate lookup indices. Generate num_hotness of indices for each sample. 138 | // Copy the first hotness_for_sample indices into generated indices. 139 | auto generator = index_generators::PowerLawFeatureGenerator( 140 | options.num_categories() - 1, 141 | options.hotness(), 142 | options.alpha(), 143 | options.shuffle_indices(), 144 | options.permute_indices(), 145 | index_generators::PowerLawType::kPsx); 146 | 147 | indices->clear(); 148 | for (int i = 0; i < options.batch_size(); i++) { 149 | auto generated_idx = generator.getCategoryIndices(); 150 | CHECK_EQ(generated_idx.size(), options.hotness()); 151 | int hotness_for_sample = options.hotness(); 152 | if (options.is_csr()) { 153 | hotness_for_sample = (*offsets)[i + 1] - (*offsets)[i]; 154 | } 155 | indices->insert(indices->end(), 156 | generated_idx.begin(), 157 | generated_idx.begin() + hotness_for_sample); 158 | } 159 | 160 | // Generate weights. Weights are either 0.5f or 0.25f for easier correctness 161 | // checking. 162 | { 163 | weights->resize(indices->size()); 164 | std::bernoulli_distribution distrib(0.5); 165 | for (size_t i = 0; i < weights->size(); i++) { 166 | (*weights)[i] = distrib(rng) ? 0.5f : 0.25f; 167 | } 168 | } 169 | } 170 | 171 | template 172 | void AllocateTranspose( 173 | const AllocationOptions& options, 174 | const int nnz, 175 | thrust::universal_vector* transpose_indices, 176 | thrust::universal_vector* transpose_remapped_indices, 177 | thrust::universal_vector* transpose_sample_ids, 178 | thrust::universal_vector* transpose_weights, 179 | thrust::universal_vector* sample_ids, 180 | thrust::universal_vector* transpose_workspace) { 181 | transpose_indices->resize(nnz, 0); 182 | transpose_remapped_indices->resize(nnz, 0); 183 | transpose_sample_ids->resize(nnz, 0); 184 | transpose_weights->resize(nnz, 0); 185 | sample_ids->resize(nnz, 0); 186 | 187 | // Allocate scratch space 188 | size_t lwork_transpose = 0; 189 | size_t lwork_compressed_grad = 0; 190 | thrust::universal_vector tmp_sample_ids((int64_t)nnz); 191 | thrust::universal_vector tmp_indices((int64_t)nnz); 192 | thrust::universal_vector tmp_weights((int64_t)nnz); 193 | const WeightT* weight_ptr = nullptr; 194 | WeightT* transpose_weight_ptr = nullptr; 195 | if (options.is_weighted()) { 196 | weight_ptr = tmp_weights.data().get(); 197 | transpose_weight_ptr = transpose_weights->data().get(); 198 | } 199 | Transpose(tmp_sample_ids.data().get(), 200 | tmp_indices.data().get(), 201 | weight_ptr, 202 | nnz, 203 | transpose_indices->data().get(), 204 | transpose_sample_ids->data().get(), 205 | transpose_weight_ptr, 206 | nullptr, 207 | &lwork_transpose); 208 | 209 | if (options.compressed_grad()) { 210 | ComputeCompressedGradIndices( 211 | transpose_indices->data().get(), 212 | nnz, 213 | transpose_remapped_indices->data().get(), 214 | nullptr, 215 | &lwork_compressed_grad); 216 | } 217 | 218 | transpose_workspace->resize(std::max(lwork_transpose, lwork_compressed_grad)); 219 | } 220 | 221 | template 222 | void AllocateBackward(const AllocationOptions& options, 223 | thrust::universal_vector* grad_y, 224 | thrust::universal_vector* grad_embedding, 225 | thrust::universal_vector* inverse_mapping, 226 | const int num_unique) { 227 | if (options.combine_mode() != CombineMode::kConcat) { 228 | grad_y->resize(((int64_t)options.batch_size()) * options.embed_width(), 0); 229 | } else { 230 | grad_y->resize(((int64_t)options.batch_size()) * options.embed_width() * 231 | options.hotness(), 232 | 0); 233 | } 234 | std::default_random_engine rng(654321); 235 | std::uniform_int_distribution dist(-10, 10); 236 | std::generate( 237 | grad_y->begin(), grad_y->end(), [&] { return (GradT)dist(rng); }); 238 | 239 | if (options.compressed_grad()) { 240 | grad_embedding->resize((int64_t)num_unique * options.embed_width(), 0); 241 | inverse_mapping->resize(num_unique, 0); 242 | } else { 243 | grad_embedding->resize( 244 | ((int64_t)options.num_categories()) * options.embed_width(), 0); 245 | inverse_mapping->resize(0, 0); 246 | } 247 | } 248 | 249 | template 255 | void AllocateHost(const AllocationOptions& options, 256 | UniversalEmbeddingAllocation* allocation, 262 | bool forward_only) { 263 | AllocateForward( 264 | options, 265 | &allocation->embedding, 266 | &allocation->indices, 267 | &allocation->offsets, 268 | &allocation->weights, 269 | &allocation->result); 270 | 271 | if (forward_only) { 272 | return; 273 | } 274 | 275 | int nnz = allocation->indices.size(); 276 | AllocateTranspose(options, 277 | nnz, 278 | &allocation->transpose_indices, 279 | &allocation->transpose_remapped_indices, 280 | &allocation->transpose_sample_ids, 281 | &allocation->transpose_weights, 282 | &allocation->sample_ids, 283 | &allocation->transpose_workspace); 284 | 285 | // Compute num_unique 286 | thrust::device_vector indices_copy = allocation->indices; 287 | 288 | // Sort the copy of the indices 289 | thrust::sort(indices_copy.begin(), indices_copy.end()); 290 | 291 | // Get the end of the unique range 292 | auto unique_end = thrust::unique(indices_copy.begin(), indices_copy.end()); 293 | 294 | // Calculate the number of unique indices 295 | int num_unique = unique_end - indices_copy.begin(); 296 | 297 | AllocateBackward(options, 298 | &allocation->grad_y, 299 | &allocation->grad_embedding, 300 | &allocation->inverse_mapping, 301 | num_unique); 302 | } 303 | 304 | template 310 | void AllocateDevice( 311 | const AllocationOptions& options, 312 | const UniversalEmbeddingAllocation& u_a, 318 | DeviceEmbeddingAllocation* 319 | d_a, 320 | bool forward_only) { 321 | // Resize the device vectors to match the universal vectors 322 | d_a->embedding.resize(u_a.embedding.size()); 323 | d_a->indices.resize(u_a.indices.size()); 324 | d_a->offsets.resize(u_a.offsets.size()); 325 | d_a->weights.resize(u_a.weights.size()); 326 | d_a->result.resize(u_a.result.size()); 327 | if (!forward_only) { 328 | d_a->transpose_indices.resize(u_a.transpose_indices.size()); 329 | d_a->transpose_remapped_indices.resize( 330 | u_a.transpose_remapped_indices.size()); 331 | d_a->transpose_sample_ids.resize(u_a.transpose_sample_ids.size()); 332 | d_a->transpose_weights.resize(u_a.transpose_weights.size()); 333 | d_a->sample_ids.resize(u_a.sample_ids.size()); 334 | d_a->transpose_workspace.resize(u_a.transpose_workspace.size()); 335 | d_a->grad_y.resize(u_a.grad_y.size()); 336 | d_a->grad_embedding.resize(u_a.grad_embedding.size()); 337 | d_a->inverse_mapping.resize(u_a.inverse_mapping.size()); 338 | } 339 | 340 | // Copy input data from universal vectors to device vectors 341 | thrust::copy( 342 | u_a.embedding.begin(), u_a.embedding.end(), d_a->embedding.begin()); 343 | thrust::copy(u_a.indices.begin(), u_a.indices.end(), d_a->indices.begin()); 344 | thrust::copy(u_a.offsets.begin(), u_a.offsets.end(), d_a->offsets.begin()); 345 | thrust::copy(u_a.weights.begin(), u_a.weights.end(), d_a->weights.begin()); 346 | if (!forward_only) { 347 | thrust::copy(u_a.grad_y.begin(), u_a.grad_y.end(), d_a->grad_y.begin()); 348 | } 349 | } 350 | 351 | #define ALLOCATE_TEMPLATE(InputT, OutputT, IndexT, WeightT, OffsetT, GradT) \ 352 | template void \ 353 | AllocateHost( \ 354 | const AllocationOptions& options, \ 355 | UniversalEmbeddingAllocation* u_a, \ 361 | bool forward_only); \ 362 | template void \ 363 | AllocateDevice( \ 364 | const AllocationOptions& options, \ 365 | const UniversalEmbeddingAllocation& u_a, \ 371 | DeviceEmbeddingAllocation* d_a, \ 377 | bool forward_only); 378 | 379 | ALLOCATE_TEMPLATE(float, float, int32_t, float, int, float) 380 | ALLOCATE_TEMPLATE(float, float, int64_t, float, int, float) 381 | ALLOCATE_TEMPLATE(__half, __half, int32_t, __half, int, __half) 382 | ALLOCATE_TEMPLATE(__half, __half, int64_t, __half, int, __half) 383 | ALLOCATE_TEMPLATE(__half, __half, int32_t, __half, int, float) 384 | ALLOCATE_TEMPLATE(__half, __half, int64_t, __half, int, float) 385 | 386 | } // namespace utils 387 | 388 | } // namespace cuembed 389 | -------------------------------------------------------------------------------- /utils/src/embedding_cpu.cu: -------------------------------------------------------------------------------- 1 | // clang-format off 2 | /* 3 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | * SPDX-License-Identifier: Apache-2.0 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | // clang-format on 19 | 20 | #include "absl/log/check.h" 21 | #include "utils/include/embedding_allocation.h" 22 | #include "utils/include/embedding_utils.h" 23 | 24 | // CPU reference implementations 25 | #include "utils/include/embedding_lookup_cpu.hpp" 26 | #include "utils/include/index_transforms_cpu.hpp" 27 | 28 | namespace cuembed { 29 | 30 | namespace utils { 31 | 32 | template 33 | void RunForwardReference(const utils::AllocationOptions& options, 34 | const thrust::universal_vector& embedding, 35 | const thrust::universal_vector& indices, 36 | const thrust::universal_vector& offsets, 37 | const thrust::universal_vector& weights, 38 | thrust::universal_vector* result) { 39 | const OffsetT* offsets_ptr = nullptr; 40 | int hotness = options.hotness(); 41 | if (options.is_csr()) { 42 | offsets_ptr = offsets.data().get(); 43 | hotness = 0; 44 | } 45 | 46 | const ElemT* weight_ptr = nullptr; 47 | if (options.is_weighted()) { 48 | weight_ptr = weights.data().get(); 49 | } 50 | using InputT = ElemT; 51 | using OutputT = ElemT; 52 | EmbeddingForwardCpu( 53 | embedding.data().get(), 54 | options.embed_width(), 55 | options.batch_size(), 56 | hotness, 57 | indices.data().get(), 58 | offsets_ptr, 59 | weight_ptr, 60 | result->data().get(), 61 | options.combine_mode()); 62 | } 63 | 64 | #define RUN_FORWARD_TEMPLATE(ElemT, IndexT, OffsetT, fp16_math) \ 65 | template void RunForwardReference( \ 66 | const utils::AllocationOptions& options, \ 67 | const thrust::universal_vector& embedding, \ 68 | const thrust::universal_vector& indices, \ 69 | const thrust::universal_vector& offsets, \ 70 | const thrust::universal_vector& weights, \ 71 | thrust::universal_vector* result); 72 | 73 | RUN_FORWARD_TEMPLATE(float, int32_t, int, false); 74 | RUN_FORWARD_TEMPLATE(float, int64_t, int, false); 75 | RUN_FORWARD_TEMPLATE(__half, int32_t, int, false); 76 | RUN_FORWARD_TEMPLATE(__half, int64_t, int, false); 77 | RUN_FORWARD_TEMPLATE(float, int32_t, int, true); 78 | RUN_FORWARD_TEMPLATE(float, int64_t, int, true); 79 | RUN_FORWARD_TEMPLATE(__half, int32_t, int, true); 80 | RUN_FORWARD_TEMPLATE(__half, int64_t, int, true); 81 | 82 | #undef RUN_FORWARD_TEMPLATE 83 | 84 | template 85 | void RunTransposeReference( 86 | const utils::AllocationOptions& options, 87 | const thrust::universal_vector& indices, 88 | const thrust::universal_vector& offsets, 89 | const thrust::universal_vector& weights, 90 | const int nnz, 91 | thrust::universal_vector* transpose_indices, 92 | thrust::universal_vector* transpose_remapped_indices, 93 | thrust::universal_vector* transpose_sample_ids, 94 | thrust::universal_vector* transpose_weights) { 95 | // Extract rows 96 | thrust::universal_vector sample_ids(indices.size(), 0); 97 | if (options.combine_mode() == CombineMode::kConcat) { 98 | ExtractRowIdsForConcatCpu(nnz, sample_ids.data().get()); 99 | } else if (options.is_csr()) { 100 | ExtractRowIdsFromCSRCpu( 101 | offsets.data().get(), options.batch_size(), sample_ids.data().get()); 102 | } else { 103 | ExtractRowIdsFromFixedCpu( 104 | options.batch_size(), options.hotness(), sample_ids.data().get()); 105 | } 106 | 107 | const WeightT* weight_ptr = nullptr; 108 | WeightT* transpose_weight_ptr = nullptr; 109 | if (options.is_weighted()) { 110 | weight_ptr = weights.data().get(); 111 | transpose_weight_ptr = transpose_weights->data().get(); 112 | } 113 | 114 | TransposeCpu(sample_ids.data().get(), 115 | indices.data().get(), 116 | weight_ptr, 117 | nnz, 118 | transpose_indices->data().get(), 119 | transpose_sample_ids->data().get(), 120 | transpose_weight_ptr); 121 | 122 | // Compute sparse indices 123 | if (options.compressed_grad()) { 124 | ComputeCompressedGradIndicesCpu( 125 | transpose_indices->data().get(), 126 | nnz, 127 | transpose_remapped_indices->data().get()); 128 | } 129 | } 130 | 131 | #define RUN_TRANSPOSE_TEMPLATE(IndexT, OffsetT, WeightT) \ 132 | template void RunTransposeReference( \ 133 | const utils::AllocationOptions& options, \ 134 | const thrust::universal_vector& indices, \ 135 | const thrust::universal_vector& offsets, \ 136 | const thrust::universal_vector& weights, \ 137 | const int nnz, \ 138 | thrust::universal_vector* transpose_indices, \ 139 | thrust::universal_vector* transpose_remapped_indices, \ 140 | thrust::universal_vector* transpose_sample_ids, \ 141 | thrust::universal_vector* transpose_weights); 142 | 143 | RUN_TRANSPOSE_TEMPLATE(int32_t, int, float); 144 | RUN_TRANSPOSE_TEMPLATE(int64_t, int, float); 145 | RUN_TRANSPOSE_TEMPLATE(int32_t, int, __half); 146 | RUN_TRANSPOSE_TEMPLATE(int64_t, int, __half); 147 | 148 | #undef RUN_TRANSPOSE_TEMPLATE 149 | 150 | template 151 | void RunBackwardReference( 152 | const utils::AllocationOptions& options, 153 | const thrust::universal_vector& grad_y, 154 | const thrust::universal_vector& transpose_indices, 155 | const thrust::universal_vector& transpose_remapped_indices, 156 | const thrust::universal_vector& transpose_sample_ids, 157 | const thrust::universal_vector& transpose_weights, 158 | const thrust::universal_vector& offsets, 159 | const int nnz, 160 | thrust::universal_vector* grad_embedding, 161 | thrust::universal_vector* inverse_mapping) { 162 | const ElemT* transpose_weight_ptr = nullptr; 163 | if (options.is_weighted()) { 164 | transpose_weight_ptr = transpose_weights.data().get(); 165 | } 166 | const IndexT* transpose_remapped_indices_ptr = nullptr; 167 | IndexT* inverse_mapping_ptr = nullptr; 168 | int num_grad_embedding_rows = options.num_categories(); 169 | if (options.compressed_grad()) { 170 | transpose_remapped_indices_ptr = transpose_remapped_indices.data().get(); 171 | inverse_mapping_ptr = inverse_mapping->data().get(); 172 | num_grad_embedding_rows = transpose_remapped_indices[nnz - 1] + 1; 173 | } 174 | 175 | EmbeddingBackwardCpu(grad_y.data().get(), 176 | options.embed_width(), 177 | num_grad_embedding_rows, 178 | nnz, 179 | transpose_indices.data().get(), 180 | transpose_sample_ids.data().get(), 181 | transpose_remapped_indices_ptr, 182 | transpose_weight_ptr, 183 | options.skip_grad_init(), 184 | grad_embedding->data().get(), 185 | inverse_mapping_ptr); 186 | } 187 | 188 | #define RUN_BACKWARD_TEMPLATE(GradT, IndexT, OffsetT) \ 189 | template void RunBackwardReference( \ 190 | const utils::AllocationOptions& options, \ 191 | const thrust::universal_vector& grad_y, \ 192 | const thrust::universal_vector& transpose_indices, \ 193 | const thrust::universal_vector& transpose_remapped_indices, \ 194 | const thrust::universal_vector& transpose_sample_ids, \ 195 | const thrust::universal_vector& transpose_weights, \ 196 | const thrust::universal_vector& offsets, \ 197 | const int nnz, \ 198 | thrust::universal_vector* grad_embedding, \ 199 | thrust::universal_vector* inverse_mapping); 200 | 201 | RUN_BACKWARD_TEMPLATE(float, int32_t, int); 202 | RUN_BACKWARD_TEMPLATE(float, int64_t, int); 203 | RUN_BACKWARD_TEMPLATE(__half, int32_t, int); 204 | RUN_BACKWARD_TEMPLATE(__half, int64_t, int); 205 | 206 | #undef RUN_BACKWARD_TEMPLATE 207 | 208 | } // namespace utils 209 | 210 | } // namespace cuembed 211 | -------------------------------------------------------------------------------- /utils/src/embedding_gpu_backward.cu: -------------------------------------------------------------------------------- 1 | // clang-format off 2 | /* 3 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | * SPDX-License-Identifier: Apache-2.0 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | // clang-format on 19 | 20 | #include 21 | 22 | #include "absl/log/check.h" 23 | #include "cuembed/include/embedding_lookup.cuh" 24 | #include "utils/include/embedding_allocation.h" 25 | #include "utils/include/embedding_utils.h" 26 | 27 | namespace cuembed { 28 | 29 | namespace utils { 30 | 31 | template 32 | void RunBackward( 33 | const utils::AllocationOptions& options, 34 | const thrust::device_vector& grad_y, 35 | const thrust::device_vector& transpose_indices, 36 | const thrust::device_vector& transpose_remapped_indices, 37 | const thrust::device_vector& transpose_sample_ids, 38 | const thrust::device_vector& transpose_weights, 39 | const thrust::device_vector& offsets, 40 | const OffsetT nnz, 41 | const OffsetT num_unique, 42 | thrust::device_vector* grad_embedding, 43 | thrust::device_vector* inverse_mapping) { 44 | const IndexT* transpose_remapped_indices_ptr = nullptr; 45 | IndexT* inverse_mapping_ptr = nullptr; 46 | if (options.compressed_grad()) { 47 | transpose_remapped_indices_ptr = transpose_remapped_indices.data().get(); 48 | inverse_mapping_ptr = inverse_mapping->data().get(); 49 | } 50 | 51 | const ElemT* transpose_weights_ptr = nullptr; 52 | if (options.is_weighted()) { 53 | transpose_weights_ptr = transpose_weights.data().get(); 54 | } 55 | 56 | cuembed::EmbeddingBackward( 57 | grad_y.data().get(), 58 | options.embed_width(), 59 | options.compressed_grad() ? num_unique : options.num_categories(), 60 | nnz, 61 | transpose_indices.data().get(), 62 | transpose_sample_ids.data().get(), 63 | transpose_remapped_indices_ptr, 64 | transpose_weights_ptr, 65 | options.skip_grad_init(), 66 | grad_embedding->data().get(), 67 | inverse_mapping_ptr); 68 | } 69 | 70 | #define RUN_BACKWARD_TEMPLATE(GradT, IndexT, OffsetT) \ 71 | template void RunBackward( \ 72 | const utils::AllocationOptions& options, \ 73 | const thrust::device_vector& grad_y, \ 74 | const thrust::device_vector& transpose_indices, \ 75 | const thrust::device_vector& transpose_remapped_indices, \ 76 | const thrust::device_vector& transpose_sample_ids, \ 77 | const thrust::device_vector& transpose_weights, \ 78 | const thrust::device_vector& offsets, \ 79 | const OffsetT nnz, \ 80 | const OffsetT num_unique, \ 81 | thrust::device_vector* grad_embedding, \ 82 | thrust::device_vector* inverse_mapping); 83 | 84 | RUN_BACKWARD_TEMPLATE(float, int32_t, int); 85 | RUN_BACKWARD_TEMPLATE(float, int64_t, int); 86 | RUN_BACKWARD_TEMPLATE(__half, int32_t, int); 87 | RUN_BACKWARD_TEMPLATE(__half, int64_t, int); 88 | 89 | #undef RUN_BACKWARD_TEMPLATE 90 | 91 | } // namespace utils 92 | 93 | } // namespace cuembed 94 | -------------------------------------------------------------------------------- /utils/src/embedding_gpu_forward.cu: -------------------------------------------------------------------------------- 1 | // clang-format off 2 | /* 3 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | * SPDX-License-Identifier: Apache-2.0 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | // clang-format on 19 | 20 | #include "absl/log/check.h" 21 | #include "cuembed/include/embedding_lookup.cuh" 22 | #include "utils/include/embedding_allocation.h" 23 | #include "utils/include/embedding_utils.h" 24 | 25 | namespace cuembed { 26 | 27 | namespace utils { 28 | 29 | template 30 | void RunForward(const utils::AllocationOptions& options, 31 | const thrust::device_vector& embedding, 32 | const thrust::device_vector& indices, 33 | const thrust::device_vector& offsets, 34 | const thrust::device_vector& weights, 35 | thrust::device_vector* result) { 36 | const int* offsets_ptr = nullptr; 37 | int hotness = options.hotness(); 38 | if (options.is_csr()) { 39 | offsets_ptr = offsets.data().get(); 40 | hotness = 0; 41 | } 42 | const ElemT* weight_ptr = nullptr; 43 | if (options.is_weighted()) { 44 | weight_ptr = weights.data().get(); 45 | } 46 | using InputT = ElemT; 47 | using OutputT = ElemT; 48 | EmbeddingForward( 49 | embedding.data().get(), 50 | options.embed_width(), 51 | indices.data().get(), 52 | offsets_ptr, 53 | weight_ptr, 54 | options.batch_size(), 55 | hotness, 56 | options.combine_mode(), 57 | result->data().get()); 58 | } 59 | 60 | #define RUN_FORWARD_TEMPLATE(ElemT, IndexT, OffsetT, fp16_math) \ 61 | template void RunForward( \ 62 | const utils::AllocationOptions& options, \ 63 | const thrust::device_vector& embedding, \ 64 | const thrust::device_vector& indices, \ 65 | const thrust::device_vector& offsets, \ 66 | const thrust::device_vector& weights, \ 67 | thrust::device_vector* result); 68 | 69 | RUN_FORWARD_TEMPLATE(float, int32_t, int, false); 70 | RUN_FORWARD_TEMPLATE(float, int64_t, int, false); 71 | RUN_FORWARD_TEMPLATE(__half, int32_t, int, false); 72 | RUN_FORWARD_TEMPLATE(__half, int64_t, int, false); 73 | RUN_FORWARD_TEMPLATE(float, int32_t, int, true); 74 | RUN_FORWARD_TEMPLATE(float, int64_t, int, true); 75 | RUN_FORWARD_TEMPLATE(__half, int32_t, int, true); 76 | RUN_FORWARD_TEMPLATE(__half, int64_t, int, true); 77 | 78 | #undef RUN_FORWARD_TEMPLATE 79 | 80 | } // namespace utils 81 | 82 | } // namespace cuembed 83 | -------------------------------------------------------------------------------- /utils/src/embedding_gpu_transpose.cu: -------------------------------------------------------------------------------- 1 | // clang-format off 2 | /* 3 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | * SPDX-License-Identifier: Apache-2.0 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | // clang-format on 19 | 20 | #include 21 | 22 | #include "absl/log/check.h" 23 | #include "cuembed/include/index_transforms.cuh" 24 | #include "utils/include/embedding_allocation.h" 25 | #include "utils/include/embedding_utils.h" 26 | 27 | namespace cuembed { 28 | 29 | namespace utils { 30 | 31 | template 32 | void RunTranspose(const utils::AllocationOptions& options, 33 | const thrust::device_vector& indices, 34 | const thrust::device_vector& offsets, 35 | const thrust::device_vector& weights, 36 | const OffsetT nnz, 37 | thrust::device_vector* transpose_indices, 38 | thrust::device_vector* transpose_remapped_indices, 39 | thrust::device_vector* transpose_sample_ids, 40 | thrust::device_vector* transpose_weights, 41 | thrust::device_vector* sample_ids, 42 | thrust::device_vector* transpose_workspace) { 43 | if (options.combine_mode() == CombineMode::kConcat) { 44 | ExtractRowIdsForConcat(nnz, sample_ids->data().get()); 45 | } else if (options.is_csr()) { 46 | ExtractRowIdsFromCSR( 47 | offsets.data().get(), options.batch_size(), sample_ids->data().get()); 48 | } else { 49 | ExtractRowIdsFromFixed( 50 | options.batch_size(), options.hotness(), sample_ids->data().get()); 51 | } 52 | 53 | const WeightT* weight_ptr = nullptr; 54 | WeightT* transpose_weight_ptr = nullptr; 55 | if (options.is_weighted()) { 56 | weight_ptr = weights.data().get(); 57 | transpose_weight_ptr = transpose_weights->data().get(); 58 | } 59 | 60 | size_t lwork = transpose_workspace->size(); 61 | Transpose(sample_ids->data().get(), 62 | indices.data().get(), 63 | weight_ptr, 64 | nnz, 65 | transpose_indices->data().get(), 66 | transpose_sample_ids->data().get(), 67 | transpose_weight_ptr, 68 | transpose_workspace->data().get(), 69 | &lwork); 70 | 71 | if (options.compressed_grad()) { 72 | ComputeCompressedGradIndices( 73 | transpose_indices->data().get(), 74 | nnz, 75 | transpose_remapped_indices->data().get(), 76 | transpose_workspace->data().get(), 77 | &lwork); 78 | } 79 | } 80 | 81 | #define RUN_TRANSPOSE_TEMPLATE(IndexT, OffsetT, WeightT) \ 82 | template void RunTranspose( \ 83 | const utils::AllocationOptions& options, \ 84 | const thrust::device_vector& indices, \ 85 | const thrust::device_vector& offsets, \ 86 | const thrust::device_vector& weights, \ 87 | const OffsetT nnz, \ 88 | thrust::device_vector* transpose_indices, \ 89 | thrust::device_vector* transpose_remapped_indices, \ 90 | thrust::device_vector* transpose_sample_ids, \ 91 | thrust::device_vector* transpose_weights, \ 92 | thrust::device_vector* sample_ids, \ 93 | thrust::device_vector* transpose_workspace); 94 | 95 | RUN_TRANSPOSE_TEMPLATE(int32_t, int, float); 96 | RUN_TRANSPOSE_TEMPLATE(int64_t, int, float); 97 | RUN_TRANSPOSE_TEMPLATE(int32_t, int, __half); 98 | RUN_TRANSPOSE_TEMPLATE(int64_t, int, __half); 99 | 100 | #undef RUN_TRANSPOSE_TEMPLATE 101 | 102 | } // namespace utils 103 | 104 | } // namespace cuembed 105 | --------------------------------------------------------------------------------