├── .clang-format ├── .github └── workflows │ ├── auto-release.yml │ ├── cmake-build-test-arm64.yml │ ├── cmake-build-test-darwin.yml │ ├── cmake-build-test-win64.yml │ └── cmake-build-test.yml ├── .gitignore ├── CMakeLists.txt ├── CONTRIBUTING.md ├── Changelog.md ├── LICENSE ├── README.md ├── docs ├── api-full.svg ├── logo-inpher1.png ├── logo-inpher2.png ├── logo-sandboxaq-black.svg └── logo-sandboxaq-white.svg ├── manifest.yaml ├── scripts ├── auto-release.sh ├── ci-pkg └── prepare-release ├── spqlios ├── CMakeLists.txt ├── arithmetic │ ├── module_api.c │ ├── scalar_vector_product.c │ ├── vec_rnx_api.c │ ├── vec_rnx_approxdecomp_avx.c │ ├── vec_rnx_approxdecomp_ref.c │ ├── vec_rnx_arithmetic.c │ ├── vec_rnx_arithmetic.h │ ├── vec_rnx_arithmetic_avx.c │ ├── vec_rnx_arithmetic_plugin.h │ ├── vec_rnx_arithmetic_private.h │ ├── vec_rnx_conversions_ref.c │ ├── vec_rnx_svp_ref.c │ ├── vec_rnx_vmp_avx.c │ ├── vec_rnx_vmp_ref.c │ ├── vec_znx.c │ ├── vec_znx_arithmetic.h │ ├── vec_znx_arithmetic_private.h │ ├── vec_znx_avx.c │ ├── vec_znx_big.c │ ├── vec_znx_dft.c │ ├── vec_znx_dft_avx2.c │ ├── vector_matrix_product.c │ ├── vector_matrix_product_avx.c │ ├── zn_api.c │ ├── zn_approxdecomp_ref.c │ ├── zn_arithmetic.h │ ├── zn_arithmetic_plugin.h │ ├── zn_arithmetic_private.h │ ├── zn_conversions_ref.c │ ├── zn_vmp_int16_avx.c │ ├── zn_vmp_int16_ref.c │ ├── zn_vmp_int32_avx.c │ ├── zn_vmp_int32_ref.c │ ├── zn_vmp_int8_avx.c │ ├── zn_vmp_int8_ref.c │ ├── zn_vmp_ref.c │ └── znx_small.c ├── coeffs │ ├── coeffs_arithmetic.c │ ├── coeffs_arithmetic.h │ └── coeffs_arithmetic_avx.c ├── commons.c ├── commons.h ├── commons_private.c ├── commons_private.h ├── cplx │ ├── README.md │ ├── cplx_common.c │ ├── cplx_conversions.c │ ├── cplx_conversions_avx2_fma.c │ ├── cplx_execute.c │ ├── cplx_fallbacks_aarch64.c │ ├── cplx_fft.h │ ├── cplx_fft16_avx_fma.s │ ├── cplx_fft16_avx_fma_win32.s │ ├── cplx_fft_asserts.c │ ├── cplx_fft_avx2_fma.c │ ├── cplx_fft_avx512.c │ ├── cplx_fft_internal.h │ ├── cplx_fft_private.h │ ├── cplx_fft_ref.c │ ├── cplx_fft_sse.c │ ├── cplx_fftvec_avx2_fma.c │ ├── cplx_fftvec_ref.c │ ├── cplx_ifft16_avx_fma.s │ ├── cplx_ifft16_avx_fma_win32.s │ ├── cplx_ifft_avx2_fma.c │ ├── cplx_ifft_ref.c │ └── spqlios_cplx_fft.c ├── ext │ └── neon_accel │ │ ├── macrof.h │ │ └── macrofx4.h ├── q120 │ ├── q120_arithmetic.h │ ├── q120_arithmetic_avx2.c │ ├── q120_arithmetic_private.h │ ├── q120_arithmetic_ref.c │ ├── q120_arithmetic_simple.c │ ├── q120_common.h │ ├── q120_fallbacks_aarch64.c │ ├── q120_ntt.c │ ├── q120_ntt.h │ ├── q120_ntt_avx2.c │ └── q120_ntt_private.h ├── reim │ ├── reim_conversions.c │ ├── reim_conversions_avx.c │ ├── reim_execute.c │ ├── reim_fallbacks_aarch64.c │ ├── reim_fft.h │ ├── reim_fft16_avx_fma.s │ ├── reim_fft16_avx_fma_win32.s │ ├── reim_fft4_avx_fma.c │ ├── reim_fft8_avx_fma.c │ ├── reim_fft_avx2.c │ ├── reim_fft_core_template.h │ ├── reim_fft_ifft.c │ ├── reim_fft_internal.h │ ├── reim_fft_neon.c │ ├── reim_fft_private.h │ ├── reim_fft_ref.c │ ├── reim_fftvec_addmul_fma.c │ ├── reim_fftvec_addmul_ref.c │ ├── reim_ifft16_avx_fma.s │ ├── reim_ifft16_avx_fma_win32.s │ ├── reim_ifft4_avx_fma.c │ ├── reim_ifft8_avx_fma.c │ ├── reim_ifft_avx2.c │ ├── reim_ifft_ref.c │ ├── reim_to_tnx_avx.c │ └── reim_to_tnx_ref.c └── reim4 │ ├── reim4_arithmetic.h │ ├── reim4_arithmetic_avx2.c │ ├── reim4_arithmetic_ref.c │ ├── reim4_execute.c │ ├── reim4_fallbacks_aarch64.c │ ├── reim4_fftvec_addmul_fma.c │ ├── reim4_fftvec_addmul_ref.c │ ├── reim4_fftvec_conv_fma.c │ ├── reim4_fftvec_conv_ref.c │ ├── reim4_fftvec_internal.h │ ├── reim4_fftvec_private.h │ └── reim4_fftvec_public.h └── test ├── CMakeLists.txt ├── spqlios_coeffs_arithmetic_test.cpp ├── spqlios_cplx_conversions_test.cpp ├── spqlios_cplx_fft_bench.cpp ├── spqlios_cplx_test.cpp ├── spqlios_q120_arithmetic_bench.cpp ├── spqlios_q120_arithmetic_test.cpp ├── spqlios_q120_ntt_bench.cpp ├── spqlios_q120_ntt_test.cpp ├── spqlios_reim4_arithmetic_bench.cpp ├── spqlios_reim4_arithmetic_test.cpp ├── spqlios_reim_conversions_test.cpp ├── spqlios_reim_test.cpp ├── spqlios_svp_product_test.cpp ├── spqlios_svp_test.cpp ├── spqlios_test.cpp ├── spqlios_vec_rnx_approxdecomp_tnxdbl_test.cpp ├── spqlios_vec_rnx_conversions_test.cpp ├── spqlios_vec_rnx_ppol_test.cpp ├── spqlios_vec_rnx_test.cpp ├── spqlios_vec_rnx_vmp_test.cpp ├── spqlios_vec_znx_big_test.cpp ├── spqlios_vec_znx_dft_test.cpp ├── spqlios_vec_znx_test.cpp ├── spqlios_vmp_product_test.cpp ├── spqlios_zn_approxdecomp_test.cpp ├── spqlios_zn_conversions_test.cpp ├── spqlios_zn_vmp_test.cpp ├── spqlios_znx_small_test.cpp └── testlib ├── fft64_dft.cpp ├── fft64_dft.h ├── fft64_layouts.cpp ├── fft64_layouts.h ├── mod_q120.cpp ├── mod_q120.h ├── negacyclic_polynomial.cpp ├── negacyclic_polynomial.h ├── negacyclic_polynomial_impl.h ├── ntt120_dft.cpp ├── ntt120_dft.h ├── ntt120_layouts.cpp ├── ntt120_layouts.h ├── polynomial_vector.cpp ├── polynomial_vector.h ├── random.cpp ├── reim4_elem.cpp ├── reim4_elem.h ├── sha3.c ├── sha3.h ├── test_commons.cpp ├── test_commons.h ├── test_hash.cpp ├── vec_rnx_layout.cpp ├── vec_rnx_layout.h ├── zn_layouts.cpp └── zn_layouts.h /.clang-format: -------------------------------------------------------------------------------- 1 | # Use the Google style in this project. 2 | BasedOnStyle: Google 3 | 4 | # Some folks prefer to write "int& foo" while others prefer "int &foo". The 5 | # Google Style Guide only asks for consistency within a project, we chose 6 | # "int& foo" for this project: 7 | DerivePointerAlignment: false 8 | PointerAlignment: Left 9 | 10 | # The Google Style Guide only asks for consistency w.r.t. "east const" vs. 11 | # "const west" alignment of cv-qualifiers. In this project we use "east const". 12 | QualifierAlignment: Left 13 | 14 | ColumnLimit: 120 15 | -------------------------------------------------------------------------------- /.github/workflows/auto-release.yml: -------------------------------------------------------------------------------- 1 | name: Auto-Release 2 | 3 | on: 4 | workflow_dispatch: 5 | push: 6 | branches: [ "main" ] 7 | 8 | jobs: 9 | build: 10 | name: Auto-Release 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - uses: actions/checkout@v3 15 | with: 16 | fetch-depth: 3 17 | # sparse-checkout: manifest.yaml scripts/auto-release.sh 18 | 19 | - run: 20 | ${{github.workspace}}/scripts/auto-release.sh 21 | -------------------------------------------------------------------------------- /.github/workflows/cmake-build-test-arm64.yml: -------------------------------------------------------------------------------- 1 | name: CMake-Build-Test-Arm64 2 | 3 | on: 4 | workflow_dispatch: 5 | push: 6 | branches: [ "main" ] 7 | pull_request: 8 | branches: [ "main" ] 9 | types: [ labeled, opened, synchronize, reopened ] 10 | 11 | jobs: 12 | build-arm64: 13 | name: CMake-Build-Test-Arm64 14 | if: github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'check-on-arm64') 15 | runs-on: self-hosted-arm64 16 | 17 | steps: 18 | - uses: actions/checkout@v4 19 | 20 | - name: Configure CMake 21 | run: cmake -B ${{github.workspace}}/build -DCMAKE_BUILD_TYPE=Release -DENABLE_TESTING=ON -DWARNING_PARANOID=ON -DDEVMODE_INSTALL=ON 22 | 23 | - name: Build 24 | run: cmake --build ${{github.workspace}}/build 25 | 26 | - name: Test 27 | run: cd ${{github.workspace}}/build; ctest 28 | 29 | - name: Ci Package 30 | if: github.event_name != 'pull_request' 31 | env: 32 | CI_CREDS: ${{ secrets.CICREDS }} 33 | run: ./scripts/ci-pkg create 34 | 35 | -------------------------------------------------------------------------------- /.github/workflows/cmake-build-test-darwin.yml: -------------------------------------------------------------------------------- 1 | name: CMake-Build-Test-Darwin 2 | 3 | on: 4 | workflow_dispatch: 5 | push: 6 | branches: [ "main" ] 7 | pull_request: 8 | branches: [ "main" ] 9 | 10 | jobs: 11 | build-darwin: 12 | name: CMake-Build-Test-Darwin 13 | runs-on: macos-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v4 17 | 18 | - name: Brew 19 | run: brew install cmake googletest google-benchmark 20 | 21 | - name: Configure CMake 22 | run: cmake -B build -DCMAKE_BUILD_TYPE=Release -DENABLE_TESTING=ON -DWARNING_PARANOID=ON -DDEVMODE_INSTALL=ON 23 | 24 | - name: Build 25 | run: cmake --build build 26 | 27 | - name: Test 28 | run: cd build; ctest 29 | 30 | - name: Package 31 | if: github.event_name != 'pull_request' 32 | env: 33 | CI_CREDS: ${{ secrets.CICREDS }} 34 | run: ./scripts/ci-pkg create 35 | 36 | -------------------------------------------------------------------------------- /.github/workflows/cmake-build-test-win64.yml: -------------------------------------------------------------------------------- 1 | name: CMake-Build-Test-Win64 2 | 3 | on: 4 | workflow_dispatch: 5 | push: 6 | branches: [ "main" ] 7 | pull_request: 8 | branches: [ "main" ] 9 | 10 | jobs: 11 | build-msys-clang64: 12 | name: CMake-Build-Test-Win64 13 | runs-on: windows-latest 14 | defaults: 15 | run: 16 | shell: msys2 {0} 17 | 18 | steps: 19 | - uses: actions/checkout@v4 20 | 21 | - uses: msys2/setup-msys2@v2 22 | with: 23 | msystem: CLANG64 24 | #update: true 25 | install: git 26 | pacboy: toolchain:p cmake:p gtest:p benchmark:p 27 | 28 | - name: Configure CMake 29 | run: cmake -B build -DCMAKE_BUILD_TYPE=Release -DENABLE_TESTING=ON -DWARNING_PARANOID=ON -DDEVMODE_INSTALL=ON 30 | 31 | - name: Build 32 | run: cmake --build build 33 | 34 | - name: Test 35 | working-directory: build 36 | run: ctest --output-on-failure 37 | 38 | - name: CI Package 39 | if: github.event_name != 'pull_request' 40 | env: 41 | CI_CREDS: ${{ secrets.CICREDS }} 42 | run: ./scripts/ci-pkg create 43 | 44 | -------------------------------------------------------------------------------- /.github/workflows/cmake-build-test.yml: -------------------------------------------------------------------------------- 1 | name: CMake-Build-Test 2 | 3 | on: 4 | workflow_dispatch: 5 | push: 6 | branches: [ "main" ] 7 | pull_request: 8 | branches: [ "main" ] 9 | 10 | jobs: 11 | build: 12 | name: CMake-Build-Test 13 | runs-on: ubuntu-latest 14 | container: ngama75/spqlios-ci:latest 15 | 16 | steps: 17 | - uses: actions/checkout@v4 18 | 19 | - name: Configure CMake 20 | run: cmake -B ${{github.workspace}}/build -DCMAKE_BUILD_TYPE=Release -DENABLE_TESTING=ON -DWARNING_PARANOID=ON -DDEVMODE_INSTALL=ON 21 | 22 | - name: Build 23 | run: cmake --build ${{github.workspace}}/build 24 | 25 | - name: Test 26 | run: cd ${{github.workspace}}/build; ctest 27 | 28 | - name: Package 29 | if: github.event_name != 'pull_request' 30 | env: 31 | CI_CREDS: ${{ secrets.CICREDS }} 32 | run: ./scripts/ci-pkg create 33 | 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | cmake-build-* 2 | .idea 3 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.8) 2 | project(spqlios) 3 | 4 | # read the current version from the manifest file 5 | file(READ "manifest.yaml" manifest) 6 | string(REGEX MATCH "version: +(([0-9]+)\\.([0-9]+)\\.([0-9]+))" SPQLIOS_VERSION_BLAH ${manifest}) 7 | #message(STATUS "Version: ${SPQLIOS_VERSION_BLAH}") 8 | set(SPQLIOS_VERSION ${CMAKE_MATCH_1}) 9 | set(SPQLIOS_VERSION_MAJOR ${CMAKE_MATCH_2}) 10 | set(SPQLIOS_VERSION_MINOR ${CMAKE_MATCH_3}) 11 | set(SPQLIOS_VERSION_PATCH ${CMAKE_MATCH_4}) 12 | message(STATUS "Compiling spqlios-fft version: ${SPQLIOS_VERSION_MAJOR}.${SPQLIOS_VERSION_MINOR}.${SPQLIOS_VERSION_PATCH}") 13 | 14 | #set(ENABLE_SPQLIOS_F128 ON CACHE BOOL "Enable float128 via libquadmath") 15 | set(WARNING_PARANOID ON CACHE BOOL "Treat all warnings as errors") 16 | set(ENABLE_TESTING ON CACHE BOOL "Compiles unittests and integration tests") 17 | set(DEVMODE_INSTALL OFF CACHE BOOL "Install private headers and testlib (mainly for CI)") 18 | 19 | if (NOT CMAKE_BUILD_TYPE OR CMAKE_BUILD_TYPE STREQUAL "") 20 | set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Build type: Release or Debug" FORCE) 21 | endif() 22 | message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") 23 | 24 | if (WARNING_PARANOID) 25 | add_compile_options(-Wall -Werror -Wno-unused-command-line-argument) 26 | endif() 27 | 28 | message(STATUS "CMAKE_HOST_SYSTEM_NAME: ${CMAKE_HOST_SYSTEM_NAME}") 29 | message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") 30 | message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}") 31 | 32 | if (CMAKE_SYSTEM_PROCESSOR MATCHES "(x86)|(X86)|(amd64)|(AMD64)") 33 | set(X86 ON) 34 | set(AARCH64 OFF) 35 | else () 36 | set(X86 OFF) 37 | # set(ENABLE_SPQLIOS_F128 OFF) # float128 are only supported for x86 targets 38 | endif () 39 | if (CMAKE_SYSTEM_PROCESSOR MATCHES "(aarch64)|(arm64)") 40 | set(AARCH64 ON) 41 | endif () 42 | 43 | if (CMAKE_SYSTEM_NAME MATCHES "(Windows)|(MSYS)") 44 | set(WIN32 ON) 45 | endif () 46 | if (WIN32) 47 | #overrides for win32 48 | set(X86 OFF) 49 | set(AARCH64 OFF) 50 | set(X86_WIN32 ON) 51 | else() 52 | set(X86_WIN32 OFF) 53 | set(WIN32 OFF) 54 | endif (WIN32) 55 | 56 | message(STATUS "--> WIN32: ${WIN32}") 57 | message(STATUS "--> X86_WIN32: ${X86_WIN32}") 58 | message(STATUS "--> X86_LINUX: ${X86}") 59 | message(STATUS "--> AARCH64: ${AARCH64}") 60 | 61 | # compiles the main library in spqlios 62 | add_subdirectory(spqlios) 63 | 64 | # compiles and activates unittests and itests 65 | if (${ENABLE_TESTING}) 66 | enable_testing() 67 | add_subdirectory(test) 68 | endif() 69 | 70 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to SPQlios-fft 2 | 3 | The spqlios-fft team encourages contributions. 4 | We encourage users to fix bugs, improve the documentation, write tests and to enhance the code, or ask for new features. 5 | We encourage researchers to contribute with implementations of their FFT or NTT algorithms. 6 | In the following we are trying to give some guidance on how to contribute effectively. 7 | 8 | ## Communication ## 9 | 10 | Communication in the spqlios-fft project happens mainly on [GitHub](https://github.com/tfhe/spqlios-fft/issues). 11 | 12 | All communications are public, so please make sure to maintain professional behaviour in 13 | all published comments. See [Code of Conduct](https://www.contributor-covenant.org/version/2/1/code_of_conduct/) for 14 | guidelines. 15 | 16 | ## Reporting Bugs or Requesting features ## 17 | 18 | Bug should be filed at [https://github.com/tfhe/spqlios-fft/issues](https://github.com/tfhe/spqlios-fft/issues). 19 | 20 | Features can also be requested there, in this case, please ensure that the features you request are self-contained, 21 | easy to define, and generic enough to be used in different use-cases. Please provide an example of use-cases if 22 | possible. 23 | 24 | ## Setting up topic branches and generating pull requests 25 | 26 | This section applies to people that already have write access to the repository. Specific instructions for pull-requests 27 | from public forks will be given later. 28 | 29 | To implement some changes, please follow these steps: 30 | 31 | - Create a "topic branch". Usually, the branch name should be `username/small-title` 32 | or better `username/issuenumber-small-title` where `issuenumber` is the number of 33 | the github issue number that is tackled. 34 | - Push any needed commits to your branch. Make sure it compiles in `CMAKE_BUILD_TYPE=Debug` and `=Release`, with `-DWARNING_PARANOID=ON`. 35 | - When the branch is nearly ready for review, please open a pull request, and add the label `check-on-arm` 36 | - Do as many commits as necessary until all CI checks pass and all PR comments have been resolved. 37 | 38 | > _During the process, you may optionnally use `git rebase -i` to clean up your commit history. If you elect to do so, 39 | please at the very least make sure that nobody else is working or has forked from your branch: the conflicts it would generate 40 | and the human hours to fix them are not worth it. `Git merge` remains the preferred option._ 41 | 42 | - Finally, when all reviews are positive and all CI checks pass, you may merge your branch via the github webpage. 43 | 44 | ### Keep your pull requests limited to a single issue 45 | 46 | Pull requests should be as small/atomic as possible. 47 | 48 | ### Coding Conventions 49 | 50 | * Please make sure that your code is formatted according to the `.clang-format` file and 51 | that all files end with a newline character. 52 | * Please make sure that all the functions declared in the public api have relevant doxygen comments. 53 | Preferably, functions in the private apis should also contain a brief doxygen description. 54 | 55 | ### Versions and History 56 | 57 | * **Stable API** The project uses semantic versioning on the functions that are listed as `stable` in the documentation. A version has 58 | the form `x.y.z` 59 | * a patch release that increments `z` does not modify the stable API. 60 | * a minor release that increments `y` adds a new feature to the stable API. 61 | * In the unlikely case where we need to change or remove a feature, we will trigger a major release that 62 | increments `x`. 63 | 64 | > _If any, we will mark those features as deprecated at least six months before the major release._ 65 | 66 | * **Experimental API** Features that are not part of the stable section in the documentation are experimental features: you may test them at 67 | your own risk, 68 | but keep in mind that semantic versioning does not apply to them. 69 | 70 | > _If you have a use-case that uses an experimental feature, we encourage 71 | > you to tell us about it, so that this feature reaches to the stable section faster!_ 72 | 73 | * **Version history** The current version is reported in `manifest.yaml`, any change of version comes up with a tag on the main branch, and the history between releases is summarized in `Changelog.md`. It is the main source of truth for anyone who wishes to 74 | get insight about 75 | the history of the repository (not the commit graph). 76 | 77 | > Note: _The commit graph of git is for git's internal use only. Its main purpose is to reduce potential merge conflicts to a minimum, even in scenario where multiple features are developped in parallel: it may therefore be non-linear. If, as humans, we like to see a linear history, please read `Changelog.md` instead!_ 78 | -------------------------------------------------------------------------------- /Changelog.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 5 | 6 | ## [2.0.0] - 2024-08-21 7 | 8 | - Initial release of the `vec_znx` (except convolution products), `vec_rnx` and `zn` apis. 9 | - Hardware acceleration available: AVX2 (most parts) 10 | - APIs are documented in the wiki and are in "beta mode": during the 2.x -> 3.x transition, functions whose API is satisfactory in test projects will pass in "stable mode". 11 | 12 | ## [1.0.0] - 2023-07-18 13 | 14 | - Initial release of the double precision fft on the reim and cplx backends 15 | - Coeffs-space conversions cplx <-> znx32 and tnx32 16 | - FFT-space conversions cplx <-> reim4 layouts 17 | - FFT-space multiplications on the cplx, reim and reim4 layouts. 18 | - In this first release, the only platform supported is linux x86_64 (generic C code, and avx2/fma). It compiles on arm64, but without any acceleration. 19 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SPQlios library 2 | 3 | 4 | 5 | The SPQlios library provides fast arithmetic for Fully Homomorphic Encryption, and other lattice constructions that arise in post quantum cryptography. 6 | 7 | 8 | 9 | Namely, it is divided into 4 sections: 10 | 11 | * The low-level DFT section support FFT over 64-bit floats, as well as NTT modulo one fixed 120-bit modulus. It is an upgrade of the original spqlios-fft module embedded in the TFHE library since 2016. The DFT section exposes the traditional DFT, inverse-DFT, and coefficient-wise multiplications in DFT space. 12 | * The VEC_ZNX section exposes fast algebra over vectors of small integer polynomial modulo $X^N+1$. It proposed in particular efficient (prepared) vector-matrix products, scalar-vector products, convolution products, and element-wise products, operations that naturally occurs on gadget-decomposed Ring-LWE coordinates. 13 | * The RNX section is a simpler variant of VEC_ZNX, to represent single polynomials modulo $X^N+1$ (over the reals or over the torus) when the coefficient precision fits on 64-bit doubles. The small vector-matrix API of the RNX section is particularly adapted to reproducing the fastest CGGI-based bootstrappings. 14 | * The ZN section focuses over vector and matrix algebra over scalars (used by scalar LWE, or scalar key-switches, but also on non-ring schemes like Frodo, FrodoPIR, and SimplePIR). 15 | 16 | ### A high value target for hardware accelerations 17 | 18 | SPQlios is more than a library, it is also a good target for hardware developers. 19 | On one hand, the arithmetic operations that are defined in the library have a clear standalone mathematical definition. And at the same time, the amount of work in each operations is sufficiently large so that meaningful functions only require a few of these. 20 | 21 | This makes the SPQlios API a high value target for hardware acceleration, that targets FHE. 22 | 23 | ### SPQLios is not an FHE library, but a huge enabler 24 | 25 | SPQlios itself is not an FHE library: there is no ciphertext, plaintext or key. It is a mathematical library that exposes efficient algebra over polynomials. Using the functions exposed, it is possible to quickly build efficient FHE libraries, with support for the main schemes based on Ring-LWE: BFV, BGV, CGGI, DM, CKKS. 26 | 27 | 28 | ## Dependencies 29 | 30 | The SPQLIOS-FFT library is a C library that can be compiled with a standard C compiler, and depends only on libc and libm. The API 31 | interface can be used in a regular C code, and any other language via classical foreign APIs. 32 | 33 | The unittests and integration tests are in an optional part of the code, and are written in C++. These tests rely on 34 | [```benchmark```](https://github.com/google/benchmark), and [```gtest```](https://github.com/google/googletest) libraries, and therefore require a C++17 compiler. 35 | 36 | Currently, the project has been tested with the gcc,g++ >= 11.3.0 compiler under Linux (x86_64). In the future, we plan to 37 | extend the compatibility to other compilers, platforms and operating systems. 38 | 39 | 40 | ## Installation 41 | 42 | The library uses a classical ```cmake``` build mechanism: use ```cmake``` to create a ```build``` folder in the top level directory and run ```make``` from inside it. This assumes that the standard tool ```cmake``` is already installed on the system, and an up-to-date c++ compiler (i.e. g++ >=11.3.0) as well. 43 | 44 | It will compile the shared library in optimized mode, and ```make install``` install it to the desired prefix folder (by default ```/usr/local/lib```). 45 | 46 | If you want to choose additional compile options (i.e. other installation folder, debug mode, tests), you need to run cmake manually and pass the desired options: 47 | ``` 48 | mkdir build 49 | cd build 50 | cmake ../src -CMAKE_INSTALL_PREFIX=/usr/ 51 | make 52 | ``` 53 | The available options are the following: 54 | 55 | | Variable Name | values | 56 | | -------------------- | ------------------------------------------------------------ | 57 | | CMAKE_INSTALL_PREFIX | */usr/local* installation folder (libs go in lib/ and headers in include/) | 58 | | WARNING_PARANOID | All warnings are shown and treated as errors. Off by default | 59 | | ENABLE_TESTING | Compiles unit tests and integration tests | 60 | 61 | ------ 62 | 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /docs/logo-inpher1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tfhe/spqlios-arithmetic/7ea875f61c51c67687ec9c15ae0eae04cda961f8/docs/logo-inpher1.png -------------------------------------------------------------------------------- /docs/logo-inpher2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tfhe/spqlios-arithmetic/7ea875f61c51c67687ec9c15ae0eae04cda961f8/docs/logo-inpher2.png -------------------------------------------------------------------------------- /manifest.yaml: -------------------------------------------------------------------------------- 1 | library: spqlios-fft 2 | version: 2.0.0 3 | -------------------------------------------------------------------------------- /scripts/auto-release.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # this script generates one tag if there is a version change in manifest.yaml 4 | cd `dirname $0`/.. 5 | if [ "v$1" = "v-y" ]; then 6 | echo "production mode!"; 7 | fi 8 | changes=`git diff HEAD~1..HEAD -- manifest.yaml | grep 'version:'` 9 | oldversion=$(echo "$changes" | grep '^-version:' | cut '-d ' -f2) 10 | version=$(echo "$changes" | grep '^+version:' | cut '-d ' -f2) 11 | echo "Versions: $oldversion --> $version" 12 | if [ "v$oldversion" = "v$version" ]; then 13 | echo "Same version - nothing to do"; exit 0; 14 | fi 15 | if [ "v$1" = "v-y" ]; then 16 | git config user.name github-actions 17 | git config user.email github-actions@github.com 18 | git tag -a "v$version" -m "Version $version" 19 | git push origin "v$version" 20 | else 21 | cat </dev/null 75 | rm -f "$DIR/$FNAME" 2>/dev/null 76 | DESTDIR="$DIR/dist" cmake --install build || exit 1 77 | if [ -d "$DIR/dist$CI_INSTALL_PREFIX" ]; then 78 | tar -C "$DIR/dist" -cvzf "$DIR/$FNAME" . 79 | else 80 | # fix since msys can mess up the paths 81 | REAL_DEST=`find "$DIR/dist" -type d -exec test -d "{}$CI_INSTALL_PREFIX" \; -print` 82 | echo "REAL_DEST: $REAL_DEST" 83 | [ -d "$REAL_DEST$CI_INSTALL_PREFIX" ] && tar -C "$REAL_DEST" -cvzf "$DIR/$FNAME" . 84 | fi 85 | [ -f "$DIR/$FNAME" ] || { echo "failed to create $DIR/$FNAME"; exit 1; } 86 | [ "x$CI_CREDS" = "x" ] && { echo "CI_CREDS is not set: not uploading"; exit 1; } 87 | curl -u "$CI_CREDS" -T "$DIR/$FNAME" "$CI_REPO_URL/$FNAME" 88 | fi 89 | 90 | if [ "x$1" = "xinstall" ]; then 91 | [ "x$CI_CREDS" = "x" ] && { echo "CI_CREDS is not set: not downloading"; exit 1; } 92 | # cleaning 93 | rm -rf "$DESTDIR$CI_INSTALL_PREFIX"/* 2>/dev/null 94 | rm -f "$DIR/$FNAME" 2>/dev/null 95 | # downloading 96 | curl -u "$CI_CREDS" -o "$DIR/$FNAME" "$CI_REPO_URL/$FNAME" 97 | [ -f "$DIR/$FNAME" ] || { echo "failed to download $DIR/$FNAME"; exit 0; } 98 | # installing 99 | mkdir -p $DESTDIR 100 | tar -C "$DESTDIR" -xvzf "$DIR/$FNAME" 101 | exit 0 102 | fi 103 | -------------------------------------------------------------------------------- /spqlios/arithmetic/scalar_vector_product.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "vec_znx_arithmetic_private.h" 4 | 5 | EXPORT uint64_t bytes_of_svp_ppol(const MODULE* module) { return module->func.bytes_of_svp_ppol(module); } 6 | 7 | EXPORT uint64_t fft64_bytes_of_svp_ppol(const MODULE* module) { return module->nn * sizeof(double); } 8 | 9 | EXPORT SVP_PPOL* new_svp_ppol(const MODULE* module) { return spqlios_alloc(bytes_of_svp_ppol(module)); } 10 | 11 | EXPORT void delete_svp_ppol(SVP_PPOL* ppol) { spqlios_free(ppol); } 12 | 13 | // public wrappers 14 | EXPORT void svp_prepare(const MODULE* module, // N 15 | SVP_PPOL* ppol, // output 16 | const int64_t* pol // a 17 | ) { 18 | module->func.svp_prepare(module, ppol, pol); 19 | } 20 | 21 | /** @brief prepares a svp polynomial */ 22 | EXPORT void fft64_svp_prepare_ref(const MODULE* module, // N 23 | SVP_PPOL* ppol, // output 24 | const int64_t* pol // a 25 | ) { 26 | reim_from_znx64(module->mod.fft64.p_conv, ppol, pol); 27 | reim_fft(module->mod.fft64.p_fft, (double*)ppol); 28 | } 29 | 30 | EXPORT void svp_apply_dft(const MODULE* module, // N 31 | const VEC_ZNX_DFT* res, uint64_t res_size, // output 32 | const SVP_PPOL* ppol, // prepared pol 33 | const int64_t* a, uint64_t a_size, uint64_t a_sl) { 34 | module->func.svp_apply_dft(module, // N 35 | res, 36 | res_size, // output 37 | ppol, // prepared pol 38 | a, a_size, a_sl); 39 | } 40 | 41 | // result = ppol * a 42 | EXPORT void fft64_svp_apply_dft_ref(const MODULE* module, // N 43 | const VEC_ZNX_DFT* res, uint64_t res_size, // output 44 | const SVP_PPOL* ppol, // prepared pol 45 | const int64_t* a, uint64_t a_size, uint64_t a_sl // a 46 | ) { 47 | const uint64_t nn = module->nn; 48 | double* const dres = (double*)res; 49 | double* const dppol = (double*)ppol; 50 | 51 | const uint64_t auto_end_idx = res_size < a_size ? res_size : a_size; 52 | for (uint64_t i = 0; i < auto_end_idx; ++i) { 53 | const int64_t* a_ptr = a + i * a_sl; 54 | double* const res_ptr = dres + i * nn; 55 | // copy the polynomial to res, apply fft in place, call fftvec_mul in place. 56 | reim_from_znx64(module->mod.fft64.p_conv, res_ptr, a_ptr); 57 | reim_fft(module->mod.fft64.p_fft, res_ptr); 58 | reim_fftvec_mul(module->mod.fft64.mul_fft, res_ptr, res_ptr, dppol); 59 | } 60 | 61 | // then extend with zeros 62 | memset(dres + auto_end_idx * nn, 0, (res_size - auto_end_idx) * nn * sizeof(double)); 63 | } 64 | -------------------------------------------------------------------------------- /spqlios/arithmetic/vec_rnx_approxdecomp_avx.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "immintrin.h" 4 | #include "vec_rnx_arithmetic_private.h" 5 | 6 | /** @brief sets res = gadget_decompose(a) */ 7 | EXPORT void rnx_approxdecomp_from_tnxdbl_avx( // 8 | const MOD_RNX* module, // N 9 | const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K 10 | double* res, uint64_t res_size, uint64_t res_sl, // res 11 | const double* a // a 12 | ) { 13 | const uint64_t nn = module->n; 14 | if (nn < 4) return rnx_approxdecomp_from_tnxdbl_ref(module, gadget, res, res_size, res_sl, a); 15 | const uint64_t ell = gadget->ell; 16 | const __m256i k = _mm256_set1_epi64x(gadget->k); 17 | const __m256d add_cst = _mm256_set1_pd(gadget->add_cst); 18 | const __m256i and_mask = _mm256_set1_epi64x(gadget->and_mask); 19 | const __m256i or_mask = _mm256_set1_epi64x(gadget->or_mask); 20 | const __m256d sub_cst = _mm256_set1_pd(gadget->sub_cst); 21 | const uint64_t msize = res_size <= ell ? res_size : ell; 22 | // gadget decompose column by column 23 | if (msize == ell) { 24 | // this is the main scenario when msize == ell 25 | double* const last_r = res + (msize - 1) * res_sl; 26 | for (uint64_t j = 0; j < nn; j += 4) { 27 | double* rr = last_r + j; 28 | const double* aa = a + j; 29 | __m256d t_dbl = _mm256_add_pd(_mm256_loadu_pd(aa), add_cst); 30 | __m256i t_int = _mm256_castpd_si256(t_dbl); 31 | do { 32 | __m256i u_int = _mm256_or_si256(_mm256_and_si256(t_int, and_mask), or_mask); 33 | _mm256_storeu_pd(rr, _mm256_sub_pd(_mm256_castsi256_pd(u_int), sub_cst)); 34 | t_int = _mm256_srlv_epi64(t_int, k); 35 | rr -= res_sl; 36 | } while (rr >= res); 37 | } 38 | } else if (msize > 0) { 39 | // otherwise, if msize < ell: there is one additional rshift 40 | const __m256i first_rsh = _mm256_set1_epi64x((ell - msize) * gadget->k); 41 | double* const last_r = res + (msize - 1) * res_sl; 42 | for (uint64_t j = 0; j < nn; j += 4) { 43 | double* rr = last_r + j; 44 | const double* aa = a + j; 45 | __m256d t_dbl = _mm256_add_pd(_mm256_loadu_pd(aa), add_cst); 46 | __m256i t_int = _mm256_srlv_epi64(_mm256_castpd_si256(t_dbl), first_rsh); 47 | do { 48 | __m256i u_int = _mm256_or_si256(_mm256_and_si256(t_int, and_mask), or_mask); 49 | _mm256_storeu_pd(rr, _mm256_sub_pd(_mm256_castsi256_pd(u_int), sub_cst)); 50 | t_int = _mm256_srlv_epi64(t_int, k); 51 | rr -= res_sl; 52 | } while (rr >= res); 53 | } 54 | } 55 | // zero-out the last slices (if any) 56 | for (uint64_t i = msize; i < res_size; ++i) { 57 | memset(res + i * res_sl, 0, nn * sizeof(double)); 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /spqlios/arithmetic/vec_rnx_approxdecomp_ref.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "vec_rnx_arithmetic_private.h" 4 | 5 | typedef union di { 6 | double dv; 7 | uint64_t uv; 8 | } di_t; 9 | 10 | /** @brief new gadget: delete with delete_tnxdbl_approxdecomp_gadget */ 11 | EXPORT TNXDBL_APPROXDECOMP_GADGET* new_tnxdbl_approxdecomp_gadget( // 12 | const MOD_RNX* module, // N 13 | uint64_t k, uint64_t ell // base 2^K and size 14 | ) { 15 | if (k * ell > 50) return spqlios_error("gadget requires a too large fp precision"); 16 | TNXDBL_APPROXDECOMP_GADGET* res = spqlios_alloc(sizeof(TNXDBL_APPROXDECOMP_GADGET)); 17 | res->k = k; 18 | res->ell = ell; 19 | // double add_cst; // double(3.2^(51-ell.K) + 1/2.(sum 2^(-iK)) for i=[0,ell[) 20 | union di add_cst; 21 | add_cst.dv = UINT64_C(3) << (51 - ell * k); 22 | for (uint64_t i = 0; i < ell; ++i) { 23 | add_cst.uv |= UINT64_C(1) << ((i + 1) * k - 1); 24 | } 25 | res->add_cst = add_cst.dv; 26 | // uint64_t and_mask; // uint64(2^(K)-1) 27 | res->and_mask = (UINT64_C(1) << k) - 1; 28 | // uint64_t or_mask; // double(2^52) 29 | union di or_mask; 30 | or_mask.dv = (UINT64_C(1) << 52); 31 | res->or_mask = or_mask.uv; 32 | // double sub_cst; // double(2^52 + 2^(K-1)) 33 | res->sub_cst = ((UINT64_C(1) << 52) + (UINT64_C(1) << (k - 1))); 34 | return res; 35 | } 36 | 37 | EXPORT void delete_tnxdbl_approxdecomp_gadget(TNXDBL_APPROXDECOMP_GADGET* gadget) { spqlios_free(gadget); } 38 | 39 | /** @brief sets res = gadget_decompose(a) */ 40 | EXPORT void rnx_approxdecomp_from_tnxdbl_ref( // 41 | const MOD_RNX* module, // N 42 | const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K 43 | double* res, uint64_t res_size, uint64_t res_sl, // res 44 | const double* a // a 45 | ) { 46 | const uint64_t nn = module->n; 47 | const uint64_t k = gadget->k; 48 | const uint64_t ell = gadget->ell; 49 | const double add_cst = gadget->add_cst; 50 | const uint64_t and_mask = gadget->and_mask; 51 | const uint64_t or_mask = gadget->or_mask; 52 | const double sub_cst = gadget->sub_cst; 53 | const uint64_t msize = res_size <= ell ? res_size : ell; 54 | const uint64_t first_rsh = (ell - msize) * k; 55 | // gadget decompose column by column 56 | if (msize > 0) { 57 | double* const last_r = res + (msize - 1) * res_sl; 58 | for (uint64_t j = 0; j < nn; ++j) { 59 | double* rr = last_r + j; 60 | di_t t = {.dv = a[j] + add_cst}; 61 | if (msize < ell) t.uv >>= first_rsh; 62 | do { 63 | di_t u; 64 | u.uv = (t.uv & and_mask) | or_mask; 65 | *rr = u.dv - sub_cst; 66 | t.uv >>= k; 67 | rr -= res_sl; 68 | } while (rr >= res); 69 | } 70 | } 71 | // zero-out the last slices (if any) 72 | for (uint64_t i = msize; i < res_size; ++i) { 73 | memset(res + i * res_sl, 0, nn * sizeof(double)); 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /spqlios/arithmetic/vec_rnx_conversions_ref.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "vec_rnx_arithmetic_private.h" 4 | #include "zn_arithmetic_private.h" 5 | 6 | EXPORT void vec_rnx_to_znx32_ref( // 7 | const MOD_RNX* module, // N 8 | int32_t* res, uint64_t res_size, uint64_t res_sl, // res 9 | const double* a, uint64_t a_size, uint64_t a_sl // a 10 | ) { 11 | const uint64_t nn = module->n; 12 | const uint64_t msize = res_size < a_size ? res_size : a_size; 13 | for (uint64_t i = 0; i < msize; ++i) { 14 | dbl_round_to_i32_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn); 15 | } 16 | for (uint64_t i = msize; i < res_size; ++i) { 17 | memset(res + i * res_sl, 0, nn * sizeof(int32_t)); 18 | } 19 | } 20 | 21 | EXPORT void vec_rnx_from_znx32_ref( // 22 | const MOD_RNX* module, // N 23 | double* res, uint64_t res_size, uint64_t res_sl, // res 24 | const int32_t* a, uint64_t a_size, uint64_t a_sl // a 25 | ) { 26 | const uint64_t nn = module->n; 27 | const uint64_t msize = res_size < a_size ? res_size : a_size; 28 | for (uint64_t i = 0; i < msize; ++i) { 29 | i32_to_dbl_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn); 30 | } 31 | for (uint64_t i = msize; i < res_size; ++i) { 32 | memset(res + i * res_sl, 0, nn * sizeof(int32_t)); 33 | } 34 | } 35 | EXPORT void vec_rnx_to_tnx32_ref( // 36 | const MOD_RNX* module, // N 37 | int32_t* res, uint64_t res_size, uint64_t res_sl, // res 38 | const double* a, uint64_t a_size, uint64_t a_sl // a 39 | ) { 40 | const uint64_t nn = module->n; 41 | const uint64_t msize = res_size < a_size ? res_size : a_size; 42 | for (uint64_t i = 0; i < msize; ++i) { 43 | dbl_to_tn32_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn); 44 | } 45 | for (uint64_t i = msize; i < res_size; ++i) { 46 | memset(res + i * res_sl, 0, nn * sizeof(int32_t)); 47 | } 48 | } 49 | EXPORT void vec_rnx_from_tnx32_ref( // 50 | const MOD_RNX* module, // N 51 | double* res, uint64_t res_size, uint64_t res_sl, // res 52 | const int32_t* a, uint64_t a_size, uint64_t a_sl // a 53 | ) { 54 | const uint64_t nn = module->n; 55 | const uint64_t msize = res_size < a_size ? res_size : a_size; 56 | for (uint64_t i = 0; i < msize; ++i) { 57 | tn32_to_dbl_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn); 58 | } 59 | for (uint64_t i = msize; i < res_size; ++i) { 60 | memset(res + i * res_sl, 0, nn * sizeof(int32_t)); 61 | } 62 | } 63 | 64 | static void dbl_to_tndbl_ref( // 65 | const void* UNUSED, // N 66 | double* res, uint64_t res_size, // res 67 | const double* a, uint64_t a_size // a 68 | ) { 69 | static const double OFF_CST = INT64_C(3) << 51; 70 | const uint64_t msize = res_size < a_size ? res_size : a_size; 71 | for (uint64_t i = 0; i < msize; ++i) { 72 | double ai = a[i] + OFF_CST; 73 | res[i] = a[i] - (ai - OFF_CST); 74 | } 75 | memset(res + msize, 0, (res_size - msize) * sizeof(double)); 76 | } 77 | 78 | EXPORT void vec_rnx_to_tnxdbl_ref( // 79 | const MOD_RNX* module, // N 80 | double* res, uint64_t res_size, uint64_t res_sl, // res 81 | const double* a, uint64_t a_size, uint64_t a_sl // a 82 | ) { 83 | const uint64_t nn = module->n; 84 | const uint64_t msize = res_size < a_size ? res_size : a_size; 85 | for (uint64_t i = 0; i < msize; ++i) { 86 | dbl_to_tndbl_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn); 87 | } 88 | for (uint64_t i = msize; i < res_size; ++i) { 89 | memset(res + i * res_sl, 0, nn * sizeof(int32_t)); 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /spqlios/arithmetic/vec_rnx_svp_ref.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "../coeffs/coeffs_arithmetic.h" 4 | #include "vec_rnx_arithmetic_private.h" 5 | 6 | EXPORT uint64_t fft64_bytes_of_rnx_svp_ppol(const MOD_RNX* module) { return module->n * sizeof(double); } 7 | 8 | EXPORT RNX_SVP_PPOL* new_rnx_svp_ppol(const MOD_RNX* module) { return spqlios_alloc(bytes_of_rnx_svp_ppol(module)); } 9 | 10 | EXPORT void delete_rnx_svp_ppol(RNX_SVP_PPOL* ppol) { spqlios_free(ppol); } 11 | 12 | /** @brief prepares a svp polynomial */ 13 | EXPORT void fft64_rnx_svp_prepare_ref(const MOD_RNX* module, // N 14 | RNX_SVP_PPOL* ppol, // output 15 | const double* pol // a 16 | ) { 17 | double* const dppol = (double*)ppol; 18 | rnx_divide_by_m_ref(module->n, module->m, dppol, pol); 19 | reim_fft(module->precomp.fft64.p_fft, dppol); 20 | } 21 | 22 | EXPORT void fft64_rnx_svp_apply_ref( // 23 | const MOD_RNX* module, // N 24 | double* res, uint64_t res_size, uint64_t res_sl, // output 25 | const RNX_SVP_PPOL* ppol, // prepared pol 26 | const double* a, uint64_t a_size, uint64_t a_sl // a 27 | ) { 28 | const uint64_t nn = module->n; 29 | double* const dppol = (double*)ppol; 30 | 31 | const uint64_t auto_end_idx = res_size < a_size ? res_size : a_size; 32 | for (uint64_t i = 0; i < auto_end_idx; ++i) { 33 | const double* a_ptr = a + i * a_sl; 34 | double* const res_ptr = res + i * res_sl; 35 | // copy the polynomial to res, apply fft in place, call fftvec 36 | // _mul, apply ifft in place. 37 | memcpy(res_ptr, a_ptr, nn * sizeof(double)); 38 | reim_fft(module->precomp.fft64.p_fft, (double*)res_ptr); 39 | reim_fftvec_mul(module->precomp.fft64.p_fftvec_mul, res_ptr, res_ptr, dppol); 40 | reim_ifft(module->precomp.fft64.p_ifft, res_ptr); 41 | } 42 | 43 | // then extend with zeros 44 | for (uint64_t i = auto_end_idx; i < res_size; ++i) { 45 | memset(res + i * res_sl, 0, nn * sizeof(double)); 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /spqlios/arithmetic/vec_znx_avx.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "../coeffs/coeffs_arithmetic.h" 4 | #include "../reim4/reim4_arithmetic.h" 5 | #include "vec_znx_arithmetic_private.h" 6 | 7 | // specialized function (ref) 8 | 9 | // Note: these functions do not have an avx variant. 10 | #define znx_copy_i64_avx znx_copy_i64_ref 11 | #define znx_zero_i64_avx znx_zero_i64_ref 12 | 13 | EXPORT void vec_znx_add_avx(const MODULE* module, // N 14 | int64_t* res, uint64_t res_size, uint64_t res_sl, // res 15 | const int64_t* a, uint64_t a_size, uint64_t a_sl, // a 16 | const int64_t* b, uint64_t b_size, uint64_t b_sl // b 17 | ) { 18 | const uint64_t nn = module->nn; 19 | if (a_size <= b_size) { 20 | const uint64_t sum_idx = res_size < a_size ? res_size : a_size; 21 | const uint64_t copy_idx = res_size < b_size ? res_size : b_size; 22 | // add up to the smallest dimension 23 | for (uint64_t i = 0; i < sum_idx; ++i) { 24 | znx_add_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); 25 | } 26 | // then copy to the largest dimension 27 | for (uint64_t i = sum_idx; i < copy_idx; ++i) { 28 | znx_copy_i64_avx(nn, res + i * res_sl, b + i * b_sl); 29 | } 30 | // then extend with zeros 31 | for (uint64_t i = copy_idx; i < res_size; ++i) { 32 | znx_zero_i64_avx(nn, res + i * res_sl); 33 | } 34 | } else { 35 | const uint64_t sum_idx = res_size < b_size ? res_size : b_size; 36 | const uint64_t copy_idx = res_size < a_size ? res_size : a_size; 37 | // add up to the smallest dimension 38 | for (uint64_t i = 0; i < sum_idx; ++i) { 39 | znx_add_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); 40 | } 41 | // then copy to the largest dimension 42 | for (uint64_t i = sum_idx; i < copy_idx; ++i) { 43 | znx_copy_i64_avx(nn, res + i * res_sl, a + i * a_sl); 44 | } 45 | // then extend with zeros 46 | for (uint64_t i = copy_idx; i < res_size; ++i) { 47 | znx_zero_i64_avx(nn, res + i * res_sl); 48 | } 49 | } 50 | } 51 | 52 | EXPORT void vec_znx_sub_avx(const MODULE* module, // N 53 | int64_t* res, uint64_t res_size, uint64_t res_sl, // res 54 | const int64_t* a, uint64_t a_size, uint64_t a_sl, // a 55 | const int64_t* b, uint64_t b_size, uint64_t b_sl // b 56 | ) { 57 | const uint64_t nn = module->nn; 58 | if (a_size <= b_size) { 59 | const uint64_t sub_idx = res_size < a_size ? res_size : a_size; 60 | const uint64_t copy_idx = res_size < b_size ? res_size : b_size; 61 | // subtract up to the smallest dimension 62 | for (uint64_t i = 0; i < sub_idx; ++i) { 63 | znx_sub_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); 64 | } 65 | // then negate to the largest dimension 66 | for (uint64_t i = sub_idx; i < copy_idx; ++i) { 67 | znx_negate_i64_avx(nn, res + i * res_sl, b + i * b_sl); 68 | } 69 | // then extend with zeros 70 | for (uint64_t i = copy_idx; i < res_size; ++i) { 71 | znx_zero_i64_avx(nn, res + i * res_sl); 72 | } 73 | } else { 74 | const uint64_t sub_idx = res_size < b_size ? res_size : b_size; 75 | const uint64_t copy_idx = res_size < a_size ? res_size : a_size; 76 | // subtract up to the smallest dimension 77 | for (uint64_t i = 0; i < sub_idx; ++i) { 78 | znx_sub_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); 79 | } 80 | // then copy to the largest dimension 81 | for (uint64_t i = sub_idx; i < copy_idx; ++i) { 82 | znx_copy_i64_avx(nn, res + i * res_sl, a + i * a_sl); 83 | } 84 | // then extend with zeros 85 | for (uint64_t i = copy_idx; i < res_size; ++i) { 86 | znx_zero_i64_avx(nn, res + i * res_sl); 87 | } 88 | } 89 | } 90 | 91 | EXPORT void vec_znx_negate_avx(const MODULE* module, // N 92 | int64_t* res, uint64_t res_size, uint64_t res_sl, // res 93 | const int64_t* a, uint64_t a_size, uint64_t a_sl // a 94 | ) { 95 | uint64_t nn = module->nn; 96 | uint64_t smin = res_size < a_size ? res_size : a_size; 97 | for (uint64_t i = 0; i < smin; ++i) { 98 | znx_negate_i64_avx(nn, res + i * res_sl, a + i * a_sl); 99 | } 100 | for (uint64_t i = smin; i < res_size; ++i) { 101 | znx_zero_i64_ref(nn, res + i * res_sl); 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /spqlios/arithmetic/vec_znx_dft_avx2.c: -------------------------------------------------------------------------------- 1 | #include "vec_znx_arithmetic_private.h" 2 | -------------------------------------------------------------------------------- /spqlios/arithmetic/zn_approxdecomp_ref.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "zn_arithmetic_private.h" 4 | 5 | EXPORT TNDBL_APPROXDECOMP_GADGET* new_tndbl_approxdecomp_gadget(const MOD_Z* module, // 6 | uint64_t k, uint64_t ell) { 7 | if (k * ell > 50) { 8 | return spqlios_error("approx decomposition requested is too precise for doubles"); 9 | } 10 | if (k < 1) { 11 | return spqlios_error("approx decomposition supports k>=1"); 12 | } 13 | TNDBL_APPROXDECOMP_GADGET* res = malloc(sizeof(TNDBL_APPROXDECOMP_GADGET)); 14 | memset(res, 0, sizeof(TNDBL_APPROXDECOMP_GADGET)); 15 | res->k = k; 16 | res->ell = ell; 17 | double add_cst = INT64_C(3) << (51 - k * ell); 18 | for (uint64_t i = 0; i < ell; ++i) { 19 | add_cst += pow(2., -(double)(i * k + 1)); 20 | } 21 | res->add_cst = add_cst; 22 | res->and_mask = (UINT64_C(1) << k) - 1; 23 | res->sub_cst = UINT64_C(1) << (k - 1); 24 | for (uint64_t i = 0; i < ell; ++i) res->rshifts[i] = (ell - 1 - i) * k; 25 | return res; 26 | } 27 | EXPORT void delete_tndbl_approxdecomp_gadget(TNDBL_APPROXDECOMP_GADGET* ptr) { free(ptr); } 28 | 29 | EXPORT int default_init_tndbl_approxdecomp_gadget(const MOD_Z* module, // 30 | TNDBL_APPROXDECOMP_GADGET* res, // 31 | uint64_t k, uint64_t ell) { 32 | return 0; 33 | } 34 | 35 | typedef union { 36 | double dv; 37 | uint64_t uv; 38 | } du_t; 39 | 40 | #define IMPL_ixx_approxdecomp_from_tndbl_ref(ITYPE) \ 41 | if (res_size != a_size * gadget->ell) NOT_IMPLEMENTED(); \ 42 | const uint64_t ell = gadget->ell; \ 43 | const double add_cst = gadget->add_cst; \ 44 | const uint8_t* const rshifts = gadget->rshifts; \ 45 | const ITYPE and_mask = gadget->and_mask; \ 46 | const ITYPE sub_cst = gadget->sub_cst; \ 47 | ITYPE* rr = res; \ 48 | const double* aa = a; \ 49 | const double* aaend = a + a_size; \ 50 | while (aa < aaend) { \ 51 | du_t t = {.dv = *aa + add_cst}; \ 52 | for (uint64_t i = 0; i < ell; ++i) { \ 53 | ITYPE v = (ITYPE)(t.uv >> rshifts[i]); \ 54 | *rr = (v & and_mask) - sub_cst; \ 55 | ++rr; \ 56 | } \ 57 | ++aa; \ 58 | } 59 | 60 | /** @brief sets res = gadget_decompose(a) (int8_t* output) */ 61 | EXPORT void default_i8_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N 62 | const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget 63 | int8_t* res, uint64_t res_size, // res (in general, size ell.a_size) 64 | const double* a, uint64_t a_size // 65 | ){IMPL_ixx_approxdecomp_from_tndbl_ref(int8_t)} 66 | 67 | /** @brief sets res = gadget_decompose(a) (int16_t* output) */ 68 | EXPORT void default_i16_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N 69 | const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget 70 | int16_t* res, uint64_t res_size, // res 71 | const double* a, uint64_t a_size // a 72 | ){IMPL_ixx_approxdecomp_from_tndbl_ref(int16_t)} 73 | 74 | /** @brief sets res = gadget_decompose(a) (int32_t* output) */ 75 | EXPORT void default_i32_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N 76 | const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget 77 | int32_t* res, uint64_t res_size, // res 78 | const double* a, uint64_t a_size // a 79 | ) { 80 | IMPL_ixx_approxdecomp_from_tndbl_ref(int32_t) 81 | } 82 | -------------------------------------------------------------------------------- /spqlios/arithmetic/zn_arithmetic_plugin.h: -------------------------------------------------------------------------------- 1 | #ifndef SPQLIOS_ZN_ARITHMETIC_PLUGIN_H 2 | #define SPQLIOS_ZN_ARITHMETIC_PLUGIN_H 3 | 4 | #include "zn_arithmetic.h" 5 | 6 | typedef typeof(i8_approxdecomp_from_tndbl) I8_APPROXDECOMP_FROM_TNDBL_F; 7 | typedef typeof(i16_approxdecomp_from_tndbl) I16_APPROXDECOMP_FROM_TNDBL_F; 8 | typedef typeof(i32_approxdecomp_from_tndbl) I32_APPROXDECOMP_FROM_TNDBL_F; 9 | typedef typeof(bytes_of_zn32_vmp_pmat) BYTES_OF_ZN32_VMP_PMAT_F; 10 | typedef typeof(zn32_vmp_prepare_contiguous) ZN32_VMP_PREPARE_CONTIGUOUS_F; 11 | typedef typeof(zn32_vmp_apply_i32) ZN32_VMP_APPLY_I32_F; 12 | typedef typeof(zn32_vmp_apply_i16) ZN32_VMP_APPLY_I16_F; 13 | typedef typeof(zn32_vmp_apply_i8) ZN32_VMP_APPLY_I8_F; 14 | typedef typeof(dbl_to_tn32) DBL_TO_TN32_F; 15 | typedef typeof(tn32_to_dbl) TN32_TO_DBL_F; 16 | typedef typeof(dbl_round_to_i32) DBL_ROUND_TO_I32_F; 17 | typedef typeof(i32_to_dbl) I32_TO_DBL_F; 18 | typedef typeof(dbl_round_to_i64) DBL_ROUND_TO_I64_F; 19 | typedef typeof(i64_to_dbl) I64_TO_DBL_F; 20 | 21 | typedef struct z_module_vtable_t Z_MODULE_VTABLE; 22 | struct z_module_vtable_t { 23 | I8_APPROXDECOMP_FROM_TNDBL_F* i8_approxdecomp_from_tndbl; 24 | I16_APPROXDECOMP_FROM_TNDBL_F* i16_approxdecomp_from_tndbl; 25 | I32_APPROXDECOMP_FROM_TNDBL_F* i32_approxdecomp_from_tndbl; 26 | BYTES_OF_ZN32_VMP_PMAT_F* bytes_of_zn32_vmp_pmat; 27 | ZN32_VMP_PREPARE_CONTIGUOUS_F* zn32_vmp_prepare_contiguous; 28 | ZN32_VMP_APPLY_I32_F* zn32_vmp_apply_i32; 29 | ZN32_VMP_APPLY_I16_F* zn32_vmp_apply_i16; 30 | ZN32_VMP_APPLY_I8_F* zn32_vmp_apply_i8; 31 | DBL_TO_TN32_F* dbl_to_tn32; 32 | TN32_TO_DBL_F* tn32_to_dbl; 33 | DBL_ROUND_TO_I32_F* dbl_round_to_i32; 34 | I32_TO_DBL_F* i32_to_dbl; 35 | DBL_ROUND_TO_I64_F* dbl_round_to_i64; 36 | I64_TO_DBL_F* i64_to_dbl; 37 | }; 38 | 39 | #endif // SPQLIOS_ZN_ARITHMETIC_PLUGIN_H 40 | -------------------------------------------------------------------------------- /spqlios/arithmetic/zn_vmp_int16_avx.c: -------------------------------------------------------------------------------- 1 | #define INTTYPE int16_t 2 | #define INTSN i16 3 | 4 | #include "zn_vmp_int32_avx.c" 5 | -------------------------------------------------------------------------------- /spqlios/arithmetic/zn_vmp_int16_ref.c: -------------------------------------------------------------------------------- 1 | #define INTTYPE int16_t 2 | #define INTSN i16 3 | 4 | #include "zn_vmp_int32_ref.c" 5 | -------------------------------------------------------------------------------- /spqlios/arithmetic/zn_vmp_int32_ref.c: -------------------------------------------------------------------------------- 1 | // This file is actually a template: it will be compiled multiple times with 2 | // different INTTYPES 3 | #ifndef INTTYPE 4 | #define INTTYPE int32_t 5 | #define INTSN i32 6 | #endif 7 | 8 | #include 9 | 10 | #include "zn_arithmetic_private.h" 11 | 12 | #define concat_inner(aa, bb, cc) aa##_##bb##_##cc 13 | #define concat(aa, bb, cc) concat_inner(aa, bb, cc) 14 | #define zn32_vec_fn(cc) concat(zn32_vec, INTSN, cc) 15 | 16 | // the ref version shares the same implementation for each fixed column size 17 | // optimized implementations may do something different. 18 | static __always_inline void IMPL_zn32_vec_matcols_ref( 19 | const uint64_t NCOLS, // fixed number of columns 20 | uint64_t nrows, // nrows of b 21 | int32_t* res, // result: size NCOLS, only the first min(b_sl, NCOLS) are relevant 22 | const INTTYPE* a, // a: nrows-sized vector 23 | const int32_t* b, uint64_t b_sl // b: nrows * min(b_sl, NCOLS) matrix 24 | ) { 25 | memset(res, 0, NCOLS * sizeof(int32_t)); 26 | for (uint64_t row = 0; row < nrows; ++row) { 27 | int32_t ai = a[row]; 28 | const int32_t* bb = b + row * b_sl; 29 | for (uint64_t i = 0; i < NCOLS; ++i) { 30 | res[i] += ai * bb[i]; 31 | } 32 | } 33 | } 34 | 35 | void zn32_vec_fn(mat32cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) { 36 | IMPL_zn32_vec_matcols_ref(32, nrows, res, a, b, b_sl); 37 | } 38 | void zn32_vec_fn(mat24cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) { 39 | IMPL_zn32_vec_matcols_ref(24, nrows, res, a, b, b_sl); 40 | } 41 | void zn32_vec_fn(mat16cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) { 42 | IMPL_zn32_vec_matcols_ref(16, nrows, res, a, b, b_sl); 43 | } 44 | void zn32_vec_fn(mat8cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) { 45 | IMPL_zn32_vec_matcols_ref(8, nrows, res, a, b, b_sl); 46 | } 47 | 48 | typedef void (*vm_f)(uint64_t nrows, // 49 | int32_t* res, // 50 | const INTTYPE* a, // 51 | const int32_t* b, uint64_t b_sl // 52 | ); 53 | static const vm_f zn32_vec_mat8kcols_ref[4] = { // 54 | zn32_vec_fn(mat8cols_ref), // 55 | zn32_vec_fn(mat16cols_ref), // 56 | zn32_vec_fn(mat24cols_ref), // 57 | zn32_vec_fn(mat32cols_ref)}; 58 | 59 | /** @brief applies a vmp product (int32_t* input) */ 60 | EXPORT void concat(default_zn32_vmp_apply, INTSN, ref)( // 61 | const MOD_Z* module, // 62 | int32_t* res, uint64_t res_size, // 63 | const INTTYPE* a, uint64_t a_size, // 64 | const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) { 65 | const uint64_t rows = a_size < nrows ? a_size : nrows; 66 | const uint64_t cols = res_size < ncols ? res_size : ncols; 67 | const uint64_t ncolblk = cols >> 5; 68 | const uint64_t ncolrem = cols & 31; 69 | // copy the first full blocks 70 | const uint32_t full_blk_size = nrows * 32; 71 | const int32_t* mat = (int32_t*)pmat; 72 | int32_t* rr = res; 73 | for (uint64_t blk = 0; // 74 | blk < ncolblk; // 75 | ++blk, mat += full_blk_size, rr += 32) { 76 | zn32_vec_fn(mat32cols_ref)(rows, rr, a, mat, 32); 77 | } 78 | // last block 79 | if (ncolrem) { 80 | uint64_t orig_rem = ncols - (ncolblk << 5); 81 | uint64_t b_sl = orig_rem >= 32 ? 32 : orig_rem; 82 | int32_t tmp[32]; 83 | zn32_vec_mat8kcols_ref[(ncolrem - 1) >> 3](rows, tmp, a, mat, b_sl); 84 | memcpy(rr, tmp, ncolrem * sizeof(int32_t)); 85 | } 86 | // trailing bytes 87 | memset(res + cols, 0, (res_size - cols) * sizeof(int32_t)); 88 | } 89 | -------------------------------------------------------------------------------- /spqlios/arithmetic/zn_vmp_int8_avx.c: -------------------------------------------------------------------------------- 1 | #define INTTYPE int8_t 2 | #define INTSN i8 3 | 4 | #include "zn_vmp_int32_avx.c" 5 | -------------------------------------------------------------------------------- /spqlios/arithmetic/zn_vmp_int8_ref.c: -------------------------------------------------------------------------------- 1 | #define INTTYPE int8_t 2 | #define INTSN i8 3 | 4 | #include "zn_vmp_int32_ref.c" 5 | -------------------------------------------------------------------------------- /spqlios/arithmetic/znx_small.c: -------------------------------------------------------------------------------- 1 | #include "vec_znx_arithmetic_private.h" 2 | 3 | /** @brief res = a * b : small integer polynomial product */ 4 | EXPORT void fft64_znx_small_single_product(const MODULE* module, // N 5 | int64_t* res, // output 6 | const int64_t* a, // a 7 | const int64_t* b, // b 8 | uint8_t* tmp) { 9 | const uint64_t nn = module->nn; 10 | double* const ffta = (double*)tmp; 11 | double* const fftb = ((double*)tmp) + nn; 12 | reim_from_znx64(module->mod.fft64.p_conv, ffta, a); 13 | reim_from_znx64(module->mod.fft64.p_conv, fftb, b); 14 | reim_fft(module->mod.fft64.p_fft, ffta); 15 | reim_fft(module->mod.fft64.p_fft, fftb); 16 | reim_fftvec_mul_simple(module->m, ffta, ffta, fftb); 17 | reim_ifft(module->mod.fft64.p_ifft, ffta); 18 | reim_to_znx64(module->mod.fft64.p_reim_to_znx, res, ffta); 19 | } 20 | 21 | /** @brief tmp bytes required for znx_small_single_product */ 22 | EXPORT uint64_t fft64_znx_small_single_product_tmp_bytes(const MODULE* module) { 23 | return 2 * module->nn * sizeof(double); 24 | } 25 | 26 | /** @brief res = a * b : small integer polynomial product */ 27 | EXPORT void znx_small_single_product(const MODULE* module, // N 28 | int64_t* res, // output 29 | const int64_t* a, // a 30 | const int64_t* b, // b 31 | uint8_t* tmp) { 32 | module->func.znx_small_single_product(module, res, a, b, tmp); 33 | } 34 | 35 | /** @brief tmp bytes required for znx_small_single_product */ 36 | EXPORT uint64_t znx_small_single_product_tmp_bytes(const MODULE* module) { 37 | return module->func.znx_small_single_product_tmp_bytes(module); 38 | } 39 | -------------------------------------------------------------------------------- /spqlios/coeffs/coeffs_arithmetic.h: -------------------------------------------------------------------------------- 1 | #ifndef SPQLIOS_COEFFS_ARITHMETIC_H 2 | #define SPQLIOS_COEFFS_ARITHMETIC_H 3 | 4 | #include "../commons.h" 5 | 6 | /** res = a + b */ 7 | EXPORT void znx_add_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b); 8 | EXPORT void znx_add_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b); 9 | /** res = a - b */ 10 | EXPORT void znx_sub_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b); 11 | EXPORT void znx_sub_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b); 12 | /** res = -a */ 13 | EXPORT void znx_negate_i64_ref(uint64_t nn, int64_t* res, const int64_t* a); 14 | EXPORT void znx_negate_i64_avx(uint64_t nn, int64_t* res, const int64_t* a); 15 | /** res = a */ 16 | EXPORT void znx_copy_i64_ref(uint64_t nn, int64_t* res, const int64_t* a); 17 | /** res = 0 */ 18 | EXPORT void znx_zero_i64_ref(uint64_t nn, int64_t* res); 19 | 20 | /** res = a / m where m is a power of 2 */ 21 | EXPORT void rnx_divide_by_m_ref(uint64_t nn, double m, double* res, const double* a); 22 | EXPORT void rnx_divide_by_m_avx(uint64_t nn, double m, double* res, const double* a); 23 | 24 | /** 25 | * @param res = X^p *in mod X^nn +1 26 | * @param nn the ring dimension 27 | * @param p a power for the rotation -2nn <= p <= 2nn 28 | * @param in is a rnx/znx vector of dimension nn 29 | */ 30 | EXPORT void rnx_rotate_f64(uint64_t nn, int64_t p, double* res, const double* in); 31 | EXPORT void znx_rotate_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in); 32 | EXPORT void rnx_rotate_inplace_f64(uint64_t nn, int64_t p, double* res); 33 | EXPORT void znx_rotate_inplace_i64(uint64_t nn, int64_t p, int64_t* res); 34 | 35 | /** 36 | * @brief res(X) = in(X^p) 37 | * @param nn the ring dimension 38 | * @param p is odd integer and must be between 0 < p < 2nn 39 | * @param in is a rnx/znx vector of dimension nn 40 | */ 41 | EXPORT void rnx_automorphism_f64(uint64_t nn, int64_t p, double* res, const double* in); 42 | EXPORT void znx_automorphism_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in); 43 | EXPORT void rnx_automorphism_inplace_f64(uint64_t nn, int64_t p, double* res); 44 | EXPORT void znx_automorphism_inplace_i64(uint64_t nn, int64_t p, int64_t* res); 45 | 46 | /** 47 | * @brief res = (X^p-1).in 48 | * @param nn the ring dimension 49 | * @param p must be between -2nn <= p <= 2nn 50 | * @param in is a rnx/znx vector of dimension nn 51 | */ 52 | EXPORT void rnx_mul_xp_minus_one(uint64_t nn, int64_t p, double* res, const double* in); 53 | EXPORT void znx_mul_xp_minus_one(uint64_t nn, int64_t p, int64_t* res, const int64_t* in); 54 | EXPORT void rnx_mul_xp_minus_one_inplace(uint64_t nn, int64_t p, double* res); 55 | 56 | /** 57 | * @brief Normalize input plus carry mod-2^k. The following 58 | * equality holds @c {in + carry_in == out + carry_out . 2^k}. 59 | * 60 | * @c in must be in [-2^62 .. 2^62] 61 | * 62 | * @c out is in [ -2^(base_k-1), 2^(base_k-1) [. 63 | * 64 | * @c carry_in and @carry_out have at most 64+1-k bits. 65 | * 66 | * Null @c carry_in or @c carry_out are ignored. 67 | * 68 | * @param[in] nn the ring dimension 69 | * @param[in] base_k the base k 70 | * @param out output normalized znx 71 | * @param carry_out output carry znx 72 | * @param[in] in input znx 73 | * @param[in] carry_in input carry znx 74 | */ 75 | EXPORT void znx_normalize(uint64_t nn, uint64_t base_k, int64_t* out, int64_t* carry_out, const int64_t* in, 76 | const int64_t* carry_in); 77 | 78 | #endif // SPQLIOS_COEFFS_ARITHMETIC_H 79 | -------------------------------------------------------------------------------- /spqlios/coeffs/coeffs_arithmetic_avx.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "../commons_private.h" 4 | #include "coeffs_arithmetic.h" 5 | 6 | // res = a + b. dimension n must be a power of 2 7 | EXPORT void znx_add_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) { 8 | if (nn <= 2) { 9 | if (nn == 1) { 10 | res[0] = a[0] + b[0]; 11 | } else { 12 | _mm_storeu_si128((__m128i*)res, // 13 | _mm_add_epi64( // 14 | _mm_loadu_si128((__m128i*)a), // 15 | _mm_loadu_si128((__m128i*)b))); 16 | } 17 | } else { 18 | const __m256i* aa = (__m256i*)a; 19 | const __m256i* bb = (__m256i*)b; 20 | __m256i* rr = (__m256i*)res; 21 | __m256i* const rrend = (__m256i*)(res + nn); 22 | do { 23 | _mm256_storeu_si256(rr, // 24 | _mm256_add_epi64( // 25 | _mm256_loadu_si256(aa), // 26 | _mm256_loadu_si256(bb))); 27 | ++rr; 28 | ++aa; 29 | ++bb; 30 | } while (rr < rrend); 31 | } 32 | } 33 | 34 | // res = a - b. dimension n must be a power of 2 35 | EXPORT void znx_sub_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) { 36 | if (nn <= 2) { 37 | if (nn == 1) { 38 | res[0] = a[0] - b[0]; 39 | } else { 40 | _mm_storeu_si128((__m128i*)res, // 41 | _mm_sub_epi64( // 42 | _mm_loadu_si128((__m128i*)a), // 43 | _mm_loadu_si128((__m128i*)b))); 44 | } 45 | } else { 46 | const __m256i* aa = (__m256i*)a; 47 | const __m256i* bb = (__m256i*)b; 48 | __m256i* rr = (__m256i*)res; 49 | __m256i* const rrend = (__m256i*)(res + nn); 50 | do { 51 | _mm256_storeu_si256(rr, // 52 | _mm256_sub_epi64( // 53 | _mm256_loadu_si256(aa), // 54 | _mm256_loadu_si256(bb))); 55 | ++rr; 56 | ++aa; 57 | ++bb; 58 | } while (rr < rrend); 59 | } 60 | } 61 | 62 | EXPORT void znx_negate_i64_avx(uint64_t nn, int64_t* res, const int64_t* a) { 63 | if (nn <= 2) { 64 | if (nn == 1) { 65 | res[0] = -a[0]; 66 | } else { 67 | _mm_storeu_si128((__m128i*)res, // 68 | _mm_sub_epi64( // 69 | _mm_set1_epi64x(0), // 70 | _mm_loadu_si128((__m128i*)a))); 71 | } 72 | } else { 73 | const __m256i* aa = (__m256i*)a; 74 | __m256i* rr = (__m256i*)res; 75 | __m256i* const rrend = (__m256i*)(res + nn); 76 | do { 77 | _mm256_storeu_si256(rr, // 78 | _mm256_sub_epi64( // 79 | _mm256_set1_epi64x(0), // 80 | _mm256_loadu_si256(aa))); 81 | ++rr; 82 | ++aa; 83 | } while (rr < rrend); 84 | } 85 | } 86 | 87 | EXPORT void rnx_divide_by_m_avx(uint64_t n, double m, double* res, const double* a) { 88 | // TODO: see if there is a faster way of dividing by a power of 2? 89 | const double invm = 1. / m; 90 | if (n < 8) { 91 | switch (n) { 92 | case 1: 93 | *res = *a * invm; 94 | break; 95 | case 2: 96 | _mm_storeu_pd(res, // 97 | _mm_mul_pd(_mm_loadu_pd(a), // 98 | _mm_set1_pd(invm))); 99 | break; 100 | case 4: 101 | _mm256_storeu_pd(res, // 102 | _mm256_mul_pd(_mm256_loadu_pd(a), // 103 | _mm256_set1_pd(invm))); 104 | break; 105 | default: 106 | NOT_SUPPORTED(); // non-power of 2 107 | } 108 | return; 109 | } 110 | const __m256d invm256 = _mm256_set1_pd(invm); 111 | double* rr = res; 112 | const double* aa = a; 113 | const double* const aaend = a + n; 114 | do { 115 | _mm256_storeu_pd(rr, // 116 | _mm256_mul_pd(_mm256_loadu_pd(aa), // 117 | invm256)); 118 | _mm256_storeu_pd(rr + 4, // 119 | _mm256_mul_pd(_mm256_loadu_pd(aa + 4), // 120 | invm256)); 121 | rr += 8; 122 | aa += 8; 123 | } while (aa < aaend); 124 | } 125 | -------------------------------------------------------------------------------- /spqlios/commons.h: -------------------------------------------------------------------------------- 1 | #ifndef SPQLIOS_COMMONS_H 2 | #define SPQLIOS_COMMONS_H 3 | 4 | #ifdef __cplusplus 5 | #include 6 | #include 7 | #include 8 | #define EXPORT extern "C" 9 | #define EXPORT_DECL extern "C" 10 | #else 11 | #include 12 | #include 13 | #include 14 | #define EXPORT 15 | #define EXPORT_DECL extern 16 | #define nullptr 0x0; 17 | #endif 18 | 19 | #define UNDEFINED() \ 20 | { \ 21 | fprintf(stderr, "UNDEFINED!!!\n"); \ 22 | abort(); \ 23 | } 24 | #define NOT_IMPLEMENTED() \ 25 | { \ 26 | fprintf(stderr, "NOT IMPLEMENTED!!!\n"); \ 27 | abort(); \ 28 | } 29 | #define FATAL_ERROR(MESSAGE) \ 30 | { \ 31 | fprintf(stderr, "ERROR: %s\n", (MESSAGE)); \ 32 | abort(); \ 33 | } 34 | 35 | EXPORT void* UNDEFINED_p_ii(int32_t n, int32_t m); 36 | EXPORT void* UNDEFINED_p_uu(uint32_t n, uint32_t m); 37 | EXPORT double* UNDEFINED_dp_pi(const void* p, int32_t n); 38 | EXPORT void* UNDEFINED_vp_pi(const void* p, int32_t n); 39 | EXPORT void* UNDEFINED_vp_pu(const void* p, uint32_t n); 40 | EXPORT void UNDEFINED_v_vpdp(const void* p, double* a); 41 | EXPORT void UNDEFINED_v_vpvp(const void* p, void* a); 42 | EXPORT double* NOT_IMPLEMENTED_dp_i(int32_t n); 43 | EXPORT void* NOT_IMPLEMENTED_vp_i(int32_t n); 44 | EXPORT void* NOT_IMPLEMENTED_vp_u(uint32_t n); 45 | EXPORT void NOT_IMPLEMENTED_v_dp(double* a); 46 | EXPORT void NOT_IMPLEMENTED_v_vp(void* p); 47 | EXPORT void NOT_IMPLEMENTED_v_idpdpdp(int32_t n, double* a, const double* b, const double* c); 48 | EXPORT void NOT_IMPLEMENTED_v_uvpcvpcvp(uint32_t n, void* r, const void* a, const void* b); 49 | EXPORT void NOT_IMPLEMENTED_v_uvpvpcvp(uint32_t n, void* a, void* b, const void* o); 50 | 51 | // windows 52 | 53 | #if defined(_WIN32) || defined(__APPLE__) 54 | #define __always_inline inline __attribute((always_inline)) 55 | #endif 56 | 57 | EXPORT void spqlios_free(void* address); 58 | 59 | EXPORT void* spqlios_alloc(uint64_t size); 60 | EXPORT void* spqlios_alloc_custom_align(uint64_t align, uint64_t size); 61 | 62 | #define USE_LIBM_SIN_COS 63 | #ifndef USE_LIBM_SIN_COS 64 | // if at some point, we want to remove the libm dependency, we can 65 | // consider this: 66 | EXPORT double internal_accurate_cos(double x); 67 | EXPORT double internal_accurate_sin(double x); 68 | EXPORT void internal_accurate_sincos(double* rcos, double* rsin, double x); 69 | #define m_accurate_cos internal_accurate_cos 70 | #define m_accurate_sin internal_accurate_sin 71 | #else 72 | // let's use libm sin and cos 73 | #define m_accurate_cos cos 74 | #define m_accurate_sin sin 75 | #endif 76 | 77 | #endif // SPQLIOS_COMMONS_H 78 | -------------------------------------------------------------------------------- /spqlios/commons_private.c: -------------------------------------------------------------------------------- 1 | #include "commons_private.h" 2 | 3 | #include 4 | #include 5 | 6 | #include "commons.h" 7 | 8 | EXPORT void* spqlios_error(const char* error) { 9 | fputs(error, stderr); 10 | abort(); 11 | return nullptr; 12 | } 13 | EXPORT void* spqlios_keep_or_free(void* ptr, void* ptr2) { 14 | if (!ptr2) { 15 | free(ptr); 16 | } 17 | return ptr2; 18 | } 19 | 20 | EXPORT uint32_t log2m(uint32_t m) { 21 | uint32_t a = m - 1; 22 | if (m & a) FATAL_ERROR("m must be a power of two"); 23 | a = (a & 0x55555555u) + ((a >> 1) & 0x55555555u); 24 | a = (a & 0x33333333u) + ((a >> 2) & 0x33333333u); 25 | a = (a & 0x0F0F0F0Fu) + ((a >> 4) & 0x0F0F0F0Fu); 26 | a = (a & 0x00FF00FFu) + ((a >> 8) & 0x00FF00FFu); 27 | return (a & 0x0000FFFFu) + ((a >> 16) & 0x0000FFFFu); 28 | } 29 | 30 | EXPORT uint64_t is_not_pow2_double(void* doublevalue) { return (*(uint64_t*)doublevalue) & 0x7FFFFFFFFFFFFUL; } 31 | 32 | uint32_t revbits(uint32_t nbits, uint32_t value) { 33 | uint32_t res = 0; 34 | for (uint32_t i = 0; i < nbits; ++i) { 35 | res = (res << 1) + (value & 1); 36 | value >>= 1; 37 | } 38 | return res; 39 | } 40 | 41 | /** 42 | * @brief this computes the sequence: 0,1/2,1/4,3/4,1/8,5/8,3/8,7/8,... 43 | * essentially: the bits of (i+1) in lsb order on the basis (1/2^k) mod 1*/ 44 | double fracrevbits(uint32_t i) { 45 | if (i == 0) return 0; 46 | if (i == 1) return 0.5; 47 | if (i % 2 == 0) 48 | return fracrevbits(i / 2) / 2.; 49 | else 50 | return fracrevbits((i - 1) / 2) / 2. + 0.5; 51 | } 52 | 53 | uint64_t ceilto64b(uint64_t size) { return (size + UINT64_C(63)) & (UINT64_C(-64)); } 54 | 55 | uint64_t ceilto32b(uint64_t size) { return (size + UINT64_C(31)) & (UINT64_C(-32)); } 56 | -------------------------------------------------------------------------------- /spqlios/commons_private.h: -------------------------------------------------------------------------------- 1 | #ifndef SPQLIOS_COMMONS_PRIVATE_H 2 | #define SPQLIOS_COMMONS_PRIVATE_H 3 | 4 | #include "commons.h" 5 | 6 | #ifdef __cplusplus 7 | #include 8 | #include 9 | #include 10 | #else 11 | #include 12 | #include 13 | #include 14 | #define nullptr 0x0; 15 | #endif 16 | 17 | /** @brief log2 of a power of two (UB if m is not a power of two) */ 18 | EXPORT uint32_t log2m(uint32_t m); 19 | 20 | /** @brief checks if the doublevalue is a power of two */ 21 | EXPORT uint64_t is_not_pow2_double(void* doublevalue); 22 | 23 | #define UNDEFINED() \ 24 | { \ 25 | fprintf(stderr, "UNDEFINED!!!\n"); \ 26 | abort(); \ 27 | } 28 | #define NOT_IMPLEMENTED() \ 29 | { \ 30 | fprintf(stderr, "NOT IMPLEMENTED!!!\n"); \ 31 | abort(); \ 32 | } 33 | #define NOT_SUPPORTED() \ 34 | { \ 35 | fprintf(stderr, "NOT SUPPORTED!!!\n"); \ 36 | abort(); \ 37 | } 38 | #define FATAL_ERROR(MESSAGE) \ 39 | { \ 40 | fprintf(stderr, "ERROR: %s\n", (MESSAGE)); \ 41 | abort(); \ 42 | } 43 | 44 | #define STATIC_ASSERT(condition) (void)sizeof(char[-1 + 2 * !!(condition)]) 45 | 46 | /** @brief reports the error and returns nullptr */ 47 | EXPORT void* spqlios_error(const char* error); 48 | /** @brief if ptr2 is not null, returns ptr, otherwise free ptr and return null */ 49 | EXPORT void* spqlios_keep_or_free(void* ptr, void* ptr2); 50 | 51 | #ifdef __x86_64__ 52 | #define CPU_SUPPORTS __builtin_cpu_supports 53 | #else 54 | // TODO for now, we do not have any optimization for non x86 targets 55 | #define CPU_SUPPORTS(xxxx) 0 56 | #endif 57 | 58 | /** @brief returns the n bits of value in reversed order */ 59 | EXPORT uint32_t revbits(uint32_t nbits, uint32_t value); 60 | 61 | /** 62 | * @brief this computes the sequence: 0,1/2,1/4,3/4,1/8,5/8,3/8,7/8,... 63 | * essentially: the bits of (i+1) in lsb order on the basis (1/2^k) mod 1*/ 64 | EXPORT double fracrevbits(uint32_t i); 65 | 66 | /** @brief smallest multiple of 64 higher or equal to size */ 67 | EXPORT uint64_t ceilto64b(uint64_t size); 68 | 69 | /** @brief smallest multiple of 32 higher or equal to size */ 70 | EXPORT uint64_t ceilto32b(uint64_t size); 71 | 72 | #endif // SPQLIOS_COMMONS_PRIVATE_H 73 | -------------------------------------------------------------------------------- /spqlios/cplx/README.md: -------------------------------------------------------------------------------- 1 | In this folder, we deal with the full complex FFT in `C[X] mod X^M-i`. 2 | One complex is represented by two consecutive doubles `(real,imag)` 3 | Note that a real polynomial sum_{j=0}^{N-1} p_j.X^j mod X^N+1 4 | corresponds to the complex polynomial of half degree `M=N/2`: 5 | `sum_{j=0}^{M-1} (p_{j} + i.p_{j+M}) X^j mod X^M-i` 6 | 7 | For a complex polynomial A(X) sum c_i X^i of degree M-1 8 | or a real polynomial sum a_i X^i of degree N 9 | 10 | coefficient space: 11 | a_0,a_M,a_1,a_{M+1},...,a_{M-1},a_{2M-1} 12 | or equivalently 13 | Re(c_0),Im(c_0),Re(c_1),Im(c_1),...Re(c_{M-1}),Im(c_{M-1}) 14 | 15 | eval space: 16 | c(omega_{0}),...,c(omega_{M-1}) 17 | 18 | where 19 | omega_j = omega^{1+rev_{2N}(j)} 20 | and omega = exp(i.pi/N) 21 | 22 | rev_{2N}(j) is the number that has the log2(2N) bits of j in reverse order. -------------------------------------------------------------------------------- /spqlios/cplx/cplx_common.c: -------------------------------------------------------------------------------- 1 | #include "cplx_fft_internal.h" 2 | 3 | void cplx_set(CPLX r, const CPLX a) { 4 | r[0] = a[0]; 5 | r[1] = a[1]; 6 | } 7 | void cplx_neg(CPLX r, const CPLX a) { 8 | r[0] = -a[0]; 9 | r[1] = -a[1]; 10 | } 11 | void cplx_add(CPLX r, const CPLX a, const CPLX b) { 12 | r[0] = a[0] + b[0]; 13 | r[1] = a[1] + b[1]; 14 | } 15 | void cplx_sub(CPLX r, const CPLX a, const CPLX b) { 16 | r[0] = a[0] - b[0]; 17 | r[1] = a[1] - b[1]; 18 | } 19 | void cplx_mul(CPLX r, const CPLX a, const CPLX b) { 20 | double re = a[0] * b[0] - a[1] * b[1]; 21 | r[1] = a[0] * b[1] + a[1] * b[0]; 22 | r[0] = re; 23 | } 24 | 25 | /** 26 | * @brief splits 2h evaluations of one polynomials into 2 times h evaluations of even/odd polynomial 27 | * Input: Q_0(y),...,Q_{h-1}(y),Q_0(-y),...,Q_{h-1}(-y) 28 | * Output: P_0(z),...,P_{h-1}(z),P_h(z),...,P_{2h-1}(z) 29 | * where Q_i(X)=P_i(X^2)+X.P_{h+i}(X^2) and y^2 = z 30 | * @param h number of "coefficients" h >= 1 31 | * @param data 2h complex coefficients interleaved and 256b aligned 32 | * @param powom y represented as (yre,yim) 33 | */ 34 | EXPORT void cplx_split_fft_ref(int32_t h, CPLX* data, const CPLX powom) { 35 | CPLX* d0 = data; 36 | CPLX* d1 = data + h; 37 | for (uint64_t i = 0; i < h; ++i) { 38 | CPLX diff; 39 | cplx_sub(diff, d0[i], d1[i]); 40 | cplx_add(d0[i], d0[i], d1[i]); 41 | cplx_mul(d1[i], diff, powom); 42 | } 43 | } 44 | 45 | /** 46 | * @brief Do two layers of itwiddle (i.e. split). 47 | * Input/output: d0,d1,d2,d3 of length h 48 | * Algo: 49 | * itwiddle(d0,d1,om[0]),itwiddle(d2,d3,i.om[0]) 50 | * itwiddle(d0,d2,om[1]),itwiddle(d1,d3,om[1]) 51 | * @param h number of "coefficients" h >= 1 52 | * @param data 4h complex coefficients interleaved and 256b aligned 53 | * @param powom om[0] (re,im) and om[1] where om[1]=om[0]^2 54 | */ 55 | EXPORT void cplx_bisplit_fft_ref(int32_t h, CPLX* data, const CPLX powom[2]) { 56 | CPLX* d0 = data; 57 | CPLX* d2 = data + 2*h; 58 | const CPLX* om0 = powom; 59 | CPLX iom0; 60 | iom0[0]=powom[0][1]; 61 | iom0[1]=-powom[0][0]; 62 | const CPLX* om1 = powom+1; 63 | cplx_split_fft_ref(h, d0, *om0); 64 | cplx_split_fft_ref(h, d2, iom0); 65 | cplx_split_fft_ref(2*h, d0, *om1); 66 | } 67 | 68 | /** 69 | * Input: Q(y),Q(-y) 70 | * Output: P_0(z),P_1(z) 71 | * where Q(X)=P_0(X^2)+X.P_1(X^2) and y^2 = z 72 | * @param data 2 complexes coefficients interleaved and 256b aligned 73 | * @param powom (z,-z) interleaved: (zre,zim,-zre,-zim) 74 | */ 75 | void split_fft_last_ref(CPLX* data, const CPLX powom) { 76 | CPLX diff; 77 | cplx_sub(diff, data[0], data[1]); 78 | cplx_add(data[0], data[0], data[1]); 79 | cplx_mul(data[1], diff, powom); 80 | } 81 | -------------------------------------------------------------------------------- /spqlios/cplx/cplx_execute.c: -------------------------------------------------------------------------------- 1 | #include "cplx_fft_internal.h" 2 | #include "cplx_fft_private.h" 3 | 4 | EXPORT void cplx_from_znx32(const CPLX_FROM_ZNX32_PRECOMP* tables, void* r, const int32_t* a) { 5 | tables->function(tables, r, a); 6 | } 7 | EXPORT void cplx_from_tnx32(const CPLX_FROM_TNX32_PRECOMP* tables, void* r, const int32_t* a) { 8 | tables->function(tables, r, a); 9 | } 10 | EXPORT void cplx_to_tnx32(const CPLX_TO_TNX32_PRECOMP* tables, int32_t* r, const void* a) { 11 | tables->function(tables, r, a); 12 | } 13 | EXPORT void cplx_fftvec_mul(const CPLX_FFTVEC_MUL_PRECOMP* tables, void* r, const void* a, const void* b) { 14 | tables->function(tables, r, a, b); 15 | } 16 | EXPORT void cplx_fftvec_addmul(const CPLX_FFTVEC_ADDMUL_PRECOMP* tables, void* r, const void* a, const void* b) { 17 | tables->function(tables, r, a, b); 18 | } 19 | -------------------------------------------------------------------------------- /spqlios/cplx/cplx_fallbacks_aarch64.c: -------------------------------------------------------------------------------- 1 | #include "cplx_fft_internal.h" 2 | #include "cplx_fft_private.h" 3 | 4 | EXPORT void cplx_fftvec_addmul_fma(const CPLX_FFTVEC_ADDMUL_PRECOMP* tables, void* r, const void* a, const void* b) { 5 | UNDEFINED(); // not defined for non x86 targets 6 | } 7 | EXPORT void cplx_fftvec_mul_fma(const CPLX_FFTVEC_MUL_PRECOMP* tables, void* r, const void* a, const void* b) { 8 | UNDEFINED(); 9 | } 10 | EXPORT void cplx_fftvec_addmul_sse(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b) { 11 | UNDEFINED(); 12 | } 13 | EXPORT void cplx_fftvec_addmul_avx512(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, 14 | const void* b) { 15 | UNDEFINED(); 16 | } 17 | EXPORT void cplx_fft16_avx_fma(void* data, const void* omega) { UNDEFINED(); } 18 | EXPORT void cplx_ifft16_avx_fma(void* data, const void* omega) { UNDEFINED(); } 19 | EXPORT void cplx_from_znx32_avx2_fma(const CPLX_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x) { UNDEFINED(); } 20 | EXPORT void cplx_from_tnx32_avx2_fma(const CPLX_FROM_TNX32_PRECOMP* precomp, void* r, const int32_t* x) { UNDEFINED(); } 21 | EXPORT void cplx_to_tnx32_avx2_fma(const CPLX_TO_TNX32_PRECOMP* precomp, int32_t* x, const void* c) { UNDEFINED(); } 22 | EXPORT void cplx_fft_avx2_fma(const CPLX_FFT_PRECOMP* tables, void* data){UNDEFINED()} EXPORT 23 | void cplx_ifft_avx2_fma(const CPLX_IFFT_PRECOMP* itables, void* data){UNDEFINED()} EXPORT 24 | void cplx_fftvec_twiddle_fma(const CPLX_FFTVEC_TWIDDLE_PRECOMP* tables, void* a, void* b, const void* om){ 25 | UNDEFINED()} EXPORT void cplx_fftvec_twiddle_avx512(const CPLX_FFTVEC_TWIDDLE_PRECOMP* tables, void* a, void* b, 26 | const void* om){UNDEFINED()} EXPORT 27 | void cplx_fftvec_bitwiddle_fma(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* tables, void* a, uint64_t slice, 28 | const void* om){UNDEFINED()} EXPORT 29 | void cplx_fftvec_bitwiddle_avx512(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* tables, void* a, uint64_t slice, 30 | const void* om){UNDEFINED()} 31 | 32 | // DEPRECATED? 33 | EXPORT void cplx_fftvec_add_fma(uint32_t m, void* r, const void* a, const void* b){UNDEFINED()} EXPORT 34 | void cplx_fftvec_sub2_to_fma(uint32_t m, void* r, const void* a, const void* b){UNDEFINED()} EXPORT 35 | void cplx_fftvec_copy_fma(uint32_t m, void* r, const void* a){UNDEFINED()} 36 | 37 | // executors 38 | //EXPORT void cplx_ifft(const CPLX_IFFT_PRECOMP* itables, void* data) { 39 | // itables->function(itables, data); 40 | //} 41 | //EXPORT void cplx_fft(const CPLX_FFT_PRECOMP* tables, void* data) { tables->function(tables, data); } 42 | -------------------------------------------------------------------------------- /spqlios/cplx/cplx_fft_asserts.c: -------------------------------------------------------------------------------- 1 | #include "cplx_fft_private.h" 2 | #include "../commons_private.h" 3 | 4 | __always_inline void my_asserts() { 5 | STATIC_ASSERT(sizeof(FFT_FUNCTION)==8); 6 | STATIC_ASSERT(sizeof(CPLX_FFT_PRECOMP)==40); 7 | STATIC_ASSERT(sizeof(CPLX_IFFT_PRECOMP)==40); 8 | } 9 | -------------------------------------------------------------------------------- /spqlios/cplx/cplx_fft_private.h: -------------------------------------------------------------------------------- 1 | #ifndef SPQLIOS_CPLX_FFT_PRIVATE_H 2 | #define SPQLIOS_CPLX_FFT_PRIVATE_H 3 | 4 | #include "cplx_fft.h" 5 | 6 | typedef struct cplx_twiddle_precomp CPLX_FFTVEC_TWIDDLE_PRECOMP; 7 | typedef struct cplx_bitwiddle_precomp CPLX_FFTVEC_BITWIDDLE_PRECOMP; 8 | 9 | typedef void (*IFFT_FUNCTION)(const CPLX_IFFT_PRECOMP*, void*); 10 | typedef void (*FFT_FUNCTION)(const CPLX_FFT_PRECOMP*, void*); 11 | // conversions 12 | typedef void (*FROM_ZNX32_FUNCTION)(const CPLX_FROM_ZNX32_PRECOMP*, void*, const int32_t*); 13 | typedef void (*TO_ZNX32_FUNCTION)(const CPLX_FROM_ZNX32_PRECOMP*, int32_t*, const void*); 14 | typedef void (*FROM_TNX32_FUNCTION)(const CPLX_FROM_TNX32_PRECOMP*, void*, const int32_t*); 15 | typedef void (*TO_TNX32_FUNCTION)(const CPLX_TO_TNX32_PRECOMP*, int32_t*, const void*); 16 | typedef void (*FROM_RNX64_FUNCTION)(const CPLX_FROM_RNX64_PRECOMP* precomp, void* r, const double* x); 17 | typedef void (*TO_RNX64_FUNCTION)(const CPLX_TO_RNX64_PRECOMP* precomp, double* r, const void* x); 18 | typedef void (*ROUND_TO_RNX64_FUNCTION)(const CPLX_ROUND_TO_RNX64_PRECOMP* precomp, double* r, const void* x); 19 | // fftvec operations 20 | typedef void (*FFTVEC_MUL_FUNCTION)(const CPLX_FFTVEC_MUL_PRECOMP*, void*, const void*, const void*); 21 | typedef void (*FFTVEC_ADDMUL_FUNCTION)(const CPLX_FFTVEC_ADDMUL_PRECOMP*, void*, const void*, const void*); 22 | 23 | typedef void (*FFTVEC_TWIDDLE_FUNCTION)(const CPLX_FFTVEC_TWIDDLE_PRECOMP*, void*, const void*, const void*); 24 | typedef void (*FFTVEC_BITWIDDLE_FUNCTION)(const CPLX_FFTVEC_BITWIDDLE_PRECOMP*, void*, uint64_t, const void*); 25 | 26 | struct cplx_ifft_precomp { 27 | IFFT_FUNCTION function; 28 | int64_t m; 29 | uint64_t buf_size; 30 | double* powomegas; 31 | void* aligned_buffers; 32 | }; 33 | 34 | struct cplx_fft_precomp { 35 | FFT_FUNCTION function; 36 | int64_t m; 37 | uint64_t buf_size; 38 | double* powomegas; 39 | void* aligned_buffers; 40 | }; 41 | 42 | struct cplx_from_znx32_precomp { 43 | FROM_ZNX32_FUNCTION function; 44 | int64_t m; 45 | }; 46 | 47 | struct cplx_to_znx32_precomp { 48 | TO_ZNX32_FUNCTION function; 49 | int64_t m; 50 | double divisor; 51 | }; 52 | 53 | struct cplx_from_tnx32_precomp { 54 | FROM_TNX32_FUNCTION function; 55 | int64_t m; 56 | }; 57 | 58 | struct cplx_to_tnx32_precomp { 59 | TO_TNX32_FUNCTION function; 60 | int64_t m; 61 | double divisor; 62 | }; 63 | 64 | struct cplx_from_rnx64_precomp { 65 | FROM_RNX64_FUNCTION function; 66 | int64_t m; 67 | }; 68 | 69 | struct cplx_to_rnx64_precomp { 70 | TO_RNX64_FUNCTION function; 71 | int64_t m; 72 | double divisor; 73 | }; 74 | 75 | struct cplx_round_to_rnx64_precomp { 76 | ROUND_TO_RNX64_FUNCTION function; 77 | int64_t m; 78 | double divisor; 79 | uint32_t log2bound; 80 | }; 81 | 82 | typedef struct cplx_mul_precomp { 83 | FFTVEC_MUL_FUNCTION function; 84 | int64_t m; 85 | } CPLX_FFTVEC_MUL_PRECOMP; 86 | 87 | typedef struct cplx_addmul_precomp { 88 | FFTVEC_ADDMUL_FUNCTION function; 89 | int64_t m; 90 | } CPLX_FFTVEC_ADDMUL_PRECOMP; 91 | 92 | struct cplx_twiddle_precomp { 93 | FFTVEC_TWIDDLE_FUNCTION function; 94 | int64_t m; 95 | }; 96 | 97 | struct cplx_bitwiddle_precomp { 98 | FFTVEC_BITWIDDLE_FUNCTION function; 99 | int64_t m; 100 | }; 101 | 102 | EXPORT void cplx_fftvec_twiddle_fma(const CPLX_FFTVEC_TWIDDLE_PRECOMP* tables, void* a, void* b, const void* om); 103 | EXPORT void cplx_fftvec_twiddle_avx512(const CPLX_FFTVEC_TWIDDLE_PRECOMP* tables, void* a, void* b, const void* om); 104 | EXPORT void cplx_fftvec_bitwiddle_fma(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* tables, void* a, uint64_t slice, 105 | const void* om); 106 | EXPORT void cplx_fftvec_bitwiddle_avx512(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* tables, void* a, uint64_t slice, 107 | const void* om); 108 | 109 | #endif // SPQLIOS_CPLX_FFT_PRIVATE_H 110 | -------------------------------------------------------------------------------- /spqlios/cplx/cplx_fftvec_ref.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "../commons_private.h" 4 | #include "cplx_fft_internal.h" 5 | #include "cplx_fft_private.h" 6 | 7 | EXPORT void cplx_fftvec_addmul_ref(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b) { 8 | const uint32_t m = precomp->m; 9 | const CPLX* aa = (CPLX*)a; 10 | const CPLX* bb = (CPLX*)b; 11 | CPLX* rr = (CPLX*)r; 12 | for (uint32_t i = 0; i < m; ++i) { 13 | const double re = aa[i][0] * bb[i][0] - aa[i][1] * bb[i][1]; 14 | const double im = aa[i][0] * bb[i][1] + aa[i][1] * bb[i][0]; 15 | rr[i][0] += re; 16 | rr[i][1] += im; 17 | } 18 | } 19 | 20 | EXPORT void cplx_fftvec_mul_ref(const CPLX_FFTVEC_MUL_PRECOMP* precomp, void* r, const void* a, const void* b) { 21 | const uint32_t m = precomp->m; 22 | const CPLX* aa = (CPLX*)a; 23 | const CPLX* bb = (CPLX*)b; 24 | CPLX* rr = (CPLX*)r; 25 | for (uint32_t i = 0; i < m; ++i) { 26 | const double re = aa[i][0] * bb[i][0] - aa[i][1] * bb[i][1]; 27 | const double im = aa[i][0] * bb[i][1] + aa[i][1] * bb[i][0]; 28 | rr[i][0] = re; 29 | rr[i][1] = im; 30 | } 31 | } 32 | 33 | EXPORT void* init_cplx_fftvec_addmul_precomp(CPLX_FFTVEC_ADDMUL_PRECOMP* r, uint32_t m) { 34 | if (m & (m - 1)) return spqlios_error("m must be a power of two"); 35 | r->m = m; 36 | if (m <= 4) { 37 | r->function = cplx_fftvec_addmul_ref; 38 | } else if (CPU_SUPPORTS("fma")) { 39 | r->function = cplx_fftvec_addmul_fma; 40 | } else { 41 | r->function = cplx_fftvec_addmul_ref; 42 | } 43 | return r; 44 | } 45 | 46 | EXPORT void* init_cplx_fftvec_mul_precomp(CPLX_FFTVEC_MUL_PRECOMP* r, uint32_t m) { 47 | if (m & (m - 1)) return spqlios_error("m must be a power of two"); 48 | r->m = m; 49 | if (m <= 4) { 50 | r->function = cplx_fftvec_mul_ref; 51 | } else if (CPU_SUPPORTS("fma")) { 52 | r->function = cplx_fftvec_mul_fma; 53 | } else { 54 | r->function = cplx_fftvec_mul_ref; 55 | } 56 | return r; 57 | } 58 | 59 | EXPORT CPLX_FFTVEC_ADDMUL_PRECOMP* new_cplx_fftvec_addmul_precomp(uint32_t m) { 60 | CPLX_FFTVEC_ADDMUL_PRECOMP* r = malloc(sizeof(CPLX_FFTVEC_MUL_PRECOMP)); 61 | return spqlios_keep_or_free(r, init_cplx_fftvec_addmul_precomp(r, m)); 62 | } 63 | 64 | EXPORT CPLX_FFTVEC_MUL_PRECOMP* new_cplx_fftvec_mul_precomp(uint32_t m) { 65 | CPLX_FFTVEC_MUL_PRECOMP* r = malloc(sizeof(CPLX_FFTVEC_MUL_PRECOMP)); 66 | return spqlios_keep_or_free(r, init_cplx_fftvec_mul_precomp(r, m)); 67 | } 68 | 69 | EXPORT void cplx_fftvec_mul_simple(uint32_t m, void* r, const void* a, const void* b) { 70 | static CPLX_FFTVEC_MUL_PRECOMP p[31] = {0}; 71 | CPLX_FFTVEC_MUL_PRECOMP* f = p + log2m(m); 72 | if (!f->function) { 73 | if (!init_cplx_fftvec_mul_precomp(f, m)) abort(); 74 | } 75 | f->function(f, r, a, b); 76 | } 77 | 78 | EXPORT void cplx_fftvec_addmul_simple(uint32_t m, void* r, const void* a, const void* b) { 79 | static CPLX_FFTVEC_ADDMUL_PRECOMP p[31] = {0}; 80 | CPLX_FFTVEC_ADDMUL_PRECOMP* f = p + log2m(m); 81 | if (!f->function) { 82 | if (!init_cplx_fftvec_addmul_precomp(f, m)) abort(); 83 | } 84 | f->function(f, r, a, b); 85 | } 86 | -------------------------------------------------------------------------------- /spqlios/cplx/spqlios_cplx_fft.c: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tfhe/spqlios-arithmetic/7ea875f61c51c67687ec9c15ae0eae04cda961f8/spqlios/cplx/spqlios_cplx_fft.c -------------------------------------------------------------------------------- /spqlios/ext/neon_accel/macrof.h: -------------------------------------------------------------------------------- 1 | /* 2 | * This file is extracted from the implementation of the FFT on Arm64/Neon 3 | * available in https://github.com/cothan/Falcon-Arm (neon/macrof.h). 4 | * ============================================================================= 5 | * Copyright (c) 2022 by Cryptographic Engineering Research Group (CERG) 6 | * ECE Department, George Mason University 7 | * Fairfax, VA, U.S.A. 8 | * @author: Duc Tri Nguyen dnguye69@gmu.edu, cothannguyen@gmail.com 9 | * Licensed under the Apache License, Version 2.0 (the "License"); 10 | * ============================================================================= 11 | * 12 | * This 64-bit Floating point NEON macro x1 has not been modified and is provided as is. 13 | */ 14 | 15 | #ifndef MACROF_H 16 | #define MACROF_H 17 | 18 | #include 19 | 20 | // c <= addr x1 21 | #define vload(c, addr) c = vld1q_f64(addr); 22 | // c <= addr interleave 2 23 | #define vload2(c, addr) c = vld2q_f64(addr); 24 | // c <= addr interleave 4 25 | #define vload4(c, addr) c = vld4q_f64(addr); 26 | 27 | #define vstore(addr, c) vst1q_f64(addr, c); 28 | // addr <= c 29 | #define vstore2(addr, c) vst2q_f64(addr, c); 30 | // addr <= c 31 | #define vstore4(addr, c) vst4q_f64(addr, c); 32 | 33 | // c <= addr x2 34 | #define vloadx2(c, addr) c = vld1q_f64_x2(addr); 35 | // c <= addr x3 36 | #define vloadx3(c, addr) c = vld1q_f64_x3(addr); 37 | 38 | // addr <= c 39 | #define vstorex2(addr, c) vst1q_f64_x2(addr, c); 40 | 41 | // c = a - b 42 | #define vfsub(c, a, b) c = vsubq_f64(a, b); 43 | 44 | // c = a + b 45 | #define vfadd(c, a, b) c = vaddq_f64(a, b); 46 | 47 | // c = a * b 48 | #define vfmul(c, a, b) c = vmulq_f64(a, b); 49 | 50 | // c = a * n (n is constant) 51 | #define vfmuln(c, a, n) c = vmulq_n_f64(a, n); 52 | 53 | // Swap from a|b to b|a 54 | #define vswap(c, a) c = vextq_f64(a, a, 1); 55 | 56 | // c = a * b[i] 57 | #define vfmul_lane(c, a, b, i) c = vmulq_laneq_f64(a, b, i); 58 | 59 | // c = 1/a 60 | #define vfinv(c, a) c = vdivq_f64(vdupq_n_f64(1.0), a); 61 | 62 | // c = -a 63 | #define vfneg(c, a) c = vnegq_f64(a); 64 | 65 | #define transpose_f64(a, b, t, ia, ib, it) \ 66 | t.val[it] = a.val[ia]; \ 67 | a.val[ia] = vzip1q_f64(a.val[ia], b.val[ib]); \ 68 | b.val[ib] = vzip2q_f64(t.val[it], b.val[ib]); 69 | 70 | /* 71 | * c = a + jb 72 | * c[0] = a[0] - b[1] 73 | * c[1] = a[1] + b[0] 74 | */ 75 | #define vfcaddj(c, a, b) c = vcaddq_rot90_f64(a, b); 76 | 77 | /* 78 | * c = a - jb 79 | * c[0] = a[0] + b[1] 80 | * c[1] = a[1] - b[0] 81 | */ 82 | #define vfcsubj(c, a, b) c = vcaddq_rot270_f64(a, b); 83 | 84 | // c[0] = c[0] + b[0]*a[0], c[1] = c[1] + b[1]*a[0] 85 | #define vfcmla(c, a, b) c = vcmlaq_f64(c, a, b); 86 | 87 | // c[0] = c[0] - b[1]*a[1], c[1] = c[1] + b[0]*a[1] 88 | #define vfcmla_90(c, a, b) c = vcmlaq_rot90_f64(c, a, b); 89 | 90 | // c[0] = c[0] - b[0]*a[0], c[1] = c[1] - b[1]*a[0] 91 | #define vfcmla_180(c, a, b) c = vcmlaq_rot180_f64(c, a, b); 92 | 93 | // c[0] = c[0] + b[1]*a[1], c[1] = c[1] - b[0]*a[1] 94 | #define vfcmla_270(c, a, b) c = vcmlaq_rot270_f64(c, a, b); 95 | 96 | /* 97 | * Complex MUL: c = a*b 98 | * c[0] = a[0]*b[0] - a[1]*b[1] 99 | * c[1] = a[0]*b[1] + a[1]*b[0] 100 | */ 101 | #define FPC_CMUL(c, a, b) \ 102 | c = vmulq_laneq_f64(b, a, 0); \ 103 | c = vcmlaq_rot90_f64(c, a, b); 104 | 105 | /* 106 | * Complex MUL: c = a * conjugate(b) = a * (b[0], -b[1]) 107 | * c[0] = b[0]*a[0] + b[1]*a[1] 108 | * c[1] = + b[0]*a[1] - b[1]*a[0] 109 | */ 110 | #define FPC_CMUL_CONJ(c, a, b) \ 111 | c = vmulq_laneq_f64(a, b, 0); \ 112 | c = vcmlaq_rot270_f64(c, b, a); 113 | 114 | #if FMA == 1 115 | // d = c + a *b 116 | #define vfmla(d, c, a, b) d = vfmaq_f64(c, a, b); 117 | // d = c - a * b 118 | #define vfmls(d, c, a, b) d = vfmsq_f64(c, a, b); 119 | // d = c + a * b[i] 120 | #define vfmla_lane(d, c, a, b, i) d = vfmaq_laneq_f64(c, a, b, i); 121 | // d = c - a * b[i] 122 | #define vfmls_lane(d, c, a, b, i) d = vfmsq_laneq_f64(c, a, b, i); 123 | 124 | #else 125 | // d = c + a *b 126 | #define vfmla(d, c, a, b) d = vaddq_f64(c, vmulq_f64(a, b)); 127 | // d = c - a *b 128 | #define vfmls(d, c, a, b) d = vsubq_f64(c, vmulq_f64(a, b)); 129 | // d = c + a * b[i] 130 | #define vfmla_lane(d, c, a, b, i) \ 131 | d = vaddq_f64(c, vmulq_laneq_f64(a, b, i)); 132 | 133 | #define vfmls_lane(d, c, a, b, i) \ 134 | d = vsubq_f64(c, vmulq_laneq_f64(a, b, i)); 135 | 136 | #endif 137 | 138 | #endif 139 | -------------------------------------------------------------------------------- /spqlios/q120/q120_arithmetic_private.h: -------------------------------------------------------------------------------- 1 | #ifndef SPQLIOS_Q120_ARITHMETIC_DEF_H 2 | #define SPQLIOS_Q120_ARITHMETIC_DEF_H 3 | 4 | #include 5 | 6 | typedef struct _q120_mat1col_product_baa_precomp { 7 | uint64_t h; 8 | uint64_t h_pow_red[4]; 9 | #ifndef NDEBUG 10 | double res_bit_size; 11 | #endif 12 | } q120_mat1col_product_baa_precomp; 13 | 14 | typedef struct _q120_mat1col_product_bbb_precomp { 15 | uint64_t h; 16 | uint64_t s1h_pow_red[4]; 17 | uint64_t s2l_pow_red[4]; 18 | uint64_t s2h_pow_red[4]; 19 | uint64_t s3l_pow_red[4]; 20 | uint64_t s3h_pow_red[4]; 21 | uint64_t s4l_pow_red[4]; 22 | uint64_t s4h_pow_red[4]; 23 | #ifndef NDEBUG 24 | double res_bit_size; 25 | #endif 26 | } q120_mat1col_product_bbb_precomp; 27 | 28 | typedef struct _q120_mat1col_product_bbc_precomp { 29 | uint64_t h; 30 | uint64_t s2l_pow_red[4]; 31 | uint64_t s2h_pow_red[4]; 32 | #ifndef NDEBUG 33 | double res_bit_size; 34 | #endif 35 | } q120_mat1col_product_bbc_precomp; 36 | 37 | #endif // SPQLIOS_Q120_ARITHMETIC_DEF_H 38 | -------------------------------------------------------------------------------- /spqlios/q120/q120_common.h: -------------------------------------------------------------------------------- 1 | #ifndef SPQLIOS_Q120_COMMON_H 2 | #define SPQLIOS_Q120_COMMON_H 3 | 4 | #include 5 | 6 | #if !defined(SPQLIOS_Q120_USE_29_BIT_PRIMES) && !defined(SPQLIOS_Q120_USE_30_BIT_PRIMES) && \ 7 | !defined(SPQLIOS_Q120_USE_31_BIT_PRIMES) 8 | #define SPQLIOS_Q120_USE_30_BIT_PRIMES 9 | #endif 10 | 11 | /** 12 | * 29-bit primes and 2*2^16 roots of unity 13 | */ 14 | #ifdef SPQLIOS_Q120_USE_29_BIT_PRIMES 15 | #define Q1 ((1u << 29) - 2 * (1u << 17) + 1) 16 | #define OMEGA1 78289835 17 | #define Q1_CRT_CST 301701286 // (Q2*Q3*Q4)^-1 mod Q1 18 | 19 | #define Q2 ((1u << 29) - 5 * (1u << 17) + 1) 20 | #define OMEGA2 178519192 21 | #define Q2_CRT_CST 536020447 // (Q1*Q3*Q4)^-1 mod Q2 22 | 23 | #define Q3 ((1u << 29) - 26 * (1u << 17) + 1) 24 | #define OMEGA3 483889678 25 | #define Q3_CRT_CST 86367873 // (Q1*Q2*Q4)^-1 mod Q3 26 | 27 | #define Q4 ((1u << 29) - 35 * (1u << 17) + 1) 28 | #define OMEGA4 239808033 29 | #define Q4_CRT_CST 147030781 // (Q1*Q2*Q3)^-1 mod Q4 30 | #endif 31 | 32 | /** 33 | * 30-bit primes and 2*2^16 roots of unity 34 | */ 35 | #ifdef SPQLIOS_Q120_USE_30_BIT_PRIMES 36 | #define Q1 ((1u << 30) - 2 * (1u << 17) + 1) 37 | #define OMEGA1 1070907127 38 | #define Q1_CRT_CST 43599465 // (Q2*Q3*Q4)^-1 mod Q1 39 | 40 | #define Q2 ((1u << 30) - 17 * (1u << 17) + 1) 41 | #define OMEGA2 315046632 42 | #define Q2_CRT_CST 292938863 // (Q1*Q3*Q4)^-1 mod Q2 43 | 44 | #define Q3 ((1u << 30) - 23 * (1u << 17) + 1) 45 | #define OMEGA3 309185662 46 | #define Q3_CRT_CST 594011630 // (Q1*Q2*Q4)^-1 mod Q3 47 | 48 | #define Q4 ((1u << 30) - 42 * (1u << 17) + 1) 49 | #define OMEGA4 846468380 50 | #define Q4_CRT_CST 140177212 // (Q1*Q2*Q3)^-1 mod Q4 51 | #endif 52 | 53 | /** 54 | * 31-bit primes and 2*2^16 roots of unity 55 | */ 56 | #ifdef SPQLIOS_Q120_USE_31_BIT_PRIMES 57 | #define Q1 ((1u << 31) - 1 * (1u << 17) + 1) 58 | #define OMEGA1 1615402923 59 | #define Q1_CRT_CST 1811422063 // (Q2*Q3*Q4)^-1 mod Q1 60 | 61 | #define Q2 ((1u << 31) - 4 * (1u << 17) + 1) 62 | #define OMEGA2 1137738560 63 | #define Q2_CRT_CST 2093150204 // (Q1*Q3*Q4)^-1 mod Q2 64 | 65 | #define Q3 ((1u << 31) - 11 * (1u << 17) + 1) 66 | #define OMEGA3 154880552 67 | #define Q3_CRT_CST 164149010 // (Q1*Q2*Q4)^-1 mod Q3 68 | 69 | #define Q4 ((1u << 31) - 23 * (1u << 17) + 1) 70 | #define OMEGA4 558784885 71 | #define Q4_CRT_CST 225197446 // (Q1*Q2*Q3)^-1 mod Q4 72 | #endif 73 | 74 | static const uint32_t PRIMES_VEC[4] = {Q1, Q2, Q3, Q4}; 75 | static const uint32_t OMEGAS_VEC[4] = {OMEGA1, OMEGA2, OMEGA3, OMEGA4}; 76 | 77 | #define MAX_ELL 10000 78 | 79 | // each number x mod Q120 is represented by uint64_t[4] with (non-unique) values (x mod q1, x mod q2,x mod q3,x mod q4), 80 | // each between [0 and 2^32-1] 81 | typedef struct _q120a q120a; 82 | 83 | // each number x mod Q120 is represented by uint64_t[4] with (non-unique) values (x mod q1, x mod q2,x mod q3,x mod q4), 84 | // each between [0 and 2^64-1] 85 | typedef struct _q120b q120b; 86 | 87 | // each number x mod Q120 is represented by uint32_t[8] with values (x mod q1, 2^32x mod q1, x mod q2, 2^32.x mod q2, x 88 | // mod q3, 2^32.x mod q3, x mod q4, 2^32.x mod q4) each between [0 and 2^32-1] 89 | typedef struct _q120c q120c; 90 | 91 | typedef struct _q120x2b q120x2b; 92 | typedef struct _q120x2c q120x2c; 93 | 94 | #endif // SPQLIOS_Q120_COMMON_H 95 | -------------------------------------------------------------------------------- /spqlios/q120/q120_fallbacks_aarch64.c: -------------------------------------------------------------------------------- 1 | #include "q120_ntt_private.h" 2 | 3 | EXPORT void q120_ntt_bb_avx2(const q120_ntt_precomp* const precomp, q120b* const data) { UNDEFINED(); } 4 | 5 | EXPORT void q120_intt_bb_avx2(const q120_ntt_precomp* const precomp, q120b* const data) { UNDEFINED(); } 6 | -------------------------------------------------------------------------------- /spqlios/q120/q120_ntt.h: -------------------------------------------------------------------------------- 1 | #ifndef SPQLIOS_Q120_NTT_H 2 | #define SPQLIOS_Q120_NTT_H 3 | 4 | #include "../commons.h" 5 | #include "q120_common.h" 6 | 7 | typedef struct _q120_ntt_precomp q120_ntt_precomp; 8 | 9 | EXPORT q120_ntt_precomp* q120_new_ntt_bb_precomp(const uint64_t n); 10 | EXPORT void q120_del_ntt_bb_precomp(q120_ntt_precomp* precomp); 11 | 12 | EXPORT q120_ntt_precomp* q120_new_intt_bb_precomp(const uint64_t n); 13 | EXPORT void q120_del_intt_bb_precomp(q120_ntt_precomp* precomp); 14 | 15 | /** 16 | * @brief computes a direct ntt in-place over data. 17 | */ 18 | EXPORT void q120_ntt_bb_avx2(const q120_ntt_precomp* const precomp, q120b* const data); 19 | 20 | /** 21 | * @brief computes an inverse ntt in-place over data. 22 | */ 23 | EXPORT void q120_intt_bb_avx2(const q120_ntt_precomp* const precomp, q120b* const data); 24 | 25 | #endif // SPQLIOS_Q120_NTT_H 26 | -------------------------------------------------------------------------------- /spqlios/q120/q120_ntt_private.h: -------------------------------------------------------------------------------- 1 | #include "q120_ntt.h" 2 | 3 | #ifndef NDEBUG 4 | #define CHECK_BOUNDS 1 5 | #define VERBOSE 6 | #else 7 | #define CHECK_BOUNDS 0 8 | #endif 9 | 10 | #ifndef VERBOSE 11 | #define LOG(...) ; 12 | #else 13 | #define LOG(...) printf(__VA_ARGS__); 14 | #endif 15 | 16 | typedef struct _q120_ntt_step_precomp { 17 | uint64_t q2bs[4]; // q2bs = 2^{bs-31}.q[k] 18 | uint64_t bs; // inputs at this iterations must be in Q_n 19 | uint64_t half_bs; // == ceil(bs/2) 20 | uint64_t mask; // (1< 2 | 3 | #include "reim_fft_private.h" 4 | 5 | void reim_from_znx64_bnd50_fma(const REIM_FROM_ZNX64_PRECOMP* precomp, void* r, const int64_t* x) { 6 | static const double EXPO = INT64_C(1) << 52; 7 | const int64_t ADD_CST = INT64_C(1) << 51; 8 | const double SUB_CST = INT64_C(3) << 51; 9 | 10 | const __m256d SUB_CST_4 = _mm256_set1_pd(SUB_CST); 11 | const __m256i ADD_CST_4 = _mm256_set1_epi64x(ADD_CST); 12 | const __m256d EXPO_4 = _mm256_set1_pd(EXPO); 13 | 14 | double(*out)[4] = (double(*)[4])r; 15 | __m256i* in = (__m256i*)x; 16 | __m256i* inend = (__m256i*)(x + (precomp->m << 1)); 17 | do { 18 | // read the next value 19 | __m256i a = _mm256_loadu_si256(in); 20 | a = _mm256_add_epi64(a, ADD_CST_4); 21 | __m256d ad = _mm256_castsi256_pd(a); 22 | ad = _mm256_or_pd(ad, EXPO_4); 23 | ad = _mm256_sub_pd(ad, SUB_CST_4); 24 | // store the next value 25 | _mm256_storeu_pd(out[0], ad); 26 | ++out; 27 | ++in; 28 | } while (in < inend); 29 | } 30 | 31 | // version where the output norm can be as big as 2^63 32 | void reim_to_znx64_avx2_bnd63_fma(const REIM_TO_ZNX64_PRECOMP* precomp, int64_t* r, const void* x) { 33 | static const uint64_t SIGN_MASK = 0x8000000000000000UL; 34 | static const uint64_t EXPO_MASK = 0x7FF0000000000000UL; 35 | static const uint64_t MANTISSA_MASK = 0x000FFFFFFFFFFFFFUL; 36 | static const uint64_t MANTISSA_MSB = 0x0010000000000000UL; 37 | const double divisor_bits = precomp->divisor * ((double)(INT64_C(1) << 52)); 38 | const double offset = precomp->divisor / 2.; 39 | 40 | const __m256d SIGN_MASK_4 = _mm256_castsi256_pd(_mm256_set1_epi64x(SIGN_MASK)); 41 | const __m256i EXPO_MASK_4 = _mm256_set1_epi64x(EXPO_MASK); 42 | const __m256i MANTISSA_MASK_4 = _mm256_set1_epi64x(MANTISSA_MASK); 43 | const __m256i MANTISSA_MSB_4 = _mm256_set1_epi64x(MANTISSA_MSB); 44 | const __m256d offset_4 = _mm256_set1_pd(offset); 45 | const __m256i divi_bits_4 = _mm256_castpd_si256(_mm256_set1_pd(divisor_bits)); 46 | 47 | double(*in)[4] = (double(*)[4])x; 48 | __m256i* out = (__m256i*)r; 49 | __m256i* outend = (__m256i*)(r + (precomp->m << 1)); 50 | do { 51 | // read the next value 52 | __m256d a = _mm256_loadu_pd(in[0]); 53 | // a += sign(a) * m/2 54 | __m256d asign = _mm256_and_pd(a, SIGN_MASK_4); 55 | a = _mm256_add_pd(a, _mm256_or_pd(asign, offset_4)); 56 | // sign: either 0 or -1 57 | __m256i sign_mask = _mm256_castpd_si256(asign); 58 | sign_mask = _mm256_sub_epi64(_mm256_set1_epi64x(0), _mm256_srli_epi64(sign_mask, 63)); 59 | // compute the exponents 60 | __m256i a0exp = _mm256_and_si256(_mm256_castpd_si256(a), EXPO_MASK_4); 61 | __m256i a0lsh = _mm256_sub_epi64(a0exp, divi_bits_4); 62 | __m256i a0rsh = _mm256_sub_epi64(divi_bits_4, a0exp); 63 | a0lsh = _mm256_srli_epi64(a0lsh, 52); 64 | a0rsh = _mm256_srli_epi64(a0rsh, 52); 65 | // compute the new mantissa 66 | __m256i a0pos = _mm256_and_si256(_mm256_castpd_si256(a), MANTISSA_MASK_4); 67 | a0pos = _mm256_or_si256(a0pos, MANTISSA_MSB_4); 68 | a0lsh = _mm256_sllv_epi64(a0pos, a0lsh); 69 | a0rsh = _mm256_srlv_epi64(a0pos, a0rsh); 70 | __m256i final = _mm256_or_si256(a0lsh, a0rsh); 71 | // negate if the sign was negative 72 | final = _mm256_xor_si256(final, sign_mask); 73 | final = _mm256_sub_epi64(final, sign_mask); 74 | // read the next value 75 | _mm256_storeu_si256(out, final); 76 | ++out; 77 | ++in; 78 | } while (out < outend); 79 | } 80 | 81 | // version where the output norm can be as big as 2^50 82 | void reim_to_znx64_avx2_bnd50_fma(const REIM_TO_ZNX64_PRECOMP* precomp, int64_t* r, const void* x) { 83 | static const uint64_t MANTISSA_MASK = 0x000FFFFFFFFFFFFFUL; 84 | const int64_t SUB_CST = INT64_C(1) << 51; 85 | const double add_cst = precomp->divisor * ((double)(INT64_C(3) << 51)); 86 | 87 | const __m256i SUB_CST_4 = _mm256_set1_epi64x(SUB_CST); 88 | const __m256d add_cst_4 = _mm256_set1_pd(add_cst); 89 | const __m256i MANTISSA_MASK_4 = _mm256_set1_epi64x(MANTISSA_MASK); 90 | 91 | double(*in)[4] = (double(*)[4])x; 92 | __m256i* out = (__m256i*)r; 93 | __m256i* outend = (__m256i*)(r + (precomp->m << 1)); 94 | do { 95 | // read the next value 96 | __m256d a = _mm256_loadu_pd(in[0]); 97 | a = _mm256_add_pd(a, add_cst_4); 98 | __m256i ai = _mm256_castpd_si256(a); 99 | ai = _mm256_and_si256(ai, MANTISSA_MASK_4); 100 | ai = _mm256_sub_epi64(ai, SUB_CST_4); 101 | // store the next value 102 | _mm256_storeu_si256(out, ai); 103 | ++out; 104 | ++in; 105 | } while (out < outend); 106 | } 107 | -------------------------------------------------------------------------------- /spqlios/reim/reim_execute.c: -------------------------------------------------------------------------------- 1 | #include "reim_fft_internal.h" 2 | #include "reim_fft_private.h" 3 | 4 | EXPORT void reim_from_znx32(const REIM_FROM_ZNX32_PRECOMP* tables, void* r, const int32_t* a) { 5 | tables->function(tables, r, a); 6 | } 7 | 8 | EXPORT void reim_from_tnx32(const REIM_FROM_TNX32_PRECOMP* tables, void* r, const int32_t* a) { 9 | tables->function(tables, r, a); 10 | } 11 | 12 | EXPORT void reim_to_tnx32(const REIM_TO_TNX32_PRECOMP* tables, int32_t* r, const void* a) { 13 | tables->function(tables, r, a); 14 | } 15 | 16 | EXPORT void reim_fftvec_mul(const REIM_FFTVEC_MUL_PRECOMP* tables, double* r, const double* a, const double* b) { 17 | tables->function(tables, r, a, b); 18 | } 19 | 20 | EXPORT void reim_fftvec_addmul(const REIM_FFTVEC_ADDMUL_PRECOMP* tables, double* r, const double* a, const double* b) { 21 | tables->function(tables, r, a, b); 22 | } 23 | -------------------------------------------------------------------------------- /spqlios/reim/reim_fallbacks_aarch64.c: -------------------------------------------------------------------------------- 1 | #include "reim_fft_private.h" 2 | 3 | EXPORT void reim_fftvec_addmul_fma(const REIM_FFTVEC_ADDMUL_PRECOMP* precomp, double* r, const double* a, 4 | const double* b) { 5 | UNDEFINED(); 6 | } 7 | EXPORT void reim_fftvec_mul_fma(const REIM_FFTVEC_MUL_PRECOMP* precomp, double* r, const double* a, const double* b) { 8 | UNDEFINED(); 9 | } 10 | 11 | EXPORT void reim_fft_avx2_fma(const REIM_FFT_PRECOMP* tables, double* data) { UNDEFINED(); } 12 | EXPORT void reim_ifft_avx2_fma(const REIM_IFFT_PRECOMP* tables, double* data) { UNDEFINED(); } 13 | 14 | //EXPORT void reim_fft(const REIM_FFT_PRECOMP* tables, double* data) { tables->function(tables, data); } 15 | //EXPORT void reim_ifft(const REIM_IFFT_PRECOMP* tables, double* data) { tables->function(tables, data); } 16 | -------------------------------------------------------------------------------- /spqlios/reim/reim_fft4_avx_fma.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "reim_fft_private.h" 6 | 7 | __always_inline void reim_ctwiddle_avx_fma(__m128d* ra, __m128d* rb, __m128d* ia, __m128d* ib, const __m128d omre, 8 | const __m128d omim) { 9 | // rb * omre - ib * omim; 10 | __m128d rprod0 = _mm_mul_pd(*ib, omim); 11 | rprod0 = _mm_fmsub_pd(*rb, omre, rprod0); 12 | 13 | // rb * omim + ib * omre; 14 | __m128d iprod0 = _mm_mul_pd(*rb, omim); 15 | iprod0 = _mm_fmadd_pd(*ib, omre, iprod0); 16 | 17 | *rb = _mm_sub_pd(*ra, rprod0); 18 | *ib = _mm_sub_pd(*ia, iprod0); 19 | *ra = _mm_add_pd(*ra, rprod0); 20 | *ia = _mm_add_pd(*ia, iprod0); 21 | } 22 | 23 | EXPORT void reim_fft4_avx_fma(double* dre, double* dim, const void* ompv) { 24 | const double* omp = (const double*)ompv; 25 | 26 | __m128d ra0 = _mm_loadu_pd(dre); 27 | __m128d ra2 = _mm_loadu_pd(dre + 2); 28 | __m128d ia0 = _mm_loadu_pd(dim); 29 | __m128d ia2 = _mm_loadu_pd(dim + 2); 30 | 31 | // 1 32 | { 33 | // duplicate omegas in precomp? 34 | __m128d om = _mm_loadu_pd(omp); 35 | __m128d omre = _mm_permute_pd(om, 0); 36 | __m128d omim = _mm_permute_pd(om, 3); 37 | 38 | reim_ctwiddle_avx_fma(&ra0, &ra2, &ia0, &ia2, omre, omim); 39 | } 40 | 41 | // 2 42 | { 43 | const __m128d fft4neg = _mm_castsi128_pd(_mm_set_epi64x(UINT64_C(1) << 63, 0)); 44 | __m128d om = _mm_loadu_pd(omp + 2); // om: r,i 45 | __m128d omim = _mm_permute_pd(om, 1); // omim: i,r 46 | __m128d omre = _mm_xor_pd(om, fft4neg); // omre: r,-i 47 | 48 | __m128d rb = _mm_unpackhi_pd(ra0, ra2); // (r0, r1), (r2, r3) -> (r1, r3) 49 | __m128d ib = _mm_unpackhi_pd(ia0, ia2); // (i0, i1), (i2, i3) -> (i1, i3) 50 | __m128d ra = _mm_unpacklo_pd(ra0, ra2); // (r0, r1), (r2, r3) -> (r0, r2) 51 | __m128d ia = _mm_unpacklo_pd(ia0, ia2); // (i0, i1), (i2, i3) -> (i0, i2) 52 | 53 | reim_ctwiddle_avx_fma(&ra, &rb, &ia, &ib, omre, omim); 54 | 55 | ra0 = _mm_unpacklo_pd(ra, rb); 56 | ia0 = _mm_unpacklo_pd(ia, ib); 57 | ra2 = _mm_unpackhi_pd(ra, rb); 58 | ia2 = _mm_unpackhi_pd(ia, ib); 59 | } 60 | 61 | // 4 62 | _mm_storeu_pd(dre, ra0); 63 | _mm_storeu_pd(dre + 2, ra2); 64 | _mm_storeu_pd(dim, ia0); 65 | _mm_storeu_pd(dim + 2, ia2); 66 | } 67 | -------------------------------------------------------------------------------- /spqlios/reim/reim_fft8_avx_fma.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "reim_fft_private.h" 6 | 7 | __always_inline void reim_ctwiddle_avx_fma(__m256d* ra, __m256d* rb, __m256d* ia, __m256d* ib, const __m256d omre, 8 | const __m256d omim) { 9 | // rb * omre - ib * omim; 10 | __m256d rprod0 = _mm256_mul_pd(*ib, omim); 11 | rprod0 = _mm256_fmsub_pd(*rb, omre, rprod0); 12 | 13 | // rb * omim + ib * omre; 14 | __m256d iprod0 = _mm256_mul_pd(*rb, omim); 15 | iprod0 = _mm256_fmadd_pd(*ib, omre, iprod0); 16 | 17 | *rb = _mm256_sub_pd(*ra, rprod0); 18 | *ib = _mm256_sub_pd(*ia, iprod0); 19 | *ra = _mm256_add_pd(*ra, rprod0); 20 | *ia = _mm256_add_pd(*ia, iprod0); 21 | } 22 | 23 | EXPORT void reim_fft8_avx_fma(double* dre, double* dim, const void* ompv) { 24 | const double* omp = (const double*)ompv; 25 | 26 | __m256d ra0 = _mm256_loadu_pd(dre); 27 | __m256d ra4 = _mm256_loadu_pd(dre + 4); 28 | __m256d ia0 = _mm256_loadu_pd(dim); 29 | __m256d ia4 = _mm256_loadu_pd(dim + 4); 30 | 31 | // 1 32 | { 33 | // duplicate omegas in precomp? 34 | __m128d omri = _mm_loadu_pd(omp); 35 | __m256d omriri = _mm256_set_m128d(omri, omri); 36 | __m256d omi = _mm256_permute_pd(omriri, 15); 37 | __m256d omr = _mm256_permute_pd(omriri, 0); 38 | 39 | reim_ctwiddle_avx_fma(&ra0, &ra4, &ia0, &ia4, omr, omi); 40 | } 41 | 42 | // 2 43 | { 44 | const __m128d fft8neg = _mm_castsi128_pd(_mm_set_epi64x(UINT64_C(1) << 63, 0)); 45 | __m128d omri = _mm_loadu_pd(omp + 2); // r,i 46 | __m128d omrmi = _mm_xor_pd(omri, fft8neg); // r,-i 47 | __m256d omrirmi = _mm256_set_m128d(omrmi, omri); // r,i,r,-i 48 | __m256d omi = _mm256_permute_pd(omrirmi, 3); // i,i,r,r 49 | __m256d omr = _mm256_permute_pd(omrirmi, 12); // r,r,-i,-i 50 | 51 | __m256d rb = _mm256_permute2f128_pd(ra0, ra4, 0x31); 52 | __m256d ib = _mm256_permute2f128_pd(ia0, ia4, 0x31); 53 | __m256d ra = _mm256_permute2f128_pd(ra0, ra4, 0x20); 54 | __m256d ia = _mm256_permute2f128_pd(ia0, ia4, 0x20); 55 | 56 | reim_ctwiddle_avx_fma(&ra, &rb, &ia, &ib, omr, omi); 57 | 58 | ra0 = _mm256_permute2f128_pd(ra, rb, 0x20); 59 | ra4 = _mm256_permute2f128_pd(ra, rb, 0x31); 60 | ia0 = _mm256_permute2f128_pd(ia, ib, 0x20); 61 | ia4 = _mm256_permute2f128_pd(ia, ib, 0x31); 62 | } 63 | 64 | // 3 65 | { 66 | const __m256d fft8neg2 = _mm256_castsi256_pd(_mm256_set_epi64x(UINT64_C(1) << 63, UINT64_C(1) << 63, 0, 0)); 67 | __m256d om = _mm256_loadu_pd(omp + 4); // r0,r1,i0,i1 68 | __m256d omi = _mm256_permute2f128_pd(om, om, 1); // i0,i1,r0,r1 69 | __m256d omr = _mm256_xor_pd(om, fft8neg2); // r0,r1,-i0,-i1 70 | 71 | __m256d rb = _mm256_unpackhi_pd(ra0, ra4); 72 | __m256d ib = _mm256_unpackhi_pd(ia0, ia4); 73 | __m256d ra = _mm256_unpacklo_pd(ra0, ra4); 74 | __m256d ia = _mm256_unpacklo_pd(ia0, ia4); 75 | 76 | reim_ctwiddle_avx_fma(&ra, &rb, &ia, &ib, omr, omi); 77 | 78 | ra4 = _mm256_unpackhi_pd(ra, rb); 79 | ia4 = _mm256_unpackhi_pd(ia, ib); 80 | ra0 = _mm256_unpacklo_pd(ra, rb); 81 | ia0 = _mm256_unpacklo_pd(ia, ib); 82 | } 83 | 84 | // 4 85 | _mm256_storeu_pd(dre, ra0); 86 | _mm256_storeu_pd(dre + 4, ra4); 87 | _mm256_storeu_pd(dim, ia0); 88 | _mm256_storeu_pd(dim + 4, ia4); 89 | } 90 | -------------------------------------------------------------------------------- /spqlios/reim/reim_fft_ifft.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "../commons_private.h" 7 | #include "reim_fft_internal.h" 8 | #include "reim_fft_private.h" 9 | 10 | double accurate_cos(int32_t i, int32_t n) { // cos(2pi*i/n) 11 | i = ((i % n) + n) % n; 12 | if (i >= 3 * n / 4) return cos(2. * M_PI * (n - i) / (double)(n)); 13 | if (i >= 2 * n / 4) return -cos(2. * M_PI * (i - n / 2) / (double)(n)); 14 | if (i >= 1 * n / 4) return -cos(2. * M_PI * (n / 2 - i) / (double)(n)); 15 | return cos(2. * M_PI * (i) / (double)(n)); 16 | } 17 | 18 | double accurate_sin(int32_t i, int32_t n) { // sin(2pi*i/n) 19 | i = ((i % n) + n) % n; 20 | if (i >= 3 * n / 4) return -sin(2. * M_PI * (n - i) / (double)(n)); 21 | if (i >= 2 * n / 4) return -sin(2. * M_PI * (i - n / 2) / (double)(n)); 22 | if (i >= 1 * n / 4) return sin(2. * M_PI * (n / 2 - i) / (double)(n)); 23 | return sin(2. * M_PI * (i) / (double)(n)); 24 | } 25 | 26 | 27 | EXPORT double* reim_ifft_precomp_get_buffer(const REIM_IFFT_PRECOMP* tables, uint32_t buffer_index) { 28 | return (double*)((uint8_t*) tables->aligned_buffers + buffer_index * tables->buf_size); 29 | } 30 | 31 | EXPORT double* reim_fft_precomp_get_buffer(const REIM_FFT_PRECOMP* tables, uint32_t buffer_index) { 32 | return (double*)((uint8_t*) tables->aligned_buffers + buffer_index * tables->buf_size); 33 | } 34 | 35 | 36 | EXPORT void reim_fft(const REIM_FFT_PRECOMP* tables, double* data) { tables->function(tables, data); } 37 | EXPORT void reim_ifft(const REIM_IFFT_PRECOMP* tables, double* data) { tables->function(tables, data); } 38 | -------------------------------------------------------------------------------- /spqlios/reim/reim_fft_private.h: -------------------------------------------------------------------------------- 1 | #ifndef SPQLIOS_REIM_FFT_PRIVATE_H 2 | #define SPQLIOS_REIM_FFT_PRIVATE_H 3 | 4 | #include "../commons_private.h" 5 | #include "reim_fft.h" 6 | 7 | #define STATIC_ASSERT(condition) (void)sizeof(char[-1 + 2 * !!(condition)]) 8 | 9 | typedef struct reim_twiddle_precomp REIM_FFTVEC_TWIDDLE_PRECOMP; 10 | typedef struct reim_bitwiddle_precomp REIM_FFTVEC_BITWIDDLE_PRECOMP; 11 | 12 | typedef void (*FFT_FUNC)(const REIM_FFT_PRECOMP*, double*); 13 | typedef void (*IFFT_FUNC)(const REIM_IFFT_PRECOMP*, double*); 14 | typedef void (*FFTVEC_MUL_FUNC)(const REIM_FFTVEC_MUL_PRECOMP*, double*, const double*, const double*); 15 | typedef void (*FFTVEC_ADDMUL_FUNC)(const REIM_FFTVEC_ADDMUL_PRECOMP*, double*, const double*, const double*); 16 | 17 | typedef void (*FROM_ZNX32_FUNC)(const REIM_FROM_ZNX32_PRECOMP*, void*, const int32_t*); 18 | typedef void (*FROM_ZNX64_FUNC)(const REIM_FROM_ZNX64_PRECOMP*, void*, const int64_t*); 19 | typedef void (*FROM_TNX32_FUNC)(const REIM_FROM_TNX32_PRECOMP*, void*, const int32_t*); 20 | typedef void (*TO_TNX32_FUNC)(const REIM_TO_TNX32_PRECOMP*, int32_t*, const void*); 21 | typedef void (*TO_TNX_FUNC)(const REIM_TO_TNX_PRECOMP*, double*, const double*); 22 | typedef void (*TO_ZNX64_FUNC)(const REIM_TO_ZNX64_PRECOMP*, int64_t*, const void*); 23 | typedef void (*FFTVEC_TWIDDLE_FUNC)(const REIM_FFTVEC_TWIDDLE_PRECOMP*, void*, const void*, const void*); 24 | typedef void (*FFTVEC_BITWIDDLE_FUNC)(const REIM_FFTVEC_BITWIDDLE_PRECOMP*, void*, uint64_t, const void*); 25 | 26 | typedef struct reim_fft_precomp { 27 | FFT_FUNC function; 28 | int64_t m; ///< complex dimension warning: reim uses n=2N=4m 29 | uint64_t buf_size; ///< size of aligned_buffers (multiple of 64B) 30 | double* powomegas; ///< 64B aligned 31 | void* aligned_buffers; ///< 64B aligned 32 | } REIM_FFT_PRECOMP; 33 | 34 | typedef struct reim_ifft_precomp { 35 | IFFT_FUNC function; 36 | int64_t m; // warning: reim uses n=2N=4m 37 | uint64_t buf_size; ///< size of aligned_buffers (multiple of 64B) 38 | double* powomegas; 39 | void* aligned_buffers; 40 | } REIM_IFFT_PRECOMP; 41 | 42 | typedef struct reim_mul_precomp { 43 | FFTVEC_MUL_FUNC function; 44 | int64_t m; 45 | } REIM_FFTVEC_MUL_PRECOMP; 46 | 47 | typedef struct reim_addmul_precomp { 48 | FFTVEC_ADDMUL_FUNC function; 49 | int64_t m; 50 | } REIM_FFTVEC_ADDMUL_PRECOMP; 51 | 52 | 53 | struct reim_from_znx32_precomp { 54 | FROM_ZNX32_FUNC function; 55 | int64_t m; 56 | }; 57 | 58 | struct reim_from_znx64_precomp { 59 | FROM_ZNX64_FUNC function; 60 | int64_t m; 61 | }; 62 | 63 | struct reim_from_tnx32_precomp { 64 | FROM_TNX32_FUNC function; 65 | int64_t m; 66 | }; 67 | 68 | struct reim_to_tnx32_precomp { 69 | TO_TNX32_FUNC function; 70 | int64_t m; 71 | double divisor; 72 | }; 73 | 74 | struct reim_to_tnx_precomp { 75 | TO_TNX_FUNC function; 76 | int64_t m; 77 | double divisor; 78 | uint32_t log2overhead; 79 | double add_cst; 80 | uint64_t mask_and; 81 | uint64_t mask_or; 82 | double sub_cst; 83 | }; 84 | 85 | struct reim_to_znx64_precomp { 86 | TO_ZNX64_FUNC function; 87 | int64_t m; 88 | double divisor; 89 | }; 90 | 91 | struct reim_twiddle_precomp { 92 | FFTVEC_TWIDDLE_FUNC function; 93 | int64_t m; 94 | }; 95 | 96 | struct reim_bitwiddle_precomp { 97 | FFTVEC_BITWIDDLE_FUNC function; 98 | int64_t m; 99 | }; 100 | 101 | #endif // SPQLIOS_REIM_FFT_PRIVATE_H 102 | -------------------------------------------------------------------------------- /spqlios/reim/reim_fftvec_addmul_fma.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "reim_fft_private.h" 6 | 7 | EXPORT void reim_fftvec_addmul_fma(const REIM_FFTVEC_ADDMUL_PRECOMP* precomp, double* r, const double* a, 8 | const double* b) { 9 | const uint64_t m = precomp->m; 10 | double* rr_ptr = r; 11 | double* ri_ptr = r + m; 12 | const double* ar_ptr = a; 13 | const double* ai_ptr = a + m; 14 | const double* br_ptr = b; 15 | const double* bi_ptr = b + m; 16 | 17 | const double* const rend_ptr = ri_ptr; 18 | while (rr_ptr != rend_ptr) { 19 | __m256d rr = _mm256_loadu_pd(rr_ptr); 20 | __m256d ri = _mm256_loadu_pd(ri_ptr); 21 | const __m256d ar = _mm256_loadu_pd(ar_ptr); 22 | const __m256d ai = _mm256_loadu_pd(ai_ptr); 23 | const __m256d br = _mm256_loadu_pd(br_ptr); 24 | const __m256d bi = _mm256_loadu_pd(bi_ptr); 25 | 26 | rr = _mm256_fmsub_pd(ai, bi, rr); 27 | rr = _mm256_fmsub_pd(ar, br, rr); 28 | ri = _mm256_fmadd_pd(ar, bi, ri); 29 | ri = _mm256_fmadd_pd(ai, br, ri); 30 | 31 | _mm256_storeu_pd(rr_ptr, rr); 32 | _mm256_storeu_pd(ri_ptr, ri); 33 | 34 | rr_ptr += 4; 35 | ri_ptr += 4; 36 | ar_ptr += 4; 37 | ai_ptr += 4; 38 | br_ptr += 4; 39 | bi_ptr += 4; 40 | } 41 | } 42 | 43 | EXPORT void reim_fftvec_mul_fma(const REIM_FFTVEC_MUL_PRECOMP* precomp, double* r, const double* a, const double* b) { 44 | const uint64_t m = precomp->m; 45 | double* rr_ptr = r; 46 | double* ri_ptr = r + m; 47 | const double* ar_ptr = a; 48 | const double* ai_ptr = a + m; 49 | const double* br_ptr = b; 50 | const double* bi_ptr = b + m; 51 | 52 | const double* const rend_ptr = ri_ptr; 53 | while (rr_ptr != rend_ptr) { 54 | const __m256d ar = _mm256_loadu_pd(ar_ptr); 55 | const __m256d ai = _mm256_loadu_pd(ai_ptr); 56 | const __m256d br = _mm256_loadu_pd(br_ptr); 57 | const __m256d bi = _mm256_loadu_pd(bi_ptr); 58 | 59 | const __m256d t1 = _mm256_mul_pd(ai, bi); 60 | const __m256d t2 = _mm256_mul_pd(ar, bi); 61 | 62 | __m256d rr = _mm256_fmsub_pd(ar, br, t1); 63 | __m256d ri = _mm256_fmadd_pd(ai, br, t2); 64 | 65 | _mm256_storeu_pd(rr_ptr, rr); 66 | _mm256_storeu_pd(ri_ptr, ri); 67 | 68 | rr_ptr += 4; 69 | ri_ptr += 4; 70 | ar_ptr += 4; 71 | ai_ptr += 4; 72 | br_ptr += 4; 73 | bi_ptr += 4; 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /spqlios/reim/reim_fftvec_addmul_ref.c: -------------------------------------------------------------------------------- 1 | #include "reim_fft_internal.h" 2 | #include "reim_fft_private.h" 3 | 4 | EXPORT void reim_fftvec_addmul_ref(const REIM_FFTVEC_ADDMUL_PRECOMP* precomp, double* r, const double* a, 5 | const double* b) { 6 | const uint64_t m = precomp->m; 7 | for (uint64_t i = 0; i < m; ++i) { 8 | double re = a[i] * b[i] - a[i + m] * b[i + m]; 9 | double im = a[i] * b[i + m] + a[i + m] * b[i]; 10 | r[i] += re; 11 | r[i + m] += im; 12 | } 13 | } 14 | 15 | EXPORT void reim_fftvec_mul_ref(const REIM_FFTVEC_MUL_PRECOMP* precomp, double* r, const double* a, const double* b) { 16 | const uint64_t m = precomp->m; 17 | for (uint64_t i = 0; i < m; ++i) { 18 | double re = a[i] * b[i] - a[i + m] * b[i + m]; 19 | double im = a[i] * b[i + m] + a[i + m] * b[i]; 20 | r[i] = re; 21 | r[i + m] = im; 22 | } 23 | } 24 | 25 | EXPORT REIM_FFTVEC_ADDMUL_PRECOMP* new_reim_fftvec_addmul_precomp(uint32_t m) { 26 | REIM_FFTVEC_ADDMUL_PRECOMP* reps = malloc(sizeof(REIM_FFTVEC_ADDMUL_PRECOMP)); 27 | reps->m = m; 28 | if (CPU_SUPPORTS("fma")) { 29 | if (m >= 4) { 30 | reps->function = reim_fftvec_addmul_fma; 31 | } else { 32 | reps->function = reim_fftvec_addmul_ref; 33 | } 34 | } else { 35 | reps->function = reim_fftvec_addmul_ref; 36 | } 37 | return reps; 38 | } 39 | 40 | EXPORT REIM_FFTVEC_MUL_PRECOMP* new_reim_fftvec_mul_precomp(uint32_t m) { 41 | REIM_FFTVEC_MUL_PRECOMP* reps = malloc(sizeof(REIM_FFTVEC_MUL_PRECOMP)); 42 | reps->m = m; 43 | if (CPU_SUPPORTS("fma")) { 44 | if (m >= 4) { 45 | reps->function = reim_fftvec_mul_fma; 46 | } else { 47 | reps->function = reim_fftvec_mul_ref; 48 | } 49 | } else { 50 | reps->function = reim_fftvec_mul_ref; 51 | } 52 | return reps; 53 | } 54 | 55 | -------------------------------------------------------------------------------- /spqlios/reim/reim_ifft4_avx_fma.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "reim_fft_private.h" 6 | 7 | __always_inline void reim_invctwiddle_avx_fma(__m128d* ra, __m128d* rb, __m128d* ia, __m128d* ib, const __m128d omre, 8 | const __m128d omim) { 9 | __m128d rdiff = _mm_sub_pd(*ra, *rb); 10 | __m128d idiff = _mm_sub_pd(*ia, *ib); 11 | *ra = _mm_add_pd(*ra, *rb); 12 | *ia = _mm_add_pd(*ia, *ib); 13 | 14 | *rb = _mm_mul_pd(idiff, omim); 15 | *rb = _mm_fmsub_pd(rdiff, omre, *rb); 16 | 17 | *ib = _mm_mul_pd(rdiff, omim); 18 | *ib = _mm_fmadd_pd(idiff, omre, *ib); 19 | } 20 | 21 | EXPORT void reim_ifft4_avx_fma(double* dre, double* dim, const void* ompv) { 22 | const double* omp = (const double*)ompv; 23 | 24 | __m128d ra0 = _mm_loadu_pd(dre); 25 | __m128d ra2 = _mm_loadu_pd(dre + 2); 26 | __m128d ia0 = _mm_loadu_pd(dim); 27 | __m128d ia2 = _mm_loadu_pd(dim + 2); 28 | 29 | // 1 30 | { 31 | const __m128d ifft4neg = _mm_castsi128_pd(_mm_set_epi64x(UINT64_C(1) << 63, 0)); 32 | __m128d omre = _mm_loadu_pd(omp); // omre: r,i 33 | __m128d omim = _mm_xor_pd(_mm_permute_pd(omre, 1), ifft4neg); // omim: i,-r 34 | 35 | __m128d ra = _mm_unpacklo_pd(ra0, ra2); // (r0, r1), (r2, r3) -> (r0, r2) 36 | __m128d ia = _mm_unpacklo_pd(ia0, ia2); // (i0, i1), (i2, i3) -> (i0, i2) 37 | __m128d rb = _mm_unpackhi_pd(ra0, ra2); // (r0, r1), (r2, r3) -> (r1, r3) 38 | __m128d ib = _mm_unpackhi_pd(ia0, ia2); // (i0, i1), (i2, i3) -> (i1, i3) 39 | 40 | reim_invctwiddle_avx_fma(&ra, &rb, &ia, &ib, omre, omim); 41 | 42 | ra0 = _mm_unpacklo_pd(ra, rb); 43 | ia0 = _mm_unpacklo_pd(ia, ib); 44 | ra2 = _mm_unpackhi_pd(ra, rb); 45 | ia2 = _mm_unpackhi_pd(ia, ib); 46 | } 47 | 48 | // 2 49 | { 50 | __m128d om = _mm_loadu_pd(omp + 2); 51 | __m128d omre = _mm_permute_pd(om, 0); 52 | __m128d omim = _mm_permute_pd(om, 3); 53 | 54 | reim_invctwiddle_avx_fma(&ra0, &ra2, &ia0, &ia2, omre, omim); 55 | } 56 | 57 | // 4 58 | _mm_storeu_pd(dre, ra0); 59 | _mm_storeu_pd(dre + 2, ra2); 60 | _mm_storeu_pd(dim, ia0); 61 | _mm_storeu_pd(dim + 2, ia2); 62 | } 63 | -------------------------------------------------------------------------------- /spqlios/reim/reim_ifft8_avx_fma.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "reim_fft_private.h" 6 | 7 | __always_inline void reim_invctwiddle_avx_fma(__m256d* ra, __m256d* rb, __m256d* ia, __m256d* ib, const __m256d omre, 8 | const __m256d omim) { 9 | __m256d rdiff = _mm256_sub_pd(*ra, *rb); 10 | __m256d idiff = _mm256_sub_pd(*ia, *ib); 11 | *ra = _mm256_add_pd(*ra, *rb); 12 | *ia = _mm256_add_pd(*ia, *ib); 13 | 14 | *rb = _mm256_mul_pd(idiff, omim); 15 | *rb = _mm256_fmsub_pd(rdiff, omre, *rb); 16 | 17 | *ib = _mm256_mul_pd(rdiff, omim); 18 | *ib = _mm256_fmadd_pd(idiff, omre, *ib); 19 | } 20 | 21 | EXPORT void reim_ifft8_avx_fma(double* dre, double* dim, const void* ompv) { 22 | const double* omp = (const double*)ompv; 23 | 24 | __m256d ra0 = _mm256_loadu_pd(dre); 25 | __m256d ra4 = _mm256_loadu_pd(dre + 4); 26 | __m256d ia0 = _mm256_loadu_pd(dim); 27 | __m256d ia4 = _mm256_loadu_pd(dim + 4); 28 | 29 | // 1 30 | { 31 | const __m256d fft8neg2 = _mm256_castsi256_pd(_mm256_set_epi64x(UINT64_C(1) << 63, UINT64_C(1) << 63, 0, 0)); 32 | __m256d omr = _mm256_loadu_pd(omp); // r0,r1,i0,i1 33 | __m256d omiirr = _mm256_permute2f128_pd(omr, omr, 1); // i0,i1,r0,r1 34 | __m256d omi = _mm256_xor_pd(omiirr, fft8neg2); // i0,i1,-r0,-r1 35 | 36 | __m256d rb = _mm256_unpackhi_pd(ra0, ra4); 37 | __m256d ib = _mm256_unpackhi_pd(ia0, ia4); 38 | __m256d ra = _mm256_unpacklo_pd(ra0, ra4); 39 | __m256d ia = _mm256_unpacklo_pd(ia0, ia4); 40 | 41 | reim_invctwiddle_avx_fma(&ra, &rb, &ia, &ib, omr, omi); 42 | 43 | ra4 = _mm256_unpackhi_pd(ra, rb); 44 | ia4 = _mm256_unpackhi_pd(ia, ib); 45 | ra0 = _mm256_unpacklo_pd(ra, rb); 46 | ia0 = _mm256_unpacklo_pd(ia, ib); 47 | } 48 | 49 | // 2 50 | { 51 | const __m128d ifft8neg = _mm_castsi128_pd(_mm_set_epi64x(0, UINT64_C(1) << 63)); 52 | __m128d omri = _mm_loadu_pd(omp + 4); // r,i 53 | __m128d ommri = _mm_xor_pd(omri, ifft8neg); // -r,i 54 | __m256d omrimri = _mm256_set_m128d(ommri, omri); // r,i,-r,i 55 | __m256d omi = _mm256_permute_pd(omrimri, 3); // i,i,-r,-r 56 | __m256d omr = _mm256_permute_pd(omrimri, 12); // r,r,i,i 57 | 58 | __m256d rb = _mm256_permute2f128_pd(ra0, ra4, 0x31); 59 | __m256d ib = _mm256_permute2f128_pd(ia0, ia4, 0x31); 60 | __m256d ra = _mm256_permute2f128_pd(ra0, ra4, 0x20); 61 | __m256d ia = _mm256_permute2f128_pd(ia0, ia4, 0x20); 62 | 63 | reim_invctwiddle_avx_fma(&ra, &rb, &ia, &ib, omr, omi); 64 | 65 | ra0 = _mm256_permute2f128_pd(ra, rb, 0x20); 66 | ra4 = _mm256_permute2f128_pd(ra, rb, 0x31); 67 | ia0 = _mm256_permute2f128_pd(ia, ib, 0x20); 68 | ia4 = _mm256_permute2f128_pd(ia, ib, 0x31); 69 | } 70 | 71 | // 3 72 | { 73 | __m128d omri = _mm_loadu_pd(omp + 6); // r,i 74 | __m256d omriri = _mm256_set_m128d(omri, omri); // r,i,r,i 75 | __m256d omi = _mm256_permute_pd(omriri, 15); // i,i,i,i 76 | __m256d omr = _mm256_permute_pd(omriri, 0); // r,r,r,r 77 | 78 | reim_invctwiddle_avx_fma(&ra0, &ra4, &ia0, &ia4, omr, omi); 79 | } 80 | 81 | // 4 82 | _mm256_storeu_pd(dre, ra0); 83 | _mm256_storeu_pd(dre + 4, ra4); 84 | _mm256_storeu_pd(dim, ia0); 85 | _mm256_storeu_pd(dim + 4, ia4); 86 | } 87 | -------------------------------------------------------------------------------- /spqlios/reim/reim_to_tnx_avx.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "../commons_private.h" 6 | #include "reim_fft_internal.h" 7 | #include "reim_fft_private.h" 8 | 9 | typedef union {double d; uint64_t u;} dblui64_t; 10 | 11 | EXPORT void reim_to_tnx_avx(const REIM_TO_TNX_PRECOMP* tables, double* r, const double* x) { 12 | const uint64_t n = tables->m << 1; 13 | const __m256d add_cst = _mm256_set1_pd(tables->add_cst); 14 | const __m256d mask_and = _mm256_castsi256_pd(_mm256_set1_epi64x(tables->mask_and)); 15 | const __m256d mask_or = _mm256_castsi256_pd(_mm256_set1_epi64x(tables->mask_or)); 16 | const __m256d sub_cst = _mm256_set1_pd(tables->sub_cst); 17 | __m256d reg0,reg1; 18 | for (uint64_t i=0; i 2 | #include 3 | 4 | #include "../commons_private.h" 5 | #include "reim_fft_internal.h" 6 | #include "reim_fft_private.h" 7 | 8 | EXPORT void reim_to_tnx_basic_ref(const REIM_TO_TNX_PRECOMP* tables, double* r, const double* x) { 9 | const uint64_t n = tables->m << 1; 10 | double divisor = tables->divisor; 11 | for (uint64_t i=0; im << 1; 21 | double add_cst = tables->add_cst; 22 | uint64_t mask_and = tables->mask_and; 23 | uint64_t mask_or = tables->mask_or; 24 | double sub_cst = tables->sub_cst; 25 | dblui64_t cur; 26 | for (uint64_t i=0; i 52) return spqlios_error("log2overhead is too large"); 38 | res->m = m; 39 | res->divisor = divisor; 40 | res->log2overhead = log2overhead; 41 | // 52 + 11 + 1 42 | // ......1.......01(1)|expo|sign 43 | // .......=========(1)|expo|sign msbbits = log2ovh + 2 + 11 + 1 44 | uint64_t nbits = 50 - log2overhead; 45 | dblui64_t ovh_cst; 46 | ovh_cst.d = 0.5 + (6<add_cst = ovh_cst.d * divisor; 48 | res->mask_and = ((UINT64_C(1) << nbits) - 1); 49 | res->mask_or = ovh_cst.u & ((UINT64_C(-1)) << nbits); 50 | res->sub_cst = ovh_cst.d; 51 | // TODO: check selection logic 52 | if (CPU_SUPPORTS("avx2")) { 53 | if (m >= 8) { 54 | res->function = reim_to_tnx_avx; 55 | } else { 56 | res->function = reim_to_tnx_ref; 57 | } 58 | } else { 59 | res->function = reim_to_tnx_ref; 60 | } 61 | return res; 62 | } 63 | 64 | EXPORT REIM_TO_TNX_PRECOMP* new_reim_to_tnx_precomp(uint32_t m, double divisor, uint32_t log2overhead) { 65 | REIM_TO_TNX_PRECOMP* res = malloc(sizeof(*res)); 66 | if (!res) return spqlios_error(strerror(errno)); 67 | return spqlios_keep_or_free(res, init_reim_to_tnx_precomp(res, m, divisor, log2overhead)); 68 | } 69 | 70 | EXPORT void reim_to_tnx(const REIM_TO_TNX_PRECOMP* tables, double* r, const double* x) { 71 | tables->function(tables, r, x); 72 | } 73 | -------------------------------------------------------------------------------- /spqlios/reim4/reim4_execute.c: -------------------------------------------------------------------------------- 1 | #include "reim4_fftvec_internal.h" 2 | #include "reim4_fftvec_private.h" 3 | 4 | EXPORT void reim4_fftvec_addmul(const REIM4_FFTVEC_ADDMUL_PRECOMP* tables, double* r, const double* a, 5 | const double* b) { 6 | tables->function(tables, r, a, b); 7 | } 8 | 9 | EXPORT void reim4_fftvec_mul(const REIM4_FFTVEC_MUL_PRECOMP* tables, double* r, const double* a, const double* b) { 10 | tables->function(tables, r, a, b); 11 | } 12 | 13 | EXPORT void reim4_from_cplx(const REIM4_FROM_CPLX_PRECOMP* tables, double* r, const void* a) { 14 | tables->function(tables, r, a); 15 | } 16 | 17 | EXPORT void reim4_to_cplx(const REIM4_TO_CPLX_PRECOMP* tables, void* r, const double* a) { 18 | tables->function(tables, r, a); 19 | } 20 | -------------------------------------------------------------------------------- /spqlios/reim4/reim4_fallbacks_aarch64.c: -------------------------------------------------------------------------------- 1 | #include "reim4_fftvec_private.h" 2 | 3 | EXPORT void reim4_fftvec_mul_fma(const REIM4_FFTVEC_MUL_PRECOMP* tables, double* r, const double* a, 4 | const double* b){UNDEFINED()} 5 | 6 | EXPORT void reim4_fftvec_addmul_fma(const REIM4_FFTVEC_ADDMUL_PRECOMP* tables, double* r, const double* a, 7 | const double* b){UNDEFINED()} 8 | 9 | EXPORT void reim4_from_cplx_fma(const REIM4_FROM_CPLX_PRECOMP* tables, double* r, const void* a){UNDEFINED()} 10 | 11 | EXPORT void reim4_to_cplx_fma(const REIM4_TO_CPLX_PRECOMP* tables, void* r, const double* a){UNDEFINED()} 12 | -------------------------------------------------------------------------------- /spqlios/reim4/reim4_fftvec_addmul_fma.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "reim4_fftvec_private.h" 6 | 7 | EXPORT void reim4_fftvec_addmul_fma(const REIM4_FFTVEC_ADDMUL_PRECOMP* tables, double* r_ptr, const double* a_ptr, 8 | const double* b_ptr) { 9 | const double* const rend_ptr = r_ptr + (tables->m << 1); 10 | while (r_ptr != rend_ptr) { 11 | __m256d rr = _mm256_loadu_pd(r_ptr); 12 | __m256d ri = _mm256_loadu_pd(r_ptr + 4); 13 | const __m256d ar = _mm256_loadu_pd(a_ptr); 14 | const __m256d ai = _mm256_loadu_pd(a_ptr + 4); 15 | const __m256d br = _mm256_loadu_pd(b_ptr); 16 | const __m256d bi = _mm256_loadu_pd(b_ptr + 4); 17 | 18 | rr = _mm256_fmsub_pd(ai, bi, rr); 19 | rr = _mm256_fmsub_pd(ar, br, rr); 20 | ri = _mm256_fmadd_pd(ar, bi, ri); 21 | ri = _mm256_fmadd_pd(ai, br, ri); 22 | 23 | _mm256_storeu_pd(r_ptr, rr); 24 | _mm256_storeu_pd(r_ptr + 4, ri); 25 | 26 | r_ptr += 8; 27 | a_ptr += 8; 28 | b_ptr += 8; 29 | } 30 | } 31 | 32 | EXPORT void reim4_fftvec_mul_fma(const REIM4_FFTVEC_MUL_PRECOMP* tables, double* r_ptr, const double* a_ptr, 33 | const double* b_ptr) { 34 | const double* const rend_ptr = r_ptr + (tables->m << 1); 35 | while (r_ptr != rend_ptr) { 36 | const __m256d ar = _mm256_loadu_pd(a_ptr); 37 | const __m256d ai = _mm256_loadu_pd(a_ptr + 4); 38 | const __m256d br = _mm256_loadu_pd(b_ptr); 39 | const __m256d bi = _mm256_loadu_pd(b_ptr + 4); 40 | 41 | const __m256d t1 = _mm256_mul_pd(ai, bi); 42 | const __m256d t2 = _mm256_mul_pd(ar, bi); 43 | 44 | __m256d rr = _mm256_fmsub_pd(ar, br, t1); 45 | __m256d ri = _mm256_fmadd_pd(ai, br, t2); 46 | 47 | _mm256_storeu_pd(r_ptr, rr); 48 | _mm256_storeu_pd(r_ptr + 4, ri); 49 | 50 | r_ptr += 8; 51 | a_ptr += 8; 52 | b_ptr += 8; 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /spqlios/reim4/reim4_fftvec_addmul_ref.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "../commons_private.h" 7 | #include "reim4_fftvec_internal.h" 8 | #include "reim4_fftvec_private.h" 9 | 10 | void* init_reim4_fftvec_addmul_precomp(REIM4_FFTVEC_ADDMUL_PRECOMP* res, uint32_t m) { 11 | res->m = m; 12 | if (CPU_SUPPORTS("fma")) { 13 | if (m >= 2) { 14 | res->function = reim4_fftvec_addmul_fma; 15 | } else { 16 | res->function = reim4_fftvec_addmul_ref; 17 | } 18 | } else { 19 | res->function = reim4_fftvec_addmul_ref; 20 | } 21 | return res; 22 | } 23 | 24 | EXPORT REIM4_FFTVEC_ADDMUL_PRECOMP* new_reim4_fftvec_addmul_precomp(uint32_t m) { 25 | REIM4_FFTVEC_ADDMUL_PRECOMP* res = malloc(sizeof(*res)); 26 | if (!res) return spqlios_error(strerror(errno)); 27 | return spqlios_keep_or_free(res, init_reim4_fftvec_addmul_precomp(res, m)); 28 | } 29 | 30 | EXPORT void reim4_fftvec_addmul_ref(const REIM4_FFTVEC_ADDMUL_PRECOMP* precomp, double* r, const double* a, 31 | const double* b) { 32 | const uint64_t m = precomp->m; 33 | for (uint64_t j = 0; j < m / 4; ++j) { 34 | for (uint64_t i = 0; i < 4; ++i) { 35 | double re = a[i] * b[i] - a[i + 4] * b[i + 4]; 36 | double im = a[i] * b[i + 4] + a[i + 4] * b[i]; 37 | r[i] += re; 38 | r[i + 4] += im; 39 | } 40 | a += 8; 41 | b += 8; 42 | r += 8; 43 | } 44 | } 45 | 46 | EXPORT void reim4_fftvec_addmul_simple(uint32_t m, double* r, const double* a, const double* b) { 47 | static REIM4_FFTVEC_ADDMUL_PRECOMP precomp[32]; 48 | REIM4_FFTVEC_ADDMUL_PRECOMP* p = precomp + log2m(m); 49 | if (!p->function) { 50 | if (!init_reim4_fftvec_addmul_precomp(p, m)) abort(); 51 | } 52 | p->function(p, r, a, b); 53 | } 54 | 55 | void* init_reim4_fftvec_mul_precomp(REIM4_FFTVEC_MUL_PRECOMP* res, uint32_t m) { 56 | res->m = m; 57 | if (CPU_SUPPORTS("fma")) { 58 | if (m >= 4) { 59 | res->function = reim4_fftvec_mul_fma; 60 | } else { 61 | res->function = reim4_fftvec_mul_ref; 62 | } 63 | } else { 64 | res->function = reim4_fftvec_mul_ref; 65 | } 66 | return res; 67 | } 68 | 69 | EXPORT REIM4_FFTVEC_MUL_PRECOMP* new_reim4_fftvec_mul_precomp(uint32_t m) { 70 | REIM4_FFTVEC_MUL_PRECOMP* res = malloc(sizeof(*res)); 71 | if (!res) return spqlios_error(strerror(errno)); 72 | return spqlios_keep_or_free(res, init_reim4_fftvec_mul_precomp(res, m)); 73 | } 74 | 75 | EXPORT void reim4_fftvec_mul_ref(const REIM4_FFTVEC_MUL_PRECOMP* precomp, double* r, const double* a, const double* b) { 76 | const uint64_t m = precomp->m; 77 | for (uint64_t j = 0; j < m / 4; ++j) { 78 | for (uint64_t i = 0; i < 4; ++i) { 79 | double re = a[i] * b[i] - a[i + 4] * b[i + 4]; 80 | double im = a[i] * b[i + 4] + a[i + 4] * b[i]; 81 | r[i] = re; 82 | r[i + 4] = im; 83 | } 84 | a += 8; 85 | b += 8; 86 | r += 8; 87 | } 88 | } 89 | 90 | EXPORT void reim4_fftvec_mul_simple(uint32_t m, double* r, const double* a, const double* b) { 91 | static REIM4_FFTVEC_MUL_PRECOMP precomp[32]; 92 | REIM4_FFTVEC_MUL_PRECOMP* p = precomp + log2m(m); 93 | if (!p->function) { 94 | if (!init_reim4_fftvec_mul_precomp(p, m)) abort(); 95 | } 96 | p->function(p, r, a, b); 97 | } 98 | -------------------------------------------------------------------------------- /spqlios/reim4/reim4_fftvec_conv_fma.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "reim4_fftvec_private.h" 6 | 7 | EXPORT void reim4_from_cplx_fma(const REIM4_FROM_CPLX_PRECOMP* tables, double* r_ptr, const void* a) { 8 | const double* const rend_ptr = r_ptr + (tables->m << 1); 9 | 10 | const double* a_ptr = (double*)a; 11 | while (r_ptr != rend_ptr) { 12 | __m256d t1 = _mm256_loadu_pd(a_ptr); 13 | __m256d t2 = _mm256_loadu_pd(a_ptr + 4); 14 | 15 | _mm256_storeu_pd(r_ptr, _mm256_unpacklo_pd(t1, t2)); 16 | _mm256_storeu_pd(r_ptr + 4, _mm256_unpackhi_pd(t1, t2)); 17 | 18 | r_ptr += 8; 19 | a_ptr += 8; 20 | } 21 | } 22 | 23 | EXPORT void reim4_to_cplx_fma(const REIM4_TO_CPLX_PRECOMP* tables, void* r, const double* a_ptr) { 24 | const double* const aend_ptr = a_ptr + (tables->m << 1); 25 | double* r_ptr = (double*)r; 26 | 27 | while (a_ptr != aend_ptr) { 28 | __m256d t1 = _mm256_loadu_pd(a_ptr); 29 | __m256d t2 = _mm256_loadu_pd(a_ptr + 4); 30 | 31 | _mm256_storeu_pd(r_ptr, _mm256_unpacklo_pd(t1, t2)); 32 | _mm256_storeu_pd(r_ptr + 4, _mm256_unpackhi_pd(t1, t2)); 33 | 34 | r_ptr += 8; 35 | a_ptr += 8; 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /spqlios/reim4/reim4_fftvec_conv_ref.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "../commons_private.h" 7 | #include "reim4_fftvec_internal.h" 8 | #include "reim4_fftvec_private.h" 9 | 10 | EXPORT void reim4_from_cplx_ref(const REIM4_FROM_CPLX_PRECOMP* tables, double* r, const void* a) { 11 | const double* x = (double*)a; 12 | const uint64_t m = tables->m; 13 | for (uint64_t i = 0; i < m / 4; ++i) { 14 | double r0 = x[0]; 15 | double i0 = x[1]; 16 | double r1 = x[2]; 17 | double i1 = x[3]; 18 | double r2 = x[4]; 19 | double i2 = x[5]; 20 | double r3 = x[6]; 21 | double i3 = x[7]; 22 | r[0] = r0; 23 | r[1] = r2; 24 | r[2] = r1; 25 | r[3] = r3; 26 | r[4] = i0; 27 | r[5] = i2; 28 | r[6] = i1; 29 | r[7] = i3; 30 | x += 8; 31 | r += 8; 32 | } 33 | } 34 | 35 | void* init_reim4_from_cplx_precomp(REIM4_FROM_CPLX_PRECOMP* res, uint32_t nn) { 36 | res->m = nn / 2; 37 | if (CPU_SUPPORTS("fma")) { 38 | if (nn >= 4) { 39 | res->function = reim4_from_cplx_fma; 40 | } else { 41 | res->function = reim4_from_cplx_ref; 42 | } 43 | } else { 44 | res->function = reim4_from_cplx_ref; 45 | } 46 | return res; 47 | } 48 | 49 | EXPORT REIM4_FROM_CPLX_PRECOMP* new_reim4_from_cplx_precomp(uint32_t m) { 50 | REIM4_FROM_CPLX_PRECOMP* res = malloc(sizeof(*res)); 51 | if (!res) return spqlios_error(strerror(errno)); 52 | return spqlios_keep_or_free(res, init_reim4_from_cplx_precomp(res, m)); 53 | } 54 | 55 | EXPORT void reim4_from_cplx_simple(uint32_t m, double* r, const void* a) { 56 | static REIM4_FROM_CPLX_PRECOMP precomp[32]; 57 | REIM4_FROM_CPLX_PRECOMP* p = precomp + log2m(m); 58 | if (!p->function) { 59 | if (!init_reim4_from_cplx_precomp(p, m)) abort(); 60 | } 61 | p->function(p, r, a); 62 | } 63 | 64 | EXPORT void reim4_to_cplx_ref(const REIM4_TO_CPLX_PRECOMP* tables, void* r, const double* a) { 65 | double* y = (double*)r; 66 | const uint64_t m = tables->m; 67 | for (uint64_t i = 0; i < m / 4; ++i) { 68 | double r0 = a[0]; 69 | double r2 = a[1]; 70 | double r1 = a[2]; 71 | double r3 = a[3]; 72 | double i0 = a[4]; 73 | double i2 = a[5]; 74 | double i1 = a[6]; 75 | double i3 = a[7]; 76 | y[0] = r0; 77 | y[1] = i0; 78 | y[2] = r1; 79 | y[3] = i1; 80 | y[4] = r2; 81 | y[5] = i2; 82 | y[6] = r3; 83 | y[7] = i3; 84 | a += 8; 85 | y += 8; 86 | } 87 | } 88 | 89 | void* init_reim4_to_cplx_precomp(REIM4_TO_CPLX_PRECOMP* res, uint32_t m) { 90 | res->m = m; 91 | if (CPU_SUPPORTS("fma")) { 92 | if (m >= 2) { 93 | res->function = reim4_to_cplx_fma; 94 | } else { 95 | res->function = reim4_to_cplx_ref; 96 | } 97 | } else { 98 | res->function = reim4_to_cplx_ref; 99 | } 100 | return res; 101 | } 102 | 103 | EXPORT REIM4_TO_CPLX_PRECOMP* new_reim4_to_cplx_precomp(uint32_t m) { 104 | REIM4_TO_CPLX_PRECOMP* res = malloc(sizeof(*res)); 105 | if (!res) return spqlios_error(strerror(errno)); 106 | return spqlios_keep_or_free(res, init_reim4_to_cplx_precomp(res, m)); 107 | } 108 | 109 | EXPORT void reim4_to_cplx_simple(uint32_t m, void* r, const double* a) { 110 | static REIM4_TO_CPLX_PRECOMP precomp[32]; 111 | REIM4_TO_CPLX_PRECOMP* p = precomp + log2m(m); 112 | if (!p->function) { 113 | if (!init_reim4_to_cplx_precomp(p, m)) abort(); 114 | } 115 | p->function(p, r, a); 116 | } 117 | -------------------------------------------------------------------------------- /spqlios/reim4/reim4_fftvec_internal.h: -------------------------------------------------------------------------------- 1 | #ifndef SPQLIOS_REIM4_FFTVEC_INTERNAL_H 2 | #define SPQLIOS_REIM4_FFTVEC_INTERNAL_H 3 | 4 | #include "reim4_fftvec_public.h" 5 | 6 | EXPORT void reim4_fftvec_mul_ref(const REIM4_FFTVEC_MUL_PRECOMP* tables, double* r, const double* a, const double* b); 7 | EXPORT void reim4_fftvec_mul_fma(const REIM4_FFTVEC_MUL_PRECOMP* tables, double* r, const double* a, const double* b); 8 | 9 | EXPORT void reim4_fftvec_addmul_ref(const REIM4_FFTVEC_ADDMUL_PRECOMP* tables, double* r, const double* a, 10 | const double* b); 11 | EXPORT void reim4_fftvec_addmul_fma(const REIM4_FFTVEC_ADDMUL_PRECOMP* tables, double* r, const double* a, 12 | const double* b); 13 | 14 | EXPORT void reim4_from_cplx_ref(const REIM4_FROM_CPLX_PRECOMP* tables, double* r, const void* a); 15 | EXPORT void reim4_from_cplx_fma(const REIM4_FROM_CPLX_PRECOMP* tables, double* r, const void* a); 16 | 17 | EXPORT void reim4_to_cplx_ref(const REIM4_TO_CPLX_PRECOMP* tables, void* r, const double* a); 18 | EXPORT void reim4_to_cplx_fma(const REIM4_TO_CPLX_PRECOMP* tables, void* r, const double* a); 19 | 20 | #endif // SPQLIOS_REIM4_FFTVEC_INTERNAL_H 21 | -------------------------------------------------------------------------------- /spqlios/reim4/reim4_fftvec_private.h: -------------------------------------------------------------------------------- 1 | #ifndef SPQLIOS_REIM4_FFTVEC_PRIVATE_H 2 | #define SPQLIOS_REIM4_FFTVEC_PRIVATE_H 3 | 4 | #include "reim4_fftvec_public.h" 5 | 6 | #define STATIC_ASSERT(condition) (void)sizeof(char[-1 + 2 * !!(condition)]) 7 | 8 | typedef void (*R4_FFTVEC_MUL_FUNC)(const REIM4_FFTVEC_MUL_PRECOMP*, double*, const double*, const double*); 9 | typedef void (*R4_FFTVEC_ADDMUL_FUNC)(const REIM4_FFTVEC_ADDMUL_PRECOMP*, double*, const double*, const double*); 10 | typedef void (*R4_FROM_CPLX_FUNC)(const REIM4_FROM_CPLX_PRECOMP*, double*, const void*); 11 | typedef void (*R4_TO_CPLX_FUNC)(const REIM4_TO_CPLX_PRECOMP*, void*, const double*); 12 | 13 | struct reim4_mul_precomp { 14 | R4_FFTVEC_MUL_FUNC function; 15 | int64_t m; 16 | }; 17 | 18 | struct reim4_addmul_precomp { 19 | R4_FFTVEC_ADDMUL_FUNC function; 20 | int64_t m; 21 | }; 22 | 23 | struct reim4_from_cplx_precomp { 24 | R4_FROM_CPLX_FUNC function; 25 | int64_t m; 26 | }; 27 | 28 | struct reim4_to_cplx_precomp { 29 | R4_TO_CPLX_FUNC function; 30 | int64_t m; 31 | }; 32 | 33 | #endif // SPQLIOS_REIM4_FFTVEC_PRIVATE_H 34 | -------------------------------------------------------------------------------- /spqlios/reim4/reim4_fftvec_public.h: -------------------------------------------------------------------------------- 1 | #ifndef SPQLIOS_REIM4_FFTVEC_PUBLIC_H 2 | #define SPQLIOS_REIM4_FFTVEC_PUBLIC_H 3 | 4 | #include "../commons.h" 5 | 6 | typedef struct reim4_addmul_precomp REIM4_FFTVEC_ADDMUL_PRECOMP; 7 | typedef struct reim4_mul_precomp REIM4_FFTVEC_MUL_PRECOMP; 8 | typedef struct reim4_from_cplx_precomp REIM4_FROM_CPLX_PRECOMP; 9 | typedef struct reim4_to_cplx_precomp REIM4_TO_CPLX_PRECOMP; 10 | 11 | EXPORT REIM4_FFTVEC_MUL_PRECOMP* new_reim4_fftvec_mul_precomp(uint32_t m); 12 | EXPORT void reim4_fftvec_mul(const REIM4_FFTVEC_MUL_PRECOMP* tables, double* r, const double* a, const double* b); 13 | #define delete_reim4_fftvec_mul_precomp free 14 | 15 | EXPORT REIM4_FFTVEC_ADDMUL_PRECOMP* new_reim4_fftvec_addmul_precomp(uint32_t nn); 16 | EXPORT void reim4_fftvec_addmul(const REIM4_FFTVEC_ADDMUL_PRECOMP* tables, double* r, const double* a, const double* b); 17 | #define delete_reim4_fftvec_addmul_precomp free 18 | 19 | /** 20 | * @brief prepares a conversion from the cplx fftvec layout to the reim4 layout. 21 | * @param m complex dimension m from C[X] mod X^m-i. 22 | */ 23 | EXPORT REIM4_FROM_CPLX_PRECOMP* new_reim4_from_cplx_precomp(uint32_t m); 24 | EXPORT void reim4_from_cplx(const REIM4_FROM_CPLX_PRECOMP* tables, double* r, const void* a); 25 | #define delete_reim4_from_cplx_precomp free 26 | 27 | /** 28 | * @brief prepares a conversion from the reim4 fftvec layout to the cplx layout 29 | * @param m the complex dimension m from C[X] mod X^m-i. 30 | */ 31 | EXPORT REIM4_TO_CPLX_PRECOMP* new_reim4_to_cplx_precomp(uint32_t m); 32 | EXPORT void reim4_to_cplx(const REIM4_TO_CPLX_PRECOMP* tables, void* r, const double* a); 33 | #define delete_reim4_to_cplx_precomp free 34 | 35 | /** 36 | * @brief Simpler API for the fftvec multiplication function. 37 | * For each dimension, the precomputed tables for this dimension are generated automatically the first time. 38 | * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ 39 | EXPORT void reim4_fftvec_mul_simple(uint32_t m, double* r, const double* a, const double* b); 40 | 41 | /** 42 | * @brief Simpler API for the fftvec addmul function. 43 | * For each dimension, the precomputed tables for this dimension are generated automatically the first time. 44 | * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ 45 | EXPORT void reim4_fftvec_addmul_simple(uint32_t m, double* r, const double* a, const double* b); 46 | 47 | /** 48 | * @brief Simpler API for cplx conversion. 49 | * For each dimension, the precomputed tables for this dimension are generated automatically the first time. 50 | * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ 51 | EXPORT void reim4_from_cplx_simple(uint32_t m, double* r, const void* a); 52 | 53 | /** 54 | * @brief Simpler API for to cplx conversion. 55 | * For each dimension, the precomputed tables for this dimension are generated automatically the first time. 56 | * It is advised to do one dry-run call per desired dimension before using in a multithread environment */ 57 | EXPORT void reim4_to_cplx_simple(uint32_t m, void* r, const double* a); 58 | 59 | #endif // SPQLIOS_REIM4_FFTVEC_PUBLIC_H -------------------------------------------------------------------------------- /test/spqlios_cplx_conversions_test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include "spqlios/cplx/cplx_fft_internal.h" 6 | #include "spqlios/cplx/cplx_fft_private.h" 7 | 8 | #ifdef __x86_64__ 9 | TEST(fft, cplx_from_znx32_ref_vs_fma) { 10 | const uint32_t m = 128; 11 | int32_t* src = (int32_t*)spqlios_alloc_custom_align(32, 10 * m * sizeof(int32_t)); 12 | CPLX* dst1 = (CPLX*)(src + 2 * m); 13 | CPLX* dst2 = (CPLX*)(src + 6 * m); 14 | for (uint64_t i = 0; i < 2 * m; ++i) { 15 | src[i] = rand() - RAND_MAX / 2; 16 | } 17 | CPLX_FROM_ZNX32_PRECOMP precomp; 18 | precomp.m = m; 19 | cplx_from_znx32_ref(&precomp, dst1, src); 20 | // cplx_from_znx32_simple(m, 32, dst1, src); 21 | cplx_from_znx32_avx2_fma(&precomp, dst2, src); 22 | for (uint64_t i = 0; i < m; ++i) { 23 | ASSERT_EQ(dst1[i][0], dst2[i][0]); 24 | ASSERT_EQ(dst1[i][1], dst2[i][1]); 25 | } 26 | spqlios_free(src); 27 | } 28 | #endif 29 | 30 | #ifdef __x86_64__ 31 | TEST(fft, cplx_from_tnx32_ref_vs_fma) { 32 | const uint32_t m = 128; 33 | int32_t* src = (int32_t*)spqlios_alloc_custom_align(32, 10 * m * sizeof(int32_t)); 34 | CPLX* dst1 = (CPLX*)(src + 2 * m); 35 | CPLX* dst2 = (CPLX*)(src + 6 * m); 36 | for (uint64_t i = 0; i < 2 * m; ++i) { 37 | src[i] = rand() + (rand() << 20); 38 | } 39 | CPLX_FROM_TNX32_PRECOMP precomp; 40 | precomp.m = m; 41 | cplx_from_tnx32_ref(&precomp, dst1, src); 42 | // cplx_from_tnx32_simple(m, dst1, src); 43 | cplx_from_tnx32_avx2_fma(&precomp, dst2, src); 44 | for (uint64_t i = 0; i < m; ++i) { 45 | ASSERT_EQ(dst1[i][0], dst2[i][0]); 46 | ASSERT_EQ(dst1[i][1], dst2[i][1]); 47 | } 48 | spqlios_free(src); 49 | } 50 | #endif 51 | 52 | #ifdef __x86_64__ 53 | TEST(fft, cplx_to_tnx32_ref_vs_fma) { 54 | for (const uint32_t m : {8, 128, 1024, 65536}) { 55 | for (const double divisor : {double(1), double(m), double(0.5)}) { 56 | CPLX* src = (CPLX*)spqlios_alloc_custom_align(32, 10 * m * sizeof(int32_t)); 57 | int32_t* dst1 = (int32_t*)(src + m); 58 | int32_t* dst2 = (int32_t*)(src + 2 * m); 59 | for (uint64_t i = 0; i < 2 * m; ++i) { 60 | src[i][0] = (rand() / double(RAND_MAX) - 0.5) * pow(2., 19 - (rand() % 60)) * divisor; 61 | src[i][1] = (rand() / double(RAND_MAX) - 0.5) * pow(2., 19 - (rand() % 60)) * divisor; 62 | } 63 | CPLX_TO_TNX32_PRECOMP precomp; 64 | precomp.m = m; 65 | precomp.divisor = divisor; 66 | cplx_to_tnx32_ref(&precomp, dst1, src); 67 | cplx_to_tnx32_avx2_fma(&precomp, dst2, src); 68 | // cplx_to_tnx32_simple(m, divisor, 18, dst2, src); 69 | for (uint64_t i = 0; i < 2 * m; ++i) { 70 | double truevalue = 71 | (src[i % m][i / m] / divisor - floor(src[i % m][i / m] / divisor + 0.5)) * (INT64_C(1) << 32); 72 | if (fabs(truevalue - floor(truevalue)) == 0.5) { 73 | // ties can differ by 0, 1 or -1 74 | ASSERT_LE(abs(dst1[i] - dst2[i]), 0) 75 | << i << " " << dst1[i] << " " << dst2[i] << " " << truevalue << std::endl; 76 | } else { 77 | // otherwise, we should have equality 78 | ASSERT_LE(abs(dst1[i] - dst2[i]), 0) 79 | << i << " " << dst1[i] << " " << dst2[i] << " " << truevalue << std::endl; 80 | } 81 | } 82 | spqlios_free(src); 83 | } 84 | } 85 | } 86 | #endif 87 | -------------------------------------------------------------------------------- /test/spqlios_cplx_fft_bench.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include "../spqlios/cplx/cplx_fft_internal.h" 14 | #include "spqlios/reim/reim_fft.h" 15 | 16 | using namespace std; 17 | 18 | void init_random_values(uint64_t n, double* v) { 19 | for (uint64_t i = 0; i < n; ++i) v[i] = rand() - (RAND_MAX >> 1); 20 | } 21 | 22 | void benchmark_cplx_fft(benchmark::State& state) { 23 | const int32_t nn = state.range(0); 24 | CPLX_FFT_PRECOMP* a = new_cplx_fft_precomp(nn / 2, 1); 25 | double* c = (double*)cplx_fft_precomp_get_buffer(a, 0); 26 | init_random_values(nn, c); 27 | for (auto _ : state) { 28 | // cplx_fft_simple(nn/2, c); 29 | cplx_fft(a, c); 30 | } 31 | delete_cplx_fft_precomp(a); 32 | } 33 | 34 | void benchmark_cplx_ifft(benchmark::State& state) { 35 | const int32_t nn = state.range(0); 36 | CPLX_IFFT_PRECOMP* a = new_cplx_ifft_precomp(nn / 2, 1); 37 | double* c = (double*)cplx_ifft_precomp_get_buffer(a, 0); 38 | init_random_values(nn, c); 39 | for (auto _ : state) { 40 | // cplx_ifft_simple(nn/2, c); 41 | cplx_ifft(a, c); 42 | } 43 | delete_cplx_ifft_precomp(a); 44 | } 45 | 46 | void benchmark_reim_fft(benchmark::State& state) { 47 | const int32_t nn = state.range(0); 48 | const uint32_t m = nn / 2; 49 | REIM_FFT_PRECOMP* a = new_reim_fft_precomp(m, 1); 50 | double* c = reim_fft_precomp_get_buffer(a, 0); 51 | init_random_values(nn, c); 52 | for (auto _ : state) { 53 | // cplx_fft_simple(nn/2, c); 54 | reim_fft(a, c); 55 | } 56 | delete_reim_fft_precomp(a); 57 | } 58 | 59 | #ifdef __aarch64__ 60 | EXPORT REIM_FFT_PRECOMP* new_reim_fft_precomp_neon(uint32_t m, uint32_t num_buffers); 61 | EXPORT void reim_fft_neon(const REIM_FFT_PRECOMP* precomp, double* d); 62 | 63 | void benchmark_reim_fft_neon(benchmark::State& state) { 64 | const int32_t nn = state.range(0); 65 | const uint32_t m = nn / 2; 66 | REIM_FFT_PRECOMP* a = new_reim_fft_precomp_neon(m, 1); 67 | double* c = reim_fft_precomp_get_buffer(a, 0); 68 | init_random_values(nn, c); 69 | for (auto _ : state) { 70 | // cplx_fft_simple(nn/2, c); 71 | reim_fft_neon(a, c); 72 | } 73 | delete_reim_fft_precomp(a); 74 | } 75 | #endif 76 | 77 | void benchmark_reim_ifft(benchmark::State& state) { 78 | const int32_t nn = state.range(0); 79 | const uint32_t m = nn / 2; 80 | REIM_IFFT_PRECOMP* a = new_reim_ifft_precomp(m, 1); 81 | double* c = reim_ifft_precomp_get_buffer(a, 0); 82 | init_random_values(nn, c); 83 | for (auto _ : state) { 84 | // cplx_ifft_simple(nn/2, c); 85 | reim_ifft(a, c); 86 | } 87 | delete_reim_ifft_precomp(a); 88 | } 89 | 90 | // #define ARGS Arg(1024)->Arg(8192)->Arg(32768)->Arg(65536) 91 | #define ARGS Arg(64)->Arg(256)->Arg(1024)->Arg(2048)->Arg(4096)->Arg(8192)->Arg(16384)->Arg(32768)->Arg(65536) 92 | 93 | int main(int argc, char** argv) { 94 | ::benchmark::Initialize(&argc, argv); 95 | if (::benchmark::ReportUnrecognizedArguments(argc, argv)) return 1; 96 | std::cout << "Dimensions n in the benchmark below are in \"real FFT\" modulo X^n+1" << std::endl; 97 | std::cout << "The complex dimension m (modulo X^m-i) is half of it" << std::endl; 98 | BENCHMARK(benchmark_cplx_fft)->ARGS; 99 | BENCHMARK(benchmark_cplx_ifft)->ARGS; 100 | BENCHMARK(benchmark_reim_fft)->ARGS; 101 | #ifdef __aarch64__ 102 | BENCHMARK(benchmark_reim_fft_neon)->ARGS; 103 | #endif 104 | BENCHMARK(benchmark_reim_ifft)->ARGS; 105 | // if (CPU_SUPPORTS("avx512f")) { 106 | // BENCHMARK(bench_cplx_fftvec_twiddle_avx512)->ARGS; 107 | // BENCHMARK(bench_cplx_fftvec_bitwiddle_avx512)->ARGS; 108 | //} 109 | ::benchmark::RunSpecifiedBenchmarks(); 110 | ::benchmark::Shutdown(); 111 | return 0; 112 | } 113 | -------------------------------------------------------------------------------- /test/spqlios_q120_ntt_bench.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include "spqlios/q120/q120_ntt.h" 6 | 7 | #define ARGS Arg(1 << 10)->Arg(1 << 11)->Arg(1 << 12)->Arg(1 << 13)->Arg(1 << 14)->Arg(1 << 15)->Arg(1 << 16) 8 | 9 | template 10 | void benchmark_ntt(benchmark::State& state) { 11 | const uint64_t n = state.range(0); 12 | q120_ntt_precomp* precomp = q120_new_ntt_bb_precomp(n); 13 | 14 | uint64_t* px = new uint64_t[n * 4]; 15 | for (uint64_t i = 0; i < 4 * n; i++) { 16 | px[i] = (rand() << 31) + rand(); 17 | } 18 | for (auto _ : state) { 19 | f(precomp, (q120b*)px); 20 | } 21 | delete[] px; 22 | q120_del_ntt_bb_precomp(precomp); 23 | } 24 | 25 | template 26 | void benchmark_intt(benchmark::State& state) { 27 | const uint64_t n = state.range(0); 28 | q120_ntt_precomp* precomp = q120_new_intt_bb_precomp(n); 29 | 30 | uint64_t* px = new uint64_t[n * 4]; 31 | for (uint64_t i = 0; i < 4 * n; i++) { 32 | px[i] = (rand() << 31) + rand(); 33 | } 34 | for (auto _ : state) { 35 | f(precomp, (q120b*)px); 36 | } 37 | delete[] px; 38 | q120_del_intt_bb_precomp(precomp); 39 | } 40 | 41 | BENCHMARK(benchmark_ntt)->Name("q120_ntt_bb_avx2")->ARGS; 42 | BENCHMARK(benchmark_intt)->Name("q120_intt_bb_avx2")->ARGS; 43 | 44 | BENCHMARK_MAIN(); 45 | -------------------------------------------------------------------------------- /test/spqlios_reim4_arithmetic_bench.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "spqlios/reim4/reim4_arithmetic.h" 4 | 5 | void init_random_values(uint64_t n, double* v) { 6 | for (uint64_t i = 0; i < n; ++i) 7 | v[i] = (double(rand() % (UINT64_C(1) << 14)) - (UINT64_C(1) << 13)) / (UINT64_C(1) << 12); 8 | } 9 | 10 | // Run the benchmark 11 | BENCHMARK_MAIN(); 12 | 13 | #undef ARGS 14 | #define ARGS Args({47, 16384})->Args({93, 32768}) 15 | 16 | /* 17 | * reim4_vec_mat1col_product 18 | * reim4_vec_mat2col_product 19 | * reim4_vec_mat3col_product 20 | * reim4_vec_mat4col_product 21 | */ 22 | 23 | template 25 | void benchmark_reim4_vec_matXcols_product(benchmark::State& state) { 26 | const uint64_t nrows = state.range(0); 27 | 28 | double* u = new double[nrows * 8]; 29 | init_random_values(8 * nrows, u); 30 | double* v = new double[nrows * X * 8]; 31 | init_random_values(X * 8 * nrows, v); 32 | double* dst = new double[X * 8]; 33 | 34 | for (auto _ : state) { 35 | fnc(nrows, dst, u, v); 36 | } 37 | 38 | delete[] dst; 39 | delete[] v; 40 | delete[] u; 41 | } 42 | 43 | #undef ARGS 44 | #define ARGS Arg(128)->Arg(1024)->Arg(4096) 45 | 46 | #ifdef __x86_64__ 47 | BENCHMARK(benchmark_reim4_vec_matXcols_product<1, reim4_vec_mat1col_product_avx2>)->ARGS; 48 | // TODO: please remove when fixed: 49 | BENCHMARK(benchmark_reim4_vec_matXcols_product<2, reim4_vec_mat2cols_product_avx2>)->ARGS; 50 | #endif 51 | BENCHMARK(benchmark_reim4_vec_matXcols_product<1, reim4_vec_mat1col_product_ref>)->ARGS; 52 | BENCHMARK(benchmark_reim4_vec_matXcols_product<2, reim4_vec_mat2cols_product_ref>)->ARGS; 53 | -------------------------------------------------------------------------------- /test/spqlios_svp_product_test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "../spqlios/arithmetic/vec_znx_arithmetic_private.h" 4 | #include "testlib/fft64_dft.h" 5 | #include "testlib/fft64_layouts.h" 6 | #include "testlib/polynomial_vector.h" 7 | 8 | // todo: remove when registered 9 | typedef typeof(fft64_svp_prepare_ref) SVP_PREPARE_F; 10 | 11 | void test_fft64_svp_prepare(SVP_PREPARE_F svp_prepare) { 12 | for (uint64_t n : {2, 4, 8, 64, 128}) { 13 | MODULE* module = new_module_info(n, FFT64); 14 | znx_i64 in = znx_i64::random_log2bound(n, 40); 15 | fft64_svp_ppol_layout out(n); 16 | reim_fft64vec expect = simple_fft64(in); 17 | svp_prepare(module, out.data, in.data()); 18 | const double* ed = (double*)expect.data(); 19 | const double* ac = (double*)out.data; 20 | for (uint64_t i = 0; i < n; ++i) { 21 | ASSERT_LE(abs(ed[i] - ac[i]), 1e-10) << i << n; 22 | } 23 | delete_module_info(module); 24 | } 25 | } 26 | 27 | TEST(svp_prepare, fft64_svp_prepare_ref) { test_fft64_svp_prepare(fft64_svp_prepare_ref); } 28 | TEST(svp_prepare, svp_prepare) { test_fft64_svp_prepare(svp_prepare); } 29 | -------------------------------------------------------------------------------- /test/spqlios_svp_test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "../spqlios/arithmetic/vec_znx_arithmetic_private.h" 4 | #include "testlib/fft64_dft.h" 5 | #include "testlib/fft64_layouts.h" 6 | #include "testlib/polynomial_vector.h" 7 | 8 | void test_fft64_svp_apply_dft(SVP_APPLY_DFT_F svp) { 9 | for (uint64_t n : {2, 4, 8, 64, 128}) { 10 | MODULE* module = new_module_info(n, FFT64); 11 | // poly 1 to multiply - create and prepare 12 | fft64_svp_ppol_layout ppol(n); 13 | ppol.fill_random(1.); 14 | for (uint64_t sa : {3, 5, 8}) { 15 | for (uint64_t sr : {3, 5, 8}) { 16 | uint64_t a_sl = n + uniform_u64_bits(2); 17 | // poly 2 to multiply 18 | znx_vec_i64_layout a(n, sa, a_sl); 19 | a.fill_random(19); 20 | // original operation result 21 | fft64_vec_znx_dft_layout res(n, sr); 22 | thash hash_a_before = a.content_hash(); 23 | thash hash_ppol_before = ppol.content_hash(); 24 | svp(module, res.data, sr, ppol.data, a.data(), sa, a_sl); 25 | ASSERT_EQ(a.content_hash(), hash_a_before); 26 | ASSERT_EQ(ppol.content_hash(), hash_ppol_before); 27 | // create expected value 28 | reim_fft64vec ppo = ppol.get_copy(); 29 | std::vector expect(sr); 30 | for (uint64_t i = 0; i < sr; ++i) { 31 | expect[i] = ppo * simple_fft64(a.get_copy_zext(i)); 32 | } 33 | // this is the largest precision we can safely expect 34 | double prec_expect = n * pow(2., 19 - 52); 35 | for (uint64_t i = 0; i < sr; ++i) { 36 | reim_fft64vec actual = res.get_copy_zext(i); 37 | ASSERT_LE(infty_dist(actual, expect[i]), prec_expect); 38 | } 39 | } 40 | } 41 | 42 | delete_module_info(module); 43 | } 44 | } 45 | 46 | TEST(fft64_svp_apply_dft, svp_apply_dft) { test_fft64_svp_apply_dft(svp_apply_dft); } 47 | TEST(fft64_svp_apply_dft, fft64_svp_apply_dft_ref) { test_fft64_svp_apply_dft(fft64_svp_apply_dft_ref); } 48 | -------------------------------------------------------------------------------- /test/spqlios_vec_rnx_approxdecomp_tnxdbl_test.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "spqlios/arithmetic/vec_rnx_arithmetic_private.h" 3 | #include "testlib/vec_rnx_layout.h" 4 | 5 | static void test_rnx_approxdecomp(RNX_APPROXDECOMP_FROM_TNXDBL_F approxdec) { 6 | for (const uint64_t nn : {2, 4, 8, 32}) { 7 | MOD_RNX* module = new_rnx_module_info(nn, FFT64); 8 | for (const uint64_t ell : {1, 2, 7}) { 9 | for (const uint64_t k : {2, 5}) { 10 | TNXDBL_APPROXDECOMP_GADGET* gadget = new_tnxdbl_approxdecomp_gadget(module, k, ell); 11 | for (const uint64_t res_size : {ell, ell - 1, ell + 1}) { 12 | const uint64_t res_sl = nn + uniform_u64_bits(2); 13 | rnx_vec_f64_layout in(nn, 1, nn); 14 | in.fill_random(3); 15 | rnx_vec_f64_layout out(nn, res_size, res_sl); 16 | approxdec(module, gadget, out.data(), res_size, res_sl, in.data()); 17 | // reconstruct the output 18 | uint64_t msize = std::min(res_size, ell); 19 | double err_bnd = msize == ell ? pow(2., -double(msize * k) - 1) : pow(2., -double(msize * k)); 20 | for (uint64_t j = 0; j < nn; ++j) { 21 | double in_j = in.data()[j]; 22 | double out_j = 0; 23 | for (uint64_t i = 0; i < res_size; ++i) { 24 | out_j += out.get_copy(i).get_coeff(j) * pow(2., -double((i + 1) * k)); 25 | } 26 | double err = out_j - in_j; 27 | double err_abs = fabs(err - rint(err)); 28 | ASSERT_LE(err_abs, err_bnd); 29 | } 30 | } 31 | delete_tnxdbl_approxdecomp_gadget(gadget); 32 | } 33 | } 34 | delete_rnx_module_info(module); 35 | } 36 | } 37 | 38 | TEST(vec_rnx, rnx_approxdecomp) { test_rnx_approxdecomp(rnx_approxdecomp_from_tnxdbl); } 39 | TEST(vec_rnx, rnx_approxdecomp_ref) { test_rnx_approxdecomp(rnx_approxdecomp_from_tnxdbl_ref); } 40 | #ifdef __x86_64__ 41 | TEST(vec_rnx, rnx_approxdecomp_avx) { test_rnx_approxdecomp(rnx_approxdecomp_from_tnxdbl_avx); } 42 | #endif 43 | -------------------------------------------------------------------------------- /test/spqlios_vec_rnx_ppol_test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "spqlios/arithmetic/vec_rnx_arithmetic_private.h" 4 | #include "spqlios/reim/reim_fft.h" 5 | #include "test/testlib/vec_rnx_layout.h" 6 | 7 | static void test_vec_rnx_svp_prepare(RNX_SVP_PREPARE_F* rnx_svp_prepare, BYTES_OF_RNX_SVP_PPOL_F* tmp_bytes) { 8 | for (uint64_t n : {2, 4, 8, 64}) { 9 | MOD_RNX* mod = new_rnx_module_info(n, FFT64); 10 | const double invm = 1. / mod->m; 11 | 12 | rnx_f64 in = rnx_f64::random_log2bound(n, 40); 13 | rnx_f64 in_divide_by_m = rnx_f64::zero(n); 14 | for (uint64_t i = 0; i < n; ++i) { 15 | in_divide_by_m.set_coeff(i, in.get_coeff(i) * invm); 16 | } 17 | fft64_rnx_svp_ppol_layout out(n); 18 | reim_fft64vec expect = simple_fft64(in_divide_by_m); 19 | rnx_svp_prepare(mod, out.data, in.data()); 20 | const double* ed = (double*)expect.data(); 21 | const double* ac = (double*)out.data; 22 | for (uint64_t i = 0; i < n; ++i) { 23 | ASSERT_LE(abs(ed[i] - ac[i]), 1e-10) << i << n; 24 | } 25 | delete_rnx_module_info(mod); 26 | } 27 | } 28 | TEST(vec_rnx, vec_rnx_svp_prepare) { test_vec_rnx_svp_prepare(rnx_svp_prepare, bytes_of_rnx_svp_ppol); } 29 | TEST(vec_rnx, vec_rnx_svp_prepare_ref) { 30 | test_vec_rnx_svp_prepare(fft64_rnx_svp_prepare_ref, fft64_bytes_of_rnx_svp_ppol); 31 | } 32 | 33 | static void test_vec_rnx_svp_apply(RNX_SVP_APPLY_F* apply) { 34 | for (uint64_t n : {2, 4, 8, 64, 128}) { 35 | MOD_RNX* mod = new_rnx_module_info(n, FFT64); 36 | 37 | // poly 1 to multiply - create and prepare 38 | fft64_rnx_svp_ppol_layout ppol(n); 39 | ppol.fill_random(1.); 40 | for (uint64_t sa : {3, 5, 8}) { 41 | for (uint64_t sr : {3, 5, 8}) { 42 | uint64_t a_sl = n + uniform_u64_bits(2); 43 | uint64_t r_sl = n + uniform_u64_bits(2); 44 | // poly 2 to multiply 45 | rnx_vec_f64_layout a(n, sa, a_sl); 46 | a.fill_random(19); 47 | 48 | // original operation result 49 | rnx_vec_f64_layout res(n, sr, r_sl); 50 | thash hash_a_before = a.content_hash(); 51 | thash hash_ppol_before = ppol.content_hash(); 52 | apply(mod, res.data(), sr, r_sl, ppol.data, a.data(), sa, a_sl); 53 | ASSERT_EQ(a.content_hash(), hash_a_before); 54 | ASSERT_EQ(ppol.content_hash(), hash_ppol_before); 55 | // create expected value 56 | reim_fft64vec ppo = ppol.get_copy(); 57 | std::vector expect(sr); 58 | for (uint64_t i = 0; i < sr; ++i) { 59 | expect[i] = simple_ifft64(ppo * simple_fft64(a.get_copy_zext(i))); 60 | } 61 | // this is the largest precision we can safely expect 62 | double prec_expect = n * pow(2., 19 - 50); 63 | for (uint64_t i = 0; i < sr; ++i) { 64 | rnx_f64 actual = res.get_copy_zext(i); 65 | ASSERT_LE(infty_dist(actual, expect[i]), prec_expect); 66 | } 67 | } 68 | } 69 | delete_rnx_module_info(mod); 70 | } 71 | } 72 | TEST(vec_rnx, vec_rnx_svp_apply) { test_vec_rnx_svp_apply(rnx_svp_apply); } 73 | TEST(vec_rnx, vec_rnx_svp_apply_ref) { test_vec_rnx_svp_apply(fft64_rnx_svp_apply_ref); } 74 | -------------------------------------------------------------------------------- /test/spqlios_zn_approxdecomp_test.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "spqlios/arithmetic/zn_arithmetic_private.h" 3 | #include "testlib/test_commons.h" 4 | 5 | template 6 | static void test_tndbl_approxdecomp( // 7 | void (*approxdec)(const MOD_Z*, const TNDBL_APPROXDECOMP_GADGET*, INTTYPE*, uint64_t, const double*, uint64_t) // 8 | ) { 9 | for (const uint64_t nn : {1, 3, 8, 51}) { 10 | MOD_Z* module = new_z_module_info(DEFAULT); 11 | for (const uint64_t ell : {1, 2, 7}) { 12 | for (const uint64_t k : {2, 5}) { 13 | TNDBL_APPROXDECOMP_GADGET* gadget = new_tndbl_approxdecomp_gadget(module, k, ell); 14 | for (const uint64_t res_size : {ell * nn}) { 15 | std::vector in(nn); 16 | std::vector out(res_size); 17 | for (double& x : in) x = uniform_f64_bounds(-10, 10); 18 | approxdec(module, gadget, out.data(), res_size, in.data(), nn); 19 | // reconstruct the output 20 | double err_bnd = pow(2., -double(ell * k) - 1); 21 | for (uint64_t j = 0; j < nn; ++j) { 22 | double in_j = in[j]; 23 | double out_j = 0; 24 | for (uint64_t i = 0; i < ell; ++i) { 25 | out_j += out[ell * j + i] * pow(2., -double((i + 1) * k)); 26 | } 27 | double err = out_j - in_j; 28 | double err_abs = fabs(err - rint(err)); 29 | ASSERT_LE(err_abs, err_bnd); 30 | } 31 | } 32 | delete_tndbl_approxdecomp_gadget(gadget); 33 | } 34 | } 35 | delete_z_module_info(module); 36 | } 37 | } 38 | 39 | TEST(vec_rnx, i8_tndbl_rnx_approxdecomp) { test_tndbl_approxdecomp(i8_approxdecomp_from_tndbl); } 40 | TEST(vec_rnx, default_i8_tndbl_rnx_approxdecomp) { test_tndbl_approxdecomp(default_i8_approxdecomp_from_tndbl_ref); } 41 | 42 | TEST(vec_rnx, i16_tndbl_rnx_approxdecomp) { test_tndbl_approxdecomp(i16_approxdecomp_from_tndbl); } 43 | TEST(vec_rnx, default_i16_tndbl_rnx_approxdecomp) { test_tndbl_approxdecomp(default_i16_approxdecomp_from_tndbl_ref); } 44 | 45 | TEST(vec_rnx, i32_tndbl_rnx_approxdecomp) { test_tndbl_approxdecomp(i32_approxdecomp_from_tndbl); } 46 | TEST(vec_rnx, default_i32_tndbl_rnx_approxdecomp) { test_tndbl_approxdecomp(default_i32_approxdecomp_from_tndbl_ref); } 47 | -------------------------------------------------------------------------------- /test/spqlios_zn_conversions_test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "testlib/test_commons.h" 5 | 6 | template 7 | static void test_conv(void (*conv_f)(const MOD_Z*, DST_T* res, uint64_t res_size, const SRC_T* a, uint64_t a_size), 8 | DST_T (*ideal_conv_f)(SRC_T x), SRC_T (*random_f)()) { 9 | MOD_Z* module = new_z_module_info(DEFAULT); 10 | for (uint64_t a_size : {0, 1, 2, 42}) { 11 | for (uint64_t res_size : {0, 1, 2, 42}) { 12 | for (uint64_t trials = 0; trials < 100; ++trials) { 13 | std::vector a(a_size); 14 | std::vector res(res_size); 15 | uint64_t msize = std::min(a_size, res_size); 16 | for (SRC_T& x : a) x = random_f(); 17 | conv_f(module, res.data(), res_size, a.data(), a_size); 18 | for (uint64_t i = 0; i < msize; ++i) { 19 | DST_T expect = ideal_conv_f(a[i]); 20 | DST_T actual = res[i]; 21 | ASSERT_EQ(expect, actual); 22 | } 23 | for (uint64_t i = msize; i < res_size; ++i) { 24 | DST_T expect = 0; 25 | SRC_T actual = res[i]; 26 | ASSERT_EQ(expect, actual); 27 | } 28 | } 29 | } 30 | } 31 | delete_z_module_info(module); 32 | } 33 | 34 | static int32_t ideal_dbl_to_tn32(double a) { 35 | double _2p32 = INT64_C(1) << 32; 36 | double a_mod_1 = a - rint(a); 37 | int64_t t = rint(a_mod_1 * _2p32); 38 | return int32_t(t); 39 | } 40 | 41 | static double random_f64_10() { return uniform_f64_bounds(-10, 10); } 42 | 43 | static void test_dbl_to_tn32(DBL_TO_TN32_F dbl_to_tn32_f) { 44 | test_conv(dbl_to_tn32_f, ideal_dbl_to_tn32, random_f64_10); 45 | } 46 | 47 | TEST(zn_arithmetic, dbl_to_tn32) { test_dbl_to_tn32(dbl_to_tn32); } 48 | TEST(zn_arithmetic, dbl_to_tn32_ref) { test_dbl_to_tn32(dbl_to_tn32_ref); } 49 | 50 | static double ideal_tn32_to_dbl(int32_t a) { 51 | const double _2p32 = INT64_C(1) << 32; 52 | return double(a) / _2p32; 53 | } 54 | 55 | static int32_t random_t32() { return uniform_i64_bits(32); } 56 | 57 | static void test_tn32_to_dbl(TN32_TO_DBL_F tn32_to_dbl_f) { test_conv(tn32_to_dbl_f, ideal_tn32_to_dbl, random_t32); } 58 | 59 | TEST(zn_arithmetic, tn32_to_dbl) { test_tn32_to_dbl(tn32_to_dbl); } 60 | TEST(zn_arithmetic, tn32_to_dbl_ref) { test_tn32_to_dbl(tn32_to_dbl_ref); } 61 | 62 | static int32_t ideal_dbl_round_to_i32(double a) { return int32_t(rint(a)); } 63 | 64 | static double random_dbl_explaw_18() { return uniform_f64_bounds(-1., 1.) * pow(2., uniform_u64_bits(6) % 19); } 65 | 66 | static void test_dbl_round_to_i32(DBL_ROUND_TO_I32_F dbl_round_to_i32_f) { 67 | test_conv(dbl_round_to_i32_f, ideal_dbl_round_to_i32, random_dbl_explaw_18); 68 | } 69 | 70 | TEST(zn_arithmetic, dbl_round_to_i32) { test_dbl_round_to_i32(dbl_round_to_i32); } 71 | TEST(zn_arithmetic, dbl_round_to_i32_ref) { test_dbl_round_to_i32(dbl_round_to_i32_ref); } 72 | 73 | static double ideal_i32_to_dbl(int32_t a) { return double(a); } 74 | 75 | static int32_t random_i32_explaw_18() { return uniform_i64_bits(uniform_u64_bits(6) % 19); } 76 | 77 | static void test_i32_to_dbl(I32_TO_DBL_F i32_to_dbl_f) { 78 | test_conv(i32_to_dbl_f, ideal_i32_to_dbl, random_i32_explaw_18); 79 | } 80 | 81 | TEST(zn_arithmetic, i32_to_dbl) { test_i32_to_dbl(i32_to_dbl); } 82 | TEST(zn_arithmetic, i32_to_dbl_ref) { test_i32_to_dbl(i32_to_dbl_ref); } 83 | 84 | static int64_t ideal_dbl_round_to_i64(double a) { return rint(a); } 85 | 86 | static double random_dbl_explaw_50() { return uniform_f64_bounds(-1., 1.) * pow(2., uniform_u64_bits(7) % 51); } 87 | 88 | static void test_dbl_round_to_i64(DBL_ROUND_TO_I64_F dbl_round_to_i64_f) { 89 | test_conv(dbl_round_to_i64_f, ideal_dbl_round_to_i64, random_dbl_explaw_50); 90 | } 91 | 92 | TEST(zn_arithmetic, dbl_round_to_i64) { test_dbl_round_to_i64(dbl_round_to_i64); } 93 | TEST(zn_arithmetic, dbl_round_to_i64_ref) { test_dbl_round_to_i64(dbl_round_to_i64_ref); } 94 | 95 | static double ideal_i64_to_dbl(int64_t a) { return double(a); } 96 | 97 | static int64_t random_i64_explaw_50() { return uniform_i64_bits(uniform_u64_bits(7) % 51); } 98 | 99 | static void test_i64_to_dbl(I64_TO_DBL_F i64_to_dbl_f) { 100 | test_conv(i64_to_dbl_f, ideal_i64_to_dbl, random_i64_explaw_50); 101 | } 102 | 103 | TEST(zn_arithmetic, i64_to_dbl) { test_i64_to_dbl(i64_to_dbl); } 104 | TEST(zn_arithmetic, i64_to_dbl_ref) { test_i64_to_dbl(i64_to_dbl_ref); } 105 | -------------------------------------------------------------------------------- /test/spqlios_zn_vmp_test.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "spqlios/arithmetic/zn_arithmetic_private.h" 3 | #include "testlib/zn_layouts.h" 4 | 5 | static void test_zn_vmp_prepare(ZN32_VMP_PREPARE_CONTIGUOUS_F prep) { 6 | MOD_Z* module = new_z_module_info(DEFAULT); 7 | for (uint64_t nrows : {1, 2, 5, 15}) { 8 | for (uint64_t ncols : {1, 2, 32, 42, 67}) { 9 | std::vector src(nrows * ncols); 10 | zn32_pmat_layout out(nrows, ncols); 11 | for (int32_t& x : src) x = uniform_i64_bits(32); 12 | prep(module, out.data, src.data(), nrows, ncols); 13 | for (uint64_t i = 0; i < nrows; ++i) { 14 | for (uint64_t j = 0; j < ncols; ++j) { 15 | int32_t in = src[i * ncols + j]; 16 | int32_t actual = out.get(i, j); 17 | ASSERT_EQ(actual, in); 18 | } 19 | } 20 | } 21 | } 22 | delete_z_module_info(module); 23 | } 24 | 25 | TEST(zn, zn32_vmp_prepare_contiguous) { test_zn_vmp_prepare(zn32_vmp_prepare_contiguous); } 26 | TEST(zn, default_zn32_vmp_prepare_contiguous_ref) { test_zn_vmp_prepare(default_zn32_vmp_prepare_contiguous_ref); } 27 | 28 | template 29 | static void test_zn_vmp_apply(void (*apply)(const MOD_Z*, int32_t*, uint64_t, const INTTYPE*, uint64_t, 30 | const ZN32_VMP_PMAT*, uint64_t, uint64_t)) { 31 | MOD_Z* module = new_z_module_info(DEFAULT); 32 | for (uint64_t nrows : {1, 2, 5, 15}) { 33 | for (uint64_t ncols : {1, 2, 32, 42, 67}) { 34 | for (uint64_t a_size : {1, 2, 5, 15}) { 35 | for (uint64_t res_size : {1, 2, 32, 42, 67}) { 36 | std::vector a(a_size); 37 | zn32_pmat_layout out(nrows, ncols); 38 | std::vector res(res_size); 39 | for (INTTYPE& x : a) x = uniform_i64_bits(32); 40 | out.fill_random(); 41 | std::vector expect = vmp_product(a.data(), a_size, res_size, out); 42 | apply(module, res.data(), res_size, a.data(), a_size, out.data, nrows, ncols); 43 | for (uint64_t i = 0; i < res_size; ++i) { 44 | int32_t exp = expect[i]; 45 | int32_t actual = res[i]; 46 | ASSERT_EQ(actual, exp); 47 | } 48 | } 49 | } 50 | } 51 | } 52 | delete_z_module_info(module); 53 | } 54 | 55 | TEST(zn, zn32_vmp_apply_i32) { test_zn_vmp_apply(zn32_vmp_apply_i32); } 56 | TEST(zn, zn32_vmp_apply_i16) { test_zn_vmp_apply(zn32_vmp_apply_i16); } 57 | TEST(zn, zn32_vmp_apply_i8) { test_zn_vmp_apply(zn32_vmp_apply_i8); } 58 | 59 | TEST(zn, default_zn32_vmp_apply_i32_ref) { test_zn_vmp_apply(default_zn32_vmp_apply_i32_ref); } 60 | TEST(zn, default_zn32_vmp_apply_i16_ref) { test_zn_vmp_apply(default_zn32_vmp_apply_i16_ref); } 61 | TEST(zn, default_zn32_vmp_apply_i8_ref) { test_zn_vmp_apply(default_zn32_vmp_apply_i8_ref); } 62 | 63 | #ifdef __x86_64__ 64 | TEST(zn, default_zn32_vmp_apply_i32_avx) { test_zn_vmp_apply(default_zn32_vmp_apply_i32_avx); } 65 | TEST(zn, default_zn32_vmp_apply_i16_avx) { test_zn_vmp_apply(default_zn32_vmp_apply_i16_avx); } 66 | TEST(zn, default_zn32_vmp_apply_i8_avx) { test_zn_vmp_apply(default_zn32_vmp_apply_i8_avx); } 67 | #endif 68 | -------------------------------------------------------------------------------- /test/spqlios_znx_small_test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "../spqlios/arithmetic/vec_znx_arithmetic_private.h" 4 | #include "testlib/negacyclic_polynomial.h" 5 | 6 | static void test_znx_small_single_product(ZNX_SMALL_SINGLE_PRODUCT_F product, 7 | ZNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F product_tmp_bytes) { 8 | for (const uint64_t nn : {2, 4, 8, 64}) { 9 | MODULE* module = new_module_info(nn, FFT64); 10 | znx_i64 a = znx_i64::random_log2bound(nn, 20); 11 | znx_i64 b = znx_i64::random_log2bound(nn, 20); 12 | znx_i64 expect = naive_product(a, b); 13 | znx_i64 actual(nn); 14 | std::vector tmp(znx_small_single_product_tmp_bytes(module)); 15 | fft64_znx_small_single_product(module, actual.data(), a.data(), b.data(), tmp.data()); 16 | ASSERT_EQ(actual, expect) << actual.get_coeff(0) << " vs. " << expect.get_coeff(0); 17 | delete_module_info(module); 18 | } 19 | } 20 | 21 | TEST(znx_small, fft64_znx_small_single_product) { 22 | test_znx_small_single_product(fft64_znx_small_single_product, fft64_znx_small_single_product_tmp_bytes); 23 | } 24 | TEST(znx_small, znx_small_single_product) { 25 | test_znx_small_single_product(znx_small_single_product, znx_small_single_product_tmp_bytes); 26 | } 27 | -------------------------------------------------------------------------------- /test/testlib/fft64_dft.h: -------------------------------------------------------------------------------- 1 | #ifndef SPQLIOS_FFT64_DFT_H 2 | #define SPQLIOS_FFT64_DFT_H 3 | 4 | #include "negacyclic_polynomial.h" 5 | #include "reim4_elem.h" 6 | 7 | class reim_fft64vec { 8 | std::vector v; 9 | 10 | public: 11 | reim_fft64vec() = default; 12 | explicit reim_fft64vec(uint64_t n); 13 | reim_fft64vec(uint64_t n, const double* data); 14 | uint64_t nn() const; 15 | static reim_fft64vec zero(uint64_t n); 16 | /** random complex coefficients (unstructured) */ 17 | static reim_fft64vec random(uint64_t n, double log2bound); 18 | /** random fft of a small int polynomial */ 19 | static reim_fft64vec dft_random(uint64_t n, uint64_t log2bound); 20 | double* data(); 21 | const double* data() const; 22 | void save_as(double* dest) const; 23 | reim4_elem get_blk(uint64_t blk) const; 24 | void set_blk(uint64_t blk, const reim4_elem& value); 25 | }; 26 | 27 | reim_fft64vec operator+(const reim_fft64vec& a, const reim_fft64vec& b); 28 | reim_fft64vec operator-(const reim_fft64vec& a, const reim_fft64vec& b); 29 | reim_fft64vec operator*(const reim_fft64vec& a, const reim_fft64vec& b); 30 | reim_fft64vec operator*(double coeff, const reim_fft64vec& v); 31 | reim_fft64vec& operator+=(reim_fft64vec& a, const reim_fft64vec& b); 32 | reim_fft64vec& operator-=(reim_fft64vec& a, const reim_fft64vec& b); 33 | 34 | /** infty distance */ 35 | double infty_dist(const reim_fft64vec& a, const reim_fft64vec& b); 36 | 37 | reim_fft64vec simple_fft64(const znx_i64& polynomial); 38 | znx_i64 simple_rint_ifft64(const reim_fft64vec& fftvec); 39 | rnx_f64 naive_ifft64(const reim_fft64vec& fftvec); 40 | reim_fft64vec simple_fft64(const rnx_f64& polynomial); 41 | rnx_f64 simple_ifft64(const reim_fft64vec& v); 42 | 43 | #endif // SPQLIOS_FFT64_DFT_H 44 | -------------------------------------------------------------------------------- /test/testlib/fft64_layouts.h: -------------------------------------------------------------------------------- 1 | #ifndef SPQLIOS_FFT64_LAYOUTS_H 2 | #define SPQLIOS_FFT64_LAYOUTS_H 3 | 4 | #include "../../spqlios/arithmetic/vec_znx_arithmetic.h" 5 | #include "fft64_dft.h" 6 | #include "negacyclic_polynomial.h" 7 | #include "reim4_elem.h" 8 | 9 | /** @brief test layout for the VEC_ZNX_DFT */ 10 | struct fft64_vec_znx_dft_layout { 11 | public: 12 | const uint64_t nn; 13 | const uint64_t size; 14 | VEC_ZNX_DFT* const data; 15 | reim_vector_view view; 16 | /** @brief fill with random double values (unstructured) */ 17 | void fill_random(double log2bound); 18 | /** @brief fill with random ffts of small int polynomials */ 19 | void fill_dft_random(uint64_t log2bound); 20 | reim4_elem get(uint64_t idx, uint64_t blk) const; 21 | reim4_elem get_zext(uint64_t idx, uint64_t blk) const; 22 | void set(uint64_t idx, uint64_t blk, const reim4_elem& value); 23 | fft64_vec_znx_dft_layout(uint64_t n, uint64_t size); 24 | void fill_random_log2bound(uint64_t bits); 25 | void fill_dft_random_log2bound(uint64_t bits); 26 | double* get_addr(uint64_t idx); 27 | const double* get_addr(uint64_t idx) const; 28 | reim_fft64vec get_copy_zext(uint64_t idx) const; 29 | void set(uint64_t idx, const reim_fft64vec& value); 30 | thash content_hash() const; 31 | ~fft64_vec_znx_dft_layout(); 32 | }; 33 | 34 | /** @brief test layout for the VEC_ZNX_BIG */ 35 | class fft64_vec_znx_big_layout { 36 | public: 37 | const uint64_t nn; 38 | const uint64_t size; 39 | VEC_ZNX_BIG* const data; 40 | fft64_vec_znx_big_layout(uint64_t n, uint64_t size); 41 | void fill_random(); 42 | znx_i64 get_copy(uint64_t index) const; 43 | znx_i64 get_copy_zext(uint64_t index) const; 44 | void set(uint64_t index, const znx_i64& value); 45 | thash content_hash() const; 46 | ~fft64_vec_znx_big_layout(); 47 | }; 48 | 49 | /** @brief test layout for the VMP_PMAT */ 50 | class fft64_vmp_pmat_layout { 51 | public: 52 | const uint64_t nn; 53 | const uint64_t nrows; 54 | const uint64_t ncols; 55 | VMP_PMAT* const data; 56 | fft64_vmp_pmat_layout(uint64_t n, uint64_t nrows, uint64_t ncols); 57 | double* get_addr(uint64_t row, uint64_t col, uint64_t blk) const; 58 | reim4_elem get(uint64_t row, uint64_t col, uint64_t blk) const; 59 | thash content_hash() const; 60 | reim4_elem get_zext(uint64_t row, uint64_t col, uint64_t blk) const; 61 | reim_fft64vec get_zext(uint64_t row, uint64_t col) const; 62 | void set(uint64_t row, uint64_t col, uint64_t blk, const reim4_elem& v) const; 63 | void set(uint64_t row, uint64_t col, const reim_fft64vec& value); 64 | /** @brief fill with random double values (unstructured) */ 65 | void fill_random(double log2bound); 66 | /** @brief fill with random ffts of small int polynomials */ 67 | void fill_dft_random(uint64_t log2bound); 68 | ~fft64_vmp_pmat_layout(); 69 | }; 70 | 71 | /** @brief test layout for the SVP_PPOL */ 72 | class fft64_svp_ppol_layout { 73 | public: 74 | const uint64_t nn; 75 | SVP_PPOL* const data; 76 | fft64_svp_ppol_layout(uint64_t n); 77 | thash content_hash() const; 78 | reim_fft64vec get_copy() const; 79 | void set(const reim_fft64vec&); 80 | /** @brief fill with random double values (unstructured) */ 81 | void fill_random(double log2bound); 82 | /** @brief fill with random ffts of small int polynomials */ 83 | void fill_dft_random(uint64_t log2bound); 84 | ~fft64_svp_ppol_layout(); 85 | }; 86 | 87 | /** @brief test layout for the CNV_PVEC_L */ 88 | class fft64_cnv_left_layout { 89 | const uint64_t nn; 90 | const uint64_t size; 91 | CNV_PVEC_L* const data; 92 | fft64_cnv_left_layout(uint64_t n, uint64_t size); 93 | reim4_elem get(uint64_t idx, uint64_t blk); 94 | thash content_hash() const; 95 | ~fft64_cnv_left_layout(); 96 | }; 97 | 98 | /** @brief test layout for the CNV_PVEC_R */ 99 | class fft64_cnv_right_layout { 100 | const uint64_t nn; 101 | const uint64_t size; 102 | CNV_PVEC_R* const data; 103 | fft64_cnv_right_layout(uint64_t n, uint64_t size); 104 | reim4_elem get(uint64_t idx, uint64_t blk); 105 | thash content_hash() const; 106 | ~fft64_cnv_right_layout(); 107 | }; 108 | 109 | #endif // SPQLIOS_FFT64_LAYOUTS_H 110 | -------------------------------------------------------------------------------- /test/testlib/mod_q120.h: -------------------------------------------------------------------------------- 1 | #ifndef SPQLIOS_MOD_Q120_H 2 | #define SPQLIOS_MOD_Q120_H 3 | 4 | #include 5 | 6 | #include "../../spqlios/q120/q120_common.h" 7 | #include "test_commons.h" 8 | 9 | /** @brief centered modulo q */ 10 | int64_t centermod(int64_t v, int64_t q); 11 | int64_t centermod(uint64_t v, int64_t q); 12 | 13 | /** @brief this class represents an integer mod Q120 */ 14 | class mod_q120 { 15 | public: 16 | static constexpr int64_t Qi[] = {Q1, Q2, Q3, Q4}; 17 | int64_t a[4]; 18 | mod_q120(int64_t a1, int64_t a2, int64_t a3, int64_t a4); 19 | mod_q120(); 20 | __int128_t to_int128() const; 21 | static mod_q120 from_q120a(const void* addr); 22 | static mod_q120 from_q120b(const void* addr); 23 | static mod_q120 from_q120c(const void* addr); 24 | void save_as_q120a(void* dest) const; 25 | void save_as_q120b(void* dest) const; 26 | void save_as_q120c(void* dest) const; 27 | }; 28 | 29 | mod_q120 operator+(const mod_q120& x, const mod_q120& y); 30 | mod_q120 operator-(const mod_q120& x, const mod_q120& y); 31 | mod_q120 operator*(const mod_q120& x, const mod_q120& y); 32 | mod_q120& operator+=(mod_q120& x, const mod_q120& y); 33 | mod_q120& operator-=(mod_q120& x, const mod_q120& y); 34 | mod_q120& operator*=(mod_q120& x, const mod_q120& y); 35 | std::ostream& operator<<(std::ostream& out, const mod_q120& x); 36 | bool operator==(const mod_q120& x, const mod_q120& y); 37 | mod_q120 pow(const mod_q120& x, int32_t k); 38 | mod_q120 half(const mod_q120& x); 39 | 40 | /** @brief a uniformly drawn number mod Q120 */ 41 | mod_q120 uniform_q120(); 42 | /** @brief a uniformly random mod Q120 layout A (4 integers < 2^32) */ 43 | void uniform_q120a(void* dest); 44 | /** @brief a uniformly random mod Q120 layout B (4 integers < 2^64) */ 45 | void uniform_q120b(void* dest); 46 | /** @brief a uniformly random mod Q120 layout C (4 integers repr. x,2^32x) */ 47 | void uniform_q120c(void* dest); 48 | 49 | #endif // SPQLIOS_MOD_Q120_H 50 | -------------------------------------------------------------------------------- /test/testlib/negacyclic_polynomial.cpp: -------------------------------------------------------------------------------- 1 | #include "negacyclic_polynomial_impl.h" 2 | 3 | // explicit instantiation 4 | EXPLICIT_INSTANTIATE_POLYNOMIAL(__int128_t); 5 | EXPLICIT_INSTANTIATE_POLYNOMIAL(int64_t); 6 | EXPLICIT_INSTANTIATE_POLYNOMIAL(double); 7 | 8 | double infty_dist(const rnx_f64& a, const rnx_f64& b) { 9 | const uint64_t nn = a.nn(); 10 | const double* aa = a.data(); 11 | const double* bb = b.data(); 12 | double res = 0.; 13 | for (uint64_t i = 0; i < nn; ++i) { 14 | double d = fabs(aa[i] - bb[i]); 15 | if (d > res) res = d; 16 | } 17 | return res; 18 | } 19 | -------------------------------------------------------------------------------- /test/testlib/negacyclic_polynomial.h: -------------------------------------------------------------------------------- 1 | #ifndef SPQLIOS_NEGACYCLIC_POLYNOMIAL_H 2 | #define SPQLIOS_NEGACYCLIC_POLYNOMIAL_H 3 | 4 | #include 5 | 6 | #include "test_commons.h" 7 | 8 | template 9 | class polynomial; 10 | typedef polynomial<__int128_t> znx_i128; 11 | typedef polynomial znx_i64; 12 | typedef polynomial rnx_f64; 13 | 14 | template 15 | class polynomial { 16 | public: 17 | std::vector coeffs; 18 | /** @brief create a polynomial out of existing coeffs */ 19 | polynomial(uint64_t N, const T* c); 20 | /** @brief zero polynomial of dimension N */ 21 | explicit polynomial(uint64_t N); 22 | /** @brief empty polynomial (dim 0) */ 23 | polynomial(); 24 | 25 | /** @brief ring dimension */ 26 | uint64_t nn() const; 27 | /** @brief special setter (accept any indexes, and does the negacyclic translation) */ 28 | void set_coeff(int64_t i, T v); 29 | /** @brief special getter (accept any indexes, and does the negacyclic translation) */ 30 | T get_coeff(int64_t i) const; 31 | /** @brief returns the coefficient layout */ 32 | T* data(); 33 | /** @brief returns the coefficient layout (const version) */ 34 | const T* data() const; 35 | /** @brief saves to the layout */ 36 | void save_as(T* dest) const; 37 | /** @brief zero */ 38 | static polynomial zero(uint64_t n); 39 | /** @brief random polynomial with coefficients in [-2^log2bounds, 2^log2bounds]*/ 40 | static polynomial random_log2bound(uint64_t n, uint64_t log2bound); 41 | /** @brief random polynomial with coefficients in [-2^log2bounds, 2^log2bounds]*/ 42 | static polynomial random(uint64_t n); 43 | /** @brief random polynomial with coefficient in [lb;ub] */ 44 | static polynomial random_bound(uint64_t n, const T lb, const T ub); 45 | }; 46 | 47 | /** @brief equality operator (used during tests) */ 48 | template 49 | bool operator==(const polynomial& a, const polynomial& b); 50 | 51 | /** @brief addition operator (used during tests) */ 52 | template 53 | polynomial operator+(const polynomial& a, const polynomial& b); 54 | 55 | /** @brief subtraction operator (used during tests) */ 56 | template 57 | polynomial operator-(const polynomial& a, const polynomial& b); 58 | 59 | /** @brief negation operator (used during tests) */ 60 | template 61 | polynomial operator-(const polynomial& a); 62 | 63 | template 64 | polynomial naive_product(const polynomial& a, const polynomial& b); 65 | 66 | /** @brief distance between two real polynomials (used during tests) */ 67 | double infty_dist(const rnx_f64& a, const rnx_f64& b); 68 | 69 | #endif // SPQLIOS_NEGACYCLIC_POLYNOMIAL_H 70 | -------------------------------------------------------------------------------- /test/testlib/ntt120_dft.cpp: -------------------------------------------------------------------------------- 1 | #include "ntt120_dft.h" 2 | 3 | #include "mod_q120.h" 4 | 5 | // @brief alternative version of the NTT 6 | 7 | /** for all s=k/2^17, root_of_unity(s) = omega_0^k */ 8 | static mod_q120 root_of_unity(double s) { 9 | static mod_q120 omega_2pow17{OMEGA1, OMEGA2, OMEGA3, OMEGA4}; 10 | static double _2pow17 = 1 << 17; 11 | return pow(omega_2pow17, s * _2pow17); 12 | } 13 | static mod_q120 root_of_unity_inv(double s) { 14 | static mod_q120 omega_2pow17{OMEGA1, OMEGA2, OMEGA3, OMEGA4}; 15 | static double _2pow17 = 1 << 17; 16 | return pow(omega_2pow17, -s * _2pow17); 17 | } 18 | 19 | /** recursive naive ntt */ 20 | static void q120_ntt_naive_rec(uint64_t n, double entry_pwr, mod_q120* data) { 21 | if (n == 1) return; 22 | const uint64_t h = n / 2; 23 | const double s = entry_pwr / 2.; 24 | mod_q120 om = root_of_unity(s); 25 | for (uint64_t j = 0; j < h; ++j) { 26 | mod_q120 om_right = data[h + j] * om; 27 | data[h + j] = data[j] - om_right; 28 | data[j] = data[j] + om_right; 29 | } 30 | q120_ntt_naive_rec(h, s, data); 31 | q120_ntt_naive_rec(h, s + 0.5, data + h); 32 | } 33 | static void q120_intt_naive_rec(uint64_t n, double entry_pwr, mod_q120* data) { 34 | if (n == 1) return; 35 | const uint64_t h = n / 2; 36 | const double s = entry_pwr / 2.; 37 | q120_intt_naive_rec(h, s, data); 38 | q120_intt_naive_rec(h, s + 0.5, data + h); 39 | mod_q120 om = root_of_unity_inv(s); 40 | for (uint64_t j = 0; j < h; ++j) { 41 | mod_q120 dat_diff = half(data[j] - data[h + j]); 42 | data[j] = half(data[j] + data[h + j]); 43 | data[h + j] = dat_diff * om; 44 | } 45 | } 46 | 47 | /** user friendly version */ 48 | q120_nttvec simple_ntt120(const znx_i64& polynomial) { 49 | const uint64_t n = polynomial.nn(); 50 | q120_nttvec res(n); 51 | for (uint64_t i = 0; i < n; ++i) { 52 | int64_t xi = polynomial.get_coeff(i); 53 | res.v[i] = mod_q120(xi, xi, xi, xi); 54 | } 55 | q120_ntt_naive_rec(n, 0.5, res.v.data()); 56 | return res; 57 | } 58 | 59 | znx_i128 simple_intt120(const q120_nttvec& fftvec) { 60 | const uint64_t n = fftvec.nn(); 61 | q120_nttvec copy = fftvec; 62 | znx_i128 res(n); 63 | q120_intt_naive_rec(n, 0.5, copy.v.data()); 64 | for (uint64_t i = 0; i < n; ++i) { 65 | res.set_coeff(i, copy.v[i].to_int128()); 66 | } 67 | return res; 68 | } 69 | bool operator==(const q120_nttvec& a, const q120_nttvec& b) { return a.v == b.v; } 70 | 71 | std::vector q120_ntt_naive(const std::vector& x) { 72 | std::vector res = x; 73 | q120_ntt_naive_rec(res.size(), 0.5, res.data()); 74 | return res; 75 | } 76 | q120_nttvec::q120_nttvec(uint64_t n) : v(n) {} 77 | q120_nttvec::q120_nttvec(uint64_t n, const q120b* data) : v(n) { 78 | int64_t* d = (int64_t*)data; 79 | for (uint64_t i = 0; i < n; ++i) { 80 | v[i] = mod_q120::from_q120b(d + 4 * i); 81 | } 82 | } 83 | q120_nttvec::q120_nttvec(uint64_t n, const q120c* data) : v(n) { 84 | int64_t* d = (int64_t*)data; 85 | for (uint64_t i = 0; i < n; ++i) { 86 | v[i] = mod_q120::from_q120c(d + 4 * i); 87 | } 88 | } 89 | uint64_t q120_nttvec::nn() const { return v.size(); } 90 | q120_nttvec q120_nttvec::zero(uint64_t n) { return q120_nttvec(n); } 91 | void q120_nttvec::save_as(q120a* dest) const { 92 | int64_t* const d = (int64_t*)dest; 93 | const uint64_t n = nn(); 94 | for (uint64_t i = 0; i < n; ++i) { 95 | v[i].save_as_q120a(d + 4 * i); 96 | } 97 | } 98 | void q120_nttvec::save_as(q120b* dest) const { 99 | int64_t* const d = (int64_t*)dest; 100 | const uint64_t n = nn(); 101 | for (uint64_t i = 0; i < n; ++i) { 102 | v[i].save_as_q120b(d + 4 * i); 103 | } 104 | } 105 | void q120_nttvec::save_as(q120c* dest) const { 106 | int64_t* const d = (int64_t*)dest; 107 | const uint64_t n = nn(); 108 | for (uint64_t i = 0; i < n; ++i) { 109 | v[i].save_as_q120c(d + 4 * i); 110 | } 111 | } 112 | mod_q120 q120_nttvec::get_blk(uint64_t blk) const { 113 | REQUIRE_DRAMATICALLY(blk < nn(), "blk overflow"); 114 | return v[blk]; 115 | } 116 | q120_nttvec q120_nttvec::random(uint64_t n) { 117 | q120_nttvec res(n); 118 | for (uint64_t i = 0; i < n; ++i) { 119 | res.v[i] = uniform_q120(); 120 | } 121 | return res; 122 | } 123 | -------------------------------------------------------------------------------- /test/testlib/ntt120_dft.h: -------------------------------------------------------------------------------- 1 | #ifndef SPQLIOS_NTT120_DFT_H 2 | #define SPQLIOS_NTT120_DFT_H 3 | 4 | #include 5 | 6 | #include "../../spqlios/q120/q120_arithmetic.h" 7 | #include "mod_q120.h" 8 | #include "negacyclic_polynomial.h" 9 | #include "test_commons.h" 10 | 11 | class q120_nttvec { 12 | public: 13 | std::vector v; 14 | q120_nttvec() = default; 15 | explicit q120_nttvec(uint64_t n); 16 | q120_nttvec(uint64_t n, const q120b* data); 17 | q120_nttvec(uint64_t n, const q120c* data); 18 | uint64_t nn() const; 19 | static q120_nttvec zero(uint64_t n); 20 | static q120_nttvec random(uint64_t n); 21 | void save_as(q120a* dest) const; 22 | void save_as(q120b* dest) const; 23 | void save_as(q120c* dest) const; 24 | mod_q120 get_blk(uint64_t blk) const; 25 | }; 26 | 27 | q120_nttvec simple_ntt120(const znx_i64& polynomial); 28 | znx_i128 simple_intt120(const q120_nttvec& fftvec); 29 | bool operator==(const q120_nttvec& a, const q120_nttvec& b); 30 | 31 | #endif // SPQLIOS_NTT120_DFT_H 32 | -------------------------------------------------------------------------------- /test/testlib/ntt120_layouts.cpp: -------------------------------------------------------------------------------- 1 | #include "ntt120_layouts.h" 2 | 3 | mod_q120x2::mod_q120x2() {} 4 | mod_q120x2::mod_q120x2(const mod_q120& a, const mod_q120& b) { 5 | value[0] = a; 6 | value[1] = b; 7 | } 8 | mod_q120x2::mod_q120x2(q120x2b* addr) { 9 | uint64_t* p = (uint64_t*)addr; 10 | value[0] = mod_q120::from_q120b(p); 11 | value[1] = mod_q120::from_q120b(p + 4); 12 | } 13 | 14 | ntt120_vec_znx_dft_layout::ntt120_vec_znx_dft_layout(uint64_t n, uint64_t size) 15 | : nn(n), // 16 | size(size), // 17 | data((VEC_ZNX_DFT*)alloc64(n * size * 4 * sizeof(uint64_t))) {} 18 | 19 | mod_q120x2 ntt120_vec_znx_dft_layout::get_copy_zext(uint64_t idx, uint64_t blk) { 20 | return mod_q120x2(get_blk(idx, blk)); 21 | } 22 | q120x2b* ntt120_vec_znx_dft_layout::get_blk(uint64_t idx, uint64_t blk) { 23 | REQUIRE_DRAMATICALLY(idx < size, "idx overflow"); 24 | REQUIRE_DRAMATICALLY(blk < nn / 2, "blk overflow"); 25 | uint64_t* d = (uint64_t*)data; 26 | return (q120x2b*)(d + 4 * nn * idx + 8 * blk); 27 | } 28 | ntt120_vec_znx_dft_layout::~ntt120_vec_znx_dft_layout() { spqlios_free(data); } 29 | q120_nttvec ntt120_vec_znx_dft_layout::get_copy_zext(uint64_t idx) { 30 | int64_t* d = (int64_t*)data; 31 | if (idx < size) { 32 | return q120_nttvec(nn, (q120b*)(d + idx * nn * 4)); 33 | } else { 34 | return q120_nttvec::zero(nn); 35 | } 36 | } 37 | void ntt120_vec_znx_dft_layout::set(uint64_t idx, const q120_nttvec& value) { 38 | REQUIRE_DRAMATICALLY(idx < size, "index overflow: " << idx << " / " << size); 39 | q120b* dest_addr = (q120b*)((int64_t*)data + idx * nn * 4); 40 | value.save_as(dest_addr); 41 | } 42 | void ntt120_vec_znx_dft_layout::fill_random() { 43 | for (uint64_t i = 0; i < size; ++i) { 44 | set(i, q120_nttvec::random(nn)); 45 | } 46 | } 47 | thash ntt120_vec_znx_dft_layout::content_hash() const { return test_hash(data, nn * size * 4 * sizeof(int64_t)); } 48 | ntt120_vec_znx_big_layout::ntt120_vec_znx_big_layout(uint64_t n, uint64_t size) 49 | : nn(n), // 50 | size(size), 51 | data((VEC_ZNX_BIG*)alloc64(n * size * sizeof(__int128_t))) {} 52 | 53 | znx_i128 ntt120_vec_znx_big_layout::get_copy(uint64_t index) const { return znx_i128(nn, get_addr(index)); } 54 | znx_i128 ntt120_vec_znx_big_layout::get_copy_zext(uint64_t index) const { 55 | if (index < size) { 56 | return znx_i128(nn, get_addr(index)); 57 | } else { 58 | return znx_i128::zero(nn); 59 | } 60 | } 61 | __int128* ntt120_vec_znx_big_layout::get_addr(uint64_t index) const { 62 | REQUIRE_DRAMATICALLY(index < size, "index overflow: " << index << " / " << size); 63 | return (__int128_t*)data + index * nn; 64 | } 65 | void ntt120_vec_znx_big_layout::set(uint64_t index, const znx_i128& value) { value.save_as(get_addr(index)); } 66 | ntt120_vec_znx_big_layout::~ntt120_vec_znx_big_layout() { spqlios_free(data); } 67 | -------------------------------------------------------------------------------- /test/testlib/ntt120_layouts.h: -------------------------------------------------------------------------------- 1 | #ifndef SPQLIOS_NTT120_LAYOUTS_H 2 | #define SPQLIOS_NTT120_LAYOUTS_H 3 | 4 | #include "../../spqlios/arithmetic/vec_znx_arithmetic.h" 5 | #include "mod_q120.h" 6 | #include "negacyclic_polynomial.h" 7 | #include "ntt120_dft.h" 8 | #include "test_commons.h" 9 | 10 | struct q120b_vector_view {}; 11 | 12 | struct mod_q120x2 { 13 | mod_q120 value[2]; 14 | mod_q120x2(); 15 | mod_q120x2(const mod_q120& a, const mod_q120& b); 16 | mod_q120x2(__int128_t value); 17 | explicit mod_q120x2(q120x2b* addr); 18 | explicit mod_q120x2(q120x2c* addr); 19 | void save_as(q120x2b* addr) const; 20 | void save_as(q120x2c* addr) const; 21 | static mod_q120x2 random(); 22 | }; 23 | mod_q120x2 operator+(const mod_q120x2& a, const mod_q120x2& b); 24 | mod_q120x2 operator-(const mod_q120x2& a, const mod_q120x2& b); 25 | mod_q120x2 operator*(const mod_q120x2& a, const mod_q120x2& b); 26 | bool operator==(const mod_q120x2& a, const mod_q120x2& b); 27 | bool operator!=(const mod_q120x2& a, const mod_q120x2& b); 28 | mod_q120x2& operator+=(mod_q120x2& a, const mod_q120x2& b); 29 | mod_q120x2& operator-=(mod_q120x2& a, const mod_q120x2& b); 30 | 31 | /** @brief test layout for the VEC_ZNX_DFT */ 32 | struct ntt120_vec_znx_dft_layout { 33 | const uint64_t nn; 34 | const uint64_t size; 35 | VEC_ZNX_DFT* const data; 36 | ntt120_vec_znx_dft_layout(uint64_t n, uint64_t size); 37 | mod_q120x2 get_copy_zext(uint64_t idx, uint64_t blk); 38 | q120_nttvec get_copy_zext(uint64_t idx); 39 | void set(uint64_t idx, const q120_nttvec& v); 40 | q120x2b* get_blk(uint64_t idx, uint64_t blk); 41 | thash content_hash() const; 42 | void fill_random(); 43 | ~ntt120_vec_znx_dft_layout(); 44 | }; 45 | 46 | /** @brief test layout for the VEC_ZNX_BIG */ 47 | class ntt120_vec_znx_big_layout { 48 | public: 49 | const uint64_t nn; 50 | const uint64_t size; 51 | VEC_ZNX_BIG* const data; 52 | ntt120_vec_znx_big_layout(uint64_t n, uint64_t size); 53 | 54 | private: 55 | __int128* get_addr(uint64_t index) const; 56 | 57 | public: 58 | znx_i128 get_copy(uint64_t index) const; 59 | znx_i128 get_copy_zext(uint64_t index) const; 60 | void set(uint64_t index, const znx_i128& value); 61 | ~ntt120_vec_znx_big_layout(); 62 | }; 63 | 64 | /** @brief test layout for the VMP_PMAT */ 65 | class ntt120_vmp_pmat_layout { 66 | const uint64_t nn; 67 | const uint64_t nrows; 68 | const uint64_t ncols; 69 | VMP_PMAT* const data; 70 | ntt120_vmp_pmat_layout(uint64_t n, uint64_t nrows, uint64_t ncols); 71 | mod_q120x2 get(uint64_t row, uint64_t col, uint64_t blk) const; 72 | ~ntt120_vmp_pmat_layout(); 73 | }; 74 | 75 | /** @brief test layout for the SVP_PPOL */ 76 | class ntt120_svp_ppol_layout { 77 | const uint64_t nn; 78 | SVP_PPOL* const data; 79 | ntt120_svp_ppol_layout(uint64_t n); 80 | ~ntt120_svp_ppol_layout(); 81 | }; 82 | 83 | /** @brief test layout for the CNV_PVEC_L */ 84 | class ntt120_cnv_left_layout { 85 | const uint64_t nn; 86 | const uint64_t size; 87 | CNV_PVEC_L* const data; 88 | ntt120_cnv_left_layout(uint64_t n, uint64_t size); 89 | mod_q120x2 get(uint64_t idx, uint64_t blk); 90 | ~ntt120_cnv_left_layout(); 91 | }; 92 | 93 | /** @brief test layout for the CNV_PVEC_R */ 94 | class ntt120_cnv_right_layout { 95 | const uint64_t nn; 96 | const uint64_t size; 97 | CNV_PVEC_R* const data; 98 | ntt120_cnv_right_layout(uint64_t n, uint64_t size); 99 | mod_q120x2 get(uint64_t idx, uint64_t blk); 100 | ~ntt120_cnv_right_layout(); 101 | }; 102 | 103 | #endif // SPQLIOS_NTT120_LAYOUTS_H 104 | -------------------------------------------------------------------------------- /test/testlib/polynomial_vector.cpp: -------------------------------------------------------------------------------- 1 | #include "polynomial_vector.h" 2 | 3 | #include 4 | 5 | #ifdef VALGRIND_MEM_TESTS 6 | #include "valgrind/memcheck.h" 7 | #endif 8 | 9 | #define CANARY_PADDING (1024) 10 | #define GARBAGE_VALUE (242) 11 | 12 | znx_vec_i64_layout::znx_vec_i64_layout(uint64_t n, uint64_t size, uint64_t slice) : n(n), size(size), slice(slice) { 13 | REQUIRE_DRAMATICALLY(is_pow2(n), "not a power of 2" << n); 14 | REQUIRE_DRAMATICALLY(slice >= n, "slice too small" << slice << " < " << n); 15 | this->region = (uint8_t*)malloc(size * slice * sizeof(int64_t) + 2 * CANARY_PADDING); 16 | this->data_start = (int64_t*)(region + CANARY_PADDING); 17 | // ensure that any invalid value is kind-of garbage 18 | memset(region, GARBAGE_VALUE, size * slice * sizeof(int64_t) + 2 * CANARY_PADDING); 19 | // mark inter-slice memory as non accessible 20 | #ifdef VALGRIND_MEM_TESTS 21 | VALGRIND_MAKE_MEM_NOACCESS(region, CANARY_PADDING); 22 | VALGRIND_MAKE_MEM_NOACCESS(region + size * slice * sizeof(int64_t) + CANARY_PADDING, CANARY_PADDING); 23 | for (uint64_t i = 0; i < size; ++i) { 24 | VALGRIND_MAKE_MEM_UNDEFINED(data_start + i * slice, n * sizeof(int64_t)); 25 | } 26 | if (size != slice) { 27 | for (uint64_t i = 0; i < size; ++i) { 28 | VALGRIND_MAKE_MEM_NOACCESS(data_start + i * slice + n, (slice - n) * sizeof(int64_t)); 29 | } 30 | } 31 | #endif 32 | } 33 | 34 | znx_vec_i64_layout::~znx_vec_i64_layout() { free(region); } 35 | 36 | znx_i64 znx_vec_i64_layout::get_copy_zext(uint64_t index) const { 37 | if (index < size) { 38 | return znx_i64(n, data_start + index * slice); 39 | } else { 40 | return znx_i64::zero(n); 41 | } 42 | } 43 | 44 | znx_i64 znx_vec_i64_layout::get_copy(uint64_t index) const { 45 | REQUIRE_DRAMATICALLY(index < size, "index overflow: " << index << " / " << size); 46 | return znx_i64(n, data_start + index * slice); 47 | } 48 | 49 | void znx_vec_i64_layout::set(uint64_t index, const znx_i64& elem) { 50 | REQUIRE_DRAMATICALLY(index < size, "index overflow: " << index << " / " << size); 51 | REQUIRE_DRAMATICALLY(elem.nn() == n, "incompatible ring dimensions: " << elem.nn() << " / " << n); 52 | elem.save_as(data_start + index * slice); 53 | } 54 | 55 | int64_t* znx_vec_i64_layout::data() { return data_start; } 56 | const int64_t* znx_vec_i64_layout::data() const { return data_start; } 57 | 58 | void znx_vec_i64_layout::fill_random(uint64_t bits) { 59 | for (uint64_t i = 0; i < size; ++i) { 60 | set(i, znx_i64::random_log2bound(n, bits)); 61 | } 62 | } 63 | __uint128_t znx_vec_i64_layout::content_hash() const { 64 | test_hasher hasher; 65 | for (uint64_t i = 0; i < size; ++i) { 66 | hasher.update(data() + i * slice, n * sizeof(int64_t)); 67 | } 68 | return hasher.hash(); 69 | } 70 | -------------------------------------------------------------------------------- /test/testlib/polynomial_vector.h: -------------------------------------------------------------------------------- 1 | #ifndef SPQLIOS_POLYNOMIAL_VECTOR_H 2 | #define SPQLIOS_POLYNOMIAL_VECTOR_H 3 | 4 | #include "negacyclic_polynomial.h" 5 | #include "test_commons.h" 6 | 7 | /** @brief a test memory layout for znx i64 polynomials vectors */ 8 | class znx_vec_i64_layout { 9 | uint64_t n; 10 | uint64_t size; 11 | uint64_t slice; 12 | int64_t* data_start; 13 | uint8_t* region; 14 | 15 | public: 16 | // NO-COPY structure 17 | znx_vec_i64_layout(const znx_vec_i64_layout&) = delete; 18 | void operator=(const znx_vec_i64_layout&) = delete; 19 | znx_vec_i64_layout(znx_vec_i64_layout&&) = delete; 20 | void operator=(znx_vec_i64_layout&&) = delete; 21 | /** @brief initialises a memory layout */ 22 | znx_vec_i64_layout(uint64_t n, uint64_t size, uint64_t slice); 23 | /** @brief destructor */ 24 | ~znx_vec_i64_layout(); 25 | 26 | /** @brief get a copy of item index index (extended with zeros) */ 27 | znx_i64 get_copy_zext(uint64_t index) const; 28 | /** @brief get a copy of item index index (extended with zeros) */ 29 | znx_i64 get_copy(uint64_t index) const; 30 | /** @brief get a copy of item index index (index 2 | #include 3 | 4 | #include "test_commons.h" 5 | 6 | bool is_pow2(uint64_t n) { return !(n & (n - 1)); } 7 | 8 | test_rng& randgen() { 9 | static test_rng gen; 10 | return gen; 11 | } 12 | uint64_t uniform_u64() { 13 | static std::uniform_int_distribution dist64(0, UINT64_MAX); 14 | return dist64(randgen()); 15 | } 16 | 17 | uint64_t uniform_u64_bits(uint64_t nbits) { 18 | if (nbits >= 64) return uniform_u64(); 19 | return uniform_u64() >> (64 - nbits); 20 | } 21 | 22 | int64_t uniform_i64() { 23 | std::uniform_int_distribution dist; 24 | return dist(randgen()); 25 | } 26 | 27 | int64_t uniform_i64_bits(uint64_t nbits) { 28 | int64_t bound = int64_t(1) << nbits; 29 | std::uniform_int_distribution dist(-bound, bound); 30 | return dist(randgen()); 31 | } 32 | 33 | int64_t uniform_i64_bounds(const int64_t lb, const int64_t ub) { 34 | std::uniform_int_distribution dist(lb, ub); 35 | return dist(randgen()); 36 | } 37 | 38 | __int128_t uniform_i128_bounds(const __int128_t lb, const __int128_t ub) { 39 | std::uniform_int_distribution<__int128_t> dist(lb, ub); 40 | return dist(randgen()); 41 | } 42 | 43 | double random_f64_gaussian(double stdev) { 44 | std::normal_distribution dist(0, stdev); 45 | return dist(randgen()); 46 | } 47 | 48 | double uniform_f64_bounds(const double lb, const double ub) { 49 | std::uniform_real_distribution dist(lb, ub); 50 | return dist(randgen()); 51 | } 52 | 53 | double uniform_f64_01() { 54 | return uniform_f64_bounds(0, 1); 55 | } 56 | -------------------------------------------------------------------------------- /test/testlib/reim4_elem.h: -------------------------------------------------------------------------------- 1 | #ifndef SPQLIOS_REIM4_ELEM_H 2 | #define SPQLIOS_REIM4_ELEM_H 3 | 4 | #include "test_commons.h" 5 | 6 | /** @brief test class representing one single reim4 element */ 7 | class reim4_elem { 8 | public: 9 | /** @brief 8 components (4 real parts followed by 4 imag parts) */ 10 | double value[8]; 11 | /** @brief constructs from 4 real parts and 4 imaginary parts */ 12 | reim4_elem(const double* re, const double* im); 13 | /** @brief constructs from 8 components */ 14 | explicit reim4_elem(const double* layout); 15 | /** @brief zero */ 16 | reim4_elem(); 17 | /** @brief saves the real parts to re and the 4 imag to im */ 18 | void save_re_im(double* re, double* im) const; 19 | /** @brief saves the 8 components to reim4 */ 20 | void save_as(double* reim4) const; 21 | static reim4_elem zero(); 22 | }; 23 | 24 | /** @brief checks for equality */ 25 | bool operator==(const reim4_elem& x, const reim4_elem& y); 26 | /** @brief random gaussian reim4 of stdev 1 and mean 0 */ 27 | reim4_elem gaussian_reim4(); 28 | /** @brief addition */ 29 | reim4_elem operator+(const reim4_elem& x, const reim4_elem& y); 30 | reim4_elem& operator+=(reim4_elem& x, const reim4_elem& y); 31 | /** @brief subtraction */ 32 | reim4_elem operator-(const reim4_elem& x, const reim4_elem& y); 33 | reim4_elem& operator-=(reim4_elem& x, const reim4_elem& y); 34 | /** @brief product */ 35 | reim4_elem operator*(const reim4_elem& x, const reim4_elem& y); 36 | std::ostream& operator<<(std::ostream& out, const reim4_elem& x); 37 | /** @brief distance in infty norm */ 38 | double infty_dist(const reim4_elem& x, const reim4_elem& y); 39 | 40 | /** @brief test class representing the view of one reim of m complexes */ 41 | class reim4_array_view { 42 | uint64_t size; ///< size of the reim array 43 | double* data; ///< pointer to the start of the array 44 | public: 45 | /** @brief ininitializes a view at an existing given address */ 46 | reim4_array_view(uint64_t size, double* data); 47 | ; 48 | /** @brief gets the i-th element */ 49 | reim4_elem get(uint64_t i) const; 50 | /** @brief sets the i-th element */ 51 | void set(uint64_t i, const reim4_elem& value); 52 | }; 53 | 54 | /** @brief test class representing the view of one matrix of nrowsxncols reim4's */ 55 | class reim4_matrix_view { 56 | uint64_t nrows; ///< number of rows 57 | uint64_t ncols; ///< number of columns 58 | double* data; ///< pointer to the start of the matrix 59 | public: 60 | /** @brief ininitializes a view at an existing given address */ 61 | reim4_matrix_view(uint64_t nrows, uint64_t ncols, double* data); 62 | /** @brief gets the i-th element */ 63 | reim4_elem get(uint64_t row, uint64_t col) const; 64 | /** @brief sets the i-th element */ 65 | void set(uint64_t row, uint64_t col, const reim4_elem& value); 66 | }; 67 | 68 | /** @brief test class representing the view of one reim of m complexes */ 69 | class reim_view { 70 | uint64_t m; ///< (complex) dimension of the reim polynomial 71 | double* data; ///< address of the start of the reim polynomial 72 | public: 73 | /** @brief ininitializes a view at an existing given address */ 74 | reim_view(uint64_t m, double* data); 75 | ; 76 | /** @brief extracts the i-th reim4 block (i 3 | // https://github.com/mjosaarinen/tiny_sha3 4 | // License: MIT 5 | 6 | #ifndef SHA3_H 7 | #define SHA3_H 8 | 9 | #ifdef __cplusplus 10 | extern "C" { 11 | #endif 12 | 13 | #include 14 | #include 15 | 16 | #ifndef KECCAKF_ROUNDS 17 | #define KECCAKF_ROUNDS 24 18 | #endif 19 | 20 | #ifndef ROTL64 21 | #define ROTL64(x, y) (((x) << (y)) | ((x) >> (64 - (y)))) 22 | #endif 23 | 24 | // state context 25 | typedef struct { 26 | union { // state: 27 | uint8_t b[200]; // 8-bit bytes 28 | uint64_t q[25]; // 64-bit words 29 | } st; 30 | int pt, rsiz, mdlen; // these don't overflow 31 | } sha3_ctx_t; 32 | 33 | // Compression function. 34 | void sha3_keccakf(uint64_t st[25]); 35 | 36 | // OpenSSL - like interfece 37 | int sha3_init(sha3_ctx_t* c, int mdlen); // mdlen = hash output in bytes 38 | int sha3_update(sha3_ctx_t* c, const void* data, size_t len); 39 | int sha3_final(void* md, sha3_ctx_t* c); // digest goes to md 40 | 41 | // compute a sha3 hash (md) of given byte length from "in" 42 | void* sha3(const void* in, size_t inlen, void* md, int mdlen); 43 | 44 | // SHAKE128 and SHAKE256 extensible-output functions 45 | #define shake128_init(c) sha3_init(c, 16) 46 | #define shake256_init(c) sha3_init(c, 32) 47 | #define shake_update sha3_update 48 | 49 | void shake_xof(sha3_ctx_t* c); 50 | void shake_out(sha3_ctx_t* c, void* out, size_t len); 51 | 52 | #ifdef __cplusplus 53 | } 54 | #endif 55 | 56 | #endif // SHA3_H 57 | -------------------------------------------------------------------------------- /test/testlib/test_commons.cpp: -------------------------------------------------------------------------------- 1 | #include "test_commons.h" 2 | 3 | #include 4 | 5 | std::ostream& operator<<(std::ostream& out, __int128_t x) { 6 | char c[35] = {0}; 7 | snprintf(c, 35, "0x%016" PRIx64 "%016" PRIx64, uint64_t(x >> 64), uint64_t(x)); 8 | return out << c; 9 | } 10 | std::ostream& operator<<(std::ostream& out, __uint128_t x) { return out << __int128_t(x); } 11 | -------------------------------------------------------------------------------- /test/testlib/test_commons.h: -------------------------------------------------------------------------------- 1 | #ifndef SPQLIOS_TEST_COMMONS_H 2 | #define SPQLIOS_TEST_COMMONS_H 3 | 4 | #include 5 | #include 6 | 7 | #include "../../spqlios/commons.h" 8 | 9 | /** @brief macro that crashes if the condition are not met */ 10 | #define REQUIRE_DRAMATICALLY(req_contition, error_msg) \ 11 | do { \ 12 | if (!(req_contition)) { \ 13 | std::cerr << "REQUIREMENT FAILED at " << __FILE__ << ":" << __LINE__ << ": " << error_msg << std::endl; \ 14 | abort(); \ 15 | } \ 16 | } while (0) 17 | 18 | typedef std::default_random_engine test_rng; 19 | /** @brief reference to the default test rng */ 20 | test_rng& randgen(); 21 | /** @brief uniformly random 64-bit uint */ 22 | uint64_t uniform_u64(); 23 | /** @brief uniformly random number <= 2^nbits-1 */ 24 | uint64_t uniform_u64_bits(uint64_t nbits); 25 | /** @brief uniformly random signed 64-bit number */ 26 | int64_t uniform_i64(); 27 | /** @brief uniformly random signed |number| <= 2^nbits */ 28 | int64_t uniform_i64_bits(uint64_t nbits); 29 | /** @brief uniformly random signed lb <= number <= ub */ 30 | int64_t uniform_i64_bounds(const int64_t lb, const int64_t ub); 31 | /** @brief uniformly random signed lb <= number <= ub */ 32 | __int128_t uniform_i128_bounds(const __int128_t lb, const __int128_t ub); 33 | /** @brief uniformly random gaussian float64 */ 34 | double random_f64_gaussian(double stdev = 1); 35 | /** @brief uniformly random signed lb <= number <= ub */ 36 | double uniform_f64_bounds(const double lb, const double ub); 37 | /** @brief uniformly random float64 in [0,1] */ 38 | double uniform_f64_01(); 39 | /** @brief random gaussian float64 */ 40 | double random_f64_gaussian(double stdev); 41 | 42 | bool is_pow2(uint64_t n); 43 | 44 | void* alloc64(uint64_t size); 45 | 46 | typedef __uint128_t thash; 47 | /** @brief returns some pseudorandom hash of a contiguous content */ 48 | thash test_hash(const void* data, uint64_t size); 49 | /** @brief class to return a pseudorandom hash of a piecewise-defined content */ 50 | class test_hasher { 51 | void* md; 52 | public: 53 | test_hasher(); 54 | test_hasher(const test_hasher&) = delete; 55 | void operator=(const test_hasher&) = delete; 56 | /** 57 | * @brief append input bytes. 58 | * The final hash only depends on the concatenation of bytes, not on the 59 | * way the content was split into multiple calls to update. 60 | */ 61 | void update(const void* data, uint64_t size); 62 | /** 63 | * @brief returns the final hash. 64 | * no more calls to update(...) shall be issued after this call. 65 | */ 66 | thash hash(); 67 | ~test_hasher(); 68 | }; 69 | 70 | // not included by default, since it makes some versions of gtest not compile 71 | // std::ostream& operator<<(std::ostream& out, __int128_t x); 72 | // std::ostream& operator<<(std::ostream& out, __uint128_t x); 73 | 74 | #endif // SPQLIOS_TEST_COMMONS_H 75 | -------------------------------------------------------------------------------- /test/testlib/test_hash.cpp: -------------------------------------------------------------------------------- 1 | #include "sha3.h" 2 | #include "test_commons.h" 3 | 4 | /** @brief returns some pseudorandom hash of the content */ 5 | thash test_hash(const void* data, uint64_t size) { 6 | thash res; 7 | sha3(data, size, &res, sizeof(res)); 8 | return res; 9 | } 10 | /** @brief class to return a pseudorandom hash of the content */ 11 | test_hasher::test_hasher() { 12 | md = malloc(sizeof(sha3_ctx_t)); 13 | sha3_init((sha3_ctx_t*)md, 16); 14 | } 15 | 16 | void test_hasher::update(const void* data, uint64_t size) { sha3_update((sha3_ctx_t*)md, data, size); } 17 | 18 | thash test_hasher::hash() { 19 | thash res; 20 | sha3_final(&res, (sha3_ctx_t*)md); 21 | return res; 22 | } 23 | 24 | test_hasher::~test_hasher() { free(md); } 25 | -------------------------------------------------------------------------------- /test/testlib/vec_rnx_layout.h: -------------------------------------------------------------------------------- 1 | #ifndef SPQLIOS_EXT_VEC_RNX_LAYOUT_H 2 | #define SPQLIOS_EXT_VEC_RNX_LAYOUT_H 3 | 4 | #include "../../spqlios/arithmetic/vec_rnx_arithmetic.h" 5 | #include "fft64_dft.h" 6 | #include "negacyclic_polynomial.h" 7 | #include "reim4_elem.h" 8 | #include "test_commons.h" 9 | 10 | /** @brief a test memory layout for rnx i64 polynomials vectors */ 11 | class rnx_vec_f64_layout { 12 | uint64_t n; 13 | uint64_t size; 14 | uint64_t slice; 15 | double* data_start; 16 | uint8_t* region; 17 | 18 | public: 19 | // NO-COPY structure 20 | rnx_vec_f64_layout(const rnx_vec_f64_layout&) = delete; 21 | void operator=(const rnx_vec_f64_layout&) = delete; 22 | rnx_vec_f64_layout(rnx_vec_f64_layout&&) = delete; 23 | void operator=(rnx_vec_f64_layout&&) = delete; 24 | /** @brief initialises a memory layout */ 25 | rnx_vec_f64_layout(uint64_t n, uint64_t size, uint64_t slice); 26 | /** @brief destructor */ 27 | ~rnx_vec_f64_layout(); 28 | 29 | /** @brief get a copy of item index index (extended with zeros) */ 30 | rnx_f64 get_copy_zext(uint64_t index) const; 31 | /** @brief get a copy of item index index (extended with zeros) */ 32 | rnx_f64 get_copy(uint64_t index) const; 33 | /** @brief get a copy of item index index (extended with zeros) */ 34 | reim_fft64vec get_dft_copy_zext(uint64_t index) const; 35 | /** @brief get a copy of item index index (extended with zeros) */ 36 | reim_fft64vec get_dft_copy(uint64_t index) const; 37 | 38 | /** @brief get a copy of item index index (index> 5; 14 | const uint64_t rem_ncols = ncols & 31; 15 | uint64_t blk = col >> 5; 16 | uint64_t col_rem = col & 31; 17 | if (blk < nblk) { 18 | // column is part of a full block 19 | return (int32_t*)data + blk * nrows * 32 + row * 32 + col_rem; 20 | } else { 21 | // column is part of the last block 22 | return (int32_t*)data + blk * nrows * 32 + row * rem_ncols + col_rem; 23 | } 24 | } 25 | int32_t zn32_pmat_layout::get(uint64_t row, uint64_t col) const { return *get_addr(row, col); } 26 | int32_t zn32_pmat_layout::get_zext(uint64_t row, uint64_t col) const { 27 | if (row >= nrows || col >= ncols) return 0; 28 | return *get_addr(row, col); 29 | } 30 | void zn32_pmat_layout::set(uint64_t row, uint64_t col, int32_t value) { *get_addr(row, col) = value; } 31 | void zn32_pmat_layout::fill_random() { 32 | int32_t* d = (int32_t*)data; 33 | for (uint64_t i = 0; i < nrows * ncols; ++i) d[i] = uniform_i64_bits(32); 34 | } 35 | thash zn32_pmat_layout::content_hash() const { return test_hash(data, nrows * ncols * sizeof(int32_t)); } 36 | 37 | template 38 | std::vector vmp_product(const T* vec, uint64_t vec_size, uint64_t out_size, const zn32_pmat_layout& mat) { 39 | uint64_t rows = std::min(vec_size, mat.nrows); 40 | uint64_t cols = std::min(out_size, mat.ncols); 41 | std::vector res(out_size, 0); 42 | for (uint64_t j = 0; j < cols; ++j) { 43 | for (uint64_t i = 0; i < rows; ++i) { 44 | res[j] += vec[i] * mat.get(i, j); 45 | } 46 | } 47 | return res; 48 | } 49 | 50 | template std::vector vmp_product(const int8_t* vec, uint64_t vec_size, uint64_t out_size, 51 | const zn32_pmat_layout& mat); 52 | template std::vector vmp_product(const int16_t* vec, uint64_t vec_size, uint64_t out_size, 53 | const zn32_pmat_layout& mat); 54 | template std::vector vmp_product(const int32_t* vec, uint64_t vec_size, uint64_t out_size, 55 | const zn32_pmat_layout& mat); 56 | -------------------------------------------------------------------------------- /test/testlib/zn_layouts.h: -------------------------------------------------------------------------------- 1 | #ifndef SPQLIOS_EXT_ZN_LAYOUTS_H 2 | #define SPQLIOS_EXT_ZN_LAYOUTS_H 3 | 4 | #include "../../spqlios/arithmetic/zn_arithmetic.h" 5 | #include "test_commons.h" 6 | 7 | class zn32_pmat_layout { 8 | public: 9 | const uint64_t nrows; 10 | const uint64_t ncols; 11 | ZN32_VMP_PMAT* const data; 12 | zn32_pmat_layout(uint64_t nrows, uint64_t ncols); 13 | 14 | private: 15 | int32_t* get_addr(uint64_t row, uint64_t col) const; 16 | 17 | public: 18 | int32_t get(uint64_t row, uint64_t col) const; 19 | int32_t get_zext(uint64_t row, uint64_t col) const; 20 | void set(uint64_t row, uint64_t col, int32_t value); 21 | void fill_random(); 22 | thash content_hash() const; 23 | ~zn32_pmat_layout(); 24 | }; 25 | 26 | template 27 | std::vector vmp_product(const T* vec, uint64_t vec_size, uint64_t out_size, const zn32_pmat_layout& mat); 28 | 29 | #endif // SPQLIOS_EXT_ZN_LAYOUTS_H 30 | --------------------------------------------------------------------------------