├── .clang-format ├── .clang-tidy ├── .gitignore ├── CMakeLists.txt ├── LICENSE ├── README.md ├── cmake ├── arch.cmake ├── clang-format.cmake ├── compilation-flags.cmake ├── sources.cmake └── test_x86_64_avx512_ifma.c ├── include ├── internal │ ├── avx512.h │ ├── defs.h │ ├── fast_mul_operators.h │ └── pre_compute.h ├── ntt_avx512_ifma.h ├── ntt_hexl.h ├── ntt_radix4.h ├── ntt_radix4_s390x_vef.h ├── ntt_radix4x4.h ├── ntt_reference.h └── ntt_seal.h ├── src ├── ntt_r2_16_avx512_ifma.c ├── ntt_r4r2_avx512_ifma.c ├── ntt_radix4.c ├── ntt_radix4_avx512_ifma.c ├── ntt_radix4_avx512_ifma_unordered.c ├── ntt_radix4_s390x_vef.c ├── ntt_radix4x4.c └── ntt_reference.c ├── tests ├── bench.c ├── main.c ├── measurements.h ├── pre-commit-script.sh ├── test_cases.h ├── test_correctness.c ├── tests.h └── utils.h └── third_party ├── CMakeLists.txt ├── README.md ├── hexl ├── avx512-util.h ├── fwd-ntt-avx512.c └── ntt-avx512-util.h └── seal └── ntt_seal.c /.clang-format: -------------------------------------------------------------------------------- 1 | AlignAfterOpenBracket: true 2 | AlignConsecutiveMacros: true 3 | AlignConsecutiveAssignments: true 4 | AlignConsecutiveDeclarations: true 5 | AlignEscapedNewlines: Left 6 | AlignTrailingComments: true 7 | AllowAllParametersOfDeclarationOnNextLine: true 8 | AllowAllArgumentsOnNextLine: false 9 | AllowShortCaseLabelsOnASingleLine: true 10 | AllowShortFunctionsOnASingleLine: true 11 | AllowShortIfStatementsOnASingleLine: true 12 | AllowShortLoopsOnASingleLine: true 13 | AlwaysBreakBeforeMultilineStrings: false 14 | AlwaysBreakAfterReturnType: None 15 | BinPackParameters: false 16 | BreakBeforeBraces: Custom 17 | BraceWrapping: 18 | AfterCaseLabel: false 19 | AfterControlStatement: false 20 | AfterEnum: true 21 | AfterExternBlock: false 22 | AfterFunction: true 23 | AfterNamespace: false 24 | AfterStruct: false 25 | AfterUnion: false 26 | BeforeElse: false 27 | SplitEmptyFunction: false 28 | BreakBeforeBinaryOperators: false 29 | ColumnLimit: 82 30 | ContinuationIndentWidth: 2 31 | DerivePointerAlignment: false 32 | IndentCaseLabels: true 33 | IndentPPDirectives: AfterHash 34 | IndentWidth: 2 35 | IndentWrappedFunctionNames: false 36 | MaxEmptyLinesToKeep: 1 37 | NamespaceIndentation: None 38 | PenaltyReturnTypeOnItsOwnLine: 25 39 | PointerAlignment: Right 40 | ReflowComments: true 41 | SpaceAfterCStyleCast: false 42 | SpaceBeforeAssignmentOperators: true 43 | SpaceBeforeParens: Never 44 | SpaceInEmptyParentheses: false 45 | SpacesBeforeTrailingComments: 1 46 | SpacesInContainerLiterals: false 47 | SortIncludes: true 48 | UseTab: Never -------------------------------------------------------------------------------- /.clang-tidy: -------------------------------------------------------------------------------- 1 | # We remove the cert* checks that are related to rand() and srand() 2 | 3 | Checks: '-*, 4 | bugprone-*, 5 | cert-*, 6 | -cert-msc50-cpp, 7 | -cert-msc51-cpp, 8 | -cert-msc30-c, 9 | -cert-msc32-c, 10 | darwin-*, 11 | hicpp-*, 12 | -hicpp-signed-bitwise, 13 | -hicpp-no-assembler, 14 | misc-*, 15 | readability-*' 16 | 17 | WarningsAsErrors: '*' 18 | HeaderFilterRegex: 'B/.*third_party/.*' 19 | FormatStyle: 'file' 20 | CheckOptions: 21 | - key: bugprone-argument-comment.StrictMode 22 | value: '1' 23 | - key: bugprone-argument-comment.CommentBoolLiterals 24 | value: '1' 25 | - key: bugprone-argument-comment.CommentIntegerLiterals 26 | value: '0' 27 | - key: bugprone-argument-comment.CommentFloatLiterals 28 | value: '1' 29 | - key: bugprone-argument-comment.CommentCharacterLiterals 30 | value: '1' 31 | - key: bugprone-argument-comment.CommentUserDefinedLiterals 32 | value: '1' 33 | - key: bugprone-argument-comment.CommentNullPtrs 34 | value: '1' 35 | - key: bugprone-misplaced-widening-cast.CheckImplicitCasts 36 | value: '1' 37 | - key: bugprone-sizeof-expression.WarnOnSizeOfConstant 38 | value: '1' 39 | - key: bugprone-sizeof-expression.WarnOnSizeOfIntegerExpression 40 | value: '1' 41 | - key: bugprone-sizeof-expression.WarnOnSizeOfCompareToConstant 42 | value: '1' 43 | - key: bugprone-suspicious-string-compare.WarnOnImplicitComparison 44 | value: '1' 45 | - key: bugprone-suspicious-string-compare.WarnOnLogicalNotComparison 46 | value: '1' 47 | - key: bugprone-suspicious-string-compare.StringCompareLikeFunctions 48 | value: '1' 49 | - key: google-runtime-int.TypeSufix 50 | value: '_t' 51 | - key: readability-magic-numbers.IgnoredIntegerValues 52 | value: '0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;16' 53 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | *.d 3 | 4 | # Object files 5 | *.o 6 | *.ko 7 | *.obj 8 | *.elf 9 | 10 | # Linker output 11 | *.ilk 12 | *.map 13 | *.exp 14 | 15 | # Precompiled Headers 16 | *.gch 17 | *.pch 18 | 19 | # Libraries 20 | *.lib 21 | *.a 22 | *.la 23 | *.lo 24 | 25 | # Shared objects (inc. Windows DLLs) 26 | *.dll 27 | *.so 28 | *.so.* 29 | *.dylib 30 | 31 | # Executables 32 | *.exe 33 | *.out 34 | *.app 35 | *.i*86 36 | *.x86_64 37 | *.hex 38 | 39 | # Debug files 40 | *.dSYM/ 41 | *.su 42 | *.idb 43 | *.pdb 44 | 45 | # Kernel Module Compile Results 46 | *.mod* 47 | *.cmd 48 | .tmp_versions/ 49 | modules.order 50 | Module.symvers 51 | Mkfile.old 52 | dkms.conf 53 | 54 | # CMake compilation dir 55 | build 56 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright IBM Inc. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | cmake_minimum_required(VERSION 3.0.0) 5 | 6 | if(NOT CMAKE_BUILD_TYPE) 7 | set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Build type" FORCE) 8 | set_property(CACHE CMAKE_BUILD_TYPE PROPERTY 9 | STRINGS "Release" "Debug" "MinSizeRel" "RelWithDebInfo") 10 | endif() 11 | message(STATUS "Build type (CMAKE_BUILD_TYPE): ${CMAKE_BUILD_TYPE}") 12 | 13 | project (ntt-variants C) 14 | 15 | set(INCLUDE_DIR ${PROJECT_SOURCE_DIR}/include) 16 | set(SRC_DIR ${PROJECT_SOURCE_DIR}/src) 17 | set(TESTS_DIR ${PROJECT_SOURCE_DIR}/tests) 18 | set(THIRD_PARTY_DIR ${PROJECT_SOURCE_DIR}/third_party) 19 | 20 | include_directories(${INCLUDE_DIR}) 21 | include_directories(${INCLUDE_DIR}/internal) 22 | 23 | include(cmake/arch.cmake) 24 | 25 | include(cmake/compilation-flags.cmake) 26 | 27 | # Depends on SRC_DIR 28 | # and on arch.cmake 29 | include(cmake/sources.cmake) 30 | 31 | include(cmake/clang-format.cmake) 32 | 33 | add_subdirectory(third_party) 34 | 35 | add_executable(${PROJECT_NAME} 36 | 37 | ${NTT_SOURCES} 38 | ${MAIN_SOURCE} 39 | $ 40 | ) 41 | 42 | set(BENCH ${PROJECT_NAME}-bench) 43 | ADD_EXECUTABLE(${BENCH} 44 | 45 | ${NTT_SOURCES} 46 | ${MAIN_SOURCE} 47 | $ 48 | ) 49 | SET_TARGET_PROPERTIES(${BENCH} PROPERTIES COMPILE_FLAGS "-DTEST_SPEED") 50 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # optimized-number-theoretic-transform-implementations 2 | 3 | This sample code package is an implementation of the Number Theoretic Transform (NTT) 4 | algorithm for the ring R/(X^N + 1) where N=2^m. 5 | 6 | This sample code provides testing binaries but no shared or static libraries. This is because the code is designed to be used for benchmarking purposes only and not in final products. 7 | 8 | ## License 9 | 10 | This project is licensed under the Apache-2.0 License. 11 | 12 | Dependencies 13 | ----- 14 | This package requires 15 | - CMake 3 and above 16 | - A compiler that supports the required C intrinsics (e.g., VMSL on s390x machines or AVX512-IFMA on X86_64 platforms). For example, GCC-10 and Clang-12. 17 | 18 | BUILD 19 | ----- 20 | 21 | To build the directory first create a working directory 22 | ``` 23 | mkdir build 24 | cd build 25 | ``` 26 | 27 | Then, run CMake and compile 28 | ``` 29 | cmake .. 30 | make 31 | ``` 32 | 33 | To run 34 | 35 | `./ntt-variants` 36 | 37 | Additional CMake compilation flags: 38 | - DEBUG - To enable debug prints 39 | 40 | To clean - remove the `build` directory. Note that a "clean" is required prior to compilation with modified flags. 41 | 42 | To format (`clang-format-9` or above is required): 43 | 44 | `make format` 45 | 46 | To use clang-tidy (`clang-tidy-9` is required): 47 | 48 | ``` 49 | CC=clang-12 cmake -DCMAKE_C_CLANG_TIDY="clang-tidy;--format-style=file" .. 50 | make 51 | ``` 52 | 53 | Before committing code, please test it using 54 | `tests/pre-commit-script.sh` 55 | This will run all the sanitizers and also `clang-format` and `clang-tidy` (requires clang-9 to be installed). 56 | 57 | The package was compiled and tested with gcc-10 and clang-12 in 64-bit mode. 58 | Tests were run on a Linux (Ubuntu 20.04.2 LTS) OS on s390x and on ICL x86-64 platforms. 59 | Compilation on other platforms may require some adjustments. 60 | 61 | Performance measurements 62 | ------------------------ 63 | The performance measurements are reported in processor cycles (per single core). The results are obtained using the following methodology. Each measured function was isolated, run 10 times (warm-up), followed by 200 iterations that were clocked and averaged. To minimize the effect of background tasks running on the system, every experiment was repeated 10 times, and the minimum result is reported. 64 | 65 | To run the benchmarking 66 | 67 | `./ntt-variants-bench` 68 | 69 | Testing 70 | ------- 71 | - The library has several fixed test-cases with different values of `q` and `N`. 72 | - The library was run using Address/Undefined-Behaviour sanitizers. 73 | 74 | Notes 75 | ----- 76 | clang-12 achieves much better results than GCC-10. 77 | -------------------------------------------------------------------------------- /cmake/arch.cmake: -------------------------------------------------------------------------------- 1 | # Copyright IBM Inc. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(s390x)$") 5 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DS390X") 6 | set(S390X 1) 7 | elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|amd64|AMD64)$") 8 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DX86_64") 9 | set(X86_64 1) 10 | elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(aarch64|arm64|arm64e)$") 11 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DAARCH64") 12 | set(AARCH64 1) 13 | endif() 14 | 15 | if(X86_64) 16 | # Test AVX512-IFMA 17 | try_run(RUN_RESULT COMPILE_RESULT 18 | "${CMAKE_BINARY_DIR}" "${PROJECT_SOURCE_DIR}/cmake/test_x86_64_avx512_ifma.c" 19 | COMPILE_DEFINITIONS "-march=native -Werror -Wall -Wpedantic" 20 | OUTPUT_VARIABLE OUTPUT 21 | ) 22 | 23 | if(${COMPILE_RESULT} AND (RUN_RESULT EQUAL 0)) 24 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DAVX512_IFMA_SUPPORT") 25 | set(AVX512_IFMA 1) 26 | else() 27 | message(STATUS "The AVX512_IFMA implementation is not supported") 28 | endif() 29 | endif() 30 | -------------------------------------------------------------------------------- /cmake/clang-format.cmake: -------------------------------------------------------------------------------- 1 | # Copyright IBM Inc. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # Some of the definitions in .clang-format require clang-format-9 and above. 5 | find_program(CLANG_FORMAT 6 | NAMES 7 | clang-format-12 8 | clang-format-11 9 | clang-format-10 10 | clang-format-9 11 | clang-format) 12 | 13 | IF(CLANG_FORMAT) 14 | # Get the major version of clang-format 15 | # CLANG_FORMAT_VERSION should be in the format "clang-format version [Major].[Minor].[Patch] " 16 | exec_program(${CLANG_FORMAT} ${CMAKE_CURRENT_SOURCE_DIR} ARGS --version OUTPUT_VARIABLE CLANG_FORMAT_VERSION) 17 | STRING(REGEX REPLACE ".* ([0-9]+)\\.[0-9]+\\.[0-9]+.*" "\\1" CLANG_FORMAT_MAJOR_VERSION ${CLANG_FORMAT_VERSION}) 18 | 19 | message(STATUS "Found version ${CLANG_FORMAT_MAJOR_VERSION} of clang-format.") 20 | if(${CLANG_FORMAT_MAJOR_VERSION} LESS "9") 21 | message(STATUS "To run the format target clang-format version >= 9 is required.") 22 | else() 23 | set(CLANG_FORMAT_FILE_TYPES ${CLANG_FORMAT_FILE_TYPES} ) 24 | file(GLOB_RECURSE CF_FILES1 ${SRC_DIR}/*.c) 25 | file(GLOB_RECURSE CF_FILES2 ${INCLUDE_DIR}/*.h ${INCLUDE_DIR}/internal/*.h) 26 | file(GLOB_RECURSE CF_FILES3 ${TESTS_DIR}/*.c ${TESTS_DIR}/*.h) 27 | set(FILES_TO_FORMAT "${CF_FILES1}" "${CF_FILES2}" "${CF_FILES3}") 28 | 29 | ADD_CUSTOM_TARGET( 30 | format 31 | COMMAND ${CLANG_FORMAT} -i -style=file ${FILES_TO_FORMAT} 32 | COMMENT "Clang-formatting all (*.c/*.h) source files" 33 | ) 34 | endif() 35 | else() 36 | message(STATUS "Did not find clang-format.") 37 | endif() 38 | -------------------------------------------------------------------------------- /cmake/compilation-flags.cmake: -------------------------------------------------------------------------------- 1 | # Copyright IBM Inc. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | if(CMAKE_C_COMPILER_ID MATCHES "Clang") 5 | set(CLANG 1) 6 | endif() 7 | 8 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -ggdb -O3 -fPIC") 9 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fvisibility=hidden -Wall -Wextra -Werror -Wpedantic") 10 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wunused -Wcomment -Wchar-subscripts -Wuninitialized -Wshadow") 11 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wwrite-strings -Wformat-security -Wcast-qual -Wunused-result") 12 | 13 | if(S390X) 14 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=z14 -mvx -mzvector") 15 | if (NOT CLANG) 16 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mbranch-cost=3") 17 | endif() 18 | elseif(X86_64) 19 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=native -mno-red-zone") 20 | else() 21 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mcpu=native") 22 | endif() 23 | 24 | if(MSAN) 25 | if(NOT CLANG) 26 | message(FATAL_ERROR "Cannot enable MSAN unless using Clang") 27 | endif() 28 | 29 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=memory -fsanitize-memory-track-origins -fno-omit-frame-pointer") 30 | endif() 31 | 32 | if(ASAN) 33 | if(NOT CLANG) 34 | message(FATAL_ERROR "Cannot enable ASAN unless using Clang") 35 | endif() 36 | 37 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=address -fsanitize-address-use-after-scope -fno-omit-frame-pointer") 38 | endif() 39 | 40 | if(TSAN) 41 | if(NOT CLANG) 42 | message(FATAL_ERROR "Cannot enable TSAN unless using Clang") 43 | endif() 44 | if(S390X) 45 | message(FATAL_ERROR "Cannot enable TSAN for s390x machines") 46 | endif() 47 | 48 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=thread") 49 | endif() 50 | 51 | if(UBSAN) 52 | if(NOT CLANG) 53 | message(FATAL_ERROR "Cannot enable UBSAN unless using Clang") 54 | endif() 55 | 56 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=undefined") 57 | endif() 58 | 59 | if(DEBUG) 60 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DDEBUG") 61 | endif() 62 | 63 | if(INTEL_SDE) 64 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DINTEL_SDE") 65 | endif() 66 | -------------------------------------------------------------------------------- /cmake/sources.cmake: -------------------------------------------------------------------------------- 1 | # Copyright IBM Inc. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | set(NTT_SOURCES 5 | ${SRC_DIR}/ntt_radix4.c 6 | ${SRC_DIR}/ntt_radix4x4.c 7 | ${SRC_DIR}/ntt_reference.c 8 | ) 9 | 10 | if(S390X) 11 | set(NTT_SOURCES ${NTT_SOURCES} 12 | ${SRC_DIR}/ntt_radix4_s390x_vef.c 13 | ) 14 | endif() 15 | 16 | if(X86_64 AND AVX512_IFMA) 17 | set(NTT_SOURCES ${NTT_SOURCES} 18 | ${SRC_DIR}/ntt_radix4_avx512_ifma.c 19 | ${SRC_DIR}/ntt_r4r2_avx512_ifma.c 20 | ${SRC_DIR}/ntt_r2_16_avx512_ifma.c 21 | ${SRC_DIR}/ntt_radix4_avx512_ifma_unordered.c 22 | ) 23 | endif() 24 | 25 | set(MAIN_SOURCE 26 | ${TESTS_DIR}/main.c 27 | ${TESTS_DIR}/bench.c 28 | ${TESTS_DIR}/test_correctness.c 29 | ) 30 | -------------------------------------------------------------------------------- /cmake/test_x86_64_avx512_ifma.c: -------------------------------------------------------------------------------- 1 | // Copyright IBM Inc. All Rights Reserved. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | #include 5 | #include 6 | 7 | int main(void) 8 | { 9 | __m512i reg = {0}; 10 | uint64_t mem[8] = {0}; 11 | reg = _mm512_loadu_si512((const __m512i*)mem); 12 | reg = _mm512_madd52lo_epu64(reg, reg, reg); 13 | _mm512_storeu_si512((__m512i*)mem, reg); 14 | 15 | return 0; 16 | } 17 | -------------------------------------------------------------------------------- /include/internal/avx512.h: -------------------------------------------------------------------------------- 1 | // Copyright IBM Inc. All Rights Reserved. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | #pragma once 5 | 6 | #include "defs.h" 7 | 8 | EXTERNC_BEGIN 9 | 10 | #include 11 | 12 | #define ADD(a, b) _mm512_add_epi64(a, b) 13 | #define SUB(a, b) _mm512_sub_epi64(a, b) 14 | #define MIN(a, b) _mm512_min_epu64(a, b) 15 | 16 | #define SHUF(a, b, mask) _mm512_shuffle_i64x2((a), (b), (mask)) 17 | #define PERM(a, idx, b) _mm512_permutex2var_epi64((a), (idx), (b)) 18 | #define UNPACKLO(a, b) _mm512_unpacklo_epi64((a), (b)) 19 | #define UNPACKHI(a, b) _mm512_unpackhi_epi64((a), (b)) 20 | 21 | #define MADDLO(accum, op1, op2) _mm512_madd52lo_epu64((accum), (op1), (op2)) 22 | #define MADDHI(accum, op1, op2) _mm512_madd52hi_epu64((accum), (op1), (op2)) 23 | 24 | #define SET1(val) _mm512_set1_epi64(val) 25 | #define SETR(a, b, c, d, e, f, g, h) \ 26 | _mm512_setr_epi64((a), (b), (c), (d), (e), (f), (g), (h)) 27 | #define LOAD(mem) _mm512_loadu_epi64((mem)) 28 | #define LOADA(mem) _mm512_load_epi64((mem)) 29 | #define STORE(mem, reg) _mm512_storeu_epi64((mem), (reg)) 30 | 31 | #define GATHER(idx, mem, scale) _mm512_i64gather_epi64((idx), (mem), (scale)) 32 | #define SCATTER(mem, idx, reg, scale) \ 33 | _mm512_i64scatter_epi64((mem), (idx), (reg), scale) 34 | 35 | #define SET1_256(val) _mm256_set1_epi64x(val) 36 | #define BROADCAST2HALVES(v1, v2) _mm512_inserti64x4(SET1((v1)), SET1_256((v2)), 1) 37 | 38 | typedef struct mul_op_m512_s { 39 | __m512i op; 40 | __m512i con; 41 | } mul_op_m512_t; 42 | 43 | static inline __m512i reduce_if_greater(const __m512i val, const __m512i mod) 44 | { 45 | return MIN(val, SUB(val, mod)); 46 | } 47 | 48 | static inline __m512i 49 | fast_mul_mod_q2_m512(const mul_op_m512_t w, const __m512i t, const __m512i neg_q) 50 | { 51 | const __m512i zero = SET1(0); 52 | const __m512i Q = MADDHI(zero, w.con, t); 53 | const __m512i tmp = MADDLO(zero, Q, neg_q); 54 | return MADDLO(tmp, w.op, t) & AVX512_IFMA_WORD_SIZE_MASK; 55 | } 56 | 57 | static inline __m512i fast_dbl_mul_mod_q2_m512(const mul_op_m512_t w1, 58 | const mul_op_m512_t w2, 59 | const __m512i t1, 60 | const __m512i t2, 61 | const __m512i neg_q) 62 | { 63 | const __m512i zero = SET1(0); 64 | const __m512i TMP1 = MADDHI(zero, w1.con, t1); 65 | const __m512i Q = MADDHI(TMP1, w2.con, t2); 66 | 67 | const __m512i TMP2 = MADDLO(zero, w1.op, t1); 68 | const __m512i TMP3 = MADDLO(TMP2, w2.op, t2); 69 | const __m512i TMP4 = MADDLO(TMP3, Q, neg_q); 70 | return TMP4 & AVX512_IFMA_WORD_SIZE_MASK; 71 | } 72 | 73 | static inline void fwd_radix4_butterfly_m512(__m512i * X, 74 | __m512i * Y, 75 | __m512i * Z, 76 | __m512i * T, 77 | const mul_op_m512_t w[5], 78 | const uint64_t q_64) 79 | { 80 | const __m512i neg_q = SET1(-1 * q_64); 81 | const __m512i q2 = SET1(q_64 << 1); 82 | const __m512i q4 = SET1(q_64 << 2); 83 | 84 | const __m512i T1 = reduce_if_greater(*X, q4); 85 | const __m512i T2 = fast_mul_mod_q2_m512(w[0], *Z, neg_q); 86 | 87 | const __m512i Y1 = fast_dbl_mul_mod_q2_m512(w[1], w[2], *Y, *T, neg_q); 88 | const __m512i Y2 = fast_dbl_mul_mod_q2_m512(w[3], w[4], *Y, *T, neg_q); 89 | 90 | const __m512i T3 = ADD(T1, T2); 91 | const __m512i T4 = SUB(T1, T2); 92 | 93 | *X = ADD(T3, Y1); 94 | *Y = ADD(SUB(q2, Y1), T3); 95 | *Z = ADD(ADD(q2, Y2), T4); 96 | *T = ADD(SUB(q4, Y2), T4); 97 | } 98 | 99 | static inline void fwd_radix2_butterfly_m512(__m512i * X, 100 | __m512i * Y, 101 | const mul_op_m512_t *w, 102 | const uint64_t q_64) 103 | { 104 | const __m512i neg_q = SET1(-1 * q_64); 105 | const __m512i q2 = SET1(q_64 << 1); 106 | 107 | *X = reduce_if_greater(*X, q2); 108 | const __m512i T = fast_mul_mod_q2_m512(*w, *Y, neg_q); 109 | 110 | *Y = ADD(SUB(q2, T), *X); 111 | *X = ADD(*X, T); 112 | } 113 | 114 | EXTERNC_END 115 | -------------------------------------------------------------------------------- /include/internal/defs.h: -------------------------------------------------------------------------------- 1 | // Copyright IBM Inc. All Rights Reserved. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #ifdef __cplusplus 11 | # define EXTERNC extern "C" 12 | # define EXTERNC_BEGIN extern "C" { 13 | # define EXTERNC_END } 14 | #else 15 | # define EXTERNC 16 | # define EXTERNC_BEGIN 17 | # define EXTERNC_END 18 | #endif 19 | 20 | #define SUCCESS 0 21 | #define ERROR (-1) 22 | 23 | #define GUARD(func) \ 24 | { \ 25 | if(SUCCESS != (func)) { \ 26 | return ERROR; \ 27 | } \ 28 | } 29 | 30 | #define GUARD_MSG(func, msg) \ 31 | { \ 32 | if(SUCCESS != (func)) { \ 33 | printf(msg); \ 34 | return ERROR; \ 35 | } \ 36 | } 37 | 38 | #if defined(__GNUC__) || defined(__clang__) 39 | # define UNUSED __attribute__((unused)) 40 | #else 41 | # define UNUSED 42 | #endif 43 | 44 | #define WORD_SIZE 64UL 45 | #define VMSL_WORD_SIZE 56UL 46 | #define AVX512_IFMA_WORD_SIZE 52UL 47 | 48 | #if WORD_SIZE == 64 49 | # define WORD_SIZE_MASK (-1UL) 50 | #else 51 | # define WORD_SIZE_MASK ((1UL << WORD_SIZE) - 1) 52 | #endif 53 | 54 | #define HIGH_WORD(x) ((x) >> WORD_SIZE) 55 | #define LOW_WORD(x) ((x)&WORD_SIZE_MASK) 56 | 57 | #define VMSL_WORD_SIZE_MASK ((1UL << VMSL_WORD_SIZE) - 1) 58 | #define HIGH_VMSL_WORD(x) (uint64_t)((__uint128_t)(x) >> VMSL_WORD_SIZE) 59 | #define LOW_VMSL_WORD(x) ((x)&VMSL_WORD_SIZE_MASK) 60 | 61 | #define AVX512_IFMA_WORD_SIZE_MASK ((1UL << AVX512_IFMA_WORD_SIZE) - 1) 62 | #define AVX512_IFMA_MAX_MODULUS 49UL 63 | #define AVX512_IFMA_MAX_MODULUS_MASK (~((1UL << AVX512_IFMA_MAX_MODULUS) - 1)) 64 | 65 | // Check whether N=2^m where m is odd by masking it. 66 | #define ODD_POWER_MASK 0xaaaaaaaaaaaaaaaa 67 | #define REM1_POWER_MASK 0x2222222222222222 68 | #define REM2_POWER_MASK 0x4444444444444444 69 | #define REM3_POWER_MASK 0x8888888888888888 70 | 71 | #define HAS_AN_EVEN_POWER(n) (!((n)&ODD_POWER_MASK)) 72 | #define HAS_AN_REM1_POWER(n) ((n)&REM1_POWER_MASK) 73 | #define HAS_AN_REM2_POWER(n) ((n)&REM2_POWER_MASK) 74 | #define HAS_AN_REM3_POWER(n) ((n)&REM3_POWER_MASK) 75 | 76 | #if defined(__GNUC__) && (__GNUC__ >= 8) 77 | # define GCC_SUPPORT_UNROLL_PRAGMA 78 | #endif 79 | 80 | #ifdef GCC_SUPPORT_UNROLL_PRAGMA 81 | # define LOOP_UNROLL_2 _Pragma("GCC unroll 2") 82 | # define LOOP_UNROLL_4 _Pragma("GCC unroll 4") 83 | # define LOOP_UNROLL_8 _Pragma("GCC unroll 8") 84 | #elif defined(__clang__) 85 | # define LOOP_UNROLL_2 _Pragma("clang loop unroll_count(2)") 86 | # define LOOP_UNROLL_4 _Pragma("clang loop unroll_count(4)") 87 | # define LOOP_UNROLL_8 _Pragma("clang loop unroll_count(8)") 88 | #else 89 | # define LOOP_UNROLL_2 90 | # define LOOP_UNROLL_4 91 | # define LOOP_UNROLL_8 92 | #endif 93 | 94 | #define ALIGN(n) __attribute__((aligned(n))) 95 | -------------------------------------------------------------------------------- /include/internal/fast_mul_operators.h: -------------------------------------------------------------------------------- 1 | // Copyright IBM Inc. All Rights Reserved. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | #pragma once 5 | 6 | #include "defs.h" 7 | 8 | EXTERNC_BEGIN 9 | 10 | typedef struct mul_op_s { 11 | __uint128_t op; 12 | __uint128_t con; 13 | } mul_op_t; 14 | 15 | static inline uint64_t reduce_2q_to_q(const uint64_t val, const uint64_t q) 16 | { 17 | return (val < q) ? val : val - q; 18 | } 19 | 20 | static inline uint64_t reduce_4q_to_2q(const uint64_t val, const uint64_t q) 21 | { 22 | return (val < 2 * q) ? val : val - 2 * q; 23 | } 24 | 25 | static inline uint64_t reduce_4q_to_q(const uint64_t val, const uint64_t q) 26 | { 27 | return reduce_2q_to_q(reduce_4q_to_2q(val, q), q); 28 | } 29 | 30 | static inline uint64_t reduce_8q_to_4q(const uint64_t val, const uint64_t q) 31 | { 32 | return (val < 4 * q) ? val : val - 4 * q; 33 | } 34 | 35 | static inline uint64_t reduce_8q_to_2q(const uint64_t val, const uint64_t q) 36 | { 37 | return reduce_4q_to_2q(reduce_8q_to_4q(val, q), q); 38 | } 39 | 40 | static inline uint64_t reduce_8q_to_q(const uint64_t val, const uint64_t q) 41 | { 42 | return reduce_2q_to_q(reduce_8q_to_2q(val, q), q); 43 | } 44 | 45 | #ifndef L_HIGH_WORD 46 | # define L_HIGH_WORD HIGH_WORD 47 | #endif 48 | 49 | static inline uint64_t 50 | fast_mul_mod_q2(const mul_op_t w, const uint64_t t, const uint64_t q) 51 | { 52 | const uint64_t Q = L_HIGH_WORD(w.con * t); 53 | return w.op * t - Q * q; 54 | } 55 | 56 | static inline uint64_t 57 | fast_mul_mod_q(const mul_op_t w, const uint64_t t, const uint64_t q) 58 | { 59 | return reduce_2q_to_q(fast_mul_mod_q2(w, t, q), q); 60 | } 61 | 62 | static inline uint64_t fast_dbl_mul_mod_q2(const mul_op_t w1, 63 | const mul_op_t w2, 64 | const uint64_t t1, 65 | const uint64_t t2, 66 | const uint64_t q) 67 | { 68 | const uint64_t Q = L_HIGH_WORD(w1.con * t1 + w2.con * t2); 69 | return (t1 * w1.op) + (t2 * w2.op) - (Q * q); 70 | } 71 | 72 | static inline void 73 | harvey_fwd_butterfly(uint64_t *X, uint64_t *Y, const mul_op_t w, const uint64_t q) 74 | { 75 | const uint64_t q2 = q << 1; 76 | const uint64_t X1 = reduce_4q_to_2q(*X, q); 77 | const uint64_t T = fast_mul_mod_q2(w, *Y, q); 78 | 79 | *X = X1 + T; 80 | *Y = X1 - T + q2; 81 | } 82 | 83 | static inline void 84 | harvey_bkw_butterfly(uint64_t *X, uint64_t *Y, const mul_op_t w, const uint64_t q) 85 | { 86 | const uint64_t q2 = q << 1; 87 | const uint64_t X1 = reduce_4q_to_2q(*X + *Y, q); 88 | const uint64_t T = *X - *Y + q2; 89 | 90 | *X = X1; 91 | *Y = fast_mul_mod_q2(w, T, q); 92 | } 93 | 94 | static inline void harvey_bkw_butterfly_final(uint64_t * X, 95 | uint64_t * Y, 96 | const mul_op_t w, 97 | const mul_op_t n_inv, 98 | const uint64_t q) 99 | { 100 | const uint64_t q2 = q << 1; 101 | const uint64_t X1 = *X + *Y; 102 | const uint64_t T = *X - *Y + q2; 103 | 104 | *X = fast_mul_mod_q(n_inv, X1, q); 105 | *Y = fast_mul_mod_q(w, T, q); 106 | } 107 | 108 | static inline void radix4_fwd_butterfly(uint64_t * X, 109 | uint64_t * Y, 110 | uint64_t * Z, 111 | uint64_t * T, 112 | const mul_op_t w[5], 113 | const uint64_t q) 114 | { 115 | const uint64_t q2 = q << 1; 116 | const uint64_t q4 = q << 2; 117 | 118 | const uint64_t Y1 = fast_dbl_mul_mod_q2(w[1], w[2], *Y, *T, q); 119 | const uint64_t Y2 = fast_dbl_mul_mod_q2(w[3], w[4], *Y, *T, q); 120 | 121 | const uint64_t T1 = reduce_8q_to_4q(*X, q); 122 | const uint64_t T2 = fast_mul_mod_q2(w[0], *Z, q); 123 | 124 | *X = (T1 + T2 + Y1); 125 | *Y = (T1 + T2 - Y1) + q2; 126 | *Z = (T1 - T2 + Y2) + q2; 127 | *T = (T1 - T2 - Y2) + q4; 128 | } 129 | 130 | static inline void radix4_inv_butterfly(uint64_t * X, 131 | uint64_t * Y, 132 | uint64_t * Z, 133 | uint64_t * T, 134 | const mul_op_t w[5], 135 | const uint64_t q) 136 | { 137 | const uint64_t q4 = q << 2; 138 | 139 | const uint64_t T0 = *Z + *T; 140 | const uint64_t T1 = *X + *Y; 141 | 142 | const uint64_t T2 = q4 + *X - *Y; 143 | const uint64_t T3 = q4 + *Z - *T; 144 | 145 | *X = reduce_8q_to_2q(T1 + T0, q); 146 | *Z = fast_mul_mod_q(w[0], q4 + T1 - T0, q); 147 | *Y = fast_dbl_mul_mod_q2(w[1], w[3], T2, T3, q); 148 | *T = fast_dbl_mul_mod_q2(w[2], w[4], T2, T3, q); 149 | } 150 | 151 | EXTERNC_END 152 | -------------------------------------------------------------------------------- /include/internal/pre_compute.h: -------------------------------------------------------------------------------- 1 | // Copyright IBM Inc. All Rights Reserved. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | #pragma once 5 | 6 | #include 7 | 8 | #include "defs.h" 9 | 10 | EXTERNC_BEGIN 11 | 12 | // We dont care about the performance of the functions 13 | // in this file, as all of their output serves as 14 | // precomputaitons that we can cache for the NTT computations. 15 | 16 | static inline uint64_t bit_rev_idx(uint64_t idx, uint64_t width) 17 | { 18 | uint64_t ret = 0; 19 | while(width > 0) { 20 | width--; 21 | ret += ((idx & 1) << width); 22 | idx >>= 1; 23 | } 24 | 25 | return ret; 26 | } 27 | 28 | static inline void bit_rev(uint64_t w_powers[], 29 | const uint64_t w[], 30 | const uint64_t N, 31 | const uint64_t width) 32 | { 33 | for(size_t i = 0; i < N; i++) { 34 | w_powers[bit_rev_idx(i, width)] = w[i]; 35 | } 36 | } 37 | 38 | static inline void calc_w(uint64_t w_powers_rev[], 39 | const uint64_t w, 40 | const uint64_t N, 41 | const uint64_t q, 42 | const uint64_t width) 43 | { 44 | uint64_t w_powers[N]; 45 | w_powers[0] = 1; 46 | for(size_t i = 1; i < N; i++) { 47 | w_powers[i] = (uint64_t)(((__uint128_t)w_powers[i - 1] * w) % q); 48 | } 49 | 50 | bit_rev(w_powers_rev, w_powers, N, width); 51 | } 52 | 53 | static inline void calc_w_inv(uint64_t w_inv_rev[], 54 | const uint64_t w_inv, 55 | const uint64_t N, 56 | const uint64_t q, 57 | const uint64_t width) 58 | { 59 | uint64_t w_inv_powers[N]; 60 | w_inv_powers[0] = 1; 61 | for(size_t i = 1; i < N; i++) { 62 | w_inv_powers[i] = (uint64_t)(((__uint128_t)w_inv_powers[i - 1] * w_inv) % q); 63 | } 64 | 65 | bit_rev(w_inv_rev, w_inv_powers, N, width); 66 | } 67 | 68 | static inline void calc_w_con(uint64_t w_con[], 69 | const uint64_t w[], 70 | const uint64_t N, 71 | const uint64_t q, 72 | const uint64_t word_size) 73 | { 74 | for(size_t i = 0; i < N; i++) { 75 | w_con[i] = ((__uint128_t)w[i] << word_size) / q; 76 | } 77 | } 78 | 79 | static uint64_t 80 | calc_ninv_con(const uint64_t Ninv, const uint64_t q, const uint64_t word_size) 81 | { 82 | return ((__uint128_t)Ninv << word_size) / q; 83 | } 84 | 85 | static inline void expand_w(uint64_t w_expanded[], 86 | const uint64_t w[], 87 | const uint64_t N, 88 | const uint64_t q) 89 | { 90 | w_expanded[0] = w[0]; 91 | w_expanded[1] = 0; 92 | w_expanded[2] = w[1]; 93 | w_expanded[3] = 0; 94 | for(size_t i = 4; i < 2 * N; i += 2) { 95 | w_expanded[i] = w[i / 2]; 96 | 97 | if(i % 4 == 0) { 98 | const __uint128_t t = w_expanded[i / 2]; 99 | w_expanded[i + 1] = (t * w[i / 2]) % q; 100 | } else { 101 | const __uint128_t t = w_expanded[(i - 2) / 2]; 102 | w_expanded[i + 1] = q - ((t * w[i / 2]) % q); 103 | } 104 | } 105 | } 106 | 107 | #ifdef AVX512_IFMA_SUPPORT 108 | 109 | static inline void 110 | expand_w_hexl(uint64_t w_expanded[], const uint64_t w[], const uint64_t N) 111 | { 112 | size_t idx = 0; 113 | 114 | memcpy(&w_expanded[idx], w, N / 8 * sizeof(uint64_t)); 115 | idx += N / 8; 116 | 117 | // Duplicate four times for FwdT4 118 | for(size_t i = 0; i < (N / 8); i++) { 119 | w_expanded[idx] = w[(N / 8) + i]; 120 | w_expanded[idx + 1] = w[(N / 8) + i]; 121 | w_expanded[idx + 2] = w[(N / 8) + i]; 122 | w_expanded[idx + 3] = w[(N / 8) + i]; 123 | idx += 4; 124 | } 125 | 126 | // Duplicate four times for FwdT2 127 | for(size_t i = 0; i < (N / 4); i++) { 128 | w_expanded[idx] = w[(N / 4) + i]; 129 | w_expanded[idx + 1] = w[(N / 4) + i]; 130 | idx += 2; 131 | } 132 | 133 | memcpy(&w_expanded[idx], &w[N / 2], N / 2 * sizeof(uint64_t)); 134 | idx += N / 2; 135 | 136 | memset(&w_expanded[idx], 0, (2 * N - idx) * sizeof(uint64_t)); 137 | } 138 | 139 | static inline void permute_w(uint64_t in_out[8]) 140 | { 141 | uint64_t t[8]; 142 | memcpy(t, in_out, 8 * sizeof(uint64_t)); 143 | 144 | in_out[0] = t[0]; 145 | in_out[1] = t[4]; 146 | in_out[2] = t[1]; 147 | in_out[3] = t[5]; 148 | in_out[4] = t[2]; 149 | in_out[5] = t[6]; 150 | in_out[6] = t[3]; 151 | in_out[7] = t[7]; 152 | } 153 | 154 | static inline void expand_w_r4_avx512_ifma(uint64_t w_expanded[], 155 | const uint64_t w[], 156 | const uint64_t N, 157 | const uint64_t q, 158 | const uint64_t unordered) 159 | { 160 | size_t w_idx = 1; 161 | size_t new_w_idx = 1; 162 | 163 | w_expanded[0] = 0; 164 | 165 | // FWD8 166 | if(HAS_AN_EVEN_POWER(N)) { 167 | for(size_t m = 1; w_idx < (N >> 5); m <<= 2) { 168 | for(size_t i = 0; i < m; i++, w_idx++) { 169 | const __uint128_t w1 = w[w_idx]; 170 | const __uint128_t w2 = w[2 * w_idx]; 171 | const __uint128_t w3 = w[2 * w_idx + 1]; 172 | w_expanded[new_w_idx++] = w1; 173 | w_expanded[new_w_idx++] = w2; 174 | w_expanded[new_w_idx++] = (w1 * w2) % q; 175 | w_expanded[new_w_idx++] = w3; 176 | w_expanded[new_w_idx++] = q - ((w1 * w3) % q); 177 | } 178 | w_idx = 4 * m; 179 | } 180 | } else { 181 | // First radix-2 iteration 182 | w_expanded[new_w_idx++] = w[w_idx++]; 183 | 184 | for(size_t m = 2; w_idx < (N >> 5); m <<= 2) { 185 | for(size_t i = 0; i < m; i++, w_idx++) { 186 | const __uint128_t w1 = w[w_idx]; 187 | const __uint128_t w2 = w[2 * w_idx]; 188 | const __uint128_t w3 = w[2 * w_idx + 1]; 189 | w_expanded[new_w_idx++] = w1; 190 | w_expanded[new_w_idx++] = w2; 191 | w_expanded[new_w_idx++] = (w1 * w2) % q; 192 | w_expanded[new_w_idx++] = w3; 193 | w_expanded[new_w_idx++] = q - ((w1 * w3) % q); 194 | } 195 | w_idx = 4 * m; 196 | } 197 | } 198 | 199 | // FWD4 200 | for(w_idx = (N >> 4); w_idx < (N >> 3); w_idx += 2) { 201 | const uint64_t k = 2 * w_idx; 202 | w_expanded[new_w_idx++] = w[w_idx]; 203 | w_expanded[new_w_idx++] = w[w_idx + 1]; 204 | w_expanded[new_w_idx++] = w[k]; 205 | w_expanded[new_w_idx++] = w[k + 2]; 206 | w_expanded[new_w_idx++] = (w[w_idx] * w[k]) % q; 207 | w_expanded[new_w_idx++] = (w[w_idx + 1] * w[k + 2]) % q; 208 | w_expanded[new_w_idx++] = w[k + 1]; 209 | w_expanded[new_w_idx++] = w[k + 2 + 1]; 210 | w_expanded[new_w_idx++] = q - ((w[w_idx] * w[k + 1]) % q); 211 | w_expanded[new_w_idx++] = q - ((w[w_idx + 1] * w[k + 3]) % q); 212 | } 213 | 214 | // Align on an 8-qw boundary 215 | new_w_idx = ((new_w_idx >> 3) << 3) + 8; 216 | 217 | // FWD1 218 | for(w_idx = (N >> 2); w_idx < (N >> 1); w_idx += 8) { 219 | // W1 220 | for(size_t i = 0; i < 8; i++) { 221 | w_expanded[new_w_idx++] = w[w_idx + i]; 222 | } 223 | // W2 224 | for(size_t i = 0; i < 8; i++) { 225 | w_expanded[new_w_idx++] = w[2 * (w_idx + i)]; 226 | } 227 | // W3 228 | for(size_t i = 0; i < 8; i++) { 229 | w_expanded[new_w_idx++] = (w[w_idx + i] * w[2 * (w_idx + i)]) % q; 230 | } 231 | // W4 232 | for(size_t i = 0; i < 8; i++) { 233 | w_expanded[new_w_idx++] = w[2 * (w_idx + i) + 1]; 234 | } 235 | // W5 236 | for(size_t i = 0; i < 8; i++) { 237 | w_expanded[new_w_idx++] = q - ((w[w_idx + i] * w[2 * (w_idx + i) + 1]) % q); 238 | } 239 | 240 | // Need to permute values 241 | if(unordered) { 242 | permute_w(&w_expanded[new_w_idx - 8 * 5]); 243 | permute_w(&w_expanded[new_w_idx - 8 * 4]); 244 | permute_w(&w_expanded[new_w_idx - 8 * 3]); 245 | permute_w(&w_expanded[new_w_idx - 8 * 2]); 246 | permute_w(&w_expanded[new_w_idx - 8 * 1]); 247 | } 248 | } 249 | 250 | memset(&w_expanded[new_w_idx], 0, ((5 * N) - new_w_idx) * sizeof(uint64_t)); 251 | } 252 | 253 | static inline void expand_w_r4r2_avx512_ifma(uint64_t w_expanded[], 254 | const uint64_t w[], 255 | const uint64_t N, 256 | const uint64_t q) 257 | { 258 | size_t w_idx = 1; 259 | size_t new_w_idx = 1; 260 | size_t t = N >> 4; 261 | 262 | w_expanded[0] = 0; 263 | 264 | // FWD8 in radix4 265 | for(size_t m = 1; w_idx < t; m <<= 2) { 266 | for(size_t i = 0; i < m; i++, w_idx++) { 267 | const __uint128_t w1 = w[w_idx]; 268 | const __uint128_t w2 = w[2 * w_idx]; 269 | const __uint128_t w3 = w[2 * w_idx + 1]; 270 | w_expanded[new_w_idx++] = w1; 271 | w_expanded[new_w_idx++] = w2; 272 | w_expanded[new_w_idx++] = (w1 * w2) % q; 273 | w_expanded[new_w_idx++] = w3; 274 | w_expanded[new_w_idx++] = q - ((w1 * w3) % q); 275 | } 276 | w_idx = 4 * m; 277 | } 278 | 279 | // Align on an 8-qw boundary 280 | new_w_idx = ((new_w_idx >> 3) << 3) + 8; 281 | 282 | if(HAS_AN_EVEN_POWER(N)) { 283 | // FWD8 in radix2 284 | memcpy(&w_expanded[new_w_idx], &w[w_idx], t * sizeof(uint64_t)); 285 | new_w_idx += t; 286 | } 287 | 288 | t <<= 1; 289 | 290 | // Duplicate four times for FwdT4 291 | for(size_t i = 0; i < t; i++) { 292 | w_expanded[new_w_idx++] = w[t + i]; 293 | w_expanded[new_w_idx++] = w[t + i]; 294 | w_expanded[new_w_idx++] = w[t + i]; 295 | w_expanded[new_w_idx++] = w[t + i]; 296 | } 297 | t <<= 1; 298 | 299 | // Duplicate four times for FwdT2 300 | for(size_t i = 0; i < t; i += 4) { 301 | w_expanded[new_w_idx++] = w[t + i + 0]; 302 | w_expanded[new_w_idx++] = w[t + i + 0]; 303 | w_expanded[new_w_idx++] = w[t + i + 2]; 304 | w_expanded[new_w_idx++] = w[t + i + 2]; 305 | w_expanded[new_w_idx++] = w[t + i + 1]; 306 | w_expanded[new_w_idx++] = w[t + i + 1]; 307 | w_expanded[new_w_idx++] = w[t + i + 3]; 308 | w_expanded[new_w_idx++] = w[t + i + 3]; 309 | } 310 | t <<= 1; 311 | 312 | for(size_t i = 0; i < t; i += 8) { 313 | w_expanded[new_w_idx++] = w[t + i + 0]; 314 | w_expanded[new_w_idx++] = w[t + i + 4]; 315 | w_expanded[new_w_idx++] = w[t + i + 1]; 316 | w_expanded[new_w_idx++] = w[t + i + 5]; 317 | w_expanded[new_w_idx++] = w[t + i + 2]; 318 | w_expanded[new_w_idx++] = w[t + i + 6]; 319 | w_expanded[new_w_idx++] = w[t + i + 3]; 320 | w_expanded[new_w_idx++] = w[t + i + 7]; 321 | } 322 | 323 | memset(&w_expanded[new_w_idx], 0, ((5 * N) - new_w_idx) * sizeof(uint64_t)); 324 | } 325 | 326 | static inline void expand_w_r2_16_avx512_ifma(uint64_t w_expanded[], 327 | const uint64_t w[], 328 | const uint64_t N) 329 | { 330 | size_t t = N >> 3; 331 | size_t new_w_idx = t; 332 | 333 | memcpy(w_expanded, w, t * sizeof(uint64_t)); 334 | 335 | // Duplicate four times for FwdT4 336 | for(size_t i = 0; i < t; i++) { 337 | w_expanded[new_w_idx++] = w[t + i]; 338 | w_expanded[new_w_idx++] = w[t + i]; 339 | w_expanded[new_w_idx++] = w[t + i]; 340 | w_expanded[new_w_idx++] = w[t + i]; 341 | } 342 | t <<= 1; 343 | 344 | // Duplicate four times for FwdT2 345 | for(size_t i = 0; i < t; i += 4) { 346 | w_expanded[new_w_idx++] = w[t + i + 0]; 347 | w_expanded[new_w_idx++] = w[t + i + 0]; 348 | w_expanded[new_w_idx++] = w[t + i + 2]; 349 | w_expanded[new_w_idx++] = w[t + i + 2]; 350 | w_expanded[new_w_idx++] = w[t + i + 1]; 351 | w_expanded[new_w_idx++] = w[t + i + 1]; 352 | w_expanded[new_w_idx++] = w[t + i + 3]; 353 | w_expanded[new_w_idx++] = w[t + i + 3]; 354 | } 355 | t <<= 1; 356 | 357 | for(size_t i = 0; i < t; i += 8) { 358 | w_expanded[new_w_idx++] = w[t + i + 0]; 359 | w_expanded[new_w_idx++] = w[t + i + 4]; 360 | w_expanded[new_w_idx++] = w[t + i + 1]; 361 | w_expanded[new_w_idx++] = w[t + i + 5]; 362 | w_expanded[new_w_idx++] = w[t + i + 2]; 363 | w_expanded[new_w_idx++] = w[t + i + 6]; 364 | w_expanded[new_w_idx++] = w[t + i + 3]; 365 | w_expanded[new_w_idx++] = w[t + i + 7]; 366 | } 367 | } 368 | 369 | #endif 370 | 371 | EXTERNC_END 372 | -------------------------------------------------------------------------------- /include/ntt_avx512_ifma.h: -------------------------------------------------------------------------------- 1 | // Copyright IBM Inc. All Rights Reserved. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | #pragma once 5 | 6 | #include "defs.h" 7 | 8 | EXTERNC_BEGIN 9 | 10 | #ifdef AVX512_IFMA_SUPPORT 11 | 12 | # include "avx512.h" 13 | 14 | void fwd_ntt_radix4_avx512_ifma_lazy(uint64_t a[], 15 | uint64_t N, 16 | uint64_t q, 17 | const uint64_t w[], 18 | const uint64_t w_con[]); 19 | 20 | // Assumption N % 2^6 = 0 21 | static inline void 22 | final_reduce_q8(uint64_t a[], const uint64_t N, const uint64_t q_64) 23 | { 24 | const __m512i q = SET1(q_64); 25 | const __m512i q2 = SET1(q_64 << 1); 26 | const __m512i q4 = SET1(q_64 << 2); 27 | 28 | // Final reduction 29 | for(size_t i = 0; i < N; i += 8 * 8) { 30 | LOOP_UNROLL_8 31 | for(size_t j = 0; j < 8; j++) { 32 | __m512i T = reduce_if_greater(LOAD(&a[i + 8 * j]), q4); 33 | T = reduce_if_greater(T, q2); 34 | T = reduce_if_greater(T, q); 35 | STORE(&a[i + 8 * j], T); 36 | } 37 | } 38 | } 39 | 40 | // Assumption N % 2^6 = 0 41 | static inline void 42 | final_reduce_q4(uint64_t a[], const uint64_t N, const uint64_t q_64) 43 | { 44 | const __m512i q = SET1(q_64); 45 | const __m512i q2 = SET1(q_64 << 1); 46 | // Final reduction 47 | for(size_t i = 0; i < N; i += 8 * 8) { 48 | LOOP_UNROLL_8 49 | for(size_t j = 0; j < 8; j++) { 50 | __m512i T = reduce_if_greater(LOAD(&a[i + 8 * j]), q2); 51 | STORE(&a[i + 8 * j], reduce_if_greater(T, q)); 52 | } 53 | } 54 | } 55 | 56 | static inline void fwd_ntt_radix4_avx512_ifma(uint64_t a[], 57 | const uint64_t N, 58 | const uint64_t q, 59 | const uint64_t w[], 60 | const uint64_t w_con[]) 61 | { 62 | fwd_ntt_radix4_avx512_ifma_lazy(a, N, q, w, w_con); 63 | final_reduce_q8(a, N, q); 64 | } 65 | 66 | void fwd_ntt_r4r2_avx512_ifma_lazy(uint64_t a[], 67 | uint64_t N, 68 | uint64_t q, 69 | const uint64_t w[], 70 | const uint64_t w_con[]); 71 | 72 | static inline void fwd_ntt_r4r2_avx512_ifma(uint64_t a[], 73 | const uint64_t N, 74 | const uint64_t q, 75 | const uint64_t w[], 76 | const uint64_t w_con[]) 77 | { 78 | fwd_ntt_r4r2_avx512_ifma_lazy(a, N, q, w, w_con); 79 | final_reduce_q4(a, N, q); 80 | } 81 | 82 | void fwd_ntt_radix4_avx512_ifma_lazy_unordered(uint64_t a[], 83 | uint64_t N, 84 | uint64_t q, 85 | const uint64_t w[], 86 | const uint64_t w_con[]); 87 | 88 | static inline void fwd_ntt_radix4_avx512_ifma_unordered(uint64_t a[], 89 | const uint64_t N, 90 | const uint64_t q, 91 | const uint64_t w[], 92 | const uint64_t w_con[]) 93 | { 94 | fwd_ntt_radix4_avx512_ifma_lazy_unordered(a, N, q, w, w_con); 95 | final_reduce_q8(a, N, q); 96 | } 97 | 98 | void fwd_ntt_r2_16_avx512_ifma_lazy(uint64_t a[], 99 | uint64_t N, 100 | uint64_t q, 101 | const uint64_t w[], 102 | const uint64_t w_con[]); 103 | 104 | static inline void fwd_ntt_r2_16_avx512_ifma(uint64_t a[], 105 | const uint64_t N, 106 | const uint64_t q, 107 | const uint64_t w[], 108 | const uint64_t w_con[]) 109 | { 110 | fwd_ntt_r2_16_avx512_ifma_lazy(a, N, q, w, w_con); 111 | final_reduce_q4(a, N, q); 112 | } 113 | 114 | #endif 115 | 116 | EXTERNC_END 117 | -------------------------------------------------------------------------------- /include/ntt_hexl.h: -------------------------------------------------------------------------------- 1 | // Copyright IBM Inc. All Rights Reserved. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | #pragma once 5 | 6 | #include "defs.h" 7 | 8 | EXTERNC_BEGIN 9 | 10 | #ifdef AVX512_IFMA_SUPPORT 11 | 12 | // Internal function of Intel HEXL under the license of Intel HEXL 13 | void ForwardTransformToBitReverseAVX512( 14 | uint64_t * operand, 15 | uint64_t degree, 16 | uint64_t mod, 17 | const uint64_t *root_of_unity_powers, 18 | const uint64_t *precon_root_of_unity_powers, 19 | uint64_t input_mod_factor, 20 | uint64_t output_mod_factor, 21 | uint64_t recursion_depth, 22 | uint64_t recursion_half); 23 | 24 | static inline void fwd_ntt_radix2_hexl_lazy(uint64_t a[], 25 | const uint64_t N, 26 | const uint64_t q, 27 | const uint64_t w[], 28 | const uint64_t w_con[]) 29 | { 30 | ForwardTransformToBitReverseAVX512(a, N, q, w, w_con, 2, 2, 0, 0); 31 | } 32 | 33 | static inline void fwd_ntt_radix2_hexl(uint64_t a[], 34 | const uint64_t N, 35 | const uint64_t q, 36 | const uint64_t w[], 37 | const uint64_t w_con[]) 38 | { 39 | ForwardTransformToBitReverseAVX512(a, N, q, w, w_con, 2, 1, 0, 0); 40 | } 41 | 42 | #endif 43 | 44 | EXTERNC_END 45 | -------------------------------------------------------------------------------- /include/ntt_radix4.h: -------------------------------------------------------------------------------- 1 | // Copyright IBM Inc. All Rights Reserved. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | #pragma once 5 | 6 | #include "fast_mul_operators.h" 7 | 8 | EXTERNC_BEGIN 9 | 10 | void fwd_ntt_radix4_lazy(uint64_t a[], 11 | uint64_t N, 12 | uint64_t q, 13 | const uint64_t w[], 14 | const uint64_t w_con[]); 15 | 16 | static inline void fwd_ntt_radix4(uint64_t a[], 17 | const uint64_t N, 18 | const uint64_t q, 19 | const uint64_t w[], 20 | const uint64_t w_con[]) 21 | { 22 | fwd_ntt_radix4_lazy(a, N, q, w, w_con); 23 | 24 | // Final reduction 25 | for(size_t i = 0; i < N; i++) { 26 | a[i] = reduce_8q_to_q(a[i], q); 27 | } 28 | } 29 | 30 | void inv_ntt_radix4(uint64_t a[], 31 | uint64_t N, 32 | uint64_t q, 33 | mul_op_t n_inv, 34 | const uint64_t w[], 35 | const uint64_t w_con[]); 36 | 37 | EXTERNC_END 38 | -------------------------------------------------------------------------------- /include/ntt_radix4_s390x_vef.h: -------------------------------------------------------------------------------- 1 | // Copyright IBM Inc. All Rights Reserved. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | #pragma once 5 | 6 | #include "defs.h" 7 | 8 | EXTERNC_BEGIN 9 | 10 | /****************************** 11 | Single input 12 | ******************************/ 13 | void fwd_ntt_radix4_intrinsic_lazy(uint64_t a[], 14 | uint64_t N, 15 | uint64_t q, 16 | const uint64_t w[], 17 | const uint64_t w_con[]); 18 | 19 | static inline void fwd_ntt_radix4_intrinsic(uint64_t a[], 20 | uint64_t N, 21 | uint64_t q, 22 | const uint64_t w[], 23 | const uint64_t w_con[]) 24 | { 25 | fwd_ntt_radix4_intrinsic_lazy(a, N, q, w, w_con); 26 | 27 | // Final reduction 28 | for(size_t i = 0; i < N; i++) { 29 | a[i] = reduce_8q_to_q(a[i], q); 30 | } 31 | } 32 | 33 | void inv_ntt_radix4_intrinsic(uint64_t a[], 34 | uint64_t N, 35 | uint64_t q, 36 | mul_op_t n_inv, 37 | const uint64_t w[], 38 | const uint64_t w_con[]); 39 | 40 | /****************************** 41 | Double input 42 | ******************************/ 43 | void fwd_ntt_radix4_intrinsic_lazy_dbl(uint64_t a1[], 44 | uint64_t a2[], 45 | uint64_t N, 46 | uint64_t q, 47 | const uint64_t w[], 48 | const uint64_t w_con[]); 49 | 50 | static inline void fwd_ntt_radix4_intrinsic_dbl(uint64_t a1[], 51 | uint64_t a2[], 52 | uint64_t N, 53 | uint64_t q, 54 | const uint64_t w[], 55 | const uint64_t w_con[]) 56 | { 57 | fwd_ntt_radix4_intrinsic_lazy_dbl(a1, a2, N, q, w, w_con); 58 | 59 | // Final reduction 60 | for(size_t i = 0; i < N; i++) { 61 | a1[i] = reduce_8q_to_q(a1[i], q); 62 | a2[i] = reduce_8q_to_q(a2[i], q); 63 | } 64 | } 65 | 66 | EXTERNC_END 67 | -------------------------------------------------------------------------------- /include/ntt_radix4x4.h: -------------------------------------------------------------------------------- 1 | // Copyright IBM Inc. All Rights Reserved. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | #pragma once 5 | 6 | #include "fast_mul_operators.h" 7 | 8 | EXTERNC_BEGIN 9 | 10 | void fwd_ntt_radix4x4_lazy(uint64_t a[], 11 | uint64_t N, 12 | uint64_t q, 13 | const uint64_t w[], 14 | const uint64_t w_con[]); 15 | 16 | static inline void fwd_ntt_radix4x4(uint64_t a[], 17 | const uint64_t N, 18 | const uint64_t q, 19 | const uint64_t w[], 20 | const uint64_t w_con[]) 21 | { 22 | fwd_ntt_radix4x4_lazy(a, N, q, w, w_con); 23 | 24 | // Final reduction 25 | for(size_t i = 0; i < N; i++) { 26 | a[i] = reduce_8q_to_q(a[i], q); 27 | } 28 | } 29 | 30 | EXTERNC_END 31 | -------------------------------------------------------------------------------- /include/ntt_reference.h: -------------------------------------------------------------------------------- 1 | // Copyright IBM Inc. All Rights Reserved. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | #pragma once 5 | 6 | #include "fast_mul_operators.h" 7 | 8 | EXTERNC_BEGIN 9 | 10 | /****************************** 11 | Single input 12 | ******************************/ 13 | void fwd_ntt_ref_harvey_lazy(uint64_t a[], 14 | uint64_t N, 15 | uint64_t q, 16 | const uint64_t w[], 17 | const uint64_t w_con[]); 18 | 19 | static inline void fwd_ntt_ref_harvey(uint64_t a[], 20 | const uint64_t N, 21 | const uint64_t q, 22 | const uint64_t w[], 23 | const uint64_t w_con[]) 24 | { 25 | fwd_ntt_ref_harvey_lazy(a, N, q, w, w_con); 26 | 27 | // Final reduction 28 | for(size_t i = 0; i < N; i++) { 29 | a[i] = reduce_4q_to_q(a[i], q); 30 | } 31 | } 32 | 33 | void inv_ntt_ref_harvey(uint64_t a[], 34 | uint64_t N, 35 | uint64_t q, 36 | mul_op_t n_inv, 37 | uint64_t word_size, 38 | const uint64_t w[], 39 | const uint64_t w_con[]); 40 | 41 | /****************************** 42 | Double input 43 | ******************************/ 44 | void fwd_ntt_ref_harvey_lazy_dbl(uint64_t a1[], 45 | uint64_t a2[], 46 | uint64_t N, 47 | uint64_t q, 48 | const uint64_t w[], 49 | const uint64_t w_con[]); 50 | 51 | static inline void fwd_ntt_ref_harvey_dbl(uint64_t a1[], 52 | uint64_t a2[], 53 | const uint64_t N, 54 | const uint64_t q, 55 | const uint64_t w[], 56 | const uint64_t w_con[]) 57 | { 58 | fwd_ntt_ref_harvey_lazy_dbl(a1, a2, N, q, w, w_con); 59 | 60 | // Final reduction 61 | for(size_t i = 0; i < N; i++) { 62 | a1[i] = reduce_4q_to_q(a1[i], q); 63 | a2[i] = reduce_4q_to_q(a2[i], q); 64 | } 65 | } 66 | 67 | EXTERNC_END 68 | -------------------------------------------------------------------------------- /include/ntt_seal.h: -------------------------------------------------------------------------------- 1 | // Copyright IBM Inc. All Rights Reserved. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | #pragma once 5 | 6 | #include "fast_mul_operators.h" 7 | 8 | EXTERNC_BEGIN 9 | 10 | void fwd_ntt_seal_lazy(uint64_t a[], 11 | uint64_t N, 12 | uint64_t q, 13 | const uint64_t w[], 14 | const uint64_t w_con[]); 15 | 16 | static inline void fwd_ntt_seal(uint64_t a[], 17 | const uint64_t N, 18 | const uint64_t q, 19 | const uint64_t w[], 20 | const uint64_t w_con[]) 21 | { 22 | fwd_ntt_seal_lazy(a, N, q, w, w_con); 23 | 24 | // Final reduction 25 | for(size_t i = 0; i < N; i++) { 26 | a[i] = reduce_4q_to_q(a[i], q); 27 | } 28 | } 29 | 30 | void inv_ntt_seal(uint64_t a[], 31 | uint64_t N, 32 | uint64_t q, 33 | uint64_t n_inv, 34 | uint64_t n_inv_con, 35 | const uint64_t w[], 36 | const uint64_t w_con[]); 37 | 38 | EXTERNC_END 39 | -------------------------------------------------------------------------------- /src/ntt_r2_16_avx512_ifma.c: -------------------------------------------------------------------------------- 1 | // Copyright IBM Inc. All Rights Reserved. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | #ifdef AVX512_IFMA_SUPPORT 5 | 6 | # include "ntt_avx512_ifma.h" 7 | 8 | static inline void fwd16_r2(uint64_t * a, 9 | const uint64_t m, 10 | const uint64_t *w, 11 | const uint64_t *w_con, 12 | const uint64_t q_64) 13 | { 14 | for(size_t j = m; j > 0; --j) { 15 | size_t i = j - 1; 16 | const uint64_t w_idx2 = (8 * i) + m; 17 | const uint64_t w_idx3 = w_idx2 + (8 * m); 18 | const uint64_t w_idx4 = w_idx2 + (16 * m); 19 | const mul_op_m512_t w1 = {SET1(w[i]), SET1(w_con[i])}; 20 | const mul_op_m512_t w2 = {LOADA(&w[w_idx2]), LOADA(&w_con[w_idx2])}; 21 | const mul_op_m512_t w3 = {LOADA(&w[w_idx3]), LOADA(&w_con[w_idx3])}; 22 | const mul_op_m512_t w4 = {LOADA(&w[w_idx4]), LOADA(&w_con[w_idx4])}; 23 | 24 | __m512i X = LOAD(&a[16 * i]); 25 | __m512i Y = LOAD(&a[16 * i + 8]); 26 | __m512i T; 27 | 28 | fwd_radix2_butterfly_m512(&X, &Y, &w1, q_64); 29 | 30 | T = SHUF(X, Y, 0x44); // (0, 1 ,2, 3, 80, 81, 82, 83) 31 | Y = SHUF(X, Y, 0xee); // (4, 5 ,6, 7, 84, 85, 86, 87) 32 | X = T; 33 | 34 | fwd_radix2_butterfly_m512(&X, &Y, &w2, q_64); 35 | 36 | T = SHUF(X, Y, 0x88); // (0, 1 ,80, 81, 4, 5, 84, 85) 37 | Y = SHUF(X, Y, 0xdd); // (2, 3, 82, 83, 6, 7, 86, 87) 38 | X = T; 39 | 40 | fwd_radix2_butterfly_m512(&X, &Y, &w3, q_64); 41 | 42 | __m512i idx1 = SETR(0, 2, 8 + 0, 8 + 2, 4, 6, 8 + 4, 8 + 6); 43 | __m512i idx2 = SETR(1, 3, 8 + 1, 8 + 3, 5, 7, 8 + 5, 8 + 7); 44 | 45 | T = PERM(X, idx1, Y); // (0, 80 ,2, 82, 4, 84, 6, 86) 46 | Y = PERM(X, idx2, Y); // (1, 81 ,3, 83, 5, 85, 7, 87) 47 | X = T; 48 | 49 | fwd_radix2_butterfly_m512(&X, &Y, &w4, q_64); 50 | 51 | STORE(&a[16 * i], UNPACKLO(X, Y)); 52 | STORE(&a[16 * i + 8], UNPACKHI(X, Y)); 53 | } 54 | } 55 | 56 | static inline void fwd8_r2(uint64_t * X_64, 57 | uint64_t * Y_64, 58 | const mul_op_m512_t *w, 59 | const uint64_t q_64) 60 | { 61 | __m512i X = LOAD(X_64); 62 | __m512i Y = LOAD(Y_64); 63 | 64 | fwd_radix2_butterfly_m512(&X, &Y, w, q_64); 65 | 66 | STORE(X_64, X); 67 | STORE(Y_64, Y); 68 | } 69 | 70 | void fwd_ntt_r2_16_avx512_ifma_lazy(uint64_t a[], 71 | uint64_t N, 72 | uint64_t q, 73 | const uint64_t w[], 74 | const uint64_t w_con[]) 75 | { 76 | size_t m = 1; 77 | size_t t = N >> 1; 78 | 79 | for(; m < (N >> 4); m <<= 1, t >>= 1) { 80 | for(size_t j = 0; j < m; j++) { 81 | 82 | const uint64_t k = 2 * t * j; 83 | const mul_op_m512_t w1 = {SET1(w[m + j]), SET1(w_con[m + j])}; 84 | 85 | for(size_t i = k; i < k + t; i += 8) { 86 | fwd8_r2(&a[i], &a[i + t], &w1, q); 87 | } 88 | } 89 | } 90 | 91 | fwd16_r2(a, m, &w[m], &w_con[m], q); 92 | } 93 | 94 | #endif 95 | -------------------------------------------------------------------------------- /src/ntt_r4r2_avx512_ifma.c: -------------------------------------------------------------------------------- 1 | // Copyright IBM Inc. All Rights Reserved. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | #ifdef AVX512_IFMA_SUPPORT 5 | 6 | # include "ntt_avx512_ifma.h" 7 | # include "ntt_hexl.h" 8 | 9 | static inline void _fwd8_r2(uint64_t * a, 10 | __m512i * X, 11 | __m512i * Y, 12 | const mul_op_m512_t *w2, 13 | const mul_op_m512_t *w3, 14 | const mul_op_m512_t *w4, 15 | const uint64_t q_64) 16 | { 17 | __m512i T; 18 | T = SHUF(*X, *Y, 0x44); // (0, 1 ,2, 3, 80, 81, 82, 83) 19 | *Y = SHUF(*X, *Y, 0xee); // (4, 5 ,6, 7, 84, 85, 86, 87) 20 | *X = T; 21 | 22 | fwd_radix2_butterfly_m512(X, Y, w2, q_64); 23 | 24 | T = SHUF(*X, *Y, 0x88); // (0, 1 ,80, 81, 4, 5, 84, 85) 25 | *Y = SHUF(*X, *Y, 0xdd); // (2, 3, 82, 83, 6, 7, 86, 87) 26 | *X = T; 27 | 28 | fwd_radix2_butterfly_m512(X, Y, w3, q_64); 29 | 30 | __m512i idx1 = SETR(0, 2, 8 + 0, 8 + 2, 4, 6, 8 + 4, 8 + 6); 31 | __m512i idx2 = SETR(1, 3, 8 + 1, 8 + 3, 5, 7, 8 + 5, 8 + 7); 32 | 33 | T = PERM(*X, idx1, *Y); // (0, 80 ,2, 82, 4, 84, 6, 86) 34 | *Y = PERM(*X, idx2, *Y); // (1, 81 ,3, 83, 5, 85, 7, 87) 35 | *X = T; 36 | 37 | fwd_radix2_butterfly_m512(X, Y, w4, q_64); 38 | 39 | STORE(&a[0], UNPACKLO(*X, *Y)); 40 | STORE(&a[8], UNPACKHI(*X, *Y)); 41 | } 42 | 43 | static inline void fwd8_r2(uint64_t * a, 44 | const uint64_t m, 45 | const uint64_t *w, 46 | const uint64_t *w_con, 47 | const uint64_t q_64) 48 | { 49 | const __m512i q4 = SET1(q_64 << 2); 50 | 51 | LOOP_UNROLL_4 52 | for(size_t i = 0; i < m; ++i) { 53 | const uint64_t w_idx2 = (8 * i); 54 | const uint64_t w_idx3 = w_idx2 + (8 * m); 55 | const uint64_t w_idx4 = w_idx2 + (16 * m); 56 | const mul_op_m512_t w2 = {LOADA(&w[w_idx2]), LOADA(&w_con[w_idx2])}; 57 | const mul_op_m512_t w3 = {LOADA(&w[w_idx3]), LOADA(&w_con[w_idx3])}; 58 | const mul_op_m512_t w4 = {LOADA(&w[w_idx4]), LOADA(&w_con[w_idx4])}; 59 | 60 | // Radix-4 butterfly leaves values in [0, 8q) 61 | __m512i X = 62 | reduce_if_greater(LOAD(&a[16 * i]) & AVX512_IFMA_WORD_SIZE_MASK, q4); 63 | __m512i Y = 64 | reduce_if_greater(LOAD(&a[16 * i + 8]) & AVX512_IFMA_WORD_SIZE_MASK, q4); 65 | 66 | _fwd8_r2(&a[16 * i], &X, &Y, &w2, &w3, &w4, q_64); 67 | } 68 | } 69 | 70 | static inline void fwd16_r2(uint64_t * a, 71 | const uint64_t m, 72 | const uint64_t *w, 73 | const uint64_t *w_con, 74 | const uint64_t q_64) 75 | { 76 | const __m512i q4 = SET1(q_64 << 2); 77 | 78 | LOOP_UNROLL_4 79 | for(size_t j = m; j > 0; --j) { 80 | size_t i = j - 1; 81 | const uint64_t w_idx2 = (8 * i) + m; 82 | const uint64_t w_idx3 = w_idx2 + (8 * m); 83 | const uint64_t w_idx4 = w_idx2 + (16 * m); 84 | const mul_op_m512_t w1 = {SET1(w[i]), SET1(w_con[i])}; 85 | const mul_op_m512_t w2 = {LOADA(&w[w_idx2]), LOADA(&w_con[w_idx2])}; 86 | const mul_op_m512_t w3 = {LOADA(&w[w_idx3]), LOADA(&w_con[w_idx3])}; 87 | const mul_op_m512_t w4 = {LOADA(&w[w_idx4]), LOADA(&w_con[w_idx4])}; 88 | 89 | // Radix-4 butterfly leaves values in [0, 8q) 90 | __m512i X = 91 | reduce_if_greater(LOAD(&a[16 * i]) & AVX512_IFMA_WORD_SIZE_MASK, q4); 92 | __m512i Y = 93 | reduce_if_greater(LOAD(&a[16 * i + 8]) & AVX512_IFMA_WORD_SIZE_MASK, q4); 94 | 95 | fwd_radix2_butterfly_m512(&X, &Y, &w1, q_64); 96 | 97 | _fwd8_r2(&a[16 * i], &X, &Y, &w2, &w3, &w4, q_64); 98 | } 99 | } 100 | 101 | static inline void collect_roots_fwd8_r4(mul_op_m512_t w1[5], 102 | const uint64_t w[], 103 | const uint64_t w_con[], 104 | size_t * idx) 105 | { 106 | w1[0].op = SET1(w[*idx]); 107 | w1[1].op = SET1(w[*idx + 1]); 108 | w1[2].op = SET1(w[*idx + 2]); 109 | w1[3].op = SET1(w[*idx + 3]); 110 | w1[4].op = SET1(w[*idx + 4]); 111 | 112 | w1[0].con = SET1(w_con[*idx]); 113 | w1[1].con = SET1(w_con[*idx + 1]); 114 | w1[2].con = SET1(w_con[*idx + 2]); 115 | w1[3].con = SET1(w_con[*idx + 3]); 116 | w1[4].con = SET1(w_con[*idx + 4]); 117 | 118 | *idx += 5; 119 | } 120 | 121 | static inline void fwd8_r4(uint64_t * X_64, 122 | uint64_t * Y_64, 123 | uint64_t * Z_64, 124 | uint64_t * T_64, 125 | const mul_op_m512_t w[5], 126 | const uint64_t q_64) 127 | { 128 | __m512i X = LOAD(X_64); 129 | __m512i Y = LOAD(Y_64); 130 | __m512i Z = LOAD(Z_64); 131 | __m512i T = LOAD(T_64); 132 | 133 | fwd_radix4_butterfly_m512(&X, &Y, &Z, &T, w, q_64); 134 | 135 | STORE(X_64, X); 136 | STORE(Y_64, Y); 137 | STORE(Z_64, Z); 138 | STORE(T_64, T); 139 | } 140 | 141 | void fwd_ntt_r4r2_avx512_ifma_lazy(uint64_t a[], 142 | uint64_t N, 143 | uint64_t q, 144 | const uint64_t w[], 145 | const uint64_t w_con[]) 146 | { 147 | mul_op_m512_t roots[5]; 148 | size_t t = N >> 2; 149 | size_t m = 1; 150 | size_t idx = 1; 151 | 152 | for(; t > 4; m <<= 2) { 153 | for(size_t j = 0; j < m; j++) { 154 | const uint64_t k = 4 * t * j; 155 | collect_roots_fwd8_r4(roots, w, w_con, &idx); 156 | for(size_t i = k; i < k + t; i += 8) { 157 | fwd8_r4(&a[i], &a[i + t], &a[i + 2 * t], &a[i + 3 * t], roots, q); 158 | } 159 | } 160 | t >>= 2; 161 | } 162 | 163 | // Align on an 8-qw boundary 164 | idx = ((idx >> 3) << 3) + 8; 165 | 166 | if(HAS_AN_EVEN_POWER(N)) { 167 | fwd16_r2(a, m, &w[idx], &w_con[idx], q); 168 | } else { 169 | m >>= 1; 170 | fwd8_r2(a, m, &w[idx], &w_con[idx], q); 171 | } 172 | } 173 | 174 | #endif 175 | -------------------------------------------------------------------------------- /src/ntt_radix4.c: -------------------------------------------------------------------------------- 1 | // Copyright IBM Inc. All Rights Reserved. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | #include "ntt_radix4.h" 5 | #include "fast_mul_operators.h" 6 | 7 | static inline void collect_roots(mul_op_t w1[5], 8 | const uint64_t w[], 9 | const uint64_t w_con[], 10 | const size_t m, 11 | const size_t j) 12 | { 13 | const uint64_t m1 = 2 * (m + j); 14 | w1[0].op = w[m1]; 15 | w1[1].op = w[2 * m1]; 16 | w1[2].op = w[2 * m1 + 1]; 17 | w1[3].op = w[2 * m1 + 2]; 18 | w1[4].op = w[2 * m1 + 3]; 19 | 20 | w1[0].con = w_con[m1]; 21 | w1[1].con = w_con[2 * m1]; 22 | w1[2].con = w_con[2 * m1 + 1]; 23 | w1[3].con = w_con[2 * m1 + 2]; 24 | w1[4].con = w_con[2 * m1 + 3]; 25 | } 26 | 27 | void fwd_ntt_radix4_lazy(uint64_t a[], 28 | const uint64_t N, 29 | const uint64_t q, 30 | const uint64_t w[], 31 | const uint64_t w_con[]) 32 | { 33 | const uint64_t bound_r4 = HAS_AN_EVEN_POWER(N) ? N : (N >> 1); 34 | mul_op_t roots[5]; 35 | size_t t = N >> 2; 36 | 37 | for(size_t m = 1; m < bound_r4; m <<= 2) { 38 | for(size_t j = 0; j < m; j++) { 39 | const uint64_t k = 4 * t * j; 40 | 41 | collect_roots(roots, w, w_con, m, j); 42 | for(size_t i = k; i < k + t; i++) { 43 | radix4_fwd_butterfly(&a[i], &a[i + t], &a[i + 2 * t], &a[i + 3 * t], 44 | roots, q); 45 | } 46 | } 47 | t >>= 2; 48 | } 49 | 50 | // Check whether N=2^m where m is odd. 51 | // If not perform extra radix-2 iteration. 52 | if(HAS_AN_EVEN_POWER(N)) { 53 | return; 54 | } 55 | 56 | for(size_t i = 0; i < N; i += 2) { 57 | const mul_op_t w1 = {w[N + i], w_con[N + i]}; 58 | a[i] = reduce_8q_to_4q(a[i], q); 59 | 60 | harvey_fwd_butterfly(&a[i], &a[i + 1], w1, q); 61 | } 62 | } 63 | 64 | void inv_ntt_radix4(uint64_t a[], 65 | const uint64_t N, 66 | const uint64_t q, 67 | const mul_op_t n_inv, 68 | const uint64_t w[], 69 | const uint64_t w_con[]) 70 | { 71 | uint64_t t = 1; 72 | uint64_t m = N; 73 | mul_op_t roots[5]; 74 | 75 | // 1. Check whether N=2^m where m is even. 76 | // If yes, reduce all values modulo 2q, this also can be done outside of this 77 | // function. Otherwise, perform one radix-2 iteration. 78 | if(HAS_AN_EVEN_POWER(N)) { 79 | for(size_t i = 0; i < N; i++) { 80 | a[i] = reduce_8q_to_2q(a[i], q); 81 | } 82 | 83 | } else { 84 | // Perform the first iteration as a radix-2 iteration. 85 | for(size_t i = 0; i < N; i += 2) { 86 | const mul_op_t w1 = {w[N + i], w_con[N + i]}; 87 | 88 | a[i] = reduce_8q_to_4q(a[i], q); 89 | harvey_bkw_butterfly(&a[i], &a[i + 1], w1, q); 90 | } 91 | 92 | m >>= 1; 93 | t <<= 1; 94 | } 95 | 96 | // 2. Perform radix-4 NTT iterations. 97 | for(m >>= 2; m > 0; m >>= 2) { 98 | for(size_t j = 0; j < m; j++) { 99 | const uint64_t k = 4 * t * j; 100 | collect_roots(roots, w, w_con, m, j); 101 | 102 | for(size_t i = k; i < k + t; i++) { 103 | radix4_inv_butterfly(&a[i], &a[i + t], &a[i + 2 * t], &a[i + 3 * t], 104 | roots, q); 105 | } 106 | } 107 | t <<= 2; 108 | } 109 | 110 | // 3. Normalize the results 111 | for(size_t i = 0; i < N; i++) { 112 | a[i] = fast_mul_mod_q(n_inv, a[i], q); 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /src/ntt_radix4_avx512_ifma.c: -------------------------------------------------------------------------------- 1 | // Copyright IBM Inc. All Rights Reserved. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | #ifdef AVX512_IFMA_SUPPORT 5 | 6 | # include "ntt_avx512_ifma.h" 7 | 8 | static inline void collect_roots_fwd1(mul_op_m512_t w1[5], 9 | const uint64_t w[], 10 | const uint64_t w_con[], 11 | size_t * idx) 12 | { 13 | w1[0].op = LOADA(&w[*idx]); 14 | w1[1].op = LOADA(&w[*idx + 8]); 15 | w1[2].op = LOADA(&w[*idx + 16]); 16 | w1[3].op = LOADA(&w[*idx + 24]); 17 | w1[4].op = LOADA(&w[*idx + 32]); 18 | 19 | w1[0].con = LOADA(&w_con[*idx]); 20 | w1[1].con = LOADA(&w_con[*idx + 8]); 21 | w1[2].con = LOADA(&w_con[*idx + 16]); 22 | w1[3].con = LOADA(&w_con[*idx + 24]); 23 | w1[4].con = LOADA(&w_con[*idx + 32]); 24 | 25 | *idx += 5 * 8; 26 | } 27 | 28 | static inline void collect_roots_fwd4(mul_op_m512_t w1[5], 29 | const uint64_t w[], 30 | const uint64_t w_con[], 31 | size_t * idx) 32 | { 33 | w1[0].op = BROADCAST2HALVES(w[*idx + 0], w[*idx + 1]); 34 | w1[1].op = BROADCAST2HALVES(w[*idx + 2], w[*idx + 3]); 35 | w1[2].op = BROADCAST2HALVES(w[*idx + 4], w[*idx + 5]); 36 | w1[3].op = BROADCAST2HALVES(w[*idx + 6], w[*idx + 7]); 37 | w1[4].op = BROADCAST2HALVES(w[*idx + 8], w[*idx + 9]); 38 | 39 | w1[0].con = BROADCAST2HALVES(w_con[*idx + 0], w_con[*idx + 1]); 40 | w1[1].con = BROADCAST2HALVES(w_con[*idx + 2], w_con[*idx + 3]); 41 | w1[2].con = BROADCAST2HALVES(w_con[*idx + 4], w_con[*idx + 5]); 42 | w1[3].con = BROADCAST2HALVES(w_con[*idx + 6], w_con[*idx + 7]); 43 | w1[4].con = BROADCAST2HALVES(w_con[*idx + 8], w_con[*idx + 9]); 44 | 45 | *idx += 10; 46 | } 47 | 48 | static inline void collect_roots_fwd8(mul_op_m512_t w1[5], 49 | const uint64_t w[], 50 | const uint64_t w_con[], 51 | size_t * idx) 52 | { 53 | w1[0].op = SET1(w[*idx]); 54 | w1[1].op = SET1(w[*idx + 1]); 55 | w1[2].op = SET1(w[*idx + 2]); 56 | w1[3].op = SET1(w[*idx + 3]); 57 | w1[4].op = SET1(w[*idx + 4]); 58 | 59 | w1[0].con = SET1(w_con[*idx]); 60 | w1[1].con = SET1(w_con[*idx + 1]); 61 | w1[2].con = SET1(w_con[*idx + 2]); 62 | w1[3].con = SET1(w_con[*idx + 3]); 63 | w1[4].con = SET1(w_con[*idx + 4]); 64 | 65 | *idx += 5; 66 | } 67 | 68 | static inline void 69 | fwd1(uint64_t *a, const mul_op_m512_t w[5], const uint64_t q_64) 70 | { 71 | const __m512i idx = _mm512_setr_epi64(0, 4, 8, 12, 16, 20, 24, 28); 72 | 73 | __m512i X = GATHER(idx, &a[0 + 0], 8); 74 | __m512i Y = GATHER(idx, &a[0 + 1], 8); 75 | __m512i Z = GATHER(idx, &a[0 + 2], 8); 76 | __m512i T = GATHER(idx, &a[0 + 3], 8); 77 | 78 | fwd_radix4_butterfly_m512(&X, &Y, &Z, &T, w, q_64); 79 | 80 | SCATTER(&a[0 + 0], idx, X, 8); 81 | SCATTER(&a[0 + 1], idx, Y, 8); 82 | SCATTER(&a[0 + 2], idx, Z, 8); 83 | SCATTER(&a[0 + 3], idx, T, 8); 84 | } 85 | 86 | static inline void 87 | fwd4(uint64_t *a, const mul_op_m512_t w[5], const uint64_t q_64) 88 | { 89 | __m512i X1 = LOAD(&a[0]); 90 | __m512i Y1 = LOAD(&a[8]); 91 | __m512i Z1 = LOAD(&a[16]); 92 | __m512i T1 = LOAD(&a[24]); 93 | 94 | __m512i X = SHUF(X1, Z1, 0x44); 95 | __m512i Y = SHUF(X1, Z1, 0xee); 96 | __m512i Z = SHUF(Y1, T1, 0x44); 97 | __m512i T = SHUF(Y1, T1, 0xee); 98 | 99 | fwd_radix4_butterfly_m512(&X, &Y, &Z, &T, w, q_64); 100 | 101 | X1 = SHUF(X, Y, 0x44); 102 | Z1 = SHUF(X, Y, 0xee); 103 | Y1 = SHUF(Z, T, 0x44); 104 | T1 = SHUF(Z, T, 0xee); 105 | 106 | STORE(&a[0], X1); 107 | STORE(&a[8], Y1); 108 | STORE(&a[16], Z1); 109 | STORE(&a[24], T1); 110 | } 111 | 112 | static inline void fwd8(uint64_t * X_64, 113 | uint64_t * Y_64, 114 | uint64_t * Z_64, 115 | uint64_t * T_64, 116 | const mul_op_m512_t w[5], 117 | const uint64_t q_64) 118 | { 119 | __m512i X = LOAD(X_64); 120 | __m512i Y = LOAD(Y_64); 121 | __m512i Z = LOAD(Z_64); 122 | __m512i T = LOAD(T_64); 123 | 124 | fwd_radix4_butterfly_m512(&X, &Y, &Z, &T, w, q_64); 125 | 126 | STORE(X_64, X); 127 | STORE(Y_64, Y); 128 | STORE(Z_64, Z); 129 | STORE(T_64, T); 130 | } 131 | 132 | void fwd_ntt_radix4_avx512_ifma_lazy(uint64_t a[], 133 | const uint64_t N, 134 | const uint64_t q, 135 | const uint64_t w[], 136 | const uint64_t w_con[]) 137 | { 138 | mul_op_m512_t roots[5]; 139 | size_t bound_r4 = N; 140 | size_t t = N >> 1; 141 | size_t m = 1; 142 | size_t idx = 1; 143 | 144 | // Check whether N=2^m where m is odd. 145 | // If not perform extra radix-2 iteration. 146 | if(!HAS_AN_EVEN_POWER(N)) { 147 | const mul_op_m512_t w1 = {SET1(w[1]), SET1(w_con[1])}; 148 | 149 | for(size_t j = 0; j < t; j += 8) { 150 | __m512i X = LOAD(&a[j]); 151 | __m512i Y = LOAD(&a[j + t]); 152 | 153 | fwd_radix2_butterfly_m512(&X, &Y, &w1, q); 154 | 155 | STORE(&a[j], X); 156 | STORE(&a[j + t], Y); 157 | } 158 | bound_r4 >>= 1; 159 | t >>= 1; 160 | m <<= 1; 161 | idx++; 162 | } 163 | 164 | // Adjust to radix-4 165 | t >>= 1; 166 | 167 | for(; m < bound_r4; m <<= 2) { 168 | if(t >= 8) { 169 | for(size_t j = 0; j < m; j++) { 170 | const uint64_t k = 4 * t * j; 171 | collect_roots_fwd8(roots, w, w_con, &idx); 172 | for(size_t i = k; i < k + t; i += 8) { 173 | fwd8(&a[i], &a[i + t], &a[i + 2 * t], &a[i + 3 * t], roots, q); 174 | } 175 | } 176 | } else if(t == 4) { 177 | for(size_t j = 0; j < m; j += 2) { 178 | collect_roots_fwd4(roots, w, w_con, &idx); 179 | fwd4(&a[4 * 4 * j], roots, q); 180 | } 181 | } else { 182 | // Align on an 8-qw boundary 183 | idx = ((idx >> 3) << 3) + 8; 184 | 185 | LOOP_UNROLL_4 186 | for(size_t j = 0; j < m; j += 8) { 187 | collect_roots_fwd1(roots, w, w_con, &idx); 188 | fwd1(&a[4 * j], roots, q); 189 | } 190 | } 191 | t >>= 2; 192 | } 193 | } 194 | 195 | #endif 196 | -------------------------------------------------------------------------------- /src/ntt_radix4_avx512_ifma_unordered.c: -------------------------------------------------------------------------------- 1 | // Copyright IBM Inc. All Rights Reserved. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | #ifdef AVX512_IFMA_SUPPORT 5 | 6 | # include "ntt_avx512_ifma.h" 7 | 8 | static inline void collect_roots_fwd1(mul_op_m512_t w1[5], 9 | const uint64_t w[], 10 | const uint64_t w_con[], 11 | size_t * idx) 12 | { 13 | w1[0].op = LOADA(&w[*idx]); 14 | w1[1].op = LOADA(&w[*idx + 8]); 15 | w1[2].op = LOADA(&w[*idx + 16]); 16 | w1[3].op = LOADA(&w[*idx + 24]); 17 | w1[4].op = LOADA(&w[*idx + 32]); 18 | 19 | w1[0].con = LOADA(&w_con[*idx]); 20 | w1[1].con = LOADA(&w_con[*idx + 8]); 21 | w1[2].con = LOADA(&w_con[*idx + 16]); 22 | w1[3].con = LOADA(&w_con[*idx + 24]); 23 | w1[4].con = LOADA(&w_con[*idx + 32]); 24 | 25 | *idx += 5 * 8; 26 | } 27 | 28 | static inline void collect_roots_fwd4(mul_op_m512_t w1[5], 29 | const uint64_t w[], 30 | const uint64_t w_con[], 31 | size_t * idx) 32 | { 33 | w1[0].op = BROADCAST2HALVES(w[*idx + 0], w[*idx + 1]); 34 | w1[1].op = BROADCAST2HALVES(w[*idx + 2], w[*idx + 3]); 35 | w1[2].op = BROADCAST2HALVES(w[*idx + 4], w[*idx + 5]); 36 | w1[3].op = BROADCAST2HALVES(w[*idx + 6], w[*idx + 7]); 37 | w1[4].op = BROADCAST2HALVES(w[*idx + 8], w[*idx + 9]); 38 | 39 | w1[0].con = BROADCAST2HALVES(w_con[*idx + 0], w_con[*idx + 1]); 40 | w1[1].con = BROADCAST2HALVES(w_con[*idx + 2], w_con[*idx + 3]); 41 | w1[2].con = BROADCAST2HALVES(w_con[*idx + 4], w_con[*idx + 5]); 42 | w1[3].con = BROADCAST2HALVES(w_con[*idx + 6], w_con[*idx + 7]); 43 | w1[4].con = BROADCAST2HALVES(w_con[*idx + 8], w_con[*idx + 9]); 44 | 45 | *idx += 10; 46 | } 47 | 48 | static inline void collect_roots_fwd8(mul_op_m512_t w1[5], 49 | const uint64_t w[], 50 | const uint64_t w_con[], 51 | size_t * idx) 52 | { 53 | w1[0].op = SET1(w[*idx]); 54 | w1[1].op = SET1(w[*idx + 1]); 55 | w1[2].op = SET1(w[*idx + 2]); 56 | w1[3].op = SET1(w[*idx + 3]); 57 | w1[4].op = SET1(w[*idx + 4]); 58 | 59 | w1[0].con = SET1(w_con[*idx]); 60 | w1[1].con = SET1(w_con[*idx + 1]); 61 | w1[2].con = SET1(w_con[*idx + 2]); 62 | w1[3].con = SET1(w_con[*idx + 3]); 63 | w1[4].con = SET1(w_con[*idx + 4]); 64 | 65 | *idx += 5; 66 | } 67 | 68 | static inline void 69 | fwd1(uint64_t *a, const mul_op_m512_t w[5], const uint64_t q_64) 70 | { 71 | const __m512i idx = _mm512_setr_epi64(0, 4, 8, 12, 16, 20, 24, 28); 72 | 73 | __m512i X = _mm512_i64gather_epi64(idx, &a[0 + 0], 8); 74 | __m512i Y = _mm512_i64gather_epi64(idx, &a[0 + 1], 8); 75 | __m512i Z = _mm512_i64gather_epi64(idx, &a[0 + 2], 8); 76 | __m512i T = _mm512_i64gather_epi64(idx, &a[0 + 3], 8); 77 | 78 | fwd_radix4_butterfly_m512(&X, &Y, &Z, &T, w, q_64); 79 | 80 | STORE(&a[0], X); 81 | STORE(&a[8], Y); 82 | STORE(&a[16], Z); 83 | STORE(&a[24], T); 84 | } 85 | 86 | static inline void 87 | fwd4(uint64_t *a, const mul_op_m512_t w[5], const uint64_t q_64) 88 | { 89 | __m512i X1 = LOAD(&a[0]); 90 | __m512i Y1 = LOAD(&a[8]); 91 | __m512i Z1 = LOAD(&a[16]); 92 | __m512i T1 = LOAD(&a[24]); 93 | 94 | __m512i X = _mm512_shuffle_i64x2(X1, Z1, 0x44); 95 | __m512i Y = _mm512_shuffle_i64x2(X1, Z1, 0xee); 96 | __m512i Z = _mm512_shuffle_i64x2(Y1, T1, 0x44); 97 | __m512i T = _mm512_shuffle_i64x2(Y1, T1, 0xee); 98 | 99 | fwd_radix4_butterfly_m512(&X, &Y, &Z, &T, w, q_64); 100 | 101 | STORE(&a[0], X); 102 | STORE(&a[8], Y); 103 | STORE(&a[16], Z); 104 | STORE(&a[24], T); 105 | } 106 | 107 | static inline void fwd8(uint64_t * X_64, 108 | uint64_t * Y_64, 109 | uint64_t * Z_64, 110 | uint64_t * T_64, 111 | const mul_op_m512_t w[5], 112 | const uint64_t q_64) 113 | { 114 | __m512i X = LOAD(X_64); 115 | __m512i Y = LOAD(Y_64); 116 | __m512i Z = LOAD(Z_64); 117 | __m512i T = LOAD(T_64); 118 | 119 | fwd_radix4_butterfly_m512(&X, &Y, &Z, &T, w, q_64); 120 | 121 | STORE(X_64, X); 122 | STORE(Y_64, Y); 123 | STORE(Z_64, Z); 124 | STORE(T_64, T); 125 | } 126 | 127 | void fwd_ntt_radix4_avx512_ifma_lazy_unordered(uint64_t a[], 128 | const uint64_t N, 129 | const uint64_t q, 130 | const uint64_t w[], 131 | const uint64_t w_con[]) 132 | { 133 | mul_op_m512_t roots[5]; 134 | size_t bound_r4 = N; 135 | size_t t = N >> 1; 136 | size_t m = 1; 137 | size_t idx = 1; 138 | 139 | // Check whether N=2^m where m is odd. 140 | // If not perform extra radix-2 iteration. 141 | if(!HAS_AN_EVEN_POWER(N)) { 142 | const mul_op_m512_t w1 = {SET1(w[1]), SET1(w_con[1])}; 143 | 144 | for(size_t j = 0; j < t; j += 8) { 145 | __m512i X = LOAD(&a[j]); 146 | __m512i Y = LOAD(&a[j + t]); 147 | 148 | fwd_radix2_butterfly_m512(&X, &Y, &w1, q); 149 | 150 | STORE(&a[j], X); 151 | STORE(&a[j + t], Y); 152 | } 153 | bound_r4 >>= 1; 154 | t >>= 1; 155 | m <<= 1; 156 | idx++; 157 | } 158 | 159 | // Adjust to radix-4 160 | t >>= 1; 161 | 162 | for(; m < bound_r4; m <<= 2) { 163 | if(t >= 8) { 164 | for(size_t j = 0; j < m; j++) { 165 | const uint64_t k = 4 * t * j; 166 | collect_roots_fwd8(roots, w, w_con, &idx); 167 | for(size_t i = k; i < k + t; i += 8) { 168 | fwd8(&a[i], &a[i + t], &a[i + 2 * t], &a[i + 3 * t], roots, q); 169 | } 170 | } 171 | } else if(t == 4) { 172 | for(size_t j = 0; j < m; j += 2) { 173 | collect_roots_fwd4(roots, w, w_con, &idx); 174 | fwd4(&a[4 * 4 * j], roots, q); 175 | } 176 | } else { 177 | // Align on an 8-qw boundary 178 | idx = ((idx >> 3) << 3) + 8; 179 | 180 | for(size_t j = 0; j < m;) { 181 | LOOP_UNROLL_4 182 | for(size_t k = 0; k < 4; k += 1, j += 8) { 183 | collect_roots_fwd1(roots, w, w_con, &idx); 184 | fwd1(&a[4 * j], roots, q); 185 | } 186 | } 187 | } 188 | t >>= 2; 189 | } 190 | } 191 | 192 | #endif 193 | -------------------------------------------------------------------------------- /src/ntt_radix4_s390x_vef.c: -------------------------------------------------------------------------------- 1 | // Copyright IBM Inc. All Rights Reserved. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | #include 5 | 6 | #define L_HIGH_WORD HIGH_VMSL_WORD 7 | 8 | #include "fast_mul_operators.h" 9 | #include "ntt_radix4_s390x_vef.h" 10 | 11 | #define UL_VMSL_Z(a, b, ctx) (ul_vec) vec_msum_u128(a, b, (ctx)->zero, 0) 12 | #define UL_VMSL (ul_vec) vec_msum_u128 13 | #define VEC_BYTES 16 14 | 15 | typedef vector unsigned long long ul_vec; 16 | typedef vector unsigned char uc_vec; 17 | 18 | typedef struct loop_ctx_s { 19 | mul_op_t w1; 20 | ul_vec r1; 21 | ul_vec r2; 22 | ul_vec r1_con; 23 | ul_vec r2_con; 24 | 25 | // In order not to hit a performance slowdown when using VMSL, 26 | // we mask -q to be smaller than 2^56. This should be fine in cases 27 | // where we only consider the lower 56 bits of the VMSL output. 28 | uint64_t neg_q; 29 | 30 | // q2 = 2*q, q4 = 4*q 31 | uint64_t q2; 32 | uint64_t q4; 33 | ul_vec q2_vec; 34 | ul_vec q4_vec; 35 | 36 | uc_vec zero; 37 | } loop_ctx_t; 38 | 39 | /****************************** 40 | Single input 41 | ******************************/ 42 | 43 | static inline ul_vec extended_shoup_multiply(const ul_vec r, 44 | const ul_vec r_con, 45 | const ul_vec YT, 46 | const ul_vec neg_q, 47 | const loop_ctx_t *ctx) 48 | { 49 | ul_vec t1 = UL_VMSL_Z(r_con, YT, ctx); 50 | ul_vec t2 = UL_VMSL_Z(r, YT, ctx); 51 | t1 = vec_sld((ul_vec)ctx->zero, t1, VEC_BYTES - (VMSL_WORD_SIZE / 8)); 52 | return UL_VMSL(neg_q, t1, (uc_vec)t2, 0); 53 | } 54 | 55 | static inline void single_fwd_butterfly(uint64_t a[], 56 | const uint64_t q, 57 | const loop_ctx_t *ctx, 58 | const size_t i, 59 | const uint64_t t) 60 | { 61 | const uint64_t X = reduce_8q_to_4q(a[i], q); 62 | const uint64_t Z = fast_mul_mod_q2(ctx->w1, a[i + 2 * t], q); 63 | 64 | // Load Y and T 65 | ul_vec YT = {a[i + t], a[i + 3 * t]}; 66 | 67 | ul_vec t1 = {0, ctx->neg_q}; 68 | ul_vec r4_1 = extended_shoup_multiply(ctx->r1, ctx->r1_con, YT, t1, ctx); 69 | ul_vec r5_1 = extended_shoup_multiply(ctx->r2, ctx->r2_con, YT, t1, ctx); 70 | const uint64_t Y1 = LOW_VMSL_WORD(r4_1[1]); 71 | const uint64_t Y2 = LOW_VMSL_WORD(r5_1[1]); 72 | 73 | // Store the results 74 | a[i] = (X + Z + Y1); 75 | a[i + t] = ctx->q2 + (X + Z - Y1); 76 | a[i + 2 * t] = ctx->q2 + (X - Z + Y2); 77 | a[i + 3 * t] = ctx->q4 + (X - Z - Y2); 78 | } 79 | 80 | // Assumption t is even. 81 | static inline void double_fwd_butterfly(uint64_t a[], 82 | const loop_ctx_t *ctx, 83 | const size_t i, 84 | const uint64_t t) 85 | { 86 | // We use vector pointers to load two coefficients of "a" at once. 87 | ul_vec *a_vec = (ul_vec *)&a[i]; 88 | 89 | // Load X and Z, note that because we are using vector pointers 90 | // t is divided by 2. 91 | ul_vec X = a_vec[0]; 92 | ul_vec Z = a_vec[2 * (t >> 1)]; 93 | 94 | // Load Y and T of each of the two iterations into YT1, YT2, respectively. 95 | // Note that here we use t instead of t>>1 as "a" is a uint64_t. 96 | ul_vec YT1 = {a[i + t], a[i + 3 * t]}; 97 | ul_vec YT2 = {a[i + t + 1], a[i + 3 * t + 1]}; 98 | 99 | // X = X mod 4q 100 | X = vec_sel(X, X - ctx->q4_vec, vec_cmpge(X, ctx->q4_vec)); 101 | 102 | // Simple Shoup multiply on two elements in parallel. 103 | ul_vec t1 = {ctx->w1.op, ctx->neg_q}; 104 | ul_vec Q_1 = {Z[0], HIGH_VMSL_WORD(ctx->w1.con * Z[0])}; 105 | ul_vec Q_2 = {Z[1], HIGH_VMSL_WORD(ctx->w1.con * Z[1])}; 106 | Q_1 = UL_VMSL_Z(t1, Q_1, ctx); 107 | Q_2 = UL_VMSL_Z(t1, Q_2, ctx); 108 | Z = LOW_VMSL_WORD(vec_mergel(Q_1, Q_2)); 109 | 110 | // Extended Shoup multiply on two elements in parallel. 111 | t1[0] = 0; 112 | ul_vec r4_1 = extended_shoup_multiply(ctx->r1, ctx->r1_con, YT1, t1, ctx); 113 | ul_vec r4_2 = extended_shoup_multiply(ctx->r1, ctx->r1_con, YT2, t1, ctx); 114 | ul_vec r5_1 = extended_shoup_multiply(ctx->r2, ctx->r2_con, YT1, t1, ctx); 115 | ul_vec r5_2 = extended_shoup_multiply(ctx->r2, ctx->r2_con, YT2, t1, ctx); 116 | 117 | const ul_vec Y1 = LOW_VMSL_WORD(vec_mergel(r4_1, r4_2)); 118 | const ul_vec Y2 = LOW_VMSL_WORD(vec_mergel(r5_1, r5_2)); 119 | 120 | // Store the results 121 | a_vec[0 * (t >> 1)] = (X + Z + Y1); 122 | a_vec[1 * (t >> 1)] = ctx->q2_vec + (X + Z - Y1); 123 | a_vec[2 * (t >> 1)] = ctx->q2_vec + (X - Z + Y2); 124 | a_vec[3 * (t >> 1)] = ctx->q4_vec + (X - Z - Y2); 125 | } 126 | 127 | void fwd_ntt_radix4_intrinsic_lazy(uint64_t a[], 128 | const uint64_t N, 129 | const uint64_t q, 130 | const uint64_t w[], 131 | const uint64_t w_con[]) 132 | { 133 | const size_t bound_r4 = HAS_AN_EVEN_POWER(N) ? N : (N >> 1); 134 | 135 | for(size_t m = 1, t = (N >> 2); m < bound_r4; m <<= 2, t >>= 2) { 136 | for(size_t j = 0; j < m; j++) { 137 | const size_t m1 = 2 * (m + j); 138 | const size_t m2 = 2 * m1; 139 | const size_t k = 4 * t * j; 140 | const loop_ctx_t ctx = {.w1 = {w[m1], w_con[m1]}, 141 | .r1 = {w[m2 + 0], w[m2 + 1]}, 142 | .r2 = {w[m2 + 2], w[m2 + 3]}, 143 | .r1_con = {w_con[m2 + 0], w_con[m2 + 1]}, 144 | .r2_con = {w_con[m2 + 2], w_con[m2 + 3]}, 145 | .zero = {0}, 146 | .neg_q = (-1 * q) & VMSL_WORD_SIZE_MASK, 147 | .q2 = 2 * q, 148 | .q4 = 4 * q, 149 | .q2_vec = {2 * q, 2 * q}, 150 | .q4_vec = {4 * q, 4 * q}}; 151 | 152 | if(t == 1) { 153 | for(size_t i = k; i < k + t; i++) { 154 | single_fwd_butterfly(a, q, &ctx, i, t); 155 | } 156 | } else { 157 | for(size_t i = k; i < k + t; i += 2) { 158 | double_fwd_butterfly(a, &ctx, i, t); 159 | } 160 | } 161 | } 162 | } 163 | 164 | // Check whether N=2^m where m is odd. 165 | if(HAS_AN_EVEN_POWER(N)) { 166 | return; 167 | } 168 | 169 | for(size_t i = 0; i < N; i += 2) { 170 | const mul_op_t w1 = {w[i + N], w_con[i + N]}; 171 | 172 | a[i] = reduce_8q_to_4q(a[i], q); 173 | harvey_fwd_butterfly(&a[i], &a[i + 1], w1, q); 174 | } 175 | } 176 | 177 | static inline void single_inv_butterfly(uint64_t a[], 178 | const uint64_t q, 179 | const loop_ctx_t *ctx, 180 | const size_t i, 181 | const uint64_t t) 182 | { 183 | const uint64_t X = a[i + 0 * t]; 184 | const uint64_t Y = a[i + 1 * t]; 185 | const uint64_t Z = a[i + 2 * t]; 186 | const uint64_t T = a[i + 3 * t]; 187 | 188 | const uint64_t T0 = Z + T; 189 | const uint64_t T1 = X + Y; 190 | const uint64_t T4 = fast_mul_mod_q2(ctx->w1, ctx->q4 + T1 - T0, q); 191 | 192 | ul_vec T23 = {ctx->q4 + X - Y, ctx->q4 + Z - T}; 193 | ul_vec t1 = {0, ctx->neg_q}; 194 | ul_vec r4_1 = extended_shoup_multiply(ctx->r1, ctx->r1_con, T23, t1, ctx); 195 | ul_vec r5_1 = extended_shoup_multiply(ctx->r2, ctx->r2_con, T23, t1, ctx); 196 | 197 | const uint64_t Y1 = LOW_VMSL_WORD(r4_1[1]); 198 | const uint64_t Y2 = LOW_VMSL_WORD(r5_1[1]); 199 | 200 | a[i + 0 * t] = reduce_8q_to_2q(T1 + T0, q); 201 | a[i + 1 * t] = Y1; 202 | a[i + 2 * t] = T4; 203 | a[i + 3 * t] = Y2; 204 | } 205 | 206 | // Assumption t is even. 207 | static inline void double_inv_butterfly(uint64_t a[], 208 | const loop_ctx_t *ctx, 209 | const size_t i, 210 | const size_t t) 211 | { 212 | // We use vector pointers to load two coefficients of "a" at once. 213 | ul_vec *a_vec = (ul_vec *)&a[i]; 214 | 215 | const ul_vec X = a_vec[0 * t]; 216 | const ul_vec Y = a_vec[1 * t]; 217 | const ul_vec Z = a_vec[2 * t]; 218 | const ul_vec T = a_vec[3 * t]; 219 | 220 | const ul_vec T0 = Z + T; 221 | ul_vec T1 = X + Y; 222 | ul_vec T4 = ctx->q4_vec + T1 - T0; 223 | 224 | // Simple Shoup multiply on two elements in parallel. 225 | ul_vec t1 = {ctx->w1.op, ctx->neg_q}; 226 | ul_vec Q_1 = {T4[0], HIGH_VMSL_WORD(ctx->w1.con * T4[0])}; 227 | ul_vec Q_2 = {T4[1], HIGH_VMSL_WORD(ctx->w1.con * T4[1])}; 228 | Q_1 = UL_VMSL_Z(t1, Q_1, ctx); 229 | Q_2 = UL_VMSL_Z(t1, Q_2, ctx); 230 | T4 = LOW_VMSL_WORD(vec_mergel(Q_1, Q_2)); 231 | 232 | ul_vec T2 = ctx->q4_vec + X - Y; 233 | ul_vec T3 = ctx->q4_vec + Z - T; 234 | ul_vec T23a = {T2[0], T3[0]}; 235 | ul_vec T23b = {T2[1], T3[1]}; 236 | 237 | // Extended Shoup multiply on two elements in parallel. 238 | t1[0] = 0; 239 | ul_vec r4_1 = extended_shoup_multiply(ctx->r1, ctx->r1_con, T23a, t1, ctx); 240 | ul_vec r4_2 = extended_shoup_multiply(ctx->r1, ctx->r1_con, T23b, t1, ctx); 241 | ul_vec r5_1 = extended_shoup_multiply(ctx->r2, ctx->r2_con, T23a, t1, ctx); 242 | ul_vec r5_2 = extended_shoup_multiply(ctx->r2, ctx->r2_con, T23b, t1, ctx); 243 | 244 | const ul_vec Y1 = LOW_VMSL_WORD(vec_mergel(r4_1, r4_2)); 245 | const ul_vec Y2 = LOW_VMSL_WORD(vec_mergel(r5_1, r5_2)); 246 | 247 | T1 = T1 + T0; 248 | T1 = vec_sel(T1, T1 - ctx->q4_vec, vec_cmpge(T1, ctx->q4_vec)); 249 | T1 = vec_sel(T1, T1 - ctx->q2_vec, vec_cmpge(T1, ctx->q2_vec)); 250 | a_vec[0 * t] = T1; 251 | a_vec[1 * t] = Y1; 252 | a_vec[2 * t] = T4; 253 | a_vec[3 * t] = Y2; 254 | } 255 | 256 | void inv_ntt_radix4_intrinsic(uint64_t a[], 257 | const uint64_t N, 258 | const uint64_t q, 259 | const mul_op_t n_inv, 260 | const uint64_t w[], 261 | const uint64_t w_con[]) 262 | { 263 | size_t t = 1; 264 | size_t m = N; 265 | 266 | // Check whether N=2^m where m is odd. 267 | if(HAS_AN_EVEN_POWER(N)) { 268 | for(size_t i = 0; i < N; i++) { 269 | a[i] = reduce_8q_to_2q(a[i], q); 270 | } 271 | } else { 272 | // Perform the first iteration as a radix-2 iteration. 273 | for(size_t i = 0; i < N; i += 2) { 274 | const mul_op_t w1 = {w[i + N], w_con[i + N]}; 275 | a[i] = reduce_8q_to_4q(a[i], q); 276 | harvey_bkw_butterfly(&a[i], &a[i + 1], w1, q); 277 | } 278 | m >>= 1; 279 | t <<= 1; 280 | } 281 | 282 | while(m > 0) { 283 | m >>= 2; 284 | for(size_t j = 0; j < m; j++) { 285 | const uint64_t m1 = 2 * (m + j); 286 | const uint64_t m2 = 2 * m1; 287 | const uint64_t k = 4 * t * j; 288 | const loop_ctx_t ctx = {.w1 = {w[m1], w_con[m1]}, 289 | .r1 = {w[m2 + 0], w[m2 + 2]}, 290 | .r2 = {w[m2 + 1], w[m2 + 3]}, 291 | .r1_con = {w_con[m2 + 0], w_con[m2 + 2]}, 292 | .r2_con = {w_con[m2 + 1], w_con[m2 + 3]}, 293 | .zero = {0}, 294 | .neg_q = (-1 * q) & VMSL_WORD_SIZE_MASK, 295 | .q2 = 2 * q, 296 | .q4 = 4 * q, 297 | .q2_vec = {2 * q, 2 * q}, 298 | .q4_vec = {4 * q, 4 * q}}; 299 | 300 | if(t == 1) { 301 | for(size_t i = k; i < k + t; i++) { 302 | single_inv_butterfly(a, q, &ctx, i, t); 303 | } 304 | } else { 305 | for(size_t i = k; i < k + t; i += 2) { 306 | double_inv_butterfly(a, &ctx, i, t >> 1); 307 | } 308 | } 309 | } 310 | 311 | t <<= 2; 312 | } 313 | 314 | for(size_t i = 0; i < N; i++) { 315 | // At the last iteration, multiply by n^-1 mod q 316 | a[i] = fast_mul_mod_q(n_inv, a[i], q); 317 | } 318 | } 319 | 320 | /****************************** 321 | Double input 322 | ******************************/ 323 | void fwd_ntt_radix4_intrinsic_lazy_dbl(uint64_t a1[], 324 | uint64_t a2[], 325 | const uint64_t N, 326 | const uint64_t q, 327 | const uint64_t w[], 328 | const uint64_t w_con[]) 329 | { 330 | const size_t bound_r4 = HAS_AN_EVEN_POWER(N) ? N : (N >> 1); 331 | 332 | for(size_t m = 1, t = (N >> 2); m < bound_r4; m <<= 2, t >>= 2) { 333 | for(size_t j = 0; j < m; j++) { 334 | const size_t m1 = 2 * (m + j); 335 | const size_t m2 = 2 * m1; 336 | const size_t k = 4 * t * j; 337 | const loop_ctx_t ctx = {.w1 = {w[m1], w_con[m1]}, 338 | .r1 = {w[m2 + 0], w[m2 + 1]}, 339 | .r2 = {w[m2 + 2], w[m2 + 3]}, 340 | .r1_con = {w_con[m2 + 0], w_con[m2 + 1]}, 341 | .r2_con = {w_con[m2 + 2], w_con[m2 + 3]}, 342 | .zero = {0}, 343 | .neg_q = (-1 * q) & VMSL_WORD_SIZE_MASK, 344 | .q2 = 2 * q, 345 | .q4 = 4 * q, 346 | .q2_vec = {2 * q, 2 * q}, 347 | .q4_vec = {4 * q, 4 * q}}; 348 | 349 | if(t == 1) { 350 | for(size_t i = k; i < k + t; i++) { 351 | single_fwd_butterfly(a1, q, &ctx, i, t); 352 | single_fwd_butterfly(a2, q, &ctx, i, t); 353 | } 354 | } else { 355 | for(size_t i = k; i < k + t; i += 2) { 356 | double_fwd_butterfly(a1, &ctx, i, t); 357 | double_fwd_butterfly(a2, &ctx, i, t); 358 | } 359 | } 360 | } 361 | } 362 | 363 | // Check whether N=2^m where m is odd. 364 | if(HAS_AN_EVEN_POWER(N)) { 365 | return; 366 | } 367 | 368 | for(size_t i = 0; i < N; i += 2) { 369 | const mul_op_t w1 = {w[i + N], w_con[i + N]}; 370 | a1[i] = reduce_8q_to_4q(a1[i], q); 371 | a2[i] = reduce_8q_to_4q(a2[i], q); 372 | 373 | harvey_fwd_butterfly(&a1[i], &a1[i + 1], w1, q); 374 | harvey_fwd_butterfly(&a2[i], &a2[i + 1], w1, q); 375 | } 376 | } 377 | -------------------------------------------------------------------------------- /src/ntt_radix4x4.c: -------------------------------------------------------------------------------- 1 | // Copyright IBM Inc. All Rights Reserved. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | #include "ntt_radix4x4.h" 5 | #include "fast_mul_operators.h" 6 | 7 | static inline void collect_roots(mul_op_t w1[5], 8 | const uint64_t w[], 9 | const uint64_t w_con[], 10 | const size_t m, 11 | const size_t j) 12 | { 13 | const uint64_t m1 = 2 * (m + j); 14 | w1[0].op = w[m1]; 15 | w1[1].op = w[2 * m1]; 16 | w1[2].op = w[2 * m1 + 1]; 17 | w1[3].op = w[2 * m1 + 2]; 18 | w1[4].op = w[2 * m1 + 3]; 19 | 20 | w1[0].con = w_con[m1]; 21 | w1[1].con = w_con[2 * m1]; 22 | w1[2].con = w_con[2 * m1 + 1]; 23 | w1[3].con = w_con[2 * m1 + 2]; 24 | w1[4].con = w_con[2 * m1 + 3]; 25 | } 26 | 27 | static inline uint64_t get_iter_reminder(const uint64_t N) 28 | { 29 | if(HAS_AN_REM1_POWER(N)) { 30 | return 1; 31 | } 32 | if(HAS_AN_REM2_POWER(N)) { 33 | return 2; 34 | } 35 | if(HAS_AN_REM3_POWER(N)) { 36 | return 3; 37 | } 38 | return 0; 39 | } 40 | 41 | void fwd_ntt_radix4x4_lazy(uint64_t a[], 42 | const uint64_t N, 43 | const uint64_t q, 44 | const uint64_t w[], 45 | const uint64_t w_con[]) 46 | { 47 | const uint64_t bound_r4 = N; 48 | uint64_t m_rem = get_iter_reminder(N); 49 | 50 | mul_op_t roots[5]; 51 | mul_op_t roots4[4][5]; 52 | size_t t = N >> 2; 53 | 54 | for(size_t m = 1; m < (bound_r4 >> m_rem); m <<= 4) { 55 | for(size_t j = 0; j < m; j++) { 56 | const uint64_t k = 4 * t * j; 57 | size_t t2 = t >> 2; 58 | 59 | collect_roots(roots, w, w_con, m, j); 60 | for(size_t i = 0; i < 4; i++) { 61 | collect_roots(roots4[i], w, w_con, m << 2, 4 * j + i); 62 | } 63 | 64 | // Perform the 16-radix NTT in two steps of radix-4 NTT 65 | for(size_t i = k; i < k + t2; i++) { 66 | for(size_t l = i; l < i + t; l += t2) { 67 | radix4_fwd_butterfly(&a[l], &a[l + t], &a[l + 2 * t], &a[l + 3 * t], 68 | roots, q); 69 | } 70 | size_t x = 0; 71 | for(size_t l = i; l < i + 4 * t; l += t, x++) { 72 | radix4_fwd_butterfly(&a[l], &a[l + t2], &a[l + 2 * t2], &a[l + 3 * t2], 73 | roots4[x], q); 74 | } 75 | } 76 | } 77 | t >>= 4; 78 | } 79 | 80 | // Perform extra iterations if needed 81 | switch(m_rem) { 82 | case 1: 83 | // Perform extra radix-2 iteration. 84 | for(size_t i = 0; i < N; i += 2) { 85 | const mul_op_t w1 = {w[N + i], w_con[N + i]}; 86 | a[i] = reduce_8q_to_4q(a[i], q); 87 | 88 | harvey_fwd_butterfly(&a[i], &a[i + 1], w1, q); 89 | } 90 | return; 91 | case 3: 92 | // Perform extra radix-2 and then radix-4 iteration. 93 | t = 4; 94 | const size_t m = N >> 3; 95 | for(size_t i = 0; i < m; i++) { 96 | const size_t k = 2 * t * i; 97 | const mul_op_t w1 = {w[2 * (m + i)], w_con[2 * (m + i)]}; 98 | a[i] = reduce_8q_to_4q(a[i], q); 99 | 100 | for(size_t j = k; j < k + t; j++) { 101 | harvey_fwd_butterfly(&a[j], &a[j + t], w1, q); 102 | } 103 | } 104 | /* fall through */ 105 | case 2: 106 | // Perform extra radix-4 iteration (for cases 2 and 3). 107 | for(size_t i = 0; i < N; i += 4) { 108 | collect_roots(roots, w, w_con, N >> 2, i >> 2); 109 | radix4_fwd_butterfly(&a[i], &a[i + 1], &a[i + 2], &a[i + 3], roots, q); 110 | } 111 | return; 112 | default: return; 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /src/ntt_reference.c: -------------------------------------------------------------------------------- 1 | // Copyright IBM Inc. All Rights Reserved. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | #include "ntt_reference.h" 5 | #include "fast_mul_operators.h" 6 | 7 | /****************************** 8 | Single input 9 | ******************************/ 10 | 11 | void fwd_ntt_ref_harvey_lazy(uint64_t a[], 12 | const uint64_t N, 13 | const uint64_t q, 14 | const uint64_t w[], 15 | const uint64_t w_con[]) 16 | { 17 | size_t t = N >> 1; 18 | 19 | for(size_t m = 1; m < N; m <<= 1, t >>= 1) { 20 | size_t k = 0; 21 | for(size_t i = 0; i < m; i++) { 22 | const mul_op_t w1 = {w[m + i], w_con[m + i]}; 23 | 24 | LOOP_UNROLL_4 25 | for(size_t j = k; j < k + t; j++) { 26 | harvey_fwd_butterfly(&a[j], &a[j + t], w1, q); 27 | } 28 | k = k + (2 * t); 29 | } 30 | } 31 | } 32 | 33 | void inv_ntt_ref_harvey(uint64_t a[], 34 | const uint64_t N, 35 | const uint64_t q, 36 | const mul_op_t n_inv, 37 | const uint64_t word_size, 38 | const uint64_t w[], 39 | const uint64_t w_con[]) 40 | { 41 | uint64_t t = 1; 42 | 43 | for(size_t m = N >> 1; m > 1; m >>= 1, t <<= 1) { 44 | size_t k = 0; 45 | for(size_t i = 0; i < m; i++) { 46 | const mul_op_t w1 = {w[m + i], w_con[m + i]}; 47 | 48 | for(size_t j = k; j < k + t; j++) { 49 | harvey_bkw_butterfly(&a[j], &a[j + t], w1, q); 50 | } 51 | k = k + (2 * t); 52 | } 53 | } 54 | 55 | // Final round - the harvey_bkw_butterfly, where the output is multiplies by 56 | // n_inv. Here m=1, k=0, t=N/2. 57 | const __uint128_t tmp = fast_mul_mod_q2(n_inv, w[1], q); 58 | 59 | // We can speed up this integer devision by using barreto reduction. 60 | // However, as it happens only once we keep the code simple. 61 | const mul_op_t w1 = {tmp, (tmp << word_size) / q}; 62 | 63 | for(size_t j = 0; j < t; j++) { 64 | harvey_bkw_butterfly_final(&a[j], &a[j + t], w1, n_inv, q); 65 | } 66 | } 67 | 68 | /****************************** 69 | Double input 70 | ******************************/ 71 | void fwd_ntt_ref_harvey_lazy_dbl(uint64_t a1[], 72 | uint64_t a2[], 73 | const uint64_t N, 74 | const uint64_t q, 75 | const uint64_t w[], 76 | const uint64_t w_con[]) 77 | { 78 | uint64_t t = N >> 1; 79 | 80 | for(size_t m = 1; m < N; m <<= 1, t >>= 1) { 81 | size_t k = 0; 82 | for(size_t i = 0; i < m; i++) { 83 | const mul_op_t w1 = {w[m + i], w_con[m + i]}; 84 | for(size_t j = k; j < k + t; j++) { 85 | harvey_fwd_butterfly(&a1[j], &a1[j + t], w1, q); 86 | harvey_fwd_butterfly(&a2[j], &a2[j + t], w1, q); 87 | } 88 | k = k + (2 * t); 89 | } 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /tests/bench.c: -------------------------------------------------------------------------------- 1 | // Copyright IBM Inc. All Rights Reserved. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | #include 5 | 6 | #include "measurements.h" 7 | #include "ntt_radix4.h" 8 | #include "ntt_radix4x4.h" 9 | #include "ntt_reference.h" 10 | #include "ntt_seal.h" 11 | #include "tests.h" 12 | #include "utils.h" 13 | 14 | #ifdef S390X 15 | # include "ntt_radix4_s390x_vef.h" 16 | #endif 17 | 18 | #ifdef AVX512_IFMA_SUPPORT 19 | # include "ntt_avx512_ifma.h" 20 | # include "ntt_hexl.h" 21 | #endif 22 | 23 | void report_test_fwd_perf_headers(void) 24 | { 25 | printf(" | fwd " 26 | " | fwd-lazy\n"); 27 | printf("-----------------------------------------------------------------------" 28 | "---------------------"); 29 | printf("--------------------------------------\n"); 30 | printf(" N q"); 31 | printf(" rad2-ref"); 32 | printf(" rad2-SEAL"); 33 | printf(" rad4"); 34 | printf(" rad4x4"); 35 | #ifdef S390X 36 | printf(" rad4-vmsl"); 37 | #elif AVX512_IFMA_SUPPORT 38 | printf(" rad2-hexl"); 39 | printf(" rad2-ifma"); 40 | printf(" rad2-ifma2"); 41 | printf(" r4r2-ifma"); 42 | printf(" r216-ifma"); 43 | #endif 44 | printf(" rad2-dbl"); 45 | #ifdef S390X 46 | printf(" rad4v-dbl"); 47 | #endif 48 | 49 | printf(" rad2-ref"); 50 | printf(" rad2-SEAL"); 51 | printf(" rad4"); 52 | #ifdef S390X 53 | printf(" rad4-vmsl"); 54 | #endif 55 | printf("\n"); 56 | } 57 | 58 | static inline void test_fwd_perf(const test_case_t *t, 59 | uint64_t * a, 60 | uint64_t * b, 61 | const uint64_t * a_cpy) 62 | { 63 | const uint64_t q = t->q; 64 | const uint64_t n = t->n; 65 | 66 | printf("%3.0lu 0x%14.0lx ", t->m, t->q); 67 | 68 | MEASURE(fwd_ntt_ref_harvey(a, n, q, t->w_powers.ptr, t->w_powers_con.ptr)); 69 | memcpy(a, a_cpy, n * sizeof(uint64_t)); 70 | 71 | MEASURE(fwd_ntt_seal(a, n, q, t->w_powers.ptr, t->w_powers_con.ptr)); 72 | memcpy(a, a_cpy, n * sizeof(uint64_t)); 73 | 74 | MEASURE(fwd_ntt_radix4(a, n, q, t->w_powers_r4.ptr, t->w_powers_con_r4.ptr)); 75 | memcpy(a, a_cpy, n * sizeof(uint64_t)); 76 | 77 | MEASURE(fwd_ntt_radix4x4(a, n, q, t->w_powers_r4.ptr, t->w_powers_con_r4.ptr)); 78 | memcpy(a, a_cpy, n * sizeof(uint64_t)); 79 | 80 | #ifdef S390X 81 | MEASURE(fwd_ntt_radix4_intrinsic(a, n, q, t->w_powers_r4.ptr, 82 | t->w_powers_con_r4_vmsl.ptr)); 83 | memcpy(a, a_cpy, n * sizeof(uint64_t)); 84 | #elif AVX512_IFMA_SUPPORT 85 | MEASURE(fwd_ntt_radix2_hexl(a, t->n, t->q, t->w_powers_hexl.ptr, 86 | t->w_powers_con_hexl.ptr)); 87 | memcpy(a, a_cpy, n * sizeof(uint64_t)); 88 | 89 | MEASURE(fwd_ntt_radix4_avx512_ifma(a, t->n, t->q, 90 | t->w_powers_r4_avx512_ifma.ptr, 91 | t->w_powers_con_r4_avx512_ifma.ptr)); 92 | memcpy(a, a_cpy, n * sizeof(uint64_t)); 93 | 94 | MEASURE(fwd_ntt_radix4_avx512_ifma_unordered( 95 | a, t->n, t->q, t->w_powers_r4_avx512_ifma_unordered.ptr, 96 | t->w_powers_con_r4_avx512_ifma_unordered.ptr)); 97 | memcpy(a, a_cpy, n * sizeof(uint64_t)); 98 | 99 | MEASURE(fwd_ntt_r4r2_avx512_ifma(a, t->n, t->q, 100 | t->w_powers_r4r2_avx512_ifma.ptr, 101 | t->w_powers_con_r4r2_avx512_ifma.ptr)); 102 | memcpy(a, a_cpy, n * sizeof(uint64_t)); 103 | 104 | MEASURE(fwd_ntt_r2_16_avx512_ifma(a, t->n, t->q, 105 | t->w_powers_r2_16_avx512_ifma.ptr, 106 | t->w_powers_con_r2_16_avx512_ifma.ptr)); 107 | memcpy(a, a_cpy, n * sizeof(uint64_t)); 108 | #endif 109 | 110 | MEASURE(fwd_ntt_ref_harvey_dbl(a, b, t->n, t->q, t->w_powers.ptr, 111 | t->w_powers_con.ptr)); 112 | memcpy(a, a_cpy, n * sizeof(uint64_t)); 113 | memcpy(b, a_cpy, n * sizeof(uint64_t)); 114 | 115 | #ifdef S390X 116 | MEASURE(fwd_ntt_radix4_intrinsic_dbl(a, b, t->n, t->q, t->w_powers_r4.ptr, 117 | t->w_powers_con_r4_vmsl.ptr)); 118 | #endif 119 | 120 | memcpy(a, a_cpy, n * sizeof(uint64_t)); 121 | memcpy(b, a_cpy, n * sizeof(uint64_t)); 122 | 123 | MEASURE( 124 | fwd_ntt_ref_harvey_lazy(a, n, q, t->w_powers.ptr, t->w_powers_con.ptr);); 125 | memcpy(a, a_cpy, n * sizeof(uint64_t)); 126 | 127 | MEASURE(fwd_ntt_seal_lazy(a, n, q, t->w_powers.ptr, t->w_powers_con.ptr);); 128 | memcpy(a, a_cpy, n * sizeof(uint64_t)); 129 | 130 | MEASURE( 131 | fwd_ntt_radix4_lazy(a, n, q, t->w_powers_r4.ptr, t->w_powers_con_r4.ptr);); 132 | memcpy(a, a_cpy, n * sizeof(uint64_t)); 133 | 134 | #ifdef S390X 135 | MEASURE(fwd_ntt_radix4_intrinsic_lazy(a, n, q, t->w_powers_r4.ptr, 136 | t->w_powers_con_r4_vmsl.ptr)); 137 | #endif 138 | 139 | printf("\n"); 140 | } 141 | 142 | void test_aligned_fwd_perf(const test_case_t *t) 143 | { 144 | const uint64_t n = t->n; 145 | const uint64_t q = t->q; 146 | 147 | // We use a_cpy to reset a after every NTT call. 148 | // This is especially important when dealing with the lazy evaluation functions 149 | // To avoid overflowing and therefore slowdowns of VMSL. 150 | ALIGN(64) uint64_t a[n]; 151 | ALIGN(64) uint64_t b[n]; 152 | ALIGN(64) uint64_t a_cpy[n]; 153 | random_buf(a, n, q); 154 | memcpy(a_cpy, a, sizeof(a)); 155 | memcpy(b, a, sizeof(a)); 156 | 157 | test_fwd_perf(t, a, b, a_cpy); 158 | } 159 | 160 | void test_unaligned_fwd_perf(const test_case_t *t) 161 | { 162 | const uint64_t n = t->n; 163 | const uint64_t q = t->q; 164 | 165 | // We use a_cpy to reset a after every NTT call. 166 | // This is especially important when dealing with the lazy evaluation functions 167 | // To avoid overflowing and therefore slowdowns of VMSL. 168 | unaligned64_ptr_t a; 169 | unaligned64_ptr_t b; 170 | unaligned64_ptr_t a_cpy; 171 | allocate_unaligned_array(&a, n); 172 | allocate_unaligned_array(&b, n); 173 | allocate_unaligned_array(&a_cpy, n); 174 | 175 | random_buf(a.ptr, n, q); 176 | memset(a_cpy.ptr, 0, n * sizeof(uint64_t)); 177 | memcpy(a_cpy.ptr, a.ptr, n * sizeof(uint64_t)); 178 | memset(b.ptr, 0, n * sizeof(uint64_t)); 179 | memcpy(b.ptr, a.ptr, n * sizeof(uint64_t)); 180 | 181 | test_fwd_perf(t, a.ptr, b.ptr, a_cpy.ptr); 182 | 183 | free_unaligned_array(&a); 184 | free_unaligned_array(&b); 185 | free_unaligned_array(&a_cpy); 186 | } 187 | 188 | void report_test_inv_perf_headers(void) 189 | { 190 | printf(" | inv\n"); 191 | printf("------------------------------------------------------\n"); 192 | printf(" N q"); 193 | 194 | printf(" rad2-ref"); 195 | printf(" rad2-SEAL"); 196 | printf(" rad4"); 197 | 198 | #ifdef S390X 199 | printf(" rad4-vmsl"); 200 | #endif 201 | 202 | printf("\n"); 203 | } 204 | 205 | void test_inv_perf(const test_case_t *t) 206 | { 207 | const uint64_t n = t->n; 208 | const uint64_t q = t->q; 209 | 210 | printf("%3.0lu 0x%14.0lx ", t->m, t->q); 211 | 212 | // We use a_cpy to reset a after every NTT call. 213 | // This is especially important when dealing with the lazy evaluation functions 214 | // To avoid overflowing and therefore slowdowns of VMSL. 215 | uint64_t a[n]; 216 | uint64_t a_cpy[n]; 217 | random_buf(a, n, q); 218 | memcpy(a_cpy, a, sizeof(a)); 219 | 220 | MEASURE(inv_ntt_ref_harvey(a, n, q, t->n_inv, WORD_SIZE, t->w_inv_powers.ptr, 221 | t->w_inv_powers_con.ptr)); 222 | memcpy(a, a_cpy, sizeof(a)); 223 | 224 | MEASURE(inv_ntt_seal(a, t->n, t->q, t->n_inv.op, t->n_inv.con, 225 | t->w_inv_powers.ptr, t->w_inv_powers_con.ptr)); 226 | memcpy(a, a_cpy, sizeof(a)); 227 | 228 | MEASURE(inv_ntt_radix4(a, n, q, t->n_inv, t->w_inv_powers_r4.ptr, 229 | t->w_inv_powers_con_r4.ptr)); 230 | memcpy(a, a_cpy, sizeof(a)); 231 | 232 | #ifdef S390X 233 | MEASURE(inv_ntt_radix4_intrinsic(a, n, q, t->n_inv_vmsl, t->w_inv_powers_r4.ptr, 234 | t->w_inv_powers_con_r4_vmsl.ptr)); 235 | #endif 236 | 237 | printf("\n"); 238 | } 239 | 240 | void test_fwd_single_case(const test_case_t *t, const func_num_t func_num) 241 | { 242 | const uint64_t n = t->n; 243 | const uint64_t q = t->q; 244 | 245 | // We use a_cpy to reset a after every NTT call. 246 | // This is especially important when dealing with the lazy evaluation functions 247 | // To avoid overflowing and therefore slowdowns of VMSL. 248 | uint64_t a[n]; 249 | uint64_t a_cpy[n]; 250 | random_buf(a, n, q); 251 | memcpy(a_cpy, a, sizeof(a)); 252 | 253 | switch(func_num) { 254 | case FWD_REF: 255 | MEASURE(fwd_ntt_ref_harvey(a, n, q, t->w_powers.ptr, t->w_powers_con.ptr)); 256 | break; 257 | case FWD_SEAL: 258 | MEASURE(fwd_ntt_seal(a, n, q, t->w_powers.ptr, t->w_powers_con.ptr)); 259 | break; 260 | case FWD_R4: 261 | MEASURE( 262 | fwd_ntt_radix4(a, n, q, t->w_powers_r4.ptr, t->w_powers_con_r4.ptr)); 263 | break; 264 | case FWD_R4x4: 265 | MEASURE( 266 | fwd_ntt_radix4x4(a, n, q, t->w_powers_r4.ptr, t->w_powers_con_r4.ptr)); 267 | break; 268 | case FWD_R4_VMSL: 269 | #ifdef S390X 270 | MEASURE(fwd_ntt_radix4_intrinsic(a, n, q, t->w_powers_r4.ptr, 271 | t->w_powers_con_r4_vmsl.ptr)); 272 | break; 273 | #elif AVX512_IFMA_SUPPORT 274 | case FWD_R4_HEXL: 275 | MEASURE(fwd_ntt_radix2_hexl(a, t->n, t->q, t->w_powers_hexl.ptr, 276 | t->w_powers_con_hexl.ptr)); 277 | break; 278 | case FWD_R4_AVX512_IFMA: 279 | MEASURE(fwd_ntt_radix4_avx512_ifma(a, t->n, t->q, 280 | t->w_powers_r4_avx512_ifma.ptr, 281 | t->w_powers_con_r4_avx512_ifma.ptr)); 282 | break; 283 | case FWD_R4_AVX512_IFMA_UNORDERED: 284 | MEASURE(fwd_ntt_radix4_avx512_ifma_unordered( 285 | a, t->n, t->q, t->w_powers_r4_avx512_ifma_unordered.ptr, 286 | t->w_powers_con_r4_avx512_ifma_unordered.ptr)); 287 | break; 288 | case FWD_R4R2_AVX512_IFMA: 289 | MEASURE(fwd_ntt_r4r2_avx512_ifma(a, t->n, t->q, 290 | t->w_powers_r4r2_avx512_ifma.ptr, 291 | t->w_powers_con_r4r2_avx512_ifma.ptr)); 292 | break; 293 | case FWD_R2_R16_AVX512_IFMA: 294 | MEASURE(fwd_ntt_r2_16_avx512_ifma(a, t->n, t->q, 295 | t->w_powers_r2_16_avx512_ifma.ptr, 296 | t->w_powers_con_r2_16_avx512_ifma.ptr)); 297 | break; 298 | #endif 299 | default: break; 300 | } 301 | } 302 | -------------------------------------------------------------------------------- /tests/main.c: -------------------------------------------------------------------------------- 1 | // Copyright IBM Inc. All Rights Reserved. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | #include "pre_compute.h" 5 | #include "tests.h" 6 | 7 | int main(UNUSED int argc, UNUSED char *argv[]) 8 | { 9 | init_test_cases(); 10 | 11 | #ifdef TEST_SPEED 12 | if(argc == 2) { 13 | printf("Testing test 9 and func %ld cycle=", strtol(argv[1], NULL, 0)); 14 | test_fwd_single_case(&tests[9], strtol(argv[1], NULL, 0)); 15 | printf("\n"); 16 | return SUCCESS; 17 | } 18 | 19 | printf("\n\nTesting forward NTT with unaligned inputs\n\n"); 20 | report_test_fwd_perf_headers(); 21 | for(size_t i = 0; i < NUM_OF_TEST_CASES; i++) { 22 | test_unaligned_fwd_perf(&tests[i]); 23 | } 24 | 25 | printf("Testing forward NTT with aligned inputs\n\n"); 26 | report_test_fwd_perf_headers(); 27 | for(size_t i = 0; i < NUM_OF_TEST_CASES; i++) { 28 | test_aligned_fwd_perf(&tests[i]); 29 | } 30 | 31 | printf("Testing inverse NTT with unaligned inputs\n\n"); 32 | report_test_inv_perf_headers(); 33 | for(size_t i = 0; i < NUM_OF_TEST_CASES; i++) { 34 | test_inv_perf(&tests[i]); 35 | } 36 | 37 | #else 38 | 39 | for(size_t i = 0; i < NUM_OF_TEST_CASES; i++) { 40 | printf("Test %2.0lu\n", i); 41 | if(SUCCESS != test_correctness(&tests[i])) { 42 | destroy_test_cases(); 43 | return SUCCESS; 44 | } 45 | } 46 | #endif 47 | 48 | destroy_test_cases(); 49 | return SUCCESS; 50 | } 51 | -------------------------------------------------------------------------------- /tests/measurements.h: -------------------------------------------------------------------------------- 1 | // Copyright IBM Inc. All Rights Reserved. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | #pragma once 5 | 6 | #include "defs.h" 7 | 8 | EXTERNC_BEGIN 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | #ifdef TEST_SPEED 18 | 19 | # ifdef INTEL_SDE 20 | static inline void SDE_SSC_MARK(unsigned int mark_id) 21 | { 22 | int ret_val; 23 | __asm__ __volatile__(".byte 0x64, 0x67, 0x90\n\t" // SSC mark (with ID in ebx) 24 | : "=r"(ret_val) 25 | : "b"(mark_id)); 26 | } 27 | 28 | # define SDE_SSC_START SDE_SSC_MARK(1) 29 | # define SDE_SSC_STOP SDE_SSC_MARK(2) 30 | # define MEASURE(x) \ 31 | SDE_SSC_START; \ 32 | do { \ 33 | x; \ 34 | } while(0); \ 35 | SDE_SSC_STOP 36 | 37 | # else 38 | # define WARMUP 10 39 | # define OUTER_REPEAT 10 40 | # define MEASURE_TIMES 200 41 | 42 | static double start_clk; 43 | static double end_clk; 44 | static double total_clk; 45 | static double temp_clk; 46 | 47 | # define NANO_SEC (1000000000UL) 48 | 49 | static inline uint64_t cpucycles(void) 50 | { 51 | struct timespec ts; 52 | clock_gettime(CLOCK_MONOTONIC, &ts); 53 | return (uint64_t)ts.tv_sec * NANO_SEC + ts.tv_nsec; 54 | ; 55 | } 56 | 57 | # define MEASURE(x) \ 58 | for(size_t warmup_itr = 0; warmup_itr < WARMUP; warmup_itr++) { \ 59 | { \ 60 | x; \ 61 | } \ 62 | } \ 63 | total_clk = DBL_MAX; \ 64 | for(size_t outer_itr = 0; outer_itr < OUTER_REPEAT; outer_itr++) { \ 65 | start_clk = cpucycles(); \ 66 | for(size_t clk_itr = 0; clk_itr < MEASURE_TIMES; clk_itr++) { \ 67 | { \ 68 | x; \ 69 | } \ 70 | } \ 71 | end_clk = cpucycles(); \ 72 | temp_clk = (double)(end_clk - start_clk) / MEASURE_TIMES; \ 73 | if(total_clk > temp_clk) total_clk = temp_clk; \ 74 | } \ 75 | printf("%9.0lu ", (uint64_t)total_clk); 76 | 77 | # endif 78 | #else 79 | # define MEASURE(x) \ 80 | do { \ 81 | x; \ 82 | } while(0) 83 | #endif 84 | 85 | EXTERNC_END 86 | -------------------------------------------------------------------------------- /tests/pre-commit-script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -ex 2 | # Copyright IBM Inc. All Rights Reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | # Avoid removing the "build" directory if the script does not run from the 6 | # package root directory 7 | basedir=`pwd` 8 | if [[ ! -f "$basedir/tests/pre-commit-script.sh" ]]; then 9 | >&2 echo "Script does not run from the root directory" 10 | exit 0 11 | fi 12 | 13 | # Clean previous build content 14 | rm -rf build; 15 | 16 | mkdir build; 17 | cd build; 18 | 19 | # Test clang-format 20 | cmake ..; make format; 21 | rm -rf * 22 | 23 | # Test clang-tidy 24 | CC=clang-12 cmake -DCMAKE_C_CLANG_TIDY="clang-tidy;--format-style=file" .. 25 | make -j20 26 | rm -rf * 27 | 28 | for flag in "" "-DASAN=1" "-DUBSAN=1" ; do 29 | CC=clang-12 cmake $flag ..; 30 | make -j20 31 | ./ntt-variants 32 | rm -rf * 33 | done 34 | -------------------------------------------------------------------------------- /tests/test_cases.h: -------------------------------------------------------------------------------- 1 | // Copyright IBM Inc. All Rights Reserved. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | #pragma once 5 | 6 | #include 7 | 8 | #include "fast_mul_operators.h" 9 | #include "pre_compute.h" 10 | 11 | EXTERNC_BEGIN 12 | 13 | /* 14 | * The pointers in this file are 64 bytes aligned. 15 | */ 16 | typedef struct aligned64_ptr_s { 17 | void * base; 18 | uint64_t *ptr; 19 | } aligned64_ptr_t; 20 | 21 | typedef struct unaligned64_ptr_s { 22 | void * base; 23 | uint64_t *ptr; 24 | } unaligned64_ptr_t; 25 | 26 | static inline int allocate_aligned_array(aligned64_ptr_t *aptr, size_t qw_num) 27 | { 28 | size_t size_to_allocate = qw_num * sizeof(uint64_t) + 64; 29 | if(NULL == ((aptr->base) = malloc(size_to_allocate))) { 30 | printf("Allocation error"); 31 | return ERROR; 32 | } 33 | aptr->ptr = (uint64_t *)(((uint64_t)aptr->base & (~0x3fULL)) + 64); 34 | return SUCCESS; 35 | } 36 | 37 | static inline int allocate_unaligned_array(unaligned64_ptr_t *aptr, size_t qw_num) 38 | { 39 | size_t size_to_allocate = qw_num * sizeof(uint64_t) + 64 + 8; 40 | if(NULL == ((aptr->base) = malloc(size_to_allocate))) { 41 | printf("Allocation error"); 42 | return ERROR; 43 | } 44 | aptr->ptr = (uint64_t *)(((uint64_t)aptr->base & (~0x3fULL)) + 64 + 8); 45 | return SUCCESS; 46 | } 47 | 48 | static inline void free_aligned_array(aligned64_ptr_t *aptr) 49 | { 50 | free(aptr->base); 51 | aptr->base = NULL; 52 | aptr->ptr = NULL; 53 | } 54 | 55 | static inline void free_unaligned_array(unaligned64_ptr_t *aptr) 56 | { 57 | free(aptr->base); 58 | aptr->base = NULL; 59 | aptr->ptr = NULL; 60 | } 61 | 62 | typedef struct test_case_s { 63 | // These parameters are predefined 64 | uint64_t m; 65 | uint64_t q; 66 | uint64_t w; 67 | uint64_t w_inv; // w^(-1) mod q 68 | mul_op_t n_inv; // 2^(-m) mod q 69 | 70 | // These parameters are dinamically computed based on the above values. 71 | uint64_t n; 72 | uint64_t qneg; 73 | uint64_t q2; 74 | uint64_t q4; 75 | aligned64_ptr_t w_powers; 76 | aligned64_ptr_t w_powers_con; 77 | aligned64_ptr_t w_inv_powers; 78 | aligned64_ptr_t w_inv_powers_con; 79 | 80 | // For radix-4 tests 81 | aligned64_ptr_t w_powers_r4; 82 | aligned64_ptr_t w_powers_con_r4; 83 | aligned64_ptr_t w_inv_powers_r4; 84 | aligned64_ptr_t w_inv_powers_con_r4; 85 | 86 | #ifdef S390X 87 | // For radix-4 tests with VMSL (56-bits instead of 64-bits) 88 | aligned64_ptr_t w_powers_con_r4_vmsl; 89 | aligned64_ptr_t w_inv_powers_con_r4_vmsl; 90 | mul_op_t n_inv_vmsl; 91 | #endif 92 | 93 | #ifdef AVX512_IFMA_SUPPORT 94 | // For radix-2 tests with AVX512-IFMA on X86-64 bit platofrms (52-bits) 95 | aligned64_ptr_t w_powers_hexl; 96 | aligned64_ptr_t w_powers_con_hexl; 97 | 98 | aligned64_ptr_t w_powers_r4_avx512_ifma; 99 | aligned64_ptr_t w_powers_con_r4_avx512_ifma; 100 | 101 | aligned64_ptr_t w_powers_r4_avx512_ifma_unordered; 102 | aligned64_ptr_t w_powers_con_r4_avx512_ifma_unordered; 103 | 104 | aligned64_ptr_t w_powers_r4r2_avx512_ifma; 105 | aligned64_ptr_t w_powers_con_r4r2_avx512_ifma; 106 | 107 | aligned64_ptr_t w_powers_r2_16_avx512_ifma; 108 | aligned64_ptr_t w_powers_con_r2_16_avx512_ifma; 109 | #endif 110 | 111 | } test_case_t; 112 | 113 | /* 114 | We used the following sagemath script to generate the values below. 115 | 116 | params = [(7681, 8), (65537, 9), 117 | (65537, 10), (65537, 11), 118 | (65537, 12), (65537, 13), 119 | (65537, 14), (0xc0001, 14), 120 | (0xfff0001, 14), (0x1ffc8001, 14), 121 | (0x7ffe0001, 14), (0xfff88001, 14), 122 | (0x7fffffffe0001, 14), (0x80000001c0001, 14), 123 | (0x80000001c0001, 15),(0x7fffffffe0001, 16) 124 | ] 125 | 126 | for (p, m) in params: 127 | n=2^m 128 | Zp = Integers(p) 129 | w_p = primitive_root(p) 130 | 131 | # Get a primitive 2n'th root of unity w 132 | eta = (p-1)//(2*n) 133 | w = Zp(w_p)^eta 134 | 135 | # Find minimum root 136 | new_w, min_w = w, w 137 | for i in range(2*n): 138 | min_w = min(min_w, new_w) 139 | new_w = new_w * w^2; 140 | print("p=",hex(p), " m=",m, " w=", min_w, " verify=", 1 == min_w^(2*n), " 141 | w_inv=", min_w^-1, " n_inv=",Zp(n)^-1, sep='') 142 | */ 143 | 144 | // We use NOLINT in order to stop clang-tidy from reporting magic numbers 145 | static test_case_t tests[] = { 146 | {.m = 8, .q = 0x1e01, .w = 62, .w_inv = 1115, .n_inv.op = 7651}, // NOLINT 147 | {.m = 9, .q = 0x10001, .w = 431, .w_inv = 55045, .n_inv.op = 65409}, // NOLINT 148 | {.m = 10, .q = 0x10001, .w = 33, .w_inv = 1986, .n_inv.op = 65473}, // NOLINT 149 | {.m = 11, .q = 0x10001, .w = 21, .w_inv = 49933, .n_inv.op = 65505}, // NOLINT 150 | {.m = 12, .q = 0x10001, .w = 13, .w_inv = 15124, .n_inv.op = 65521}, // NOLINT 151 | {.m = 13, .q = 0x10001, .w = 15, .w_inv = 30584, .n_inv.op = 65529}, // NOLINT 152 | {.m = 14, .q = 0x10001, .w = 9, .w_inv = 7282, .n_inv.op = 65533}, // NOLINT 153 | {.m = 14, .q = 0xc0001, .w = 9, .w_inv = 174763, .n_inv.op = 786385}, // NOLINT 154 | {.m = 14, 155 | .q = 0xfff0001, // NOLINT 156 | .w = 10360, // NOLINT 157 | .w_inv = 28987060, // NOLINT 158 | .n_inv.op = 268353541}, // NOLINT 159 | {.m = 14, 160 | .q = 0x1ffc8001, // NOLINT 161 | .w = 101907, 162 | .w_inv = 42191135, // NOLINT 163 | .n_inv.op = 536608783}, // NOLINT 164 | {.m = 14, 165 | .q = 0x7ffe0001, // NOLINT 166 | .w = 320878, 167 | .w_inv = 74168714, // NOLINT 168 | .n_inv.op = 2147221513}, // NOLINT 169 | {.m = 14, 170 | .q = 0xfff88001, // NOLINT 171 | .w = 263641, 172 | .w_inv = 243522111, // NOLINT 173 | .n_inv.op = 4294213663}, // NOLINT 174 | {.m = 14, 175 | .q = 0x7fffffffe0001, // NOLINT 176 | .w = 83051296654, 177 | .w_inv = 374947202223591, // NOLINT 178 | .n_inv.op = 2251662374600713}, // NOLINT 179 | {.m = 14, 180 | .q = 0x80000001c0001, // NOLINT 181 | .w = 72703961923, 182 | .w_inv = 153477749218715, // NOLINT 183 | .n_inv.op = 2251662376566673}, // NOLINT 184 | {.m = 15, 185 | .q = 0x10001, // NOLINT 186 | .w = 3, 187 | .w_inv = 21846, // NOLINT 188 | .n_inv.op = 65535}, // NOLINT 189 | {.m = 15, 190 | .q = 0x80000001c0001, // NOLINT 191 | .w = 82138512871, 192 | .w_inv = 535648572761016, // NOLINT 193 | .n_inv.op = 2251731096043465}, // NOLINT 194 | {.m = 16, // NOLINT 195 | .q = 0x7ffe0001, // NOLINT 196 | .w = 1859, 197 | .w_inv = 1579037640, // NOLINT 198 | .n_inv.op = 2147319811}, 199 | {.m = 16, // NOLINT 200 | .q = 0x7fffffffe0001, // NOLINT 201 | .w = 29454831443, 202 | .w_inv = 520731633805630, // NOLINT 203 | .n_inv.op = 2251765453815811}, // NOLINT 204 | {.m = 17, // NOLINT 205 | .q = 0x100180001, // NOLINT 206 | .w = 79247, 207 | .w_inv = 4203069932, // NOLINT 208 | .n_inv.op = 4296507381}}; // NOLINT 209 | 210 | #define NUM_OF_TEST_CASES (sizeof(tests) / sizeof(test_case_t)) 211 | 212 | static inline int _init_test(test_case_t *t) 213 | { 214 | // For brevity 215 | const uint64_t q = t->q; 216 | const uint64_t w = t->w; 217 | const uint64_t m = t->m; 218 | const uint64_t w_inv = t->w_inv; 219 | const uint64_t n = 1UL << t->m; 220 | 221 | t->n = n; 222 | t->n_inv.con = calc_ninv_con(t->n_inv.op, q, WORD_SIZE); 223 | t->q2 = 2 * q; 224 | t->q4 = 4 * q; 225 | 226 | // Prepare radix-2 w-powers 227 | allocate_aligned_array(&t->w_powers, n); 228 | calc_w(t->w_powers.ptr, w, n, q, m); 229 | 230 | allocate_aligned_array(&t->w_powers_con, n); 231 | calc_w_con(t->w_powers_con.ptr, t->w_powers.ptr, n, q, WORD_SIZE); 232 | 233 | allocate_aligned_array(&t->w_inv_powers, n); 234 | calc_w_inv(t->w_inv_powers.ptr, w_inv, n, q, m); 235 | 236 | allocate_aligned_array(&t->w_inv_powers_con, n); 237 | calc_w_con(t->w_inv_powers_con.ptr, t->w_inv_powers.ptr, n, q, WORD_SIZE); 238 | 239 | // Expand the list of powers to support the radix-4 case. 240 | allocate_aligned_array(&t->w_powers_r4, 2 * n); 241 | expand_w(t->w_powers_r4.ptr, t->w_powers.ptr, n, q); 242 | 243 | allocate_aligned_array(&t->w_powers_con_r4, 2 * n); 244 | calc_w_con(t->w_powers_con_r4.ptr, t->w_powers_r4.ptr, 2 * n, q, WORD_SIZE); 245 | 246 | allocate_aligned_array(&t->w_inv_powers_r4, 2 * n); 247 | expand_w(t->w_inv_powers_r4.ptr, t->w_inv_powers.ptr, n, q); 248 | 249 | allocate_aligned_array(&t->w_inv_powers_con_r4, 2 * n); 250 | calc_w_con(t->w_inv_powers_con_r4.ptr, t->w_inv_powers_r4.ptr, 2 * n, q, 251 | WORD_SIZE); 252 | 253 | #ifdef S390X 254 | t->n_inv_vmsl.con = calc_ninv_con(&t->n_inv.op, q, VMSL_WORD_SIZE); 255 | t->n_inv_vmsl.op = t->n_inv.op; 256 | 257 | // for radix-4 vmsl 258 | allocate_aligned_array(&t->w_powers_con_r4_vmsl, 2 * n); 259 | calc_w_con(t->w_powers_con_r4_vmsl.ptr, t->w_powers_r4.ptr, 2 * t->n, t->q, 260 | VMSL_WORD_SIZE); 261 | 262 | allocate_aligned_array(&t->w_inv_powers_con_r4_vmsl, 2 * n); 263 | calc_w_con(t->w_inv_powers_con_r4_vmsl.ptr, t->w_inv_powers_r4.ptr, 2 * t->n, 264 | t->q, VMSL_WORD_SIZE); 265 | #endif 266 | 267 | #ifdef AVX512_IFMA_SUPPORT 268 | // For avx512-ifma 269 | // In fact, we only need to allocate 1.25n but we allocate 2n just in case. 270 | allocate_aligned_array(&t->w_powers_hexl, 2 * n); 271 | expand_w_hexl(t->w_powers_hexl.ptr, t->w_powers.ptr, n); 272 | 273 | allocate_aligned_array(&t->w_powers_con_hexl, n * 2); 274 | calc_w_con(t->w_powers_con_hexl.ptr, t->w_powers_hexl.ptr, n * 2, q, 275 | AVX512_IFMA_WORD_SIZE); 276 | 277 | allocate_aligned_array(&t->w_powers_r4_avx512_ifma, n * 5); 278 | expand_w_r4_avx512_ifma(t->w_powers_r4_avx512_ifma.ptr, t->w_powers.ptr, n, q, 279 | 0); 280 | 281 | allocate_aligned_array(&t->w_powers_con_r4_avx512_ifma, n * 5); 282 | calc_w_con(t->w_powers_con_r4_avx512_ifma.ptr, t->w_powers_r4_avx512_ifma.ptr, 283 | 5 * n, q, AVX512_IFMA_WORD_SIZE); 284 | 285 | allocate_aligned_array(&t->w_powers_r4_avx512_ifma_unordered, n * 5); 286 | expand_w_r4_avx512_ifma(t->w_powers_r4_avx512_ifma_unordered.ptr, 287 | t->w_powers.ptr, n, q, 1); 288 | 289 | allocate_aligned_array(&t->w_powers_con_r4_avx512_ifma_unordered, n * 5); 290 | calc_w_con(t->w_powers_con_r4_avx512_ifma_unordered.ptr, 291 | t->w_powers_r4_avx512_ifma_unordered.ptr, n * 5, q, 292 | AVX512_IFMA_WORD_SIZE); 293 | 294 | allocate_aligned_array(&t->w_powers_r4r2_avx512_ifma, n * 5); 295 | expand_w_r4r2_avx512_ifma(t->w_powers_r4r2_avx512_ifma.ptr, t->w_powers.ptr, n, 296 | q); 297 | 298 | allocate_aligned_array(&t->w_powers_con_r4r2_avx512_ifma, n * 5); 299 | calc_w_con(t->w_powers_con_r4r2_avx512_ifma.ptr, 300 | t->w_powers_r4r2_avx512_ifma.ptr, n * 5, q, AVX512_IFMA_WORD_SIZE); 301 | 302 | allocate_aligned_array(&t->w_powers_r2_16_avx512_ifma, n * 3); 303 | expand_w_r2_16_avx512_ifma(t->w_powers_r2_16_avx512_ifma.ptr, t->w_powers.ptr, 304 | n); 305 | 306 | allocate_aligned_array(&t->w_powers_con_r2_16_avx512_ifma, n * 3); 307 | calc_w_con(t->w_powers_con_r2_16_avx512_ifma.ptr, 308 | t->w_powers_r2_16_avx512_ifma.ptr, n * 3, q, AVX512_IFMA_WORD_SIZE); 309 | #endif 310 | return 1; 311 | } 312 | 313 | static inline int init_test_cases(void) 314 | { 315 | for(size_t i = 0; i < NUM_OF_TEST_CASES; i++) { 316 | if(!_init_test(&tests[i])) { 317 | return 0; 318 | } 319 | } 320 | return 1; 321 | } 322 | 323 | static inline void _destroy_test(test_case_t *t) 324 | { 325 | // for radix-2 326 | free_aligned_array(&t->w_powers); 327 | free_aligned_array(&t->w_powers_con); 328 | free_aligned_array(&t->w_inv_powers); 329 | free_aligned_array(&t->w_inv_powers_con); 330 | 331 | // for radix-4 332 | free_aligned_array(&t->w_powers_r4); 333 | free_aligned_array(&t->w_powers_con_r4); 334 | free_aligned_array(&t->w_inv_powers_r4); 335 | free_aligned_array(&t->w_inv_powers_con_r4); 336 | 337 | #ifdef S390X 338 | // for VMSL 339 | free_aligned_array(&t->w_powers_con_r4_vmsl); 340 | free_aligned_array(&t->w_inv_powers_con_r4_vmsl); 341 | 342 | #endif 343 | #ifdef AVX512_IFMA_SUPPORT 344 | // for AVX512-IFMA 345 | free_aligned_array(&t->w_powers_hexl); 346 | free_aligned_array(&t->w_powers_con_hexl); 347 | 348 | free_aligned_array(&t->w_powers_r4_avx512_ifma); 349 | free_aligned_array(&t->w_powers_con_r4_avx512_ifma); 350 | 351 | free_aligned_array(&t->w_powers_r4_avx512_ifma_unordered); 352 | free_aligned_array(&t->w_powers_con_r4_avx512_ifma_unordered); 353 | 354 | free_aligned_array(&t->w_powers_r4r2_avx512_ifma); 355 | free_aligned_array(&t->w_powers_con_r4r2_avx512_ifma); 356 | 357 | free_aligned_array(&t->w_powers_r2_16_avx512_ifma); 358 | free_aligned_array(&t->w_powers_con_r2_16_avx512_ifma); 359 | #endif 360 | } 361 | 362 | static inline void destroy_test_cases(void) 363 | { 364 | for(size_t i = 0; i < NUM_OF_TEST_CASES; i++) { 365 | _destroy_test(&tests[i]); 366 | } 367 | } 368 | 369 | EXTERNC_END 370 | -------------------------------------------------------------------------------- /tests/test_correctness.c: -------------------------------------------------------------------------------- 1 | // Copyright IBM Inc. All Rights Reserved. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | #include 5 | 6 | #include "ntt_radix4.h" 7 | #include "ntt_radix4x4.h" 8 | #include "ntt_reference.h" 9 | #include "ntt_seal.h" 10 | #include "pre_compute.h" 11 | #include "test_cases.h" 12 | #include "utils.h" 13 | 14 | #ifdef S390X 15 | # include "ntt_radix4_s390x_vef.h" 16 | #endif 17 | 18 | #ifdef AVX512_IFMA_SUPPORT 19 | # include "ntt_avx512_ifma.h" 20 | # include "ntt_hexl.h" 21 | #endif 22 | 23 | static inline int test_radix2_scalar(const test_case_t *t, uint64_t a_orig[]) 24 | { 25 | uint64_t a[t->n]; 26 | memcpy(a, a_orig, sizeof(a)); 27 | 28 | printf("Running fwd_ntt_ref_harvey\n"); 29 | fwd_ntt_ref_harvey(a, t->n, t->q, t->w_powers.ptr, t->w_powers_con.ptr); 30 | 31 | printf("Running inv_ntt_ref_harvey\n"); 32 | inv_ntt_ref_harvey(a, t->n, t->q, t->n_inv, WORD_SIZE, t->w_inv_powers.ptr, 33 | t->w_inv_powers_con.ptr); 34 | 35 | GUARD_MSG(memcmp(a_orig, a, sizeof(a)), "Bad results after radix-2 inv\n"); 36 | 37 | return SUCCESS; 38 | } 39 | 40 | static inline int test_radix2_scalar_dbl(const test_case_t *t, 41 | uint64_t a_orig[], 42 | uint64_t b_orig[], 43 | uint64_t a_ntt[]) 44 | { 45 | uint64_t a[t->n]; 46 | uint64_t b[t->n]; 47 | memcpy(a, a_orig, sizeof(a)); 48 | memcpy(b, b_orig, sizeof(a)); 49 | 50 | printf("Running fwd_ntt_ref_harvey_dbl\n"); 51 | fwd_ntt_ref_harvey_dbl(a, b, t->n, t->q, t->w_powers.ptr, t->w_powers_con.ptr); 52 | 53 | GUARD_MSG(memcmp(a_ntt, a, sizeof(a)), 54 | "Bad results after radix-2 scalar double for a\n"); 55 | GUARD_MSG(memcmp(a_ntt, b, sizeof(a)), 56 | "Bad results after radix-2 scalar double for b\n"); 57 | 58 | return SUCCESS; 59 | } 60 | 61 | static inline int 62 | test_radix2_scalar_seal(const test_case_t *t, uint64_t a_orig[], uint64_t a_ntt[]) 63 | { 64 | uint64_t a[t->n]; 65 | memcpy(a, a_orig, sizeof(a)); 66 | 67 | printf("Running fwd_ntt_seal\n"); 68 | fwd_ntt_seal(a, t->n, t->q, t->w_powers.ptr, t->w_powers_con.ptr); 69 | GUARD_MSG(memcmp(a_ntt, a, sizeof(a)), 70 | "Bad results after radix-2 SEAL fwd implementation\n"); 71 | 72 | printf("Running inv_ntt_seal\n"); 73 | inv_ntt_seal(a, t->n, t->q, t->n_inv.op, t->n_inv.con, t->w_inv_powers.ptr, 74 | t->w_inv_powers_con.ptr); 75 | GUARD_MSG(memcmp(a_orig, a, sizeof(a)), 76 | "Bad results after radix-2 SEAL inv implementation\n"); 77 | 78 | return SUCCESS; 79 | } 80 | 81 | static inline int 82 | test_radix4_scalar(const test_case_t *t, uint64_t a_orig[], uint64_t a_ntt[]) 83 | { 84 | uint64_t a[t->n]; 85 | memcpy(a, a_orig, sizeof(a)); 86 | 87 | printf("Running fwd_ntt_radix4\n"); 88 | fwd_ntt_radix4(a, t->n, t->q, t->w_powers_r4.ptr, t->w_powers_con_r4.ptr); 89 | GUARD_MSG(memcmp(a_ntt, a, sizeof(a)), "Bad results after radix-4 fwd\n"); 90 | 91 | printf("Running inv_ntt_radix4\n"); 92 | inv_ntt_radix4(a, t->n, t->q, t->n_inv, t->w_inv_powers_r4.ptr, 93 | t->w_inv_powers_con_r4.ptr); 94 | 95 | GUARD_MSG(memcmp(a_orig, a, sizeof(a)), "Bad results after radix-4 inv\n"); 96 | 97 | return SUCCESS; 98 | } 99 | 100 | static inline int 101 | test_radix4x4_scalar(const test_case_t *t, uint64_t a_orig[], uint64_t a_ntt[]) 102 | { 103 | uint64_t a[t->n]; 104 | memcpy(a, a_orig, sizeof(a)); 105 | 106 | printf("Running fwd_ntt_radix4x4\n"); 107 | fwd_ntt_radix4x4(a, t->n, t->q, t->w_powers_r4.ptr, t->w_powers_con_r4.ptr); 108 | GUARD_MSG(memcmp(a_ntt, a, sizeof(a)), "Bad results after radix-4x4 fwd\n"); 109 | 110 | return SUCCESS; 111 | } 112 | 113 | #ifdef S390X 114 | static inline int 115 | test_radix4_intrinsic(const test_case_t *t, uint64_t a_orig[], uint64_t a_ntt[]) 116 | { 117 | uint64_t a[t->n]; 118 | memcpy(a, a_orig, sizeof(a)); 119 | 120 | printf("Running fwd_ntt_radix4_intrinsic\n"); 121 | fwd_ntt_radix4_intrinsic(a, t->n, t->q, t->w_powers_r4.ptr, 122 | t->w_powers_con_r4_vmsl.ptr); 123 | GUARD_MSG(memcmp(a_ntt, a, sizeof(a)), 124 | "Bad results after radix-4 with intrinsic fwd\n"); 125 | 126 | printf("Running inv_ntt_radix4_intrinsic\n"); 127 | inv_ntt_radix4_intrinsic(a, t->n, t->q, t->n_inv_vmsl, t->w_inv_powers_r4.ptr, 128 | t->w_inv_powers_con_r4_vmsl.ptr); 129 | 130 | GUARD_MSG(memcmp(a_orig, a, sizeof(a)), 131 | "Bad results after radix-4 inv with intrinsic\n"); 132 | 133 | return SUCCESS; 134 | } 135 | 136 | static inline int test_radix4_intrinsic_dbl(const test_case_t *t, 137 | uint64_t a_orig[], 138 | uint64_t b_orig[], 139 | uint64_t a_ntt[]) 140 | { 141 | uint64_t a[t->n]; 142 | uint64_t b[t->n]; 143 | memcpy(a, a_orig, sizeof(a)); 144 | memcpy(b, b_orig, sizeof(a)); 145 | 146 | printf("Running fwd_ntt_ref_harvey_dbl\n"); 147 | fwd_ntt_radix4_intrinsic_dbl(a, b, t->n, t->q, t->w_powers_r4.ptr, 148 | t->w_powers_con_r4_vmsl.ptr); 149 | GUARD_MSG(memcmp(a_ntt, a, sizeof(a)), 150 | "Bad results after radix-2 scalar double for a\n"); 151 | GUARD_MSG(memcmp(a_ntt, b, sizeof(b)), 152 | "Bad results after radix-2 scalar double for b\n"); 153 | 154 | return SUCCESS; 155 | } 156 | #endif 157 | 158 | #ifdef AVX512_IFMA_SUPPORT 159 | static inline int 160 | test_radix2_hexl(const test_case_t *t, uint64_t a_orig[], uint64_t a_ntt[]) 161 | { 162 | // We can't test AVX512-IFMA for q>2^49 so we always success. 163 | if(t->q & AVX512_IFMA_MAX_MODULUS_MASK) { 164 | return SUCCESS; 165 | } 166 | 167 | uint64_t a[t->n]; 168 | memcpy(a, a_orig, sizeof(a)); 169 | 170 | printf("Running fwd_ntt_radix2_hexl\n"); 171 | fwd_ntt_radix2_hexl(a, t->n, t->q, t->w_powers_hexl.ptr, 172 | t->w_powers_con_hexl.ptr); 173 | GUARD_MSG(memcmp(a_ntt, a, sizeof(a)), 174 | "Bad results after HEXL radix-2 with AVX512-IFMA intrinsic fwd\n"); 175 | 176 | return SUCCESS; 177 | } 178 | 179 | static inline void fix_a_order(uint64_t *a, uint64_t n) 180 | { 181 | const __m512i idx = _mm512_setr_epi64(0, 4, 8, 12, 16, 20, 24, 28); 182 | 183 | for(size_t i = 0; i < n; i += (4 * 8)) { 184 | __m512i X = LOAD(&a[i + 0]); 185 | __m512i Y = LOAD(&a[i + 8]); 186 | __m512i Z = LOAD(&a[i + 16]); 187 | __m512i T = LOAD(&a[i + 24]); 188 | 189 | SCATTER(&a[i + 0], idx, X, 8); 190 | SCATTER(&a[i + 1], idx, Y, 8); 191 | SCATTER(&a[i + 2], idx, Z, 8); 192 | SCATTER(&a[i + 3], idx, T, 8); 193 | 194 | X = LOAD(&a[i + 0]); 195 | Y = LOAD(&a[i + 8]); 196 | Z = LOAD(&a[i + 16]); 197 | T = LOAD(&a[i + 24]); 198 | 199 | const __m512i X1 = SHUF(X, Y, 0x44); 200 | const __m512i Z1 = SHUF(X, Y, 0xee); 201 | const __m512i Y1 = SHUF(Z, T, 0x44); 202 | const __m512i T1 = SHUF(Z, T, 0xee); 203 | 204 | STORE(&a[i + 0], X1); 205 | STORE(&a[i + 8], Y1); 206 | STORE(&a[i + 16], Z1); 207 | STORE(&a[i + 24], T1); 208 | } 209 | } 210 | 211 | static inline int 212 | test_radix4_avx512_ifma(const test_case_t *t, uint64_t a_orig[], uint64_t a_ntt[]) 213 | { 214 | // We can't test AVX512-IFMA for q>2^49 so we always success. 215 | if(t->q & AVX512_IFMA_MAX_MODULUS_MASK) { 216 | return SUCCESS; 217 | } 218 | 219 | uint64_t a[t->n]; 220 | memcpy(a, a_orig, sizeof(a)); 221 | 222 | printf("Running fwd_ntt_radix4_avx512_ifma\n"); 223 | fwd_ntt_radix4_avx512_ifma(a, t->n, t->q, t->w_powers_r4_avx512_ifma.ptr, 224 | t->w_powers_con_r4_avx512_ifma.ptr); 225 | GUARD_MSG(memcmp(a_ntt, a, sizeof(a)), 226 | "Bad results after radix-4 with AVX512-IFMA intrinsic fwd\n"); 227 | 228 | memcpy(a, a_orig, sizeof(a)); 229 | printf("Running fwd_ntt_radix4_avx512_ifma_unordered\n"); 230 | fwd_ntt_radix4_avx512_ifma_unordered( 231 | a, t->n, t->q, t->w_powers_r4_avx512_ifma_unordered.ptr, 232 | t->w_powers_con_r4_avx512_ifma_unordered.ptr); 233 | fix_a_order(a, t->n); 234 | GUARD_MSG( 235 | memcmp(a_ntt, a, sizeof(a)), 236 | "Bad results after radix-4 with AVX512-IFMA intrinsic unordered fwd\n"); 237 | 238 | memcpy(a, a_orig, sizeof(a)); 239 | printf("Running fwd_ntt_r4r2_avx512_ifma\n"); 240 | fwd_ntt_r4r2_avx512_ifma(a, t->n, t->q, t->w_powers_r4r2_avx512_ifma.ptr, 241 | t->w_powers_con_r4r2_avx512_ifma.ptr); 242 | GUARD_MSG(memcmp(a_ntt, a, sizeof(a)), 243 | "Bad results after r4r2 with AVX512-IFMA intrinsic fwd\n"); 244 | 245 | memcpy(a, a_orig, sizeof(a)); 246 | printf("Running fwd_ntt_r2_16_avx512_ifma\n"); 247 | fwd_ntt_r2_16_avx512_ifma(a, t->n, t->q, t->w_powers_r2_16_avx512_ifma.ptr, 248 | t->w_powers_con_r2_16_avx512_ifma.ptr); 249 | GUARD_MSG(memcmp(a_ntt, a, sizeof(a)), 250 | "Bad results after r2_16 with AVX512-IFMA intrinsic fwd\n"); 251 | 252 | return SUCCESS; 253 | } 254 | #endif 255 | 256 | int test_correctness(const test_case_t *t) 257 | { 258 | // Prepare input 259 | uint64_t a[t->n]; 260 | uint64_t b[t->n]; 261 | uint64_t a_ntt[t->n]; 262 | uint64_t a_cpy[t->n]; 263 | random_buf(a, t->n, t->q); 264 | memcpy(a_cpy, a, sizeof(a)); 265 | memcpy(b, a, sizeof(a)); 266 | 267 | // Prepare a_ntt = NTT(a) 268 | fwd_ntt_ref_harvey(a_cpy, t->n, t->q, t->w_powers.ptr, t->w_powers_con.ptr); 269 | memcpy(a_ntt, a_cpy, sizeof(a_cpy)); 270 | 271 | GUARD(test_radix2_scalar(t, a)); 272 | GUARD(test_radix2_scalar_dbl(t, a, b, a_ntt)); 273 | GUARD(test_radix2_scalar_seal(t, a, a_ntt)) 274 | GUARD(test_radix4_scalar(t, a, a_ntt)) 275 | GUARD(test_radix4x4_scalar(t, a, a_ntt)) 276 | #ifdef S390X 277 | GUARD(test_radix4_intrinsic(t, a, a_ntt)) 278 | GUARD(test_radix4_intrinsic_dbl(t, a, b, a_ntt)) 279 | #elif AVX512_IFMA_SUPPORT 280 | GUARD(test_radix2_hexl(t, a, a_ntt)) 281 | GUARD(test_radix4_avx512_ifma(t, a, a_ntt)) 282 | #endif 283 | 284 | return SUCCESS; 285 | } 286 | -------------------------------------------------------------------------------- /tests/tests.h: -------------------------------------------------------------------------------- 1 | // Copyright IBM Inc. All Rights Reserved. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | #pragma once 5 | 6 | EXTERNC_BEGIN 7 | 8 | #include "test_cases.h" 9 | 10 | typedef enum 11 | { 12 | FIRST_FWD = 0, 13 | FWD_REF = FIRST_FWD, 14 | FWD_SEAL, 15 | FWD_R4, 16 | FWD_R4x4, 17 | FWD_R4_VMSL, 18 | FWD_R4_HEXL, 19 | FWD_R4_AVX512_IFMA, 20 | FWD_R4_AVX512_IFMA_UNORDERED, 21 | FWD_R4R2_AVX512_IFMA, 22 | FWD_R2_R16_AVX512_IFMA, 23 | MAX_FWD = FWD_R2_R16_AVX512_IFMA 24 | } func_num_t; 25 | 26 | #ifdef TEST_SPEED 27 | 28 | void report_test_fwd_perf_headers(void); 29 | void report_test_inv_perf_headers(void); 30 | 31 | void test_aligned_fwd_perf(const test_case_t *t); 32 | void test_unaligned_fwd_perf(const test_case_t *t); 33 | void test_inv_perf(const test_case_t *t); 34 | 35 | void test_fwd_single_case(const test_case_t *t, func_num_t func_num); 36 | 37 | #else 38 | 39 | int test_correctness(const test_case_t *t); 40 | 41 | #endif 42 | 43 | EXTERNC_END 44 | -------------------------------------------------------------------------------- /tests/utils.h: -------------------------------------------------------------------------------- 1 | // Copyright IBM Inc. All Rights Reserved. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | EXTERNC_BEGIN 11 | 12 | static inline void random_buf(uint64_t *values, const size_t n, const uint64_t q) 13 | { 14 | for(size_t i = 0; i < n; i++) { 15 | values[i] = rand() % q; 16 | } 17 | } 18 | 19 | static inline void print(UNUSED const uint64_t *values, UNUSED const size_t n) 20 | { 21 | #ifdef DEBUG 22 | for(size_t i = 0; i < n; i++) { 23 | printf("%lx ", values[i]); 24 | } 25 | printf("\n"); 26 | #endif 27 | } 28 | 29 | EXTERNC_END 30 | -------------------------------------------------------------------------------- /third_party/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright IBM Inc. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | set(THIRD_PARTY_SOURCES 5 | ${THIRD_PARTY_DIR}/seal/ntt_seal.c 6 | ) 7 | 8 | if(X86_64 AND AVX512_IFMA) 9 | set(NTT_SOURCES ${NTT_SOURCES} 10 | ${SRC_DIR}/ntt_radix4_avx512_ifma.c 11 | ) 12 | 13 | set(THIRD_PARTY_SOURCES ${THIRD_PARTY_SOURCES} 14 | ${THIRD_PARTY_DIR}/hexl/fwd-ntt-avx512.c 15 | ) 16 | include_directories(${THIRD_PARTY_DIR}/hexl/) 17 | endif() 18 | 19 | set(TEMP ${CMAKE_C_CLANG_TIDY}) 20 | set(CMAKE_C_CLANG_TIDY "") 21 | 22 | add_library( 23 | third_party OBJECT 24 | 25 | ${THIRD_PARTY_SOURCES} 26 | ) 27 | 28 | set(CMAKE_C_CLANG_TIDY ${TEMP}) 29 | -------------------------------------------------------------------------------- /third_party/README.md: -------------------------------------------------------------------------------- 1 | Third party code 2 | ---------------- 3 | The code in this directory was extracted and modified from 4 | 1) Microsoft SEAL GitHub file (dwthandler.h)[https://github.com/microsoft/SEAL/blob/d045f1beff96dff0fccc7fa0c5acb1493a65338c/native/src/seal/util/dwthandler.h] commit d045f1b on 15 Jun 2021 that has an (MIT license)[https://github.com/microsoft/SEAL/blob/main/LICENSE] 5 | 2) Intel HEXL GitHub (fwd-ntt-avx512.cpp)[https://github.com/intel/hexl/blob/db9535c140227010c5c9d6f34a11054b16f02de7/hexl/ntt/fwd-ntt-avx512.cpp] commit 4d9806f 01 Sep 2021 that has an (Apache 2.0 license)[https://github.com/intel/hexl/blob/main/LICENSE] 6 | 7 | The code was converted from C++ to native C by 8 | - Removing templates and converting all relevant functions to `static inline` functions. 9 | - Converting `reinterpret_cast` and `static_cast` to C-style casting. 10 | 11 | Specifically for HEXL we 12 | - Fixed the `BitShift` parameter to 52 and removed code paths (and branches) to its other values. 13 | - Set the `InputLessThanMod` parameter as an input parameter to the relevant functions. 14 | - Converted the `HEXL_LOOP_UNROLL_N` macros to `LOOP_UNROLL_N` macros. 15 | - Defined the `HEXL_CHECK` and ``HEXL_VLOG` macros as empty macros. 16 | -------------------------------------------------------------------------------- /third_party/hexl/avx512-util.h: -------------------------------------------------------------------------------- 1 | /* 2 | * The code in this file was extracted and modified from 3 | b) Intel HEXL GitHub (fwd-ntt-avx512.cpp)[https://github.com/intel/hexl/blob/db9535c140227010c5c9d6f34a11054b16f02de7/hexl/ntt/fwd-ntt-avx512.cpp] commit 4d9806f 01 Sep 2021 that has an (Apache 2.0 license)[https://github.com/intel/hexl/blob/main/LICENSE] 4 | 5 | The code was converted from C++ to native C by 6 | - Removing templates and converting all relevant functions to `static inline` functions. 7 | - Converting `reinterpret_cast` and `static_cast` to C-style casting. 8 | - Fixed the `BitShift` parameter to 52 and removed code paths (and branches) to its other values. 9 | - Set the `InputLessThanMod` parameter as an input parameter to the relevant functions. 10 | - Converted the `HEXL_LOOP_UNROLL_N` macros to `LOOP_UNROLL_N` macros. 11 | - Defined the `HEXL_CHECK` and ``HEXL_VLOG` macros as empty macros. 12 | */ 13 | 14 | #pragma once 15 | 16 | #include 17 | 18 | #define HEXL_CHECK(...) {} 19 | #define HEXL_CHECK_BOUNDS(...) {} 20 | #define HEXL_VLOG(...) {} 21 | 22 | static inline __m512i _mm512_hexl_mullo_epi_64(__m512i x, __m512i y) { 23 | return _mm512_mullo_epi64(x, y); 24 | } 25 | 26 | static inline __m512i _mm512_hexl_mullo_epi_52(__m512i x, __m512i y) { 27 | __m512i zero = _mm512_set1_epi64(0); 28 | return _mm512_madd52lo_epu64(zero, x, y); 29 | } 30 | 31 | static inline __m512i _mm512_hexl_mullo_add_lo_epi_52(__m512i x, __m512i y, 32 | __m512i z) { 33 | __m512i result = _mm512_madd52lo_epu64(x, y, z); 34 | 35 | // Clear high 12 bits from result 36 | const __m512i two_pow52_min1 = _mm512_set1_epi64((1ULL << 52) - 1); 37 | result = _mm512_and_epi64(result, two_pow52_min1); 38 | return result; 39 | } 40 | 41 | static inline __m512i _mm512_hexl_mullo_add_lo_epi_64(__m512i x, __m512i y, 42 | __m512i z) { 43 | __m512i prod = _mm512_mullo_epi64(y, z); 44 | return _mm512_add_epi64(x, prod); 45 | } 46 | 47 | // Returns x mod q across each 64-bit integer SIMD lanes 48 | // Assumes x < InputModFactor * q in all lanes 49 | static const int InputModFactor = 2; 50 | static inline __m512i _mm512_hexl_small_mod_epu64(__m512i x, __m512i q) { 51 | __m512i* q_times_2 = NULL; 52 | __m512i* q_times_4 = NULL; 53 | HEXL_CHECK(InputModFactor == 1 || InputModFactor == 2 || 54 | InputModFactor == 4 || InputModFactor == 8, 55 | "InputModFactor must be 1, 2, 4, or 8"); 56 | if (InputModFactor == 1) { 57 | return x; 58 | } 59 | if (InputModFactor == 2) { 60 | return _mm512_min_epu64(x, _mm512_sub_epi64(x, q)); 61 | } 62 | if (InputModFactor == 4) { 63 | HEXL_CHECK(q_times_2 != nullptr, "q_times_2 must not be nullptr"); 64 | x = _mm512_min_epu64(x, _mm512_sub_epi64(x, *q_times_2)); 65 | return _mm512_min_epu64(x, _mm512_sub_epi64(x, q)); 66 | } 67 | if (InputModFactor == 8) { 68 | HEXL_CHECK(q_times_2 != nullptr, "q_times_2 must not be nullptr"); 69 | HEXL_CHECK(q_times_4 != nullptr, "q_times_4 must not be nullptr"); 70 | x = _mm512_min_epu64(x, _mm512_sub_epi64(x, *q_times_4)); 71 | x = _mm512_min_epu64(x, _mm512_sub_epi64(x, *q_times_2)); 72 | return _mm512_min_epu64(x, _mm512_sub_epi64(x, q)); 73 | } 74 | HEXL_CHECK(false, "Invalid InputModFactor"); 75 | return x; // Return dummy value 76 | } 77 | 78 | static inline __m512i _mm512_hexl_mulhi_epi_64(__m512i x, __m512i y) { 79 | // https://stackoverflow.com/questions/28807341/simd-signed-with-unsigned-multiplication-for-64-bit-64-bit-to-128-bit 80 | __m512i lo_mask = _mm512_set1_epi64(0x00000000ffffffff); 81 | // Shuffle high bits with low bits in each 64-bit integer => 82 | // x0_lo, x0_hi, x1_lo, x1_hi, x2_lo, x2_hi, ... 83 | __m512i x_hi = _mm512_shuffle_epi32(x, (_MM_PERM_ENUM)0xB1); 84 | // y0_lo, y0_hi, y1_lo, y1_hi, y2_lo, y2_hi, ... 85 | __m512i y_hi = _mm512_shuffle_epi32(y, (_MM_PERM_ENUM)0xB1); 86 | __m512i z_lo_lo = _mm512_mul_epu32(x, y); // x_lo * y_lo 87 | __m512i z_lo_hi = _mm512_mul_epu32(x, y_hi); // x_lo * y_hi 88 | __m512i z_hi_lo = _mm512_mul_epu32(x_hi, y); // x_hi * y_lo 89 | __m512i z_hi_hi = _mm512_mul_epu32(x_hi, y_hi); // x_hi * y_hi 90 | 91 | // x_hi | x_lo 92 | // x y_hi | y_lo 93 | // ------------------------------ 94 | // [x_lo * y_lo] // z_lo_lo 95 | // + [z_lo * y_hi] // z_lo_hi 96 | // + [x_hi * y_lo] // z_hi_lo 97 | // + [x_hi * y_hi] // z_hi_hi 98 | // ^-----------^ <-- only bits needed 99 | // sum_| hi | mid | lo | 100 | 101 | // Low bits of z_lo_lo are not needed 102 | __m512i z_lo_lo_shift = _mm512_srli_epi64(z_lo_lo, 32); 103 | 104 | // [x_lo * y_lo] // z_lo_lo 105 | // + [z_lo * y_hi] // z_lo_hi 106 | // ------------------------ 107 | // | sum_tmp | 108 | // |sum_mid|sum_lo| 109 | __m512i sum_tmp = _mm512_add_epi64(z_lo_hi, z_lo_lo_shift); 110 | __m512i sum_lo = _mm512_and_si512(sum_tmp, lo_mask); 111 | __m512i sum_mid = _mm512_srli_epi64(sum_tmp, 32); 112 | // | |sum_lo| 113 | // + [x_hi * y_lo] // z_hi_lo 114 | // ------------------ 115 | // [ sum_mid2 ] 116 | __m512i sum_mid2 = _mm512_add_epi64(z_hi_lo, sum_lo); 117 | __m512i sum_mid2_hi = _mm512_srli_epi64(sum_mid2, 32); 118 | __m512i sum_hi = _mm512_add_epi64(z_hi_hi, sum_mid); 119 | return _mm512_add_epi64(sum_hi, sum_mid2_hi); 120 | } 121 | 122 | static inline __m512i _mm512_hexl_mulhi_approx_epi_64(__m512i x, __m512i y) { 123 | // https://stackoverflow.com/questions/28807341/simd-signed-with-unsigned-multiplication-for-64-bit-64-bit-to-128-bit 124 | __m512i lo_mask = _mm512_set1_epi64(0x00000000ffffffff); 125 | // Shuffle high bits with low bits in each 64-bit integer => 126 | // x0_lo, x0_hi, x1_lo, x1_hi, x2_lo, x2_hi, ... 127 | __m512i x_hi = _mm512_shuffle_epi32(x, (_MM_PERM_ENUM)0xB1); 128 | // y0_lo, y0_hi, y1_lo, y1_hi, y2_lo, y2_hi, ... 129 | __m512i y_hi = _mm512_shuffle_epi32(y, (_MM_PERM_ENUM)0xB1); 130 | __m512i z_lo_hi = _mm512_mul_epu32(x, y_hi); // x_lo * y_hi 131 | __m512i z_hi_lo = _mm512_mul_epu32(x_hi, y); // x_hi * y_lo 132 | __m512i z_hi_hi = _mm512_mul_epu32(x_hi, y_hi); // x_hi * y_hi 133 | 134 | // x_hi | x_lo 135 | // x y_hi | y_lo 136 | // ------------------------------ 137 | // [x_lo * y_lo] // unused, resulting in approximation 138 | // + [z_lo * y_hi] // z_lo_hi 139 | // + [x_hi * y_lo] // z_hi_lo 140 | // + [x_hi * y_hi] // z_hi_hi 141 | // ^-----------^ <-- only bits needed 142 | // sum_| hi | mid | lo | 143 | 144 | __m512i sum_lo = _mm512_and_si512(z_lo_hi, lo_mask); 145 | __m512i sum_mid = _mm512_srli_epi64(z_lo_hi, 32); 146 | // | |sum_lo| 147 | // + [x_hi * y_lo] // z_hi_lo 148 | // ------------------ 149 | // [ sum_mid2 ] 150 | __m512i sum_mid2 = _mm512_add_epi64(z_hi_lo, sum_lo); 151 | __m512i sum_mid2_hi = _mm512_srli_epi64(sum_mid2, 32); 152 | __m512i sum_hi = _mm512_add_epi64(z_hi_hi, sum_mid); 153 | return _mm512_add_epi64(sum_hi, sum_mid2_hi); 154 | } 155 | 156 | static inline __m512i _mm512_hexl_mulhi_approx_epi_52(__m512i x, __m512i y) { 157 | __m512i zero = _mm512_set1_epi64(0); 158 | return _mm512_madd52hi_epu64(zero, x, y); 159 | } 160 | 161 | static inline __m512i _mm512_hexl_mulhi_epi_52(__m512i x, __m512i y) { 162 | __m512i zero = _mm512_set1_epi64(0); 163 | return _mm512_madd52hi_epu64(zero, x, y); 164 | } 165 | -------------------------------------------------------------------------------- /third_party/hexl/fwd-ntt-avx512.c: -------------------------------------------------------------------------------- 1 | /* 2 | * The code in this file was extracted and modified from 3 | b) Intel HEXL GitHub (fwd-ntt-avx512.cpp)[https://github.com/intel/hexl/blob/db9535c140227010c5c9d6f34a11054b16f02de7/hexl/ntt/fwd-ntt-avx512.cpp] commit 4d9806f 01 Sep 2021 that has an (Apache 2.0 license)[https://github.com/intel/hexl/blob/main/LICENSE] 4 | 5 | The code was converted from C++ to native C by 6 | - Removing templates and converting all relevant functions to `static inline` functions. 7 | - Converting `reinterpret_cast` and `static_cast` to C-style casting. 8 | - Fixed the `BitShift` parameter to 52 and removed code paths (and branches) to its other values. 9 | - Set the `InputLessThanMod` parameter as an input parameter to the relevant functions. 10 | - Converted the `HEXL_LOOP_UNROLL_N` macros to `LOOP_UNROLL_N` macros. 11 | - Defined the `HEXL_CHECK` and ``HEXL_VLOG` macros as empty macros. 12 | */ 13 | // Copyright (C) 2020-2021 Intel Corporation 14 | // SPDX-License-Identifier: Apache-2.0 15 | 16 | #include 17 | #include 18 | #include 19 | 20 | #include "ntt_hexl.h" 21 | #include "ntt-avx512-util.h" 22 | 23 | UNUSED static const int BitShift = 52; 24 | 25 | /// @brief The Harvey butterfly: assume \p X, \p Y in [0, 4q), and return X', Y' 26 | /// in [0, 4q) such that X', Y' = X + WY, X - WY (mod q). 27 | /// @param[in,out] X Input representing 8 64-bit signed integers in SIMD form 28 | /// @param[in,out] Y Input representing 8 64-bit signed integers in SIMD form 29 | /// @param[in] W Root of unity represented as 8 64-bit signed integers in 30 | /// SIMD form 31 | /// @param[in] W_precon Preconditioned \p W for BitShift-bit Barrett 32 | /// reduction 33 | /// @param[in] neg_modulus Negative modulus, i.e. (-q) represented as 8 64-bit 34 | /// signed integers in SIMD form 35 | /// @param[in] twice_modulus Twice the modulus, i.e. 2*q represented as 8 64-bit 36 | /// signed integers in SIMD form 37 | /// @param InputLessThanMod If true, assumes \p X, \p Y < \p q. Otherwise, 38 | /// assumes \p X, \p Y < 4*\p q 39 | /// @details See Algorithm 4 of https://arxiv.org/pdf/1205.2926.pdf 40 | void FwdButterfly(__m512i* X, __m512i* Y, __m512i W, __m512i W_precon, 41 | __m512i neg_modulus, __m512i twice_modulus, const int InputLessThanMod) { 42 | if (!InputLessThanMod) { 43 | *X = _mm512_hexl_small_mod_epu64(*X, twice_modulus); 44 | } 45 | 46 | __m512i T; 47 | __m512i Q = _mm512_hexl_mulhi_epi_52(W_precon, *Y); 48 | __m512i W_Y = _mm512_hexl_mullo_epi_52(W, *Y); 49 | T = _mm512_hexl_mullo_add_lo_epi_52(W_Y, Q, neg_modulus); 50 | 51 | __m512i twice_mod_minus_T = _mm512_sub_epi64(twice_modulus, T); 52 | *Y = _mm512_add_epi64(*X, twice_mod_minus_T); 53 | *X = _mm512_add_epi64(*X, T); 54 | } 55 | 56 | void FwdT1(uint64_t* operand, __m512i v_neg_modulus, __m512i v_twice_mod, 57 | uint64_t m, const uint64_t* W, const uint64_t* W_precon) { 58 | const __m512i* v_W_pt = (const __m512i*)(W); 59 | const __m512i* v_W_precon_pt = (const __m512i*)(W_precon); 60 | size_t j1 = 0; 61 | 62 | // 8 | m guaranteed by n >= 16 63 | LOOP_UNROLL_8 64 | for (size_t i = m / 8; i > 0; --i) { 65 | uint64_t* X = operand + j1; 66 | __m512i* v_X_pt = (__m512i*)(X); 67 | 68 | __m512i v_X; 69 | __m512i v_Y; 70 | LoadFwdInterleavedT1(X, &v_X, &v_Y); 71 | __m512i v_W = _mm512_loadu_si512(v_W_pt++); 72 | __m512i v_W_precon = _mm512_loadu_si512(v_W_precon_pt++); 73 | 74 | FwdButterfly(&v_X, &v_Y, v_W, v_W_precon, v_neg_modulus, v_twice_mod, false); 75 | WriteFwdInterleavedT1(v_X, v_Y, v_X_pt); 76 | 77 | j1 += 16; 78 | } 79 | } 80 | 81 | void FwdT2(uint64_t* operand, __m512i v_neg_modulus, __m512i v_twice_mod, 82 | uint64_t m, const uint64_t* W, const uint64_t* W_precon) { 83 | const __m512i* v_W_pt = (const __m512i*)(W); 84 | const __m512i* v_W_precon_pt = (const __m512i*)(W_precon); 85 | 86 | size_t j1 = 0; 87 | // 4 | m guaranteed by n >= 16 88 | LOOP_UNROLL_4 89 | for (size_t i = m / 4; i > 0; --i) { 90 | uint64_t* X = operand + j1; 91 | __m512i* v_X_pt = (__m512i*)(X); 92 | 93 | __m512i v_X; 94 | __m512i v_Y; 95 | LoadFwdInterleavedT2(X, &v_X, &v_Y); 96 | 97 | __m512i v_W = _mm512_loadu_si512(v_W_pt++); 98 | __m512i v_W_precon = _mm512_loadu_si512(v_W_precon_pt++); 99 | 100 | HEXL_CHECK(ExtractValues(v_W)[0] == ExtractValues(v_W)[1], 101 | "bad v_W " << ExtractValues(v_W)); 102 | HEXL_CHECK(ExtractValues(v_W_precon)[0] == ExtractValues(v_W_precon)[1], 103 | "bad v_W_precon " << ExtractValues(v_W_precon)); 104 | FwdButterfly(&v_X, &v_Y, v_W, v_W_precon, v_neg_modulus, v_twice_mod, false); 105 | 106 | _mm512_storeu_si512(v_X_pt++, v_X); 107 | _mm512_storeu_si512(v_X_pt, v_Y); 108 | 109 | j1 += 16; 110 | } 111 | } 112 | 113 | void FwdT4(uint64_t* operand, __m512i v_neg_modulus, __m512i v_twice_mod, 114 | uint64_t m, const uint64_t* W, const uint64_t* W_precon) { 115 | size_t j1 = 0; 116 | const __m512i* v_W_pt = (const __m512i*)(W); 117 | const __m512i* v_W_precon_pt = (const __m512i*)(W_precon); 118 | 119 | // 2 | m guaranteed by n >= 16 120 | LOOP_UNROLL_4 121 | for (size_t i = m / 2; i > 0; --i) { 122 | uint64_t* X = operand + j1; 123 | __m512i* v_X_pt = (__m512i*)(X); 124 | 125 | __m512i v_X; 126 | __m512i v_Y; 127 | LoadFwdInterleavedT4(X, &v_X, &v_Y); 128 | 129 | __m512i v_W = _mm512_loadu_si512(v_W_pt++); 130 | __m512i v_W_precon = _mm512_loadu_si512(v_W_precon_pt++); 131 | FwdButterfly(&v_X, &v_Y, v_W, v_W_precon, v_neg_modulus, v_twice_mod, false); 132 | 133 | _mm512_storeu_si512(v_X_pt++, v_X); 134 | _mm512_storeu_si512(v_X_pt, v_Y); 135 | 136 | j1 += 16; 137 | } 138 | } 139 | 140 | void FwdT8(uint64_t* operand, __m512i v_neg_modulus, __m512i v_twice_mod, 141 | uint64_t t, uint64_t m, const uint64_t* W, 142 | const uint64_t* W_precon, const int InputLessThanMod) { 143 | size_t j1 = 0; 144 | 145 | LOOP_UNROLL_4 146 | for (size_t i = 0; i < m; i++) { 147 | uint64_t* X = operand + j1; 148 | uint64_t* Y = X + t; 149 | 150 | __m512i v_W = _mm512_set1_epi64((int64_t)(*W++)); 151 | __m512i v_W_precon = _mm512_set1_epi64((int64_t)(*W_precon++)); 152 | 153 | __m512i* v_X_pt = (__m512i*)(X); 154 | __m512i* v_Y_pt = (__m512i*)(Y); 155 | 156 | // assume 8 | t 157 | for (size_t j = t / 8; j > 0; --j) { 158 | __m512i v_X = _mm512_loadu_si512(v_X_pt); 159 | __m512i v_Y = _mm512_loadu_si512(v_Y_pt); 160 | 161 | FwdButterfly(&v_X, &v_Y, v_W, v_W_precon,v_neg_modulus, v_twice_mod, InputLessThanMod); 162 | 163 | _mm512_storeu_si512(v_X_pt++, v_X); 164 | _mm512_storeu_si512(v_Y_pt++, v_Y); 165 | } 166 | j1 += (t << 1); 167 | } 168 | } 169 | 170 | // Correction step needed due to extra copies of roots of unity in the 171 | // AVX512 vectors loaded for FwdT2 and FwdT4 172 | size_t compute_new_W_idx(size_t idx, uint64_t n, uint64_t recursion_depth) { 173 | // Originally, from root of unity vector index to loop: 174 | // [0, N/8) => FwdT8 175 | // [N/8, N/4) => FwdT4 176 | // [N/4, N/2) => FwdT2 177 | // [N/2, N) => FwdT1 178 | // The new mapping from AVX512 root of unity vector index to loop: 179 | // [0, N/8) => FwdT8 180 | // [N/8, 5N/8) => FwdT4 181 | // [5N/8, 9N/8) => FwdT2 182 | // [9N/8, 13N/8) => FwdT1 183 | size_t N = n << recursion_depth; 184 | 185 | // FwdT8 range 186 | if (idx <= N / 8) { 187 | return idx; 188 | } 189 | // FwdT4 range 190 | if (idx <= N / 4) { 191 | return (idx - N / 8) * 4 + (N / 8); 192 | } 193 | // FwdT2 range 194 | if (idx <= N / 2) { 195 | return (idx - N / 4) * 2 + (5 * N / 8); 196 | } 197 | // FwdT1 range 198 | return idx + (5 * N / 8); 199 | } 200 | 201 | void ForwardTransformToBitReverseAVX512( 202 | uint64_t* operand, uint64_t n, uint64_t modulus, 203 | const uint64_t* root_of_unity_powers, 204 | const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor, 205 | uint64_t output_mod_factor, uint64_t recursion_depth, 206 | uint64_t recursion_half) { 207 | HEXL_CHECK(NTT::CheckArguments(n, modulus), ""); 208 | HEXL_CHECK(modulus < MaximumValue(BitShift) / 4, 209 | "modulus " << modulus << " too large for BitShift " << BitShift 210 | << " => maximum value " << MaximumValue(BitShift) / 4); 211 | HEXL_CHECK_BOUNDS(precon_root_of_unity_powers, n, MaximumValue(BitShift), 212 | "precon_root_of_unity_powers too large"); 213 | HEXL_CHECK_BOUNDS(operand, n, MaximumValue(BitShift), "operand too large"); 214 | // Skip input bound checking for recursive steps 215 | HEXL_CHECK_BOUNDS(operand, (recursion_depth == 0) ? n : 0, 216 | input_mod_factor * modulus, 217 | "operand larger than input_mod_factor * modulus (" 218 | << input_mod_factor << " * " << modulus << ")"); 219 | HEXL_CHECK(n >= 16, 220 | "Don't support small transforms. Need n >= 16, got n = " << n); 221 | HEXL_CHECK( 222 | input_mod_factor == 1 || input_mod_factor == 2 || input_mod_factor == 4, 223 | "input_mod_factor must be 1, 2, or 4; got " << input_mod_factor); 224 | HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 4, 225 | "output_mod_factor must be 1 or 4; got " << output_mod_factor); 226 | 227 | uint64_t twice_mod = modulus << 1; 228 | 229 | __m512i v_modulus = _mm512_set1_epi64((int64_t)(modulus)); 230 | __m512i v_neg_modulus = _mm512_set1_epi64(-(int64_t)(modulus)); 231 | __m512i v_twice_mod = _mm512_set1_epi64((int64_t)(twice_mod)); 232 | 233 | HEXL_VLOG(5, "root_of_unity_powers " << std::vector( 234 | root_of_unity_powers, root_of_unity_powers + n)) 235 | HEXL_VLOG(5, 236 | "precon_root_of_unity_powers " << std::vector( 237 | precon_root_of_unity_powers, precon_root_of_unity_powers + n)); 238 | HEXL_VLOG(5, "operand " << std::vector(operand, operand + n)); 239 | 240 | static const size_t base_ntt_size = 1024; 241 | 242 | if (n <= base_ntt_size) { // Perform breadth-first NTT 243 | size_t t = (n >> 1); 244 | size_t m = 1; 245 | size_t W_idx = (m << recursion_depth) + (recursion_half * m); 246 | // First iteration assumes input in [0,p) 247 | if (m < (n >> 3)) { 248 | const uint64_t* W = &root_of_unity_powers[W_idx]; 249 | const uint64_t* W_precon = &precon_root_of_unity_powers[W_idx]; 250 | 251 | if ((input_mod_factor <= 2) && (recursion_depth == 0)) { 252 | FwdT8(operand, v_neg_modulus, v_twice_mod, t, m, W, W_precon, true); 253 | } else { 254 | FwdT8(operand, v_neg_modulus, v_twice_mod, t, m, W, W_precon, false); 255 | } 256 | 257 | t >>= 1; 258 | m <<= 1; 259 | W_idx <<= 1; 260 | } 261 | for (; m < (n >> 3); m <<= 1) { 262 | const uint64_t* W = &root_of_unity_powers[W_idx]; 263 | const uint64_t* W_precon = &precon_root_of_unity_powers[W_idx]; 264 | FwdT8(operand, v_neg_modulus, v_twice_mod, t, m, W, W_precon, false); 265 | t >>= 1; 266 | W_idx <<= 1; 267 | } 268 | 269 | // Do T=4, T=2, T=1 separately 270 | { 271 | size_t new_W_idx = compute_new_W_idx(W_idx, n, recursion_depth); 272 | const uint64_t* W = &root_of_unity_powers[new_W_idx]; 273 | const uint64_t* W_precon = &precon_root_of_unity_powers[new_W_idx]; 274 | FwdT4(operand, v_neg_modulus, v_twice_mod, m, W, W_precon); 275 | 276 | m <<= 1; 277 | W_idx <<= 1; 278 | new_W_idx = compute_new_W_idx(W_idx, n, recursion_depth); 279 | W = &root_of_unity_powers[new_W_idx]; 280 | W_precon = &precon_root_of_unity_powers[new_W_idx]; 281 | FwdT2(operand, v_neg_modulus, v_twice_mod, m, W, W_precon); 282 | 283 | m <<= 1; 284 | W_idx <<= 1; 285 | new_W_idx = compute_new_W_idx(W_idx, n, recursion_depth); 286 | W = &root_of_unity_powers[new_W_idx]; 287 | W_precon = &precon_root_of_unity_powers[new_W_idx]; 288 | FwdT1(operand, v_neg_modulus, v_twice_mod, m, W, W_precon); 289 | } 290 | 291 | if (output_mod_factor == 1) { 292 | // n power of two at least 8 => n divisible by 8 293 | HEXL_CHECK(n % 8 == 0, "n " << n << " not a power of 2"); 294 | __m512i* v_X_pt = (__m512i*)(operand); 295 | for (size_t i = 0; i < n; i += 8) { 296 | __m512i v_X = _mm512_loadu_si512(v_X_pt); 297 | 298 | // Reduce from [0, 4q) to [0, q) 299 | v_X = _mm512_hexl_small_mod_epu64(v_X, v_twice_mod); 300 | v_X = _mm512_hexl_small_mod_epu64(v_X, v_modulus); 301 | 302 | HEXL_CHECK_BOUNDS(ExtractValues(v_X).data(), 8, modulus, 303 | "v_X exceeds bound " << modulus); 304 | 305 | _mm512_storeu_si512(v_X_pt, v_X); 306 | 307 | ++v_X_pt; 308 | } 309 | } 310 | } else { 311 | // Perform depth-first NTT via recursive call 312 | size_t t = (n >> 1); 313 | size_t W_idx = (1ULL << recursion_depth) + recursion_half; 314 | const uint64_t* W = &root_of_unity_powers[W_idx]; 315 | const uint64_t* W_precon = &precon_root_of_unity_powers[W_idx]; 316 | 317 | FwdT8(operand, v_neg_modulus, v_twice_mod, t, 1, W, W_precon, false); 318 | 319 | ForwardTransformToBitReverseAVX512( 320 | operand, n / 2, modulus, root_of_unity_powers, 321 | precon_root_of_unity_powers, input_mod_factor, output_mod_factor, 322 | recursion_depth + 1, recursion_half * 2); 323 | 324 | ForwardTransformToBitReverseAVX512( 325 | &operand[n / 2], n / 2, modulus, root_of_unity_powers, 326 | precon_root_of_unity_powers, input_mod_factor, output_mod_factor, 327 | recursion_depth + 1, recursion_half * 2 + 1); 328 | } 329 | } 330 | 331 | -------------------------------------------------------------------------------- /third_party/hexl/ntt-avx512-util.h: -------------------------------------------------------------------------------- 1 | /* 2 | * The code in this file was extracted and modified from 3 | b) Intel HEXL GitHub (fwd-ntt-avx512.cpp)[https://github.com/intel/hexl/blob/db9535c140227010c5c9d6f34a11054b16f02de7/hexl/ntt/fwd-ntt-avx512.cpp] commit 4d9806f 01 Sep 2021 that has an (Apache 2.0 license)[https://github.com/intel/hexl/blob/main/LICENSE] 4 | 5 | The code was converted from C++ to native C by 6 | - Removing templates and converting all relevant functions to `static inline` functions. 7 | - Converting `reinterpret_cast` and `static_cast` to C-style casting. 8 | - Fixed the `BitShift` parameter to 52 and removed code paths (and branches) to its other values. 9 | - Set the `InputLessThanMod` parameter as an input parameter to the relevant functions. 10 | - Converted the `HEXL_LOOP_UNROLL_N` macros to `LOOP_UNROLL_N` macros. 11 | - Defined the `HEXL_CHECK` and ``HEXL_VLOG` macros as empty macros. 12 | */ 13 | // Copyright (C) 2020-2021 Intel Corporation 14 | // SPDX-License-Identifier: Apache-2.0 15 | 16 | #pragma once 17 | 18 | #include "avx512-util.h" 19 | 20 | // Given input: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 21 | // Returns 22 | // *out1 = _mm512_set_epi64(14, 6, 12, 4, 10, 2, 8, 0); 23 | // *out2 = _mm512_set_epi64(15, 7, 13, 5, 11, 3, 9, 1); 24 | static inline void LoadFwdInterleavedT1(const uint64_t* arg, __m512i* out1, 25 | __m512i* out2) { 26 | const __m512i* arg_512 = (const __m512i*)(arg); 27 | 28 | // 0, 1, 2, 3, 4, 5, 6, 7 29 | __m512i v1 = _mm512_loadu_si512(arg_512++); 30 | // 8, 9, 10, 11, 12, 13, 14, 15 31 | __m512i v2 = _mm512_loadu_si512(arg_512); 32 | 33 | const __m512i perm_idx = _mm512_set_epi64(6, 7, 4, 5, 2, 3, 0, 1); 34 | 35 | // 1, 0, 3, 2, 5, 4, 7, 6 36 | __m512i v1_perm = _mm512_permutexvar_epi64(perm_idx, v1); 37 | // 9, 8, 11, 10, 13, 12, 15, 14 38 | __m512i v2_perm = _mm512_permutexvar_epi64(perm_idx, v2); 39 | 40 | *out1 = _mm512_mask_blend_epi64(0xaa, v1, v2_perm); 41 | *out2 = _mm512_mask_blend_epi64(0xaa, v1_perm, v2); 42 | } 43 | 44 | // Given input: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 45 | // Returns 46 | // *out1 = _mm512_set_epi64(14, 12, 10, 8, 6, 4, 2, 0); 47 | // *out2 = _mm512_set_epi64(15, 13, 11, 9, 7, 5, 3, 1); 48 | static inline void LoadInvInterleavedT1(const uint64_t* arg, __m512i* out1, 49 | __m512i* out2) { 50 | const __m512i vperm_hi_idx = _mm512_set_epi64(6, 4, 2, 0, 7, 5, 3, 1); 51 | const __m512i vperm_lo_idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); 52 | const __m512i vperm2_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); 53 | 54 | const __m512i* arg_512 = (const __m512i*)(arg); 55 | 56 | // 7, 6, 5, 4, 3, 2, 1, 0 57 | __m512i v_7to0 = _mm512_loadu_si512(arg_512++); 58 | // 15, 14, 13, 12, 11, 10, 9, 8 59 | __m512i v_15to8 = _mm512_loadu_si512(arg_512); 60 | // 7, 5, 3, 1, 6, 4, 2, 0 61 | __m512i perm_lo = _mm512_permutexvar_epi64(vperm_lo_idx, v_7to0); 62 | // 14, 12, 10, 8, 15, 13, 11, 9 63 | __m512i perm_hi = _mm512_permutexvar_epi64(vperm_hi_idx, v_15to8); 64 | 65 | *out1 = _mm512_mask_blend_epi64(0x0f, perm_hi, perm_lo); 66 | *out2 = _mm512_mask_blend_epi64(0xf0, perm_hi, perm_lo); 67 | *out2 = _mm512_permutexvar_epi64(vperm2_idx, *out2); 68 | } 69 | 70 | // Given input: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 71 | // Returns 72 | // *out1 = _mm512_set_epi64(13, 12, 9, 8, 5, 4, 1, 0); 73 | // *out2 = _mm512_set_epi64(15, 14, 11, 10, 7, 6, 3, 2) 74 | static inline void LoadFwdInterleavedT2(const uint64_t* arg, __m512i* out1, 75 | __m512i* out2) { 76 | const __m512i* arg_512 = (const __m512i*)(arg); 77 | 78 | // 11, 10, 9, 8, 3, 2, 1, 0 79 | __m512i v1 = _mm512_loadu_si512(arg_512++); 80 | // 15, 14, 13, 12, 7, 6, 5, 4 81 | __m512i v2 = _mm512_loadu_si512(arg_512); 82 | 83 | const __m512i v1_perm_idx = _mm512_set_epi64(5, 4, 7, 6, 1, 0, 3, 2); 84 | 85 | __m512i v1_perm = _mm512_permutexvar_epi64(v1_perm_idx, v1); 86 | __m512i v2_perm = _mm512_permutexvar_epi64(v1_perm_idx, v2); 87 | 88 | *out1 = _mm512_mask_blend_epi64(0xcc, v1, v2_perm); 89 | *out2 = _mm512_mask_blend_epi64(0xcc, v1_perm, v2); 90 | } 91 | 92 | // Given input: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 93 | // Returns 94 | // *out1 = _mm512_set_epi64(14, 12, 10, 8, 6, 4, 2, 0); 95 | // *out2 = _mm512_set_epi64(15, 13, 11, 9, 7, 5, 3, 1); 96 | static inline void LoadInvInterleavedT2(const uint64_t* arg, __m512i* out1, 97 | __m512i* out2) { 98 | const __m512i* arg_512 = (const __m512i*)(arg); 99 | 100 | __m512i v1 = _mm512_loadu_si512(arg_512++); 101 | __m512i v2 = _mm512_loadu_si512(arg_512); 102 | 103 | const __m512i v1_perm_idx = _mm512_set_epi64(6, 7, 4, 5, 2, 3, 0, 1); 104 | 105 | __m512i v1_perm = _mm512_permutexvar_epi64(v1_perm_idx, v1); 106 | __m512i v2_perm = _mm512_permutexvar_epi64(v1_perm_idx, v2); 107 | 108 | *out1 = _mm512_mask_blend_epi64(0xaa, v1, v2_perm); 109 | *out2 = _mm512_mask_blend_epi64(0xaa, v1_perm, v2); 110 | } 111 | 112 | // Returns 113 | // *out1 = _mm512_set_epi64(arg[11], arg[10], arg[9], arg[8], 114 | // arg[3], arg[2], arg[1], arg[0]); 115 | // *out2 = _mm512_set_epi64(arg[15], arg[14], arg[13], arg[12], 116 | // arg[7], arg[6], arg[5], arg[4]); 117 | static inline void LoadFwdInterleavedT4(const uint64_t* arg, __m512i* out1, 118 | __m512i* out2) { 119 | const __m512i* arg_512 = (const __m512i*)(arg); 120 | 121 | const __m512i vperm2_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); 122 | __m512i v_7to0 = _mm512_loadu_si512(arg_512++); 123 | __m512i v_15to8 = _mm512_loadu_si512(arg_512); 124 | __m512i perm_hi = _mm512_permutexvar_epi64(vperm2_idx, v_15to8); 125 | *out1 = _mm512_mask_blend_epi64(0x0f, perm_hi, v_7to0); 126 | *out2 = _mm512_mask_blend_epi64(0xf0, perm_hi, v_7to0); 127 | *out2 = _mm512_permutexvar_epi64(vperm2_idx, *out2); 128 | } 129 | 130 | static inline void LoadInvInterleavedT4(const uint64_t* arg, __m512i* out1, 131 | __m512i* out2) { 132 | const __m512i* arg_512 = (const __m512i*)(arg); 133 | 134 | // 0, 1, 2, 3, 4, 5, 6, 7 135 | __m512i v1 = _mm512_loadu_si512(arg_512++); 136 | // 8, 9, 10, 11, 12, 13, 14, 15 137 | __m512i v2 = _mm512_loadu_si512(arg_512); 138 | const __m512i perm_idx = _mm512_set_epi64(5, 4, 7, 6, 1, 0, 3, 2); 139 | 140 | // 1, 0, 3, 2, 5, 4, 7, 6 141 | __m512i v1_perm = _mm512_permutexvar_epi64(perm_idx, v1); 142 | // 9, 8, 11, 10, 13, 12, 15, 14 143 | __m512i v2_perm = _mm512_permutexvar_epi64(perm_idx, v2); 144 | 145 | *out1 = _mm512_mask_blend_epi64(0xcc, v1, v2_perm); 146 | *out2 = _mm512_mask_blend_epi64(0xcc, v1_perm, v2); 147 | } 148 | 149 | // Given inputs 150 | // @param arg1 = _mm512_set_epi64(15, 14, 13, 12, 11, 10, 9, 8); 151 | // @param arg2 = _mm512_set_epi64(7, 6, 5, 4, 3, 2, 1, 0); 152 | // Writes out = {8, 0, 9, 1, 10, 2, 11, 3, 153 | // 12, 4, 13, 5, 14, 6, 15, 7} 154 | static inline void WriteFwdInterleavedT1(__m512i arg1, __m512i arg2, __m512i* out) { 155 | const __m512i vperm2_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); 156 | const __m512i v_X_out_idx = _mm512_set_epi64(7, 3, 6, 2, 5, 1, 4, 0); 157 | const __m512i v_Y_out_idx = _mm512_set_epi64(3, 7, 2, 6, 1, 5, 0, 4); 158 | 159 | // v_Y => (4, 5, 6, 7, 0, 1, 2, 3) 160 | arg2 = _mm512_permutexvar_epi64(vperm2_idx, arg2); 161 | // 4, 5, 6, 7, 12, 13, 14, 15 162 | __m512i perm_lo = _mm512_mask_blend_epi64(0x0f, arg1, arg2); 163 | 164 | // 8, 9, 10, 11, 0, 1, 2, 3 165 | __m512i perm_hi = _mm512_mask_blend_epi64(0xf0, arg1, arg2); 166 | 167 | arg1 = _mm512_permutexvar_epi64(v_X_out_idx, perm_hi); 168 | arg2 = _mm512_permutexvar_epi64(v_Y_out_idx, perm_lo); 169 | 170 | _mm512_storeu_si512(out++, arg1); 171 | _mm512_storeu_si512(out, arg2); 172 | } 173 | 174 | // Given inputs 175 | // @param arg1 = _mm512_set_epi64(15, 14, 13, 12, 11, 10, 9, 8); 176 | // @param arg2 = _mm512_set_epi64(7, 6, 5, 4, 3, 2, 1, 0); 177 | // Writes out = {8, 9, 10, 11, 0, 1, 2, 3, 178 | // 12, 13, 14, 15, 4, 5, 6, 7} 179 | static inline void WriteInvInterleavedT4(__m512i arg1, __m512i arg2, __m512i* out) { 180 | __m256i x0 = _mm512_extracti64x4_epi64(arg1, 0); 181 | __m256i x1 = _mm512_extracti64x4_epi64(arg1, 1); 182 | __m256i y0 = _mm512_extracti64x4_epi64(arg2, 0); 183 | __m256i y1 = _mm512_extracti64x4_epi64(arg2, 1); 184 | __m256i* out_256 = (__m256i*)(out); 185 | _mm256_storeu_si256(out_256++, x0); 186 | _mm256_storeu_si256(out_256++, y0); 187 | _mm256_storeu_si256(out_256++, x1); 188 | _mm256_storeu_si256(out_256++, y1); 189 | } 190 | 191 | // Returns _mm512_set_epi64(arg[3], arg[3], arg[2], arg[2], 192 | // arg[1], arg[1], arg[0], arg[0]); 193 | static inline __m512i LoadWOpT2(const void* arg) { 194 | const __m512i vperm_w_idx = _mm512_set_epi64(3, 3, 2, 2, 1, 1, 0, 0); 195 | 196 | __m256i v_W_256 = _mm256_loadu_si256((const __m256i*)(arg)); 197 | __m512i v_W = _mm512_broadcast_i64x4(v_W_256); 198 | v_W = _mm512_permutexvar_epi64(vperm_w_idx, v_W); 199 | 200 | return v_W; 201 | } 202 | 203 | // Returns _mm512_set_epi64(arg[1], arg[1], arg[1], arg[1], 204 | // arg[0], arg[0], arg[0], arg[0]); 205 | static inline __m512i LoadWOpT4(const void* arg) { 206 | const __m512i vperm_w_idx = _mm512_set_epi64(1, 1, 1, 1, 0, 0, 0, 0); 207 | 208 | __m128i v_W_128 = _mm_loadu_si128((const __m128i*)(arg)); 209 | __m512i v_W = _mm512_broadcast_i64x2(v_W_128); 210 | v_W = _mm512_permutexvar_epi64(vperm_w_idx, v_W); 211 | 212 | return v_W; 213 | } 214 | -------------------------------------------------------------------------------- /third_party/seal/ntt_seal.c: -------------------------------------------------------------------------------- 1 | // Copyright IBM.com, Inc. or its affiliates. All Rights Reserved. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | // The code belwo was taken and modified from 5 | // https://github.com/microsoft/SEAL/blob/d045f1beff96dff0fccc7fa0c5acb1493a65338c/native/src/seal/util/dwthandler.h 6 | // The license for most of this file is therefore 7 | // 8 | // Copyright (c) Microsoft Corporation. All rights reserved. 9 | // Licensed under the MIT license. 10 | 11 | #include "ntt_seal.h" 12 | #include "defs.h" 13 | 14 | // SEAL uses the following Arithmetic inline functions for the NTT implementation 15 | 16 | static inline uint64_t add(const uint64_t a, const uint64_t b) { return a + b; } 17 | 18 | static inline uint64_t 19 | sub(const uint64_t a, const uint64_t b, const uint64_t two_times_modulus_) 20 | { 21 | return a + two_times_modulus_ - b; 22 | } 23 | 24 | static inline uint64_t mul_root(const uint64_t a, 25 | const uint64_t q, 26 | const uint64_t w, 27 | const uint64_t w_con) 28 | { 29 | unsigned long long tmp1; 30 | tmp1 = (unsigned long long)(((__uint128_t)a * (__uint128_t)w_con) >> WORD_SIZE); 31 | return w * a - tmp1 * q; 32 | } 33 | 34 | static inline uint64_t mul_scalar(const uint64_t a, 35 | const uint64_t q, 36 | const uint64_t s, 37 | const uint64_t s_con) 38 | { 39 | return mul_root(a, q, s, s_con); 40 | } 41 | 42 | static inline uint64_t guard(const uint64_t a, const uint64_t two_times_modulus_) 43 | { 44 | return (a >= two_times_modulus_ ? a - two_times_modulus_ : a); 45 | } 46 | 47 | void fwd_ntt_seal_lazy(uint64_t a[], 48 | const uint64_t N, 49 | const uint64_t q, 50 | const uint64_t w[], 51 | const uint64_t w_con[]) 52 | { 53 | // constant transform size 54 | // Original line: size_t n = size_t(1) << log_n; 55 | size_t n = N; 56 | // registers to hold temporary values 57 | uint64_t u; 58 | uint64_t v; 59 | // pointers for faster indexing 60 | uint64_t *x = NULL; 61 | uint64_t *y = NULL; 62 | // variables for indexing 63 | size_t gap = n >> 1; 64 | size_t m = 1; 65 | uint64_t two_times_modulus_ = q << 1; 66 | 67 | for(; m < (n >> 1); m <<= 1) { 68 | size_t offset = 0; 69 | if(gap < 4) { 70 | for(size_t i = 0; i < m; i++) { 71 | ++w; 72 | ++w_con; 73 | x = a + offset; 74 | y = x + gap; 75 | for(size_t j = 0; j < gap; j++) { 76 | u = guard(*x, two_times_modulus_); 77 | v = mul_root(*y, q, *w, *w_con); 78 | *x++ = add(u, v); 79 | *y++ = sub(u, v, two_times_modulus_); 80 | } 81 | offset += gap << 1; 82 | } 83 | } else { 84 | for(size_t i = 0; i < m; i++) { 85 | ++w; 86 | ++w_con; 87 | x = a + offset; 88 | y = x + gap; 89 | for(size_t j = 0; j < gap; j += 4) { 90 | u = guard(*x, two_times_modulus_); 91 | v = mul_root(*y, q, *w, *w_con); 92 | *x++ = add(u, v); 93 | *y++ = sub(u, v, two_times_modulus_); 94 | 95 | u = guard(*x, two_times_modulus_); 96 | v = mul_root(*y, q, *w, *w_con); 97 | *x++ = add(u, v); 98 | *y++ = sub(u, v, two_times_modulus_); 99 | 100 | u = guard(*x, two_times_modulus_); 101 | v = mul_root(*y, q, *w, *w_con); 102 | *x++ = add(u, v); 103 | *y++ = sub(u, v, two_times_modulus_); 104 | 105 | u = guard(*x, two_times_modulus_); 106 | v = mul_root(*y, q, *w, *w_con); 107 | *x++ = add(u, v); 108 | *y++ = sub(u, v, two_times_modulus_); 109 | } 110 | offset += gap << 1; 111 | } 112 | } 113 | gap >>= 1; 114 | } 115 | 116 | for(size_t i = 0; i < m; i++) { 117 | ++w; 118 | ++w_con; 119 | u = guard(a[0], two_times_modulus_); 120 | v = mul_root(a[1], q, *w, *w_con); 121 | a[0] = add(u, v); 122 | a[1] = sub(u, v, two_times_modulus_); 123 | a += 2; 124 | } 125 | } 126 | 127 | void inv_ntt_seal(uint64_t a[], 128 | const uint64_t N, 129 | const uint64_t q, 130 | const uint64_t n_inv, 131 | const uint64_t n_inv_con, 132 | const uint64_t w[], 133 | const uint64_t w_con[]) 134 | { 135 | // constant transform size 136 | // Original line: size_t n = size_t(1) << log_n; 137 | size_t n = N; 138 | // registers to hold temporary values 139 | uint64_t u; 140 | uint64_t v; 141 | // pointers for faster indexing 142 | uint64_t *x = NULL; 143 | uint64_t *y = NULL; 144 | // variables for indexing 145 | size_t gap = 1; 146 | size_t m = n >> 1; 147 | uint64_t two_times_modulus_ = q << 1; 148 | 149 | for(; m > 1; m >>= 1) { 150 | size_t offset = 0; 151 | if(gap < 4) { 152 | for(size_t i = 0; i < m; i++) { 153 | x = a + offset; 154 | y = x + gap; 155 | for(size_t j = 0; j < gap; j++) { 156 | u = *x; 157 | v = *y; 158 | *x++ = guard(add(u, v), two_times_modulus_); 159 | *y++ = 160 | mul_root(sub(u, v, two_times_modulus_), q, w[m + i], w_con[m + i]); 161 | } 162 | offset += gap << 1; 163 | } 164 | } else { 165 | for(size_t i = 0; i < m; i++) { 166 | x = a + offset; 167 | y = x + gap; 168 | for(size_t j = 0; j < gap; j += 4) { 169 | u = *x; 170 | v = *y; 171 | *x++ = guard(add(u, v), two_times_modulus_); 172 | *y++ = 173 | mul_root(sub(u, v, two_times_modulus_), q, w[m + i], w_con[m + i]); 174 | 175 | u = *x; 176 | v = *y; 177 | *x++ = guard(add(u, v), two_times_modulus_); 178 | *y++ = 179 | mul_root(sub(u, v, two_times_modulus_), q, w[m + i], w_con[m + i]); 180 | 181 | u = *x; 182 | v = *y; 183 | *x++ = guard(add(u, v), two_times_modulus_); 184 | *y++ = 185 | mul_root(sub(u, v, two_times_modulus_), q, w[m + i], w_con[m + i]); 186 | 187 | u = *x; 188 | v = *y; 189 | *x++ = guard(add(u, v), two_times_modulus_); 190 | *y++ = 191 | mul_root(sub(u, v, two_times_modulus_), q, w[m + i], w_con[m + i]); 192 | } 193 | offset += gap << 1; 194 | } 195 | } 196 | gap <<= 1; 197 | } 198 | 199 | // Adaption to meet the current code package style 200 | uint64_t scaled_r = mul_root(w[1], q, n_inv, n_inv_con); 201 | uint64_t scaled_r_con = ((__uint128_t)scaled_r << WORD_SIZE) / q; 202 | 203 | x = a; 204 | y = x + gap; 205 | if(gap < 4) { 206 | for(size_t j = 0; j < gap; j++) { 207 | u = guard(*x, two_times_modulus_); 208 | v = *y; 209 | *x++ = 210 | mul_scalar(guard(add(u, v), two_times_modulus_), q, n_inv, n_inv_con); 211 | *y++ = mul_root(sub(u, v, two_times_modulus_), q, scaled_r, scaled_r_con); 212 | } 213 | } else { 214 | for(size_t j = 0; j < gap; j += 4) { 215 | u = guard(*x, two_times_modulus_); 216 | v = *y; 217 | *x++ = 218 | mul_scalar(guard(add(u, v), two_times_modulus_), q, n_inv, n_inv_con); 219 | *y++ = mul_root(sub(u, v, two_times_modulus_), q, scaled_r, scaled_r_con); 220 | 221 | u = guard(*x, two_times_modulus_); 222 | v = *y; 223 | *x++ = 224 | mul_scalar(guard(add(u, v), two_times_modulus_), q, n_inv, n_inv_con); 225 | *y++ = mul_root(sub(u, v, two_times_modulus_), q, scaled_r, scaled_r_con); 226 | 227 | u = guard(*x, two_times_modulus_); 228 | v = *y; 229 | *x++ = 230 | mul_scalar(guard(add(u, v), two_times_modulus_), q, n_inv, n_inv_con); 231 | *y++ = mul_root(sub(u, v, two_times_modulus_), q, scaled_r, scaled_r_con); 232 | 233 | u = guard(*x, two_times_modulus_); 234 | v = *y; 235 | *x++ = 236 | mul_scalar(guard(add(u, v), two_times_modulus_), q, n_inv, n_inv_con); 237 | *y++ = mul_root(sub(u, v, two_times_modulus_), q, scaled_r, scaled_r_con); 238 | } 239 | } 240 | 241 | for(size_t i = 0; i < N; i++) { 242 | a[i] = (a[i] < q) ? a[i] : a[i] - q; 243 | } 244 | } 245 | --------------------------------------------------------------------------------