├── .clang-format
├── .github
└── workflows
│ ├── auto-release.yml
│ ├── cmake-build-test-arm64.yml
│ ├── cmake-build-test-darwin.yml
│ ├── cmake-build-test-win64.yml
│ └── cmake-build-test.yml
├── .gitignore
├── CMakeLists.txt
├── CONTRIBUTING.md
├── Changelog.md
├── LICENSE
├── README.md
├── docs
├── api-full.svg
├── logo-inpher1.png
├── logo-inpher2.png
├── logo-sandboxaq-black.svg
└── logo-sandboxaq-white.svg
├── manifest.yaml
├── scripts
├── auto-release.sh
├── ci-pkg
└── prepare-release
├── spqlios
├── CMakeLists.txt
├── arithmetic
│ ├── module_api.c
│ ├── scalar_vector_product.c
│ ├── vec_rnx_api.c
│ ├── vec_rnx_approxdecomp_avx.c
│ ├── vec_rnx_approxdecomp_ref.c
│ ├── vec_rnx_arithmetic.c
│ ├── vec_rnx_arithmetic.h
│ ├── vec_rnx_arithmetic_avx.c
│ ├── vec_rnx_arithmetic_plugin.h
│ ├── vec_rnx_arithmetic_private.h
│ ├── vec_rnx_conversions_ref.c
│ ├── vec_rnx_svp_ref.c
│ ├── vec_rnx_vmp_avx.c
│ ├── vec_rnx_vmp_ref.c
│ ├── vec_znx.c
│ ├── vec_znx_arithmetic.h
│ ├── vec_znx_arithmetic_private.h
│ ├── vec_znx_avx.c
│ ├── vec_znx_big.c
│ ├── vec_znx_dft.c
│ ├── vec_znx_dft_avx2.c
│ ├── vector_matrix_product.c
│ ├── vector_matrix_product_avx.c
│ ├── zn_api.c
│ ├── zn_approxdecomp_ref.c
│ ├── zn_arithmetic.h
│ ├── zn_arithmetic_plugin.h
│ ├── zn_arithmetic_private.h
│ ├── zn_conversions_ref.c
│ ├── zn_vmp_int16_avx.c
│ ├── zn_vmp_int16_ref.c
│ ├── zn_vmp_int32_avx.c
│ ├── zn_vmp_int32_ref.c
│ ├── zn_vmp_int8_avx.c
│ ├── zn_vmp_int8_ref.c
│ ├── zn_vmp_ref.c
│ └── znx_small.c
├── coeffs
│ ├── coeffs_arithmetic.c
│ ├── coeffs_arithmetic.h
│ └── coeffs_arithmetic_avx.c
├── commons.c
├── commons.h
├── commons_private.c
├── commons_private.h
├── cplx
│ ├── README.md
│ ├── cplx_common.c
│ ├── cplx_conversions.c
│ ├── cplx_conversions_avx2_fma.c
│ ├── cplx_execute.c
│ ├── cplx_fallbacks_aarch64.c
│ ├── cplx_fft.h
│ ├── cplx_fft16_avx_fma.s
│ ├── cplx_fft16_avx_fma_win32.s
│ ├── cplx_fft_asserts.c
│ ├── cplx_fft_avx2_fma.c
│ ├── cplx_fft_avx512.c
│ ├── cplx_fft_internal.h
│ ├── cplx_fft_private.h
│ ├── cplx_fft_ref.c
│ ├── cplx_fft_sse.c
│ ├── cplx_fftvec_avx2_fma.c
│ ├── cplx_fftvec_ref.c
│ ├── cplx_ifft16_avx_fma.s
│ ├── cplx_ifft16_avx_fma_win32.s
│ ├── cplx_ifft_avx2_fma.c
│ ├── cplx_ifft_ref.c
│ └── spqlios_cplx_fft.c
├── ext
│ └── neon_accel
│ │ ├── macrof.h
│ │ └── macrofx4.h
├── q120
│ ├── q120_arithmetic.h
│ ├── q120_arithmetic_avx2.c
│ ├── q120_arithmetic_private.h
│ ├── q120_arithmetic_ref.c
│ ├── q120_arithmetic_simple.c
│ ├── q120_common.h
│ ├── q120_fallbacks_aarch64.c
│ ├── q120_ntt.c
│ ├── q120_ntt.h
│ ├── q120_ntt_avx2.c
│ └── q120_ntt_private.h
├── reim
│ ├── reim_conversions.c
│ ├── reim_conversions_avx.c
│ ├── reim_execute.c
│ ├── reim_fallbacks_aarch64.c
│ ├── reim_fft.h
│ ├── reim_fft16_avx_fma.s
│ ├── reim_fft16_avx_fma_win32.s
│ ├── reim_fft4_avx_fma.c
│ ├── reim_fft8_avx_fma.c
│ ├── reim_fft_avx2.c
│ ├── reim_fft_core_template.h
│ ├── reim_fft_ifft.c
│ ├── reim_fft_internal.h
│ ├── reim_fft_neon.c
│ ├── reim_fft_private.h
│ ├── reim_fft_ref.c
│ ├── reim_fftvec_addmul_fma.c
│ ├── reim_fftvec_addmul_ref.c
│ ├── reim_ifft16_avx_fma.s
│ ├── reim_ifft16_avx_fma_win32.s
│ ├── reim_ifft4_avx_fma.c
│ ├── reim_ifft8_avx_fma.c
│ ├── reim_ifft_avx2.c
│ ├── reim_ifft_ref.c
│ ├── reim_to_tnx_avx.c
│ └── reim_to_tnx_ref.c
└── reim4
│ ├── reim4_arithmetic.h
│ ├── reim4_arithmetic_avx2.c
│ ├── reim4_arithmetic_ref.c
│ ├── reim4_execute.c
│ ├── reim4_fallbacks_aarch64.c
│ ├── reim4_fftvec_addmul_fma.c
│ ├── reim4_fftvec_addmul_ref.c
│ ├── reim4_fftvec_conv_fma.c
│ ├── reim4_fftvec_conv_ref.c
│ ├── reim4_fftvec_internal.h
│ ├── reim4_fftvec_private.h
│ └── reim4_fftvec_public.h
└── test
├── CMakeLists.txt
├── spqlios_coeffs_arithmetic_test.cpp
├── spqlios_cplx_conversions_test.cpp
├── spqlios_cplx_fft_bench.cpp
├── spqlios_cplx_test.cpp
├── spqlios_q120_arithmetic_bench.cpp
├── spqlios_q120_arithmetic_test.cpp
├── spqlios_q120_ntt_bench.cpp
├── spqlios_q120_ntt_test.cpp
├── spqlios_reim4_arithmetic_bench.cpp
├── spqlios_reim4_arithmetic_test.cpp
├── spqlios_reim_conversions_test.cpp
├── spqlios_reim_test.cpp
├── spqlios_svp_product_test.cpp
├── spqlios_svp_test.cpp
├── spqlios_test.cpp
├── spqlios_vec_rnx_approxdecomp_tnxdbl_test.cpp
├── spqlios_vec_rnx_conversions_test.cpp
├── spqlios_vec_rnx_ppol_test.cpp
├── spqlios_vec_rnx_test.cpp
├── spqlios_vec_rnx_vmp_test.cpp
├── spqlios_vec_znx_big_test.cpp
├── spqlios_vec_znx_dft_test.cpp
├── spqlios_vec_znx_test.cpp
├── spqlios_vmp_product_test.cpp
├── spqlios_zn_approxdecomp_test.cpp
├── spqlios_zn_conversions_test.cpp
├── spqlios_zn_vmp_test.cpp
├── spqlios_znx_small_test.cpp
└── testlib
├── fft64_dft.cpp
├── fft64_dft.h
├── fft64_layouts.cpp
├── fft64_layouts.h
├── mod_q120.cpp
├── mod_q120.h
├── negacyclic_polynomial.cpp
├── negacyclic_polynomial.h
├── negacyclic_polynomial_impl.h
├── ntt120_dft.cpp
├── ntt120_dft.h
├── ntt120_layouts.cpp
├── ntt120_layouts.h
├── polynomial_vector.cpp
├── polynomial_vector.h
├── random.cpp
├── reim4_elem.cpp
├── reim4_elem.h
├── sha3.c
├── sha3.h
├── test_commons.cpp
├── test_commons.h
├── test_hash.cpp
├── vec_rnx_layout.cpp
├── vec_rnx_layout.h
├── zn_layouts.cpp
└── zn_layouts.h
/.clang-format:
--------------------------------------------------------------------------------
1 | # Use the Google style in this project.
2 | BasedOnStyle: Google
3 |
4 | # Some folks prefer to write "int& foo" while others prefer "int &foo". The
5 | # Google Style Guide only asks for consistency within a project, we chose
6 | # "int& foo" for this project:
7 | DerivePointerAlignment: false
8 | PointerAlignment: Left
9 |
10 | # The Google Style Guide only asks for consistency w.r.t. "east const" vs.
11 | # "const west" alignment of cv-qualifiers. In this project we use "east const".
12 | QualifierAlignment: Left
13 |
14 | ColumnLimit: 120
15 |
--------------------------------------------------------------------------------
/.github/workflows/auto-release.yml:
--------------------------------------------------------------------------------
1 | name: Auto-Release
2 |
3 | on:
4 | workflow_dispatch:
5 | push:
6 | branches: [ "main" ]
7 |
8 | jobs:
9 | build:
10 | name: Auto-Release
11 | runs-on: ubuntu-latest
12 |
13 | steps:
14 | - uses: actions/checkout@v3
15 | with:
16 | fetch-depth: 3
17 | # sparse-checkout: manifest.yaml scripts/auto-release.sh
18 |
19 | - run:
20 | ${{github.workspace}}/scripts/auto-release.sh
21 |
--------------------------------------------------------------------------------
/.github/workflows/cmake-build-test-arm64.yml:
--------------------------------------------------------------------------------
1 | name: CMake-Build-Test-Arm64
2 |
3 | on:
4 | workflow_dispatch:
5 | push:
6 | branches: [ "main" ]
7 | pull_request:
8 | branches: [ "main" ]
9 | types: [ labeled, opened, synchronize, reopened ]
10 |
11 | jobs:
12 | build-arm64:
13 | name: CMake-Build-Test-Arm64
14 | if: github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'check-on-arm64')
15 | runs-on: self-hosted-arm64
16 |
17 | steps:
18 | - uses: actions/checkout@v4
19 |
20 | - name: Configure CMake
21 | run: cmake -B ${{github.workspace}}/build -DCMAKE_BUILD_TYPE=Release -DENABLE_TESTING=ON -DWARNING_PARANOID=ON -DDEVMODE_INSTALL=ON
22 |
23 | - name: Build
24 | run: cmake --build ${{github.workspace}}/build
25 |
26 | - name: Test
27 | run: cd ${{github.workspace}}/build; ctest
28 |
29 | - name: Ci Package
30 | if: github.event_name != 'pull_request'
31 | env:
32 | CI_CREDS: ${{ secrets.CICREDS }}
33 | run: ./scripts/ci-pkg create
34 |
35 |
--------------------------------------------------------------------------------
/.github/workflows/cmake-build-test-darwin.yml:
--------------------------------------------------------------------------------
1 | name: CMake-Build-Test-Darwin
2 |
3 | on:
4 | workflow_dispatch:
5 | push:
6 | branches: [ "main" ]
7 | pull_request:
8 | branches: [ "main" ]
9 |
10 | jobs:
11 | build-darwin:
12 | name: CMake-Build-Test-Darwin
13 | runs-on: macos-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v4
17 |
18 | - name: Brew
19 | run: brew install cmake googletest google-benchmark
20 |
21 | - name: Configure CMake
22 | run: cmake -B build -DCMAKE_BUILD_TYPE=Release -DENABLE_TESTING=ON -DWARNING_PARANOID=ON -DDEVMODE_INSTALL=ON
23 |
24 | - name: Build
25 | run: cmake --build build
26 |
27 | - name: Test
28 | run: cd build; ctest
29 |
30 | - name: Package
31 | if: github.event_name != 'pull_request'
32 | env:
33 | CI_CREDS: ${{ secrets.CICREDS }}
34 | run: ./scripts/ci-pkg create
35 |
36 |
--------------------------------------------------------------------------------
/.github/workflows/cmake-build-test-win64.yml:
--------------------------------------------------------------------------------
1 | name: CMake-Build-Test-Win64
2 |
3 | on:
4 | workflow_dispatch:
5 | push:
6 | branches: [ "main" ]
7 | pull_request:
8 | branches: [ "main" ]
9 |
10 | jobs:
11 | build-msys-clang64:
12 | name: CMake-Build-Test-Win64
13 | runs-on: windows-latest
14 | defaults:
15 | run:
16 | shell: msys2 {0}
17 |
18 | steps:
19 | - uses: actions/checkout@v4
20 |
21 | - uses: msys2/setup-msys2@v2
22 | with:
23 | msystem: CLANG64
24 | #update: true
25 | install: git
26 | pacboy: toolchain:p cmake:p gtest:p benchmark:p
27 |
28 | - name: Configure CMake
29 | run: cmake -B build -DCMAKE_BUILD_TYPE=Release -DENABLE_TESTING=ON -DWARNING_PARANOID=ON -DDEVMODE_INSTALL=ON
30 |
31 | - name: Build
32 | run: cmake --build build
33 |
34 | - name: Test
35 | working-directory: build
36 | run: ctest --output-on-failure
37 |
38 | - name: CI Package
39 | if: github.event_name != 'pull_request'
40 | env:
41 | CI_CREDS: ${{ secrets.CICREDS }}
42 | run: ./scripts/ci-pkg create
43 |
44 |
--------------------------------------------------------------------------------
/.github/workflows/cmake-build-test.yml:
--------------------------------------------------------------------------------
1 | name: CMake-Build-Test
2 |
3 | on:
4 | workflow_dispatch:
5 | push:
6 | branches: [ "main" ]
7 | pull_request:
8 | branches: [ "main" ]
9 |
10 | jobs:
11 | build:
12 | name: CMake-Build-Test
13 | runs-on: ubuntu-latest
14 | container: ngama75/spqlios-ci:latest
15 |
16 | steps:
17 | - uses: actions/checkout@v4
18 |
19 | - name: Configure CMake
20 | run: cmake -B ${{github.workspace}}/build -DCMAKE_BUILD_TYPE=Release -DENABLE_TESTING=ON -DWARNING_PARANOID=ON -DDEVMODE_INSTALL=ON
21 |
22 | - name: Build
23 | run: cmake --build ${{github.workspace}}/build
24 |
25 | - name: Test
26 | run: cd ${{github.workspace}}/build; ctest
27 |
28 | - name: Package
29 | if: github.event_name != 'pull_request'
30 | env:
31 | CI_CREDS: ${{ secrets.CICREDS }}
32 | run: ./scripts/ci-pkg create
33 |
34 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | cmake-build-*
2 | .idea
3 |
--------------------------------------------------------------------------------
/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.8)
2 | project(spqlios)
3 |
4 | # read the current version from the manifest file
5 | file(READ "manifest.yaml" manifest)
6 | string(REGEX MATCH "version: +(([0-9]+)\\.([0-9]+)\\.([0-9]+))" SPQLIOS_VERSION_BLAH ${manifest})
7 | #message(STATUS "Version: ${SPQLIOS_VERSION_BLAH}")
8 | set(SPQLIOS_VERSION ${CMAKE_MATCH_1})
9 | set(SPQLIOS_VERSION_MAJOR ${CMAKE_MATCH_2})
10 | set(SPQLIOS_VERSION_MINOR ${CMAKE_MATCH_3})
11 | set(SPQLIOS_VERSION_PATCH ${CMAKE_MATCH_4})
12 | message(STATUS "Compiling spqlios-fft version: ${SPQLIOS_VERSION_MAJOR}.${SPQLIOS_VERSION_MINOR}.${SPQLIOS_VERSION_PATCH}")
13 |
14 | #set(ENABLE_SPQLIOS_F128 ON CACHE BOOL "Enable float128 via libquadmath")
15 | set(WARNING_PARANOID ON CACHE BOOL "Treat all warnings as errors")
16 | set(ENABLE_TESTING ON CACHE BOOL "Compiles unittests and integration tests")
17 | set(DEVMODE_INSTALL OFF CACHE BOOL "Install private headers and testlib (mainly for CI)")
18 |
19 | if (NOT CMAKE_BUILD_TYPE OR CMAKE_BUILD_TYPE STREQUAL "")
20 | set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Build type: Release or Debug" FORCE)
21 | endif()
22 | message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
23 |
24 | if (WARNING_PARANOID)
25 | add_compile_options(-Wall -Werror -Wno-unused-command-line-argument)
26 | endif()
27 |
28 | message(STATUS "CMAKE_HOST_SYSTEM_NAME: ${CMAKE_HOST_SYSTEM_NAME}")
29 | message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
30 | message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}")
31 |
32 | if (CMAKE_SYSTEM_PROCESSOR MATCHES "(x86)|(X86)|(amd64)|(AMD64)")
33 | set(X86 ON)
34 | set(AARCH64 OFF)
35 | else ()
36 | set(X86 OFF)
37 | # set(ENABLE_SPQLIOS_F128 OFF) # float128 are only supported for x86 targets
38 | endif ()
39 | if (CMAKE_SYSTEM_PROCESSOR MATCHES "(aarch64)|(arm64)")
40 | set(AARCH64 ON)
41 | endif ()
42 |
43 | if (CMAKE_SYSTEM_NAME MATCHES "(Windows)|(MSYS)")
44 | set(WIN32 ON)
45 | endif ()
46 | if (WIN32)
47 | #overrides for win32
48 | set(X86 OFF)
49 | set(AARCH64 OFF)
50 | set(X86_WIN32 ON)
51 | else()
52 | set(X86_WIN32 OFF)
53 | set(WIN32 OFF)
54 | endif (WIN32)
55 |
56 | message(STATUS "--> WIN32: ${WIN32}")
57 | message(STATUS "--> X86_WIN32: ${X86_WIN32}")
58 | message(STATUS "--> X86_LINUX: ${X86}")
59 | message(STATUS "--> AARCH64: ${AARCH64}")
60 |
61 | # compiles the main library in spqlios
62 | add_subdirectory(spqlios)
63 |
64 | # compiles and activates unittests and itests
65 | if (${ENABLE_TESTING})
66 | enable_testing()
67 | add_subdirectory(test)
68 | endif()
69 |
70 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to SPQlios-fft
2 |
3 | The spqlios-fft team encourages contributions.
4 | We encourage users to fix bugs, improve the documentation, write tests and to enhance the code, or ask for new features.
5 | We encourage researchers to contribute with implementations of their FFT or NTT algorithms.
6 | In the following we are trying to give some guidance on how to contribute effectively.
7 |
8 | ## Communication ##
9 |
10 | Communication in the spqlios-fft project happens mainly on [GitHub](https://github.com/tfhe/spqlios-fft/issues).
11 |
12 | All communications are public, so please make sure to maintain professional behaviour in
13 | all published comments. See [Code of Conduct](https://www.contributor-covenant.org/version/2/1/code_of_conduct/) for
14 | guidelines.
15 |
16 | ## Reporting Bugs or Requesting features ##
17 |
18 | Bug should be filed at [https://github.com/tfhe/spqlios-fft/issues](https://github.com/tfhe/spqlios-fft/issues).
19 |
20 | Features can also be requested there, in this case, please ensure that the features you request are self-contained,
21 | easy to define, and generic enough to be used in different use-cases. Please provide an example of use-cases if
22 | possible.
23 |
24 | ## Setting up topic branches and generating pull requests
25 |
26 | This section applies to people that already have write access to the repository. Specific instructions for pull-requests
27 | from public forks will be given later.
28 |
29 | To implement some changes, please follow these steps:
30 |
31 | - Create a "topic branch". Usually, the branch name should be `username/small-title`
32 | or better `username/issuenumber-small-title` where `issuenumber` is the number of
33 | the github issue number that is tackled.
34 | - Push any needed commits to your branch. Make sure it compiles in `CMAKE_BUILD_TYPE=Debug` and `=Release`, with `-DWARNING_PARANOID=ON`.
35 | - When the branch is nearly ready for review, please open a pull request, and add the label `check-on-arm`
36 | - Do as many commits as necessary until all CI checks pass and all PR comments have been resolved.
37 |
38 | > _During the process, you may optionnally use `git rebase -i` to clean up your commit history. If you elect to do so,
39 | please at the very least make sure that nobody else is working or has forked from your branch: the conflicts it would generate
40 | and the human hours to fix them are not worth it. `Git merge` remains the preferred option._
41 |
42 | - Finally, when all reviews are positive and all CI checks pass, you may merge your branch via the github webpage.
43 |
44 | ### Keep your pull requests limited to a single issue
45 |
46 | Pull requests should be as small/atomic as possible.
47 |
48 | ### Coding Conventions
49 |
50 | * Please make sure that your code is formatted according to the `.clang-format` file and
51 | that all files end with a newline character.
52 | * Please make sure that all the functions declared in the public api have relevant doxygen comments.
53 | Preferably, functions in the private apis should also contain a brief doxygen description.
54 |
55 | ### Versions and History
56 |
57 | * **Stable API** The project uses semantic versioning on the functions that are listed as `stable` in the documentation. A version has
58 | the form `x.y.z`
59 | * a patch release that increments `z` does not modify the stable API.
60 | * a minor release that increments `y` adds a new feature to the stable API.
61 | * In the unlikely case where we need to change or remove a feature, we will trigger a major release that
62 | increments `x`.
63 |
64 | > _If any, we will mark those features as deprecated at least six months before the major release._
65 |
66 | * **Experimental API** Features that are not part of the stable section in the documentation are experimental features: you may test them at
67 | your own risk,
68 | but keep in mind that semantic versioning does not apply to them.
69 |
70 | > _If you have a use-case that uses an experimental feature, we encourage
71 | > you to tell us about it, so that this feature reaches to the stable section faster!_
72 |
73 | * **Version history** The current version is reported in `manifest.yaml`, any change of version comes up with a tag on the main branch, and the history between releases is summarized in `Changelog.md`. It is the main source of truth for anyone who wishes to
74 | get insight about
75 | the history of the repository (not the commit graph).
76 |
77 | > Note: _The commit graph of git is for git's internal use only. Its main purpose is to reduce potential merge conflicts to a minimum, even in scenario where multiple features are developped in parallel: it may therefore be non-linear. If, as humans, we like to see a linear history, please read `Changelog.md` instead!_
78 |
--------------------------------------------------------------------------------
/Changelog.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 |
3 | All notable changes to this project will be documented in this file.
4 | this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
5 |
6 | ## [2.0.0] - 2024-08-21
7 |
8 | - Initial release of the `vec_znx` (except convolution products), `vec_rnx` and `zn` apis.
9 | - Hardware acceleration available: AVX2 (most parts)
10 | - APIs are documented in the wiki and are in "beta mode": during the 2.x -> 3.x transition, functions whose API is satisfactory in test projects will pass in "stable mode".
11 |
12 | ## [1.0.0] - 2023-07-18
13 |
14 | - Initial release of the double precision fft on the reim and cplx backends
15 | - Coeffs-space conversions cplx <-> znx32 and tnx32
16 | - FFT-space conversions cplx <-> reim4 layouts
17 | - FFT-space multiplications on the cplx, reim and reim4 layouts.
18 | - In this first release, the only platform supported is linux x86_64 (generic C code, and avx2/fma). It compiles on arm64, but without any acceleration.
19 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SPQlios library
2 |
3 |
4 |
5 | The SPQlios library provides fast arithmetic for Fully Homomorphic Encryption, and other lattice constructions that arise in post quantum cryptography.
6 |
7 |
8 |
9 | Namely, it is divided into 4 sections:
10 |
11 | * The low-level DFT section support FFT over 64-bit floats, as well as NTT modulo one fixed 120-bit modulus. It is an upgrade of the original spqlios-fft module embedded in the TFHE library since 2016. The DFT section exposes the traditional DFT, inverse-DFT, and coefficient-wise multiplications in DFT space.
12 | * The VEC_ZNX section exposes fast algebra over vectors of small integer polynomial modulo $X^N+1$. It proposed in particular efficient (prepared) vector-matrix products, scalar-vector products, convolution products, and element-wise products, operations that naturally occurs on gadget-decomposed Ring-LWE coordinates.
13 | * The RNX section is a simpler variant of VEC_ZNX, to represent single polynomials modulo $X^N+1$ (over the reals or over the torus) when the coefficient precision fits on 64-bit doubles. The small vector-matrix API of the RNX section is particularly adapted to reproducing the fastest CGGI-based bootstrappings.
14 | * The ZN section focuses over vector and matrix algebra over scalars (used by scalar LWE, or scalar key-switches, but also on non-ring schemes like Frodo, FrodoPIR, and SimplePIR).
15 |
16 | ### A high value target for hardware accelerations
17 |
18 | SPQlios is more than a library, it is also a good target for hardware developers.
19 | On one hand, the arithmetic operations that are defined in the library have a clear standalone mathematical definition. And at the same time, the amount of work in each operations is sufficiently large so that meaningful functions only require a few of these.
20 |
21 | This makes the SPQlios API a high value target for hardware acceleration, that targets FHE.
22 |
23 | ### SPQLios is not an FHE library, but a huge enabler
24 |
25 | SPQlios itself is not an FHE library: there is no ciphertext, plaintext or key. It is a mathematical library that exposes efficient algebra over polynomials. Using the functions exposed, it is possible to quickly build efficient FHE libraries, with support for the main schemes based on Ring-LWE: BFV, BGV, CGGI, DM, CKKS.
26 |
27 |
28 | ## Dependencies
29 |
30 | The SPQLIOS-FFT library is a C library that can be compiled with a standard C compiler, and depends only on libc and libm. The API
31 | interface can be used in a regular C code, and any other language via classical foreign APIs.
32 |
33 | The unittests and integration tests are in an optional part of the code, and are written in C++. These tests rely on
34 | [```benchmark```](https://github.com/google/benchmark), and [```gtest```](https://github.com/google/googletest) libraries, and therefore require a C++17 compiler.
35 |
36 | Currently, the project has been tested with the gcc,g++ >= 11.3.0 compiler under Linux (x86_64). In the future, we plan to
37 | extend the compatibility to other compilers, platforms and operating systems.
38 |
39 |
40 | ## Installation
41 |
42 | The library uses a classical ```cmake``` build mechanism: use ```cmake``` to create a ```build``` folder in the top level directory and run ```make``` from inside it. This assumes that the standard tool ```cmake``` is already installed on the system, and an up-to-date c++ compiler (i.e. g++ >=11.3.0) as well.
43 |
44 | It will compile the shared library in optimized mode, and ```make install``` install it to the desired prefix folder (by default ```/usr/local/lib```).
45 |
46 | If you want to choose additional compile options (i.e. other installation folder, debug mode, tests), you need to run cmake manually and pass the desired options:
47 | ```
48 | mkdir build
49 | cd build
50 | cmake ../src -CMAKE_INSTALL_PREFIX=/usr/
51 | make
52 | ```
53 | The available options are the following:
54 |
55 | | Variable Name | values |
56 | | -------------------- | ------------------------------------------------------------ |
57 | | CMAKE_INSTALL_PREFIX | */usr/local* installation folder (libs go in lib/ and headers in include/) |
58 | | WARNING_PARANOID | All warnings are shown and treated as errors. Off by default |
59 | | ENABLE_TESTING | Compiles unit tests and integration tests |
60 |
61 | ------
62 |
63 |
64 |
65 |
66 |
--------------------------------------------------------------------------------
/docs/logo-inpher1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tfhe/spqlios-arithmetic/7ea875f61c51c67687ec9c15ae0eae04cda961f8/docs/logo-inpher1.png
--------------------------------------------------------------------------------
/docs/logo-inpher2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tfhe/spqlios-arithmetic/7ea875f61c51c67687ec9c15ae0eae04cda961f8/docs/logo-inpher2.png
--------------------------------------------------------------------------------
/manifest.yaml:
--------------------------------------------------------------------------------
1 | library: spqlios-fft
2 | version: 2.0.0
3 |
--------------------------------------------------------------------------------
/scripts/auto-release.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # this script generates one tag if there is a version change in manifest.yaml
4 | cd `dirname $0`/..
5 | if [ "v$1" = "v-y" ]; then
6 | echo "production mode!";
7 | fi
8 | changes=`git diff HEAD~1..HEAD -- manifest.yaml | grep 'version:'`
9 | oldversion=$(echo "$changes" | grep '^-version:' | cut '-d ' -f2)
10 | version=$(echo "$changes" | grep '^+version:' | cut '-d ' -f2)
11 | echo "Versions: $oldversion --> $version"
12 | if [ "v$oldversion" = "v$version" ]; then
13 | echo "Same version - nothing to do"; exit 0;
14 | fi
15 | if [ "v$1" = "v-y" ]; then
16 | git config user.name github-actions
17 | git config user.email github-actions@github.com
18 | git tag -a "v$version" -m "Version $version"
19 | git push origin "v$version"
20 | else
21 | cat </dev/null
75 | rm -f "$DIR/$FNAME" 2>/dev/null
76 | DESTDIR="$DIR/dist" cmake --install build || exit 1
77 | if [ -d "$DIR/dist$CI_INSTALL_PREFIX" ]; then
78 | tar -C "$DIR/dist" -cvzf "$DIR/$FNAME" .
79 | else
80 | # fix since msys can mess up the paths
81 | REAL_DEST=`find "$DIR/dist" -type d -exec test -d "{}$CI_INSTALL_PREFIX" \; -print`
82 | echo "REAL_DEST: $REAL_DEST"
83 | [ -d "$REAL_DEST$CI_INSTALL_PREFIX" ] && tar -C "$REAL_DEST" -cvzf "$DIR/$FNAME" .
84 | fi
85 | [ -f "$DIR/$FNAME" ] || { echo "failed to create $DIR/$FNAME"; exit 1; }
86 | [ "x$CI_CREDS" = "x" ] && { echo "CI_CREDS is not set: not uploading"; exit 1; }
87 | curl -u "$CI_CREDS" -T "$DIR/$FNAME" "$CI_REPO_URL/$FNAME"
88 | fi
89 |
90 | if [ "x$1" = "xinstall" ]; then
91 | [ "x$CI_CREDS" = "x" ] && { echo "CI_CREDS is not set: not downloading"; exit 1; }
92 | # cleaning
93 | rm -rf "$DESTDIR$CI_INSTALL_PREFIX"/* 2>/dev/null
94 | rm -f "$DIR/$FNAME" 2>/dev/null
95 | # downloading
96 | curl -u "$CI_CREDS" -o "$DIR/$FNAME" "$CI_REPO_URL/$FNAME"
97 | [ -f "$DIR/$FNAME" ] || { echo "failed to download $DIR/$FNAME"; exit 0; }
98 | # installing
99 | mkdir -p $DESTDIR
100 | tar -C "$DESTDIR" -xvzf "$DIR/$FNAME"
101 | exit 0
102 | fi
103 |
--------------------------------------------------------------------------------
/spqlios/arithmetic/scalar_vector_product.c:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "vec_znx_arithmetic_private.h"
4 |
5 | EXPORT uint64_t bytes_of_svp_ppol(const MODULE* module) { return module->func.bytes_of_svp_ppol(module); }
6 |
7 | EXPORT uint64_t fft64_bytes_of_svp_ppol(const MODULE* module) { return module->nn * sizeof(double); }
8 |
9 | EXPORT SVP_PPOL* new_svp_ppol(const MODULE* module) { return spqlios_alloc(bytes_of_svp_ppol(module)); }
10 |
11 | EXPORT void delete_svp_ppol(SVP_PPOL* ppol) { spqlios_free(ppol); }
12 |
13 | // public wrappers
14 | EXPORT void svp_prepare(const MODULE* module, // N
15 | SVP_PPOL* ppol, // output
16 | const int64_t* pol // a
17 | ) {
18 | module->func.svp_prepare(module, ppol, pol);
19 | }
20 |
21 | /** @brief prepares a svp polynomial */
22 | EXPORT void fft64_svp_prepare_ref(const MODULE* module, // N
23 | SVP_PPOL* ppol, // output
24 | const int64_t* pol // a
25 | ) {
26 | reim_from_znx64(module->mod.fft64.p_conv, ppol, pol);
27 | reim_fft(module->mod.fft64.p_fft, (double*)ppol);
28 | }
29 |
30 | EXPORT void svp_apply_dft(const MODULE* module, // N
31 | const VEC_ZNX_DFT* res, uint64_t res_size, // output
32 | const SVP_PPOL* ppol, // prepared pol
33 | const int64_t* a, uint64_t a_size, uint64_t a_sl) {
34 | module->func.svp_apply_dft(module, // N
35 | res,
36 | res_size, // output
37 | ppol, // prepared pol
38 | a, a_size, a_sl);
39 | }
40 |
41 | // result = ppol * a
42 | EXPORT void fft64_svp_apply_dft_ref(const MODULE* module, // N
43 | const VEC_ZNX_DFT* res, uint64_t res_size, // output
44 | const SVP_PPOL* ppol, // prepared pol
45 | const int64_t* a, uint64_t a_size, uint64_t a_sl // a
46 | ) {
47 | const uint64_t nn = module->nn;
48 | double* const dres = (double*)res;
49 | double* const dppol = (double*)ppol;
50 |
51 | const uint64_t auto_end_idx = res_size < a_size ? res_size : a_size;
52 | for (uint64_t i = 0; i < auto_end_idx; ++i) {
53 | const int64_t* a_ptr = a + i * a_sl;
54 | double* const res_ptr = dres + i * nn;
55 | // copy the polynomial to res, apply fft in place, call fftvec_mul in place.
56 | reim_from_znx64(module->mod.fft64.p_conv, res_ptr, a_ptr);
57 | reim_fft(module->mod.fft64.p_fft, res_ptr);
58 | reim_fftvec_mul(module->mod.fft64.mul_fft, res_ptr, res_ptr, dppol);
59 | }
60 |
61 | // then extend with zeros
62 | memset(dres + auto_end_idx * nn, 0, (res_size - auto_end_idx) * nn * sizeof(double));
63 | }
64 |
--------------------------------------------------------------------------------
/spqlios/arithmetic/vec_rnx_approxdecomp_avx.c:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "immintrin.h"
4 | #include "vec_rnx_arithmetic_private.h"
5 |
6 | /** @brief sets res = gadget_decompose(a) */
7 | EXPORT void rnx_approxdecomp_from_tnxdbl_avx( //
8 | const MOD_RNX* module, // N
9 | const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K
10 | double* res, uint64_t res_size, uint64_t res_sl, // res
11 | const double* a // a
12 | ) {
13 | const uint64_t nn = module->n;
14 | if (nn < 4) return rnx_approxdecomp_from_tnxdbl_ref(module, gadget, res, res_size, res_sl, a);
15 | const uint64_t ell = gadget->ell;
16 | const __m256i k = _mm256_set1_epi64x(gadget->k);
17 | const __m256d add_cst = _mm256_set1_pd(gadget->add_cst);
18 | const __m256i and_mask = _mm256_set1_epi64x(gadget->and_mask);
19 | const __m256i or_mask = _mm256_set1_epi64x(gadget->or_mask);
20 | const __m256d sub_cst = _mm256_set1_pd(gadget->sub_cst);
21 | const uint64_t msize = res_size <= ell ? res_size : ell;
22 | // gadget decompose column by column
23 | if (msize == ell) {
24 | // this is the main scenario when msize == ell
25 | double* const last_r = res + (msize - 1) * res_sl;
26 | for (uint64_t j = 0; j < nn; j += 4) {
27 | double* rr = last_r + j;
28 | const double* aa = a + j;
29 | __m256d t_dbl = _mm256_add_pd(_mm256_loadu_pd(aa), add_cst);
30 | __m256i t_int = _mm256_castpd_si256(t_dbl);
31 | do {
32 | __m256i u_int = _mm256_or_si256(_mm256_and_si256(t_int, and_mask), or_mask);
33 | _mm256_storeu_pd(rr, _mm256_sub_pd(_mm256_castsi256_pd(u_int), sub_cst));
34 | t_int = _mm256_srlv_epi64(t_int, k);
35 | rr -= res_sl;
36 | } while (rr >= res);
37 | }
38 | } else if (msize > 0) {
39 | // otherwise, if msize < ell: there is one additional rshift
40 | const __m256i first_rsh = _mm256_set1_epi64x((ell - msize) * gadget->k);
41 | double* const last_r = res + (msize - 1) * res_sl;
42 | for (uint64_t j = 0; j < nn; j += 4) {
43 | double* rr = last_r + j;
44 | const double* aa = a + j;
45 | __m256d t_dbl = _mm256_add_pd(_mm256_loadu_pd(aa), add_cst);
46 | __m256i t_int = _mm256_srlv_epi64(_mm256_castpd_si256(t_dbl), first_rsh);
47 | do {
48 | __m256i u_int = _mm256_or_si256(_mm256_and_si256(t_int, and_mask), or_mask);
49 | _mm256_storeu_pd(rr, _mm256_sub_pd(_mm256_castsi256_pd(u_int), sub_cst));
50 | t_int = _mm256_srlv_epi64(t_int, k);
51 | rr -= res_sl;
52 | } while (rr >= res);
53 | }
54 | }
55 | // zero-out the last slices (if any)
56 | for (uint64_t i = msize; i < res_size; ++i) {
57 | memset(res + i * res_sl, 0, nn * sizeof(double));
58 | }
59 | }
60 |
--------------------------------------------------------------------------------
/spqlios/arithmetic/vec_rnx_approxdecomp_ref.c:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "vec_rnx_arithmetic_private.h"
4 |
5 | typedef union di {
6 | double dv;
7 | uint64_t uv;
8 | } di_t;
9 |
10 | /** @brief new gadget: delete with delete_tnxdbl_approxdecomp_gadget */
11 | EXPORT TNXDBL_APPROXDECOMP_GADGET* new_tnxdbl_approxdecomp_gadget( //
12 | const MOD_RNX* module, // N
13 | uint64_t k, uint64_t ell // base 2^K and size
14 | ) {
15 | if (k * ell > 50) return spqlios_error("gadget requires a too large fp precision");
16 | TNXDBL_APPROXDECOMP_GADGET* res = spqlios_alloc(sizeof(TNXDBL_APPROXDECOMP_GADGET));
17 | res->k = k;
18 | res->ell = ell;
19 | // double add_cst; // double(3.2^(51-ell.K) + 1/2.(sum 2^(-iK)) for i=[0,ell[)
20 | union di add_cst;
21 | add_cst.dv = UINT64_C(3) << (51 - ell * k);
22 | for (uint64_t i = 0; i < ell; ++i) {
23 | add_cst.uv |= UINT64_C(1) << ((i + 1) * k - 1);
24 | }
25 | res->add_cst = add_cst.dv;
26 | // uint64_t and_mask; // uint64(2^(K)-1)
27 | res->and_mask = (UINT64_C(1) << k) - 1;
28 | // uint64_t or_mask; // double(2^52)
29 | union di or_mask;
30 | or_mask.dv = (UINT64_C(1) << 52);
31 | res->or_mask = or_mask.uv;
32 | // double sub_cst; // double(2^52 + 2^(K-1))
33 | res->sub_cst = ((UINT64_C(1) << 52) + (UINT64_C(1) << (k - 1)));
34 | return res;
35 | }
36 |
37 | EXPORT void delete_tnxdbl_approxdecomp_gadget(TNXDBL_APPROXDECOMP_GADGET* gadget) { spqlios_free(gadget); }
38 |
39 | /** @brief sets res = gadget_decompose(a) */
40 | EXPORT void rnx_approxdecomp_from_tnxdbl_ref( //
41 | const MOD_RNX* module, // N
42 | const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K
43 | double* res, uint64_t res_size, uint64_t res_sl, // res
44 | const double* a // a
45 | ) {
46 | const uint64_t nn = module->n;
47 | const uint64_t k = gadget->k;
48 | const uint64_t ell = gadget->ell;
49 | const double add_cst = gadget->add_cst;
50 | const uint64_t and_mask = gadget->and_mask;
51 | const uint64_t or_mask = gadget->or_mask;
52 | const double sub_cst = gadget->sub_cst;
53 | const uint64_t msize = res_size <= ell ? res_size : ell;
54 | const uint64_t first_rsh = (ell - msize) * k;
55 | // gadget decompose column by column
56 | if (msize > 0) {
57 | double* const last_r = res + (msize - 1) * res_sl;
58 | for (uint64_t j = 0; j < nn; ++j) {
59 | double* rr = last_r + j;
60 | di_t t = {.dv = a[j] + add_cst};
61 | if (msize < ell) t.uv >>= first_rsh;
62 | do {
63 | di_t u;
64 | u.uv = (t.uv & and_mask) | or_mask;
65 | *rr = u.dv - sub_cst;
66 | t.uv >>= k;
67 | rr -= res_sl;
68 | } while (rr >= res);
69 | }
70 | }
71 | // zero-out the last slices (if any)
72 | for (uint64_t i = msize; i < res_size; ++i) {
73 | memset(res + i * res_sl, 0, nn * sizeof(double));
74 | }
75 | }
76 |
--------------------------------------------------------------------------------
/spqlios/arithmetic/vec_rnx_conversions_ref.c:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "vec_rnx_arithmetic_private.h"
4 | #include "zn_arithmetic_private.h"
5 |
6 | EXPORT void vec_rnx_to_znx32_ref( //
7 | const MOD_RNX* module, // N
8 | int32_t* res, uint64_t res_size, uint64_t res_sl, // res
9 | const double* a, uint64_t a_size, uint64_t a_sl // a
10 | ) {
11 | const uint64_t nn = module->n;
12 | const uint64_t msize = res_size < a_size ? res_size : a_size;
13 | for (uint64_t i = 0; i < msize; ++i) {
14 | dbl_round_to_i32_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn);
15 | }
16 | for (uint64_t i = msize; i < res_size; ++i) {
17 | memset(res + i * res_sl, 0, nn * sizeof(int32_t));
18 | }
19 | }
20 |
21 | EXPORT void vec_rnx_from_znx32_ref( //
22 | const MOD_RNX* module, // N
23 | double* res, uint64_t res_size, uint64_t res_sl, // res
24 | const int32_t* a, uint64_t a_size, uint64_t a_sl // a
25 | ) {
26 | const uint64_t nn = module->n;
27 | const uint64_t msize = res_size < a_size ? res_size : a_size;
28 | for (uint64_t i = 0; i < msize; ++i) {
29 | i32_to_dbl_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn);
30 | }
31 | for (uint64_t i = msize; i < res_size; ++i) {
32 | memset(res + i * res_sl, 0, nn * sizeof(int32_t));
33 | }
34 | }
35 | EXPORT void vec_rnx_to_tnx32_ref( //
36 | const MOD_RNX* module, // N
37 | int32_t* res, uint64_t res_size, uint64_t res_sl, // res
38 | const double* a, uint64_t a_size, uint64_t a_sl // a
39 | ) {
40 | const uint64_t nn = module->n;
41 | const uint64_t msize = res_size < a_size ? res_size : a_size;
42 | for (uint64_t i = 0; i < msize; ++i) {
43 | dbl_to_tn32_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn);
44 | }
45 | for (uint64_t i = msize; i < res_size; ++i) {
46 | memset(res + i * res_sl, 0, nn * sizeof(int32_t));
47 | }
48 | }
49 | EXPORT void vec_rnx_from_tnx32_ref( //
50 | const MOD_RNX* module, // N
51 | double* res, uint64_t res_size, uint64_t res_sl, // res
52 | const int32_t* a, uint64_t a_size, uint64_t a_sl // a
53 | ) {
54 | const uint64_t nn = module->n;
55 | const uint64_t msize = res_size < a_size ? res_size : a_size;
56 | for (uint64_t i = 0; i < msize; ++i) {
57 | tn32_to_dbl_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn);
58 | }
59 | for (uint64_t i = msize; i < res_size; ++i) {
60 | memset(res + i * res_sl, 0, nn * sizeof(int32_t));
61 | }
62 | }
63 |
64 | static void dbl_to_tndbl_ref( //
65 | const void* UNUSED, // N
66 | double* res, uint64_t res_size, // res
67 | const double* a, uint64_t a_size // a
68 | ) {
69 | static const double OFF_CST = INT64_C(3) << 51;
70 | const uint64_t msize = res_size < a_size ? res_size : a_size;
71 | for (uint64_t i = 0; i < msize; ++i) {
72 | double ai = a[i] + OFF_CST;
73 | res[i] = a[i] - (ai - OFF_CST);
74 | }
75 | memset(res + msize, 0, (res_size - msize) * sizeof(double));
76 | }
77 |
78 | EXPORT void vec_rnx_to_tnxdbl_ref( //
79 | const MOD_RNX* module, // N
80 | double* res, uint64_t res_size, uint64_t res_sl, // res
81 | const double* a, uint64_t a_size, uint64_t a_sl // a
82 | ) {
83 | const uint64_t nn = module->n;
84 | const uint64_t msize = res_size < a_size ? res_size : a_size;
85 | for (uint64_t i = 0; i < msize; ++i) {
86 | dbl_to_tndbl_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn);
87 | }
88 | for (uint64_t i = msize; i < res_size; ++i) {
89 | memset(res + i * res_sl, 0, nn * sizeof(int32_t));
90 | }
91 | }
92 |
--------------------------------------------------------------------------------
/spqlios/arithmetic/vec_rnx_svp_ref.c:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "../coeffs/coeffs_arithmetic.h"
4 | #include "vec_rnx_arithmetic_private.h"
5 |
6 | EXPORT uint64_t fft64_bytes_of_rnx_svp_ppol(const MOD_RNX* module) { return module->n * sizeof(double); }
7 |
8 | EXPORT RNX_SVP_PPOL* new_rnx_svp_ppol(const MOD_RNX* module) { return spqlios_alloc(bytes_of_rnx_svp_ppol(module)); }
9 |
10 | EXPORT void delete_rnx_svp_ppol(RNX_SVP_PPOL* ppol) { spqlios_free(ppol); }
11 |
12 | /** @brief prepares a svp polynomial */
13 | EXPORT void fft64_rnx_svp_prepare_ref(const MOD_RNX* module, // N
14 | RNX_SVP_PPOL* ppol, // output
15 | const double* pol // a
16 | ) {
17 | double* const dppol = (double*)ppol;
18 | rnx_divide_by_m_ref(module->n, module->m, dppol, pol);
19 | reim_fft(module->precomp.fft64.p_fft, dppol);
20 | }
21 |
22 | EXPORT void fft64_rnx_svp_apply_ref( //
23 | const MOD_RNX* module, // N
24 | double* res, uint64_t res_size, uint64_t res_sl, // output
25 | const RNX_SVP_PPOL* ppol, // prepared pol
26 | const double* a, uint64_t a_size, uint64_t a_sl // a
27 | ) {
28 | const uint64_t nn = module->n;
29 | double* const dppol = (double*)ppol;
30 |
31 | const uint64_t auto_end_idx = res_size < a_size ? res_size : a_size;
32 | for (uint64_t i = 0; i < auto_end_idx; ++i) {
33 | const double* a_ptr = a + i * a_sl;
34 | double* const res_ptr = res + i * res_sl;
35 | // copy the polynomial to res, apply fft in place, call fftvec
36 | // _mul, apply ifft in place.
37 | memcpy(res_ptr, a_ptr, nn * sizeof(double));
38 | reim_fft(module->precomp.fft64.p_fft, (double*)res_ptr);
39 | reim_fftvec_mul(module->precomp.fft64.p_fftvec_mul, res_ptr, res_ptr, dppol);
40 | reim_ifft(module->precomp.fft64.p_ifft, res_ptr);
41 | }
42 |
43 | // then extend with zeros
44 | for (uint64_t i = auto_end_idx; i < res_size; ++i) {
45 | memset(res + i * res_sl, 0, nn * sizeof(double));
46 | }
47 | }
48 |
--------------------------------------------------------------------------------
/spqlios/arithmetic/vec_znx_avx.c:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "../coeffs/coeffs_arithmetic.h"
4 | #include "../reim4/reim4_arithmetic.h"
5 | #include "vec_znx_arithmetic_private.h"
6 |
7 | // specialized function (ref)
8 |
9 | // Note: these functions do not have an avx variant.
10 | #define znx_copy_i64_avx znx_copy_i64_ref
11 | #define znx_zero_i64_avx znx_zero_i64_ref
12 |
13 | EXPORT void vec_znx_add_avx(const MODULE* module, // N
14 | int64_t* res, uint64_t res_size, uint64_t res_sl, // res
15 | const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
16 | const int64_t* b, uint64_t b_size, uint64_t b_sl // b
17 | ) {
18 | const uint64_t nn = module->nn;
19 | if (a_size <= b_size) {
20 | const uint64_t sum_idx = res_size < a_size ? res_size : a_size;
21 | const uint64_t copy_idx = res_size < b_size ? res_size : b_size;
22 | // add up to the smallest dimension
23 | for (uint64_t i = 0; i < sum_idx; ++i) {
24 | znx_add_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
25 | }
26 | // then copy to the largest dimension
27 | for (uint64_t i = sum_idx; i < copy_idx; ++i) {
28 | znx_copy_i64_avx(nn, res + i * res_sl, b + i * b_sl);
29 | }
30 | // then extend with zeros
31 | for (uint64_t i = copy_idx; i < res_size; ++i) {
32 | znx_zero_i64_avx(nn, res + i * res_sl);
33 | }
34 | } else {
35 | const uint64_t sum_idx = res_size < b_size ? res_size : b_size;
36 | const uint64_t copy_idx = res_size < a_size ? res_size : a_size;
37 | // add up to the smallest dimension
38 | for (uint64_t i = 0; i < sum_idx; ++i) {
39 | znx_add_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
40 | }
41 | // then copy to the largest dimension
42 | for (uint64_t i = sum_idx; i < copy_idx; ++i) {
43 | znx_copy_i64_avx(nn, res + i * res_sl, a + i * a_sl);
44 | }
45 | // then extend with zeros
46 | for (uint64_t i = copy_idx; i < res_size; ++i) {
47 | znx_zero_i64_avx(nn, res + i * res_sl);
48 | }
49 | }
50 | }
51 |
52 | EXPORT void vec_znx_sub_avx(const MODULE* module, // N
53 | int64_t* res, uint64_t res_size, uint64_t res_sl, // res
54 | const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
55 | const int64_t* b, uint64_t b_size, uint64_t b_sl // b
56 | ) {
57 | const uint64_t nn = module->nn;
58 | if (a_size <= b_size) {
59 | const uint64_t sub_idx = res_size < a_size ? res_size : a_size;
60 | const uint64_t copy_idx = res_size < b_size ? res_size : b_size;
61 | // subtract up to the smallest dimension
62 | for (uint64_t i = 0; i < sub_idx; ++i) {
63 | znx_sub_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
64 | }
65 | // then negate to the largest dimension
66 | for (uint64_t i = sub_idx; i < copy_idx; ++i) {
67 | znx_negate_i64_avx(nn, res + i * res_sl, b + i * b_sl);
68 | }
69 | // then extend with zeros
70 | for (uint64_t i = copy_idx; i < res_size; ++i) {
71 | znx_zero_i64_avx(nn, res + i * res_sl);
72 | }
73 | } else {
74 | const uint64_t sub_idx = res_size < b_size ? res_size : b_size;
75 | const uint64_t copy_idx = res_size < a_size ? res_size : a_size;
76 | // subtract up to the smallest dimension
77 | for (uint64_t i = 0; i < sub_idx; ++i) {
78 | znx_sub_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
79 | }
80 | // then copy to the largest dimension
81 | for (uint64_t i = sub_idx; i < copy_idx; ++i) {
82 | znx_copy_i64_avx(nn, res + i * res_sl, a + i * a_sl);
83 | }
84 | // then extend with zeros
85 | for (uint64_t i = copy_idx; i < res_size; ++i) {
86 | znx_zero_i64_avx(nn, res + i * res_sl);
87 | }
88 | }
89 | }
90 |
91 | EXPORT void vec_znx_negate_avx(const MODULE* module, // N
92 | int64_t* res, uint64_t res_size, uint64_t res_sl, // res
93 | const int64_t* a, uint64_t a_size, uint64_t a_sl // a
94 | ) {
95 | uint64_t nn = module->nn;
96 | uint64_t smin = res_size < a_size ? res_size : a_size;
97 | for (uint64_t i = 0; i < smin; ++i) {
98 | znx_negate_i64_avx(nn, res + i * res_sl, a + i * a_sl);
99 | }
100 | for (uint64_t i = smin; i < res_size; ++i) {
101 | znx_zero_i64_ref(nn, res + i * res_sl);
102 | }
103 | }
104 |
--------------------------------------------------------------------------------
/spqlios/arithmetic/vec_znx_dft_avx2.c:
--------------------------------------------------------------------------------
1 | #include "vec_znx_arithmetic_private.h"
2 |
--------------------------------------------------------------------------------
/spqlios/arithmetic/zn_approxdecomp_ref.c:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "zn_arithmetic_private.h"
4 |
5 | EXPORT TNDBL_APPROXDECOMP_GADGET* new_tndbl_approxdecomp_gadget(const MOD_Z* module, //
6 | uint64_t k, uint64_t ell) {
7 | if (k * ell > 50) {
8 | return spqlios_error("approx decomposition requested is too precise for doubles");
9 | }
10 | if (k < 1) {
11 | return spqlios_error("approx decomposition supports k>=1");
12 | }
13 | TNDBL_APPROXDECOMP_GADGET* res = malloc(sizeof(TNDBL_APPROXDECOMP_GADGET));
14 | memset(res, 0, sizeof(TNDBL_APPROXDECOMP_GADGET));
15 | res->k = k;
16 | res->ell = ell;
17 | double add_cst = INT64_C(3) << (51 - k * ell);
18 | for (uint64_t i = 0; i < ell; ++i) {
19 | add_cst += pow(2., -(double)(i * k + 1));
20 | }
21 | res->add_cst = add_cst;
22 | res->and_mask = (UINT64_C(1) << k) - 1;
23 | res->sub_cst = UINT64_C(1) << (k - 1);
24 | for (uint64_t i = 0; i < ell; ++i) res->rshifts[i] = (ell - 1 - i) * k;
25 | return res;
26 | }
27 | EXPORT void delete_tndbl_approxdecomp_gadget(TNDBL_APPROXDECOMP_GADGET* ptr) { free(ptr); }
28 |
29 | EXPORT int default_init_tndbl_approxdecomp_gadget(const MOD_Z* module, //
30 | TNDBL_APPROXDECOMP_GADGET* res, //
31 | uint64_t k, uint64_t ell) {
32 | return 0;
33 | }
34 |
35 | typedef union {
36 | double dv;
37 | uint64_t uv;
38 | } du_t;
39 |
40 | #define IMPL_ixx_approxdecomp_from_tndbl_ref(ITYPE) \
41 | if (res_size != a_size * gadget->ell) NOT_IMPLEMENTED(); \
42 | const uint64_t ell = gadget->ell; \
43 | const double add_cst = gadget->add_cst; \
44 | const uint8_t* const rshifts = gadget->rshifts; \
45 | const ITYPE and_mask = gadget->and_mask; \
46 | const ITYPE sub_cst = gadget->sub_cst; \
47 | ITYPE* rr = res; \
48 | const double* aa = a; \
49 | const double* aaend = a + a_size; \
50 | while (aa < aaend) { \
51 | du_t t = {.dv = *aa + add_cst}; \
52 | for (uint64_t i = 0; i < ell; ++i) { \
53 | ITYPE v = (ITYPE)(t.uv >> rshifts[i]); \
54 | *rr = (v & and_mask) - sub_cst; \
55 | ++rr; \
56 | } \
57 | ++aa; \
58 | }
59 |
60 | /** @brief sets res = gadget_decompose(a) (int8_t* output) */
61 | EXPORT void default_i8_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N
62 | const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
63 | int8_t* res, uint64_t res_size, // res (in general, size ell.a_size)
64 | const double* a, uint64_t a_size //
65 | ){IMPL_ixx_approxdecomp_from_tndbl_ref(int8_t)}
66 |
67 | /** @brief sets res = gadget_decompose(a) (int16_t* output) */
68 | EXPORT void default_i16_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N
69 | const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
70 | int16_t* res, uint64_t res_size, // res
71 | const double* a, uint64_t a_size // a
72 | ){IMPL_ixx_approxdecomp_from_tndbl_ref(int16_t)}
73 |
74 | /** @brief sets res = gadget_decompose(a) (int32_t* output) */
75 | EXPORT void default_i32_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N
76 | const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget
77 | int32_t* res, uint64_t res_size, // res
78 | const double* a, uint64_t a_size // a
79 | ) {
80 | IMPL_ixx_approxdecomp_from_tndbl_ref(int32_t)
81 | }
82 |
--------------------------------------------------------------------------------
/spqlios/arithmetic/zn_arithmetic_plugin.h:
--------------------------------------------------------------------------------
1 | #ifndef SPQLIOS_ZN_ARITHMETIC_PLUGIN_H
2 | #define SPQLIOS_ZN_ARITHMETIC_PLUGIN_H
3 |
4 | #include "zn_arithmetic.h"
5 |
6 | typedef typeof(i8_approxdecomp_from_tndbl) I8_APPROXDECOMP_FROM_TNDBL_F;
7 | typedef typeof(i16_approxdecomp_from_tndbl) I16_APPROXDECOMP_FROM_TNDBL_F;
8 | typedef typeof(i32_approxdecomp_from_tndbl) I32_APPROXDECOMP_FROM_TNDBL_F;
9 | typedef typeof(bytes_of_zn32_vmp_pmat) BYTES_OF_ZN32_VMP_PMAT_F;
10 | typedef typeof(zn32_vmp_prepare_contiguous) ZN32_VMP_PREPARE_CONTIGUOUS_F;
11 | typedef typeof(zn32_vmp_apply_i32) ZN32_VMP_APPLY_I32_F;
12 | typedef typeof(zn32_vmp_apply_i16) ZN32_VMP_APPLY_I16_F;
13 | typedef typeof(zn32_vmp_apply_i8) ZN32_VMP_APPLY_I8_F;
14 | typedef typeof(dbl_to_tn32) DBL_TO_TN32_F;
15 | typedef typeof(tn32_to_dbl) TN32_TO_DBL_F;
16 | typedef typeof(dbl_round_to_i32) DBL_ROUND_TO_I32_F;
17 | typedef typeof(i32_to_dbl) I32_TO_DBL_F;
18 | typedef typeof(dbl_round_to_i64) DBL_ROUND_TO_I64_F;
19 | typedef typeof(i64_to_dbl) I64_TO_DBL_F;
20 |
21 | typedef struct z_module_vtable_t Z_MODULE_VTABLE;
22 | struct z_module_vtable_t {
23 | I8_APPROXDECOMP_FROM_TNDBL_F* i8_approxdecomp_from_tndbl;
24 | I16_APPROXDECOMP_FROM_TNDBL_F* i16_approxdecomp_from_tndbl;
25 | I32_APPROXDECOMP_FROM_TNDBL_F* i32_approxdecomp_from_tndbl;
26 | BYTES_OF_ZN32_VMP_PMAT_F* bytes_of_zn32_vmp_pmat;
27 | ZN32_VMP_PREPARE_CONTIGUOUS_F* zn32_vmp_prepare_contiguous;
28 | ZN32_VMP_APPLY_I32_F* zn32_vmp_apply_i32;
29 | ZN32_VMP_APPLY_I16_F* zn32_vmp_apply_i16;
30 | ZN32_VMP_APPLY_I8_F* zn32_vmp_apply_i8;
31 | DBL_TO_TN32_F* dbl_to_tn32;
32 | TN32_TO_DBL_F* tn32_to_dbl;
33 | DBL_ROUND_TO_I32_F* dbl_round_to_i32;
34 | I32_TO_DBL_F* i32_to_dbl;
35 | DBL_ROUND_TO_I64_F* dbl_round_to_i64;
36 | I64_TO_DBL_F* i64_to_dbl;
37 | };
38 |
39 | #endif // SPQLIOS_ZN_ARITHMETIC_PLUGIN_H
40 |
--------------------------------------------------------------------------------
/spqlios/arithmetic/zn_vmp_int16_avx.c:
--------------------------------------------------------------------------------
1 | #define INTTYPE int16_t
2 | #define INTSN i16
3 |
4 | #include "zn_vmp_int32_avx.c"
5 |
--------------------------------------------------------------------------------
/spqlios/arithmetic/zn_vmp_int16_ref.c:
--------------------------------------------------------------------------------
1 | #define INTTYPE int16_t
2 | #define INTSN i16
3 |
4 | #include "zn_vmp_int32_ref.c"
5 |
--------------------------------------------------------------------------------
/spqlios/arithmetic/zn_vmp_int32_ref.c:
--------------------------------------------------------------------------------
1 | // This file is actually a template: it will be compiled multiple times with
2 | // different INTTYPES
3 | #ifndef INTTYPE
4 | #define INTTYPE int32_t
5 | #define INTSN i32
6 | #endif
7 |
8 | #include
9 |
10 | #include "zn_arithmetic_private.h"
11 |
12 | #define concat_inner(aa, bb, cc) aa##_##bb##_##cc
13 | #define concat(aa, bb, cc) concat_inner(aa, bb, cc)
14 | #define zn32_vec_fn(cc) concat(zn32_vec, INTSN, cc)
15 |
16 | // the ref version shares the same implementation for each fixed column size
17 | // optimized implementations may do something different.
18 | static __always_inline void IMPL_zn32_vec_matcols_ref(
19 | const uint64_t NCOLS, // fixed number of columns
20 | uint64_t nrows, // nrows of b
21 | int32_t* res, // result: size NCOLS, only the first min(b_sl, NCOLS) are relevant
22 | const INTTYPE* a, // a: nrows-sized vector
23 | const int32_t* b, uint64_t b_sl // b: nrows * min(b_sl, NCOLS) matrix
24 | ) {
25 | memset(res, 0, NCOLS * sizeof(int32_t));
26 | for (uint64_t row = 0; row < nrows; ++row) {
27 | int32_t ai = a[row];
28 | const int32_t* bb = b + row * b_sl;
29 | for (uint64_t i = 0; i < NCOLS; ++i) {
30 | res[i] += ai * bb[i];
31 | }
32 | }
33 | }
34 |
35 | void zn32_vec_fn(mat32cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
36 | IMPL_zn32_vec_matcols_ref(32, nrows, res, a, b, b_sl);
37 | }
38 | void zn32_vec_fn(mat24cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
39 | IMPL_zn32_vec_matcols_ref(24, nrows, res, a, b, b_sl);
40 | }
41 | void zn32_vec_fn(mat16cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
42 | IMPL_zn32_vec_matcols_ref(16, nrows, res, a, b, b_sl);
43 | }
44 | void zn32_vec_fn(mat8cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) {
45 | IMPL_zn32_vec_matcols_ref(8, nrows, res, a, b, b_sl);
46 | }
47 |
48 | typedef void (*vm_f)(uint64_t nrows, //
49 | int32_t* res, //
50 | const INTTYPE* a, //
51 | const int32_t* b, uint64_t b_sl //
52 | );
53 | static const vm_f zn32_vec_mat8kcols_ref[4] = { //
54 | zn32_vec_fn(mat8cols_ref), //
55 | zn32_vec_fn(mat16cols_ref), //
56 | zn32_vec_fn(mat24cols_ref), //
57 | zn32_vec_fn(mat32cols_ref)};
58 |
59 | /** @brief applies a vmp product (int32_t* input) */
60 | EXPORT void concat(default_zn32_vmp_apply, INTSN, ref)( //
61 | const MOD_Z* module, //
62 | int32_t* res, uint64_t res_size, //
63 | const INTTYPE* a, uint64_t a_size, //
64 | const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) {
65 | const uint64_t rows = a_size < nrows ? a_size : nrows;
66 | const uint64_t cols = res_size < ncols ? res_size : ncols;
67 | const uint64_t ncolblk = cols >> 5;
68 | const uint64_t ncolrem = cols & 31;
69 | // copy the first full blocks
70 | const uint32_t full_blk_size = nrows * 32;
71 | const int32_t* mat = (int32_t*)pmat;
72 | int32_t* rr = res;
73 | for (uint64_t blk = 0; //
74 | blk < ncolblk; //
75 | ++blk, mat += full_blk_size, rr += 32) {
76 | zn32_vec_fn(mat32cols_ref)(rows, rr, a, mat, 32);
77 | }
78 | // last block
79 | if (ncolrem) {
80 | uint64_t orig_rem = ncols - (ncolblk << 5);
81 | uint64_t b_sl = orig_rem >= 32 ? 32 : orig_rem;
82 | int32_t tmp[32];
83 | zn32_vec_mat8kcols_ref[(ncolrem - 1) >> 3](rows, tmp, a, mat, b_sl);
84 | memcpy(rr, tmp, ncolrem * sizeof(int32_t));
85 | }
86 | // trailing bytes
87 | memset(res + cols, 0, (res_size - cols) * sizeof(int32_t));
88 | }
89 |
--------------------------------------------------------------------------------
/spqlios/arithmetic/zn_vmp_int8_avx.c:
--------------------------------------------------------------------------------
1 | #define INTTYPE int8_t
2 | #define INTSN i8
3 |
4 | #include "zn_vmp_int32_avx.c"
5 |
--------------------------------------------------------------------------------
/spqlios/arithmetic/zn_vmp_int8_ref.c:
--------------------------------------------------------------------------------
1 | #define INTTYPE int8_t
2 | #define INTSN i8
3 |
4 | #include "zn_vmp_int32_ref.c"
5 |
--------------------------------------------------------------------------------
/spqlios/arithmetic/znx_small.c:
--------------------------------------------------------------------------------
1 | #include "vec_znx_arithmetic_private.h"
2 |
3 | /** @brief res = a * b : small integer polynomial product */
4 | EXPORT void fft64_znx_small_single_product(const MODULE* module, // N
5 | int64_t* res, // output
6 | const int64_t* a, // a
7 | const int64_t* b, // b
8 | uint8_t* tmp) {
9 | const uint64_t nn = module->nn;
10 | double* const ffta = (double*)tmp;
11 | double* const fftb = ((double*)tmp) + nn;
12 | reim_from_znx64(module->mod.fft64.p_conv, ffta, a);
13 | reim_from_znx64(module->mod.fft64.p_conv, fftb, b);
14 | reim_fft(module->mod.fft64.p_fft, ffta);
15 | reim_fft(module->mod.fft64.p_fft, fftb);
16 | reim_fftvec_mul_simple(module->m, ffta, ffta, fftb);
17 | reim_ifft(module->mod.fft64.p_ifft, ffta);
18 | reim_to_znx64(module->mod.fft64.p_reim_to_znx, res, ffta);
19 | }
20 |
21 | /** @brief tmp bytes required for znx_small_single_product */
22 | EXPORT uint64_t fft64_znx_small_single_product_tmp_bytes(const MODULE* module) {
23 | return 2 * module->nn * sizeof(double);
24 | }
25 |
26 | /** @brief res = a * b : small integer polynomial product */
27 | EXPORT void znx_small_single_product(const MODULE* module, // N
28 | int64_t* res, // output
29 | const int64_t* a, // a
30 | const int64_t* b, // b
31 | uint8_t* tmp) {
32 | module->func.znx_small_single_product(module, res, a, b, tmp);
33 | }
34 |
35 | /** @brief tmp bytes required for znx_small_single_product */
36 | EXPORT uint64_t znx_small_single_product_tmp_bytes(const MODULE* module) {
37 | return module->func.znx_small_single_product_tmp_bytes(module);
38 | }
39 |
--------------------------------------------------------------------------------
/spqlios/coeffs/coeffs_arithmetic.h:
--------------------------------------------------------------------------------
1 | #ifndef SPQLIOS_COEFFS_ARITHMETIC_H
2 | #define SPQLIOS_COEFFS_ARITHMETIC_H
3 |
4 | #include "../commons.h"
5 |
6 | /** res = a + b */
7 | EXPORT void znx_add_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b);
8 | EXPORT void znx_add_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b);
9 | /** res = a - b */
10 | EXPORT void znx_sub_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b);
11 | EXPORT void znx_sub_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b);
12 | /** res = -a */
13 | EXPORT void znx_negate_i64_ref(uint64_t nn, int64_t* res, const int64_t* a);
14 | EXPORT void znx_negate_i64_avx(uint64_t nn, int64_t* res, const int64_t* a);
15 | /** res = a */
16 | EXPORT void znx_copy_i64_ref(uint64_t nn, int64_t* res, const int64_t* a);
17 | /** res = 0 */
18 | EXPORT void znx_zero_i64_ref(uint64_t nn, int64_t* res);
19 |
20 | /** res = a / m where m is a power of 2 */
21 | EXPORT void rnx_divide_by_m_ref(uint64_t nn, double m, double* res, const double* a);
22 | EXPORT void rnx_divide_by_m_avx(uint64_t nn, double m, double* res, const double* a);
23 |
24 | /**
25 | * @param res = X^p *in mod X^nn +1
26 | * @param nn the ring dimension
27 | * @param p a power for the rotation -2nn <= p <= 2nn
28 | * @param in is a rnx/znx vector of dimension nn
29 | */
30 | EXPORT void rnx_rotate_f64(uint64_t nn, int64_t p, double* res, const double* in);
31 | EXPORT void znx_rotate_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in);
32 | EXPORT void rnx_rotate_inplace_f64(uint64_t nn, int64_t p, double* res);
33 | EXPORT void znx_rotate_inplace_i64(uint64_t nn, int64_t p, int64_t* res);
34 |
35 | /**
36 | * @brief res(X) = in(X^p)
37 | * @param nn the ring dimension
38 | * @param p is odd integer and must be between 0 < p < 2nn
39 | * @param in is a rnx/znx vector of dimension nn
40 | */
41 | EXPORT void rnx_automorphism_f64(uint64_t nn, int64_t p, double* res, const double* in);
42 | EXPORT void znx_automorphism_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in);
43 | EXPORT void rnx_automorphism_inplace_f64(uint64_t nn, int64_t p, double* res);
44 | EXPORT void znx_automorphism_inplace_i64(uint64_t nn, int64_t p, int64_t* res);
45 |
46 | /**
47 | * @brief res = (X^p-1).in
48 | * @param nn the ring dimension
49 | * @param p must be between -2nn <= p <= 2nn
50 | * @param in is a rnx/znx vector of dimension nn
51 | */
52 | EXPORT void rnx_mul_xp_minus_one(uint64_t nn, int64_t p, double* res, const double* in);
53 | EXPORT void znx_mul_xp_minus_one(uint64_t nn, int64_t p, int64_t* res, const int64_t* in);
54 | EXPORT void rnx_mul_xp_minus_one_inplace(uint64_t nn, int64_t p, double* res);
55 |
56 | /**
57 | * @brief Normalize input plus carry mod-2^k. The following
58 | * equality holds @c {in + carry_in == out + carry_out . 2^k}.
59 | *
60 | * @c in must be in [-2^62 .. 2^62]
61 | *
62 | * @c out is in [ -2^(base_k-1), 2^(base_k-1) [.
63 | *
64 | * @c carry_in and @carry_out have at most 64+1-k bits.
65 | *
66 | * Null @c carry_in or @c carry_out are ignored.
67 | *
68 | * @param[in] nn the ring dimension
69 | * @param[in] base_k the base k
70 | * @param out output normalized znx
71 | * @param carry_out output carry znx
72 | * @param[in] in input znx
73 | * @param[in] carry_in input carry znx
74 | */
75 | EXPORT void znx_normalize(uint64_t nn, uint64_t base_k, int64_t* out, int64_t* carry_out, const int64_t* in,
76 | const int64_t* carry_in);
77 |
78 | #endif // SPQLIOS_COEFFS_ARITHMETIC_H
79 |
--------------------------------------------------------------------------------
/spqlios/coeffs/coeffs_arithmetic_avx.c:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "../commons_private.h"
4 | #include "coeffs_arithmetic.h"
5 |
6 | // res = a + b. dimension n must be a power of 2
7 | EXPORT void znx_add_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) {
8 | if (nn <= 2) {
9 | if (nn == 1) {
10 | res[0] = a[0] + b[0];
11 | } else {
12 | _mm_storeu_si128((__m128i*)res, //
13 | _mm_add_epi64( //
14 | _mm_loadu_si128((__m128i*)a), //
15 | _mm_loadu_si128((__m128i*)b)));
16 | }
17 | } else {
18 | const __m256i* aa = (__m256i*)a;
19 | const __m256i* bb = (__m256i*)b;
20 | __m256i* rr = (__m256i*)res;
21 | __m256i* const rrend = (__m256i*)(res + nn);
22 | do {
23 | _mm256_storeu_si256(rr, //
24 | _mm256_add_epi64( //
25 | _mm256_loadu_si256(aa), //
26 | _mm256_loadu_si256(bb)));
27 | ++rr;
28 | ++aa;
29 | ++bb;
30 | } while (rr < rrend);
31 | }
32 | }
33 |
34 | // res = a - b. dimension n must be a power of 2
35 | EXPORT void znx_sub_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) {
36 | if (nn <= 2) {
37 | if (nn == 1) {
38 | res[0] = a[0] - b[0];
39 | } else {
40 | _mm_storeu_si128((__m128i*)res, //
41 | _mm_sub_epi64( //
42 | _mm_loadu_si128((__m128i*)a), //
43 | _mm_loadu_si128((__m128i*)b)));
44 | }
45 | } else {
46 | const __m256i* aa = (__m256i*)a;
47 | const __m256i* bb = (__m256i*)b;
48 | __m256i* rr = (__m256i*)res;
49 | __m256i* const rrend = (__m256i*)(res + nn);
50 | do {
51 | _mm256_storeu_si256(rr, //
52 | _mm256_sub_epi64( //
53 | _mm256_loadu_si256(aa), //
54 | _mm256_loadu_si256(bb)));
55 | ++rr;
56 | ++aa;
57 | ++bb;
58 | } while (rr < rrend);
59 | }
60 | }
61 |
62 | EXPORT void znx_negate_i64_avx(uint64_t nn, int64_t* res, const int64_t* a) {
63 | if (nn <= 2) {
64 | if (nn == 1) {
65 | res[0] = -a[0];
66 | } else {
67 | _mm_storeu_si128((__m128i*)res, //
68 | _mm_sub_epi64( //
69 | _mm_set1_epi64x(0), //
70 | _mm_loadu_si128((__m128i*)a)));
71 | }
72 | } else {
73 | const __m256i* aa = (__m256i*)a;
74 | __m256i* rr = (__m256i*)res;
75 | __m256i* const rrend = (__m256i*)(res + nn);
76 | do {
77 | _mm256_storeu_si256(rr, //
78 | _mm256_sub_epi64( //
79 | _mm256_set1_epi64x(0), //
80 | _mm256_loadu_si256(aa)));
81 | ++rr;
82 | ++aa;
83 | } while (rr < rrend);
84 | }
85 | }
86 |
87 | EXPORT void rnx_divide_by_m_avx(uint64_t n, double m, double* res, const double* a) {
88 | // TODO: see if there is a faster way of dividing by a power of 2?
89 | const double invm = 1. / m;
90 | if (n < 8) {
91 | switch (n) {
92 | case 1:
93 | *res = *a * invm;
94 | break;
95 | case 2:
96 | _mm_storeu_pd(res, //
97 | _mm_mul_pd(_mm_loadu_pd(a), //
98 | _mm_set1_pd(invm)));
99 | break;
100 | case 4:
101 | _mm256_storeu_pd(res, //
102 | _mm256_mul_pd(_mm256_loadu_pd(a), //
103 | _mm256_set1_pd(invm)));
104 | break;
105 | default:
106 | NOT_SUPPORTED(); // non-power of 2
107 | }
108 | return;
109 | }
110 | const __m256d invm256 = _mm256_set1_pd(invm);
111 | double* rr = res;
112 | const double* aa = a;
113 | const double* const aaend = a + n;
114 | do {
115 | _mm256_storeu_pd(rr, //
116 | _mm256_mul_pd(_mm256_loadu_pd(aa), //
117 | invm256));
118 | _mm256_storeu_pd(rr + 4, //
119 | _mm256_mul_pd(_mm256_loadu_pd(aa + 4), //
120 | invm256));
121 | rr += 8;
122 | aa += 8;
123 | } while (aa < aaend);
124 | }
125 |
--------------------------------------------------------------------------------
/spqlios/commons.h:
--------------------------------------------------------------------------------
1 | #ifndef SPQLIOS_COMMONS_H
2 | #define SPQLIOS_COMMONS_H
3 |
4 | #ifdef __cplusplus
5 | #include
6 | #include
7 | #include
8 | #define EXPORT extern "C"
9 | #define EXPORT_DECL extern "C"
10 | #else
11 | #include
12 | #include
13 | #include
14 | #define EXPORT
15 | #define EXPORT_DECL extern
16 | #define nullptr 0x0;
17 | #endif
18 |
19 | #define UNDEFINED() \
20 | { \
21 | fprintf(stderr, "UNDEFINED!!!\n"); \
22 | abort(); \
23 | }
24 | #define NOT_IMPLEMENTED() \
25 | { \
26 | fprintf(stderr, "NOT IMPLEMENTED!!!\n"); \
27 | abort(); \
28 | }
29 | #define FATAL_ERROR(MESSAGE) \
30 | { \
31 | fprintf(stderr, "ERROR: %s\n", (MESSAGE)); \
32 | abort(); \
33 | }
34 |
35 | EXPORT void* UNDEFINED_p_ii(int32_t n, int32_t m);
36 | EXPORT void* UNDEFINED_p_uu(uint32_t n, uint32_t m);
37 | EXPORT double* UNDEFINED_dp_pi(const void* p, int32_t n);
38 | EXPORT void* UNDEFINED_vp_pi(const void* p, int32_t n);
39 | EXPORT void* UNDEFINED_vp_pu(const void* p, uint32_t n);
40 | EXPORT void UNDEFINED_v_vpdp(const void* p, double* a);
41 | EXPORT void UNDEFINED_v_vpvp(const void* p, void* a);
42 | EXPORT double* NOT_IMPLEMENTED_dp_i(int32_t n);
43 | EXPORT void* NOT_IMPLEMENTED_vp_i(int32_t n);
44 | EXPORT void* NOT_IMPLEMENTED_vp_u(uint32_t n);
45 | EXPORT void NOT_IMPLEMENTED_v_dp(double* a);
46 | EXPORT void NOT_IMPLEMENTED_v_vp(void* p);
47 | EXPORT void NOT_IMPLEMENTED_v_idpdpdp(int32_t n, double* a, const double* b, const double* c);
48 | EXPORT void NOT_IMPLEMENTED_v_uvpcvpcvp(uint32_t n, void* r, const void* a, const void* b);
49 | EXPORT void NOT_IMPLEMENTED_v_uvpvpcvp(uint32_t n, void* a, void* b, const void* o);
50 |
51 | // windows
52 |
53 | #if defined(_WIN32) || defined(__APPLE__)
54 | #define __always_inline inline __attribute((always_inline))
55 | #endif
56 |
57 | EXPORT void spqlios_free(void* address);
58 |
59 | EXPORT void* spqlios_alloc(uint64_t size);
60 | EXPORT void* spqlios_alloc_custom_align(uint64_t align, uint64_t size);
61 |
62 | #define USE_LIBM_SIN_COS
63 | #ifndef USE_LIBM_SIN_COS
64 | // if at some point, we want to remove the libm dependency, we can
65 | // consider this:
66 | EXPORT double internal_accurate_cos(double x);
67 | EXPORT double internal_accurate_sin(double x);
68 | EXPORT void internal_accurate_sincos(double* rcos, double* rsin, double x);
69 | #define m_accurate_cos internal_accurate_cos
70 | #define m_accurate_sin internal_accurate_sin
71 | #else
72 | // let's use libm sin and cos
73 | #define m_accurate_cos cos
74 | #define m_accurate_sin sin
75 | #endif
76 |
77 | #endif // SPQLIOS_COMMONS_H
78 |
--------------------------------------------------------------------------------
/spqlios/commons_private.c:
--------------------------------------------------------------------------------
1 | #include "commons_private.h"
2 |
3 | #include
4 | #include
5 |
6 | #include "commons.h"
7 |
8 | EXPORT void* spqlios_error(const char* error) {
9 | fputs(error, stderr);
10 | abort();
11 | return nullptr;
12 | }
13 | EXPORT void* spqlios_keep_or_free(void* ptr, void* ptr2) {
14 | if (!ptr2) {
15 | free(ptr);
16 | }
17 | return ptr2;
18 | }
19 |
20 | EXPORT uint32_t log2m(uint32_t m) {
21 | uint32_t a = m - 1;
22 | if (m & a) FATAL_ERROR("m must be a power of two");
23 | a = (a & 0x55555555u) + ((a >> 1) & 0x55555555u);
24 | a = (a & 0x33333333u) + ((a >> 2) & 0x33333333u);
25 | a = (a & 0x0F0F0F0Fu) + ((a >> 4) & 0x0F0F0F0Fu);
26 | a = (a & 0x00FF00FFu) + ((a >> 8) & 0x00FF00FFu);
27 | return (a & 0x0000FFFFu) + ((a >> 16) & 0x0000FFFFu);
28 | }
29 |
30 | EXPORT uint64_t is_not_pow2_double(void* doublevalue) { return (*(uint64_t*)doublevalue) & 0x7FFFFFFFFFFFFUL; }
31 |
32 | uint32_t revbits(uint32_t nbits, uint32_t value) {
33 | uint32_t res = 0;
34 | for (uint32_t i = 0; i < nbits; ++i) {
35 | res = (res << 1) + (value & 1);
36 | value >>= 1;
37 | }
38 | return res;
39 | }
40 |
41 | /**
42 | * @brief this computes the sequence: 0,1/2,1/4,3/4,1/8,5/8,3/8,7/8,...
43 | * essentially: the bits of (i+1) in lsb order on the basis (1/2^k) mod 1*/
44 | double fracrevbits(uint32_t i) {
45 | if (i == 0) return 0;
46 | if (i == 1) return 0.5;
47 | if (i % 2 == 0)
48 | return fracrevbits(i / 2) / 2.;
49 | else
50 | return fracrevbits((i - 1) / 2) / 2. + 0.5;
51 | }
52 |
53 | uint64_t ceilto64b(uint64_t size) { return (size + UINT64_C(63)) & (UINT64_C(-64)); }
54 |
55 | uint64_t ceilto32b(uint64_t size) { return (size + UINT64_C(31)) & (UINT64_C(-32)); }
56 |
--------------------------------------------------------------------------------
/spqlios/commons_private.h:
--------------------------------------------------------------------------------
1 | #ifndef SPQLIOS_COMMONS_PRIVATE_H
2 | #define SPQLIOS_COMMONS_PRIVATE_H
3 |
4 | #include "commons.h"
5 |
6 | #ifdef __cplusplus
7 | #include
8 | #include
9 | #include
10 | #else
11 | #include
12 | #include
13 | #include
14 | #define nullptr 0x0;
15 | #endif
16 |
17 | /** @brief log2 of a power of two (UB if m is not a power of two) */
18 | EXPORT uint32_t log2m(uint32_t m);
19 |
20 | /** @brief checks if the doublevalue is a power of two */
21 | EXPORT uint64_t is_not_pow2_double(void* doublevalue);
22 |
23 | #define UNDEFINED() \
24 | { \
25 | fprintf(stderr, "UNDEFINED!!!\n"); \
26 | abort(); \
27 | }
28 | #define NOT_IMPLEMENTED() \
29 | { \
30 | fprintf(stderr, "NOT IMPLEMENTED!!!\n"); \
31 | abort(); \
32 | }
33 | #define NOT_SUPPORTED() \
34 | { \
35 | fprintf(stderr, "NOT SUPPORTED!!!\n"); \
36 | abort(); \
37 | }
38 | #define FATAL_ERROR(MESSAGE) \
39 | { \
40 | fprintf(stderr, "ERROR: %s\n", (MESSAGE)); \
41 | abort(); \
42 | }
43 |
44 | #define STATIC_ASSERT(condition) (void)sizeof(char[-1 + 2 * !!(condition)])
45 |
46 | /** @brief reports the error and returns nullptr */
47 | EXPORT void* spqlios_error(const char* error);
48 | /** @brief if ptr2 is not null, returns ptr, otherwise free ptr and return null */
49 | EXPORT void* spqlios_keep_or_free(void* ptr, void* ptr2);
50 |
51 | #ifdef __x86_64__
52 | #define CPU_SUPPORTS __builtin_cpu_supports
53 | #else
54 | // TODO for now, we do not have any optimization for non x86 targets
55 | #define CPU_SUPPORTS(xxxx) 0
56 | #endif
57 |
58 | /** @brief returns the n bits of value in reversed order */
59 | EXPORT uint32_t revbits(uint32_t nbits, uint32_t value);
60 |
61 | /**
62 | * @brief this computes the sequence: 0,1/2,1/4,3/4,1/8,5/8,3/8,7/8,...
63 | * essentially: the bits of (i+1) in lsb order on the basis (1/2^k) mod 1*/
64 | EXPORT double fracrevbits(uint32_t i);
65 |
66 | /** @brief smallest multiple of 64 higher or equal to size */
67 | EXPORT uint64_t ceilto64b(uint64_t size);
68 |
69 | /** @brief smallest multiple of 32 higher or equal to size */
70 | EXPORT uint64_t ceilto32b(uint64_t size);
71 |
72 | #endif // SPQLIOS_COMMONS_PRIVATE_H
73 |
--------------------------------------------------------------------------------
/spqlios/cplx/README.md:
--------------------------------------------------------------------------------
1 | In this folder, we deal with the full complex FFT in `C[X] mod X^M-i`.
2 | One complex is represented by two consecutive doubles `(real,imag)`
3 | Note that a real polynomial sum_{j=0}^{N-1} p_j.X^j mod X^N+1
4 | corresponds to the complex polynomial of half degree `M=N/2`:
5 | `sum_{j=0}^{M-1} (p_{j} + i.p_{j+M}) X^j mod X^M-i`
6 |
7 | For a complex polynomial A(X) sum c_i X^i of degree M-1
8 | or a real polynomial sum a_i X^i of degree N
9 |
10 | coefficient space:
11 | a_0,a_M,a_1,a_{M+1},...,a_{M-1},a_{2M-1}
12 | or equivalently
13 | Re(c_0),Im(c_0),Re(c_1),Im(c_1),...Re(c_{M-1}),Im(c_{M-1})
14 |
15 | eval space:
16 | c(omega_{0}),...,c(omega_{M-1})
17 |
18 | where
19 | omega_j = omega^{1+rev_{2N}(j)}
20 | and omega = exp(i.pi/N)
21 |
22 | rev_{2N}(j) is the number that has the log2(2N) bits of j in reverse order.
--------------------------------------------------------------------------------
/spqlios/cplx/cplx_common.c:
--------------------------------------------------------------------------------
1 | #include "cplx_fft_internal.h"
2 |
3 | void cplx_set(CPLX r, const CPLX a) {
4 | r[0] = a[0];
5 | r[1] = a[1];
6 | }
7 | void cplx_neg(CPLX r, const CPLX a) {
8 | r[0] = -a[0];
9 | r[1] = -a[1];
10 | }
11 | void cplx_add(CPLX r, const CPLX a, const CPLX b) {
12 | r[0] = a[0] + b[0];
13 | r[1] = a[1] + b[1];
14 | }
15 | void cplx_sub(CPLX r, const CPLX a, const CPLX b) {
16 | r[0] = a[0] - b[0];
17 | r[1] = a[1] - b[1];
18 | }
19 | void cplx_mul(CPLX r, const CPLX a, const CPLX b) {
20 | double re = a[0] * b[0] - a[1] * b[1];
21 | r[1] = a[0] * b[1] + a[1] * b[0];
22 | r[0] = re;
23 | }
24 |
25 | /**
26 | * @brief splits 2h evaluations of one polynomials into 2 times h evaluations of even/odd polynomial
27 | * Input: Q_0(y),...,Q_{h-1}(y),Q_0(-y),...,Q_{h-1}(-y)
28 | * Output: P_0(z),...,P_{h-1}(z),P_h(z),...,P_{2h-1}(z)
29 | * where Q_i(X)=P_i(X^2)+X.P_{h+i}(X^2) and y^2 = z
30 | * @param h number of "coefficients" h >= 1
31 | * @param data 2h complex coefficients interleaved and 256b aligned
32 | * @param powom y represented as (yre,yim)
33 | */
34 | EXPORT void cplx_split_fft_ref(int32_t h, CPLX* data, const CPLX powom) {
35 | CPLX* d0 = data;
36 | CPLX* d1 = data + h;
37 | for (uint64_t i = 0; i < h; ++i) {
38 | CPLX diff;
39 | cplx_sub(diff, d0[i], d1[i]);
40 | cplx_add(d0[i], d0[i], d1[i]);
41 | cplx_mul(d1[i], diff, powom);
42 | }
43 | }
44 |
45 | /**
46 | * @brief Do two layers of itwiddle (i.e. split).
47 | * Input/output: d0,d1,d2,d3 of length h
48 | * Algo:
49 | * itwiddle(d0,d1,om[0]),itwiddle(d2,d3,i.om[0])
50 | * itwiddle(d0,d2,om[1]),itwiddle(d1,d3,om[1])
51 | * @param h number of "coefficients" h >= 1
52 | * @param data 4h complex coefficients interleaved and 256b aligned
53 | * @param powom om[0] (re,im) and om[1] where om[1]=om[0]^2
54 | */
55 | EXPORT void cplx_bisplit_fft_ref(int32_t h, CPLX* data, const CPLX powom[2]) {
56 | CPLX* d0 = data;
57 | CPLX* d2 = data + 2*h;
58 | const CPLX* om0 = powom;
59 | CPLX iom0;
60 | iom0[0]=powom[0][1];
61 | iom0[1]=-powom[0][0];
62 | const CPLX* om1 = powom+1;
63 | cplx_split_fft_ref(h, d0, *om0);
64 | cplx_split_fft_ref(h, d2, iom0);
65 | cplx_split_fft_ref(2*h, d0, *om1);
66 | }
67 |
68 | /**
69 | * Input: Q(y),Q(-y)
70 | * Output: P_0(z),P_1(z)
71 | * where Q(X)=P_0(X^2)+X.P_1(X^2) and y^2 = z
72 | * @param data 2 complexes coefficients interleaved and 256b aligned
73 | * @param powom (z,-z) interleaved: (zre,zim,-zre,-zim)
74 | */
75 | void split_fft_last_ref(CPLX* data, const CPLX powom) {
76 | CPLX diff;
77 | cplx_sub(diff, data[0], data[1]);
78 | cplx_add(data[0], data[0], data[1]);
79 | cplx_mul(data[1], diff, powom);
80 | }
81 |
--------------------------------------------------------------------------------
/spqlios/cplx/cplx_execute.c:
--------------------------------------------------------------------------------
1 | #include "cplx_fft_internal.h"
2 | #include "cplx_fft_private.h"
3 |
4 | EXPORT void cplx_from_znx32(const CPLX_FROM_ZNX32_PRECOMP* tables, void* r, const int32_t* a) {
5 | tables->function(tables, r, a);
6 | }
7 | EXPORT void cplx_from_tnx32(const CPLX_FROM_TNX32_PRECOMP* tables, void* r, const int32_t* a) {
8 | tables->function(tables, r, a);
9 | }
10 | EXPORT void cplx_to_tnx32(const CPLX_TO_TNX32_PRECOMP* tables, int32_t* r, const void* a) {
11 | tables->function(tables, r, a);
12 | }
13 | EXPORT void cplx_fftvec_mul(const CPLX_FFTVEC_MUL_PRECOMP* tables, void* r, const void* a, const void* b) {
14 | tables->function(tables, r, a, b);
15 | }
16 | EXPORT void cplx_fftvec_addmul(const CPLX_FFTVEC_ADDMUL_PRECOMP* tables, void* r, const void* a, const void* b) {
17 | tables->function(tables, r, a, b);
18 | }
19 |
--------------------------------------------------------------------------------
/spqlios/cplx/cplx_fallbacks_aarch64.c:
--------------------------------------------------------------------------------
1 | #include "cplx_fft_internal.h"
2 | #include "cplx_fft_private.h"
3 |
4 | EXPORT void cplx_fftvec_addmul_fma(const CPLX_FFTVEC_ADDMUL_PRECOMP* tables, void* r, const void* a, const void* b) {
5 | UNDEFINED(); // not defined for non x86 targets
6 | }
7 | EXPORT void cplx_fftvec_mul_fma(const CPLX_FFTVEC_MUL_PRECOMP* tables, void* r, const void* a, const void* b) {
8 | UNDEFINED();
9 | }
10 | EXPORT void cplx_fftvec_addmul_sse(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b) {
11 | UNDEFINED();
12 | }
13 | EXPORT void cplx_fftvec_addmul_avx512(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a,
14 | const void* b) {
15 | UNDEFINED();
16 | }
17 | EXPORT void cplx_fft16_avx_fma(void* data, const void* omega) { UNDEFINED(); }
18 | EXPORT void cplx_ifft16_avx_fma(void* data, const void* omega) { UNDEFINED(); }
19 | EXPORT void cplx_from_znx32_avx2_fma(const CPLX_FROM_ZNX32_PRECOMP* precomp, void* r, const int32_t* x) { UNDEFINED(); }
20 | EXPORT void cplx_from_tnx32_avx2_fma(const CPLX_FROM_TNX32_PRECOMP* precomp, void* r, const int32_t* x) { UNDEFINED(); }
21 | EXPORT void cplx_to_tnx32_avx2_fma(const CPLX_TO_TNX32_PRECOMP* precomp, int32_t* x, const void* c) { UNDEFINED(); }
22 | EXPORT void cplx_fft_avx2_fma(const CPLX_FFT_PRECOMP* tables, void* data){UNDEFINED()} EXPORT
23 | void cplx_ifft_avx2_fma(const CPLX_IFFT_PRECOMP* itables, void* data){UNDEFINED()} EXPORT
24 | void cplx_fftvec_twiddle_fma(const CPLX_FFTVEC_TWIDDLE_PRECOMP* tables, void* a, void* b, const void* om){
25 | UNDEFINED()} EXPORT void cplx_fftvec_twiddle_avx512(const CPLX_FFTVEC_TWIDDLE_PRECOMP* tables, void* a, void* b,
26 | const void* om){UNDEFINED()} EXPORT
27 | void cplx_fftvec_bitwiddle_fma(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* tables, void* a, uint64_t slice,
28 | const void* om){UNDEFINED()} EXPORT
29 | void cplx_fftvec_bitwiddle_avx512(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* tables, void* a, uint64_t slice,
30 | const void* om){UNDEFINED()}
31 |
32 | // DEPRECATED?
33 | EXPORT void cplx_fftvec_add_fma(uint32_t m, void* r, const void* a, const void* b){UNDEFINED()} EXPORT
34 | void cplx_fftvec_sub2_to_fma(uint32_t m, void* r, const void* a, const void* b){UNDEFINED()} EXPORT
35 | void cplx_fftvec_copy_fma(uint32_t m, void* r, const void* a){UNDEFINED()}
36 |
37 | // executors
38 | //EXPORT void cplx_ifft(const CPLX_IFFT_PRECOMP* itables, void* data) {
39 | // itables->function(itables, data);
40 | //}
41 | //EXPORT void cplx_fft(const CPLX_FFT_PRECOMP* tables, void* data) { tables->function(tables, data); }
42 |
--------------------------------------------------------------------------------
/spqlios/cplx/cplx_fft_asserts.c:
--------------------------------------------------------------------------------
1 | #include "cplx_fft_private.h"
2 | #include "../commons_private.h"
3 |
4 | __always_inline void my_asserts() {
5 | STATIC_ASSERT(sizeof(FFT_FUNCTION)==8);
6 | STATIC_ASSERT(sizeof(CPLX_FFT_PRECOMP)==40);
7 | STATIC_ASSERT(sizeof(CPLX_IFFT_PRECOMP)==40);
8 | }
9 |
--------------------------------------------------------------------------------
/spqlios/cplx/cplx_fft_private.h:
--------------------------------------------------------------------------------
1 | #ifndef SPQLIOS_CPLX_FFT_PRIVATE_H
2 | #define SPQLIOS_CPLX_FFT_PRIVATE_H
3 |
4 | #include "cplx_fft.h"
5 |
6 | typedef struct cplx_twiddle_precomp CPLX_FFTVEC_TWIDDLE_PRECOMP;
7 | typedef struct cplx_bitwiddle_precomp CPLX_FFTVEC_BITWIDDLE_PRECOMP;
8 |
9 | typedef void (*IFFT_FUNCTION)(const CPLX_IFFT_PRECOMP*, void*);
10 | typedef void (*FFT_FUNCTION)(const CPLX_FFT_PRECOMP*, void*);
11 | // conversions
12 | typedef void (*FROM_ZNX32_FUNCTION)(const CPLX_FROM_ZNX32_PRECOMP*, void*, const int32_t*);
13 | typedef void (*TO_ZNX32_FUNCTION)(const CPLX_FROM_ZNX32_PRECOMP*, int32_t*, const void*);
14 | typedef void (*FROM_TNX32_FUNCTION)(const CPLX_FROM_TNX32_PRECOMP*, void*, const int32_t*);
15 | typedef void (*TO_TNX32_FUNCTION)(const CPLX_TO_TNX32_PRECOMP*, int32_t*, const void*);
16 | typedef void (*FROM_RNX64_FUNCTION)(const CPLX_FROM_RNX64_PRECOMP* precomp, void* r, const double* x);
17 | typedef void (*TO_RNX64_FUNCTION)(const CPLX_TO_RNX64_PRECOMP* precomp, double* r, const void* x);
18 | typedef void (*ROUND_TO_RNX64_FUNCTION)(const CPLX_ROUND_TO_RNX64_PRECOMP* precomp, double* r, const void* x);
19 | // fftvec operations
20 | typedef void (*FFTVEC_MUL_FUNCTION)(const CPLX_FFTVEC_MUL_PRECOMP*, void*, const void*, const void*);
21 | typedef void (*FFTVEC_ADDMUL_FUNCTION)(const CPLX_FFTVEC_ADDMUL_PRECOMP*, void*, const void*, const void*);
22 |
23 | typedef void (*FFTVEC_TWIDDLE_FUNCTION)(const CPLX_FFTVEC_TWIDDLE_PRECOMP*, void*, const void*, const void*);
24 | typedef void (*FFTVEC_BITWIDDLE_FUNCTION)(const CPLX_FFTVEC_BITWIDDLE_PRECOMP*, void*, uint64_t, const void*);
25 |
26 | struct cplx_ifft_precomp {
27 | IFFT_FUNCTION function;
28 | int64_t m;
29 | uint64_t buf_size;
30 | double* powomegas;
31 | void* aligned_buffers;
32 | };
33 |
34 | struct cplx_fft_precomp {
35 | FFT_FUNCTION function;
36 | int64_t m;
37 | uint64_t buf_size;
38 | double* powomegas;
39 | void* aligned_buffers;
40 | };
41 |
42 | struct cplx_from_znx32_precomp {
43 | FROM_ZNX32_FUNCTION function;
44 | int64_t m;
45 | };
46 |
47 | struct cplx_to_znx32_precomp {
48 | TO_ZNX32_FUNCTION function;
49 | int64_t m;
50 | double divisor;
51 | };
52 |
53 | struct cplx_from_tnx32_precomp {
54 | FROM_TNX32_FUNCTION function;
55 | int64_t m;
56 | };
57 |
58 | struct cplx_to_tnx32_precomp {
59 | TO_TNX32_FUNCTION function;
60 | int64_t m;
61 | double divisor;
62 | };
63 |
64 | struct cplx_from_rnx64_precomp {
65 | FROM_RNX64_FUNCTION function;
66 | int64_t m;
67 | };
68 |
69 | struct cplx_to_rnx64_precomp {
70 | TO_RNX64_FUNCTION function;
71 | int64_t m;
72 | double divisor;
73 | };
74 |
75 | struct cplx_round_to_rnx64_precomp {
76 | ROUND_TO_RNX64_FUNCTION function;
77 | int64_t m;
78 | double divisor;
79 | uint32_t log2bound;
80 | };
81 |
82 | typedef struct cplx_mul_precomp {
83 | FFTVEC_MUL_FUNCTION function;
84 | int64_t m;
85 | } CPLX_FFTVEC_MUL_PRECOMP;
86 |
87 | typedef struct cplx_addmul_precomp {
88 | FFTVEC_ADDMUL_FUNCTION function;
89 | int64_t m;
90 | } CPLX_FFTVEC_ADDMUL_PRECOMP;
91 |
92 | struct cplx_twiddle_precomp {
93 | FFTVEC_TWIDDLE_FUNCTION function;
94 | int64_t m;
95 | };
96 |
97 | struct cplx_bitwiddle_precomp {
98 | FFTVEC_BITWIDDLE_FUNCTION function;
99 | int64_t m;
100 | };
101 |
102 | EXPORT void cplx_fftvec_twiddle_fma(const CPLX_FFTVEC_TWIDDLE_PRECOMP* tables, void* a, void* b, const void* om);
103 | EXPORT void cplx_fftvec_twiddle_avx512(const CPLX_FFTVEC_TWIDDLE_PRECOMP* tables, void* a, void* b, const void* om);
104 | EXPORT void cplx_fftvec_bitwiddle_fma(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* tables, void* a, uint64_t slice,
105 | const void* om);
106 | EXPORT void cplx_fftvec_bitwiddle_avx512(const CPLX_FFTVEC_BITWIDDLE_PRECOMP* tables, void* a, uint64_t slice,
107 | const void* om);
108 |
109 | #endif // SPQLIOS_CPLX_FFT_PRIVATE_H
110 |
--------------------------------------------------------------------------------
/spqlios/cplx/cplx_fftvec_ref.c:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "../commons_private.h"
4 | #include "cplx_fft_internal.h"
5 | #include "cplx_fft_private.h"
6 |
7 | EXPORT void cplx_fftvec_addmul_ref(const CPLX_FFTVEC_ADDMUL_PRECOMP* precomp, void* r, const void* a, const void* b) {
8 | const uint32_t m = precomp->m;
9 | const CPLX* aa = (CPLX*)a;
10 | const CPLX* bb = (CPLX*)b;
11 | CPLX* rr = (CPLX*)r;
12 | for (uint32_t i = 0; i < m; ++i) {
13 | const double re = aa[i][0] * bb[i][0] - aa[i][1] * bb[i][1];
14 | const double im = aa[i][0] * bb[i][1] + aa[i][1] * bb[i][0];
15 | rr[i][0] += re;
16 | rr[i][1] += im;
17 | }
18 | }
19 |
20 | EXPORT void cplx_fftvec_mul_ref(const CPLX_FFTVEC_MUL_PRECOMP* precomp, void* r, const void* a, const void* b) {
21 | const uint32_t m = precomp->m;
22 | const CPLX* aa = (CPLX*)a;
23 | const CPLX* bb = (CPLX*)b;
24 | CPLX* rr = (CPLX*)r;
25 | for (uint32_t i = 0; i < m; ++i) {
26 | const double re = aa[i][0] * bb[i][0] - aa[i][1] * bb[i][1];
27 | const double im = aa[i][0] * bb[i][1] + aa[i][1] * bb[i][0];
28 | rr[i][0] = re;
29 | rr[i][1] = im;
30 | }
31 | }
32 |
33 | EXPORT void* init_cplx_fftvec_addmul_precomp(CPLX_FFTVEC_ADDMUL_PRECOMP* r, uint32_t m) {
34 | if (m & (m - 1)) return spqlios_error("m must be a power of two");
35 | r->m = m;
36 | if (m <= 4) {
37 | r->function = cplx_fftvec_addmul_ref;
38 | } else if (CPU_SUPPORTS("fma")) {
39 | r->function = cplx_fftvec_addmul_fma;
40 | } else {
41 | r->function = cplx_fftvec_addmul_ref;
42 | }
43 | return r;
44 | }
45 |
46 | EXPORT void* init_cplx_fftvec_mul_precomp(CPLX_FFTVEC_MUL_PRECOMP* r, uint32_t m) {
47 | if (m & (m - 1)) return spqlios_error("m must be a power of two");
48 | r->m = m;
49 | if (m <= 4) {
50 | r->function = cplx_fftvec_mul_ref;
51 | } else if (CPU_SUPPORTS("fma")) {
52 | r->function = cplx_fftvec_mul_fma;
53 | } else {
54 | r->function = cplx_fftvec_mul_ref;
55 | }
56 | return r;
57 | }
58 |
59 | EXPORT CPLX_FFTVEC_ADDMUL_PRECOMP* new_cplx_fftvec_addmul_precomp(uint32_t m) {
60 | CPLX_FFTVEC_ADDMUL_PRECOMP* r = malloc(sizeof(CPLX_FFTVEC_MUL_PRECOMP));
61 | return spqlios_keep_or_free(r, init_cplx_fftvec_addmul_precomp(r, m));
62 | }
63 |
64 | EXPORT CPLX_FFTVEC_MUL_PRECOMP* new_cplx_fftvec_mul_precomp(uint32_t m) {
65 | CPLX_FFTVEC_MUL_PRECOMP* r = malloc(sizeof(CPLX_FFTVEC_MUL_PRECOMP));
66 | return spqlios_keep_or_free(r, init_cplx_fftvec_mul_precomp(r, m));
67 | }
68 |
69 | EXPORT void cplx_fftvec_mul_simple(uint32_t m, void* r, const void* a, const void* b) {
70 | static CPLX_FFTVEC_MUL_PRECOMP p[31] = {0};
71 | CPLX_FFTVEC_MUL_PRECOMP* f = p + log2m(m);
72 | if (!f->function) {
73 | if (!init_cplx_fftvec_mul_precomp(f, m)) abort();
74 | }
75 | f->function(f, r, a, b);
76 | }
77 |
78 | EXPORT void cplx_fftvec_addmul_simple(uint32_t m, void* r, const void* a, const void* b) {
79 | static CPLX_FFTVEC_ADDMUL_PRECOMP p[31] = {0};
80 | CPLX_FFTVEC_ADDMUL_PRECOMP* f = p + log2m(m);
81 | if (!f->function) {
82 | if (!init_cplx_fftvec_addmul_precomp(f, m)) abort();
83 | }
84 | f->function(f, r, a, b);
85 | }
86 |
--------------------------------------------------------------------------------
/spqlios/cplx/spqlios_cplx_fft.c:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tfhe/spqlios-arithmetic/7ea875f61c51c67687ec9c15ae0eae04cda961f8/spqlios/cplx/spqlios_cplx_fft.c
--------------------------------------------------------------------------------
/spqlios/ext/neon_accel/macrof.h:
--------------------------------------------------------------------------------
1 | /*
2 | * This file is extracted from the implementation of the FFT on Arm64/Neon
3 | * available in https://github.com/cothan/Falcon-Arm (neon/macrof.h).
4 | * =============================================================================
5 | * Copyright (c) 2022 by Cryptographic Engineering Research Group (CERG)
6 | * ECE Department, George Mason University
7 | * Fairfax, VA, U.S.A.
8 | * @author: Duc Tri Nguyen dnguye69@gmu.edu, cothannguyen@gmail.com
9 | * Licensed under the Apache License, Version 2.0 (the "License");
10 | * =============================================================================
11 | *
12 | * This 64-bit Floating point NEON macro x1 has not been modified and is provided as is.
13 | */
14 |
15 | #ifndef MACROF_H
16 | #define MACROF_H
17 |
18 | #include
19 |
20 | // c <= addr x1
21 | #define vload(c, addr) c = vld1q_f64(addr);
22 | // c <= addr interleave 2
23 | #define vload2(c, addr) c = vld2q_f64(addr);
24 | // c <= addr interleave 4
25 | #define vload4(c, addr) c = vld4q_f64(addr);
26 |
27 | #define vstore(addr, c) vst1q_f64(addr, c);
28 | // addr <= c
29 | #define vstore2(addr, c) vst2q_f64(addr, c);
30 | // addr <= c
31 | #define vstore4(addr, c) vst4q_f64(addr, c);
32 |
33 | // c <= addr x2
34 | #define vloadx2(c, addr) c = vld1q_f64_x2(addr);
35 | // c <= addr x3
36 | #define vloadx3(c, addr) c = vld1q_f64_x3(addr);
37 |
38 | // addr <= c
39 | #define vstorex2(addr, c) vst1q_f64_x2(addr, c);
40 |
41 | // c = a - b
42 | #define vfsub(c, a, b) c = vsubq_f64(a, b);
43 |
44 | // c = a + b
45 | #define vfadd(c, a, b) c = vaddq_f64(a, b);
46 |
47 | // c = a * b
48 | #define vfmul(c, a, b) c = vmulq_f64(a, b);
49 |
50 | // c = a * n (n is constant)
51 | #define vfmuln(c, a, n) c = vmulq_n_f64(a, n);
52 |
53 | // Swap from a|b to b|a
54 | #define vswap(c, a) c = vextq_f64(a, a, 1);
55 |
56 | // c = a * b[i]
57 | #define vfmul_lane(c, a, b, i) c = vmulq_laneq_f64(a, b, i);
58 |
59 | // c = 1/a
60 | #define vfinv(c, a) c = vdivq_f64(vdupq_n_f64(1.0), a);
61 |
62 | // c = -a
63 | #define vfneg(c, a) c = vnegq_f64(a);
64 |
65 | #define transpose_f64(a, b, t, ia, ib, it) \
66 | t.val[it] = a.val[ia]; \
67 | a.val[ia] = vzip1q_f64(a.val[ia], b.val[ib]); \
68 | b.val[ib] = vzip2q_f64(t.val[it], b.val[ib]);
69 |
70 | /*
71 | * c = a + jb
72 | * c[0] = a[0] - b[1]
73 | * c[1] = a[1] + b[0]
74 | */
75 | #define vfcaddj(c, a, b) c = vcaddq_rot90_f64(a, b);
76 |
77 | /*
78 | * c = a - jb
79 | * c[0] = a[0] + b[1]
80 | * c[1] = a[1] - b[0]
81 | */
82 | #define vfcsubj(c, a, b) c = vcaddq_rot270_f64(a, b);
83 |
84 | // c[0] = c[0] + b[0]*a[0], c[1] = c[1] + b[1]*a[0]
85 | #define vfcmla(c, a, b) c = vcmlaq_f64(c, a, b);
86 |
87 | // c[0] = c[0] - b[1]*a[1], c[1] = c[1] + b[0]*a[1]
88 | #define vfcmla_90(c, a, b) c = vcmlaq_rot90_f64(c, a, b);
89 |
90 | // c[0] = c[0] - b[0]*a[0], c[1] = c[1] - b[1]*a[0]
91 | #define vfcmla_180(c, a, b) c = vcmlaq_rot180_f64(c, a, b);
92 |
93 | // c[0] = c[0] + b[1]*a[1], c[1] = c[1] - b[0]*a[1]
94 | #define vfcmla_270(c, a, b) c = vcmlaq_rot270_f64(c, a, b);
95 |
96 | /*
97 | * Complex MUL: c = a*b
98 | * c[0] = a[0]*b[0] - a[1]*b[1]
99 | * c[1] = a[0]*b[1] + a[1]*b[0]
100 | */
101 | #define FPC_CMUL(c, a, b) \
102 | c = vmulq_laneq_f64(b, a, 0); \
103 | c = vcmlaq_rot90_f64(c, a, b);
104 |
105 | /*
106 | * Complex MUL: c = a * conjugate(b) = a * (b[0], -b[1])
107 | * c[0] = b[0]*a[0] + b[1]*a[1]
108 | * c[1] = + b[0]*a[1] - b[1]*a[0]
109 | */
110 | #define FPC_CMUL_CONJ(c, a, b) \
111 | c = vmulq_laneq_f64(a, b, 0); \
112 | c = vcmlaq_rot270_f64(c, b, a);
113 |
114 | #if FMA == 1
115 | // d = c + a *b
116 | #define vfmla(d, c, a, b) d = vfmaq_f64(c, a, b);
117 | // d = c - a * b
118 | #define vfmls(d, c, a, b) d = vfmsq_f64(c, a, b);
119 | // d = c + a * b[i]
120 | #define vfmla_lane(d, c, a, b, i) d = vfmaq_laneq_f64(c, a, b, i);
121 | // d = c - a * b[i]
122 | #define vfmls_lane(d, c, a, b, i) d = vfmsq_laneq_f64(c, a, b, i);
123 |
124 | #else
125 | // d = c + a *b
126 | #define vfmla(d, c, a, b) d = vaddq_f64(c, vmulq_f64(a, b));
127 | // d = c - a *b
128 | #define vfmls(d, c, a, b) d = vsubq_f64(c, vmulq_f64(a, b));
129 | // d = c + a * b[i]
130 | #define vfmla_lane(d, c, a, b, i) \
131 | d = vaddq_f64(c, vmulq_laneq_f64(a, b, i));
132 |
133 | #define vfmls_lane(d, c, a, b, i) \
134 | d = vsubq_f64(c, vmulq_laneq_f64(a, b, i));
135 |
136 | #endif
137 |
138 | #endif
139 |
--------------------------------------------------------------------------------
/spqlios/q120/q120_arithmetic_private.h:
--------------------------------------------------------------------------------
1 | #ifndef SPQLIOS_Q120_ARITHMETIC_DEF_H
2 | #define SPQLIOS_Q120_ARITHMETIC_DEF_H
3 |
4 | #include
5 |
6 | typedef struct _q120_mat1col_product_baa_precomp {
7 | uint64_t h;
8 | uint64_t h_pow_red[4];
9 | #ifndef NDEBUG
10 | double res_bit_size;
11 | #endif
12 | } q120_mat1col_product_baa_precomp;
13 |
14 | typedef struct _q120_mat1col_product_bbb_precomp {
15 | uint64_t h;
16 | uint64_t s1h_pow_red[4];
17 | uint64_t s2l_pow_red[4];
18 | uint64_t s2h_pow_red[4];
19 | uint64_t s3l_pow_red[4];
20 | uint64_t s3h_pow_red[4];
21 | uint64_t s4l_pow_red[4];
22 | uint64_t s4h_pow_red[4];
23 | #ifndef NDEBUG
24 | double res_bit_size;
25 | #endif
26 | } q120_mat1col_product_bbb_precomp;
27 |
28 | typedef struct _q120_mat1col_product_bbc_precomp {
29 | uint64_t h;
30 | uint64_t s2l_pow_red[4];
31 | uint64_t s2h_pow_red[4];
32 | #ifndef NDEBUG
33 | double res_bit_size;
34 | #endif
35 | } q120_mat1col_product_bbc_precomp;
36 |
37 | #endif // SPQLIOS_Q120_ARITHMETIC_DEF_H
38 |
--------------------------------------------------------------------------------
/spqlios/q120/q120_common.h:
--------------------------------------------------------------------------------
1 | #ifndef SPQLIOS_Q120_COMMON_H
2 | #define SPQLIOS_Q120_COMMON_H
3 |
4 | #include
5 |
6 | #if !defined(SPQLIOS_Q120_USE_29_BIT_PRIMES) && !defined(SPQLIOS_Q120_USE_30_BIT_PRIMES) && \
7 | !defined(SPQLIOS_Q120_USE_31_BIT_PRIMES)
8 | #define SPQLIOS_Q120_USE_30_BIT_PRIMES
9 | #endif
10 |
11 | /**
12 | * 29-bit primes and 2*2^16 roots of unity
13 | */
14 | #ifdef SPQLIOS_Q120_USE_29_BIT_PRIMES
15 | #define Q1 ((1u << 29) - 2 * (1u << 17) + 1)
16 | #define OMEGA1 78289835
17 | #define Q1_CRT_CST 301701286 // (Q2*Q3*Q4)^-1 mod Q1
18 |
19 | #define Q2 ((1u << 29) - 5 * (1u << 17) + 1)
20 | #define OMEGA2 178519192
21 | #define Q2_CRT_CST 536020447 // (Q1*Q3*Q4)^-1 mod Q2
22 |
23 | #define Q3 ((1u << 29) - 26 * (1u << 17) + 1)
24 | #define OMEGA3 483889678
25 | #define Q3_CRT_CST 86367873 // (Q1*Q2*Q4)^-1 mod Q3
26 |
27 | #define Q4 ((1u << 29) - 35 * (1u << 17) + 1)
28 | #define OMEGA4 239808033
29 | #define Q4_CRT_CST 147030781 // (Q1*Q2*Q3)^-1 mod Q4
30 | #endif
31 |
32 | /**
33 | * 30-bit primes and 2*2^16 roots of unity
34 | */
35 | #ifdef SPQLIOS_Q120_USE_30_BIT_PRIMES
36 | #define Q1 ((1u << 30) - 2 * (1u << 17) + 1)
37 | #define OMEGA1 1070907127
38 | #define Q1_CRT_CST 43599465 // (Q2*Q3*Q4)^-1 mod Q1
39 |
40 | #define Q2 ((1u << 30) - 17 * (1u << 17) + 1)
41 | #define OMEGA2 315046632
42 | #define Q2_CRT_CST 292938863 // (Q1*Q3*Q4)^-1 mod Q2
43 |
44 | #define Q3 ((1u << 30) - 23 * (1u << 17) + 1)
45 | #define OMEGA3 309185662
46 | #define Q3_CRT_CST 594011630 // (Q1*Q2*Q4)^-1 mod Q3
47 |
48 | #define Q4 ((1u << 30) - 42 * (1u << 17) + 1)
49 | #define OMEGA4 846468380
50 | #define Q4_CRT_CST 140177212 // (Q1*Q2*Q3)^-1 mod Q4
51 | #endif
52 |
53 | /**
54 | * 31-bit primes and 2*2^16 roots of unity
55 | */
56 | #ifdef SPQLIOS_Q120_USE_31_BIT_PRIMES
57 | #define Q1 ((1u << 31) - 1 * (1u << 17) + 1)
58 | #define OMEGA1 1615402923
59 | #define Q1_CRT_CST 1811422063 // (Q2*Q3*Q4)^-1 mod Q1
60 |
61 | #define Q2 ((1u << 31) - 4 * (1u << 17) + 1)
62 | #define OMEGA2 1137738560
63 | #define Q2_CRT_CST 2093150204 // (Q1*Q3*Q4)^-1 mod Q2
64 |
65 | #define Q3 ((1u << 31) - 11 * (1u << 17) + 1)
66 | #define OMEGA3 154880552
67 | #define Q3_CRT_CST 164149010 // (Q1*Q2*Q4)^-1 mod Q3
68 |
69 | #define Q4 ((1u << 31) - 23 * (1u << 17) + 1)
70 | #define OMEGA4 558784885
71 | #define Q4_CRT_CST 225197446 // (Q1*Q2*Q3)^-1 mod Q4
72 | #endif
73 |
74 | static const uint32_t PRIMES_VEC[4] = {Q1, Q2, Q3, Q4};
75 | static const uint32_t OMEGAS_VEC[4] = {OMEGA1, OMEGA2, OMEGA3, OMEGA4};
76 |
77 | #define MAX_ELL 10000
78 |
79 | // each number x mod Q120 is represented by uint64_t[4] with (non-unique) values (x mod q1, x mod q2,x mod q3,x mod q4),
80 | // each between [0 and 2^32-1]
81 | typedef struct _q120a q120a;
82 |
83 | // each number x mod Q120 is represented by uint64_t[4] with (non-unique) values (x mod q1, x mod q2,x mod q3,x mod q4),
84 | // each between [0 and 2^64-1]
85 | typedef struct _q120b q120b;
86 |
87 | // each number x mod Q120 is represented by uint32_t[8] with values (x mod q1, 2^32x mod q1, x mod q2, 2^32.x mod q2, x
88 | // mod q3, 2^32.x mod q3, x mod q4, 2^32.x mod q4) each between [0 and 2^32-1]
89 | typedef struct _q120c q120c;
90 |
91 | typedef struct _q120x2b q120x2b;
92 | typedef struct _q120x2c q120x2c;
93 |
94 | #endif // SPQLIOS_Q120_COMMON_H
95 |
--------------------------------------------------------------------------------
/spqlios/q120/q120_fallbacks_aarch64.c:
--------------------------------------------------------------------------------
1 | #include "q120_ntt_private.h"
2 |
3 | EXPORT void q120_ntt_bb_avx2(const q120_ntt_precomp* const precomp, q120b* const data) { UNDEFINED(); }
4 |
5 | EXPORT void q120_intt_bb_avx2(const q120_ntt_precomp* const precomp, q120b* const data) { UNDEFINED(); }
6 |
--------------------------------------------------------------------------------
/spqlios/q120/q120_ntt.h:
--------------------------------------------------------------------------------
1 | #ifndef SPQLIOS_Q120_NTT_H
2 | #define SPQLIOS_Q120_NTT_H
3 |
4 | #include "../commons.h"
5 | #include "q120_common.h"
6 |
7 | typedef struct _q120_ntt_precomp q120_ntt_precomp;
8 |
9 | EXPORT q120_ntt_precomp* q120_new_ntt_bb_precomp(const uint64_t n);
10 | EXPORT void q120_del_ntt_bb_precomp(q120_ntt_precomp* precomp);
11 |
12 | EXPORT q120_ntt_precomp* q120_new_intt_bb_precomp(const uint64_t n);
13 | EXPORT void q120_del_intt_bb_precomp(q120_ntt_precomp* precomp);
14 |
15 | /**
16 | * @brief computes a direct ntt in-place over data.
17 | */
18 | EXPORT void q120_ntt_bb_avx2(const q120_ntt_precomp* const precomp, q120b* const data);
19 |
20 | /**
21 | * @brief computes an inverse ntt in-place over data.
22 | */
23 | EXPORT void q120_intt_bb_avx2(const q120_ntt_precomp* const precomp, q120b* const data);
24 |
25 | #endif // SPQLIOS_Q120_NTT_H
26 |
--------------------------------------------------------------------------------
/spqlios/q120/q120_ntt_private.h:
--------------------------------------------------------------------------------
1 | #include "q120_ntt.h"
2 |
3 | #ifndef NDEBUG
4 | #define CHECK_BOUNDS 1
5 | #define VERBOSE
6 | #else
7 | #define CHECK_BOUNDS 0
8 | #endif
9 |
10 | #ifndef VERBOSE
11 | #define LOG(...) ;
12 | #else
13 | #define LOG(...) printf(__VA_ARGS__);
14 | #endif
15 |
16 | typedef struct _q120_ntt_step_precomp {
17 | uint64_t q2bs[4]; // q2bs = 2^{bs-31}.q[k]
18 | uint64_t bs; // inputs at this iterations must be in Q_n
19 | uint64_t half_bs; // == ceil(bs/2)
20 | uint64_t mask; // (1<
2 |
3 | #include "reim_fft_private.h"
4 |
5 | void reim_from_znx64_bnd50_fma(const REIM_FROM_ZNX64_PRECOMP* precomp, void* r, const int64_t* x) {
6 | static const double EXPO = INT64_C(1) << 52;
7 | const int64_t ADD_CST = INT64_C(1) << 51;
8 | const double SUB_CST = INT64_C(3) << 51;
9 |
10 | const __m256d SUB_CST_4 = _mm256_set1_pd(SUB_CST);
11 | const __m256i ADD_CST_4 = _mm256_set1_epi64x(ADD_CST);
12 | const __m256d EXPO_4 = _mm256_set1_pd(EXPO);
13 |
14 | double(*out)[4] = (double(*)[4])r;
15 | __m256i* in = (__m256i*)x;
16 | __m256i* inend = (__m256i*)(x + (precomp->m << 1));
17 | do {
18 | // read the next value
19 | __m256i a = _mm256_loadu_si256(in);
20 | a = _mm256_add_epi64(a, ADD_CST_4);
21 | __m256d ad = _mm256_castsi256_pd(a);
22 | ad = _mm256_or_pd(ad, EXPO_4);
23 | ad = _mm256_sub_pd(ad, SUB_CST_4);
24 | // store the next value
25 | _mm256_storeu_pd(out[0], ad);
26 | ++out;
27 | ++in;
28 | } while (in < inend);
29 | }
30 |
31 | // version where the output norm can be as big as 2^63
32 | void reim_to_znx64_avx2_bnd63_fma(const REIM_TO_ZNX64_PRECOMP* precomp, int64_t* r, const void* x) {
33 | static const uint64_t SIGN_MASK = 0x8000000000000000UL;
34 | static const uint64_t EXPO_MASK = 0x7FF0000000000000UL;
35 | static const uint64_t MANTISSA_MASK = 0x000FFFFFFFFFFFFFUL;
36 | static const uint64_t MANTISSA_MSB = 0x0010000000000000UL;
37 | const double divisor_bits = precomp->divisor * ((double)(INT64_C(1) << 52));
38 | const double offset = precomp->divisor / 2.;
39 |
40 | const __m256d SIGN_MASK_4 = _mm256_castsi256_pd(_mm256_set1_epi64x(SIGN_MASK));
41 | const __m256i EXPO_MASK_4 = _mm256_set1_epi64x(EXPO_MASK);
42 | const __m256i MANTISSA_MASK_4 = _mm256_set1_epi64x(MANTISSA_MASK);
43 | const __m256i MANTISSA_MSB_4 = _mm256_set1_epi64x(MANTISSA_MSB);
44 | const __m256d offset_4 = _mm256_set1_pd(offset);
45 | const __m256i divi_bits_4 = _mm256_castpd_si256(_mm256_set1_pd(divisor_bits));
46 |
47 | double(*in)[4] = (double(*)[4])x;
48 | __m256i* out = (__m256i*)r;
49 | __m256i* outend = (__m256i*)(r + (precomp->m << 1));
50 | do {
51 | // read the next value
52 | __m256d a = _mm256_loadu_pd(in[0]);
53 | // a += sign(a) * m/2
54 | __m256d asign = _mm256_and_pd(a, SIGN_MASK_4);
55 | a = _mm256_add_pd(a, _mm256_or_pd(asign, offset_4));
56 | // sign: either 0 or -1
57 | __m256i sign_mask = _mm256_castpd_si256(asign);
58 | sign_mask = _mm256_sub_epi64(_mm256_set1_epi64x(0), _mm256_srli_epi64(sign_mask, 63));
59 | // compute the exponents
60 | __m256i a0exp = _mm256_and_si256(_mm256_castpd_si256(a), EXPO_MASK_4);
61 | __m256i a0lsh = _mm256_sub_epi64(a0exp, divi_bits_4);
62 | __m256i a0rsh = _mm256_sub_epi64(divi_bits_4, a0exp);
63 | a0lsh = _mm256_srli_epi64(a0lsh, 52);
64 | a0rsh = _mm256_srli_epi64(a0rsh, 52);
65 | // compute the new mantissa
66 | __m256i a0pos = _mm256_and_si256(_mm256_castpd_si256(a), MANTISSA_MASK_4);
67 | a0pos = _mm256_or_si256(a0pos, MANTISSA_MSB_4);
68 | a0lsh = _mm256_sllv_epi64(a0pos, a0lsh);
69 | a0rsh = _mm256_srlv_epi64(a0pos, a0rsh);
70 | __m256i final = _mm256_or_si256(a0lsh, a0rsh);
71 | // negate if the sign was negative
72 | final = _mm256_xor_si256(final, sign_mask);
73 | final = _mm256_sub_epi64(final, sign_mask);
74 | // read the next value
75 | _mm256_storeu_si256(out, final);
76 | ++out;
77 | ++in;
78 | } while (out < outend);
79 | }
80 |
81 | // version where the output norm can be as big as 2^50
82 | void reim_to_znx64_avx2_bnd50_fma(const REIM_TO_ZNX64_PRECOMP* precomp, int64_t* r, const void* x) {
83 | static const uint64_t MANTISSA_MASK = 0x000FFFFFFFFFFFFFUL;
84 | const int64_t SUB_CST = INT64_C(1) << 51;
85 | const double add_cst = precomp->divisor * ((double)(INT64_C(3) << 51));
86 |
87 | const __m256i SUB_CST_4 = _mm256_set1_epi64x(SUB_CST);
88 | const __m256d add_cst_4 = _mm256_set1_pd(add_cst);
89 | const __m256i MANTISSA_MASK_4 = _mm256_set1_epi64x(MANTISSA_MASK);
90 |
91 | double(*in)[4] = (double(*)[4])x;
92 | __m256i* out = (__m256i*)r;
93 | __m256i* outend = (__m256i*)(r + (precomp->m << 1));
94 | do {
95 | // read the next value
96 | __m256d a = _mm256_loadu_pd(in[0]);
97 | a = _mm256_add_pd(a, add_cst_4);
98 | __m256i ai = _mm256_castpd_si256(a);
99 | ai = _mm256_and_si256(ai, MANTISSA_MASK_4);
100 | ai = _mm256_sub_epi64(ai, SUB_CST_4);
101 | // store the next value
102 | _mm256_storeu_si256(out, ai);
103 | ++out;
104 | ++in;
105 | } while (out < outend);
106 | }
107 |
--------------------------------------------------------------------------------
/spqlios/reim/reim_execute.c:
--------------------------------------------------------------------------------
1 | #include "reim_fft_internal.h"
2 | #include "reim_fft_private.h"
3 |
4 | EXPORT void reim_from_znx32(const REIM_FROM_ZNX32_PRECOMP* tables, void* r, const int32_t* a) {
5 | tables->function(tables, r, a);
6 | }
7 |
8 | EXPORT void reim_from_tnx32(const REIM_FROM_TNX32_PRECOMP* tables, void* r, const int32_t* a) {
9 | tables->function(tables, r, a);
10 | }
11 |
12 | EXPORT void reim_to_tnx32(const REIM_TO_TNX32_PRECOMP* tables, int32_t* r, const void* a) {
13 | tables->function(tables, r, a);
14 | }
15 |
16 | EXPORT void reim_fftvec_mul(const REIM_FFTVEC_MUL_PRECOMP* tables, double* r, const double* a, const double* b) {
17 | tables->function(tables, r, a, b);
18 | }
19 |
20 | EXPORT void reim_fftvec_addmul(const REIM_FFTVEC_ADDMUL_PRECOMP* tables, double* r, const double* a, const double* b) {
21 | tables->function(tables, r, a, b);
22 | }
23 |
--------------------------------------------------------------------------------
/spqlios/reim/reim_fallbacks_aarch64.c:
--------------------------------------------------------------------------------
1 | #include "reim_fft_private.h"
2 |
3 | EXPORT void reim_fftvec_addmul_fma(const REIM_FFTVEC_ADDMUL_PRECOMP* precomp, double* r, const double* a,
4 | const double* b) {
5 | UNDEFINED();
6 | }
7 | EXPORT void reim_fftvec_mul_fma(const REIM_FFTVEC_MUL_PRECOMP* precomp, double* r, const double* a, const double* b) {
8 | UNDEFINED();
9 | }
10 |
11 | EXPORT void reim_fft_avx2_fma(const REIM_FFT_PRECOMP* tables, double* data) { UNDEFINED(); }
12 | EXPORT void reim_ifft_avx2_fma(const REIM_IFFT_PRECOMP* tables, double* data) { UNDEFINED(); }
13 |
14 | //EXPORT void reim_fft(const REIM_FFT_PRECOMP* tables, double* data) { tables->function(tables, data); }
15 | //EXPORT void reim_ifft(const REIM_IFFT_PRECOMP* tables, double* data) { tables->function(tables, data); }
16 |
--------------------------------------------------------------------------------
/spqlios/reim/reim_fft4_avx_fma.c:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include "reim_fft_private.h"
6 |
7 | __always_inline void reim_ctwiddle_avx_fma(__m128d* ra, __m128d* rb, __m128d* ia, __m128d* ib, const __m128d omre,
8 | const __m128d omim) {
9 | // rb * omre - ib * omim;
10 | __m128d rprod0 = _mm_mul_pd(*ib, omim);
11 | rprod0 = _mm_fmsub_pd(*rb, omre, rprod0);
12 |
13 | // rb * omim + ib * omre;
14 | __m128d iprod0 = _mm_mul_pd(*rb, omim);
15 | iprod0 = _mm_fmadd_pd(*ib, omre, iprod0);
16 |
17 | *rb = _mm_sub_pd(*ra, rprod0);
18 | *ib = _mm_sub_pd(*ia, iprod0);
19 | *ra = _mm_add_pd(*ra, rprod0);
20 | *ia = _mm_add_pd(*ia, iprod0);
21 | }
22 |
23 | EXPORT void reim_fft4_avx_fma(double* dre, double* dim, const void* ompv) {
24 | const double* omp = (const double*)ompv;
25 |
26 | __m128d ra0 = _mm_loadu_pd(dre);
27 | __m128d ra2 = _mm_loadu_pd(dre + 2);
28 | __m128d ia0 = _mm_loadu_pd(dim);
29 | __m128d ia2 = _mm_loadu_pd(dim + 2);
30 |
31 | // 1
32 | {
33 | // duplicate omegas in precomp?
34 | __m128d om = _mm_loadu_pd(omp);
35 | __m128d omre = _mm_permute_pd(om, 0);
36 | __m128d omim = _mm_permute_pd(om, 3);
37 |
38 | reim_ctwiddle_avx_fma(&ra0, &ra2, &ia0, &ia2, omre, omim);
39 | }
40 |
41 | // 2
42 | {
43 | const __m128d fft4neg = _mm_castsi128_pd(_mm_set_epi64x(UINT64_C(1) << 63, 0));
44 | __m128d om = _mm_loadu_pd(omp + 2); // om: r,i
45 | __m128d omim = _mm_permute_pd(om, 1); // omim: i,r
46 | __m128d omre = _mm_xor_pd(om, fft4neg); // omre: r,-i
47 |
48 | __m128d rb = _mm_unpackhi_pd(ra0, ra2); // (r0, r1), (r2, r3) -> (r1, r3)
49 | __m128d ib = _mm_unpackhi_pd(ia0, ia2); // (i0, i1), (i2, i3) -> (i1, i3)
50 | __m128d ra = _mm_unpacklo_pd(ra0, ra2); // (r0, r1), (r2, r3) -> (r0, r2)
51 | __m128d ia = _mm_unpacklo_pd(ia0, ia2); // (i0, i1), (i2, i3) -> (i0, i2)
52 |
53 | reim_ctwiddle_avx_fma(&ra, &rb, &ia, &ib, omre, omim);
54 |
55 | ra0 = _mm_unpacklo_pd(ra, rb);
56 | ia0 = _mm_unpacklo_pd(ia, ib);
57 | ra2 = _mm_unpackhi_pd(ra, rb);
58 | ia2 = _mm_unpackhi_pd(ia, ib);
59 | }
60 |
61 | // 4
62 | _mm_storeu_pd(dre, ra0);
63 | _mm_storeu_pd(dre + 2, ra2);
64 | _mm_storeu_pd(dim, ia0);
65 | _mm_storeu_pd(dim + 2, ia2);
66 | }
67 |
--------------------------------------------------------------------------------
/spqlios/reim/reim_fft8_avx_fma.c:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include "reim_fft_private.h"
6 |
7 | __always_inline void reim_ctwiddle_avx_fma(__m256d* ra, __m256d* rb, __m256d* ia, __m256d* ib, const __m256d omre,
8 | const __m256d omim) {
9 | // rb * omre - ib * omim;
10 | __m256d rprod0 = _mm256_mul_pd(*ib, omim);
11 | rprod0 = _mm256_fmsub_pd(*rb, omre, rprod0);
12 |
13 | // rb * omim + ib * omre;
14 | __m256d iprod0 = _mm256_mul_pd(*rb, omim);
15 | iprod0 = _mm256_fmadd_pd(*ib, omre, iprod0);
16 |
17 | *rb = _mm256_sub_pd(*ra, rprod0);
18 | *ib = _mm256_sub_pd(*ia, iprod0);
19 | *ra = _mm256_add_pd(*ra, rprod0);
20 | *ia = _mm256_add_pd(*ia, iprod0);
21 | }
22 |
23 | EXPORT void reim_fft8_avx_fma(double* dre, double* dim, const void* ompv) {
24 | const double* omp = (const double*)ompv;
25 |
26 | __m256d ra0 = _mm256_loadu_pd(dre);
27 | __m256d ra4 = _mm256_loadu_pd(dre + 4);
28 | __m256d ia0 = _mm256_loadu_pd(dim);
29 | __m256d ia4 = _mm256_loadu_pd(dim + 4);
30 |
31 | // 1
32 | {
33 | // duplicate omegas in precomp?
34 | __m128d omri = _mm_loadu_pd(omp);
35 | __m256d omriri = _mm256_set_m128d(omri, omri);
36 | __m256d omi = _mm256_permute_pd(omriri, 15);
37 | __m256d omr = _mm256_permute_pd(omriri, 0);
38 |
39 | reim_ctwiddle_avx_fma(&ra0, &ra4, &ia0, &ia4, omr, omi);
40 | }
41 |
42 | // 2
43 | {
44 | const __m128d fft8neg = _mm_castsi128_pd(_mm_set_epi64x(UINT64_C(1) << 63, 0));
45 | __m128d omri = _mm_loadu_pd(omp + 2); // r,i
46 | __m128d omrmi = _mm_xor_pd(omri, fft8neg); // r,-i
47 | __m256d omrirmi = _mm256_set_m128d(omrmi, omri); // r,i,r,-i
48 | __m256d omi = _mm256_permute_pd(omrirmi, 3); // i,i,r,r
49 | __m256d omr = _mm256_permute_pd(omrirmi, 12); // r,r,-i,-i
50 |
51 | __m256d rb = _mm256_permute2f128_pd(ra0, ra4, 0x31);
52 | __m256d ib = _mm256_permute2f128_pd(ia0, ia4, 0x31);
53 | __m256d ra = _mm256_permute2f128_pd(ra0, ra4, 0x20);
54 | __m256d ia = _mm256_permute2f128_pd(ia0, ia4, 0x20);
55 |
56 | reim_ctwiddle_avx_fma(&ra, &rb, &ia, &ib, omr, omi);
57 |
58 | ra0 = _mm256_permute2f128_pd(ra, rb, 0x20);
59 | ra4 = _mm256_permute2f128_pd(ra, rb, 0x31);
60 | ia0 = _mm256_permute2f128_pd(ia, ib, 0x20);
61 | ia4 = _mm256_permute2f128_pd(ia, ib, 0x31);
62 | }
63 |
64 | // 3
65 | {
66 | const __m256d fft8neg2 = _mm256_castsi256_pd(_mm256_set_epi64x(UINT64_C(1) << 63, UINT64_C(1) << 63, 0, 0));
67 | __m256d om = _mm256_loadu_pd(omp + 4); // r0,r1,i0,i1
68 | __m256d omi = _mm256_permute2f128_pd(om, om, 1); // i0,i1,r0,r1
69 | __m256d omr = _mm256_xor_pd(om, fft8neg2); // r0,r1,-i0,-i1
70 |
71 | __m256d rb = _mm256_unpackhi_pd(ra0, ra4);
72 | __m256d ib = _mm256_unpackhi_pd(ia0, ia4);
73 | __m256d ra = _mm256_unpacklo_pd(ra0, ra4);
74 | __m256d ia = _mm256_unpacklo_pd(ia0, ia4);
75 |
76 | reim_ctwiddle_avx_fma(&ra, &rb, &ia, &ib, omr, omi);
77 |
78 | ra4 = _mm256_unpackhi_pd(ra, rb);
79 | ia4 = _mm256_unpackhi_pd(ia, ib);
80 | ra0 = _mm256_unpacklo_pd(ra, rb);
81 | ia0 = _mm256_unpacklo_pd(ia, ib);
82 | }
83 |
84 | // 4
85 | _mm256_storeu_pd(dre, ra0);
86 | _mm256_storeu_pd(dre + 4, ra4);
87 | _mm256_storeu_pd(dim, ia0);
88 | _mm256_storeu_pd(dim + 4, ia4);
89 | }
90 |
--------------------------------------------------------------------------------
/spqlios/reim/reim_fft_ifft.c:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 |
6 | #include "../commons_private.h"
7 | #include "reim_fft_internal.h"
8 | #include "reim_fft_private.h"
9 |
10 | double accurate_cos(int32_t i, int32_t n) { // cos(2pi*i/n)
11 | i = ((i % n) + n) % n;
12 | if (i >= 3 * n / 4) return cos(2. * M_PI * (n - i) / (double)(n));
13 | if (i >= 2 * n / 4) return -cos(2. * M_PI * (i - n / 2) / (double)(n));
14 | if (i >= 1 * n / 4) return -cos(2. * M_PI * (n / 2 - i) / (double)(n));
15 | return cos(2. * M_PI * (i) / (double)(n));
16 | }
17 |
18 | double accurate_sin(int32_t i, int32_t n) { // sin(2pi*i/n)
19 | i = ((i % n) + n) % n;
20 | if (i >= 3 * n / 4) return -sin(2. * M_PI * (n - i) / (double)(n));
21 | if (i >= 2 * n / 4) return -sin(2. * M_PI * (i - n / 2) / (double)(n));
22 | if (i >= 1 * n / 4) return sin(2. * M_PI * (n / 2 - i) / (double)(n));
23 | return sin(2. * M_PI * (i) / (double)(n));
24 | }
25 |
26 |
27 | EXPORT double* reim_ifft_precomp_get_buffer(const REIM_IFFT_PRECOMP* tables, uint32_t buffer_index) {
28 | return (double*)((uint8_t*) tables->aligned_buffers + buffer_index * tables->buf_size);
29 | }
30 |
31 | EXPORT double* reim_fft_precomp_get_buffer(const REIM_FFT_PRECOMP* tables, uint32_t buffer_index) {
32 | return (double*)((uint8_t*) tables->aligned_buffers + buffer_index * tables->buf_size);
33 | }
34 |
35 |
36 | EXPORT void reim_fft(const REIM_FFT_PRECOMP* tables, double* data) { tables->function(tables, data); }
37 | EXPORT void reim_ifft(const REIM_IFFT_PRECOMP* tables, double* data) { tables->function(tables, data); }
38 |
--------------------------------------------------------------------------------
/spqlios/reim/reim_fft_private.h:
--------------------------------------------------------------------------------
1 | #ifndef SPQLIOS_REIM_FFT_PRIVATE_H
2 | #define SPQLIOS_REIM_FFT_PRIVATE_H
3 |
4 | #include "../commons_private.h"
5 | #include "reim_fft.h"
6 |
7 | #define STATIC_ASSERT(condition) (void)sizeof(char[-1 + 2 * !!(condition)])
8 |
9 | typedef struct reim_twiddle_precomp REIM_FFTVEC_TWIDDLE_PRECOMP;
10 | typedef struct reim_bitwiddle_precomp REIM_FFTVEC_BITWIDDLE_PRECOMP;
11 |
12 | typedef void (*FFT_FUNC)(const REIM_FFT_PRECOMP*, double*);
13 | typedef void (*IFFT_FUNC)(const REIM_IFFT_PRECOMP*, double*);
14 | typedef void (*FFTVEC_MUL_FUNC)(const REIM_FFTVEC_MUL_PRECOMP*, double*, const double*, const double*);
15 | typedef void (*FFTVEC_ADDMUL_FUNC)(const REIM_FFTVEC_ADDMUL_PRECOMP*, double*, const double*, const double*);
16 |
17 | typedef void (*FROM_ZNX32_FUNC)(const REIM_FROM_ZNX32_PRECOMP*, void*, const int32_t*);
18 | typedef void (*FROM_ZNX64_FUNC)(const REIM_FROM_ZNX64_PRECOMP*, void*, const int64_t*);
19 | typedef void (*FROM_TNX32_FUNC)(const REIM_FROM_TNX32_PRECOMP*, void*, const int32_t*);
20 | typedef void (*TO_TNX32_FUNC)(const REIM_TO_TNX32_PRECOMP*, int32_t*, const void*);
21 | typedef void (*TO_TNX_FUNC)(const REIM_TO_TNX_PRECOMP*, double*, const double*);
22 | typedef void (*TO_ZNX64_FUNC)(const REIM_TO_ZNX64_PRECOMP*, int64_t*, const void*);
23 | typedef void (*FFTVEC_TWIDDLE_FUNC)(const REIM_FFTVEC_TWIDDLE_PRECOMP*, void*, const void*, const void*);
24 | typedef void (*FFTVEC_BITWIDDLE_FUNC)(const REIM_FFTVEC_BITWIDDLE_PRECOMP*, void*, uint64_t, const void*);
25 |
26 | typedef struct reim_fft_precomp {
27 | FFT_FUNC function;
28 | int64_t m; ///< complex dimension warning: reim uses n=2N=4m
29 | uint64_t buf_size; ///< size of aligned_buffers (multiple of 64B)
30 | double* powomegas; ///< 64B aligned
31 | void* aligned_buffers; ///< 64B aligned
32 | } REIM_FFT_PRECOMP;
33 |
34 | typedef struct reim_ifft_precomp {
35 | IFFT_FUNC function;
36 | int64_t m; // warning: reim uses n=2N=4m
37 | uint64_t buf_size; ///< size of aligned_buffers (multiple of 64B)
38 | double* powomegas;
39 | void* aligned_buffers;
40 | } REIM_IFFT_PRECOMP;
41 |
42 | typedef struct reim_mul_precomp {
43 | FFTVEC_MUL_FUNC function;
44 | int64_t m;
45 | } REIM_FFTVEC_MUL_PRECOMP;
46 |
47 | typedef struct reim_addmul_precomp {
48 | FFTVEC_ADDMUL_FUNC function;
49 | int64_t m;
50 | } REIM_FFTVEC_ADDMUL_PRECOMP;
51 |
52 |
53 | struct reim_from_znx32_precomp {
54 | FROM_ZNX32_FUNC function;
55 | int64_t m;
56 | };
57 |
58 | struct reim_from_znx64_precomp {
59 | FROM_ZNX64_FUNC function;
60 | int64_t m;
61 | };
62 |
63 | struct reim_from_tnx32_precomp {
64 | FROM_TNX32_FUNC function;
65 | int64_t m;
66 | };
67 |
68 | struct reim_to_tnx32_precomp {
69 | TO_TNX32_FUNC function;
70 | int64_t m;
71 | double divisor;
72 | };
73 |
74 | struct reim_to_tnx_precomp {
75 | TO_TNX_FUNC function;
76 | int64_t m;
77 | double divisor;
78 | uint32_t log2overhead;
79 | double add_cst;
80 | uint64_t mask_and;
81 | uint64_t mask_or;
82 | double sub_cst;
83 | };
84 |
85 | struct reim_to_znx64_precomp {
86 | TO_ZNX64_FUNC function;
87 | int64_t m;
88 | double divisor;
89 | };
90 |
91 | struct reim_twiddle_precomp {
92 | FFTVEC_TWIDDLE_FUNC function;
93 | int64_t m;
94 | };
95 |
96 | struct reim_bitwiddle_precomp {
97 | FFTVEC_BITWIDDLE_FUNC function;
98 | int64_t m;
99 | };
100 |
101 | #endif // SPQLIOS_REIM_FFT_PRIVATE_H
102 |
--------------------------------------------------------------------------------
/spqlios/reim/reim_fftvec_addmul_fma.c:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include "reim_fft_private.h"
6 |
7 | EXPORT void reim_fftvec_addmul_fma(const REIM_FFTVEC_ADDMUL_PRECOMP* precomp, double* r, const double* a,
8 | const double* b) {
9 | const uint64_t m = precomp->m;
10 | double* rr_ptr = r;
11 | double* ri_ptr = r + m;
12 | const double* ar_ptr = a;
13 | const double* ai_ptr = a + m;
14 | const double* br_ptr = b;
15 | const double* bi_ptr = b + m;
16 |
17 | const double* const rend_ptr = ri_ptr;
18 | while (rr_ptr != rend_ptr) {
19 | __m256d rr = _mm256_loadu_pd(rr_ptr);
20 | __m256d ri = _mm256_loadu_pd(ri_ptr);
21 | const __m256d ar = _mm256_loadu_pd(ar_ptr);
22 | const __m256d ai = _mm256_loadu_pd(ai_ptr);
23 | const __m256d br = _mm256_loadu_pd(br_ptr);
24 | const __m256d bi = _mm256_loadu_pd(bi_ptr);
25 |
26 | rr = _mm256_fmsub_pd(ai, bi, rr);
27 | rr = _mm256_fmsub_pd(ar, br, rr);
28 | ri = _mm256_fmadd_pd(ar, bi, ri);
29 | ri = _mm256_fmadd_pd(ai, br, ri);
30 |
31 | _mm256_storeu_pd(rr_ptr, rr);
32 | _mm256_storeu_pd(ri_ptr, ri);
33 |
34 | rr_ptr += 4;
35 | ri_ptr += 4;
36 | ar_ptr += 4;
37 | ai_ptr += 4;
38 | br_ptr += 4;
39 | bi_ptr += 4;
40 | }
41 | }
42 |
43 | EXPORT void reim_fftvec_mul_fma(const REIM_FFTVEC_MUL_PRECOMP* precomp, double* r, const double* a, const double* b) {
44 | const uint64_t m = precomp->m;
45 | double* rr_ptr = r;
46 | double* ri_ptr = r + m;
47 | const double* ar_ptr = a;
48 | const double* ai_ptr = a + m;
49 | const double* br_ptr = b;
50 | const double* bi_ptr = b + m;
51 |
52 | const double* const rend_ptr = ri_ptr;
53 | while (rr_ptr != rend_ptr) {
54 | const __m256d ar = _mm256_loadu_pd(ar_ptr);
55 | const __m256d ai = _mm256_loadu_pd(ai_ptr);
56 | const __m256d br = _mm256_loadu_pd(br_ptr);
57 | const __m256d bi = _mm256_loadu_pd(bi_ptr);
58 |
59 | const __m256d t1 = _mm256_mul_pd(ai, bi);
60 | const __m256d t2 = _mm256_mul_pd(ar, bi);
61 |
62 | __m256d rr = _mm256_fmsub_pd(ar, br, t1);
63 | __m256d ri = _mm256_fmadd_pd(ai, br, t2);
64 |
65 | _mm256_storeu_pd(rr_ptr, rr);
66 | _mm256_storeu_pd(ri_ptr, ri);
67 |
68 | rr_ptr += 4;
69 | ri_ptr += 4;
70 | ar_ptr += 4;
71 | ai_ptr += 4;
72 | br_ptr += 4;
73 | bi_ptr += 4;
74 | }
75 | }
76 |
--------------------------------------------------------------------------------
/spqlios/reim/reim_fftvec_addmul_ref.c:
--------------------------------------------------------------------------------
1 | #include "reim_fft_internal.h"
2 | #include "reim_fft_private.h"
3 |
4 | EXPORT void reim_fftvec_addmul_ref(const REIM_FFTVEC_ADDMUL_PRECOMP* precomp, double* r, const double* a,
5 | const double* b) {
6 | const uint64_t m = precomp->m;
7 | for (uint64_t i = 0; i < m; ++i) {
8 | double re = a[i] * b[i] - a[i + m] * b[i + m];
9 | double im = a[i] * b[i + m] + a[i + m] * b[i];
10 | r[i] += re;
11 | r[i + m] += im;
12 | }
13 | }
14 |
15 | EXPORT void reim_fftvec_mul_ref(const REIM_FFTVEC_MUL_PRECOMP* precomp, double* r, const double* a, const double* b) {
16 | const uint64_t m = precomp->m;
17 | for (uint64_t i = 0; i < m; ++i) {
18 | double re = a[i] * b[i] - a[i + m] * b[i + m];
19 | double im = a[i] * b[i + m] + a[i + m] * b[i];
20 | r[i] = re;
21 | r[i + m] = im;
22 | }
23 | }
24 |
25 | EXPORT REIM_FFTVEC_ADDMUL_PRECOMP* new_reim_fftvec_addmul_precomp(uint32_t m) {
26 | REIM_FFTVEC_ADDMUL_PRECOMP* reps = malloc(sizeof(REIM_FFTVEC_ADDMUL_PRECOMP));
27 | reps->m = m;
28 | if (CPU_SUPPORTS("fma")) {
29 | if (m >= 4) {
30 | reps->function = reim_fftvec_addmul_fma;
31 | } else {
32 | reps->function = reim_fftvec_addmul_ref;
33 | }
34 | } else {
35 | reps->function = reim_fftvec_addmul_ref;
36 | }
37 | return reps;
38 | }
39 |
40 | EXPORT REIM_FFTVEC_MUL_PRECOMP* new_reim_fftvec_mul_precomp(uint32_t m) {
41 | REIM_FFTVEC_MUL_PRECOMP* reps = malloc(sizeof(REIM_FFTVEC_MUL_PRECOMP));
42 | reps->m = m;
43 | if (CPU_SUPPORTS("fma")) {
44 | if (m >= 4) {
45 | reps->function = reim_fftvec_mul_fma;
46 | } else {
47 | reps->function = reim_fftvec_mul_ref;
48 | }
49 | } else {
50 | reps->function = reim_fftvec_mul_ref;
51 | }
52 | return reps;
53 | }
54 |
55 |
--------------------------------------------------------------------------------
/spqlios/reim/reim_ifft4_avx_fma.c:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include "reim_fft_private.h"
6 |
7 | __always_inline void reim_invctwiddle_avx_fma(__m128d* ra, __m128d* rb, __m128d* ia, __m128d* ib, const __m128d omre,
8 | const __m128d omim) {
9 | __m128d rdiff = _mm_sub_pd(*ra, *rb);
10 | __m128d idiff = _mm_sub_pd(*ia, *ib);
11 | *ra = _mm_add_pd(*ra, *rb);
12 | *ia = _mm_add_pd(*ia, *ib);
13 |
14 | *rb = _mm_mul_pd(idiff, omim);
15 | *rb = _mm_fmsub_pd(rdiff, omre, *rb);
16 |
17 | *ib = _mm_mul_pd(rdiff, omim);
18 | *ib = _mm_fmadd_pd(idiff, omre, *ib);
19 | }
20 |
21 | EXPORT void reim_ifft4_avx_fma(double* dre, double* dim, const void* ompv) {
22 | const double* omp = (const double*)ompv;
23 |
24 | __m128d ra0 = _mm_loadu_pd(dre);
25 | __m128d ra2 = _mm_loadu_pd(dre + 2);
26 | __m128d ia0 = _mm_loadu_pd(dim);
27 | __m128d ia2 = _mm_loadu_pd(dim + 2);
28 |
29 | // 1
30 | {
31 | const __m128d ifft4neg = _mm_castsi128_pd(_mm_set_epi64x(UINT64_C(1) << 63, 0));
32 | __m128d omre = _mm_loadu_pd(omp); // omre: r,i
33 | __m128d omim = _mm_xor_pd(_mm_permute_pd(omre, 1), ifft4neg); // omim: i,-r
34 |
35 | __m128d ra = _mm_unpacklo_pd(ra0, ra2); // (r0, r1), (r2, r3) -> (r0, r2)
36 | __m128d ia = _mm_unpacklo_pd(ia0, ia2); // (i0, i1), (i2, i3) -> (i0, i2)
37 | __m128d rb = _mm_unpackhi_pd(ra0, ra2); // (r0, r1), (r2, r3) -> (r1, r3)
38 | __m128d ib = _mm_unpackhi_pd(ia0, ia2); // (i0, i1), (i2, i3) -> (i1, i3)
39 |
40 | reim_invctwiddle_avx_fma(&ra, &rb, &ia, &ib, omre, omim);
41 |
42 | ra0 = _mm_unpacklo_pd(ra, rb);
43 | ia0 = _mm_unpacklo_pd(ia, ib);
44 | ra2 = _mm_unpackhi_pd(ra, rb);
45 | ia2 = _mm_unpackhi_pd(ia, ib);
46 | }
47 |
48 | // 2
49 | {
50 | __m128d om = _mm_loadu_pd(omp + 2);
51 | __m128d omre = _mm_permute_pd(om, 0);
52 | __m128d omim = _mm_permute_pd(om, 3);
53 |
54 | reim_invctwiddle_avx_fma(&ra0, &ra2, &ia0, &ia2, omre, omim);
55 | }
56 |
57 | // 4
58 | _mm_storeu_pd(dre, ra0);
59 | _mm_storeu_pd(dre + 2, ra2);
60 | _mm_storeu_pd(dim, ia0);
61 | _mm_storeu_pd(dim + 2, ia2);
62 | }
63 |
--------------------------------------------------------------------------------
/spqlios/reim/reim_ifft8_avx_fma.c:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include "reim_fft_private.h"
6 |
7 | __always_inline void reim_invctwiddle_avx_fma(__m256d* ra, __m256d* rb, __m256d* ia, __m256d* ib, const __m256d omre,
8 | const __m256d omim) {
9 | __m256d rdiff = _mm256_sub_pd(*ra, *rb);
10 | __m256d idiff = _mm256_sub_pd(*ia, *ib);
11 | *ra = _mm256_add_pd(*ra, *rb);
12 | *ia = _mm256_add_pd(*ia, *ib);
13 |
14 | *rb = _mm256_mul_pd(idiff, omim);
15 | *rb = _mm256_fmsub_pd(rdiff, omre, *rb);
16 |
17 | *ib = _mm256_mul_pd(rdiff, omim);
18 | *ib = _mm256_fmadd_pd(idiff, omre, *ib);
19 | }
20 |
21 | EXPORT void reim_ifft8_avx_fma(double* dre, double* dim, const void* ompv) {
22 | const double* omp = (const double*)ompv;
23 |
24 | __m256d ra0 = _mm256_loadu_pd(dre);
25 | __m256d ra4 = _mm256_loadu_pd(dre + 4);
26 | __m256d ia0 = _mm256_loadu_pd(dim);
27 | __m256d ia4 = _mm256_loadu_pd(dim + 4);
28 |
29 | // 1
30 | {
31 | const __m256d fft8neg2 = _mm256_castsi256_pd(_mm256_set_epi64x(UINT64_C(1) << 63, UINT64_C(1) << 63, 0, 0));
32 | __m256d omr = _mm256_loadu_pd(omp); // r0,r1,i0,i1
33 | __m256d omiirr = _mm256_permute2f128_pd(omr, omr, 1); // i0,i1,r0,r1
34 | __m256d omi = _mm256_xor_pd(omiirr, fft8neg2); // i0,i1,-r0,-r1
35 |
36 | __m256d rb = _mm256_unpackhi_pd(ra0, ra4);
37 | __m256d ib = _mm256_unpackhi_pd(ia0, ia4);
38 | __m256d ra = _mm256_unpacklo_pd(ra0, ra4);
39 | __m256d ia = _mm256_unpacklo_pd(ia0, ia4);
40 |
41 | reim_invctwiddle_avx_fma(&ra, &rb, &ia, &ib, omr, omi);
42 |
43 | ra4 = _mm256_unpackhi_pd(ra, rb);
44 | ia4 = _mm256_unpackhi_pd(ia, ib);
45 | ra0 = _mm256_unpacklo_pd(ra, rb);
46 | ia0 = _mm256_unpacklo_pd(ia, ib);
47 | }
48 |
49 | // 2
50 | {
51 | const __m128d ifft8neg = _mm_castsi128_pd(_mm_set_epi64x(0, UINT64_C(1) << 63));
52 | __m128d omri = _mm_loadu_pd(omp + 4); // r,i
53 | __m128d ommri = _mm_xor_pd(omri, ifft8neg); // -r,i
54 | __m256d omrimri = _mm256_set_m128d(ommri, omri); // r,i,-r,i
55 | __m256d omi = _mm256_permute_pd(omrimri, 3); // i,i,-r,-r
56 | __m256d omr = _mm256_permute_pd(omrimri, 12); // r,r,i,i
57 |
58 | __m256d rb = _mm256_permute2f128_pd(ra0, ra4, 0x31);
59 | __m256d ib = _mm256_permute2f128_pd(ia0, ia4, 0x31);
60 | __m256d ra = _mm256_permute2f128_pd(ra0, ra4, 0x20);
61 | __m256d ia = _mm256_permute2f128_pd(ia0, ia4, 0x20);
62 |
63 | reim_invctwiddle_avx_fma(&ra, &rb, &ia, &ib, omr, omi);
64 |
65 | ra0 = _mm256_permute2f128_pd(ra, rb, 0x20);
66 | ra4 = _mm256_permute2f128_pd(ra, rb, 0x31);
67 | ia0 = _mm256_permute2f128_pd(ia, ib, 0x20);
68 | ia4 = _mm256_permute2f128_pd(ia, ib, 0x31);
69 | }
70 |
71 | // 3
72 | {
73 | __m128d omri = _mm_loadu_pd(omp + 6); // r,i
74 | __m256d omriri = _mm256_set_m128d(omri, omri); // r,i,r,i
75 | __m256d omi = _mm256_permute_pd(omriri, 15); // i,i,i,i
76 | __m256d omr = _mm256_permute_pd(omriri, 0); // r,r,r,r
77 |
78 | reim_invctwiddle_avx_fma(&ra0, &ra4, &ia0, &ia4, omr, omi);
79 | }
80 |
81 | // 4
82 | _mm256_storeu_pd(dre, ra0);
83 | _mm256_storeu_pd(dre + 4, ra4);
84 | _mm256_storeu_pd(dim, ia0);
85 | _mm256_storeu_pd(dim + 4, ia4);
86 | }
87 |
--------------------------------------------------------------------------------
/spqlios/reim/reim_to_tnx_avx.c:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include "../commons_private.h"
6 | #include "reim_fft_internal.h"
7 | #include "reim_fft_private.h"
8 |
9 | typedef union {double d; uint64_t u;} dblui64_t;
10 |
11 | EXPORT void reim_to_tnx_avx(const REIM_TO_TNX_PRECOMP* tables, double* r, const double* x) {
12 | const uint64_t n = tables->m << 1;
13 | const __m256d add_cst = _mm256_set1_pd(tables->add_cst);
14 | const __m256d mask_and = _mm256_castsi256_pd(_mm256_set1_epi64x(tables->mask_and));
15 | const __m256d mask_or = _mm256_castsi256_pd(_mm256_set1_epi64x(tables->mask_or));
16 | const __m256d sub_cst = _mm256_set1_pd(tables->sub_cst);
17 | __m256d reg0,reg1;
18 | for (uint64_t i=0; i
2 | #include
3 |
4 | #include "../commons_private.h"
5 | #include "reim_fft_internal.h"
6 | #include "reim_fft_private.h"
7 |
8 | EXPORT void reim_to_tnx_basic_ref(const REIM_TO_TNX_PRECOMP* tables, double* r, const double* x) {
9 | const uint64_t n = tables->m << 1;
10 | double divisor = tables->divisor;
11 | for (uint64_t i=0; im << 1;
21 | double add_cst = tables->add_cst;
22 | uint64_t mask_and = tables->mask_and;
23 | uint64_t mask_or = tables->mask_or;
24 | double sub_cst = tables->sub_cst;
25 | dblui64_t cur;
26 | for (uint64_t i=0; i 52) return spqlios_error("log2overhead is too large");
38 | res->m = m;
39 | res->divisor = divisor;
40 | res->log2overhead = log2overhead;
41 | // 52 + 11 + 1
42 | // ......1.......01(1)|expo|sign
43 | // .......=========(1)|expo|sign msbbits = log2ovh + 2 + 11 + 1
44 | uint64_t nbits = 50 - log2overhead;
45 | dblui64_t ovh_cst;
46 | ovh_cst.d = 0.5 + (6<add_cst = ovh_cst.d * divisor;
48 | res->mask_and = ((UINT64_C(1) << nbits) - 1);
49 | res->mask_or = ovh_cst.u & ((UINT64_C(-1)) << nbits);
50 | res->sub_cst = ovh_cst.d;
51 | // TODO: check selection logic
52 | if (CPU_SUPPORTS("avx2")) {
53 | if (m >= 8) {
54 | res->function = reim_to_tnx_avx;
55 | } else {
56 | res->function = reim_to_tnx_ref;
57 | }
58 | } else {
59 | res->function = reim_to_tnx_ref;
60 | }
61 | return res;
62 | }
63 |
64 | EXPORT REIM_TO_TNX_PRECOMP* new_reim_to_tnx_precomp(uint32_t m, double divisor, uint32_t log2overhead) {
65 | REIM_TO_TNX_PRECOMP* res = malloc(sizeof(*res));
66 | if (!res) return spqlios_error(strerror(errno));
67 | return spqlios_keep_or_free(res, init_reim_to_tnx_precomp(res, m, divisor, log2overhead));
68 | }
69 |
70 | EXPORT void reim_to_tnx(const REIM_TO_TNX_PRECOMP* tables, double* r, const double* x) {
71 | tables->function(tables, r, x);
72 | }
73 |
--------------------------------------------------------------------------------
/spqlios/reim4/reim4_execute.c:
--------------------------------------------------------------------------------
1 | #include "reim4_fftvec_internal.h"
2 | #include "reim4_fftvec_private.h"
3 |
4 | EXPORT void reim4_fftvec_addmul(const REIM4_FFTVEC_ADDMUL_PRECOMP* tables, double* r, const double* a,
5 | const double* b) {
6 | tables->function(tables, r, a, b);
7 | }
8 |
9 | EXPORT void reim4_fftvec_mul(const REIM4_FFTVEC_MUL_PRECOMP* tables, double* r, const double* a, const double* b) {
10 | tables->function(tables, r, a, b);
11 | }
12 |
13 | EXPORT void reim4_from_cplx(const REIM4_FROM_CPLX_PRECOMP* tables, double* r, const void* a) {
14 | tables->function(tables, r, a);
15 | }
16 |
17 | EXPORT void reim4_to_cplx(const REIM4_TO_CPLX_PRECOMP* tables, void* r, const double* a) {
18 | tables->function(tables, r, a);
19 | }
20 |
--------------------------------------------------------------------------------
/spqlios/reim4/reim4_fallbacks_aarch64.c:
--------------------------------------------------------------------------------
1 | #include "reim4_fftvec_private.h"
2 |
3 | EXPORT void reim4_fftvec_mul_fma(const REIM4_FFTVEC_MUL_PRECOMP* tables, double* r, const double* a,
4 | const double* b){UNDEFINED()}
5 |
6 | EXPORT void reim4_fftvec_addmul_fma(const REIM4_FFTVEC_ADDMUL_PRECOMP* tables, double* r, const double* a,
7 | const double* b){UNDEFINED()}
8 |
9 | EXPORT void reim4_from_cplx_fma(const REIM4_FROM_CPLX_PRECOMP* tables, double* r, const void* a){UNDEFINED()}
10 |
11 | EXPORT void reim4_to_cplx_fma(const REIM4_TO_CPLX_PRECOMP* tables, void* r, const double* a){UNDEFINED()}
12 |
--------------------------------------------------------------------------------
/spqlios/reim4/reim4_fftvec_addmul_fma.c:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include "reim4_fftvec_private.h"
6 |
7 | EXPORT void reim4_fftvec_addmul_fma(const REIM4_FFTVEC_ADDMUL_PRECOMP* tables, double* r_ptr, const double* a_ptr,
8 | const double* b_ptr) {
9 | const double* const rend_ptr = r_ptr + (tables->m << 1);
10 | while (r_ptr != rend_ptr) {
11 | __m256d rr = _mm256_loadu_pd(r_ptr);
12 | __m256d ri = _mm256_loadu_pd(r_ptr + 4);
13 | const __m256d ar = _mm256_loadu_pd(a_ptr);
14 | const __m256d ai = _mm256_loadu_pd(a_ptr + 4);
15 | const __m256d br = _mm256_loadu_pd(b_ptr);
16 | const __m256d bi = _mm256_loadu_pd(b_ptr + 4);
17 |
18 | rr = _mm256_fmsub_pd(ai, bi, rr);
19 | rr = _mm256_fmsub_pd(ar, br, rr);
20 | ri = _mm256_fmadd_pd(ar, bi, ri);
21 | ri = _mm256_fmadd_pd(ai, br, ri);
22 |
23 | _mm256_storeu_pd(r_ptr, rr);
24 | _mm256_storeu_pd(r_ptr + 4, ri);
25 |
26 | r_ptr += 8;
27 | a_ptr += 8;
28 | b_ptr += 8;
29 | }
30 | }
31 |
32 | EXPORT void reim4_fftvec_mul_fma(const REIM4_FFTVEC_MUL_PRECOMP* tables, double* r_ptr, const double* a_ptr,
33 | const double* b_ptr) {
34 | const double* const rend_ptr = r_ptr + (tables->m << 1);
35 | while (r_ptr != rend_ptr) {
36 | const __m256d ar = _mm256_loadu_pd(a_ptr);
37 | const __m256d ai = _mm256_loadu_pd(a_ptr + 4);
38 | const __m256d br = _mm256_loadu_pd(b_ptr);
39 | const __m256d bi = _mm256_loadu_pd(b_ptr + 4);
40 |
41 | const __m256d t1 = _mm256_mul_pd(ai, bi);
42 | const __m256d t2 = _mm256_mul_pd(ar, bi);
43 |
44 | __m256d rr = _mm256_fmsub_pd(ar, br, t1);
45 | __m256d ri = _mm256_fmadd_pd(ai, br, t2);
46 |
47 | _mm256_storeu_pd(r_ptr, rr);
48 | _mm256_storeu_pd(r_ptr + 4, ri);
49 |
50 | r_ptr += 8;
51 | a_ptr += 8;
52 | b_ptr += 8;
53 | }
54 | }
55 |
--------------------------------------------------------------------------------
/spqlios/reim4/reim4_fftvec_addmul_ref.c:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 |
6 | #include "../commons_private.h"
7 | #include "reim4_fftvec_internal.h"
8 | #include "reim4_fftvec_private.h"
9 |
10 | void* init_reim4_fftvec_addmul_precomp(REIM4_FFTVEC_ADDMUL_PRECOMP* res, uint32_t m) {
11 | res->m = m;
12 | if (CPU_SUPPORTS("fma")) {
13 | if (m >= 2) {
14 | res->function = reim4_fftvec_addmul_fma;
15 | } else {
16 | res->function = reim4_fftvec_addmul_ref;
17 | }
18 | } else {
19 | res->function = reim4_fftvec_addmul_ref;
20 | }
21 | return res;
22 | }
23 |
24 | EXPORT REIM4_FFTVEC_ADDMUL_PRECOMP* new_reim4_fftvec_addmul_precomp(uint32_t m) {
25 | REIM4_FFTVEC_ADDMUL_PRECOMP* res = malloc(sizeof(*res));
26 | if (!res) return spqlios_error(strerror(errno));
27 | return spqlios_keep_or_free(res, init_reim4_fftvec_addmul_precomp(res, m));
28 | }
29 |
30 | EXPORT void reim4_fftvec_addmul_ref(const REIM4_FFTVEC_ADDMUL_PRECOMP* precomp, double* r, const double* a,
31 | const double* b) {
32 | const uint64_t m = precomp->m;
33 | for (uint64_t j = 0; j < m / 4; ++j) {
34 | for (uint64_t i = 0; i < 4; ++i) {
35 | double re = a[i] * b[i] - a[i + 4] * b[i + 4];
36 | double im = a[i] * b[i + 4] + a[i + 4] * b[i];
37 | r[i] += re;
38 | r[i + 4] += im;
39 | }
40 | a += 8;
41 | b += 8;
42 | r += 8;
43 | }
44 | }
45 |
46 | EXPORT void reim4_fftvec_addmul_simple(uint32_t m, double* r, const double* a, const double* b) {
47 | static REIM4_FFTVEC_ADDMUL_PRECOMP precomp[32];
48 | REIM4_FFTVEC_ADDMUL_PRECOMP* p = precomp + log2m(m);
49 | if (!p->function) {
50 | if (!init_reim4_fftvec_addmul_precomp(p, m)) abort();
51 | }
52 | p->function(p, r, a, b);
53 | }
54 |
55 | void* init_reim4_fftvec_mul_precomp(REIM4_FFTVEC_MUL_PRECOMP* res, uint32_t m) {
56 | res->m = m;
57 | if (CPU_SUPPORTS("fma")) {
58 | if (m >= 4) {
59 | res->function = reim4_fftvec_mul_fma;
60 | } else {
61 | res->function = reim4_fftvec_mul_ref;
62 | }
63 | } else {
64 | res->function = reim4_fftvec_mul_ref;
65 | }
66 | return res;
67 | }
68 |
69 | EXPORT REIM4_FFTVEC_MUL_PRECOMP* new_reim4_fftvec_mul_precomp(uint32_t m) {
70 | REIM4_FFTVEC_MUL_PRECOMP* res = malloc(sizeof(*res));
71 | if (!res) return spqlios_error(strerror(errno));
72 | return spqlios_keep_or_free(res, init_reim4_fftvec_mul_precomp(res, m));
73 | }
74 |
75 | EXPORT void reim4_fftvec_mul_ref(const REIM4_FFTVEC_MUL_PRECOMP* precomp, double* r, const double* a, const double* b) {
76 | const uint64_t m = precomp->m;
77 | for (uint64_t j = 0; j < m / 4; ++j) {
78 | for (uint64_t i = 0; i < 4; ++i) {
79 | double re = a[i] * b[i] - a[i + 4] * b[i + 4];
80 | double im = a[i] * b[i + 4] + a[i + 4] * b[i];
81 | r[i] = re;
82 | r[i + 4] = im;
83 | }
84 | a += 8;
85 | b += 8;
86 | r += 8;
87 | }
88 | }
89 |
90 | EXPORT void reim4_fftvec_mul_simple(uint32_t m, double* r, const double* a, const double* b) {
91 | static REIM4_FFTVEC_MUL_PRECOMP precomp[32];
92 | REIM4_FFTVEC_MUL_PRECOMP* p = precomp + log2m(m);
93 | if (!p->function) {
94 | if (!init_reim4_fftvec_mul_precomp(p, m)) abort();
95 | }
96 | p->function(p, r, a, b);
97 | }
98 |
--------------------------------------------------------------------------------
/spqlios/reim4/reim4_fftvec_conv_fma.c:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include "reim4_fftvec_private.h"
6 |
7 | EXPORT void reim4_from_cplx_fma(const REIM4_FROM_CPLX_PRECOMP* tables, double* r_ptr, const void* a) {
8 | const double* const rend_ptr = r_ptr + (tables->m << 1);
9 |
10 | const double* a_ptr = (double*)a;
11 | while (r_ptr != rend_ptr) {
12 | __m256d t1 = _mm256_loadu_pd(a_ptr);
13 | __m256d t2 = _mm256_loadu_pd(a_ptr + 4);
14 |
15 | _mm256_storeu_pd(r_ptr, _mm256_unpacklo_pd(t1, t2));
16 | _mm256_storeu_pd(r_ptr + 4, _mm256_unpackhi_pd(t1, t2));
17 |
18 | r_ptr += 8;
19 | a_ptr += 8;
20 | }
21 | }
22 |
23 | EXPORT void reim4_to_cplx_fma(const REIM4_TO_CPLX_PRECOMP* tables, void* r, const double* a_ptr) {
24 | const double* const aend_ptr = a_ptr + (tables->m << 1);
25 | double* r_ptr = (double*)r;
26 |
27 | while (a_ptr != aend_ptr) {
28 | __m256d t1 = _mm256_loadu_pd(a_ptr);
29 | __m256d t2 = _mm256_loadu_pd(a_ptr + 4);
30 |
31 | _mm256_storeu_pd(r_ptr, _mm256_unpacklo_pd(t1, t2));
32 | _mm256_storeu_pd(r_ptr + 4, _mm256_unpackhi_pd(t1, t2));
33 |
34 | r_ptr += 8;
35 | a_ptr += 8;
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/spqlios/reim4/reim4_fftvec_conv_ref.c:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 |
6 | #include "../commons_private.h"
7 | #include "reim4_fftvec_internal.h"
8 | #include "reim4_fftvec_private.h"
9 |
10 | EXPORT void reim4_from_cplx_ref(const REIM4_FROM_CPLX_PRECOMP* tables, double* r, const void* a) {
11 | const double* x = (double*)a;
12 | const uint64_t m = tables->m;
13 | for (uint64_t i = 0; i < m / 4; ++i) {
14 | double r0 = x[0];
15 | double i0 = x[1];
16 | double r1 = x[2];
17 | double i1 = x[3];
18 | double r2 = x[4];
19 | double i2 = x[5];
20 | double r3 = x[6];
21 | double i3 = x[7];
22 | r[0] = r0;
23 | r[1] = r2;
24 | r[2] = r1;
25 | r[3] = r3;
26 | r[4] = i0;
27 | r[5] = i2;
28 | r[6] = i1;
29 | r[7] = i3;
30 | x += 8;
31 | r += 8;
32 | }
33 | }
34 |
35 | void* init_reim4_from_cplx_precomp(REIM4_FROM_CPLX_PRECOMP* res, uint32_t nn) {
36 | res->m = nn / 2;
37 | if (CPU_SUPPORTS("fma")) {
38 | if (nn >= 4) {
39 | res->function = reim4_from_cplx_fma;
40 | } else {
41 | res->function = reim4_from_cplx_ref;
42 | }
43 | } else {
44 | res->function = reim4_from_cplx_ref;
45 | }
46 | return res;
47 | }
48 |
49 | EXPORT REIM4_FROM_CPLX_PRECOMP* new_reim4_from_cplx_precomp(uint32_t m) {
50 | REIM4_FROM_CPLX_PRECOMP* res = malloc(sizeof(*res));
51 | if (!res) return spqlios_error(strerror(errno));
52 | return spqlios_keep_or_free(res, init_reim4_from_cplx_precomp(res, m));
53 | }
54 |
55 | EXPORT void reim4_from_cplx_simple(uint32_t m, double* r, const void* a) {
56 | static REIM4_FROM_CPLX_PRECOMP precomp[32];
57 | REIM4_FROM_CPLX_PRECOMP* p = precomp + log2m(m);
58 | if (!p->function) {
59 | if (!init_reim4_from_cplx_precomp(p, m)) abort();
60 | }
61 | p->function(p, r, a);
62 | }
63 |
64 | EXPORT void reim4_to_cplx_ref(const REIM4_TO_CPLX_PRECOMP* tables, void* r, const double* a) {
65 | double* y = (double*)r;
66 | const uint64_t m = tables->m;
67 | for (uint64_t i = 0; i < m / 4; ++i) {
68 | double r0 = a[0];
69 | double r2 = a[1];
70 | double r1 = a[2];
71 | double r3 = a[3];
72 | double i0 = a[4];
73 | double i2 = a[5];
74 | double i1 = a[6];
75 | double i3 = a[7];
76 | y[0] = r0;
77 | y[1] = i0;
78 | y[2] = r1;
79 | y[3] = i1;
80 | y[4] = r2;
81 | y[5] = i2;
82 | y[6] = r3;
83 | y[7] = i3;
84 | a += 8;
85 | y += 8;
86 | }
87 | }
88 |
89 | void* init_reim4_to_cplx_precomp(REIM4_TO_CPLX_PRECOMP* res, uint32_t m) {
90 | res->m = m;
91 | if (CPU_SUPPORTS("fma")) {
92 | if (m >= 2) {
93 | res->function = reim4_to_cplx_fma;
94 | } else {
95 | res->function = reim4_to_cplx_ref;
96 | }
97 | } else {
98 | res->function = reim4_to_cplx_ref;
99 | }
100 | return res;
101 | }
102 |
103 | EXPORT REIM4_TO_CPLX_PRECOMP* new_reim4_to_cplx_precomp(uint32_t m) {
104 | REIM4_TO_CPLX_PRECOMP* res = malloc(sizeof(*res));
105 | if (!res) return spqlios_error(strerror(errno));
106 | return spqlios_keep_or_free(res, init_reim4_to_cplx_precomp(res, m));
107 | }
108 |
109 | EXPORT void reim4_to_cplx_simple(uint32_t m, void* r, const double* a) {
110 | static REIM4_TO_CPLX_PRECOMP precomp[32];
111 | REIM4_TO_CPLX_PRECOMP* p = precomp + log2m(m);
112 | if (!p->function) {
113 | if (!init_reim4_to_cplx_precomp(p, m)) abort();
114 | }
115 | p->function(p, r, a);
116 | }
117 |
--------------------------------------------------------------------------------
/spqlios/reim4/reim4_fftvec_internal.h:
--------------------------------------------------------------------------------
1 | #ifndef SPQLIOS_REIM4_FFTVEC_INTERNAL_H
2 | #define SPQLIOS_REIM4_FFTVEC_INTERNAL_H
3 |
4 | #include "reim4_fftvec_public.h"
5 |
6 | EXPORT void reim4_fftvec_mul_ref(const REIM4_FFTVEC_MUL_PRECOMP* tables, double* r, const double* a, const double* b);
7 | EXPORT void reim4_fftvec_mul_fma(const REIM4_FFTVEC_MUL_PRECOMP* tables, double* r, const double* a, const double* b);
8 |
9 | EXPORT void reim4_fftvec_addmul_ref(const REIM4_FFTVEC_ADDMUL_PRECOMP* tables, double* r, const double* a,
10 | const double* b);
11 | EXPORT void reim4_fftvec_addmul_fma(const REIM4_FFTVEC_ADDMUL_PRECOMP* tables, double* r, const double* a,
12 | const double* b);
13 |
14 | EXPORT void reim4_from_cplx_ref(const REIM4_FROM_CPLX_PRECOMP* tables, double* r, const void* a);
15 | EXPORT void reim4_from_cplx_fma(const REIM4_FROM_CPLX_PRECOMP* tables, double* r, const void* a);
16 |
17 | EXPORT void reim4_to_cplx_ref(const REIM4_TO_CPLX_PRECOMP* tables, void* r, const double* a);
18 | EXPORT void reim4_to_cplx_fma(const REIM4_TO_CPLX_PRECOMP* tables, void* r, const double* a);
19 |
20 | #endif // SPQLIOS_REIM4_FFTVEC_INTERNAL_H
21 |
--------------------------------------------------------------------------------
/spqlios/reim4/reim4_fftvec_private.h:
--------------------------------------------------------------------------------
1 | #ifndef SPQLIOS_REIM4_FFTVEC_PRIVATE_H
2 | #define SPQLIOS_REIM4_FFTVEC_PRIVATE_H
3 |
4 | #include "reim4_fftvec_public.h"
5 |
6 | #define STATIC_ASSERT(condition) (void)sizeof(char[-1 + 2 * !!(condition)])
7 |
8 | typedef void (*R4_FFTVEC_MUL_FUNC)(const REIM4_FFTVEC_MUL_PRECOMP*, double*, const double*, const double*);
9 | typedef void (*R4_FFTVEC_ADDMUL_FUNC)(const REIM4_FFTVEC_ADDMUL_PRECOMP*, double*, const double*, const double*);
10 | typedef void (*R4_FROM_CPLX_FUNC)(const REIM4_FROM_CPLX_PRECOMP*, double*, const void*);
11 | typedef void (*R4_TO_CPLX_FUNC)(const REIM4_TO_CPLX_PRECOMP*, void*, const double*);
12 |
13 | struct reim4_mul_precomp {
14 | R4_FFTVEC_MUL_FUNC function;
15 | int64_t m;
16 | };
17 |
18 | struct reim4_addmul_precomp {
19 | R4_FFTVEC_ADDMUL_FUNC function;
20 | int64_t m;
21 | };
22 |
23 | struct reim4_from_cplx_precomp {
24 | R4_FROM_CPLX_FUNC function;
25 | int64_t m;
26 | };
27 |
28 | struct reim4_to_cplx_precomp {
29 | R4_TO_CPLX_FUNC function;
30 | int64_t m;
31 | };
32 |
33 | #endif // SPQLIOS_REIM4_FFTVEC_PRIVATE_H
34 |
--------------------------------------------------------------------------------
/spqlios/reim4/reim4_fftvec_public.h:
--------------------------------------------------------------------------------
1 | #ifndef SPQLIOS_REIM4_FFTVEC_PUBLIC_H
2 | #define SPQLIOS_REIM4_FFTVEC_PUBLIC_H
3 |
4 | #include "../commons.h"
5 |
6 | typedef struct reim4_addmul_precomp REIM4_FFTVEC_ADDMUL_PRECOMP;
7 | typedef struct reim4_mul_precomp REIM4_FFTVEC_MUL_PRECOMP;
8 | typedef struct reim4_from_cplx_precomp REIM4_FROM_CPLX_PRECOMP;
9 | typedef struct reim4_to_cplx_precomp REIM4_TO_CPLX_PRECOMP;
10 |
11 | EXPORT REIM4_FFTVEC_MUL_PRECOMP* new_reim4_fftvec_mul_precomp(uint32_t m);
12 | EXPORT void reim4_fftvec_mul(const REIM4_FFTVEC_MUL_PRECOMP* tables, double* r, const double* a, const double* b);
13 | #define delete_reim4_fftvec_mul_precomp free
14 |
15 | EXPORT REIM4_FFTVEC_ADDMUL_PRECOMP* new_reim4_fftvec_addmul_precomp(uint32_t nn);
16 | EXPORT void reim4_fftvec_addmul(const REIM4_FFTVEC_ADDMUL_PRECOMP* tables, double* r, const double* a, const double* b);
17 | #define delete_reim4_fftvec_addmul_precomp free
18 |
19 | /**
20 | * @brief prepares a conversion from the cplx fftvec layout to the reim4 layout.
21 | * @param m complex dimension m from C[X] mod X^m-i.
22 | */
23 | EXPORT REIM4_FROM_CPLX_PRECOMP* new_reim4_from_cplx_precomp(uint32_t m);
24 | EXPORT void reim4_from_cplx(const REIM4_FROM_CPLX_PRECOMP* tables, double* r, const void* a);
25 | #define delete_reim4_from_cplx_precomp free
26 |
27 | /**
28 | * @brief prepares a conversion from the reim4 fftvec layout to the cplx layout
29 | * @param m the complex dimension m from C[X] mod X^m-i.
30 | */
31 | EXPORT REIM4_TO_CPLX_PRECOMP* new_reim4_to_cplx_precomp(uint32_t m);
32 | EXPORT void reim4_to_cplx(const REIM4_TO_CPLX_PRECOMP* tables, void* r, const double* a);
33 | #define delete_reim4_to_cplx_precomp free
34 |
35 | /**
36 | * @brief Simpler API for the fftvec multiplication function.
37 | * For each dimension, the precomputed tables for this dimension are generated automatically the first time.
38 | * It is advised to do one dry-run call per desired dimension before using in a multithread environment */
39 | EXPORT void reim4_fftvec_mul_simple(uint32_t m, double* r, const double* a, const double* b);
40 |
41 | /**
42 | * @brief Simpler API for the fftvec addmul function.
43 | * For each dimension, the precomputed tables for this dimension are generated automatically the first time.
44 | * It is advised to do one dry-run call per desired dimension before using in a multithread environment */
45 | EXPORT void reim4_fftvec_addmul_simple(uint32_t m, double* r, const double* a, const double* b);
46 |
47 | /**
48 | * @brief Simpler API for cplx conversion.
49 | * For each dimension, the precomputed tables for this dimension are generated automatically the first time.
50 | * It is advised to do one dry-run call per desired dimension before using in a multithread environment */
51 | EXPORT void reim4_from_cplx_simple(uint32_t m, double* r, const void* a);
52 |
53 | /**
54 | * @brief Simpler API for to cplx conversion.
55 | * For each dimension, the precomputed tables for this dimension are generated automatically the first time.
56 | * It is advised to do one dry-run call per desired dimension before using in a multithread environment */
57 | EXPORT void reim4_to_cplx_simple(uint32_t m, void* r, const double* a);
58 |
59 | #endif // SPQLIOS_REIM4_FFTVEC_PUBLIC_H
--------------------------------------------------------------------------------
/test/spqlios_cplx_conversions_test.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 |
5 | #include "spqlios/cplx/cplx_fft_internal.h"
6 | #include "spqlios/cplx/cplx_fft_private.h"
7 |
8 | #ifdef __x86_64__
9 | TEST(fft, cplx_from_znx32_ref_vs_fma) {
10 | const uint32_t m = 128;
11 | int32_t* src = (int32_t*)spqlios_alloc_custom_align(32, 10 * m * sizeof(int32_t));
12 | CPLX* dst1 = (CPLX*)(src + 2 * m);
13 | CPLX* dst2 = (CPLX*)(src + 6 * m);
14 | for (uint64_t i = 0; i < 2 * m; ++i) {
15 | src[i] = rand() - RAND_MAX / 2;
16 | }
17 | CPLX_FROM_ZNX32_PRECOMP precomp;
18 | precomp.m = m;
19 | cplx_from_znx32_ref(&precomp, dst1, src);
20 | // cplx_from_znx32_simple(m, 32, dst1, src);
21 | cplx_from_znx32_avx2_fma(&precomp, dst2, src);
22 | for (uint64_t i = 0; i < m; ++i) {
23 | ASSERT_EQ(dst1[i][0], dst2[i][0]);
24 | ASSERT_EQ(dst1[i][1], dst2[i][1]);
25 | }
26 | spqlios_free(src);
27 | }
28 | #endif
29 |
30 | #ifdef __x86_64__
31 | TEST(fft, cplx_from_tnx32_ref_vs_fma) {
32 | const uint32_t m = 128;
33 | int32_t* src = (int32_t*)spqlios_alloc_custom_align(32, 10 * m * sizeof(int32_t));
34 | CPLX* dst1 = (CPLX*)(src + 2 * m);
35 | CPLX* dst2 = (CPLX*)(src + 6 * m);
36 | for (uint64_t i = 0; i < 2 * m; ++i) {
37 | src[i] = rand() + (rand() << 20);
38 | }
39 | CPLX_FROM_TNX32_PRECOMP precomp;
40 | precomp.m = m;
41 | cplx_from_tnx32_ref(&precomp, dst1, src);
42 | // cplx_from_tnx32_simple(m, dst1, src);
43 | cplx_from_tnx32_avx2_fma(&precomp, dst2, src);
44 | for (uint64_t i = 0; i < m; ++i) {
45 | ASSERT_EQ(dst1[i][0], dst2[i][0]);
46 | ASSERT_EQ(dst1[i][1], dst2[i][1]);
47 | }
48 | spqlios_free(src);
49 | }
50 | #endif
51 |
52 | #ifdef __x86_64__
53 | TEST(fft, cplx_to_tnx32_ref_vs_fma) {
54 | for (const uint32_t m : {8, 128, 1024, 65536}) {
55 | for (const double divisor : {double(1), double(m), double(0.5)}) {
56 | CPLX* src = (CPLX*)spqlios_alloc_custom_align(32, 10 * m * sizeof(int32_t));
57 | int32_t* dst1 = (int32_t*)(src + m);
58 | int32_t* dst2 = (int32_t*)(src + 2 * m);
59 | for (uint64_t i = 0; i < 2 * m; ++i) {
60 | src[i][0] = (rand() / double(RAND_MAX) - 0.5) * pow(2., 19 - (rand() % 60)) * divisor;
61 | src[i][1] = (rand() / double(RAND_MAX) - 0.5) * pow(2., 19 - (rand() % 60)) * divisor;
62 | }
63 | CPLX_TO_TNX32_PRECOMP precomp;
64 | precomp.m = m;
65 | precomp.divisor = divisor;
66 | cplx_to_tnx32_ref(&precomp, dst1, src);
67 | cplx_to_tnx32_avx2_fma(&precomp, dst2, src);
68 | // cplx_to_tnx32_simple(m, divisor, 18, dst2, src);
69 | for (uint64_t i = 0; i < 2 * m; ++i) {
70 | double truevalue =
71 | (src[i % m][i / m] / divisor - floor(src[i % m][i / m] / divisor + 0.5)) * (INT64_C(1) << 32);
72 | if (fabs(truevalue - floor(truevalue)) == 0.5) {
73 | // ties can differ by 0, 1 or -1
74 | ASSERT_LE(abs(dst1[i] - dst2[i]), 0)
75 | << i << " " << dst1[i] << " " << dst2[i] << " " << truevalue << std::endl;
76 | } else {
77 | // otherwise, we should have equality
78 | ASSERT_LE(abs(dst1[i] - dst2[i]), 0)
79 | << i << " " << dst1[i] << " " << dst2[i] << " " << truevalue << std::endl;
80 | }
81 | }
82 | spqlios_free(src);
83 | }
84 | }
85 | }
86 | #endif
87 |
--------------------------------------------------------------------------------
/test/spqlios_cplx_fft_bench.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 | #include
11 | #include
12 |
13 | #include "../spqlios/cplx/cplx_fft_internal.h"
14 | #include "spqlios/reim/reim_fft.h"
15 |
16 | using namespace std;
17 |
18 | void init_random_values(uint64_t n, double* v) {
19 | for (uint64_t i = 0; i < n; ++i) v[i] = rand() - (RAND_MAX >> 1);
20 | }
21 |
22 | void benchmark_cplx_fft(benchmark::State& state) {
23 | const int32_t nn = state.range(0);
24 | CPLX_FFT_PRECOMP* a = new_cplx_fft_precomp(nn / 2, 1);
25 | double* c = (double*)cplx_fft_precomp_get_buffer(a, 0);
26 | init_random_values(nn, c);
27 | for (auto _ : state) {
28 | // cplx_fft_simple(nn/2, c);
29 | cplx_fft(a, c);
30 | }
31 | delete_cplx_fft_precomp(a);
32 | }
33 |
34 | void benchmark_cplx_ifft(benchmark::State& state) {
35 | const int32_t nn = state.range(0);
36 | CPLX_IFFT_PRECOMP* a = new_cplx_ifft_precomp(nn / 2, 1);
37 | double* c = (double*)cplx_ifft_precomp_get_buffer(a, 0);
38 | init_random_values(nn, c);
39 | for (auto _ : state) {
40 | // cplx_ifft_simple(nn/2, c);
41 | cplx_ifft(a, c);
42 | }
43 | delete_cplx_ifft_precomp(a);
44 | }
45 |
46 | void benchmark_reim_fft(benchmark::State& state) {
47 | const int32_t nn = state.range(0);
48 | const uint32_t m = nn / 2;
49 | REIM_FFT_PRECOMP* a = new_reim_fft_precomp(m, 1);
50 | double* c = reim_fft_precomp_get_buffer(a, 0);
51 | init_random_values(nn, c);
52 | for (auto _ : state) {
53 | // cplx_fft_simple(nn/2, c);
54 | reim_fft(a, c);
55 | }
56 | delete_reim_fft_precomp(a);
57 | }
58 |
59 | #ifdef __aarch64__
60 | EXPORT REIM_FFT_PRECOMP* new_reim_fft_precomp_neon(uint32_t m, uint32_t num_buffers);
61 | EXPORT void reim_fft_neon(const REIM_FFT_PRECOMP* precomp, double* d);
62 |
63 | void benchmark_reim_fft_neon(benchmark::State& state) {
64 | const int32_t nn = state.range(0);
65 | const uint32_t m = nn / 2;
66 | REIM_FFT_PRECOMP* a = new_reim_fft_precomp_neon(m, 1);
67 | double* c = reim_fft_precomp_get_buffer(a, 0);
68 | init_random_values(nn, c);
69 | for (auto _ : state) {
70 | // cplx_fft_simple(nn/2, c);
71 | reim_fft_neon(a, c);
72 | }
73 | delete_reim_fft_precomp(a);
74 | }
75 | #endif
76 |
77 | void benchmark_reim_ifft(benchmark::State& state) {
78 | const int32_t nn = state.range(0);
79 | const uint32_t m = nn / 2;
80 | REIM_IFFT_PRECOMP* a = new_reim_ifft_precomp(m, 1);
81 | double* c = reim_ifft_precomp_get_buffer(a, 0);
82 | init_random_values(nn, c);
83 | for (auto _ : state) {
84 | // cplx_ifft_simple(nn/2, c);
85 | reim_ifft(a, c);
86 | }
87 | delete_reim_ifft_precomp(a);
88 | }
89 |
90 | // #define ARGS Arg(1024)->Arg(8192)->Arg(32768)->Arg(65536)
91 | #define ARGS Arg(64)->Arg(256)->Arg(1024)->Arg(2048)->Arg(4096)->Arg(8192)->Arg(16384)->Arg(32768)->Arg(65536)
92 |
93 | int main(int argc, char** argv) {
94 | ::benchmark::Initialize(&argc, argv);
95 | if (::benchmark::ReportUnrecognizedArguments(argc, argv)) return 1;
96 | std::cout << "Dimensions n in the benchmark below are in \"real FFT\" modulo X^n+1" << std::endl;
97 | std::cout << "The complex dimension m (modulo X^m-i) is half of it" << std::endl;
98 | BENCHMARK(benchmark_cplx_fft)->ARGS;
99 | BENCHMARK(benchmark_cplx_ifft)->ARGS;
100 | BENCHMARK(benchmark_reim_fft)->ARGS;
101 | #ifdef __aarch64__
102 | BENCHMARK(benchmark_reim_fft_neon)->ARGS;
103 | #endif
104 | BENCHMARK(benchmark_reim_ifft)->ARGS;
105 | // if (CPU_SUPPORTS("avx512f")) {
106 | // BENCHMARK(bench_cplx_fftvec_twiddle_avx512)->ARGS;
107 | // BENCHMARK(bench_cplx_fftvec_bitwiddle_avx512)->ARGS;
108 | //}
109 | ::benchmark::RunSpecifiedBenchmarks();
110 | ::benchmark::Shutdown();
111 | return 0;
112 | }
113 |
--------------------------------------------------------------------------------
/test/spqlios_q120_ntt_bench.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 |
5 | #include "spqlios/q120/q120_ntt.h"
6 |
7 | #define ARGS Arg(1 << 10)->Arg(1 << 11)->Arg(1 << 12)->Arg(1 << 13)->Arg(1 << 14)->Arg(1 << 15)->Arg(1 << 16)
8 |
9 | template
10 | void benchmark_ntt(benchmark::State& state) {
11 | const uint64_t n = state.range(0);
12 | q120_ntt_precomp* precomp = q120_new_ntt_bb_precomp(n);
13 |
14 | uint64_t* px = new uint64_t[n * 4];
15 | for (uint64_t i = 0; i < 4 * n; i++) {
16 | px[i] = (rand() << 31) + rand();
17 | }
18 | for (auto _ : state) {
19 | f(precomp, (q120b*)px);
20 | }
21 | delete[] px;
22 | q120_del_ntt_bb_precomp(precomp);
23 | }
24 |
25 | template
26 | void benchmark_intt(benchmark::State& state) {
27 | const uint64_t n = state.range(0);
28 | q120_ntt_precomp* precomp = q120_new_intt_bb_precomp(n);
29 |
30 | uint64_t* px = new uint64_t[n * 4];
31 | for (uint64_t i = 0; i < 4 * n; i++) {
32 | px[i] = (rand() << 31) + rand();
33 | }
34 | for (auto _ : state) {
35 | f(precomp, (q120b*)px);
36 | }
37 | delete[] px;
38 | q120_del_intt_bb_precomp(precomp);
39 | }
40 |
41 | BENCHMARK(benchmark_ntt)->Name("q120_ntt_bb_avx2")->ARGS;
42 | BENCHMARK(benchmark_intt)->Name("q120_intt_bb_avx2")->ARGS;
43 |
44 | BENCHMARK_MAIN();
45 |
--------------------------------------------------------------------------------
/test/spqlios_reim4_arithmetic_bench.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "spqlios/reim4/reim4_arithmetic.h"
4 |
5 | void init_random_values(uint64_t n, double* v) {
6 | for (uint64_t i = 0; i < n; ++i)
7 | v[i] = (double(rand() % (UINT64_C(1) << 14)) - (UINT64_C(1) << 13)) / (UINT64_C(1) << 12);
8 | }
9 |
10 | // Run the benchmark
11 | BENCHMARK_MAIN();
12 |
13 | #undef ARGS
14 | #define ARGS Args({47, 16384})->Args({93, 32768})
15 |
16 | /*
17 | * reim4_vec_mat1col_product
18 | * reim4_vec_mat2col_product
19 | * reim4_vec_mat3col_product
20 | * reim4_vec_mat4col_product
21 | */
22 |
23 | template
25 | void benchmark_reim4_vec_matXcols_product(benchmark::State& state) {
26 | const uint64_t nrows = state.range(0);
27 |
28 | double* u = new double[nrows * 8];
29 | init_random_values(8 * nrows, u);
30 | double* v = new double[nrows * X * 8];
31 | init_random_values(X * 8 * nrows, v);
32 | double* dst = new double[X * 8];
33 |
34 | for (auto _ : state) {
35 | fnc(nrows, dst, u, v);
36 | }
37 |
38 | delete[] dst;
39 | delete[] v;
40 | delete[] u;
41 | }
42 |
43 | #undef ARGS
44 | #define ARGS Arg(128)->Arg(1024)->Arg(4096)
45 |
46 | #ifdef __x86_64__
47 | BENCHMARK(benchmark_reim4_vec_matXcols_product<1, reim4_vec_mat1col_product_avx2>)->ARGS;
48 | // TODO: please remove when fixed:
49 | BENCHMARK(benchmark_reim4_vec_matXcols_product<2, reim4_vec_mat2cols_product_avx2>)->ARGS;
50 | #endif
51 | BENCHMARK(benchmark_reim4_vec_matXcols_product<1, reim4_vec_mat1col_product_ref>)->ARGS;
52 | BENCHMARK(benchmark_reim4_vec_matXcols_product<2, reim4_vec_mat2cols_product_ref>)->ARGS;
53 |
--------------------------------------------------------------------------------
/test/spqlios_svp_product_test.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "../spqlios/arithmetic/vec_znx_arithmetic_private.h"
4 | #include "testlib/fft64_dft.h"
5 | #include "testlib/fft64_layouts.h"
6 | #include "testlib/polynomial_vector.h"
7 |
8 | // todo: remove when registered
9 | typedef typeof(fft64_svp_prepare_ref) SVP_PREPARE_F;
10 |
11 | void test_fft64_svp_prepare(SVP_PREPARE_F svp_prepare) {
12 | for (uint64_t n : {2, 4, 8, 64, 128}) {
13 | MODULE* module = new_module_info(n, FFT64);
14 | znx_i64 in = znx_i64::random_log2bound(n, 40);
15 | fft64_svp_ppol_layout out(n);
16 | reim_fft64vec expect = simple_fft64(in);
17 | svp_prepare(module, out.data, in.data());
18 | const double* ed = (double*)expect.data();
19 | const double* ac = (double*)out.data;
20 | for (uint64_t i = 0; i < n; ++i) {
21 | ASSERT_LE(abs(ed[i] - ac[i]), 1e-10) << i << n;
22 | }
23 | delete_module_info(module);
24 | }
25 | }
26 |
27 | TEST(svp_prepare, fft64_svp_prepare_ref) { test_fft64_svp_prepare(fft64_svp_prepare_ref); }
28 | TEST(svp_prepare, svp_prepare) { test_fft64_svp_prepare(svp_prepare); }
29 |
--------------------------------------------------------------------------------
/test/spqlios_svp_test.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "../spqlios/arithmetic/vec_znx_arithmetic_private.h"
4 | #include "testlib/fft64_dft.h"
5 | #include "testlib/fft64_layouts.h"
6 | #include "testlib/polynomial_vector.h"
7 |
8 | void test_fft64_svp_apply_dft(SVP_APPLY_DFT_F svp) {
9 | for (uint64_t n : {2, 4, 8, 64, 128}) {
10 | MODULE* module = new_module_info(n, FFT64);
11 | // poly 1 to multiply - create and prepare
12 | fft64_svp_ppol_layout ppol(n);
13 | ppol.fill_random(1.);
14 | for (uint64_t sa : {3, 5, 8}) {
15 | for (uint64_t sr : {3, 5, 8}) {
16 | uint64_t a_sl = n + uniform_u64_bits(2);
17 | // poly 2 to multiply
18 | znx_vec_i64_layout a(n, sa, a_sl);
19 | a.fill_random(19);
20 | // original operation result
21 | fft64_vec_znx_dft_layout res(n, sr);
22 | thash hash_a_before = a.content_hash();
23 | thash hash_ppol_before = ppol.content_hash();
24 | svp(module, res.data, sr, ppol.data, a.data(), sa, a_sl);
25 | ASSERT_EQ(a.content_hash(), hash_a_before);
26 | ASSERT_EQ(ppol.content_hash(), hash_ppol_before);
27 | // create expected value
28 | reim_fft64vec ppo = ppol.get_copy();
29 | std::vector expect(sr);
30 | for (uint64_t i = 0; i < sr; ++i) {
31 | expect[i] = ppo * simple_fft64(a.get_copy_zext(i));
32 | }
33 | // this is the largest precision we can safely expect
34 | double prec_expect = n * pow(2., 19 - 52);
35 | for (uint64_t i = 0; i < sr; ++i) {
36 | reim_fft64vec actual = res.get_copy_zext(i);
37 | ASSERT_LE(infty_dist(actual, expect[i]), prec_expect);
38 | }
39 | }
40 | }
41 |
42 | delete_module_info(module);
43 | }
44 | }
45 |
46 | TEST(fft64_svp_apply_dft, svp_apply_dft) { test_fft64_svp_apply_dft(svp_apply_dft); }
47 | TEST(fft64_svp_apply_dft, fft64_svp_apply_dft_ref) { test_fft64_svp_apply_dft(fft64_svp_apply_dft_ref); }
48 |
--------------------------------------------------------------------------------
/test/spqlios_vec_rnx_approxdecomp_tnxdbl_test.cpp:
--------------------------------------------------------------------------------
1 | #include "gtest/gtest.h"
2 | #include "spqlios/arithmetic/vec_rnx_arithmetic_private.h"
3 | #include "testlib/vec_rnx_layout.h"
4 |
5 | static void test_rnx_approxdecomp(RNX_APPROXDECOMP_FROM_TNXDBL_F approxdec) {
6 | for (const uint64_t nn : {2, 4, 8, 32}) {
7 | MOD_RNX* module = new_rnx_module_info(nn, FFT64);
8 | for (const uint64_t ell : {1, 2, 7}) {
9 | for (const uint64_t k : {2, 5}) {
10 | TNXDBL_APPROXDECOMP_GADGET* gadget = new_tnxdbl_approxdecomp_gadget(module, k, ell);
11 | for (const uint64_t res_size : {ell, ell - 1, ell + 1}) {
12 | const uint64_t res_sl = nn + uniform_u64_bits(2);
13 | rnx_vec_f64_layout in(nn, 1, nn);
14 | in.fill_random(3);
15 | rnx_vec_f64_layout out(nn, res_size, res_sl);
16 | approxdec(module, gadget, out.data(), res_size, res_sl, in.data());
17 | // reconstruct the output
18 | uint64_t msize = std::min(res_size, ell);
19 | double err_bnd = msize == ell ? pow(2., -double(msize * k) - 1) : pow(2., -double(msize * k));
20 | for (uint64_t j = 0; j < nn; ++j) {
21 | double in_j = in.data()[j];
22 | double out_j = 0;
23 | for (uint64_t i = 0; i < res_size; ++i) {
24 | out_j += out.get_copy(i).get_coeff(j) * pow(2., -double((i + 1) * k));
25 | }
26 | double err = out_j - in_j;
27 | double err_abs = fabs(err - rint(err));
28 | ASSERT_LE(err_abs, err_bnd);
29 | }
30 | }
31 | delete_tnxdbl_approxdecomp_gadget(gadget);
32 | }
33 | }
34 | delete_rnx_module_info(module);
35 | }
36 | }
37 |
38 | TEST(vec_rnx, rnx_approxdecomp) { test_rnx_approxdecomp(rnx_approxdecomp_from_tnxdbl); }
39 | TEST(vec_rnx, rnx_approxdecomp_ref) { test_rnx_approxdecomp(rnx_approxdecomp_from_tnxdbl_ref); }
40 | #ifdef __x86_64__
41 | TEST(vec_rnx, rnx_approxdecomp_avx) { test_rnx_approxdecomp(rnx_approxdecomp_from_tnxdbl_avx); }
42 | #endif
43 |
--------------------------------------------------------------------------------
/test/spqlios_vec_rnx_ppol_test.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "spqlios/arithmetic/vec_rnx_arithmetic_private.h"
4 | #include "spqlios/reim/reim_fft.h"
5 | #include "test/testlib/vec_rnx_layout.h"
6 |
7 | static void test_vec_rnx_svp_prepare(RNX_SVP_PREPARE_F* rnx_svp_prepare, BYTES_OF_RNX_SVP_PPOL_F* tmp_bytes) {
8 | for (uint64_t n : {2, 4, 8, 64}) {
9 | MOD_RNX* mod = new_rnx_module_info(n, FFT64);
10 | const double invm = 1. / mod->m;
11 |
12 | rnx_f64 in = rnx_f64::random_log2bound(n, 40);
13 | rnx_f64 in_divide_by_m = rnx_f64::zero(n);
14 | for (uint64_t i = 0; i < n; ++i) {
15 | in_divide_by_m.set_coeff(i, in.get_coeff(i) * invm);
16 | }
17 | fft64_rnx_svp_ppol_layout out(n);
18 | reim_fft64vec expect = simple_fft64(in_divide_by_m);
19 | rnx_svp_prepare(mod, out.data, in.data());
20 | const double* ed = (double*)expect.data();
21 | const double* ac = (double*)out.data;
22 | for (uint64_t i = 0; i < n; ++i) {
23 | ASSERT_LE(abs(ed[i] - ac[i]), 1e-10) << i << n;
24 | }
25 | delete_rnx_module_info(mod);
26 | }
27 | }
28 | TEST(vec_rnx, vec_rnx_svp_prepare) { test_vec_rnx_svp_prepare(rnx_svp_prepare, bytes_of_rnx_svp_ppol); }
29 | TEST(vec_rnx, vec_rnx_svp_prepare_ref) {
30 | test_vec_rnx_svp_prepare(fft64_rnx_svp_prepare_ref, fft64_bytes_of_rnx_svp_ppol);
31 | }
32 |
33 | static void test_vec_rnx_svp_apply(RNX_SVP_APPLY_F* apply) {
34 | for (uint64_t n : {2, 4, 8, 64, 128}) {
35 | MOD_RNX* mod = new_rnx_module_info(n, FFT64);
36 |
37 | // poly 1 to multiply - create and prepare
38 | fft64_rnx_svp_ppol_layout ppol(n);
39 | ppol.fill_random(1.);
40 | for (uint64_t sa : {3, 5, 8}) {
41 | for (uint64_t sr : {3, 5, 8}) {
42 | uint64_t a_sl = n + uniform_u64_bits(2);
43 | uint64_t r_sl = n + uniform_u64_bits(2);
44 | // poly 2 to multiply
45 | rnx_vec_f64_layout a(n, sa, a_sl);
46 | a.fill_random(19);
47 |
48 | // original operation result
49 | rnx_vec_f64_layout res(n, sr, r_sl);
50 | thash hash_a_before = a.content_hash();
51 | thash hash_ppol_before = ppol.content_hash();
52 | apply(mod, res.data(), sr, r_sl, ppol.data, a.data(), sa, a_sl);
53 | ASSERT_EQ(a.content_hash(), hash_a_before);
54 | ASSERT_EQ(ppol.content_hash(), hash_ppol_before);
55 | // create expected value
56 | reim_fft64vec ppo = ppol.get_copy();
57 | std::vector expect(sr);
58 | for (uint64_t i = 0; i < sr; ++i) {
59 | expect[i] = simple_ifft64(ppo * simple_fft64(a.get_copy_zext(i)));
60 | }
61 | // this is the largest precision we can safely expect
62 | double prec_expect = n * pow(2., 19 - 50);
63 | for (uint64_t i = 0; i < sr; ++i) {
64 | rnx_f64 actual = res.get_copy_zext(i);
65 | ASSERT_LE(infty_dist(actual, expect[i]), prec_expect);
66 | }
67 | }
68 | }
69 | delete_rnx_module_info(mod);
70 | }
71 | }
72 | TEST(vec_rnx, vec_rnx_svp_apply) { test_vec_rnx_svp_apply(rnx_svp_apply); }
73 | TEST(vec_rnx, vec_rnx_svp_apply_ref) { test_vec_rnx_svp_apply(fft64_rnx_svp_apply_ref); }
74 |
--------------------------------------------------------------------------------
/test/spqlios_zn_approxdecomp_test.cpp:
--------------------------------------------------------------------------------
1 | #include "gtest/gtest.h"
2 | #include "spqlios/arithmetic/zn_arithmetic_private.h"
3 | #include "testlib/test_commons.h"
4 |
5 | template
6 | static void test_tndbl_approxdecomp( //
7 | void (*approxdec)(const MOD_Z*, const TNDBL_APPROXDECOMP_GADGET*, INTTYPE*, uint64_t, const double*, uint64_t) //
8 | ) {
9 | for (const uint64_t nn : {1, 3, 8, 51}) {
10 | MOD_Z* module = new_z_module_info(DEFAULT);
11 | for (const uint64_t ell : {1, 2, 7}) {
12 | for (const uint64_t k : {2, 5}) {
13 | TNDBL_APPROXDECOMP_GADGET* gadget = new_tndbl_approxdecomp_gadget(module, k, ell);
14 | for (const uint64_t res_size : {ell * nn}) {
15 | std::vector in(nn);
16 | std::vector out(res_size);
17 | for (double& x : in) x = uniform_f64_bounds(-10, 10);
18 | approxdec(module, gadget, out.data(), res_size, in.data(), nn);
19 | // reconstruct the output
20 | double err_bnd = pow(2., -double(ell * k) - 1);
21 | for (uint64_t j = 0; j < nn; ++j) {
22 | double in_j = in[j];
23 | double out_j = 0;
24 | for (uint64_t i = 0; i < ell; ++i) {
25 | out_j += out[ell * j + i] * pow(2., -double((i + 1) * k));
26 | }
27 | double err = out_j - in_j;
28 | double err_abs = fabs(err - rint(err));
29 | ASSERT_LE(err_abs, err_bnd);
30 | }
31 | }
32 | delete_tndbl_approxdecomp_gadget(gadget);
33 | }
34 | }
35 | delete_z_module_info(module);
36 | }
37 | }
38 |
39 | TEST(vec_rnx, i8_tndbl_rnx_approxdecomp) { test_tndbl_approxdecomp(i8_approxdecomp_from_tndbl); }
40 | TEST(vec_rnx, default_i8_tndbl_rnx_approxdecomp) { test_tndbl_approxdecomp(default_i8_approxdecomp_from_tndbl_ref); }
41 |
42 | TEST(vec_rnx, i16_tndbl_rnx_approxdecomp) { test_tndbl_approxdecomp(i16_approxdecomp_from_tndbl); }
43 | TEST(vec_rnx, default_i16_tndbl_rnx_approxdecomp) { test_tndbl_approxdecomp(default_i16_approxdecomp_from_tndbl_ref); }
44 |
45 | TEST(vec_rnx, i32_tndbl_rnx_approxdecomp) { test_tndbl_approxdecomp(i32_approxdecomp_from_tndbl); }
46 | TEST(vec_rnx, default_i32_tndbl_rnx_approxdecomp) { test_tndbl_approxdecomp(default_i32_approxdecomp_from_tndbl_ref); }
47 |
--------------------------------------------------------------------------------
/test/spqlios_zn_conversions_test.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "testlib/test_commons.h"
5 |
6 | template
7 | static void test_conv(void (*conv_f)(const MOD_Z*, DST_T* res, uint64_t res_size, const SRC_T* a, uint64_t a_size),
8 | DST_T (*ideal_conv_f)(SRC_T x), SRC_T (*random_f)()) {
9 | MOD_Z* module = new_z_module_info(DEFAULT);
10 | for (uint64_t a_size : {0, 1, 2, 42}) {
11 | for (uint64_t res_size : {0, 1, 2, 42}) {
12 | for (uint64_t trials = 0; trials < 100; ++trials) {
13 | std::vector a(a_size);
14 | std::vector res(res_size);
15 | uint64_t msize = std::min(a_size, res_size);
16 | for (SRC_T& x : a) x = random_f();
17 | conv_f(module, res.data(), res_size, a.data(), a_size);
18 | for (uint64_t i = 0; i < msize; ++i) {
19 | DST_T expect = ideal_conv_f(a[i]);
20 | DST_T actual = res[i];
21 | ASSERT_EQ(expect, actual);
22 | }
23 | for (uint64_t i = msize; i < res_size; ++i) {
24 | DST_T expect = 0;
25 | SRC_T actual = res[i];
26 | ASSERT_EQ(expect, actual);
27 | }
28 | }
29 | }
30 | }
31 | delete_z_module_info(module);
32 | }
33 |
34 | static int32_t ideal_dbl_to_tn32(double a) {
35 | double _2p32 = INT64_C(1) << 32;
36 | double a_mod_1 = a - rint(a);
37 | int64_t t = rint(a_mod_1 * _2p32);
38 | return int32_t(t);
39 | }
40 |
41 | static double random_f64_10() { return uniform_f64_bounds(-10, 10); }
42 |
43 | static void test_dbl_to_tn32(DBL_TO_TN32_F dbl_to_tn32_f) {
44 | test_conv(dbl_to_tn32_f, ideal_dbl_to_tn32, random_f64_10);
45 | }
46 |
47 | TEST(zn_arithmetic, dbl_to_tn32) { test_dbl_to_tn32(dbl_to_tn32); }
48 | TEST(zn_arithmetic, dbl_to_tn32_ref) { test_dbl_to_tn32(dbl_to_tn32_ref); }
49 |
50 | static double ideal_tn32_to_dbl(int32_t a) {
51 | const double _2p32 = INT64_C(1) << 32;
52 | return double(a) / _2p32;
53 | }
54 |
55 | static int32_t random_t32() { return uniform_i64_bits(32); }
56 |
57 | static void test_tn32_to_dbl(TN32_TO_DBL_F tn32_to_dbl_f) { test_conv(tn32_to_dbl_f, ideal_tn32_to_dbl, random_t32); }
58 |
59 | TEST(zn_arithmetic, tn32_to_dbl) { test_tn32_to_dbl(tn32_to_dbl); }
60 | TEST(zn_arithmetic, tn32_to_dbl_ref) { test_tn32_to_dbl(tn32_to_dbl_ref); }
61 |
62 | static int32_t ideal_dbl_round_to_i32(double a) { return int32_t(rint(a)); }
63 |
64 | static double random_dbl_explaw_18() { return uniform_f64_bounds(-1., 1.) * pow(2., uniform_u64_bits(6) % 19); }
65 |
66 | static void test_dbl_round_to_i32(DBL_ROUND_TO_I32_F dbl_round_to_i32_f) {
67 | test_conv(dbl_round_to_i32_f, ideal_dbl_round_to_i32, random_dbl_explaw_18);
68 | }
69 |
70 | TEST(zn_arithmetic, dbl_round_to_i32) { test_dbl_round_to_i32(dbl_round_to_i32); }
71 | TEST(zn_arithmetic, dbl_round_to_i32_ref) { test_dbl_round_to_i32(dbl_round_to_i32_ref); }
72 |
73 | static double ideal_i32_to_dbl(int32_t a) { return double(a); }
74 |
75 | static int32_t random_i32_explaw_18() { return uniform_i64_bits(uniform_u64_bits(6) % 19); }
76 |
77 | static void test_i32_to_dbl(I32_TO_DBL_F i32_to_dbl_f) {
78 | test_conv(i32_to_dbl_f, ideal_i32_to_dbl, random_i32_explaw_18);
79 | }
80 |
81 | TEST(zn_arithmetic, i32_to_dbl) { test_i32_to_dbl(i32_to_dbl); }
82 | TEST(zn_arithmetic, i32_to_dbl_ref) { test_i32_to_dbl(i32_to_dbl_ref); }
83 |
84 | static int64_t ideal_dbl_round_to_i64(double a) { return rint(a); }
85 |
86 | static double random_dbl_explaw_50() { return uniform_f64_bounds(-1., 1.) * pow(2., uniform_u64_bits(7) % 51); }
87 |
88 | static void test_dbl_round_to_i64(DBL_ROUND_TO_I64_F dbl_round_to_i64_f) {
89 | test_conv(dbl_round_to_i64_f, ideal_dbl_round_to_i64, random_dbl_explaw_50);
90 | }
91 |
92 | TEST(zn_arithmetic, dbl_round_to_i64) { test_dbl_round_to_i64(dbl_round_to_i64); }
93 | TEST(zn_arithmetic, dbl_round_to_i64_ref) { test_dbl_round_to_i64(dbl_round_to_i64_ref); }
94 |
95 | static double ideal_i64_to_dbl(int64_t a) { return double(a); }
96 |
97 | static int64_t random_i64_explaw_50() { return uniform_i64_bits(uniform_u64_bits(7) % 51); }
98 |
99 | static void test_i64_to_dbl(I64_TO_DBL_F i64_to_dbl_f) {
100 | test_conv(i64_to_dbl_f, ideal_i64_to_dbl, random_i64_explaw_50);
101 | }
102 |
103 | TEST(zn_arithmetic, i64_to_dbl) { test_i64_to_dbl(i64_to_dbl); }
104 | TEST(zn_arithmetic, i64_to_dbl_ref) { test_i64_to_dbl(i64_to_dbl_ref); }
105 |
--------------------------------------------------------------------------------
/test/spqlios_zn_vmp_test.cpp:
--------------------------------------------------------------------------------
1 | #include "gtest/gtest.h"
2 | #include "spqlios/arithmetic/zn_arithmetic_private.h"
3 | #include "testlib/zn_layouts.h"
4 |
5 | static void test_zn_vmp_prepare(ZN32_VMP_PREPARE_CONTIGUOUS_F prep) {
6 | MOD_Z* module = new_z_module_info(DEFAULT);
7 | for (uint64_t nrows : {1, 2, 5, 15}) {
8 | for (uint64_t ncols : {1, 2, 32, 42, 67}) {
9 | std::vector src(nrows * ncols);
10 | zn32_pmat_layout out(nrows, ncols);
11 | for (int32_t& x : src) x = uniform_i64_bits(32);
12 | prep(module, out.data, src.data(), nrows, ncols);
13 | for (uint64_t i = 0; i < nrows; ++i) {
14 | for (uint64_t j = 0; j < ncols; ++j) {
15 | int32_t in = src[i * ncols + j];
16 | int32_t actual = out.get(i, j);
17 | ASSERT_EQ(actual, in);
18 | }
19 | }
20 | }
21 | }
22 | delete_z_module_info(module);
23 | }
24 |
25 | TEST(zn, zn32_vmp_prepare_contiguous) { test_zn_vmp_prepare(zn32_vmp_prepare_contiguous); }
26 | TEST(zn, default_zn32_vmp_prepare_contiguous_ref) { test_zn_vmp_prepare(default_zn32_vmp_prepare_contiguous_ref); }
27 |
28 | template
29 | static void test_zn_vmp_apply(void (*apply)(const MOD_Z*, int32_t*, uint64_t, const INTTYPE*, uint64_t,
30 | const ZN32_VMP_PMAT*, uint64_t, uint64_t)) {
31 | MOD_Z* module = new_z_module_info(DEFAULT);
32 | for (uint64_t nrows : {1, 2, 5, 15}) {
33 | for (uint64_t ncols : {1, 2, 32, 42, 67}) {
34 | for (uint64_t a_size : {1, 2, 5, 15}) {
35 | for (uint64_t res_size : {1, 2, 32, 42, 67}) {
36 | std::vector a(a_size);
37 | zn32_pmat_layout out(nrows, ncols);
38 | std::vector res(res_size);
39 | for (INTTYPE& x : a) x = uniform_i64_bits(32);
40 | out.fill_random();
41 | std::vector expect = vmp_product(a.data(), a_size, res_size, out);
42 | apply(module, res.data(), res_size, a.data(), a_size, out.data, nrows, ncols);
43 | for (uint64_t i = 0; i < res_size; ++i) {
44 | int32_t exp = expect[i];
45 | int32_t actual = res[i];
46 | ASSERT_EQ(actual, exp);
47 | }
48 | }
49 | }
50 | }
51 | }
52 | delete_z_module_info(module);
53 | }
54 |
55 | TEST(zn, zn32_vmp_apply_i32) { test_zn_vmp_apply(zn32_vmp_apply_i32); }
56 | TEST(zn, zn32_vmp_apply_i16) { test_zn_vmp_apply(zn32_vmp_apply_i16); }
57 | TEST(zn, zn32_vmp_apply_i8) { test_zn_vmp_apply(zn32_vmp_apply_i8); }
58 |
59 | TEST(zn, default_zn32_vmp_apply_i32_ref) { test_zn_vmp_apply(default_zn32_vmp_apply_i32_ref); }
60 | TEST(zn, default_zn32_vmp_apply_i16_ref) { test_zn_vmp_apply(default_zn32_vmp_apply_i16_ref); }
61 | TEST(zn, default_zn32_vmp_apply_i8_ref) { test_zn_vmp_apply(default_zn32_vmp_apply_i8_ref); }
62 |
63 | #ifdef __x86_64__
64 | TEST(zn, default_zn32_vmp_apply_i32_avx) { test_zn_vmp_apply(default_zn32_vmp_apply_i32_avx); }
65 | TEST(zn, default_zn32_vmp_apply_i16_avx) { test_zn_vmp_apply(default_zn32_vmp_apply_i16_avx); }
66 | TEST(zn, default_zn32_vmp_apply_i8_avx) { test_zn_vmp_apply(default_zn32_vmp_apply_i8_avx); }
67 | #endif
68 |
--------------------------------------------------------------------------------
/test/spqlios_znx_small_test.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "../spqlios/arithmetic/vec_znx_arithmetic_private.h"
4 | #include "testlib/negacyclic_polynomial.h"
5 |
6 | static void test_znx_small_single_product(ZNX_SMALL_SINGLE_PRODUCT_F product,
7 | ZNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F product_tmp_bytes) {
8 | for (const uint64_t nn : {2, 4, 8, 64}) {
9 | MODULE* module = new_module_info(nn, FFT64);
10 | znx_i64 a = znx_i64::random_log2bound(nn, 20);
11 | znx_i64 b = znx_i64::random_log2bound(nn, 20);
12 | znx_i64 expect = naive_product(a, b);
13 | znx_i64 actual(nn);
14 | std::vector tmp(znx_small_single_product_tmp_bytes(module));
15 | fft64_znx_small_single_product(module, actual.data(), a.data(), b.data(), tmp.data());
16 | ASSERT_EQ(actual, expect) << actual.get_coeff(0) << " vs. " << expect.get_coeff(0);
17 | delete_module_info(module);
18 | }
19 | }
20 |
21 | TEST(znx_small, fft64_znx_small_single_product) {
22 | test_znx_small_single_product(fft64_znx_small_single_product, fft64_znx_small_single_product_tmp_bytes);
23 | }
24 | TEST(znx_small, znx_small_single_product) {
25 | test_znx_small_single_product(znx_small_single_product, znx_small_single_product_tmp_bytes);
26 | }
27 |
--------------------------------------------------------------------------------
/test/testlib/fft64_dft.h:
--------------------------------------------------------------------------------
1 | #ifndef SPQLIOS_FFT64_DFT_H
2 | #define SPQLIOS_FFT64_DFT_H
3 |
4 | #include "negacyclic_polynomial.h"
5 | #include "reim4_elem.h"
6 |
7 | class reim_fft64vec {
8 | std::vector v;
9 |
10 | public:
11 | reim_fft64vec() = default;
12 | explicit reim_fft64vec(uint64_t n);
13 | reim_fft64vec(uint64_t n, const double* data);
14 | uint64_t nn() const;
15 | static reim_fft64vec zero(uint64_t n);
16 | /** random complex coefficients (unstructured) */
17 | static reim_fft64vec random(uint64_t n, double log2bound);
18 | /** random fft of a small int polynomial */
19 | static reim_fft64vec dft_random(uint64_t n, uint64_t log2bound);
20 | double* data();
21 | const double* data() const;
22 | void save_as(double* dest) const;
23 | reim4_elem get_blk(uint64_t blk) const;
24 | void set_blk(uint64_t blk, const reim4_elem& value);
25 | };
26 |
27 | reim_fft64vec operator+(const reim_fft64vec& a, const reim_fft64vec& b);
28 | reim_fft64vec operator-(const reim_fft64vec& a, const reim_fft64vec& b);
29 | reim_fft64vec operator*(const reim_fft64vec& a, const reim_fft64vec& b);
30 | reim_fft64vec operator*(double coeff, const reim_fft64vec& v);
31 | reim_fft64vec& operator+=(reim_fft64vec& a, const reim_fft64vec& b);
32 | reim_fft64vec& operator-=(reim_fft64vec& a, const reim_fft64vec& b);
33 |
34 | /** infty distance */
35 | double infty_dist(const reim_fft64vec& a, const reim_fft64vec& b);
36 |
37 | reim_fft64vec simple_fft64(const znx_i64& polynomial);
38 | znx_i64 simple_rint_ifft64(const reim_fft64vec& fftvec);
39 | rnx_f64 naive_ifft64(const reim_fft64vec& fftvec);
40 | reim_fft64vec simple_fft64(const rnx_f64& polynomial);
41 | rnx_f64 simple_ifft64(const reim_fft64vec& v);
42 |
43 | #endif // SPQLIOS_FFT64_DFT_H
44 |
--------------------------------------------------------------------------------
/test/testlib/fft64_layouts.h:
--------------------------------------------------------------------------------
1 | #ifndef SPQLIOS_FFT64_LAYOUTS_H
2 | #define SPQLIOS_FFT64_LAYOUTS_H
3 |
4 | #include "../../spqlios/arithmetic/vec_znx_arithmetic.h"
5 | #include "fft64_dft.h"
6 | #include "negacyclic_polynomial.h"
7 | #include "reim4_elem.h"
8 |
9 | /** @brief test layout for the VEC_ZNX_DFT */
10 | struct fft64_vec_znx_dft_layout {
11 | public:
12 | const uint64_t nn;
13 | const uint64_t size;
14 | VEC_ZNX_DFT* const data;
15 | reim_vector_view view;
16 | /** @brief fill with random double values (unstructured) */
17 | void fill_random(double log2bound);
18 | /** @brief fill with random ffts of small int polynomials */
19 | void fill_dft_random(uint64_t log2bound);
20 | reim4_elem get(uint64_t idx, uint64_t blk) const;
21 | reim4_elem get_zext(uint64_t idx, uint64_t blk) const;
22 | void set(uint64_t idx, uint64_t blk, const reim4_elem& value);
23 | fft64_vec_znx_dft_layout(uint64_t n, uint64_t size);
24 | void fill_random_log2bound(uint64_t bits);
25 | void fill_dft_random_log2bound(uint64_t bits);
26 | double* get_addr(uint64_t idx);
27 | const double* get_addr(uint64_t idx) const;
28 | reim_fft64vec get_copy_zext(uint64_t idx) const;
29 | void set(uint64_t idx, const reim_fft64vec& value);
30 | thash content_hash() const;
31 | ~fft64_vec_znx_dft_layout();
32 | };
33 |
34 | /** @brief test layout for the VEC_ZNX_BIG */
35 | class fft64_vec_znx_big_layout {
36 | public:
37 | const uint64_t nn;
38 | const uint64_t size;
39 | VEC_ZNX_BIG* const data;
40 | fft64_vec_znx_big_layout(uint64_t n, uint64_t size);
41 | void fill_random();
42 | znx_i64 get_copy(uint64_t index) const;
43 | znx_i64 get_copy_zext(uint64_t index) const;
44 | void set(uint64_t index, const znx_i64& value);
45 | thash content_hash() const;
46 | ~fft64_vec_znx_big_layout();
47 | };
48 |
49 | /** @brief test layout for the VMP_PMAT */
50 | class fft64_vmp_pmat_layout {
51 | public:
52 | const uint64_t nn;
53 | const uint64_t nrows;
54 | const uint64_t ncols;
55 | VMP_PMAT* const data;
56 | fft64_vmp_pmat_layout(uint64_t n, uint64_t nrows, uint64_t ncols);
57 | double* get_addr(uint64_t row, uint64_t col, uint64_t blk) const;
58 | reim4_elem get(uint64_t row, uint64_t col, uint64_t blk) const;
59 | thash content_hash() const;
60 | reim4_elem get_zext(uint64_t row, uint64_t col, uint64_t blk) const;
61 | reim_fft64vec get_zext(uint64_t row, uint64_t col) const;
62 | void set(uint64_t row, uint64_t col, uint64_t blk, const reim4_elem& v) const;
63 | void set(uint64_t row, uint64_t col, const reim_fft64vec& value);
64 | /** @brief fill with random double values (unstructured) */
65 | void fill_random(double log2bound);
66 | /** @brief fill with random ffts of small int polynomials */
67 | void fill_dft_random(uint64_t log2bound);
68 | ~fft64_vmp_pmat_layout();
69 | };
70 |
71 | /** @brief test layout for the SVP_PPOL */
72 | class fft64_svp_ppol_layout {
73 | public:
74 | const uint64_t nn;
75 | SVP_PPOL* const data;
76 | fft64_svp_ppol_layout(uint64_t n);
77 | thash content_hash() const;
78 | reim_fft64vec get_copy() const;
79 | void set(const reim_fft64vec&);
80 | /** @brief fill with random double values (unstructured) */
81 | void fill_random(double log2bound);
82 | /** @brief fill with random ffts of small int polynomials */
83 | void fill_dft_random(uint64_t log2bound);
84 | ~fft64_svp_ppol_layout();
85 | };
86 |
87 | /** @brief test layout for the CNV_PVEC_L */
88 | class fft64_cnv_left_layout {
89 | const uint64_t nn;
90 | const uint64_t size;
91 | CNV_PVEC_L* const data;
92 | fft64_cnv_left_layout(uint64_t n, uint64_t size);
93 | reim4_elem get(uint64_t idx, uint64_t blk);
94 | thash content_hash() const;
95 | ~fft64_cnv_left_layout();
96 | };
97 |
98 | /** @brief test layout for the CNV_PVEC_R */
99 | class fft64_cnv_right_layout {
100 | const uint64_t nn;
101 | const uint64_t size;
102 | CNV_PVEC_R* const data;
103 | fft64_cnv_right_layout(uint64_t n, uint64_t size);
104 | reim4_elem get(uint64_t idx, uint64_t blk);
105 | thash content_hash() const;
106 | ~fft64_cnv_right_layout();
107 | };
108 |
109 | #endif // SPQLIOS_FFT64_LAYOUTS_H
110 |
--------------------------------------------------------------------------------
/test/testlib/mod_q120.h:
--------------------------------------------------------------------------------
1 | #ifndef SPQLIOS_MOD_Q120_H
2 | #define SPQLIOS_MOD_Q120_H
3 |
4 | #include
5 |
6 | #include "../../spqlios/q120/q120_common.h"
7 | #include "test_commons.h"
8 |
9 | /** @brief centered modulo q */
10 | int64_t centermod(int64_t v, int64_t q);
11 | int64_t centermod(uint64_t v, int64_t q);
12 |
13 | /** @brief this class represents an integer mod Q120 */
14 | class mod_q120 {
15 | public:
16 | static constexpr int64_t Qi[] = {Q1, Q2, Q3, Q4};
17 | int64_t a[4];
18 | mod_q120(int64_t a1, int64_t a2, int64_t a3, int64_t a4);
19 | mod_q120();
20 | __int128_t to_int128() const;
21 | static mod_q120 from_q120a(const void* addr);
22 | static mod_q120 from_q120b(const void* addr);
23 | static mod_q120 from_q120c(const void* addr);
24 | void save_as_q120a(void* dest) const;
25 | void save_as_q120b(void* dest) const;
26 | void save_as_q120c(void* dest) const;
27 | };
28 |
29 | mod_q120 operator+(const mod_q120& x, const mod_q120& y);
30 | mod_q120 operator-(const mod_q120& x, const mod_q120& y);
31 | mod_q120 operator*(const mod_q120& x, const mod_q120& y);
32 | mod_q120& operator+=(mod_q120& x, const mod_q120& y);
33 | mod_q120& operator-=(mod_q120& x, const mod_q120& y);
34 | mod_q120& operator*=(mod_q120& x, const mod_q120& y);
35 | std::ostream& operator<<(std::ostream& out, const mod_q120& x);
36 | bool operator==(const mod_q120& x, const mod_q120& y);
37 | mod_q120 pow(const mod_q120& x, int32_t k);
38 | mod_q120 half(const mod_q120& x);
39 |
40 | /** @brief a uniformly drawn number mod Q120 */
41 | mod_q120 uniform_q120();
42 | /** @brief a uniformly random mod Q120 layout A (4 integers < 2^32) */
43 | void uniform_q120a(void* dest);
44 | /** @brief a uniformly random mod Q120 layout B (4 integers < 2^64) */
45 | void uniform_q120b(void* dest);
46 | /** @brief a uniformly random mod Q120 layout C (4 integers repr. x,2^32x) */
47 | void uniform_q120c(void* dest);
48 |
49 | #endif // SPQLIOS_MOD_Q120_H
50 |
--------------------------------------------------------------------------------
/test/testlib/negacyclic_polynomial.cpp:
--------------------------------------------------------------------------------
1 | #include "negacyclic_polynomial_impl.h"
2 |
3 | // explicit instantiation
4 | EXPLICIT_INSTANTIATE_POLYNOMIAL(__int128_t);
5 | EXPLICIT_INSTANTIATE_POLYNOMIAL(int64_t);
6 | EXPLICIT_INSTANTIATE_POLYNOMIAL(double);
7 |
8 | double infty_dist(const rnx_f64& a, const rnx_f64& b) {
9 | const uint64_t nn = a.nn();
10 | const double* aa = a.data();
11 | const double* bb = b.data();
12 | double res = 0.;
13 | for (uint64_t i = 0; i < nn; ++i) {
14 | double d = fabs(aa[i] - bb[i]);
15 | if (d > res) res = d;
16 | }
17 | return res;
18 | }
19 |
--------------------------------------------------------------------------------
/test/testlib/negacyclic_polynomial.h:
--------------------------------------------------------------------------------
1 | #ifndef SPQLIOS_NEGACYCLIC_POLYNOMIAL_H
2 | #define SPQLIOS_NEGACYCLIC_POLYNOMIAL_H
3 |
4 | #include
5 |
6 | #include "test_commons.h"
7 |
8 | template
9 | class polynomial;
10 | typedef polynomial<__int128_t> znx_i128;
11 | typedef polynomial znx_i64;
12 | typedef polynomial rnx_f64;
13 |
14 | template
15 | class polynomial {
16 | public:
17 | std::vector coeffs;
18 | /** @brief create a polynomial out of existing coeffs */
19 | polynomial(uint64_t N, const T* c);
20 | /** @brief zero polynomial of dimension N */
21 | explicit polynomial(uint64_t N);
22 | /** @brief empty polynomial (dim 0) */
23 | polynomial();
24 |
25 | /** @brief ring dimension */
26 | uint64_t nn() const;
27 | /** @brief special setter (accept any indexes, and does the negacyclic translation) */
28 | void set_coeff(int64_t i, T v);
29 | /** @brief special getter (accept any indexes, and does the negacyclic translation) */
30 | T get_coeff(int64_t i) const;
31 | /** @brief returns the coefficient layout */
32 | T* data();
33 | /** @brief returns the coefficient layout (const version) */
34 | const T* data() const;
35 | /** @brief saves to the layout */
36 | void save_as(T* dest) const;
37 | /** @brief zero */
38 | static polynomial zero(uint64_t n);
39 | /** @brief random polynomial with coefficients in [-2^log2bounds, 2^log2bounds]*/
40 | static polynomial random_log2bound(uint64_t n, uint64_t log2bound);
41 | /** @brief random polynomial with coefficients in [-2^log2bounds, 2^log2bounds]*/
42 | static polynomial random(uint64_t n);
43 | /** @brief random polynomial with coefficient in [lb;ub] */
44 | static polynomial random_bound(uint64_t n, const T lb, const T ub);
45 | };
46 |
47 | /** @brief equality operator (used during tests) */
48 | template
49 | bool operator==(const polynomial& a, const polynomial& b);
50 |
51 | /** @brief addition operator (used during tests) */
52 | template
53 | polynomial operator+(const polynomial& a, const polynomial& b);
54 |
55 | /** @brief subtraction operator (used during tests) */
56 | template
57 | polynomial operator-(const polynomial& a, const polynomial& b);
58 |
59 | /** @brief negation operator (used during tests) */
60 | template
61 | polynomial operator-(const polynomial& a);
62 |
63 | template
64 | polynomial naive_product(const polynomial& a, const polynomial& b);
65 |
66 | /** @brief distance between two real polynomials (used during tests) */
67 | double infty_dist(const rnx_f64& a, const rnx_f64& b);
68 |
69 | #endif // SPQLIOS_NEGACYCLIC_POLYNOMIAL_H
70 |
--------------------------------------------------------------------------------
/test/testlib/ntt120_dft.cpp:
--------------------------------------------------------------------------------
1 | #include "ntt120_dft.h"
2 |
3 | #include "mod_q120.h"
4 |
5 | // @brief alternative version of the NTT
6 |
7 | /** for all s=k/2^17, root_of_unity(s) = omega_0^k */
8 | static mod_q120 root_of_unity(double s) {
9 | static mod_q120 omega_2pow17{OMEGA1, OMEGA2, OMEGA3, OMEGA4};
10 | static double _2pow17 = 1 << 17;
11 | return pow(omega_2pow17, s * _2pow17);
12 | }
13 | static mod_q120 root_of_unity_inv(double s) {
14 | static mod_q120 omega_2pow17{OMEGA1, OMEGA2, OMEGA3, OMEGA4};
15 | static double _2pow17 = 1 << 17;
16 | return pow(omega_2pow17, -s * _2pow17);
17 | }
18 |
19 | /** recursive naive ntt */
20 | static void q120_ntt_naive_rec(uint64_t n, double entry_pwr, mod_q120* data) {
21 | if (n == 1) return;
22 | const uint64_t h = n / 2;
23 | const double s = entry_pwr / 2.;
24 | mod_q120 om = root_of_unity(s);
25 | for (uint64_t j = 0; j < h; ++j) {
26 | mod_q120 om_right = data[h + j] * om;
27 | data[h + j] = data[j] - om_right;
28 | data[j] = data[j] + om_right;
29 | }
30 | q120_ntt_naive_rec(h, s, data);
31 | q120_ntt_naive_rec(h, s + 0.5, data + h);
32 | }
33 | static void q120_intt_naive_rec(uint64_t n, double entry_pwr, mod_q120* data) {
34 | if (n == 1) return;
35 | const uint64_t h = n / 2;
36 | const double s = entry_pwr / 2.;
37 | q120_intt_naive_rec(h, s, data);
38 | q120_intt_naive_rec(h, s + 0.5, data + h);
39 | mod_q120 om = root_of_unity_inv(s);
40 | for (uint64_t j = 0; j < h; ++j) {
41 | mod_q120 dat_diff = half(data[j] - data[h + j]);
42 | data[j] = half(data[j] + data[h + j]);
43 | data[h + j] = dat_diff * om;
44 | }
45 | }
46 |
47 | /** user friendly version */
48 | q120_nttvec simple_ntt120(const znx_i64& polynomial) {
49 | const uint64_t n = polynomial.nn();
50 | q120_nttvec res(n);
51 | for (uint64_t i = 0; i < n; ++i) {
52 | int64_t xi = polynomial.get_coeff(i);
53 | res.v[i] = mod_q120(xi, xi, xi, xi);
54 | }
55 | q120_ntt_naive_rec(n, 0.5, res.v.data());
56 | return res;
57 | }
58 |
59 | znx_i128 simple_intt120(const q120_nttvec& fftvec) {
60 | const uint64_t n = fftvec.nn();
61 | q120_nttvec copy = fftvec;
62 | znx_i128 res(n);
63 | q120_intt_naive_rec(n, 0.5, copy.v.data());
64 | for (uint64_t i = 0; i < n; ++i) {
65 | res.set_coeff(i, copy.v[i].to_int128());
66 | }
67 | return res;
68 | }
69 | bool operator==(const q120_nttvec& a, const q120_nttvec& b) { return a.v == b.v; }
70 |
71 | std::vector q120_ntt_naive(const std::vector& x) {
72 | std::vector res = x;
73 | q120_ntt_naive_rec(res.size(), 0.5, res.data());
74 | return res;
75 | }
76 | q120_nttvec::q120_nttvec(uint64_t n) : v(n) {}
77 | q120_nttvec::q120_nttvec(uint64_t n, const q120b* data) : v(n) {
78 | int64_t* d = (int64_t*)data;
79 | for (uint64_t i = 0; i < n; ++i) {
80 | v[i] = mod_q120::from_q120b(d + 4 * i);
81 | }
82 | }
83 | q120_nttvec::q120_nttvec(uint64_t n, const q120c* data) : v(n) {
84 | int64_t* d = (int64_t*)data;
85 | for (uint64_t i = 0; i < n; ++i) {
86 | v[i] = mod_q120::from_q120c(d + 4 * i);
87 | }
88 | }
89 | uint64_t q120_nttvec::nn() const { return v.size(); }
90 | q120_nttvec q120_nttvec::zero(uint64_t n) { return q120_nttvec(n); }
91 | void q120_nttvec::save_as(q120a* dest) const {
92 | int64_t* const d = (int64_t*)dest;
93 | const uint64_t n = nn();
94 | for (uint64_t i = 0; i < n; ++i) {
95 | v[i].save_as_q120a(d + 4 * i);
96 | }
97 | }
98 | void q120_nttvec::save_as(q120b* dest) const {
99 | int64_t* const d = (int64_t*)dest;
100 | const uint64_t n = nn();
101 | for (uint64_t i = 0; i < n; ++i) {
102 | v[i].save_as_q120b(d + 4 * i);
103 | }
104 | }
105 | void q120_nttvec::save_as(q120c* dest) const {
106 | int64_t* const d = (int64_t*)dest;
107 | const uint64_t n = nn();
108 | for (uint64_t i = 0; i < n; ++i) {
109 | v[i].save_as_q120c(d + 4 * i);
110 | }
111 | }
112 | mod_q120 q120_nttvec::get_blk(uint64_t blk) const {
113 | REQUIRE_DRAMATICALLY(blk < nn(), "blk overflow");
114 | return v[blk];
115 | }
116 | q120_nttvec q120_nttvec::random(uint64_t n) {
117 | q120_nttvec res(n);
118 | for (uint64_t i = 0; i < n; ++i) {
119 | res.v[i] = uniform_q120();
120 | }
121 | return res;
122 | }
123 |
--------------------------------------------------------------------------------
/test/testlib/ntt120_dft.h:
--------------------------------------------------------------------------------
1 | #ifndef SPQLIOS_NTT120_DFT_H
2 | #define SPQLIOS_NTT120_DFT_H
3 |
4 | #include
5 |
6 | #include "../../spqlios/q120/q120_arithmetic.h"
7 | #include "mod_q120.h"
8 | #include "negacyclic_polynomial.h"
9 | #include "test_commons.h"
10 |
11 | class q120_nttvec {
12 | public:
13 | std::vector v;
14 | q120_nttvec() = default;
15 | explicit q120_nttvec(uint64_t n);
16 | q120_nttvec(uint64_t n, const q120b* data);
17 | q120_nttvec(uint64_t n, const q120c* data);
18 | uint64_t nn() const;
19 | static q120_nttvec zero(uint64_t n);
20 | static q120_nttvec random(uint64_t n);
21 | void save_as(q120a* dest) const;
22 | void save_as(q120b* dest) const;
23 | void save_as(q120c* dest) const;
24 | mod_q120 get_blk(uint64_t blk) const;
25 | };
26 |
27 | q120_nttvec simple_ntt120(const znx_i64& polynomial);
28 | znx_i128 simple_intt120(const q120_nttvec& fftvec);
29 | bool operator==(const q120_nttvec& a, const q120_nttvec& b);
30 |
31 | #endif // SPQLIOS_NTT120_DFT_H
32 |
--------------------------------------------------------------------------------
/test/testlib/ntt120_layouts.cpp:
--------------------------------------------------------------------------------
1 | #include "ntt120_layouts.h"
2 |
3 | mod_q120x2::mod_q120x2() {}
4 | mod_q120x2::mod_q120x2(const mod_q120& a, const mod_q120& b) {
5 | value[0] = a;
6 | value[1] = b;
7 | }
8 | mod_q120x2::mod_q120x2(q120x2b* addr) {
9 | uint64_t* p = (uint64_t*)addr;
10 | value[0] = mod_q120::from_q120b(p);
11 | value[1] = mod_q120::from_q120b(p + 4);
12 | }
13 |
14 | ntt120_vec_znx_dft_layout::ntt120_vec_znx_dft_layout(uint64_t n, uint64_t size)
15 | : nn(n), //
16 | size(size), //
17 | data((VEC_ZNX_DFT*)alloc64(n * size * 4 * sizeof(uint64_t))) {}
18 |
19 | mod_q120x2 ntt120_vec_znx_dft_layout::get_copy_zext(uint64_t idx, uint64_t blk) {
20 | return mod_q120x2(get_blk(idx, blk));
21 | }
22 | q120x2b* ntt120_vec_znx_dft_layout::get_blk(uint64_t idx, uint64_t blk) {
23 | REQUIRE_DRAMATICALLY(idx < size, "idx overflow");
24 | REQUIRE_DRAMATICALLY(blk < nn / 2, "blk overflow");
25 | uint64_t* d = (uint64_t*)data;
26 | return (q120x2b*)(d + 4 * nn * idx + 8 * blk);
27 | }
28 | ntt120_vec_znx_dft_layout::~ntt120_vec_znx_dft_layout() { spqlios_free(data); }
29 | q120_nttvec ntt120_vec_znx_dft_layout::get_copy_zext(uint64_t idx) {
30 | int64_t* d = (int64_t*)data;
31 | if (idx < size) {
32 | return q120_nttvec(nn, (q120b*)(d + idx * nn * 4));
33 | } else {
34 | return q120_nttvec::zero(nn);
35 | }
36 | }
37 | void ntt120_vec_znx_dft_layout::set(uint64_t idx, const q120_nttvec& value) {
38 | REQUIRE_DRAMATICALLY(idx < size, "index overflow: " << idx << " / " << size);
39 | q120b* dest_addr = (q120b*)((int64_t*)data + idx * nn * 4);
40 | value.save_as(dest_addr);
41 | }
42 | void ntt120_vec_znx_dft_layout::fill_random() {
43 | for (uint64_t i = 0; i < size; ++i) {
44 | set(i, q120_nttvec::random(nn));
45 | }
46 | }
47 | thash ntt120_vec_znx_dft_layout::content_hash() const { return test_hash(data, nn * size * 4 * sizeof(int64_t)); }
48 | ntt120_vec_znx_big_layout::ntt120_vec_znx_big_layout(uint64_t n, uint64_t size)
49 | : nn(n), //
50 | size(size),
51 | data((VEC_ZNX_BIG*)alloc64(n * size * sizeof(__int128_t))) {}
52 |
53 | znx_i128 ntt120_vec_znx_big_layout::get_copy(uint64_t index) const { return znx_i128(nn, get_addr(index)); }
54 | znx_i128 ntt120_vec_znx_big_layout::get_copy_zext(uint64_t index) const {
55 | if (index < size) {
56 | return znx_i128(nn, get_addr(index));
57 | } else {
58 | return znx_i128::zero(nn);
59 | }
60 | }
61 | __int128* ntt120_vec_znx_big_layout::get_addr(uint64_t index) const {
62 | REQUIRE_DRAMATICALLY(index < size, "index overflow: " << index << " / " << size);
63 | return (__int128_t*)data + index * nn;
64 | }
65 | void ntt120_vec_znx_big_layout::set(uint64_t index, const znx_i128& value) { value.save_as(get_addr(index)); }
66 | ntt120_vec_znx_big_layout::~ntt120_vec_znx_big_layout() { spqlios_free(data); }
67 |
--------------------------------------------------------------------------------
/test/testlib/ntt120_layouts.h:
--------------------------------------------------------------------------------
1 | #ifndef SPQLIOS_NTT120_LAYOUTS_H
2 | #define SPQLIOS_NTT120_LAYOUTS_H
3 |
4 | #include "../../spqlios/arithmetic/vec_znx_arithmetic.h"
5 | #include "mod_q120.h"
6 | #include "negacyclic_polynomial.h"
7 | #include "ntt120_dft.h"
8 | #include "test_commons.h"
9 |
10 | struct q120b_vector_view {};
11 |
12 | struct mod_q120x2 {
13 | mod_q120 value[2];
14 | mod_q120x2();
15 | mod_q120x2(const mod_q120& a, const mod_q120& b);
16 | mod_q120x2(__int128_t value);
17 | explicit mod_q120x2(q120x2b* addr);
18 | explicit mod_q120x2(q120x2c* addr);
19 | void save_as(q120x2b* addr) const;
20 | void save_as(q120x2c* addr) const;
21 | static mod_q120x2 random();
22 | };
23 | mod_q120x2 operator+(const mod_q120x2& a, const mod_q120x2& b);
24 | mod_q120x2 operator-(const mod_q120x2& a, const mod_q120x2& b);
25 | mod_q120x2 operator*(const mod_q120x2& a, const mod_q120x2& b);
26 | bool operator==(const mod_q120x2& a, const mod_q120x2& b);
27 | bool operator!=(const mod_q120x2& a, const mod_q120x2& b);
28 | mod_q120x2& operator+=(mod_q120x2& a, const mod_q120x2& b);
29 | mod_q120x2& operator-=(mod_q120x2& a, const mod_q120x2& b);
30 |
31 | /** @brief test layout for the VEC_ZNX_DFT */
32 | struct ntt120_vec_znx_dft_layout {
33 | const uint64_t nn;
34 | const uint64_t size;
35 | VEC_ZNX_DFT* const data;
36 | ntt120_vec_znx_dft_layout(uint64_t n, uint64_t size);
37 | mod_q120x2 get_copy_zext(uint64_t idx, uint64_t blk);
38 | q120_nttvec get_copy_zext(uint64_t idx);
39 | void set(uint64_t idx, const q120_nttvec& v);
40 | q120x2b* get_blk(uint64_t idx, uint64_t blk);
41 | thash content_hash() const;
42 | void fill_random();
43 | ~ntt120_vec_znx_dft_layout();
44 | };
45 |
46 | /** @brief test layout for the VEC_ZNX_BIG */
47 | class ntt120_vec_znx_big_layout {
48 | public:
49 | const uint64_t nn;
50 | const uint64_t size;
51 | VEC_ZNX_BIG* const data;
52 | ntt120_vec_znx_big_layout(uint64_t n, uint64_t size);
53 |
54 | private:
55 | __int128* get_addr(uint64_t index) const;
56 |
57 | public:
58 | znx_i128 get_copy(uint64_t index) const;
59 | znx_i128 get_copy_zext(uint64_t index) const;
60 | void set(uint64_t index, const znx_i128& value);
61 | ~ntt120_vec_znx_big_layout();
62 | };
63 |
64 | /** @brief test layout for the VMP_PMAT */
65 | class ntt120_vmp_pmat_layout {
66 | const uint64_t nn;
67 | const uint64_t nrows;
68 | const uint64_t ncols;
69 | VMP_PMAT* const data;
70 | ntt120_vmp_pmat_layout(uint64_t n, uint64_t nrows, uint64_t ncols);
71 | mod_q120x2 get(uint64_t row, uint64_t col, uint64_t blk) const;
72 | ~ntt120_vmp_pmat_layout();
73 | };
74 |
75 | /** @brief test layout for the SVP_PPOL */
76 | class ntt120_svp_ppol_layout {
77 | const uint64_t nn;
78 | SVP_PPOL* const data;
79 | ntt120_svp_ppol_layout(uint64_t n);
80 | ~ntt120_svp_ppol_layout();
81 | };
82 |
83 | /** @brief test layout for the CNV_PVEC_L */
84 | class ntt120_cnv_left_layout {
85 | const uint64_t nn;
86 | const uint64_t size;
87 | CNV_PVEC_L* const data;
88 | ntt120_cnv_left_layout(uint64_t n, uint64_t size);
89 | mod_q120x2 get(uint64_t idx, uint64_t blk);
90 | ~ntt120_cnv_left_layout();
91 | };
92 |
93 | /** @brief test layout for the CNV_PVEC_R */
94 | class ntt120_cnv_right_layout {
95 | const uint64_t nn;
96 | const uint64_t size;
97 | CNV_PVEC_R* const data;
98 | ntt120_cnv_right_layout(uint64_t n, uint64_t size);
99 | mod_q120x2 get(uint64_t idx, uint64_t blk);
100 | ~ntt120_cnv_right_layout();
101 | };
102 |
103 | #endif // SPQLIOS_NTT120_LAYOUTS_H
104 |
--------------------------------------------------------------------------------
/test/testlib/polynomial_vector.cpp:
--------------------------------------------------------------------------------
1 | #include "polynomial_vector.h"
2 |
3 | #include
4 |
5 | #ifdef VALGRIND_MEM_TESTS
6 | #include "valgrind/memcheck.h"
7 | #endif
8 |
9 | #define CANARY_PADDING (1024)
10 | #define GARBAGE_VALUE (242)
11 |
12 | znx_vec_i64_layout::znx_vec_i64_layout(uint64_t n, uint64_t size, uint64_t slice) : n(n), size(size), slice(slice) {
13 | REQUIRE_DRAMATICALLY(is_pow2(n), "not a power of 2" << n);
14 | REQUIRE_DRAMATICALLY(slice >= n, "slice too small" << slice << " < " << n);
15 | this->region = (uint8_t*)malloc(size * slice * sizeof(int64_t) + 2 * CANARY_PADDING);
16 | this->data_start = (int64_t*)(region + CANARY_PADDING);
17 | // ensure that any invalid value is kind-of garbage
18 | memset(region, GARBAGE_VALUE, size * slice * sizeof(int64_t) + 2 * CANARY_PADDING);
19 | // mark inter-slice memory as non accessible
20 | #ifdef VALGRIND_MEM_TESTS
21 | VALGRIND_MAKE_MEM_NOACCESS(region, CANARY_PADDING);
22 | VALGRIND_MAKE_MEM_NOACCESS(region + size * slice * sizeof(int64_t) + CANARY_PADDING, CANARY_PADDING);
23 | for (uint64_t i = 0; i < size; ++i) {
24 | VALGRIND_MAKE_MEM_UNDEFINED(data_start + i * slice, n * sizeof(int64_t));
25 | }
26 | if (size != slice) {
27 | for (uint64_t i = 0; i < size; ++i) {
28 | VALGRIND_MAKE_MEM_NOACCESS(data_start + i * slice + n, (slice - n) * sizeof(int64_t));
29 | }
30 | }
31 | #endif
32 | }
33 |
34 | znx_vec_i64_layout::~znx_vec_i64_layout() { free(region); }
35 |
36 | znx_i64 znx_vec_i64_layout::get_copy_zext(uint64_t index) const {
37 | if (index < size) {
38 | return znx_i64(n, data_start + index * slice);
39 | } else {
40 | return znx_i64::zero(n);
41 | }
42 | }
43 |
44 | znx_i64 znx_vec_i64_layout::get_copy(uint64_t index) const {
45 | REQUIRE_DRAMATICALLY(index < size, "index overflow: " << index << " / " << size);
46 | return znx_i64(n, data_start + index * slice);
47 | }
48 |
49 | void znx_vec_i64_layout::set(uint64_t index, const znx_i64& elem) {
50 | REQUIRE_DRAMATICALLY(index < size, "index overflow: " << index << " / " << size);
51 | REQUIRE_DRAMATICALLY(elem.nn() == n, "incompatible ring dimensions: " << elem.nn() << " / " << n);
52 | elem.save_as(data_start + index * slice);
53 | }
54 |
55 | int64_t* znx_vec_i64_layout::data() { return data_start; }
56 | const int64_t* znx_vec_i64_layout::data() const { return data_start; }
57 |
58 | void znx_vec_i64_layout::fill_random(uint64_t bits) {
59 | for (uint64_t i = 0; i < size; ++i) {
60 | set(i, znx_i64::random_log2bound(n, bits));
61 | }
62 | }
63 | __uint128_t znx_vec_i64_layout::content_hash() const {
64 | test_hasher hasher;
65 | for (uint64_t i = 0; i < size; ++i) {
66 | hasher.update(data() + i * slice, n * sizeof(int64_t));
67 | }
68 | return hasher.hash();
69 | }
70 |
--------------------------------------------------------------------------------
/test/testlib/polynomial_vector.h:
--------------------------------------------------------------------------------
1 | #ifndef SPQLIOS_POLYNOMIAL_VECTOR_H
2 | #define SPQLIOS_POLYNOMIAL_VECTOR_H
3 |
4 | #include "negacyclic_polynomial.h"
5 | #include "test_commons.h"
6 |
7 | /** @brief a test memory layout for znx i64 polynomials vectors */
8 | class znx_vec_i64_layout {
9 | uint64_t n;
10 | uint64_t size;
11 | uint64_t slice;
12 | int64_t* data_start;
13 | uint8_t* region;
14 |
15 | public:
16 | // NO-COPY structure
17 | znx_vec_i64_layout(const znx_vec_i64_layout&) = delete;
18 | void operator=(const znx_vec_i64_layout&) = delete;
19 | znx_vec_i64_layout(znx_vec_i64_layout&&) = delete;
20 | void operator=(znx_vec_i64_layout&&) = delete;
21 | /** @brief initialises a memory layout */
22 | znx_vec_i64_layout(uint64_t n, uint64_t size, uint64_t slice);
23 | /** @brief destructor */
24 | ~znx_vec_i64_layout();
25 |
26 | /** @brief get a copy of item index index (extended with zeros) */
27 | znx_i64 get_copy_zext(uint64_t index) const;
28 | /** @brief get a copy of item index index (extended with zeros) */
29 | znx_i64 get_copy(uint64_t index) const;
30 | /** @brief get a copy of item index index (index
2 | #include
3 |
4 | #include "test_commons.h"
5 |
6 | bool is_pow2(uint64_t n) { return !(n & (n - 1)); }
7 |
8 | test_rng& randgen() {
9 | static test_rng gen;
10 | return gen;
11 | }
12 | uint64_t uniform_u64() {
13 | static std::uniform_int_distribution dist64(0, UINT64_MAX);
14 | return dist64(randgen());
15 | }
16 |
17 | uint64_t uniform_u64_bits(uint64_t nbits) {
18 | if (nbits >= 64) return uniform_u64();
19 | return uniform_u64() >> (64 - nbits);
20 | }
21 |
22 | int64_t uniform_i64() {
23 | std::uniform_int_distribution dist;
24 | return dist(randgen());
25 | }
26 |
27 | int64_t uniform_i64_bits(uint64_t nbits) {
28 | int64_t bound = int64_t(1) << nbits;
29 | std::uniform_int_distribution dist(-bound, bound);
30 | return dist(randgen());
31 | }
32 |
33 | int64_t uniform_i64_bounds(const int64_t lb, const int64_t ub) {
34 | std::uniform_int_distribution dist(lb, ub);
35 | return dist(randgen());
36 | }
37 |
38 | __int128_t uniform_i128_bounds(const __int128_t lb, const __int128_t ub) {
39 | std::uniform_int_distribution<__int128_t> dist(lb, ub);
40 | return dist(randgen());
41 | }
42 |
43 | double random_f64_gaussian(double stdev) {
44 | std::normal_distribution dist(0, stdev);
45 | return dist(randgen());
46 | }
47 |
48 | double uniform_f64_bounds(const double lb, const double ub) {
49 | std::uniform_real_distribution dist(lb, ub);
50 | return dist(randgen());
51 | }
52 |
53 | double uniform_f64_01() {
54 | return uniform_f64_bounds(0, 1);
55 | }
56 |
--------------------------------------------------------------------------------
/test/testlib/reim4_elem.h:
--------------------------------------------------------------------------------
1 | #ifndef SPQLIOS_REIM4_ELEM_H
2 | #define SPQLIOS_REIM4_ELEM_H
3 |
4 | #include "test_commons.h"
5 |
6 | /** @brief test class representing one single reim4 element */
7 | class reim4_elem {
8 | public:
9 | /** @brief 8 components (4 real parts followed by 4 imag parts) */
10 | double value[8];
11 | /** @brief constructs from 4 real parts and 4 imaginary parts */
12 | reim4_elem(const double* re, const double* im);
13 | /** @brief constructs from 8 components */
14 | explicit reim4_elem(const double* layout);
15 | /** @brief zero */
16 | reim4_elem();
17 | /** @brief saves the real parts to re and the 4 imag to im */
18 | void save_re_im(double* re, double* im) const;
19 | /** @brief saves the 8 components to reim4 */
20 | void save_as(double* reim4) const;
21 | static reim4_elem zero();
22 | };
23 |
24 | /** @brief checks for equality */
25 | bool operator==(const reim4_elem& x, const reim4_elem& y);
26 | /** @brief random gaussian reim4 of stdev 1 and mean 0 */
27 | reim4_elem gaussian_reim4();
28 | /** @brief addition */
29 | reim4_elem operator+(const reim4_elem& x, const reim4_elem& y);
30 | reim4_elem& operator+=(reim4_elem& x, const reim4_elem& y);
31 | /** @brief subtraction */
32 | reim4_elem operator-(const reim4_elem& x, const reim4_elem& y);
33 | reim4_elem& operator-=(reim4_elem& x, const reim4_elem& y);
34 | /** @brief product */
35 | reim4_elem operator*(const reim4_elem& x, const reim4_elem& y);
36 | std::ostream& operator<<(std::ostream& out, const reim4_elem& x);
37 | /** @brief distance in infty norm */
38 | double infty_dist(const reim4_elem& x, const reim4_elem& y);
39 |
40 | /** @brief test class representing the view of one reim of m complexes */
41 | class reim4_array_view {
42 | uint64_t size; ///< size of the reim array
43 | double* data; ///< pointer to the start of the array
44 | public:
45 | /** @brief ininitializes a view at an existing given address */
46 | reim4_array_view(uint64_t size, double* data);
47 | ;
48 | /** @brief gets the i-th element */
49 | reim4_elem get(uint64_t i) const;
50 | /** @brief sets the i-th element */
51 | void set(uint64_t i, const reim4_elem& value);
52 | };
53 |
54 | /** @brief test class representing the view of one matrix of nrowsxncols reim4's */
55 | class reim4_matrix_view {
56 | uint64_t nrows; ///< number of rows
57 | uint64_t ncols; ///< number of columns
58 | double* data; ///< pointer to the start of the matrix
59 | public:
60 | /** @brief ininitializes a view at an existing given address */
61 | reim4_matrix_view(uint64_t nrows, uint64_t ncols, double* data);
62 | /** @brief gets the i-th element */
63 | reim4_elem get(uint64_t row, uint64_t col) const;
64 | /** @brief sets the i-th element */
65 | void set(uint64_t row, uint64_t col, const reim4_elem& value);
66 | };
67 |
68 | /** @brief test class representing the view of one reim of m complexes */
69 | class reim_view {
70 | uint64_t m; ///< (complex) dimension of the reim polynomial
71 | double* data; ///< address of the start of the reim polynomial
72 | public:
73 | /** @brief ininitializes a view at an existing given address */
74 | reim_view(uint64_t m, double* data);
75 | ;
76 | /** @brief extracts the i-th reim4 block (i
3 | // https://github.com/mjosaarinen/tiny_sha3
4 | // License: MIT
5 |
6 | #ifndef SHA3_H
7 | #define SHA3_H
8 |
9 | #ifdef __cplusplus
10 | extern "C" {
11 | #endif
12 |
13 | #include
14 | #include
15 |
16 | #ifndef KECCAKF_ROUNDS
17 | #define KECCAKF_ROUNDS 24
18 | #endif
19 |
20 | #ifndef ROTL64
21 | #define ROTL64(x, y) (((x) << (y)) | ((x) >> (64 - (y))))
22 | #endif
23 |
24 | // state context
25 | typedef struct {
26 | union { // state:
27 | uint8_t b[200]; // 8-bit bytes
28 | uint64_t q[25]; // 64-bit words
29 | } st;
30 | int pt, rsiz, mdlen; // these don't overflow
31 | } sha3_ctx_t;
32 |
33 | // Compression function.
34 | void sha3_keccakf(uint64_t st[25]);
35 |
36 | // OpenSSL - like interfece
37 | int sha3_init(sha3_ctx_t* c, int mdlen); // mdlen = hash output in bytes
38 | int sha3_update(sha3_ctx_t* c, const void* data, size_t len);
39 | int sha3_final(void* md, sha3_ctx_t* c); // digest goes to md
40 |
41 | // compute a sha3 hash (md) of given byte length from "in"
42 | void* sha3(const void* in, size_t inlen, void* md, int mdlen);
43 |
44 | // SHAKE128 and SHAKE256 extensible-output functions
45 | #define shake128_init(c) sha3_init(c, 16)
46 | #define shake256_init(c) sha3_init(c, 32)
47 | #define shake_update sha3_update
48 |
49 | void shake_xof(sha3_ctx_t* c);
50 | void shake_out(sha3_ctx_t* c, void* out, size_t len);
51 |
52 | #ifdef __cplusplus
53 | }
54 | #endif
55 |
56 | #endif // SHA3_H
57 |
--------------------------------------------------------------------------------
/test/testlib/test_commons.cpp:
--------------------------------------------------------------------------------
1 | #include "test_commons.h"
2 |
3 | #include
4 |
5 | std::ostream& operator<<(std::ostream& out, __int128_t x) {
6 | char c[35] = {0};
7 | snprintf(c, 35, "0x%016" PRIx64 "%016" PRIx64, uint64_t(x >> 64), uint64_t(x));
8 | return out << c;
9 | }
10 | std::ostream& operator<<(std::ostream& out, __uint128_t x) { return out << __int128_t(x); }
11 |
--------------------------------------------------------------------------------
/test/testlib/test_commons.h:
--------------------------------------------------------------------------------
1 | #ifndef SPQLIOS_TEST_COMMONS_H
2 | #define SPQLIOS_TEST_COMMONS_H
3 |
4 | #include
5 | #include
6 |
7 | #include "../../spqlios/commons.h"
8 |
9 | /** @brief macro that crashes if the condition are not met */
10 | #define REQUIRE_DRAMATICALLY(req_contition, error_msg) \
11 | do { \
12 | if (!(req_contition)) { \
13 | std::cerr << "REQUIREMENT FAILED at " << __FILE__ << ":" << __LINE__ << ": " << error_msg << std::endl; \
14 | abort(); \
15 | } \
16 | } while (0)
17 |
18 | typedef std::default_random_engine test_rng;
19 | /** @brief reference to the default test rng */
20 | test_rng& randgen();
21 | /** @brief uniformly random 64-bit uint */
22 | uint64_t uniform_u64();
23 | /** @brief uniformly random number <= 2^nbits-1 */
24 | uint64_t uniform_u64_bits(uint64_t nbits);
25 | /** @brief uniformly random signed 64-bit number */
26 | int64_t uniform_i64();
27 | /** @brief uniformly random signed |number| <= 2^nbits */
28 | int64_t uniform_i64_bits(uint64_t nbits);
29 | /** @brief uniformly random signed lb <= number <= ub */
30 | int64_t uniform_i64_bounds(const int64_t lb, const int64_t ub);
31 | /** @brief uniformly random signed lb <= number <= ub */
32 | __int128_t uniform_i128_bounds(const __int128_t lb, const __int128_t ub);
33 | /** @brief uniformly random gaussian float64 */
34 | double random_f64_gaussian(double stdev = 1);
35 | /** @brief uniformly random signed lb <= number <= ub */
36 | double uniform_f64_bounds(const double lb, const double ub);
37 | /** @brief uniformly random float64 in [0,1] */
38 | double uniform_f64_01();
39 | /** @brief random gaussian float64 */
40 | double random_f64_gaussian(double stdev);
41 |
42 | bool is_pow2(uint64_t n);
43 |
44 | void* alloc64(uint64_t size);
45 |
46 | typedef __uint128_t thash;
47 | /** @brief returns some pseudorandom hash of a contiguous content */
48 | thash test_hash(const void* data, uint64_t size);
49 | /** @brief class to return a pseudorandom hash of a piecewise-defined content */
50 | class test_hasher {
51 | void* md;
52 | public:
53 | test_hasher();
54 | test_hasher(const test_hasher&) = delete;
55 | void operator=(const test_hasher&) = delete;
56 | /**
57 | * @brief append input bytes.
58 | * The final hash only depends on the concatenation of bytes, not on the
59 | * way the content was split into multiple calls to update.
60 | */
61 | void update(const void* data, uint64_t size);
62 | /**
63 | * @brief returns the final hash.
64 | * no more calls to update(...) shall be issued after this call.
65 | */
66 | thash hash();
67 | ~test_hasher();
68 | };
69 |
70 | // not included by default, since it makes some versions of gtest not compile
71 | // std::ostream& operator<<(std::ostream& out, __int128_t x);
72 | // std::ostream& operator<<(std::ostream& out, __uint128_t x);
73 |
74 | #endif // SPQLIOS_TEST_COMMONS_H
75 |
--------------------------------------------------------------------------------
/test/testlib/test_hash.cpp:
--------------------------------------------------------------------------------
1 | #include "sha3.h"
2 | #include "test_commons.h"
3 |
4 | /** @brief returns some pseudorandom hash of the content */
5 | thash test_hash(const void* data, uint64_t size) {
6 | thash res;
7 | sha3(data, size, &res, sizeof(res));
8 | return res;
9 | }
10 | /** @brief class to return a pseudorandom hash of the content */
11 | test_hasher::test_hasher() {
12 | md = malloc(sizeof(sha3_ctx_t));
13 | sha3_init((sha3_ctx_t*)md, 16);
14 | }
15 |
16 | void test_hasher::update(const void* data, uint64_t size) { sha3_update((sha3_ctx_t*)md, data, size); }
17 |
18 | thash test_hasher::hash() {
19 | thash res;
20 | sha3_final(&res, (sha3_ctx_t*)md);
21 | return res;
22 | }
23 |
24 | test_hasher::~test_hasher() { free(md); }
25 |
--------------------------------------------------------------------------------
/test/testlib/vec_rnx_layout.h:
--------------------------------------------------------------------------------
1 | #ifndef SPQLIOS_EXT_VEC_RNX_LAYOUT_H
2 | #define SPQLIOS_EXT_VEC_RNX_LAYOUT_H
3 |
4 | #include "../../spqlios/arithmetic/vec_rnx_arithmetic.h"
5 | #include "fft64_dft.h"
6 | #include "negacyclic_polynomial.h"
7 | #include "reim4_elem.h"
8 | #include "test_commons.h"
9 |
10 | /** @brief a test memory layout for rnx i64 polynomials vectors */
11 | class rnx_vec_f64_layout {
12 | uint64_t n;
13 | uint64_t size;
14 | uint64_t slice;
15 | double* data_start;
16 | uint8_t* region;
17 |
18 | public:
19 | // NO-COPY structure
20 | rnx_vec_f64_layout(const rnx_vec_f64_layout&) = delete;
21 | void operator=(const rnx_vec_f64_layout&) = delete;
22 | rnx_vec_f64_layout(rnx_vec_f64_layout&&) = delete;
23 | void operator=(rnx_vec_f64_layout&&) = delete;
24 | /** @brief initialises a memory layout */
25 | rnx_vec_f64_layout(uint64_t n, uint64_t size, uint64_t slice);
26 | /** @brief destructor */
27 | ~rnx_vec_f64_layout();
28 |
29 | /** @brief get a copy of item index index (extended with zeros) */
30 | rnx_f64 get_copy_zext(uint64_t index) const;
31 | /** @brief get a copy of item index index (extended with zeros) */
32 | rnx_f64 get_copy(uint64_t index) const;
33 | /** @brief get a copy of item index index (extended with zeros) */
34 | reim_fft64vec get_dft_copy_zext(uint64_t index) const;
35 | /** @brief get a copy of item index index (extended with zeros) */
36 | reim_fft64vec get_dft_copy(uint64_t index) const;
37 |
38 | /** @brief get a copy of item index index (index> 5;
14 | const uint64_t rem_ncols = ncols & 31;
15 | uint64_t blk = col >> 5;
16 | uint64_t col_rem = col & 31;
17 | if (blk < nblk) {
18 | // column is part of a full block
19 | return (int32_t*)data + blk * nrows * 32 + row * 32 + col_rem;
20 | } else {
21 | // column is part of the last block
22 | return (int32_t*)data + blk * nrows * 32 + row * rem_ncols + col_rem;
23 | }
24 | }
25 | int32_t zn32_pmat_layout::get(uint64_t row, uint64_t col) const { return *get_addr(row, col); }
26 | int32_t zn32_pmat_layout::get_zext(uint64_t row, uint64_t col) const {
27 | if (row >= nrows || col >= ncols) return 0;
28 | return *get_addr(row, col);
29 | }
30 | void zn32_pmat_layout::set(uint64_t row, uint64_t col, int32_t value) { *get_addr(row, col) = value; }
31 | void zn32_pmat_layout::fill_random() {
32 | int32_t* d = (int32_t*)data;
33 | for (uint64_t i = 0; i < nrows * ncols; ++i) d[i] = uniform_i64_bits(32);
34 | }
35 | thash zn32_pmat_layout::content_hash() const { return test_hash(data, nrows * ncols * sizeof(int32_t)); }
36 |
37 | template
38 | std::vector vmp_product(const T* vec, uint64_t vec_size, uint64_t out_size, const zn32_pmat_layout& mat) {
39 | uint64_t rows = std::min(vec_size, mat.nrows);
40 | uint64_t cols = std::min(out_size, mat.ncols);
41 | std::vector res(out_size, 0);
42 | for (uint64_t j = 0; j < cols; ++j) {
43 | for (uint64_t i = 0; i < rows; ++i) {
44 | res[j] += vec[i] * mat.get(i, j);
45 | }
46 | }
47 | return res;
48 | }
49 |
50 | template std::vector vmp_product(const int8_t* vec, uint64_t vec_size, uint64_t out_size,
51 | const zn32_pmat_layout& mat);
52 | template std::vector vmp_product(const int16_t* vec, uint64_t vec_size, uint64_t out_size,
53 | const zn32_pmat_layout& mat);
54 | template std::vector vmp_product(const int32_t* vec, uint64_t vec_size, uint64_t out_size,
55 | const zn32_pmat_layout& mat);
56 |
--------------------------------------------------------------------------------
/test/testlib/zn_layouts.h:
--------------------------------------------------------------------------------
1 | #ifndef SPQLIOS_EXT_ZN_LAYOUTS_H
2 | #define SPQLIOS_EXT_ZN_LAYOUTS_H
3 |
4 | #include "../../spqlios/arithmetic/zn_arithmetic.h"
5 | #include "test_commons.h"
6 |
7 | class zn32_pmat_layout {
8 | public:
9 | const uint64_t nrows;
10 | const uint64_t ncols;
11 | ZN32_VMP_PMAT* const data;
12 | zn32_pmat_layout(uint64_t nrows, uint64_t ncols);
13 |
14 | private:
15 | int32_t* get_addr(uint64_t row, uint64_t col) const;
16 |
17 | public:
18 | int32_t get(uint64_t row, uint64_t col) const;
19 | int32_t get_zext(uint64_t row, uint64_t col) const;
20 | void set(uint64_t row, uint64_t col, int32_t value);
21 | void fill_random();
22 | thash content_hash() const;
23 | ~zn32_pmat_layout();
24 | };
25 |
26 | template
27 | std::vector vmp_product(const T* vec, uint64_t vec_size, uint64_t out_size, const zn32_pmat_layout& mat);
28 |
29 | #endif // SPQLIOS_EXT_ZN_LAYOUTS_H
30 |
--------------------------------------------------------------------------------