├── .gitignore ├── .gitmodules ├── sources ├── shuffle_ptx.cu ├── CUDA_init.cu ├── device_info.cu ├── CMakeLists.txt ├── kernels.cu ├── finitefield.cu ├── ecc_exp.cu ├── square_256_to_512.cu ├── test_framework.cu ├── cuda_exports.cu ├── basic_arithmetic.cu ├── mul_128_to_256.cu ├── host_funcs.cpp ├── ell_point.cu ├── FFT.cu ├── Groth16_prover.cu └── mont_mul.cu ├── utilities ├── check_CUDA_arch.cu ├── root_of_unity_generator.py ├── CIOS_MONT-mul.py ├── mont_asm_generator.py ├── test_multiexp_framework.py ├── test_framework.py ├── test_DFT.py ├── mul_with_shuffle_generator.py └── asm_generator.py ├── LICENSE ├── CMakeLists.txt ├── README.md └── include ├── cuda_export_headers.h ├── cuda_macros.h └── cuda_structs.h /.gitignore: -------------------------------------------------------------------------------- 1 | *vscode* 2 | build 3 | benches.txt 4 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "depends/catch"] 2 | path = depends/catch 3 | url = https://github.com/catchorg/Catch2.git 4 | [submodule "depends/CUB"] 5 | path = depends/CUB 6 | url = https://github.com/NVlabs/cub 7 | -------------------------------------------------------------------------------- /sources/shuffle_ptx.cu: -------------------------------------------------------------------------------- 1 | __constant__ unsigned A[32]; 2 | 3 | __global__ void I_wanna_understand_shuffles() 4 | { 5 | unsigned c = A[threadIdx.x]; 6 | unsigned b = __shfl_down_sync(0xffffffff, c, 1, 8); 7 | 8 | } 9 | 10 | -------------------------------------------------------------------------------- /utilities/check_CUDA_arch.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | int main(int argc, char **argv) 5 | { 6 | cudaDeviceProp dP; 7 | 8 | int rc = cudaGetDeviceProperties(&dP, 0); 9 | if(rc != cudaSuccess) 10 | { 11 | cudaError_t error = cudaGetLastError(); 12 | printf("CUDA error: %s\n", cudaGetErrorString(error)); 13 | return rc; /* Failure */ 14 | } 15 | else 16 | { 17 | printf("%d%d", dP.major, dP.minor); 18 | return 0; /* Success */ 19 | } 20 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2019 Alex Vlasov 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.8 FATAL_ERROR) 2 | project(test_cuda LANGUAGES CXX CUDA) 3 | 4 | #Determine compute possibilities of device installed on target system 5 | 6 | set(OUTPUTFILE ${CMAKE_BINARY_DIR}/check_CUDA_arch) 7 | set(CUDAFILE ${CMAKE_CURRENT_SOURCE_DIR}/utilities/check_CUDA_arch.cu) 8 | execute_process(COMMAND nvcc -lcuda ${CUDAFILE} -o ${OUTPUTFILE}) 9 | execute_process(COMMAND ${OUTPUTFILE} RESULT_VARIABLE CUDA_RETURN_CODE OUTPUT_VARIABLE GPU_ARCH) 10 | 11 | if(${CUDA_RETURN_CODE} EQUAL 0) 12 | message(STATUS "GPU Architecture: ${GPU_ARCH}") 13 | 14 | string(APPEND CMAKE_CUDA_FLAGS " -gencode arch=compute_${GPU_ARCH},code=sm_${GPU_ARCH}") 15 | else() 16 | message( FATAL_ERROR ${GPU_ARCH}) 17 | endif() 18 | 19 | include_directories(include) 20 | add_subdirectory(sources) -------------------------------------------------------------------------------- /utilities/root_of_unity_generator.py: -------------------------------------------------------------------------------- 1 | p = 21888242871839275222246405745257275088548364400416034343698204186575808495617 2 | max_k = 28 3 | 4 | field = GF(p) 5 | 6 | gen = field.multiplicative_generator() 7 | 8 | R = field(2^256) 9 | 10 | a = (p - 1) / (2 ^ max_k) 11 | 12 | root_of_unity = gen ^ a 13 | 14 | def splitter(x): 15 | x = hex(int(x * R))[2:-1] 16 | str_len = len(x) 17 | if str_len % 8 != 0: 18 | x = "0" * (8 - (str_len % 8)) + x 19 | 20 | res = ["0x" + x[i:i+8] for i in range(0, len(x), 8)] 21 | return res[::-1] 22 | 23 | def printer(x): 24 | res = "{ " 25 | for j in xrange(8): 26 | res += x[j] 27 | if j != 7: 28 | res += ", " 29 | res += " };" 30 | return res 31 | 32 | x = root_of_unity 33 | 34 | for i in xrange(max_k): 35 | ww = splitter(x) 36 | print printer(ww) 37 | x *= x -------------------------------------------------------------------------------- /utilities/CIOS_MONT-mul.py: -------------------------------------------------------------------------------- 1 | P = [ 0xd87cfd47, 0x3c208c16, 0x6871ca8d, 0x97816a91, 0x8181585d, 0xb85045b6, 0xe131a029, 0x30644e72 ] 2 | N = 0xe4866389 3 | 4 | word_len = 2^32 5 | 6 | 7 | 8 | def CIOS(A, B): 9 | S = [0, 0, 0, 0, 0, 0, 0, 0] 10 | 11 | for j in xrange(2): 12 | for i in xrange(8): 13 | S[i] += A[i] * B[j] 14 | 15 | q = (S[0] * N) % word_len 16 | for i in xrange(8): 17 | S[i] += q * P[i] 18 | 19 | for i in xrange(7): 20 | S[i] = (S[i] >> 32) + (S[i+1] % word_len) 21 | 22 | S[7] = (S[7] >> 32) 23 | 24 | temp = 0 25 | res = [] 26 | for i in xrange(8): 27 | S[i] = S[i] + temp 28 | res.append(hex(S[i] % word_len)) 29 | temp = (S[i] >> 32) 30 | 31 | return res 32 | 33 | 34 | def splitter(num): 35 | res = [] 36 | for i in xrange(8): 37 | res.append(num % word_len) 38 | num = num >> 32 39 | 40 | return res -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # About 2 | 3 | This repository contains alpha stage (highly WiP) primitives for typical operations required for various pairing based proof systems - finite field operations, FFT, EC point arithmetic and multiexponentiation. All operations implemented are only for G1 group of pairing friendly curve BN254 (Ethereum curve). 4 | 5 | There is also an implementation of 70% of the work required for Groth16 prover with interface for intergration with `bellman_ce` Rust crate. 6 | 7 | ## Instruction 8 | 9 | - Install CUDA 10.0 and Cmake 3.9+ 10 | - `cmake --release .` 11 | - `make` 12 | - in some folder clone `ff`, `pairing` and `bellman` repositories from Matter Labs 13 | - checkout `gpu` branches in all of them 14 | - copy file `sources/libcuda.so` into the `bellman` directory 15 | - try to run tests `cargo test --release -- --nocapture test_mimc_bn256_gpu_all` to get some benchmark results and validity checks. 16 | - you can also change `const MIMC_ROUNDS: usize = 16000000;` in a file `bellman/tests/mimc.rs` to reduce number of constrains in a test circuits if you run out of memory (RAM or GPU memory) -------------------------------------------------------------------------------- /sources/CUDA_init.cu: -------------------------------------------------------------------------------- 1 | #include "cuda_structs.h" 2 | 3 | #include 4 | 5 | bool CUDA_init() 6 | { 7 | //first find suitable Cuda device 8 | //TBD: or split between several CUDA devices if possible 9 | int device_count; 10 | cudaError_t cudaStatus = cudaGetDeviceCount(&device_count); 11 | if (cudaStatus != cudaSuccess) { 12 | fprintf(stderr, "cudaGetDeviceCount failed!"); 13 | return false; 14 | } 15 | if (device_count == 0) 16 | { 17 | fprintf(stderr, "No suitable CUDA devices were found!"); 18 | return false; 19 | } 20 | 21 | cudaDeviceProp prop; 22 | cudaStatus = cudaGetDeviceProperties(&prop, 0); 23 | 24 | if (cudaStatus != cudaSuccess) 25 | { 26 | fprintf(stderr, "cudaGetDeviceCount failed!"); 27 | return false; 28 | } 29 | 30 | printf("Compute possibilities: %d.%d\n", prop.major, prop.minor); 31 | 32 | //TODO: check if there are enough constant memory and other additional checks 33 | //set appropriate device 34 | // Choose which GPU to run on, change this on a multi-GPU system. 35 | cudaStatus = cudaSetDevice(0); 36 | if (cudaStatus != cudaSuccess) 37 | { 38 | fprintf(stderr, "cudaSetDevice failed! Do you have a CUDA-capable GPU installed?"); 39 | return false; 40 | } 41 | 42 | return true; 43 | } 44 | 45 | 46 | -------------------------------------------------------------------------------- /sources/device_info.cu: -------------------------------------------------------------------------------- 1 | #include "cuda_structs.h" 2 | 3 | #include 4 | 5 | void get_device_info() 6 | { 7 | cudaDeviceProp prop; 8 | cudaGetDeviceProperties(&prop, 0); 9 | 10 | uint32_t sm_count = prop.multiProcessorCount; 11 | uint32_t warp_size = prop.warpSize; 12 | uint32_t shared_mem_per_block = prop.sharedMemPerBlock; 13 | uint32_t shared_mem_per_multiprocessor = prop.sharedMemPerMultiprocessor; 14 | uint32_t regs_per_block = prop.regsPerBlock; 15 | uint32_t regs_per_multiprocessor = prop.regsPerMultiprocessor; 16 | uint32_t max_threads_per_block = prop.maxThreadsPerBlock; 17 | uint32_t max_threads_per_multiprocessor = prop.maxThreadsPerMultiProcessor; 18 | 19 | std::cout << "SM count: " << sm_count << std::endl; 20 | std::cout << "warp size: " << warp_size << std::endl; 21 | std::cout << "Number of shared memory per block (in bytes): " << shared_mem_per_block << std::endl; 22 | std::cout << "Number of shared memory per multiprocessor (in bytes): " << shared_mem_per_multiprocessor << std::endl; 23 | std::cout << "Number of 32bit registers per block: " << regs_per_block << std::endl; 24 | std::cout << "Number of 32bit register per multiprocessor: " << regs_per_multiprocessor << std::endl; 25 | std::cout << "Max number of threads per block: " << max_threads_per_block << std::endl; 26 | std::cout << "Max number of threads per multiprocessor: " << max_threads_per_multiprocessor << std::endl; 27 | } -------------------------------------------------------------------------------- /utilities/mont_asm_generator.py: -------------------------------------------------------------------------------- 1 | def generate_mont_asm_Listing(asm_len): 2 | printed_asm = "" 3 | for i in xrange(asm_len): 4 | printed_asm += 'mul.lo.u32 m, a{:d}, q;\\n\\t\"\n'.format(i) 5 | first = True 6 | for j in xrange(asm_len): 7 | if first: 8 | printed_asm += "mad." 9 | first = False 10 | else: 11 | printed_asm += "madc." 12 | printed_asm += 'lo.cc.u32 a{:d}, m, n{:d}, a{:d};\\n\\t\"\n'.format(i+j, j, i+j) 13 | j = i + asm_len 14 | while (j < 2 * asm_len): 15 | if (j < 2 * asm_len - 1): 16 | printed_asm += 'addc.cc.u32 a{:d}, a{:d}, 0;\\n\\t\"\n'.format(j, j) 17 | else: 18 | printed_asm += 'add.cc.u32 a{:d}, a{:d}, 0;\\n\\t\"\n'.format(j, j) 19 | j = j + 1 20 | first = True 21 | for j in xrange(asm_len): 22 | if first: 23 | printed_asm += "mad." 24 | first = False 25 | else: 26 | printed_asm += "madc." 27 | printed_asm += 'hi.cc.u32 a{:d}, m, n{:d}, a{:d};\\n\\t\"\n'.format(i+j+1, j, i+j+1) 28 | j = i + asm_len + 1 29 | while (j < 2 * asm_len): 30 | if (j < 2 * asm_len - 1): 31 | printed_asm += 'addc.cc.u32 a{:d}, a{:d}, 0;\\n\\t\"\n'.format(j, j) 32 | else: 33 | printed_asm += 'add.cc.u32 a{:d}, a{:d}, 0;\\n\\t\"\n'.format(j, j) 34 | j = j + 1 35 | return printed_asm 36 | 37 | print generate_mont_asm_Listing(8) -------------------------------------------------------------------------------- /utilities/test_multiexp_framework.py: -------------------------------------------------------------------------------- 1 | p = 21888242871839275222246405745257275088696311157297823662689037894645226208583 2 | r = 21888242871839275222246405745257275088548364400416034343698204186575808495617 3 | 4 | base_field = GF(p) 5 | curve = EllipticCurve(base_field, [0, 3]); 6 | G = curve(1, 2, 1) 7 | 8 | R = base_field(2 ^ 256) 9 | 10 | def to_mont_form(x): 11 | return x * R 12 | 13 | def from_mont_form(x): 14 | return x / R 15 | 16 | def parse_affine_point(line1, line2): 17 | x = int(line1.split('=')[1], 0x10) 18 | y = int(line2.split('=')[1], 0x10) 19 | 20 | return curve(x, y, R) 21 | 22 | def parse_bignum(line, base = 0x10): 23 | return int(line, base) 24 | 25 | def parse_ec_point(line1, line2, line3): 26 | x = int(line1.split('=')[1], 0x10) 27 | y = int(line2.split('=')[1], 0x10) 28 | z = int(line3.split('=')[1], 0x10) 29 | 30 | return curve(x, y, z) 31 | 32 | 33 | def extractKBits(num,k,p): 34 | num = num >> p 35 | num = num & (2^k - 1) 36 | return num 37 | 38 | 39 | pt_arr = [] 40 | 41 | FILE_LOCATION = "/home/k/TestCuda3/benches.txt" 42 | 43 | file = open(FILE_LOCATION, "r") 44 | 45 | bench_len = parse_bignum(file.readline().split("=")[1][:-1], 10) 46 | print bench_len 47 | 48 | 49 | num_of_results = 2 50 | 51 | for _ in xrange(num_of_results): 52 | file.readline() 53 | file.readline() 54 | 55 | x = file.readline()[:-1] 56 | y = file.readline()[:-1] 57 | z = file.readline()[:-1] 58 | file.readline() 59 | C = parse_ec_point(x, y, z) 60 | pt_arr.append(C) 61 | 62 | print pt_arr[0] == pt_arr[1] 63 | print len(set(pt_arr)) == 1 64 | 65 | time1 = 85375141521 66 | time2 = 12400215642 67 | 68 | print float(time1 / time2) -------------------------------------------------------------------------------- /utilities/test_framework.py: -------------------------------------------------------------------------------- 1 | p = 21888242871839275222246405745257275088696311157297823662689037894645226208583 2 | r = 21888242871839275222246405745257275088548364400416034343698204186575808495617 3 | 4 | base_field = GF(p) 5 | curve = EllipticCurve(base_field, [0, 3]); 6 | G = curve(1, 2, 1) 7 | 8 | R = base_field(2 ^ 256) 9 | 10 | def to_mont_form(x): 11 | return x * R 12 | 13 | def from_mont_form(x): 14 | return x / R 15 | 16 | def parse_affine_point(line1, line2): 17 | x = int(line1.split('=')[1], 0x10) 18 | y = int(line2.split('=')[1], 0x10) 19 | 20 | return curve(x, y, R) 21 | 22 | def parse_bignum(line): 23 | return int(line, 0x10) 24 | 25 | def parse_ec_point(line1, line2, line3): 26 | x = int(line1.split('=')[1], 0x10) 27 | y = int(line2.split('=')[1], 0x10) 28 | z = int(line3.split('=')[1], 0x10) 29 | 30 | return curve(x, y, z) 31 | 32 | A_arr = [] 33 | B_arr = [] 34 | C_arr = [] 35 | bench_len = 10 36 | FILE_LOCATION = "/home/k/TestCuda3/benches.txt" 37 | 38 | file = open(FILE_LOCATION, "r") 39 | 40 | print file.readline() 41 | 42 | for _ in xrange(bench_len): 43 | x = file.readline()[:-1] 44 | y = file.readline()[:-1] 45 | file.readline() 46 | 47 | A_arr.append(parse_affine_point(x, y)) 48 | 49 | print file.readline() 50 | print file.readline() 51 | 52 | for _ in xrange(bench_len): 53 | num = file.readline()[:-1] 54 | B_arr.append(parse_bignum(num)) 55 | 56 | print file.readline() 57 | print file.readline() 58 | 59 | for _ in xrange(bench_len): 60 | x = file.readline()[:-1] 61 | y = file.readline()[:-1] 62 | z = file.readline()[:-1] 63 | file.readline() 64 | 65 | C_arr.append(parse_ec_point(x, y, z)) 66 | 67 | 68 | for i in xrange(bench_len): 69 | print (B_arr[i] * A_arr[i] == C_arr[i]) 70 | 71 | 72 | -------------------------------------------------------------------------------- /utilities/test_DFT.py: -------------------------------------------------------------------------------- 1 | #FFT 2 | 3 | p = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 4 | field = GF(p) 5 | R = field(0xe0a77c19a07df2f666ea36f7879462e36fc76959f60cd29ac96341c4ffffffb) 6 | root_of_unity = field(0x1860ef942963f9e756452ac01eb203d8a22bf3742445ffd6636e735580d13d9c) / R 7 | FILE_LOCATION = "/home/k/TestCuda3/benches.txt" 8 | 9 | unity_order = root_of_unity.multiplicative_order() 10 | 11 | #all elements of arr are given in Montgomery form 12 | 13 | def from_mont_form(arr): 14 | return [x / R for x in arr] 15 | 16 | 17 | def to_mont_form(arr): 18 | return [x * R for x in arr] 19 | 20 | 21 | def DFT(arr): 22 | n = len(arr) 23 | omega = root_of_unity ^ (unity_order / n) 24 | res = [] 25 | 26 | temp_arr = from_mont_form(arr) 27 | 28 | 29 | for i in xrange(n): 30 | temp = field(0) 31 | for j, elem in enumerate(temp_arr): 32 | temp += elem * omega ^ (i * j) 33 | res.append(temp) 34 | 35 | return to_mont_form(res) 36 | #return [hex(int(x)) for x in to_mont_form(res)] 37 | 38 | 39 | def parse_bignum(line, base = 0x10): 40 | return int(line, base) 41 | 42 | 43 | def read_arr(bench_len, file): 44 | arr = [] 45 | 46 | file.readline() 47 | file.readline() 48 | 49 | for _ in xrange(bench_len): 50 | line = file.readline() 51 | num = parse_bignum(line) 52 | arr.append(field(num)) 53 | 54 | return arr 55 | 56 | 57 | 58 | def sample_from_file(): 59 | file = open(FILE_LOCATION, "r") 60 | 61 | bench_len = parse_bignum(file.readline().split("=")[1][:-1], 10) 62 | print bench_len 63 | 64 | A = read_arr(bench_len, file) 65 | B = read_arr(bench_len, file) 66 | C = read_arr(bench_len, file) 67 | 68 | D = DFT(A) 69 | for i in xrange(bench_len): 70 | if D[i] != C[i]: 71 | print "arrs are different" 72 | 73 | print "finish" 74 | 75 | 76 | 77 | 78 | sample_from_file() 79 | 80 | -------------------------------------------------------------------------------- /sources/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_executable(cuda_benches main.cu basic_arithmetic.cu constants.cu kernels.cu mont_mul.cu mul_128_to_256.cu 2 | mul_256_to_512.cu square_256_to_512.cu ell_point.cu ecc_exp.cu CUDA_init.cu finitefield.cu multiexp.cu FFT.cu experiemental.cu 3 | device_info.cu) 4 | 5 | add_library(cuda SHARED basic_arithmetic.cu constants.cu mont_mul.cu mul_128_to_256.cu 6 | mul_256_to_512.cu square_256_to_512.cu ell_point.cu ecc_exp.cu CUDA_init.cu finitefield.cu multiexp.cu device_info.cu 7 | Groth16_prover.cu) 8 | target_compile_features(cuda_benches PUBLIC cxx_std_14) 9 | target_compile_features(cuda PUBLIC cxx_std_14) 10 | 11 | #target_compile_definitions(cuda_benches PRIVATE PRINT_BENCHES_INPUT) 12 | #target_compile_definitions(cuda_benches PRIVATE PRINT_BENCHES_OUTPUT) 13 | target_compile_definitions(cuda_benches PRIVATE ZERO_BANK_CONFLICTS) 14 | target_compile_definitions(cuda PRIVATE ZERO_BANK_CONFLICTS) 15 | 16 | # We need to explicitly state that we need all CUDA files in the 17 | # particle library to be built with -dc as the member functions 18 | # could be called by other libraries and executables 19 | set_target_properties(cuda_benches PROPERTIES CUDA_SEPARABLE_COMPILATION ON) 20 | set_target_properties(cuda PROPERTIES CUDA_SEPARABLE_COMPILATION ON) 21 | #set_property(TARGET cuda PROPERTY POSITION_INDEPENDENT_CODE ON) 22 | 23 | #target_include_directories(cuda_benches PRIVATE ${PROJECT_SOURCE_DIR}/depends/CUB) 24 | 25 | #Checking all of the routines for correctness 26 | option(WITH_CORRECTNESS_CHECK "check correctness of all routines" OFF) 27 | 28 | set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O3") 29 | 30 | add_library(CudaPTX OBJECT shuffle_ptx.cu) 31 | set_property(TARGET CudaPTX PROPERTY CUDA_PTX_COMPILATION ON) 32 | 33 | install(TARGETS CudaPTX 34 | OBJECTS DESTINATION bin/ptx 35 | ) 36 | 37 | if (WITH_CORRECTNESS_CHECK) 38 | # dependencies 39 | # add_subdirectory(lib/Catch2) 40 | # target_link_libraries(tests Catch2::Catch2) 41 | 42 | # executable 43 | # add_executable(correctness_checks test.cpp) 44 | # target_compile_definitions(correctness_checks PRIVATE Py_LIMITED_API) 45 | # target_include_directories(correctness_checks PUBLIC ${PYTHON_INCLUDE_DIRS}) 46 | 47 | # target_link_libraries(correctness_checks PRIVATE ${PYTHON_LIBRARIES}) 48 | endif() 49 | 50 | -------------------------------------------------------------------------------- /sources/kernels.cu: -------------------------------------------------------------------------------- 1 | #include "cuda_structs.h" 2 | #include "cuda_macros.h" 3 | 4 | #include 5 | 6 | GENERAL_TEST_2_ARGS_1_TYPE(add_uint256_naive, uint256_g) 7 | GENERAL_TEST_2_ARGS_1_TYPE(add_uint256_asm, uint256_g) 8 | GENERAL_TEST_2_ARGS_1_TYPE(sub_uint256_naive, uint256_g) 9 | GENERAL_TEST_2_ARGS_1_TYPE(sub_uint256_asm, uint256_g) 10 | 11 | GENERAL_TEST_2_ARGS_2_TYPES(mul_uint256_to_512_asm, uint256_g, uint512_g) 12 | GENERAL_TEST_2_ARGS_2_TYPES(mul_uint256_to_512_naive, uint256_g, uint512_g) 13 | GENERAL_TEST_2_ARGS_2_TYPES(mul_uint256_to_512_asm_with_allocation, uint256_g, uint512_g) 14 | GENERAL_TEST_2_ARGS_2_TYPES(mul_uint256_to_512_asm_longregs, uint256_g, uint512_g) 15 | GENERAL_TEST_2_ARGS_2_TYPES(mul_uint256_to_512_Karatsuba, uint256_g, uint512_g) 16 | GENERAL_TEST_2_ARGS_2_TYPES(mul_uint256_to_512_asm_with_shuffle, uint256_g, uint512_g) 17 | 18 | GENERAL_TEST_1_ARG_2_TYPES(square_uint256_to_512_naive, uint256_g, uint512_g) 19 | GENERAL_TEST_1_ARG_2_TYPES(square_uint256_to_512_asm, uint256_g, uint512_g) 20 | 21 | GENERAL_TEST_2_ARGS_1_TYPE(mont_mul_256_naive_SOS, uint256_g) 22 | GENERAL_TEST_2_ARGS_1_TYPE(mont_mul_256_naive_CIOS, uint256_g) 23 | GENERAL_TEST_2_ARGS_1_TYPE(mont_mul_256_asm_SOS, uint256_g) 24 | GENERAL_TEST_2_ARGS_1_TYPE(mont_mul_256_asm_CIOS, uint256_g) 25 | 26 | GENERAL_TEST_1_ARG_1_TYPE(FIELD_MUL_INV, uint256_g) 27 | 28 | GENERAL_TEST_2_ARGS_1_TYPE(ECC_ADD_PROJ, ec_point); 29 | GENERAL_TEST_2_ARGS_1_TYPE(ECC_SUB_PROJ, ec_point); 30 | GENERAL_TEST_1_ARG_1_TYPE(ECC_DOUBLE_PROJ, ec_point); 31 | GENERAL_TEST_1_ARG_2_TYPES(IS_ON_CURVE_PROJ, ec_point, bool); 32 | 33 | GENERAL_TEST_2_ARGS_1_TYPE(ECC_ADD_JAC, ec_point); 34 | GENERAL_TEST_2_ARGS_1_TYPE(ECC_SUB_JAC, ec_point); 35 | GENERAL_TEST_1_ARG_1_TYPE(ECC_DOUBLE_JAC, ec_point); 36 | GENERAL_TEST_1_ARG_2_TYPES(IS_ON_CURVE_JAC, ec_point, bool); 37 | 38 | GENERAL_TEST_2_ARGS_3_TYPES(ECC_double_and_add_exp_PROJ, ec_point, uint256_g, ec_point); 39 | GENERAL_TEST_2_ARGS_3_TYPES(ECC_ternary_expansion_exp_PROJ, ec_point, uint256_g, ec_point); 40 | GENERAL_TEST_2_ARGS_3_TYPES(ECC_double_and_add_exp_JAC, ec_point, uint256_g, ec_point); 41 | GENERAL_TEST_2_ARGS_3_TYPES(ECC_ternary_expansion_exp_JAC, ec_point, uint256_g, ec_point); 42 | GENERAL_TEST_2_ARGS_3_TYPES(ECC_wNAF_exp_PROJ, ec_point, uint256_g, ec_point); 43 | GENERAL_TEST_2_ARGS_3_TYPES(ECC_wNAF_exp_JAC, ec_point, uint256_g, ec_point); 44 | 45 | 46 | GENERAL_TEST_2_ARGS_3_TYPES(ECC_double_and_add_affine_exp_PROJ, affine_point, uint256_g, ec_point); 47 | GENERAL_TEST_2_ARGS_3_TYPES(ECC_double_and_add_affine_exp_JAC, affine_point, uint256_g, ec_point); -------------------------------------------------------------------------------- /include/cuda_export_headers.h: -------------------------------------------------------------------------------- 1 | #ifndef CUDA_EXPORT_HEADERS 2 | #define CUDA_EXPORT_HEADERS 3 | 4 | #define EXPORT __attribute__((visibility("default"))) 5 | 6 | struct EXPORT embedded_field; 7 | struct EXPORT ec_point; 8 | struct EXPORT affine_point; 9 | struct EXPORT polynomial; 10 | 11 | //----------------------------------------------------------------------------------------------------------------------------------------------- 12 | //export basic parallel routines: finite field addition, substraction and multiplication; elliptic curve addiition and substraction 13 | //----------------------------------------------------------------------------------------------------------------------------------------------- 14 | 15 | void EXPORT field_add(const embedded_field*, const embedded_field*, embedded_field*, uint32_t); 16 | void EXPORT field_sub(const embedded_field*, const embedded_field*, embedded_field*, uint32_t); 17 | void EXPORT field_mul(const embedded_field*, const embedded_field*, embedded_field*, uint32_t); 18 | 19 | void EXPORT ec_point_add(const ec_point*, const ec_point*, ec_point*, uint32_t); 20 | void EXPORT ec_point_sub(const ec_point*, const ec_point*, ec_point*, uint32_t); 21 | 22 | //----------------------------------------------------------------------------------------------------------------------------------------------- 23 | //Multiexponentiation (based on Pippenger realization) 24 | //----------------------------------------------------------------------------------------------------------------------------------------------- 25 | 26 | ec_point EXPORT ec_multiexp(const affine_point*, const uint256_g*, uint32_t); 27 | 28 | //----------------------------------------------------------------------------------------------------------------------------------------------- 29 | //FFT routines 30 | //----------------------------------------------------------------------------------------------------------------------------------------------- 31 | 32 | void EXPORT FFT(const embedded_field*, embedded_field*, uint32_t); 33 | 34 | void EXPORT iFFT(const embedded_field*, embedded_field*, uint32_t, const embedded_field&); 35 | 36 | //------------------------------------------------------------------------------------------------------------------------------------------------ 37 | //polynomial arithmetic 38 | //------------------------------------------------------------------------------------------------------------------------------------------------ 39 | 40 | // polynomial EXPORT poly_add(const& polynomial, const& polynomial); 41 | // polynomial EXPORT poly_sub(const& polynomial, const& polynomial); 42 | // polynomial EXPORT poly_mul(const& polynomial, const& polynomial); 43 | 44 | #endif -------------------------------------------------------------------------------- /include/cuda_macros.h: -------------------------------------------------------------------------------- 1 | #ifndef CUDA_MACROS_H 2 | #define CUDA_MACROS_H 3 | 4 | #define GENERAL_TEST_2_ARGS_3_TYPES(func_name, A_type, B_type, C_type) \ 5 | __global__ void func_name##_kernel(A_type *a_arr, B_type *b_arr, C_type *c_arr, size_t arr_len)\ 6 | {\ 7 | size_t tid = threadIdx.x + blockIdx.x * blockDim.x;\ 8 | while (tid < arr_len)\ 9 | {\ 10 | c_arr[tid] = func_name(a_arr[tid], b_arr[tid]);\ 11 | tid += blockDim.x * gridDim.x;\ 12 | }\ 13 | }\ 14 | \ 15 | void func_name##_driver(A_type *a_arr, B_type *b_arr, C_type *c_arr, size_t arr_len)\ 16 | {\ 17 | int blockSize;\ 18 | int minGridSize;\ 19 | int realGridSize;\ 20 | int maxActiveBlocks;\ 21 | \ 22 | cudaOccupancyMaxPotentialBlockSize( &minGridSize, &blockSize, func_name##_kernel, 0, 0);\ 23 | realGridSize = (arr_len + blockSize - 1) / blockSize;\ 24 | \ 25 | cudaDeviceProp prop;\ 26 | cudaGetDeviceProperties(&prop, 0);\ 27 | uint32_t smCount = prop.multiProcessorCount;\ 28 | cudaError_t error = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&maxActiveBlocks, func_name##_kernel, blockSize, 0);\ 29 | if (error == cudaSuccess)\ 30 | realGridSize = maxActiveBlocks * smCount;\ 31 | \ 32 | std::cout << "Grid size: " << realGridSize << ", min grid size: " << minGridSize << ", blockSize: " << blockSize << std::endl;\ 33 | func_name##_kernel<<>>(a_arr, b_arr, c_arr, arr_len);\ 34 | } 35 | 36 | #define GENERAL_TEST_2_ARGS_2_TYPES(func_name, input_type, output_type) GENERAL_TEST_2_ARGS_3_TYPES(func_name, input_type, input_type, output_type) 37 | #define GENERAL_TEST_2_ARGS_1_TYPE(func_name, type) GENERAL_TEST_2_ARGS_3_TYPES(func_name, type, type, type) 38 | 39 | #define GENERAL_TEST_1_ARG_2_TYPES(func_name, input_type, output_type) \ 40 | __global__ void func_name##_kernel(input_type *a_arr, input_type *b_arr, output_type *c_arr, size_t arr_len)\ 41 | {\ 42 | size_t tid = threadIdx.x + blockIdx.x * blockDim.x;\ 43 | while (tid < arr_len)\ 44 | {\ 45 | c_arr[tid] = func_name(a_arr[tid]);\ 46 | tid += blockDim.x * gridDim.x;\ 47 | }\ 48 | }\ 49 | \ 50 | void func_name##_driver(input_type *a_arr, input_type *b_arr, output_type *c_arr, size_t arr_len)\ 51 | {\ 52 | int blockSize;\ 53 | int minGridSize;\ 54 | int realGridSize;\ 55 | \ 56 | cudaOccupancyMaxPotentialBlockSize( &minGridSize, &blockSize, func_name##_kernel, 0, 0);\ 57 | realGridSize = (arr_len + blockSize - 1) / blockSize;\ 58 | \ 59 | std::cout << "Grid size: " << realGridSize << ", min grid size: " << minGridSize << ", blockSize: " << blockSize << std::endl;\ 60 | \ 61 | func_name##_kernel<<>>(a_arr, b_arr, c_arr, arr_len);\ 62 | } 63 | 64 | #define GENERAL_TEST_1_ARG_1_TYPE(func_name, type) GENERAL_TEST_1_ARG_2_TYPES(func_name, type, type) 65 | 66 | #endif 67 | -------------------------------------------------------------------------------- /utilities/mul_with_shuffle_generator.py: -------------------------------------------------------------------------------- 1 | def gen_tables(array_len): 2 | even_table = {} 3 | odd_table = {} 4 | even_table_size = 0 5 | odd_table_size = 0 6 | 7 | #initialize dictionaries 8 | for i in xrange(array_len): 9 | even_table[i] = [] 10 | odd_table[i] = [] 11 | 12 | for a in xrange(array_len): 13 | for b in xrange(array_len): 14 | idx = int((a + b) / 2) 15 | if (a + b) % 2 == 0: 16 | even_table[idx].append((a, b)) 17 | even_table_size = even_table_size + 1 18 | else: 19 | odd_table[idx].append((a, b)) 20 | odd_table_size = odd_table_size + 1 21 | 22 | return (even_table, odd_table, even_table_size, odd_table_size) 23 | 24 | from collections import namedtuple 25 | 26 | #c = a*b + d 27 | AsmInsn = namedtuple("AsmInsn", "op_type gen_carry use_carry dest_op first_op second_op") 28 | op_mul = 0 29 | op_add = 1 30 | op_unknown = 2 31 | 32 | 33 | def find_largest_sublist(table): 34 | count = 0 35 | index = -1 36 | for j in sorted(table): 37 | if len(table[j]) > count: 38 | count = len(table[j]) 39 | index = j 40 | 41 | assert index >= 0, "Index should be greater than zero" 42 | return index 43 | 44 | 45 | def to_op(index, letter): 46 | return letter + str(index) 47 | 48 | 49 | def check_if_in_chain(table, table_size, index): 50 | if table_size == 0: 51 | return False 52 | 53 | new_index = find_largest_sublist(table) 54 | 55 | if (new_index == index + 1): 56 | return True 57 | elif ((new_index == index + 2) and len(table[index + 1]) > 0): 58 | return True 59 | else: 60 | return False 61 | 62 | 63 | def gen_asm_for_table(table, table_size, array_len, even_flag): 64 | 65 | AsmListing = [] 66 | main_reg = "r" if even_flag else "t" 67 | temp_reg = "t" if even_flag else "s" 68 | 69 | for i in xrange(array_len): 70 | if (len(table[i]) > 0): 71 | (a, b) = table[i][0] 72 | table[i].pop(0) 73 | 74 | insn = AsmInsn(op_mul, False, False, to_op(i, main_reg), to_op(a, "a"), to_op(b, "b")) 75 | AsmListing.append(insn) 76 | table_size = table_size - 1 77 | 78 | possible_overflow = [] 79 | for i in xrange(array_len): 80 | possible_overflow.append(False) 81 | 82 | while(table_size > 0): 83 | 84 | index = find_largest_sublist(table) 85 | 86 | (a, b) = table[index][0] 87 | table[index].pop(0) 88 | table_size = table_size - 1 89 | 90 | insn = AsmInsn(op_mul, False, False, to_op(index, temp_reg), to_op(a, "a"), to_op(b, "b")) 91 | AsmListing.append(insn) 92 | 93 | start_index = index 94 | 95 | #append all muls 96 | 97 | while (check_if_in_chain(table, table_size, index)): 98 | index = index + 1 99 | (a, b) = table[index][0] 100 | table[index].pop(0) 101 | table_size = table_size - 1 102 | 103 | insn = AsmInsn(op_mul, False, False, to_op(index, temp_reg), to_op(a, "a"), to_op(b, "b")) 104 | AsmListing.append(insn) 105 | 106 | #append all additions 107 | use_carry = False 108 | while (start_index <= index): 109 | 110 | insn = AsmInsn(op_add, True, use_carry, to_op(start_index, main_reg), to_op(start_index, main_reg), 111 | to_op(start_index, temp_reg)) 112 | possible_overflow[start_index] = True 113 | AsmListing.append(insn) 114 | use_carry = True 115 | start_index = start_index + 1 116 | 117 | #NB: this is a small hack 118 | 119 | if (not even_flag and start_index == array_len - 1): 120 | insn = AsmInsn(op_add, False, True, to_op(start_index, main_reg), "0", "0") 121 | AsmListing.append(insn) 122 | else: 123 | cycle_flag = True 124 | while(cycle_flag): 125 | gen_carry = (start_index < array_len - 1) and possible_overflow[start_index] 126 | insn = AsmInsn(op_add, gen_carry, True, to_op(start_index, main_reg), to_op(start_index, main_reg), "0") 127 | AsmListing.append(insn) 128 | 129 | cycle_flag = gen_carry 130 | start_index = start_index + 1 131 | 132 | use_carry = False 133 | 134 | return AsmListing 135 | 136 | def gen_asm(array_len): 137 | even_table, odd_table, even_table_size, odd_table_size = gen_tables(array_len) 138 | 139 | AsmListing = gen_asm_for_table(even_table, even_table_size, array_len, True) 140 | AsmListing += gen_asm_for_table(odd_table, odd_table_size, array_len, False) 141 | 142 | return AsmListing 143 | 144 | 145 | def generate_printable_asm(AsmListing): 146 | printed_asm = "" 147 | ending = ";\\n\\t\"\n" 148 | 149 | for elem in AsmListing: 150 | if (elem.op_type == op_mul): 151 | #"mul.wide.u16 s0, a1, b0;\n\t" 152 | 153 | printed_asm += "\"mul.wide.u16 " 154 | printed_asm += elem.dest_op + ", " + elem.first_op + ", " + elem.second_op + ending 155 | 156 | elif (elem.op_type == op_add): 157 | printed_asm += "\"add" 158 | if (elem.use_carry): 159 | printed_asm += "c" 160 | printed_asm += "." 161 | if (elem.gen_carry): 162 | printed_asm += "cc." 163 | printed_asm += "u32 " 164 | 165 | printed_asm += elem.dest_op + ", " + elem.first_op + ", " + elem.second_op + ending 166 | 167 | else: 168 | raise ValueError('Incorrect operand type.') 169 | 170 | return printed_asm 171 | 172 | 173 | ARR_LEN = 16 174 | AsmListing = gen_asm(ARR_LEN) 175 | print len(AsmListing) 176 | print generate_printable_asm(AsmListing) -------------------------------------------------------------------------------- /utilities/asm_generator.py: -------------------------------------------------------------------------------- 1 | def is_in_range(x, shift, array_len): 2 | flag = (x >= shift and x < array_len + shift) 3 | return flag 4 | 5 | def checks_passed(a, b, array_len, is_squaring): 6 | flag = is_in_range(a, 2 * array_len, array_len) and is_in_range(b, 3 * array_len, array_len) 7 | if is_squaring: 8 | flag = flag and (b - 3 * array_len > a - 2 * array_len) 9 | return flag 10 | 11 | def gen_sublists(idx, array_len, is_squaring): 12 | sublists = [] 13 | if (idx > 0): 14 | for j in xrange(min(idx, array_len)): 15 | a = idx - 1 - j + 2 * array_len 16 | b = j + 3 * array_len 17 | if checks_passed(a, b, array_len, is_squaring): 18 | sublists.append((a, b, True)) 19 | for j in xrange(min(idx, array_len) + 1): 20 | a = idx - j + 2 * array_len 21 | b = j + 3 * array_len 22 | if checks_passed(a, b, array_len, is_squaring): 23 | sublists.append((a, b, False)) 24 | return sublists 25 | 26 | from collections import namedtuple 27 | 28 | #c = a*b + d 29 | AsmInsn = namedtuple("AsmInsn", "gen_carry use_carry op_type is_high a b c d") 30 | op_mul = 0 31 | op_mad = 1 32 | op_add = 2 33 | op_unknown = 3 34 | 35 | def gen_table(array_len, squaring): 36 | table = {} 37 | table_len = 0 38 | for i in xrange(array_len * 2): 39 | arr = gen_sublists(i, array_len, squaring) 40 | table_len += len(arr) 41 | table[i] = arr 42 | return table, table_len 43 | 44 | def gen_asm(array_len, squaring = False): 45 | 46 | carry_arr = [False] * (2 * array_len) 47 | AsmListing = [] 48 | 49 | table, table_len = gen_table(array_len, squaring) 50 | 51 | lowest_index = 0 52 | cur_index = 0 53 | while not table[lowest_index]: 54 | lowest_index = lowest_index + 1 55 | cur_index = lowest_index 56 | 57 | use_carry = False 58 | while(table_len > 0 or use_carry): 59 | if cur_index >= 2 * array_len: 60 | use_carry = False 61 | 62 | if table_len == 0: 63 | break 64 | #try to find next suitable index 65 | while not table[lowest_index]: 66 | lowest_index = lowest_index + 1 67 | cur_index = lowest_index 68 | 69 | elif carry_arr[cur_index]: 70 | gen_carry = True 71 | 72 | if table[cur_index]: 73 | (a, b, is_high) = table[cur_index][0] 74 | table[cur_index].pop(0) 75 | table_len = table_len - 1 76 | op_type = op_mad 77 | else: 78 | (a, b, is_high) = (cur_index, cur_index, False) 79 | op_type = op_add 80 | 81 | insn = AsmInsn(gen_carry, use_carry, op_type, is_high, a, b, cur_index, cur_index) 82 | AsmListing.append(insn) 83 | 84 | use_carry = True 85 | cur_index = cur_index + 1 86 | 87 | else: 88 | with_addition = use_carry 89 | gen_carry = False 90 | 91 | (a, b, is_high) = table[cur_index][0] 92 | table[cur_index].pop(0) 93 | table_len = table_len - 1 94 | 95 | insn = AsmInsn(gen_carry, use_carry, with_addition, is_high, a, b, cur_index, -1) 96 | AsmListing.append(insn) 97 | 98 | carry_arr[cur_index] = True 99 | use_carry = False 100 | 101 | if table_len == 0: 102 | break 103 | #try to find next suitable index 104 | while not table[lowest_index]: 105 | lowest_index = lowest_index + 1 106 | cur_index = lowest_index 107 | 108 | return AsmListing 109 | 110 | 111 | def generate_printable_asm(AsmListing): 112 | printed_asm = "" 113 | #print len(AsmListing) 114 | for elem in AsmListing: 115 | high_low = "hi." if elem.is_high else "lo." 116 | if (elem.op_type == op_mul): 117 | printed_asm += "\"mul." + high_low + "u32"; 118 | printed_asm += ' %{:d}, %{:d}, %{:d};\\n\\t\"\n'.format(elem.c, elem.a, elem.b) 119 | elif (elem.op_type == op_mad): 120 | printed_asm += "\"mad" 121 | if (elem.use_carry): 122 | printed_asm += "c" 123 | printed_asm += "." + high_low 124 | if (elem.gen_carry): 125 | printed_asm += "cc." 126 | printed_asm += "u32" + ' %{:d}, %{:d}, %{:d}, '.format(elem.c, elem.a, elem.b) 127 | ending = "0;\\n\\t\"\n" if elem.d == -1 else '%{:d};\\n\\t\"\n'.format(elem.d) 128 | printed_asm += ending 129 | elif (elem.op_type == op_add): 130 | printed_asm += "\"add" 131 | if (elem.use_carry): 132 | printed_asm += "c" 133 | printed_asm += "." 134 | if (elem.gen_carry): 135 | printed_asm += "cc." 136 | printed_asm += "u32" + ' %{:d}, %{:d}, 0;\\n\\t\"\n'.format(elem.c, elem.a) 137 | else: 138 | raise ValueError('Incorrect operand type.') 139 | 140 | return printed_asm 141 | 142 | 143 | def generate_printable_asm_reg_squaring(AsmListing, arr_len): 144 | printed_asm = "" 145 | #print len(AsmListing) 146 | for elem in AsmListing: 147 | high_low = "hi." if elem.is_high else "lo." 148 | if (elem.op_type == op_mul): 149 | printed_asm += "\"mul." + high_low + "u32"; 150 | printed_asm += ' r{:d}, a{:d}, a{:d};\\n\\t\"\n'.format(elem.c, elem.a - arr_len * 2, elem.b - arr_len * 3) 151 | elif (elem.op_type == op_mad): 152 | printed_asm += "\"mad" 153 | if (elem.use_carry): 154 | printed_asm += "c" 155 | printed_asm += "." + high_low 156 | if (elem.gen_carry): 157 | printed_asm += "cc." 158 | printed_asm += "u32" + ' r{:d}, a{:d}, a{:d}, '.format(elem.c, elem.a - arr_len * 2, elem.b - arr_len * 3) 159 | ending = "0;\\n\\t\"\n" if elem.d == -1 else 'r{:d};\\n\\t\"\n'.format(elem.d) 160 | printed_asm += ending 161 | elif (elem.op_type == op_add): 162 | printed_asm += "\"add" 163 | if (elem.use_carry): 164 | printed_asm += "c" 165 | printed_asm += "." 166 | if (elem.gen_carry): 167 | printed_asm += "cc." 168 | printed_asm += "u32" + ' r{:d}, r{:d}, 0;\\n\\t\"\n'.format(elem.c, elem.a) 169 | else: 170 | raise ValueError('Incorrect operand type.') 171 | 172 | return printed_asm 173 | 174 | 175 | 176 | 177 | ARR_LEN = 8 178 | AsmListing = gen_asm(ARR_LEN, True) 179 | print generate_printable_asm_reg_squaring(AsmListing, ARR_LEN) -------------------------------------------------------------------------------- /sources/finitefield.cu: -------------------------------------------------------------------------------- 1 | #include "cuda_structs.h" 2 | 3 | //This module comtains functions required for finite field arithmetic 4 | //---------------------------------------------------------------------------------------------------------------------------------------------------- 5 | //---------------------------------------------------------------------------------------------------------------------------------------------------- 6 | //---------------------------------------------------------------------------------------------------------------------------------------------------- 7 | 8 | DEVICE_FUNC uint256_g FIELD_ADD_INV(const uint256_g& elem) 9 | { 10 | if (!is_zero(elem)) 11 | return SUB(BASE_FIELD_P, elem); 12 | else 13 | return elem; 14 | } 15 | 16 | DEVICE_FUNC uint256_g FIELD_ADD(const uint256_g& a, const uint256_g& b ) 17 | { 18 | uint256_g w = ADD(a, b); 19 | if (CMP(w, BASE_FIELD_P) >= 0) 20 | return SUB(w, BASE_FIELD_P); 21 | return w; 22 | } 23 | 24 | DEVICE_FUNC uint256_g FIELD_SUB(const uint256_g& a, const uint256_g& b) 25 | { 26 | if (CMP(a, b) >= 0) 27 | return SUB(a, b); 28 | else 29 | { 30 | uint256_g t = ADD(a, BASE_FIELD_P); 31 | return SUB(t, b); 32 | } 33 | } 34 | 35 | //We are using https://www.researchgate.net/publication/3387259_Improved_Montgomery_modular_inverse_algorithm (algorithm 5) 36 | //the description of The Almost Montgomery Inverse (so-called phase 1) is taken from 37 | //http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.75.8377&rep=rep1&type=pdf 38 | 39 | struct stage_one_data 40 | { 41 | uint256_g almost_mont_inverse; 42 | uint32_t k; 43 | }; 44 | 45 | #include 46 | 47 | __device__ void print2_uint256(const uint256_g& val) 48 | { 49 | printf("%x %x %x %x %x %x %x %x\n", val.n[7], val.n[6], val.n[5], val.n[4], val.n[3], val.n[2], val.n[1], val.n[0]); 50 | } 51 | 52 | static DEVICE_FUNC inline stage_one_data stage_one_mul_inv(const uint256_g& elem) 53 | { 54 | uint256_g U = BASE_FIELD_P; 55 | uint256_g V = elem; 56 | uint256_g R = uint256_g{0, 0, 0, 0, 0, 0, 0, 0}; 57 | uint256_g S = uint256_g{1, 0, 0, 0, 0, 0, 0, 0}; 58 | 59 | uint32_t k = 0; 60 | 61 | while (!is_zero(V)) 62 | { 63 | if (is_even(U)) 64 | { 65 | U = SHIFT_RIGHT(U, 1); 66 | S = SHIFT_LEFT(S, 1); 67 | } 68 | else if (is_even(V)) 69 | { 70 | V = SHIFT_RIGHT(V, 1); 71 | R = SHIFT_LEFT(R, 1); 72 | } 73 | else if (CMP(U, V) > 0) 74 | { 75 | U = SHIFT_RIGHT(SUB(U, V), 1); 76 | R = ADD(R, S); 77 | S = SHIFT_LEFT(S, 1); 78 | } 79 | else 80 | { 81 | V = SHIFT_RIGHT(SUB(V, U), 1); 82 | S = ADD(R, S); 83 | R = SHIFT_LEFT(R, 1); 84 | } 85 | 86 | k++; 87 | } 88 | 89 | if (CMP(R, BASE_FIELD_P) >= 0) 90 | R = SUB(R, BASE_FIELD_P); 91 | 92 | R = SUB(BASE_FIELD_P, R); 93 | 94 | return stage_one_data{R, k}; 95 | } 96 | 97 | DEVICE_FUNC uint256_g FIELD_MUL_INV(const uint256_g& elem) 98 | { 99 | auto data = stage_one_mul_inv(elem); 100 | if (data.k == R_LOG) 101 | { 102 | return MONT_MUL(data.almost_mont_inverse, BASE_FIELD_R2); 103 | } 104 | else 105 | { 106 | uint32_t n = 2 * R_LOG - data.k; 107 | auto res = uint256_g{0, 0, 0, 0, 0, 0, 0, 0}; 108 | 109 | if (n < R_LOG) 110 | { 111 | set_bit(res, n); 112 | res = MONT_MUL(data.almost_mont_inverse, res); 113 | } 114 | else if (n == R_LOG + 1) 115 | { 116 | res = MONT_MUL(data.almost_mont_inverse, BASE_FIELD_R2); 117 | } 118 | else 119 | { 120 | //here n == R_LOG_2 + 2 121 | res = MONT_MUL(data.almost_mont_inverse, BASE_FIELD_R4); 122 | } 123 | 124 | return MONT_MUL(res, BASE_FIELD_R_SQUARED); 125 | } 126 | } 127 | 128 | //batch inversion - simulaneously (in place) invert all-non zero elements in the array. 129 | //NB: we assume that all elements in the array are non-zero 130 | 131 | 132 | DEVICE_FUNC void BATCH_FIELD_MUL_INV(uint256_g* vec, size_t vec_size) 133 | { 134 | 135 | } 136 | 137 | 138 | //this is a field embedded into a group of points on elliptic curve 139 | 140 | 141 | DEVICE_FUNC embedded_field::embedded_field(const uint256_g rep): rep_(rep) {} 142 | 143 | DEVICE_FUNC embedded_field::embedded_field() {} 144 | 145 | DEVICE_FUNC bool embedded_field::operator==(const embedded_field& other) const 146 | { 147 | return EQUAL(rep_, other.rep_); 148 | } 149 | 150 | DEVICE_FUNC embedded_field embedded_field::zero() 151 | { 152 | uint256_g x; 153 | #pragma unroll 154 | for(uint32_t i = 0; i < N; i++) 155 | x.n[i] = 0; 156 | 157 | return embedded_field(x); 158 | } 159 | 160 | DEVICE_FUNC embedded_field embedded_field::one() 161 | { 162 | return embedded_field(EMBEDDED_FIELD_R); 163 | } 164 | 165 | DEVICE_FUNC bool embedded_field::operator!=(const embedded_field& other) const 166 | { 167 | return !EQUAL(rep_, other.rep_); 168 | } 169 | 170 | DEVICE_FUNC embedded_field::operator uint256_g() const 171 | { 172 | return rep_; 173 | } 174 | 175 | DEVICE_FUNC embedded_field embedded_field::operator-() const 176 | { 177 | if (!is_zero(rep_)) 178 | return embedded_field(SUB(EMBEDDED_FIELD_P, rep_)); 179 | else 180 | return *this; 181 | } 182 | 183 | //NB: for now we assume that highest possible limb bit is zero for the field modulus 184 | DEVICE_FUNC embedded_field& embedded_field::operator+=(const embedded_field& other) 185 | { 186 | rep_ = ADD(rep_, other.rep_); 187 | if (CMP(rep_, EMBEDDED_FIELD_P) >= 0) 188 | rep_ = SUB(rep_, EMBEDDED_FIELD_P); 189 | return *this; 190 | } 191 | 192 | DEVICE_FUNC embedded_field& embedded_field::operator-=(const embedded_field& other) 193 | { 194 | if (CMP(rep_, other.rep_) >= 0) 195 | rep_ = SUB(rep_, other.rep_); 196 | else 197 | { 198 | uint256_g t = ADD(rep_, EMBEDDED_FIELD_P); 199 | rep_ = SUB(t, other.rep_); 200 | } 201 | return *this; 202 | } 203 | 204 | //here we mean montgomery multiplication 205 | 206 | DEVICE_FUNC embedded_field& embedded_field::operator*=(const embedded_field& other) 207 | { 208 | uint256_g T; 209 | uint256_g u = rep_; 210 | uint256_g v = other.rep_; 211 | 212 | #pragma unroll 213 | for (uint32_t j = 0; j < N; j++) 214 | T.n[j] = 0; 215 | 216 | uint32_t prefix_low = 0, prefix_high = 0, m; 217 | uint32_t high_word, low_word; 218 | 219 | #pragma unroll 220 | for (uint32_t i = 0; i < N; i++) 221 | { 222 | uint32_t carry = 0; 223 | #pragma unroll 224 | for (uint32_t j = 0; j < N; j++) 225 | { 226 | low_word = device_long_mul(u.n[j], v.n[i], &high_word); 227 | low_word = device_fused_add(low_word, T.n[j], &high_word); 228 | low_word = device_fused_add(low_word, carry, &high_word); 229 | carry = high_word; 230 | T.n[j] = low_word; 231 | } 232 | 233 | //TODO: may be we actually require less space? (only one additional limb instead of two) 234 | prefix_high = 0; 235 | prefix_low = device_fused_add(prefix_low, carry, &prefix_high); 236 | 237 | m = T.n[0] * EMBEDDED_FIELD_N; 238 | low_word = device_long_mul(EMBEDDED_FIELD_P.n[0], m, &high_word); 239 | low_word = device_fused_add(low_word, T.n[0], &high_word); 240 | carry = high_word; 241 | 242 | #pragma unroll 243 | for (uint32_t j = 1; j < N; j++) 244 | { 245 | low_word = device_long_mul(EMBEDDED_FIELD_P.n[j], m, &high_word); 246 | low_word = device_fused_add(low_word, T.n[j], &high_word); 247 | low_word = device_fused_add(low_word, carry, &high_word); 248 | T.n[j-1] = low_word; 249 | carry = high_word; 250 | } 251 | 252 | T.n[N-1] = device_fused_add(prefix_low, carry, &prefix_high); 253 | prefix_low = prefix_high; 254 | } 255 | 256 | if (CMP(T, EMBEDDED_FIELD_P) >= 0) 257 | { 258 | //TODO: may be better change to inary version of sub? 259 | T = SUB(T, EMBEDDED_FIELD_P); 260 | } 261 | 262 | rep_ = T; 263 | return *this; 264 | } 265 | 266 | 267 | DEVICE_FUNC embedded_field operator+(const embedded_field& left, const embedded_field& right) 268 | { 269 | embedded_field result(left); 270 | result += right; 271 | return result; 272 | } 273 | 274 | DEVICE_FUNC embedded_field operator-(const embedded_field& left, const embedded_field& right) 275 | { 276 | embedded_field result(left); 277 | result -= right; 278 | return result; 279 | } 280 | 281 | DEVICE_FUNC embedded_field operator*(const embedded_field& left, const embedded_field& right) 282 | { 283 | embedded_field result(left); 284 | result *= right; 285 | return result; 286 | } 287 | 288 | DEVICE_FUNC void gen_random_elem(embedded_field& x, curandState& state) 289 | { 290 | for (int i = 0; i < N; i++) 291 | { 292 | x.rep_.n[i] = curand(&state); 293 | } 294 | 295 | x.rep_.n[N - 1] >>= 3; 296 | } 297 | -------------------------------------------------------------------------------- /sources/ecc_exp.cu: -------------------------------------------------------------------------------- 1 | #include "cuda_structs.h" 2 | 3 | 4 | //classical double and add algorithm: 5 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 6 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 7 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 8 | 9 | //NB: may be we can achieve additional sppedup by using special logic for += 10 | 11 | #define DOUBLE_AND_ADD_EXP(SUFFIX) \ 12 | DEVICE_FUNC ec_point ECC_double_and_add_exp##SUFFIX(const ec_point& pt, const uint256_g& power)\ 13 | {\ 14 | ec_point R = pt;\ 15 | ec_point Q = point_at_infty();\ 16 | \ 17 | for (size_t i = 0; i < N_BITLEN; i++)\ 18 | {\ 19 | bool flag = get_bit(power, i);\ 20 | if (flag)\ 21 | {\ 22 | Q = ECC_ADD##SUFFIX(Q, R);\ 23 | }\ 24 | R = ECC_DOUBLE##SUFFIX(R);\ 25 | }\ 26 | return Q;\ 27 | } 28 | 29 | DOUBLE_AND_ADD_EXP(_PROJ) 30 | DOUBLE_AND_ADD_EXP(_JAC) 31 | 32 | //algorthm with ternary expansion (TODO: have a look at Pomerance prime numbers book) 33 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 34 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 35 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 36 | 37 | #define TERNARY_EXPANSION_EXP(SUFFIX) \ 38 | DEVICE_FUNC ec_point ECC_ternary_expansion_exp##SUFFIX(const ec_point& pt, const uint256_g& power)\ 39 | {\ 40 | ec_point R = pt;\ 41 | ec_point Q = point_at_infty();\ 42 | \ 43 | bool x = false;\ 44 | bool y = get_bit(power, 0);\ 45 | bool z;\ 46 | \ 47 | for (size_t i = 0; i < N_BITLEN; i++)\ 48 | {\ 49 | z = get_bit(power, i + 1);\ 50 | if (y)\ 51 | {\ 52 | if (x && !z)\ 53 | {\ 54 | y = 0;\ 55 | z = 1;\ 56 | }\ 57 | else if (!x)\ 58 | {\ 59 | ec_point temp = (z ? INV(R) : R);\ 60 | Q = ECC_ADD##SUFFIX(Q, temp);\ 61 | }\ 62 | }\ 63 | \ 64 | x = y;\ 65 | y = z;\ 66 | R = ECC_DOUBLE##SUFFIX(R);\ 67 | }\ 68 | \ 69 | if (y)\ 70 | Q = ECC_ADD##SUFFIX(Q, R);\ 71 | return Q;\ 72 | } 73 | 74 | TERNARY_EXPANSION_EXP(_PROJ) 75 | TERNARY_EXPANSION_EXP(_JAC) 76 | 77 | 78 | //Ddecreaing version of double and add algorithm in order to be able to use mixed addition 79 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 80 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 81 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 82 | 83 | #define DOUBLE_AND_ADD_AFFINE_EXP(SUFFIX) \ 84 | DEVICE_FUNC ec_point ECC_double_and_add_affine_exp##SUFFIX(const affine_point& pt, const uint256_g& power)\ 85 | {\ 86 | ec_point Q = point_at_infty();\ 87 | \ 88 | for (int i = N_BITLEN - 1; i >= 0; i--)\ 89 | {\ 90 | Q = ECC_DOUBLE##SUFFIX(Q);\ 91 | bool flag = get_bit(power, i);\ 92 | if (flag)\ 93 | {\ 94 | Q = ECC_ADD_MIXED##SUFFIX(Q, pt);\ 95 | }\ 96 | }\ 97 | return Q;\ 98 | } 99 | 100 | DOUBLE_AND_ADD_AFFINE_EXP(_PROJ) 101 | DOUBLE_AND_ADD_AFFINE_EXP(_JAC) 102 | 103 | 104 | //Wnaf method 105 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 106 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 107 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 108 | 109 | //TODO: allign this struct in order to require very little amount of space 110 | struct wnaf_auxiliary_data 111 | { 112 | //NB: value is signed! 113 | int8_t value; 114 | uint8_t gap; 115 | }; 116 | 117 | //NB: WINDOW_SIZE should be >= 4 118 | #define WINDOW_SIZE 4 119 | #define EXP2(w) (1 << w) 120 | #define EXP2_MINUS_1(w) (1 << (w - 1)) 121 | 122 | static constexpr uint32_t PRECOMPUTED_ARRAY_LEN = (1 << (WINDOW_SIZE - 2)); 123 | static constexpr uint32_t MAX_WNAF_DATA_ARRAY_LEN = 2 * (N_BITLEN / (WINDOW_SIZE + 1) ) + 1; 124 | 125 | //NB: we assume that power is nonzero here 126 | //returns the number of wnaf_auxiliary_data in form array 127 | //we also assume that bit-len of w is less than word-size 128 | 129 | using clock_value_t = long long; 130 | 131 | __device__ void sleep(clock_value_t sleep_cycles) 132 | { 133 | clock_value_t start = clock64(); 134 | clock_value_t cycles_elapsed; 135 | do { cycles_elapsed = clock64() - start; } 136 | while (cycles_elapsed < sleep_cycles); 137 | } 138 | 139 | #include 140 | 141 | __device__ void print_uint256(const uint256_g& val) 142 | { 143 | printf("%x %x %x %x %x %x %x %x\n", val.n[7], val.n[6], val.n[5], val.n[4], val.n[3], val.n[2], val.n[1], val.n[0]); 144 | } 145 | 146 | DEVICE_FUNC static inline uint32_t convert_to_non_adjacent_form(const uint256_g& power, wnaf_auxiliary_data* form) 147 | { 148 | uint32_t elem_count = 0; 149 | uint256_g d = power; 150 | uint8_t current_gap = 0; 151 | 152 | while (!is_zero(d)) 153 | { 154 | uint32_t pos = __ffs(d.n[0]); 155 | uint32_t shift; 156 | if (pos == 1) 157 | { 158 | int8_t val = d.n[0] & (EXP2(WINDOW_SIZE) - 1); 159 | if (val >= EXP2_MINUS_1(WINDOW_SIZE)) 160 | { 161 | val -= EXP2(WINDOW_SIZE); 162 | ADD_UINT(d, -val); 163 | } 164 | else 165 | { 166 | SUB_UINT(d, val); 167 | } 168 | 169 | form[elem_count++] = {val, current_gap}; 170 | current_gap = WINDOW_SIZE; 171 | shift = WINDOW_SIZE; 172 | } 173 | else 174 | { 175 | shift = min(pos - 1, 32); 176 | current_gap += shift; 177 | } 178 | 179 | d = SHIFT_RIGHT(d, shift); 180 | } 181 | 182 | //printf("%d : %d\n\n", elem_count, MAX_WNAF_DATA_ARRAY_LEN); 183 | // for(int i = 0; i < elem_count; i++) 184 | // { 185 | // printf("%i: %u\n", form[i].value, form[i].gap); 186 | // } 187 | 188 | return elem_count; 189 | } 190 | 191 | #define ECC_WNAF_EXP(SUFFIX) \ 192 | DEVICE_FUNC ec_point ECC_wNAF_exp##SUFFIX(const ec_point& pt, const uint256_g& power)\ 193 | {\ 194 | if (is_zero(power))\ 195 | return point_at_infty();\ 196 | \ 197 | ec_point precomputed[PRECOMPUTED_ARRAY_LEN];\ 198 | wnaf_auxiliary_data wnaf_arr[MAX_WNAF_DATA_ARRAY_LEN];\ 199 | \ 200 | ec_point pt_doubled = ECC_DOUBLE##SUFFIX(pt);\ 201 | precomputed[0] = pt;\ 202 | \ 203 | for (uint32_t i = 1; i < PRECOMPUTED_ARRAY_LEN; i++)\ 204 | {\ 205 | precomputed[i] = ECC_ADD##SUFFIX(precomputed[i-1], pt_doubled);\ 206 | }\ 207 | \ 208 | auto count = convert_to_non_adjacent_form(power, wnaf_arr);\ 209 | ec_point Q = point_at_infty();\ 210 | \ 211 | for (int j = count - 1; j >=0 ; j--)\ 212 | {\ 213 | auto& wnaf_data = wnaf_arr[j];\ 214 | int8_t abs_offset;\ 215 | bool is_negative;\ 216 | if (wnaf_data.value >= 0)\ 217 | {\ 218 | abs_offset = wnaf_data.value;\ 219 | is_negative = false;\ 220 | }\ 221 | else\ 222 | {\ 223 | abs_offset = -wnaf_data.value;\ 224 | is_negative = true;\ 225 | }\ 226 | \ 227 | ec_point temp = precomputed[(abs_offset - 1)/ 2];\ 228 | if (is_negative)\ 229 | temp = INV(temp);\ 230 | \ 231 | Q = ECC_ADD##SUFFIX(Q, temp);\ 232 | \ 233 | for(uint8_t k = 0; k < wnaf_data.gap; k++)\ 234 | Q = ECC_DOUBLE##SUFFIX(Q);\ 235 | }\ 236 | \ 237 | return Q;\ 238 | } 239 | 240 | ECC_WNAF_EXP(_PROJ) 241 | ECC_WNAF_EXP(_JAC) 242 | 243 | 244 | -------------------------------------------------------------------------------- /sources/square_256_to_512.cu: -------------------------------------------------------------------------------- 1 | #include "cuda_structs.h" 2 | 3 | //we use (The Yang–Hseih–Laih Algorithm) described in 4 | //https://www.sciencedirect.com/science/article/pii/S0898122109000509 5 | 6 | DEVICE_FUNC uint512_g square_uint256_to_512_naive(const uint256_g& u) 7 | { 8 | uint512_g w; 9 | #pragma unroll 10 | for (uint32_t j = 0; j < N_DOUBLED; j++) 11 | w.n[j] = 0; 12 | 13 | uint32_t k, temp, temp2; 14 | 15 | #pragma unroll 16 | for (uint32_t i = 0; i < N; i++) 17 | { 18 | k=0; 19 | #pragma unroll 20 | for (uint32_t j = i + 1; j < N; j++) 21 | { 22 | uint32_t high_word = 0; 23 | uint32_t low_word = 0; 24 | low_word = device_long_mul(u.n[i], u.n[j], &high_word); 25 | low_word = device_fused_add(low_word, w.n[i + j], &high_word); 26 | low_word = device_fused_add(low_word, k, &high_word); 27 | k = high_word; 28 | w.n[i + j] = low_word; 29 | } 30 | 31 | w.n[N + i] = k; 32 | } 33 | 34 | k = 0; 35 | temp = 0; 36 | #pragma unroll 37 | for (uint32_t i = 0; i < N_DOUBLED; i++) 38 | { 39 | temp2 = w.n[i] >> 31; 40 | w.n[i] <<= 1; 41 | w.n[i] += temp; 42 | temp = temp2; 43 | } 44 | 45 | #pragma unroll 46 | for (uint32_t i = 0; i < N; i++) 47 | { 48 | uint32_t high_word = 0; 49 | uint32_t low_word = 0; 50 | low_word = device_long_mul(u.n[i], u.n[i], &high_word); 51 | low_word = device_fused_add(low_word, w.n[i + i], &high_word); 52 | low_word = device_fused_add(low_word, k, &high_word); 53 | w.n[i + i] = low_word; 54 | k = 0; 55 | w.n[i+i+1] = device_fused_add(w.n[i+i+1], high_word, &k); 56 | } 57 | 58 | return w; 59 | } 60 | 61 | DEVICE_FUNC uint512_g square_uint256_to_512_asm(const uint256_g& u) 62 | { 63 | uint512_g w; 64 | 65 | asm ( ".reg .u32 r0, r1, r2, r3, r4, r5, r6, r7, r8;\n\t" 66 | ".reg .u32 r9, r10, r11, r12, r13, r14, r15;\n\t" 67 | ".reg .u32 a0, a1, a2, a3, a4, a5, a6, a7;\n\t" 68 | #if (__CUDA_ARCH__ < 500) 69 | ".reg .u32 temp;\n\t" 70 | #endif 71 | //unpacking operands 72 | "mov.b64 {a0,a1}, %8;\n\t" 73 | "mov.b64 {a2,a3}, %9;\n\t" 74 | "mov.b64 {a4,a5}, %10;\n\t" 75 | "mov.b64 {a6,a7}, %11;\n\t" 76 | // multiplication - first stage 77 | "mul.lo.u32 r1, a0, a1;\n\t" 78 | "mul.hi.u32 r2, a0, a1;\n\t" 79 | "mad.lo.cc.u32 r2, a0, a2, r2;\n\t" 80 | "madc.hi.u32 r3, a0, a2, 0;\n\t" 81 | "mad.lo.cc.u32 r3, a1, a2, r3;\n\t" 82 | "madc.hi.u32 r4, a1, a2, 0;\n\t" 83 | "mad.lo.cc.u32 r3, a0, a3, r3;\n\t" 84 | "madc.hi.cc.u32 r4, a0, a3, r4;\n\t" 85 | "madc.hi.u32 r5, a1, a3, 0;\n\t" 86 | "mad.lo.cc.u32 r4, a1, a3, r4;\n\t" 87 | "madc.hi.cc.u32 r5, a0, a4, r5;\n\t" 88 | "madc.hi.u32 r6, a2, a3, 0;\n\t" 89 | "mad.lo.cc.u32 r4, a0, a4, r4;\n\t" 90 | "madc.lo.cc.u32 r5, a2, a3, r5;\n\t" 91 | "madc.hi.cc.u32 r6, a1, a4, r6;\n\t" 92 | "madc.hi.u32 r7, a2, a4, 0;\n\t" 93 | "mad.lo.cc.u32 r5, a1, a4, r5;\n\t" 94 | "madc.hi.cc.u32 r6, a0, a5, r6;\n\t" 95 | "madc.hi.cc.u32 r7, a1, a5, r7;\n\t" 96 | "madc.hi.u32 r8, a3, a4, 0;\n\t" 97 | "mad.lo.cc.u32 r5, a0, a5, r5;\n\t" 98 | "madc.lo.cc.u32 r6, a2, a4, r6;\n\t" 99 | "madc.hi.cc.u32 r7, a0, a6, r7;\n\t" 100 | "madc.hi.cc.u32 r8, a2, a5, r8;\n\t" 101 | "madc.hi.u32 r9, a3, a5, 0;\n\t" 102 | "mad.lo.cc.u32 r6, a1, a5, r6;\n\t" 103 | "madc.lo.cc.u32 r7, a3, a4, r7;\n\t" 104 | "madc.hi.cc.u32 r8, a1, a6, r8;\n\t" 105 | "madc.hi.cc.u32 r9, a2, a6, r9;\n\t" 106 | "madc.hi.u32 r10, a4, a5, 0;\n\t" 107 | "mad.lo.cc.u32 r6, a0, a6, r6;\n\t" 108 | "madc.lo.cc.u32 r7, a2, a5, r7;\n\t" 109 | "madc.hi.cc.u32 r8, a0, a7, r8;\n\t" 110 | "madc.hi.cc.u32 r9, a1, a7, r9;\n\t" 111 | "madc.hi.cc.u32 r10, a3, a6, r10;\n\t" 112 | "madc.hi.u32 r11, a4, a6, 0;\n\t" 113 | "mad.lo.cc.u32 r7, a1, a6, r7;\n\t" 114 | "madc.lo.cc.u32 r8, a3, a5, r8;\n\t" 115 | "madc.lo.cc.u32 r9, a4, a5, r9;\n\t" 116 | "madc.hi.cc.u32 r10, a2, a7, r10;\n\t" 117 | "madc.hi.cc.u32 r11, a3, a7, r11;\n\t" 118 | "madc.hi.u32 r12, a5, a6, 0;\n\t" 119 | "mad.lo.cc.u32 r7, a0, a7, r7;\n\t" 120 | "madc.lo.cc.u32 r8, a2, a6, r8;\n\t" 121 | "madc.lo.cc.u32 r9, a3, a6, r9;\n\t" 122 | "madc.lo.cc.u32 r10, a4, a6, r10;\n\t" 123 | "madc.lo.cc.u32 r11, a5, a6, r11;\n\t" 124 | "madc.hi.cc.u32 r12, a4, a7, r12;\n\t" 125 | "madc.hi.u32 r13, a5, a7, 0;\n\t" 126 | "mad.lo.cc.u32 r8, a1, a7, r8;\n\t" 127 | "madc.lo.cc.u32 r9, a2, a7, r9;\n\t" 128 | "madc.lo.cc.u32 r10, a3, a7, r10;\n\t" 129 | "madc.lo.cc.u32 r11, a4, a7, r11;\n\t" 130 | "madc.lo.cc.u32 r12, a5, a7, r12;\n\t" 131 | "madc.lo.cc.u32 r13, a6, a7, r13;\n\t" 132 | "madc.hi.u32 r14, a6, a7, 0;\n\t" 133 | //shifting 134 | #if (__CUDA_ARCH__ >= 500) 135 | "shf.l.clamp.b32 r15, r14, r15, 1;\n\t" 136 | "shf.l.clamp.b32 r14, r13, r14, 1;\n\t" 137 | "shf.l.clamp.b32 r13, r12, r13, 1;\n\t" 138 | "shf.l.clamp.b32 r12, r11, r12, 1;\n\t" 139 | "shf.l.clamp.b32 r11, r10, r11, 1;\n\t" 140 | "shf.l.clamp.b32 r10, r9, r10, 1;\n\t" 141 | "shf.l.clamp.b32 r9, r8, r9, 1;\n\t" 142 | "shf.l.clamp.b32 r8, r7, r8, 1;\n\t" 143 | "shf.l.clamp.b32 r7, r6, r7, 1;\n\t" 144 | "shf.l.clamp.b32 r6, r5, r6, 1;\n\t" 145 | "shf.l.clamp.b32 r5, r4, r5, 1;\n\t" 146 | "shf.l.clamp.b32 r4, r3, r4, 1;\n\t" 147 | "shf.l.clamp.b32 r3, r2, r3, 1;\n\t" 148 | "shf.l.clamp.b32 r2, r1, r2, 1;\n\t" 149 | "shl.b32 r1, r1, 1;\n\t" 150 | #else 151 | "shr.b32 r15, r14, 31;\n\t" 152 | "shl.b32 r14, r14, 1;\n\t" 153 | "shr.b32 temp, r13, 31;\n\t" 154 | "or.b32 r14, r14, temp;\n\t" 155 | "shl.b32 r13, r13, 1;\n\t" 156 | "shr.b32 temp, r12, 31;\n\t" 157 | "or.b32 r13, r13, temp;\n\t" 158 | "shl.b32 r12, r12, 1;\n\t" 159 | "shr.b32 temp, r11, 31;\n\t" 160 | "or.b32 r12, r12, temp;\n\t" 161 | "shl.b32 r11, r11, 1;\n\t" 162 | "shr.b32 temp, r10, 31;\n\t" 163 | "or.b32 r11, r11, temp;\n\t" 164 | "shl.b32 r10, r10, 1;\n\t" 165 | "shr.b32 temp, r9, 31;\n\t" 166 | "or.b32 r10, r10, temp;\n\t" 167 | "shl.b32 r9, r9, 1;\n\t" 168 | "shr.b32 temp, r8, 31;\n\t" 169 | "or.b32 r9, r9, temp;\n\t" 170 | "shl.b32 r8, r8, 1;\n\t" 171 | "shr.b32 temp, r7, 31;\n\t" 172 | "or.b32 r8, r8, temp;\n\t" 173 | "shl.b32 r7, r7, 1;\n\t" 174 | "shr.b32 temp, r6, 31;\n\t" 175 | "or.b32 r7, r7, temp;\n\t" 176 | "shl.b32 r6, r6, 1;\n\t" 177 | "shr.b32 temp, r5, 31;\n\t" 178 | "or.b32 r6, r6, temp;\n\t" 179 | "shl.b32 r5, r5, 1;\n\t" 180 | "shr.b32 temp, r4, 31;\n\t" 181 | "or.b32 r5, r5, temp;\n\t" 182 | "shl.b32 r4, r4, 1;\n\t" 183 | "shr.b32 temp, r3, 31;\n\t" 184 | "or.b32 r4, r4, temp;\n\t" 185 | "shl.b32 r3, r3, 1;\n\t" 186 | "shr.b32 temp, r2, 31;\n\t" 187 | "or.b32 r3, r3, temp;\n\t" 188 | "shl.b32 r2, r2, 1;\n\t" 189 | "shr.b32 temp, r1, 31;\n\t" 190 | "or.b32 r2, r2, temp;\n\t" 191 | "shl.b32 r1, r1, 1;\n\t" 192 | #endif 193 | //final multiplication 194 | "mad.lo.cc.u32 r0, a0, a0, 0;\n\t" 195 | "madc.hi.cc.u32 r1, a0, a0, r1;\n\t" 196 | "madc.lo.cc.u32 r2, a1, a1, r2;\n\t" 197 | "madc.hi.cc.u32 r3, a1, a1, r3;\n\t" 198 | "madc.lo.cc.u32 r4, a2, a2, r4;\n\t" 199 | "madc.hi.cc.u32 r5, a2, a2, r5;\n\t" 200 | "madc.lo.cc.u32 r6, a3, a3, r6;\n\t" 201 | "madc.hi.cc.u32 r7, a3, a3, r7;\n\t" 202 | "madc.lo.cc.u32 r8, a4, a4, r8;\n\t" 203 | "madc.hi.cc.u32 r9, a4, a4, r9;\n\t" 204 | "madc.lo.cc.u32 r10, a5, a5, r10;\n\t" 205 | "madc.hi.cc.u32 r11, a5, a5, r11;\n\t" 206 | "madc.lo.cc.u32 r12, a6, a6, r12;\n\t" 207 | "madc.hi.cc.u32 r13, a6, a6, r13;\n\t" 208 | "madc.lo.cc.u32 r14, a7, a7, r14;\n\t" 209 | "madc.hi.cc.u32 r15, a7, a7, r15;\n\t" 210 | //packing result 211 | "mov.b64 %0, {r0,r1};\n\t" 212 | "mov.b64 %1, {r2,r3};\n\t" 213 | "mov.b64 %2, {r4,r5};\n\t" 214 | "mov.b64 %3, {r6,r7};\n\t" 215 | "mov.b64 %4, {r8,r9};\n\t" 216 | "mov.b64 %5, {r10,r11};\n\t" 217 | "mov.b64 %6, {r12,r13};\n\t" 218 | "mov.b64 %7, {r14,r15};\n\t" 219 | : "=l"(w.nn[0]), "=l"(w.nn[1]), "=l"(w.nn[2]), "=l"(w.nn[3]), 220 | "=l"(w.nn[4]), "=l"(w.nn[5]), "=l"(w.nn[6]), "=l"(w.nn[7]) 221 | : "l"(u.nn[0]), "l"(u.nn[1]), "l"(u.nn[2]), "l"(u.nn[3])); 222 | 223 | return w; 224 | } 225 | -------------------------------------------------------------------------------- /sources/test_framework.cu: -------------------------------------------------------------------------------- 1 | #include "cuda_structs.h" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | #include 11 | 12 | #define CATCH_CONFIG_MAIN 13 | #include "catch.hpp" 14 | 15 | static constexpr size_t BYTES_PER_LIMB = 4; 16 | 17 | template 18 | T unpack_from_string(const std::string&) 19 | { 20 | static constexpr size_t chars_per_limb = 2 * BYTES_PER_LIMB; 21 | 22 | const size_t str_len = str.size(); 23 | 24 | size_t LIMB_COUNT = (std::is_same::value ? 8 : 16); 25 | 26 | assert(str_len <= 2 * bytes_per_limb * LIMB_COUNT); 27 | 28 | T res; 29 | for (size_t i = 0; i < LIMB_COUNT; i++) 30 | res.n[i] = 0; 31 | 32 | if (str_len == 0) 33 | return; 34 | 35 | boost::cnv::cstream ccnv; 36 | ccnv(std::hex)(std::skipws); 37 | 38 | size_t i = str_len; 39 | limb_index_t limb_index = 0; 40 | while (i > 0) 41 | { 42 | size_t j = (2 * bytes_per_limb > i ? 0 : i - 2 * bytes_per_limb); 43 | std::string_view str_view(str.c_str() + j, i - j); 44 | i -= (i > 2 * bytes_per_limb ? 2 * bytes_per_limb : i); 45 | auto opt_val = boost::convert(str_view, ccnv); 46 | if (opt_val) 47 | res.n[limb_index++] = opt_val.get(); 48 | else 49 | throw std::runtime_error("Incorrect conversion"); 50 | } 51 | 52 | return res; 53 | } 54 | 55 | template<> 56 | ec_point unpack_from_string(const std::string& str) 57 | { 58 | std::vector strings; 59 | std::istringstream str_stream(str); 60 | std::string s; 61 | while (getline(str_stream, s, ',')) 62 | strings.push_back(s); 63 | 64 | assert(string.size() == 3); 65 | 66 | ec_point res; 67 | res.x = unpack_from_string(strings[0]); 68 | res.y = unpack_from_string(strings[1]); 69 | res.z = unpack_from_string(strings[2]); 70 | 71 | return res; 72 | } 73 | 74 | template 75 | using kernel_func_ptr = void (*)(Atype*, Btype*, Ctype*, size_t); 76 | 77 | template 78 | using kernel_func_vec_t = std::vector>>; 79 | 80 | template 81 | bool test_framework(kernel_func_vec_t func_vec, Atype* A_host_arr, Btype* B_host_arr, Ctype* C_host_arr 82 | std::vector& results, const std::vector& data, size_t bench_len) 83 | { 84 | Atype* A_dev_arr = nullptr; 85 | Btype* B_dev_arr = nullptr; 86 | Ctype* C_dev_arr = nullptr; 87 | 88 | bool is_successful = true; 89 | 90 | auto num_of_kernels = func_vec.size(); 91 | cudaError_t cudaStatus; 92 | 93 | //fill in host arrays 94 | size_t data_index = 0; 95 | for (size_t i = 0; i < bench_len; i++) 96 | { 97 | A_host_arr[i] = unpack_from_string(data[data_index++]); 98 | B_host_arr[i] = unpack_from_string(data[data_index++]); 99 | C_host_arr[i] = unpack_from_string(data[data_index++]); 100 | } 101 | 102 | cudaStatus = cudaMalloc(&A_dev_arr, bench_len * sizeof(Atype)); 103 | if (cudaStatus != cudaSuccess) 104 | { 105 | fprintf(stderr, "cudaMalloc (A_dev_arr) failed!\n"); 106 | is_successful = false; 107 | goto Error; 108 | } 109 | 110 | cudaStatus = cudaMalloc(&B_dev_arr, bench_len * sizeof(Btype)); 111 | if (cudaStatus != cudaSuccess) 112 | { 113 | fprintf(stderr, "cudaMalloc (B_dev_arr) failed!\n"); 114 | is_successful = false; 115 | goto Error; 116 | } 117 | 118 | cudaStatus = cudaMalloc(&C_dev_arr, bench_len * sizeof(Ctype)); 119 | if (cudaStatus != cudaSuccess) 120 | { 121 | fprintf(stderr, "cudaMalloc (C_dev_arr) failed!\n"); 122 | is_successful = false; 123 | goto Error; 124 | } 125 | 126 | cudaStatus = cudaMemcpy(A_dev_arr, A_host_arr, bench_len * sizeof(Atype), cudaMemcpyHostToDevice); 127 | if (cudaStatus != cudaSuccess) 128 | { 129 | fprintf(stderr, "cudaMemcpy (A_arrs) failed!\n"); 130 | is_successful = false; 131 | goto Error; 132 | } 133 | 134 | cudaStatus = cudaMemcpy(B_dev_arr, B_host_arr, bench_len * sizeof(Btype), cudaMemcpyHostToDevice); 135 | if (cudaStatus != cudaSuccess) 136 | { 137 | fprintf(stderr, "cudaMemcpy (B_arrs) failed!\n"); 138 | is_successful = false; 139 | goto Error; 140 | } 141 | 142 | //run_kernels! 143 | //--------------------------------------------------------------------------------------------------------------------------------- 144 | for(size_t i = 0; i < num_of_kernels; i++) 145 | { 146 | auto f = func_vec[i].second; 147 | auto message = func_vec[i].first; 148 | 149 | std::cout << "Launching kernel: " << message << std::endl; 150 | 151 | f(A_dev_arr, B_dev_arr, C_dev_arr, bench_len); 152 | 153 | // Check for any errors launching the kernel 154 | cudaStatus = cudaGetLastError(); 155 | if (cudaStatus != cudaSuccess) 156 | { 157 | fprintf(stderr, "kernel launch failed: %s\n", cudaGetErrorString(cudaStatus)); 158 | is_successful = false; 159 | goto Error; 160 | } 161 | 162 | // cudaDeviceSynchronize waits for the kernel to finish, and returns 163 | // any errors encountered during the launch. 164 | cudaStatus = cudaDeviceSynchronize(); 165 | if (cudaStatus != cudaSuccess) 166 | { 167 | fprintf(stderr, "cudaDeviceSynchronize returned error code %d after launching addKernel!\n", cudaStatus); 168 | is_successful = false; 169 | goto Error; 170 | } 171 | 172 | cudaStatus = cudaMemcpy(results[i], C_dev_arr, bench_len * sizeof(Ctype), cudaMemcpyDeviceToHost); 173 | if (cudaStatus != cudaSuccess) 174 | { 175 | fprintf(stderr, "cudaMemcpy (C_arrs) failed!"); 176 | is_successful = false; 177 | goto Error; 178 | } 179 | } 180 | 181 | Error: 182 | cudaFree(A_dev_arr); 183 | cudaFree(B_dev_arr); 184 | cudaFree(C_dev_arr); 185 | 186 | return is_successful; 187 | } 188 | 189 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 190 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 191 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 192 | 193 | //We require a bunch of comparison functions (they are all implemented in host_funcs.cpp) 194 | 195 | bool equal_host(const uint256_g& a, const uint256_g& b); 196 | bool equal_host(const uint512_g& a, const uint512_g& b); 197 | bool equal_proj_host(const ec_point& a, const ec_point& b); 198 | bool equal_jac_host(const ec_point& a, const ec_point& b); 199 | 200 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 201 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 202 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 203 | 204 | //check addition test 205 | 206 | using general_func_vec_t = kernel_func_vec_t; 207 | 208 | void add_uint256_naive_driver(uint256_g*, uint256_g*, uint256_g*, size_t); 209 | void add_uint256_asm_driver(uint256_g*, uint256_g*, uint256_g*, size_t); 210 | 211 | general_func_vec_t addition_bench = { 212 | {"naive approach", add_uint256_naive_driver}, 213 | {"asm", add_uint256_asm_driver} 214 | }; 215 | 216 | static constexpr size_t FIXED_SIZE_ADD_TEST_SIZE = 6; 217 | static constexpr std::vector FIXED_SIZE_ADD_DATA = {}; 218 | 219 | TEST_CASE( "fixed_size addtion" , "[basic]" ) 220 | { 221 | static constexpr std::vector data = FIXED_SIZE_ADD_DATA; 222 | static constexpr size_t test_size = FIXED_SIZE_ADD_TEST_SIZE; 223 | 224 | using A_type = uint256_g; 225 | using B_type = uintt256_g; 226 | using C_type = uint256_g; 227 | auto& func_vec = addition_bench; 228 | 229 | auto num_of_kernels = func_vec.size(); 230 | 231 | A_type[test_size] a_arr; 232 | B_type[test_size] b_arr; 233 | C_type[test_size] c_arr; 234 | 235 | std::vector results_ptr; 236 | std::vector results; 237 | results.reserve(num_of_kernels * test_size); 238 | 239 | for (size_t i = 0; i < num_of_kernels; i++) 240 | { 241 | result_ptr.push_back(results.data() + i * test_size); 242 | } 243 | 244 | bool flag = test_framework(unc_vec, a_arr, b_arr, c_arr, results, data, test_len); 245 | REQUIRE(flag); 246 | 247 | for (size_t i = 0; i < num_of_kernels; i++) 248 | { 249 | auto message = func_vec[i].first; 250 | INFO( "Checking kernel: " << message); 251 | 252 | bool test_passed = true; 253 | for (size_t j=0; j < test_len; j++) 254 | { 255 | CHECK(equal_host(results[i][j], c_arr[j])); 256 | } 257 | } 258 | } 259 | 260 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 261 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 262 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 263 | 264 | 265 | 266 | 267 | 268 | -------------------------------------------------------------------------------- /sources/cuda_exports.cu: -------------------------------------------------------------------------------- 1 | #include "cuda_structs.h" 2 | #include "cuda_export_headers.h" 3 | 4 | struct geometry_local 5 | { 6 | int gridSize; 7 | int blockSize; 8 | }; 9 | 10 | template 11 | geometry_local find_suitable_geometry_local(T func, uint shared_memory_used, uint32_t smCount) 12 | { 13 | int gridSize; 14 | int blockSize; 15 | int maxActiveBlocks; 16 | 17 | cudaOccupancyMaxPotentialBlockSize(&gridSize, &blockSize, func, shared_memory_used, 0); 18 | cudaOccupancyMaxActiveBlocksPerMultiprocessor(&maxActiveBlocks, func, blockSize, shared_memory_used); 19 | gridSize = maxActiveBlocks * smCount; 20 | 21 | return geometry_local{gridSize, blockSize}; 22 | } 23 | 24 | 25 | __global__ void field_add_kernel(const embedded_field* a_arr, const embedded_field* b_arr, embedded_field* c_arr, size_t arr_len) 26 | { 27 | size_t tid = threadIdx.x + blockIdx.x * blockDim.x; 28 | while (tid < arr_len) 29 | { 30 | c_arr[tid] = a_arr[tid] + b_arr[tid]); 31 | tid += blockDim.x * gridDim.x; 32 | } 33 | } 34 | 35 | __global__ void field_sub_kernel(const embedded_field* a_arr, const embedded_field* b_arr, embedded_field* c_arr, size_t arr_len) 36 | { 37 | size_t tid = threadIdx.x + blockIdx.x * blockDim.x; 38 | while (tid < arr_len) 39 | { 40 | c_arr[tid] = a_arr[tid] - b_arr[tid]); 41 | tid += blockDim.x * gridDim.x; 42 | } 43 | } 44 | 45 | __global__ void field_mul_kernel(const embedded_field* a_arr, const embedded_field* b_arr, embedded_field* c_arr, size_t arr_len) 46 | { 47 | size_t tid = threadIdx.x + blockIdx.x * blockDim.x; 48 | while (tid < arr_len) 49 | { 50 | c_arr[tid] = a_arr[tid] * b_arr[tid]); 51 | tid += blockDim.x * gridDim.x; 52 | } 53 | } 54 | 55 | using field_kernel_t = __global__ void(const embedded_field*, const embedded_field*, embedded_field*, size_t); 56 | 57 | 58 | void field_func_invoke(const embedded_field* a_host_arr, const embedded_field* b_host_arr, embedded_field* c_host_arr, uint32_t arr_len, 59 | field_kernel_t func) 60 | { 61 | cudaDeviceProp prop; 62 | cudaGetDeviceProperties(&prop, 0); 63 | uint32_t smCount = prop.multiProcessorCount; 64 | 65 | geometry_local geometry = find_suitable_geometry_local(func, 0, smCount); 66 | 67 | embedded_field* a_dev_arr = nullptr; 68 | embedded_field* b_dev_arr = nullptr; 69 | embedded_field* c_dev_arr = nullptr; 70 | 71 | cudaMalloc((void **)&a_dev_arr, arr_len * sizeof(embedded_field)); 72 | cudaMalloc((void **)&b_dev_arr, arr_len * sizeof(embedded_field)); 73 | cudaMalloc((void **)&c_dev_arr, arr_len * sizeof(embedded_field)); 74 | 75 | cudaMemcpy(a_dev_arr, a_host_arr, arr_len * sizeof(embedded_field), cudaMemcpyHostToDevice); 76 | cudaMemcpy(b_dev_arr, b_host_arr, arr_len * sizeof(embedded_field), cudaMemcpyHostToDevice); 77 | 78 | (*func)<<>>(a_dev_arr, b_dev_arr, c_dev_arr, arr_len); 79 | 80 | cudaMemcpy(c_host_arr, c_dev_arr, arr_len * sizeof(embedded_field), cudaMemcpyDeviceToHost); 81 | 82 | cudaFree(a_dev_arr); 83 | cudaFree(b_dev_arr); 84 | cudaFree(c_dev_arr); 85 | } 86 | 87 | void field_add(const embedded_field* a_arr, const embedded_field* b_arr, embedded_field* c_arr, uint32_t arr_len) 88 | { 89 | field_func_invoke(a_arr, b_arr, c_arr, arr_len, field_add_kernel); 90 | } 91 | 92 | void field_sub(const embedded_field* a_arr, const embedded_field* b_arr, embedded_field* c_arr, uint32_t arr_len) 93 | { 94 | field_func_invoke(a_arr, b_arr, c_arr, arr_len, field_sub_kernel); 95 | } 96 | 97 | void field_mul(const embedded_field* a_arr, const embedded_field* b_arr, embedded_field* c_arr, uint32_t arr_len) 98 | { 99 | field_func_invoke(a_arr, b_arr, c_arr, arr_len, field_mul_kernel); 100 | } 101 | 102 | 103 | __global__ void ec_add_kernel(const ec_point* a_arr, const ec_point* b_arr, ec_point* c_arr, uint32_t arr_len) 104 | { 105 | size_t tid = threadIdx.x + blockIdx.x * blockDim.x; 106 | while (tid < arr_len) 107 | { 108 | c_arr[tid] = ECC_ADD(a_arr[tid], b_arr[tid]); 109 | tid += blockDim.x * gridDim.x; 110 | } 111 | } 112 | 113 | __global__ void ec_sub_kernel(const ec_point* a_arr, const ec_point* b_arr, ec_point* c_arr, uint32_t arr_len) 114 | { 115 | size_t tid = threadIdx.x + blockIdx.x * blockDim.x; 116 | while (tid < arr_len) 117 | { 118 | c_arr[tid] = ECC_SUB(a_arr[tid], b_arr[tid]); 119 | tid += blockDim.x * gridDim.x; 120 | } 121 | } 122 | 123 | using ec_kernel_t = __global__ void(const ec_point*, const ec_point*, ec_point*, size_t); 124 | 125 | void ec_func_invoke(const ec_point* a_host_arr, const ec_point* b_host_arr, ec_point* c_host_arr, uint32_t arr_len, 126 | ec_kernel_t func) 127 | { 128 | cudaDeviceProp prop; 129 | cudaGetDeviceProperties(&prop, 0); 130 | uint32_t smCount = prop.multiProcessorCount; 131 | 132 | geometry_local geometry = find_suitable_geometry_local(func, 0, smCount); 133 | 134 | ec_point* a_dev_arr = nullptr; 135 | ec_point* b_dev_arr = nullptr; 136 | ec_point* c_dev_arr = nullptr; 137 | 138 | cudaMalloc((void **)&a_dev_arr, arr_len * sizeof(ec_point)); 139 | cudaMalloc((void **)&b_dev_arr, arr_len * sizeof(ec_point)); 140 | cudaMalloc((void **)&c_dev_arr, arr_len * sizeof(ec_point)); 141 | 142 | cudaMemcpy(a_dev_arr, a_host_arr, arr_len * sizeof(ec_point), cudaMemcpyHostToDevice); 143 | cudaMemcpy(b_dev_arr, b_host_arr, arr_len * sizeof(ec_point), cudaMemcpyHostToDevice); 144 | 145 | (*func)<<>>(a_dev_arr, b_dev_arr, c_dev_arr, arr_len); 146 | 147 | cudaMemcpy(c_host_arr, c_dev_arr, arr_len * sizeof(ec_point), cudaMemcpyDeviceToHost); 148 | 149 | cudaFree(a_dev_arr); 150 | cudaFree(b_dev_arr); 151 | cudaFree(c_dev_arr); 152 | } 153 | 154 | void ec_point_add(ec_point* a_arr, ec_point* b_arr, ec_point* c_arr, uint32_t arr_len) 155 | { 156 | ec_func_invoke(a_arr, b_arr, c_arr, arr_len, ec_add_kernel); 157 | } 158 | 159 | void ec_point_sub(ec_point* a_arr, ec_point* b_arr, ec_point* c_arr, uint32_t arr_len) 160 | { 161 | ec_func_invoke(a_arr, b_arr, c_arr, arr_len, ec_sub_kernel); 162 | } 163 | 164 | //----------------------------------------------------------------------------------------------------------------------------------------------- 165 | //Multiexponentiation (based on Pippenger realization) 166 | //----------------------------------------------------------------------------------------------------------------------------------------------- 167 | 168 | void large_Pippenger_driver(affine_point*, uint256_g*, ec_point*, size_t); 169 | 170 | ec_point ec_multiexp(affine_point* points, uint256_g* powers, uint32_t arr_len) 171 | { 172 | 173 | affine_point* dev_points = nullptr; 174 | uint256_g* dev_powers = nullptr; 175 | ec_point* dev_res = nullptr; 176 | 177 | ec_point res; 178 | 179 | cudaMalloc((void **)&dev_points, arr_len * sizeof(affine_point)); 180 | cudaMalloc((void **)&dev_powers, arr_len * sizeof(uint256_g)); 181 | cudaMalloc((void **)&dev_res, ec_point); 182 | 183 | cudaMemcpy(dev_points, points, arr_len * sizeof(affine_point), cudaMemcpyHostToDevice); 184 | cudaMemcpy(dev_powers, powers, arr_len * sizeof(uint256_g), cudaMemcpyHostToDevice); 185 | 186 | large_Pippenger_driver(dev_points, dev_powers, dev_res, arr_len); 187 | 188 | cudaMemcpy(&res, dev_res, sizeof(ec_point), cudaMemcpyDeviceToHost); 189 | 190 | cudaFree(dev_points); 191 | cudaFree(dev_powers); 192 | cudaFree(dev_res); 193 | 194 | return res; 195 | } 196 | 197 | //----------------------------------------------------------------------------------------------------------------------------------------------- 198 | //FFT routines 199 | //----------------------------------------------------------------------------------------------------------------------------------------------- 200 | 201 | void naive_fft_driver(embedded_field*, embedded_field*, uint32_t, bool); 202 | 203 | void mult_by_const(embedded_field* arr, __constant__ embedded_field& elem) 204 | { 205 | size_t tid = threadIdx.x + blockIdx.x * blockDim.x; 206 | while (tid < arr_len) 207 | { 208 | arr[tid] *= elem; 209 | tid += blockDim.x * gridDim.x; 210 | } 211 | } 212 | 213 | void FFT_invoke(embedded_field* input_arr, embedded_field* output_arr, uint32_t arr_len, bool is_inverse, embedded_field* inversed = nullptr) 214 | { 215 | embedded_field* dev_input_arr = nullptr; 216 | embedded_field* dev_output_arr = nullptr; 217 | 218 | cudaMalloc((void **)&dev_input_arr, arr_len * sizeof(embedded_field)); 219 | cudaMalloc((void **)&dev_output_arr, arr_len * sizeof(embedded_field)); 220 | 221 | cudaMemcpy(dev_input_arr, input_arr, arr_len * sizeof(embedded_field), cudaMemcpyHostToDevice); 222 | naive_fft_driver(dev_input_arr, dev_output_arr, arr_len, is_inverse); 223 | 224 | if (is_inverse) 225 | { 226 | __constant__ embedded_field dev_temp; 227 | cudaMemcpyToSymbol(dev_temp, inversed, sizeof(embedded_field)); 228 | 229 | mult_by_const(output_arr, dev_temp); 230 | } 231 | 232 | cudaMemcpy(output_arr, dev_output_arr, arr_len * sizeof(embedded_field), cudaMemcpyDeviceToHost); 233 | 234 | cudaFree(dev_input_arr); 235 | cudaFree(dev_output_arr); 236 | } 237 | 238 | void EXPORT FFT(embedded_field* input_arr, embedded_field* output_arr, uint32_t arr_len) 239 | { 240 | FFT_invoke(input_arr, output_arr, arr_len, false); 241 | } 242 | 243 | void EXPORT iFFT(embedded_field* input_arr, embedded_field* output_arr, uint32_t arr_len, const embedded_field& n_inv) 244 | { 245 | FFT_invoke(input_arr, output_arr, arr_len, true, &n_inv); 246 | } 247 | 248 | //------------------------------------------------------------------------------------------------------------------------------------------------ 249 | //polynomial arithmetic 250 | //------------------------------------------------------------------------------------------------------------------------------------------------ 251 | 252 | // polynomial _polynomial_multiplication_on_fft(const polynomial&, const polynomial&); 253 | 254 | // polynomial EXPORT poly_add(const& polynomial, const& polynomial); 255 | // polynomial EXPORT poly_sub(const& polynomial, const& polynomial); 256 | // polynomial EXPORT poly_mul(const& polynomial, const& polynomial); 257 | 258 | -------------------------------------------------------------------------------- /sources/basic_arithmetic.cu: -------------------------------------------------------------------------------- 1 | #include "cuda_structs.h" 2 | 3 | //128 bit addition & substraction: 4 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 5 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 6 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 7 | 8 | //NB: https://devtalk.nvidia.com/default/topic/948014/forward-looking-gpu-integer-performance/?offset=14 9 | //It seems CUDA has no natural support for 64-bit integer arithmetic (and 16 bit will be obviously toooooooo slow) 10 | 11 | //in order to implement Karatsuba multiplication we need addition with carry! 12 | //HOW TO GET VALUE OF CARRY FLAG! 13 | //NO WAY! VERY DUMB STUPID NVIDIA PTX ASSEMBLY! 14 | 15 | DEVICE_FUNC uint128_with_carry_g add_uint128_with_carry_asm(const uint128_g& lhs, const uint128_g& rhs) 16 | { 17 | uint128_with_carry_g result; 18 | asm ( "add.cc.u32 %0, %5, %9;\n\t" 19 | "addc.cc.u32 %1, %6, %10;\n\t" 20 | "addc.cc.u32 %2, %7, %11;\n\t" 21 | "addc.cc.u32 %3, %8, %12;\n\t" 22 | "addc.u32 %4, 0, 0;\n\t" 23 | : "=r"(result.val.n[0]), "=r"(result.val.n[1]), "=r"(result.val.n[2]), "=r"(result.val.n[3]), "=r"(result.carry) 24 | : "r"(lhs.n[0]), "r"(lhs.n[1]), "r"(lhs.n[2]), "r"(lhs.n[3]), 25 | "r"(rhs.n[0]), "r"(rhs.n[1]), "r"(rhs.n[2]), "r"(rhs.n[3])); 26 | 27 | return result; 28 | } 29 | 30 | 31 | DEVICE_FUNC uint128_g sub_uint128_asm(const uint128_g& lhs, const uint128_g& rhs) 32 | { 33 | uint128_g result; 34 | asm ( "sub.cc.u32 %0, %4, %8;\n\t" 35 | "subc.cc.u32 %1, %5, %9;\n\t" 36 | "subc.cc.u32 %2, %6, %10;\n\t" 37 | "subc.u32 %3, %7, %11;\n\t" 38 | : "=r"(result.n[0]), "=r"(result.n[1]), "=r"(result.n[2]), "=r"(result.n[3]) 39 | : "r"(lhs.n[0]), "r"(lhs.n[1]), "r"(lhs.n[2]), "r"(lhs.n[3]), 40 | "r"(rhs.n[0]), "r"(rhs.n[1]), "r"(rhs.n[2]), "r"(rhs.n[3])); 41 | 42 | return result; 43 | } 44 | 45 | //256 bit addition & substraction 46 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 47 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 48 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 49 | 50 | DEVICE_FUNC uint256_g add_uint256_naive(const uint256_g& lhs, const uint256_g& rhs) 51 | { 52 | uint32_t carry = 0; 53 | uint256_g result; 54 | #pragma unroll 55 | for (uint32_t i = 0; i < N; i++) 56 | { 57 | result.n[i] = lhs.n[i] + rhs.n[i] + carry; 58 | carry = (result.n[i] < lhs.n[i]); 59 | } 60 | return result; 61 | } 62 | 63 | DEVICE_FUNC uint256_g add_uint256_asm(const uint256_g& lhs, const uint256_g& rhs) 64 | { 65 | uint256_g result; 66 | asm ( "add.cc.u32 %0, %8, %16;\n\t" 67 | "addc.cc.u32 %1, %9, %17;\n\t" 68 | "addc.cc.u32 %2, %10, %18;\n\t" 69 | "addc.cc.u32 %3, %11, %19;\n\t" 70 | "addc.cc.u32 %4, %12, %20;\n\t" 71 | "addc.cc.u32 %5, %13, %21;\n\t" 72 | "addc.cc.u32 %6, %14, %22;\n\t" 73 | "addc.u32 %7, %15, %23;\n\t" 74 | : "=r"(result.n[0]), "=r"(result.n[1]), "=r"(result.n[2]), "=r"(result.n[3]), 75 | "=r"(result.n[4]), "=r"(result.n[5]), "=r"(result.n[6]), "=r"(result.n[7]) 76 | : "r"(lhs.n[0]), "r"(lhs.n[1]), "r"(lhs.n[2]), "r"(lhs.n[3]), 77 | "r"(lhs.n[4]), "r"(lhs.n[5]), "r"(lhs.n[6]), "r"(lhs.n[7]), 78 | "r"(rhs.n[0]), "r"(rhs.n[1]), "r"(rhs.n[2]), "r"(rhs.n[3]), 79 | "r"(rhs.n[4]), "r"(rhs.n[5]), "r"(rhs.n[6]), "r"(rhs.n[7])); 80 | 81 | return result; 82 | } 83 | 84 | DEVICE_FUNC uint256_g sub_uint256_naive(const uint256_g& lhs, const uint256_g& rhs) 85 | { 86 | uint32_t borrow = 0; 87 | uint256_g result; 88 | 89 | #pragma unroll 90 | for (uint32_t i = 0; i < N; i++) 91 | 92 | { 93 | uint32_t a = lhs.n[i], b = rhs.n[i]; 94 | result.n[i] = a - borrow; 95 | if (b == 0) 96 | { 97 | borrow = ( result.n[i] > a ? 1 : 0); 98 | } 99 | else 100 | { 101 | result.n[i] -= b; 102 | borrow = ( result.n[i] >= a ? 1 : 0); 103 | } 104 | } 105 | 106 | return result; 107 | } 108 | 109 | DEVICE_FUNC uint256_g sub_uint256_asm(const uint256_g& lhs, const uint256_g& rhs) 110 | { 111 | uint256_g result; 112 | 113 | asm ( "sub.cc.u32 %0, %8, %16;\n\t" 114 | "subc.cc.u32 %1, %9, %17;\n\t" 115 | "subc.cc.u32 %2, %10, %18;\n\t" 116 | "subc.cc.u32 %3, %11, %19;\n\t" 117 | "subc.cc.u32 %4, %12, %20;\n\t" 118 | "subc.cc.u32 %5, %13, %21;\n\t" 119 | "subc.cc.u32 %6, %14, %22;\n\t" 120 | "subc.u32 %7, %15, %23;\n\t" 121 | : "=r"(result.n[0]), "=r"(result.n[1]), "=r"(result.n[2]), "=r"(result.n[3]), 122 | "=r"(result.n[4]), "=r"(result.n[5]), "=r"(result.n[6]), "=r"(result.n[7]) 123 | : "r"(lhs.n[0]), "r"(lhs.n[1]), "r"(lhs.n[2]), "r"(lhs.n[3]), 124 | "r"(lhs.n[4]), "r"(lhs.n[5]), "r"(lhs.n[6]), "r"(lhs.n[7]), 125 | "r"(rhs.n[0]), "r"(rhs.n[1]), "r"(rhs.n[2]), "r"(rhs.n[3]), 126 | "r"(rhs.n[4]), "r"(rhs.n[5]), "r"(rhs.n[6]), "r"(rhs.n[7])); 127 | 128 | return result; 129 | } 130 | 131 | //NB: here addition and substraction is done in place! 132 | 133 | DEVICE_FUNC void add_uint_uint256_asm(uint256_g& elem, uint32_t num) 134 | { 135 | asm( "add.cc.u32 %0, %0, %8;\n\t" 136 | "addc.cc.u32 %1, %1, 0;\n\t" 137 | "addc.cc.u32 %2, %2, 0;\n\t" 138 | "addc.cc.u32 %3, %3, 0;\n\t" 139 | "addc.cc.u32 %4, %4, 0;\n\t" 140 | "addc.cc.u32 %5, %5, 0;\n\t" 141 | "addc.cc.u32 %6, %6, 0;\n\t" 142 | "addc.u32 %7, %7, 0;\n\t" 143 | : "+r"(elem.n[0]), "+r"(elem.n[1]), "+r"(elem.n[2]), "+r"(elem.n[3]), 144 | "+r"(elem.n[4]), "+r"(elem.n[5]), "+r"(elem.n[6]), "+r"(elem.n[7]) 145 | : "r"(num)); 146 | } 147 | 148 | DEVICE_FUNC void sub_uint_uint256_asm(uint256_g& elem, uint32_t num) 149 | { 150 | asm( "sub.cc.u32 %0, %0, %8;\n\t" 151 | "subc.cc.u32 %1, %1, 0;\n\t" 152 | "subc.cc.u32 %2, %2, 0;\n\t" 153 | "subc.cc.u32 %3, %3, 0;\n\t" 154 | "subc.cc.u32 %4, %4, 0;\n\t" 155 | "subc.cc.u32 %5, %5, 0;\n\t" 156 | "subc.cc.u32 %6, %6, 0;\n\t" 157 | "subc.u32 %7, %7, 0;\n\t" 158 | : "+r"(elem.n[0]), "+r"(elem.n[1]), "+r"(elem.n[2]), "+r"(elem.n[3]), 159 | "+r"(elem.n[4]), "+r"(elem.n[5]), "+r"(elem.n[6]), "+r"(elem.n[7]) 160 | : "r"(num)); 161 | } 162 | 163 | //256 bit shifts 164 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 165 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 166 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 167 | 168 | DEVICE_FUNC uint256_g shift_right_asm(const uint256_g& elem, uint32_t shift) 169 | { 170 | uint256_g result; 171 | 172 | asm ( "\n\t" 173 | #if (__CUDA_ARCH__ >= 500) 174 | ".reg .u32 x;\n\t" 175 | "mov.u32 x, %16;\n\t" 176 | "shf.r.clamp.b32 %0, %8, %9, x;\n\t" 177 | "shf.r.clamp.b32 %1, %9, %10, x;\n\t" 178 | "shf.r.clamp.b32 %2, %10, %11, x;\n\t" 179 | "shf.r.clamp.b32 %3, %11, %12, x;\n\t" 180 | "shf.r.clamp.b32 %4, %12, %13, x;\n\t" 181 | "shf.r.clamp.b32 %5, %13, %14, x;\n\t" 182 | "shf.r.clamp.b32 %6, %14, %15, x;\n\t" 183 | "shr.b32 %7, %15, x;\n\t" 184 | #else 185 | //Not implemented yet 186 | #endif 187 | : "=r"(result.n[0]), "=r"(result.n[1]), "=r"(result.n[2]), "=r"(result.n[3]), 188 | "=r"(result.n[4]), "=r"(result.n[5]), "=r"(result.n[6]), "=r"(result.n[7]) 189 | : "r"(elem.n[0]), "r"(elem.n[1]), "r"(elem.n[2]), "r"(elem.n[3]), 190 | "r"(elem.n[4]), "r"(elem.n[5]), "r"(elem.n[6]), "r"(elem.n[7]), "r"(shift)); 191 | 192 | return result; 193 | } 194 | 195 | 196 | DEVICE_FUNC uint256_g shift_left_asm(const uint256_g& elem, uint32_t shift) 197 | { 198 | uint256_g result; 199 | 200 | asm ( "\n\t" 201 | #if (__CUDA_ARCH__ >= 500) 202 | ".reg .u32 x;\n\t" 203 | "mov.u32 x, %16;\n\t" 204 | "shf.l.clamp.b32 %7, %14, %15, x;\n\t" 205 | "shf.l.clamp.b32 %6, %13, %14, x;\n\t" 206 | "shf.l.clamp.b32 %5, %12, %13, x;\n\t" 207 | "shf.l.clamp.b32 %4, %11, %12, x;\n\t" 208 | "shf.l.clamp.b32 %3, %10, %11, x;\n\t" 209 | "shf.l.clamp.b32 %2, %9, %10, x;\n\t" 210 | "shf.l.clamp.b32 %1, %8, %9, x;\n\t" 211 | "shl.b32 %0, %8, x;\n\t" 212 | #else 213 | //Not implemented yet 214 | #endif 215 | : "=r"(result.n[0]), "=r"(result.n[1]), "=r"(result.n[2]), "=r"(result.n[3]), 216 | "=r"(result.n[4]), "=r"(result.n[5]), "=r"(result.n[6]), "=r"(result.n[7]) 217 | : "r"(elem.n[0]), "r"(elem.n[1]), "r"(elem.n[2]), "r"(elem.n[3]), 218 | "r"(elem.n[4]), "r"(elem.n[5]), "r"(elem.n[6]), "r"(elem.n[7]), "r"(shift)); 219 | 220 | return result; 221 | } 222 | 223 | 224 | //256 bit comparison and zero equality 225 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 226 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 227 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 228 | 229 | DEVICE_FUNC int cmp_uint256_naive(const uint256_g& lhs, const uint256_g& rhs) 230 | { 231 | #pragma unroll 232 | for (int32_t i = N -1 ; i >= 0; i--) 233 | { 234 | if (lhs.n[i] > rhs.n[i]) 235 | return 1; 236 | else if (lhs.n[i] < rhs.n[i]) 237 | return -1; 238 | } 239 | return 0; 240 | } 241 | 242 | DEVICE_FUNC bool is_zero(const uint256_g& x) 243 | { 244 | #pragma unroll 245 | for (int32_t i = N -1 ; i >= 0; i--) 246 | { 247 | if (x.n[i] != 0) 248 | return false; 249 | } 250 | return true; 251 | } 252 | 253 | DEVICE_FUNC bool is_even(const uint256_g& x) 254 | { 255 | return !CHECK_BIT(x.n[0], 0); 256 | } 257 | 258 | 259 | DEVICE_FUNC void gen_random_elem(uint256_g& x, curandState& state) 260 | { 261 | for (int i = 0; i < N; i++) 262 | { 263 | x.n[i] = curand(&state); 264 | } 265 | 266 | x.n[N - 1] >>= 3; 267 | } 268 | 269 | 270 | -------------------------------------------------------------------------------- /sources/mul_128_to_256.cu: -------------------------------------------------------------------------------- 1 | #include "cuda_structs.h" 2 | 3 | DEVICE_FUNC uint256_g mul_uint128_to_256_naive(const uint128_g& u, const uint128_g& v) 4 | { 5 | uint256_g w; 6 | 7 | #pragma unroll 8 | for (uint32_t j = 0; j < HALF_N; j++) 9 | { 10 | uint32_t k = 0; 11 | 12 | #pragma unroll 13 | for (uint32_t i = 0; i < HALF_N; i++) 14 | { 15 | uint32_t high_word = 0; 16 | uint32_t low_word = 0; 17 | low_word = device_long_mul(u.n[i], v.n[j], &high_word); 18 | low_word = device_fused_add(low_word, w.n[i + j], &high_word); 19 | low_word = device_fused_add(low_word, k, &high_word); 20 | k = high_word; 21 | w.n[i + j] = low_word; 22 | } 23 | 24 | w.n[HALF_N + j] = k; 25 | } 26 | 27 | return w; 28 | } 29 | 30 | //the following two samples of optimized asm multiplication code is taken from: 31 | //https://devtalk.nvidia.com/default/topic/1017754/long-integer-multiplication-mul-wide-u64-and-mul-wide-u128/ 32 | 33 | 34 | // multiply two unsigned 128-bit integers into an unsigned 256-bit product 35 | DEVICE_FUNC uint256_g mul_uint128_to_256_asm_ver1(const uint128_g& a, const uint128_g& b) 36 | { 37 | uint256_g res; 38 | asm ("{\n\t" 39 | ".reg .u32 r0, r1, r2, r3, r4, r5, r6, r7;\n\t" 40 | ".reg .u32 a0, a1, a2, a3, b0, b1, b2, b3;\n\t" 41 | "mov.b64 {a0,a1}, %4;\n\t" 42 | "mov.b64 {a2,a3}, %5;\n\t" 43 | "mov.b64 {b0,b1}, %6;\n\t" 44 | "mov.b64 {b2,b3}, %7;\n\t" 45 | "mul.lo.u32 r0, a0, b0;\n\t" 46 | "mul.hi.u32 r1, a0, b0;\n\t" 47 | "mad.lo.cc.u32 r1, a0, b1, r1;\n\t" 48 | "madc.hi.u32 r2, a0, b1, 0;\n\t" 49 | "mad.lo.cc.u32 r1, a1, b0, r1;\n\t" 50 | "madc.hi.cc.u32 r2, a1, b0, r2;\n\t" 51 | "madc.hi.u32 r3, a0, b2, 0;\n\t" 52 | "mad.lo.cc.u32 r2, a0, b2, r2;\n\t" 53 | "madc.hi.cc.u32 r3, a1, b1, r3;\n\t" 54 | "madc.hi.u32 r4, a0, b3, 0;\n\t" 55 | "mad.lo.cc.u32 r2, a1, b1, r2;\n\t" 56 | "madc.hi.cc.u32 r3, a2, b0, r3;\n\t" 57 | "madc.hi.cc.u32 r4, a1, b2, r4;\n\t" 58 | "madc.hi.u32 r5, a1, b3, 0;\n\t" 59 | "mad.lo.cc.u32 r2, a2, b0, r2;\n\t" 60 | "madc.lo.cc.u32 r3, a0, b3, r3;\n\t" 61 | "madc.hi.cc.u32 r4, a2, b1, r4;\n\t" 62 | "madc.hi.cc.u32 r5, a2, b2, r5;\n\t" 63 | "madc.hi.u32 r6, a2, b3, 0;\n\t" 64 | "mad.lo.cc.u32 r3, a1, b2, r3;\n\t" 65 | "madc.hi.cc.u32 r4, a3, b0, r4;\n\t" 66 | "madc.hi.cc.u32 r5, a3, b1, r5;\n\t" 67 | "madc.hi.cc.u32 r6, a3, b2, r6;\n\t" 68 | "madc.hi.u32 r7, a3, b3, 0;\n\t" 69 | "mad.lo.cc.u32 r3, a2, b1, r3;\n\t" 70 | "madc.lo.cc.u32 r4, a1, b3, r4;\n\t" 71 | "madc.lo.cc.u32 r5, a2, b3, r5;\n\t" 72 | "madc.lo.cc.u32 r6, a3, b3, r6;\n\t" 73 | "addc.u32 r7, r7, 0;\n\t" 74 | "mad.lo.cc.u32 r3, a3, b0, r3;\n\t" 75 | "madc.lo.cc.u32 r4, a2, b2, r4;\n\t" 76 | "madc.lo.cc.u32 r5, a3, b2, r5;\n\t" 77 | "addc.cc.u32 r6, r6, 0;\n\t" 78 | "addc.u32 r7, r7, 0;\n\t" 79 | "mad.lo.cc.u32 r4, a3, b1, r4;\n\t" 80 | "addc.cc.u32 r5, r5, 0;\n\t" 81 | "addc.cc.u32 r6, r6, 0;\n\t" 82 | "addc.u32 r7, r7, 0;\n\t" 83 | "mov.b64 %0, {r0,r1};\n\t" 84 | "mov.b64 %1, {r2,r3};\n\t" 85 | "mov.b64 %2, {r4,r5};\n\t" 86 | "mov.b64 %3, {r6,r7};\n\t" 87 | "}" 88 | : "=l"(res.nn[0]), "=l"(res.nn[1]), "=l"(res.nn[2]), "=l"(res.nn[3]) 89 | : "l"(a.low), "l"(a.high), "l"(b.low), "l"(b.high)); 90 | 91 | return res; 92 | } 93 | 94 | //NB: I do not have enough CUDA capabilities to benchmark this implementation! 95 | 96 | #if (__CUDA_ARCH__ >= 500) 97 | DEVICE_FUNC uint256_g mul_uint128_to_256_asm_ver2(const uint128_g& a, const uint128_g& b) 98 | { 99 | uint256_g res; 100 | asm ("{\n\t" 101 | ".reg .u32 aa0, aa1, aa2, aa3, bb0, bb1, bb2, bb3;\n\t" 102 | ".reg .u32 r0, r1, r2, r3, r4, r5, r6, r7;\n\t" 103 | ".reg .u32 s0, s1, s2, s3, s4, s5, s6, s7;\n\t" 104 | ".reg .u32 t0, t1, t2, t3, t4, t5, t6, t7;\n\t" 105 | ".reg .u16 a0, a1, a2, a3, a4, a5, a6, a7;\n\t" 106 | ".reg .u16 b0, b1, b2, b3, b4, b5, b6, b7;\n\t" 107 | // unpack source operands 108 | "mov.b64 {aa0,aa1}, %4;\n\t" 109 | "mov.b64 {aa2,aa3}, %5;\n\t" 110 | "mov.b64 {bb0,bb1}, %6;\n\t" 111 | "mov.b64 {bb2,bb3}, %7;\n\t" 112 | "mov.b32 {a0,a1}, aa0;\n\t" 113 | "mov.b32 {a2,a3}, aa1;\n\t" 114 | "mov.b32 {a4,a5}, aa2;\n\t" 115 | "mov.b32 {a6,a7}, aa3;\n\t" 116 | "mov.b32 {b0,b1}, bb0;\n\t" 117 | "mov.b32 {b2,b3}, bb1;\n\t" 118 | "mov.b32 {b4,b5}, bb2;\n\t" 119 | "mov.b32 {b6,b7}, bb3;\n\t" 120 | // compute first partial sum 121 | "mul.wide.u16 r0, a0, b0;\n\t" 122 | "mul.wide.u16 r1, a0, b2;\n\t" 123 | "mul.wide.u16 r2, a0, b4;\n\t" 124 | "mul.wide.u16 r3, a0, b6;\n\t" 125 | "mul.wide.u16 r4, a1, b7;\n\t" 126 | "mul.wide.u16 r5, a3, b7;\n\t" 127 | "mul.wide.u16 r6, a5, b7;\n\t" 128 | "mul.wide.u16 r7, a7, b7;\n\t" 129 | "mul.wide.u16 t3, a1, b5;\n\t" 130 | "mul.wide.u16 t4, a2, b6;\n\t" 131 | "add.cc.u32 r3, r3, t3;\n\t" 132 | "addc.cc.u32 r4, r4, t4;\n\t" 133 | "addc.u32 r5, r5, 0;\n\t" 134 | "mul.wide.u16 t3, a2, b4;\n\t" 135 | "mul.wide.u16 t4, a3, b5;\n\t" 136 | "add.cc.u32 r3, r3, t3;\n\t" 137 | "addc.cc.u32 r4, r4, t4;\n\t" 138 | "addc.u32 r5, r5, 0;\n\t" 139 | "mul.wide.u16 t2, a1, b3;\n\t" 140 | "mul.wide.u16 t3, a3, b3;\n\t" 141 | "mul.wide.u16 t4, a4, b4;\n\t" 142 | "mul.wide.u16 t5, a4, b6;\n\t" 143 | "add.cc.u32 r2, r2, t2;\n\t" 144 | "addc.cc.u32 r3, r3, t3;\n\t" 145 | "addc.cc.u32 r4, r4, t4;\n\t" 146 | "addc.cc.u32 r5, r5, t5;\n\t" 147 | "addc.u32 r6, r6, 0;\n\t" 148 | "mul.wide.u16 t2, a2, b2;\n\t" 149 | "mul.wide.u16 t3, a4, b2;\n\t" 150 | "mul.wide.u16 t4, a5, b3;\n\t" 151 | "mul.wide.u16 t5, a5, b5;\n\t" 152 | "add.cc.u32 r2, r2, t2;\n\t" 153 | "addc.cc.u32 r3, r3, t3;\n\t" 154 | "addc.cc.u32 r4, r4, t4;\n\t" 155 | "addc.cc.u32 r5, r5, t5;\n\t" 156 | "addc.u32 r6, r6, 0;\n\t" 157 | "mul.wide.u16 t1, a1, b1;\n\t" 158 | "mul.wide.u16 t2, a3, b1;\n\t" 159 | "mul.wide.u16 t3, a5, b1;\n\t" 160 | "mul.wide.u16 t4, a6, b2;\n\t" 161 | "mul.wide.u16 t5, a6, b4;\n\t" 162 | "mul.wide.u16 t6, a6, b6;\n\t" 163 | "add.cc.u32 r1, r1, t1;\n\t" 164 | "addc.cc.u32 r2, r2, t2;\n\t" 165 | "addc.cc.u32 r3, r3, t3;\n\t" 166 | "addc.cc.u32 r4, r4, t4;\n\t" 167 | "addc.cc.u32 r5, r5, t5;\n\t" 168 | "addc.cc.u32 r6, r6, t6;\n\t" 169 | "addc.u32 r7, r7, 0;\n\t" 170 | "mul.wide.u16 t1, a2, b0;\n\t" 171 | "mul.wide.u16 t2, a4, b0;\n\t" 172 | "mul.wide.u16 t3, a6, b0;\n\t" 173 | "mul.wide.u16 t4, a7, b1;\n\t" 174 | "mul.wide.u16 t5, a7, b3;\n\t" 175 | "mul.wide.u16 t6, a7, b5;\n\t" 176 | "add.cc.u32 r1, r1, t1;\n\t" 177 | "addc.cc.u32 r2, r2, t2;\n\t" 178 | "addc.cc.u32 r3, r3, t3;\n\t" 179 | "addc.cc.u32 r4, r4, t4;\n\t" 180 | "addc.cc.u32 r5, r5, t5;\n\t" 181 | "addc.cc.u32 r6, r6, t6;\n\t" 182 | "addc.u32 r7, r7, 0;\n\t" 183 | // compute second partial sum 184 | "mul.wide.u16 t0, a0, b1;\n\t" 185 | "mul.wide.u16 t1, a0, b3;\n\t" 186 | "mul.wide.u16 t2, a0, b5;\n\t" 187 | "mul.wide.u16 t3, a0, b7;\n\t" 188 | "mul.wide.u16 t4, a2, b7;\n\t" 189 | "mul.wide.u16 t5, a4, b7;\n\t" 190 | "mul.wide.u16 t6, a6, b7;\n\t" 191 | "mul.wide.u16 s3, a1, b6;\n\t" 192 | "add.cc.u32 t3, t3, s3;\n\t" 193 | "addc.u32 t4, t4, 0;\n\t" 194 | "mul.wide.u16 s3, a2, b5;\n\t" 195 | "add.cc.u32 t3, t3, s3;\n\t" 196 | "addc.u32 t4, t4, 0;\n\t" 197 | "mul.wide.u16 s2, a1, b4;\n\t" 198 | "mul.wide.u16 s3, a3, b4;\n\t" 199 | "mul.wide.u16 s4, a3, b6;\n\t" 200 | "add.cc.u32 t2, t2, s2;\n\t" 201 | "addc.cc.u32 t3, t3, s3;\n\t" 202 | "addc.cc.u32 t4, t4, s4;\n\t" 203 | "addc.u32 t5, t5, 0;\n\t" 204 | "mul.wide.u16 s2, a2, b3;\n\t" 205 | "mul.wide.u16 s3, a4, b3;\n\t" 206 | "mul.wide.u16 s4, a4, b5;\n\t" 207 | "add.cc.u32 t2, t2, s2;\n\t" 208 | "addc.cc.u32 t3, t3, s3;\n\t" 209 | "addc.cc.u32 t4, t4, s4;\n\t" 210 | "addc.u32 t5, t5, 0;\n\t" 211 | "mul.wide.u16 s1, a1, b2;\n\t" 212 | "mul.wide.u16 s2, a3, b2;\n\t" 213 | "mul.wide.u16 s3, a5, b2;\n\t" 214 | "mul.wide.u16 s4, a5, b4;\n\t" 215 | "mul.wide.u16 s5, a5, b6;\n\t" 216 | "add.cc.u32 t1, t1, s1;\n\t" 217 | "addc.cc.u32 t2, t2, s2;\n\t" 218 | "addc.cc.u32 t3, t3, s3;\n\t" 219 | "addc.cc.u32 t4, t4, s4;\n\t" 220 | "addc.cc.u32 t5, t5, s5;\n\t" 221 | "addc.u32 t6, t6, 0;\n\t" 222 | "mul.wide.u16 s1, a2, b1;\n\t" 223 | "mul.wide.u16 s2, a4, b1;\n\t" 224 | "mul.wide.u16 s3, a6, b1;\n\t" 225 | "mul.wide.u16 s4, a6, b3;\n\t" 226 | "mul.wide.u16 s5, a6, b5;\n\t" 227 | "add.cc.u32 t1, t1, s1;\n\t" 228 | "addc.cc.u32 t2, t2, s2;\n\t" 229 | "addc.cc.u32 t3, t3, s3;\n\t" 230 | "addc.cc.u32 t4, t4, s4;\n\t" 231 | "addc.cc.u32 t5, t5, s5;\n\t" 232 | "addc.u32 t6, t6, 0;\n\t" 233 | "mul.wide.u16 s0, a1, b0;\n\t" 234 | "mul.wide.u16 s1, a3, b0;\n\t" 235 | "mul.wide.u16 s2, a5, b0;\n\t" 236 | "mul.wide.u16 s3, a7, b0;\n\t" 237 | "mul.wide.u16 s4, a7, b2;\n\t" 238 | "mul.wide.u16 s5, a7, b4;\n\t" 239 | "mul.wide.u16 s6, a7, b6;\n\t" 240 | "add.cc.u32 t0, t0, s0;\n\t" 241 | "addc.cc.u32 t1, t1, s1;\n\t" 242 | "addc.cc.u32 t2, t2, s2;\n\t" 243 | "addc.cc.u32 t3, t3, s3;\n\t" 244 | "addc.cc.u32 t4, t4, s4;\n\t" 245 | "addc.cc.u32 t5, t5, s5;\n\t" 246 | "addc.cc.u32 t6, t6, s6;\n\t" 247 | "addc.u32 t7, 0, 0;\n\t" 248 | // offset second partial sum by 16 bits 249 | "shf.l.clamp.b32 s7, t6, t7, 16;\n\t" 250 | "shf.l.clamp.b32 s6, t5, t6, 16;\n\t" 251 | "shf.l.clamp.b32 s5, t4, t5, 16;\n\t" 252 | "shf.l.clamp.b32 s4, t3, t4, 16;\n\t" 253 | "shf.l.clamp.b32 s3, t2, t3, 16;\n\t" 254 | "shf.l.clamp.b32 s2, t1, t2, 16;\n\t" 255 | "shf.l.clamp.b32 s1, t0, t1, 16;\n\t" 256 | "shf.l.clamp.b32 s0, 0, t0, 16;\n\t" 257 | // add partial sums 258 | "add.cc.u32 r0, r0, s0;\n\t" 259 | "addc.cc.u32 r1, r1, s1;\n\t" 260 | "addc.cc.u32 r2, r2, s2;\n\t" 261 | "addc.cc.u32 r3, r3, s3;\n\t" 262 | "addc.cc.u32 r4, r4, s4;\n\t" 263 | "addc.cc.u32 r5, r5, s5;\n\t" 264 | "addc.cc.u32 r6, r6, s6;\n\t" 265 | "addc.u32 r7, r7, s7;\n\t" 266 | // pack up result 267 | "mov.b64 %0, {r0,r1};\n\t" 268 | "mov.b64 %1, {r2,r3};\n\t" 269 | "mov.b64 %2, {r4,r5};\n\t" 270 | "mov.b64 %3, {r6,r7};\n\t" 271 | "}" 272 | : "=l"(res.nn[0]), "=l"(res.nn[1]), "=l"(res.nn[2]), "=l"(res.nn[3]) 273 | : "l"(a.low), "l"(a.high), "l"(b.low), "l"(b.high)); 274 | 275 | return res; 276 | } 277 | #endif 278 | -------------------------------------------------------------------------------- /sources/host_funcs.cpp: -------------------------------------------------------------------------------- 1 | #include "cuda_structs.h" 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | 10 | static constexpr uint256_g MODULUS = { 11 | 0xd87cfd47, 12 | 0x3c208c16, 13 | 0x6871ca8d, 14 | 0x97816a91, 15 | 0x8181585d, 16 | 0xb85045b6, 17 | 0xe131a029, 18 | 0x30644e72 19 | }; 20 | 21 | static constexpr uint256_g R = 22 | { 23 | 0xc58f0d9d, 24 | 0xd35d438d, 25 | 0xf5c70b3d, 26 | 0x0a78eb28, 27 | 0x7879462c, 28 | 0x666ea36f, 29 | 0x9a07df2f, 30 | 0xe0a77c1 31 | }; 32 | 33 | static constexpr uint32_t N_INV = 0xe4866389; 34 | 35 | //this file preimarly contains code for generating random point on BN-curve 36 | 37 | inline uint256_g add_uint256_host(const uint256_g& lhs, const uint256_g& rhs) 38 | { 39 | uint256_g result; 40 | uint32_t carry = 0; 41 | 42 | for (uint32_t i = 0; i < N; i++) 43 | { 44 | result.n[i] = lhs.n[i] + rhs.n[i] + carry; 45 | carry = (result.n[i] < lhs.n[i]); 46 | } 47 | return result; 48 | } 49 | 50 | 51 | inline uint256_g sub_uint256_host(const uint256_g& lhs, const uint256_g& rhs) 52 | { 53 | uint32_t borrow = 0; 54 | uint256_g result; 55 | 56 | for (uint32_t i = 0; i < N; i++) 57 | { 58 | uint32_t a = lhs.n[i], b = rhs.n[i]; 59 | result.n[i] = a - borrow; 60 | if (b == 0) 61 | { 62 | borrow = ( result.n[i] > a ? 1 : 0); 63 | } 64 | else 65 | { 66 | result.n[i] -= b; 67 | borrow = ( result.n[i] >= a ? 1 : 0); 68 | } 69 | } 70 | 71 | return result; 72 | } 73 | 74 | 75 | inline int cmp_uint256_host(const uint256_g& lhs, const uint256_g& rhs) 76 | { 77 | for (int32_t i = N -1 ; i >= 0; i--) 78 | { 79 | if (lhs.n[i] > rhs.n[i]) 80 | return 1; 81 | else if (lhs.n[i] < rhs.n[i]) 82 | return -1; 83 | } 84 | return 0; 85 | } 86 | 87 | inline bool is_zero_host(const uint256_g& x) 88 | { 89 | for (int32_t i = 0 ; i < N; i++) 90 | { 91 | if (x.n[i] != 0) 92 | return false; 93 | } 94 | return true; 95 | } 96 | 97 | 98 | inline uint32_t host_long_mul(uint32_t x, uint32_t y, uint32_t* high_ptr) 99 | { 100 | uint64_t res = (uint64_t)x * (uint64_t)y; 101 | *high_ptr = (res >> 32); 102 | return res; 103 | } 104 | 105 | inline uint32_t host_fused_add(uint32_t x, uint32_t y, uint32_t* high_ptr) 106 | { 107 | uint32_t z = x + y; 108 | if (z < x) 109 | (*high_ptr)++; 110 | return z; 111 | } 112 | 113 | 114 | inline uint256_g mont_mul_256_host(const uint256_g& u, const uint256_g& v) 115 | { 116 | uint256_g T; 117 | 118 | for (uint32_t j = 0; j < N; j++) 119 | T.n[j] = 0; 120 | 121 | uint32_t prefix_low = 0, prefix_high = 0, m; 122 | uint32_t high_word, low_word; 123 | 124 | for (uint32_t i = 0; i < N; i++) 125 | { 126 | uint32_t carry = 0; 127 | for (uint32_t j = 0; j < N; j++) 128 | { 129 | low_word = host_long_mul(u.n[j], v.n[i], &high_word); 130 | low_word = host_fused_add(low_word, T.n[j], &high_word); 131 | low_word = host_fused_add(low_word, carry, &high_word); 132 | carry = high_word; 133 | T.n[j] = low_word; 134 | } 135 | 136 | //TODO: may be we actually require less space? (only one additional limb instead of two) 137 | prefix_high = 0; 138 | prefix_low = host_fused_add(prefix_low, carry, &prefix_high); 139 | 140 | m = T.n[0] * N_INV; 141 | low_word = host_long_mul(MODULUS.n[0], m, &high_word); 142 | low_word = host_fused_add(low_word, T.n[0], &high_word); 143 | carry = high_word; 144 | 145 | #pragma unroll 146 | for (uint32_t j = 1; j < N; j++) 147 | { 148 | low_word = host_long_mul(MODULUS.n[j], m, &high_word); 149 | low_word = host_fused_add(low_word, T.n[j], &high_word); 150 | low_word = host_fused_add(low_word, carry, &high_word); 151 | T.n[j-1] = low_word; 152 | carry = high_word; 153 | } 154 | 155 | T.n[N-1] = host_fused_add(prefix_low, carry, &prefix_high); 156 | prefix_low = prefix_high; 157 | } 158 | 159 | if (cmp_uint256_host(T, MODULUS) >= 0) 160 | { 161 | //TODO: may be better change to inary version of sub? 162 | T = sub_uint256_host(T, MODULUS); 163 | } 164 | 165 | return T; 166 | } 167 | 168 | //It's safe: cause we are going to use this class only on the host 169 | 170 | class Field 171 | { 172 | private: 173 | uint256_g rep_; 174 | public: 175 | static Field zero() 176 | { 177 | return Field(0); 178 | } 179 | 180 | Field(uint32_t n = 0) 181 | { 182 | for (size_t i = 1; i < N; i++) 183 | rep_.n[i] = 0; 184 | rep_.n[0] = n; 185 | } 186 | 187 | explicit Field(uint256_g n) 188 | { 189 | for (size_t i = 0; i < N; i++) 190 | rep_.n[i] = n.n[i]; 191 | } 192 | 193 | Field(const Field& other) = default; 194 | Field(Field&& other) = default; 195 | Field& operator=(const Field&) = default; 196 | Field& operator=(Field&&) = default; 197 | 198 | bool operator==(const Field& other) const 199 | { 200 | return cmp_uint256_host(rep_, other.rep_) == 0; 201 | } 202 | 203 | bool operator!=(const Field& other) const 204 | { 205 | return cmp_uint256_host(rep_, other.rep_) != 0; 206 | } 207 | 208 | Field operator-() 209 | { 210 | uint256_g ans = (is_zero_host(rep_) ? zero().rep_ : sub_uint256_host(MODULUS, rep_)); 211 | return Field(ans); 212 | } 213 | 214 | //NB: for now we assume that highest possible limb bit is zero for the field modulus 215 | Field& operator+=(const Field& other) 216 | { 217 | rep_ = add_uint256_host(rep_, other.rep_); 218 | if (cmp_uint256_host(rep_, MODULUS) >= 0) 219 | rep_ = sub_uint256_host(rep_, MODULUS); 220 | return *this; 221 | } 222 | 223 | Field& operator-=(const Field& other) 224 | { 225 | if (cmp_uint256_host(rep_, other.rep_) < 0) 226 | rep_ = add_uint256_host(rep_, MODULUS); 227 | rep_ = sub_uint256_host(rep_, other.rep_); 228 | return *this; 229 | } 230 | 231 | //here we mean montgomery multiplication 232 | Field& operator*=(const Field& other) 233 | { 234 | rep_ = mont_mul_256_host(rep_, other.rep_); 235 | return *this; 236 | } 237 | 238 | uint256_g get_raw_rep() const 239 | { 240 | return rep_; 241 | } 242 | 243 | friend Field operator+(const Field& left, const Field& right); 244 | friend Field operator-(const Field& left, const Field& right); 245 | friend Field operator*(const Field& left, const Field& right); 246 | 247 | friend std::ostream& operator<<(std::ostream& os, const Field& elem); 248 | }; 249 | 250 | Field operator+(const Field& left, const Field& right) 251 | { 252 | Field result(left); 253 | result += right; 254 | return result; 255 | } 256 | 257 | Field operator-(const Field& left, const Field& right) 258 | { 259 | Field result(left); 260 | result -= right; 261 | return result; 262 | } 263 | 264 | Field operator*(const Field& left, const Field& right) 265 | { 266 | Field result(left); 267 | result *= right; 268 | return result; 269 | } 270 | 271 | std::ostream& operator<<(std::ostream& os, const Field& elem) 272 | { 273 | os << "0x"; 274 | for (int i = 7; i >= 0; i--) 275 | { 276 | os << std::setfill('0') << std::hex << std::setw(8) << elem.rep_.n[i]; 277 | } 278 | return os; 279 | } 280 | 281 | 282 | inline bool get_bit_host(const uint256_g& x, size_t index) 283 | { 284 | auto num = x.n[index / 32]; 285 | auto pos = index % 32; 286 | return CHECK_BIT(num, pos); 287 | } 288 | 289 | inline Field exp_host(const Field& elem, const uint256_g& power) 290 | { 291 | Field S = elem; 292 | Field Q = Field(R); 293 | 294 | for (size_t i = 0; i < N_BITLEN; i++) 295 | { 296 | bool flag = get_bit_host(power, i); 297 | if (flag) 298 | { 299 | Q *= S; 300 | } 301 | 302 | S *= S; 303 | } 304 | return Q; 305 | } 306 | 307 | //We are not able to compile with C++ 17 standard 308 | 309 | struct none_t{}; 310 | static constexpr none_t NONE_OPT; 311 | 312 | template 313 | class optional 314 | { 315 | private: 316 | bool flag_; 317 | T val_; 318 | 319 | static_assert(std::is_default_constructible::value, "Inner type of optional should be constructible!"); 320 | public: 321 | optional(const T& val): flag_(true), val_(val) {} 322 | optional(const none_t& none): flag_(false) {} 323 | optional(): flag_(false) {} 324 | 325 | optional(const optional& other) = default; 326 | optional(optional&& other) = default; 327 | optional& operator=(const optional&) = default; 328 | optional& operator=(optional&&) = default; 329 | 330 | operator bool() const 331 | { 332 | return flag_; 333 | } 334 | 335 | const T& get_val() const 336 | { 337 | if (!flag_) 338 | throw std::runtime_error("Retrieving value of empty optional!"); 339 | return val_; 340 | } 341 | }; 342 | 343 | //The following algorithm is taken from 1st edition of 344 | //Jeffrey Hoffstein, Jill Pipher, J.H. Silverman - An introduction to mathematical cryptography 345 | //Proposition 2.27 on page 84 346 | //if p = 3 (mod 4) , what is true for BN256-curve underlying field, and x^2 = a is satisfyable, then 347 | //x = a ^ (p + 1)/4 348 | //NB: the equation x^2 = a may have no solutions at all, so after computing x we require to check that it'is indeed a solution 349 | //NB: MAGIC_POWER =(P+1)/4 is constant, so we are able to precompute it 350 | //NB: Magic constant should be given in standard form (i.e. NON MONTGOMERY) 351 | 352 | static constexpr uint256_g MAGIC_CONSTANT = 353 | { 354 | 0xb61f3f52, 355 | 0x4f082305, 356 | 0x5a1c72a3, 357 | 0x65e05aa4, 358 | 0xa0605617, 359 | 0x6e14116d, 360 | 0xb84c680a, 361 | 0xc19139c 362 | }; 363 | 364 | optional square_root_host(const Field& x) 365 | { 366 | Field candidate = exp_host(x, MAGIC_CONSTANT); 367 | 368 | using X = optional; 369 | return (candidate * candidate == x ? X(candidate) : X(NONE_OPT)); 370 | } 371 | 372 | //NB: we don't need to check that our random point does belong to the right subgroup: 373 | //more precisely to that one, generated by G = [1, 2, 1] 374 | //this is because cofactor is 1! 375 | 376 | //NB: in our elliptic curve a = 0, b = 3, but we are working in montgomery form, so these coefficients are also taken in montgomery form 377 | //NB: after getting x, y the coordinate of point in projective: [x, y, mont(1)], and the same for jacobian! 378 | 379 | static constexpr uint256_g A = { 380 | 0, 0, 0, 0, 0, 0, 0, 0 381 | }; 382 | 383 | static constexpr uint256_g B = 384 | { 385 | 0x50ad28d7, 386 | 0x7a17caa9, 387 | 0xe15521b9, 388 | 0x1f6ac17a, 389 | 0x696bd284, 390 | 0x334bea4e, 391 | 0xce179d8e, 392 | 0x2a1f6744 393 | }; 394 | 395 | 396 | Field get_random_field_elem() 397 | { 398 | uint256_g res; 399 | for (uint32_t i =0; i < N; i++) 400 | res.n[i] = rand(); 401 | res.n[N - 1] &= 0x1fffffff; 402 | return Field(res); 403 | } 404 | 405 | 406 | ec_point get_random_point_host() 407 | { 408 | //equation in Weierstrass form: y^2 = x^3 + a * x + b 409 | //generate random x and compute right hand side 410 | //if this is not a square - repeat, again and again, until we are successful 411 | Field x; 412 | optional y_opt; 413 | while (!y_opt) 414 | { 415 | x = get_random_field_elem(); 416 | Field righthandside = x * x * x + Field(A) * x + Field(B); 417 | y_opt = square_root_host(righthandside); 418 | } 419 | 420 | Field y = y_opt.get_val(); 421 | 422 | if (rand() % 2) 423 | y = -y; 424 | 425 | return ec_point{x.get_raw_rep(), y.get_raw_rep(), R}; 426 | } 427 | 428 | bool eqaul_host(const uint256_g& a, const uint256_g& b) 429 | { 430 | for (int32_t i = 0 ; i < N; i++) 431 | { 432 | if (a.n[i] != b.n[i]) 433 | return false; 434 | } 435 | return true; 436 | } 437 | 438 | bool eqaul_host(const uint512_g& a, const uint512_g& b) 439 | { 440 | for (int32_t i = 0 ; i < N_DOUBLED; i++) 441 | { 442 | if (a.n[i] != b.n[i]) 443 | return false; 444 | } 445 | return true; 446 | } 447 | 448 | bool equal_proj_host(const ec_point& a, const ec_point& b) 449 | { 450 | auto x1 = Field(a.x); 451 | auto y1 = Field(a.y); 452 | auto z1 = Field(a.x); 453 | 454 | auto x2 = Field(b.x); 455 | auto y2 = Field(b.x); 456 | auto z2 = Field(b.x); 457 | 458 | bool first_comp = (x1 * y2 == x2 * y1); 459 | bool second_comp = (x1 * z2 == z1 * x2); 460 | bool third_comp = (y1 * z2 == z1 * y2); 461 | 462 | return first_comp && second_comp && third_comp; 463 | } 464 | 465 | bool is_infinity_host(const ec_point& point) 466 | { 467 | return is_zero_host(point.z); 468 | } 469 | 470 | bool equal_jac_host(const ec_point& a, const ec_point& b) 471 | { 472 | if (is_infinity_host(a) ^ is_infinity_host(b)) 473 | return false; 474 | if (is_infinity_host(a) & is_infinity_host(b)) 475 | return true; 476 | 477 | auto x1 = Field(a.x); 478 | auto y1 = Field(a.y); 479 | auto z1 = Field(a.x); 480 | 481 | auto x2 = Field(b.x); 482 | auto y2 = Field(b.x); 483 | auto z2 = Field(b.x); 484 | 485 | bool first_comp = (x1 * z2 * z2 == x2 * z1 * z1); 486 | bool second_comp = (y1 * z2 * z2 * z2 == y2 * z1 * z1 * z1); 487 | 488 | return (first_comp && second_comp); 489 | } 490 | -------------------------------------------------------------------------------- /sources/ell_point.cu: -------------------------------------------------------------------------------- 1 | #include "cuda_structs.h" 2 | 3 | //Arithmetic in projective coordinates (Jacobian coordinates should be faster and we are going to check it!) 4 | //TODO: we may also use BN specific optimizations (for example use, that a = 0) 5 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 6 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 7 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 8 | 9 | DEVICE_FUNC ec_point ECC_DOUBLE_PROJ(const ec_point& pt) 10 | { 11 | if (is_zero(pt.y) || is_infinity(pt)) 12 | return point_at_infty(); 13 | else 14 | { 15 | uint256_g temp, temp2; 16 | uint256_g W, S, B, H, S2; 17 | ec_point res; 18 | 19 | #ifdef BN256_SPECIFIC_OPTIMIZATION 20 | temp = MONT_SQUARE(pt.x); 21 | W = MONT_MUL(temp, R3_g); 22 | #else 23 | temp = MONT_SQUARE(pt.x); 24 | temp = MONT_MUL(temp, BASE_FIELD_R3); 25 | temp2 = MONT_SQUARE(pt.z); 26 | temp2 = MONT_MUL(temp2, CURVE_A_COEFF); 27 | W = FIELD_ADD(temp, temp2); 28 | #endif 29 | S = MONT_MUL(pt.y, pt.z); 30 | temp = MONT_MUL(pt.x, pt.y); 31 | B = MONT_MUL(temp, S); 32 | res.x = W; 33 | 34 | temp = MONT_SQUARE(W); 35 | temp2 = MONT_MUL(BASE_FIELD_R8, B); 36 | H = FIELD_SUB(temp, temp2); 37 | 38 | temp = MONT_MUL(BASE_FIELD_R2, H); 39 | res.x = MONT_MUL(temp, S); 40 | 41 | //NB: here result is also equal to one of the operands and hence may be reused!!! 42 | //NB: this is in fact another possibility for optimization! 43 | S2 = MONT_SQUARE(S); 44 | temp = MONT_MUL(BASE_FIELD_R4, B); 45 | temp = FIELD_SUB(temp, H); 46 | temp = MONT_MUL(W, temp); 47 | 48 | temp2 = MONT_SQUARE(pt.y); 49 | temp2 = MONT_MUL(BASE_FIELD_R8, temp2); 50 | temp2 = MONT_MUL(temp2, S2); 51 | res.y = FIELD_SUB(temp, temp2); 52 | 53 | temp = MONT_MUL(BASE_FIELD_R8, S); 54 | res.z = MONT_MUL(temp, S2); 55 | 56 | return res; 57 | } 58 | } 59 | 60 | //for debug purposes only: check if point is indeed on curve 61 | DEVICE_FUNC bool IS_ON_CURVE_PROJ(const ec_point& pt) 62 | { 63 | //y^{2} * z = x^{3} + A *x * z^{2} + B * z^{3} 64 | uint256_g temp1, temp2, z2; 65 | z2 = MONT_SQUARE(pt.z); 66 | temp1 = MONT_SQUARE(pt.x); 67 | temp1 = MONT_MUL(temp1, pt.x); 68 | temp2 = MONT_MUL(CURVE_A_COEFF, pt.x); 69 | temp2 = MONT_MUL(temp2, z2); 70 | temp1 = FIELD_ADD(temp1, temp2); 71 | temp2 = MONT_MUL(CURVE_B_COEFF, pt.z); 72 | temp2 = MONT_MUL(temp2, z2); 73 | temp1 = FIELD_ADD(temp1, temp2); 74 | temp2 = MONT_SQUARE(pt.y); 75 | temp2 = MONT_MUL(temp2, pt.z); 76 | 77 | return EQUAL(temp1, temp2); 78 | } 79 | 80 | DEVICE_FUNC bool EQUAL_PROJ(const ec_point& pt1, const ec_point& pt2) 81 | { 82 | //check all of the following equations: 83 | //X_1 * Y_2 = Y_1 * X_2; 84 | //X_1 * Z_2 = X_2 * Y_1; 85 | //Y_1 * Z_2 = Z_1 * Y_2; 86 | 87 | uint256_g temp1, temp2; 88 | 89 | temp1 = MONT_MUL(pt1.x, pt2.y); 90 | temp2 = MONT_MUL(pt1.y, pt2.x); 91 | bool first_check = EQUAL(temp1, temp2); 92 | 93 | temp1 = MONT_MUL(pt1.y, pt2.z); 94 | temp2 = MONT_MUL(pt1.z, pt2.y); 95 | bool second_check = EQUAL(temp1, temp2); 96 | 97 | temp1 = MONT_MUL(pt1.x, pt2.z); 98 | temp2 = MONT_MUL(pt1.z, pt2.x); 99 | bool third_check = EQUAL(temp1, temp2); 100 | 101 | return (first_check && second_check && third_check); 102 | } 103 | 104 | DEVICE_FUNC ec_point ECC_ADD_PROJ(const ec_point& left, const ec_point& right) 105 | { 106 | if (is_infinity(left)) 107 | return right; 108 | if (is_infinity(right)) 109 | return left; 110 | 111 | uint256_g U1, U2, V1, V2; 112 | U1 = MONT_MUL(left.z, right.y); 113 | U2 = MONT_MUL(left.y, right.z); 114 | V1 = MONT_MUL(left.z, right.x); 115 | V2 = MONT_MUL(left.x, right.z); 116 | 117 | ec_point res; 118 | 119 | if (EQUAL(V1, V2)) 120 | { 121 | if (!EQUAL(U1, U2)) 122 | return point_at_infty(); 123 | else 124 | return ECC_DOUBLE_PROJ(left); 125 | } 126 | 127 | uint256_g U = FIELD_SUB(U1, U2); 128 | uint256_g V = FIELD_SUB(V1, V2); 129 | uint256_g W = MONT_MUL(left.z, right.z); 130 | uint256_g Vsq = MONT_SQUARE(V); 131 | uint256_g Vcube = MONT_MUL(Vsq, V); 132 | 133 | uint256_g temp1, temp2; 134 | temp1 = MONT_SQUARE(U); 135 | temp1 = MONT_MUL(temp1, W); 136 | temp1 = FIELD_SUB(temp1, Vcube); 137 | temp2 = MONT_MUL(BASE_FIELD_R2, Vsq); 138 | temp2 = MONT_MUL(temp2, V2); 139 | uint256_g A = FIELD_SUB(temp1, temp2); 140 | res.x = MONT_MUL(V, A); 141 | 142 | temp1 = MONT_MUL(Vsq, V2); 143 | temp1 = FIELD_SUB(temp1, A); 144 | temp1 = MONT_MUL(U, temp1); 145 | temp2 = MONT_MUL(Vcube, U2); 146 | res.y = FIELD_SUB(temp1, temp2); 147 | 148 | res.z = MONT_MUL(Vcube, W); 149 | return res; 150 | } 151 | 152 | DEVICE_FUNC ec_point ECC_SUB_PROJ(const ec_point& left, const ec_point& right) 153 | { 154 | return ECC_ADD_PROJ(left, INV(right)); 155 | } 156 | 157 | DEVICE_FUNC ec_point ECC_ADD_MIXED_PROJ(const ec_point& left, const affine_point& right) 158 | { 159 | if (is_infinity(left)) 160 | return ec_point{right.x, right.y, BASE_FIELD_R}; 161 | 162 | uint256_g U1, V1; 163 | U1 = MONT_MUL(left.z, right.y); 164 | V1 = MONT_MUL(left.z, right.x); 165 | 166 | ec_point res; 167 | 168 | if (EQUAL(V1, left.x)) 169 | { 170 | if (!EQUAL(U1, left.y)) 171 | return point_at_infty(); 172 | else 173 | return ECC_DOUBLE_PROJ(left); 174 | } 175 | 176 | uint256_g U = FIELD_SUB(U1, left.y); 177 | uint256_g V = FIELD_SUB(V1, left.x); 178 | uint256_g Vsq = MONT_SQUARE(V); 179 | uint256_g Vcube = MONT_MUL(Vsq, V); 180 | 181 | uint256_g temp1, temp2; 182 | temp1 = MONT_SQUARE(U); 183 | temp1 = MONT_MUL(temp1, left.z); 184 | temp1 = FIELD_SUB(temp1, Vcube); 185 | temp2 = MONT_MUL(BASE_FIELD_R2, Vsq); 186 | temp2 = MONT_MUL(temp2, left.x); 187 | uint256_g A = FIELD_SUB(temp1, temp2); 188 | res.x = MONT_MUL(V, A); 189 | 190 | temp1 = MONT_MUL(Vsq, left.x); 191 | temp1 = FIELD_SUB(temp1, A); 192 | temp1 = MONT_MUL(U, temp1); 193 | temp2 = MONT_MUL(Vcube, left.y); 194 | res.y = FIELD_SUB(temp1, temp2); 195 | 196 | res.z = MONT_MUL(Vcube, left.z); 197 | return res; 198 | } 199 | 200 | // Arithmetic in Jacobian coordinates (Jacobian coordinates should be faster and we are going to check it!) 201 | // TODO: we may also use BN specific optimizations (for example use, that a = 0) 202 | // ------------------------------------------------------------------------------------------------------------------------------------------------------ 203 | // ------------------------------------------------------------------------------------------------------------------------------------------------------ 204 | // ------------------------------------------------------------------------------------------------------------------------------------------------------ 205 | 206 | //TODO: An alternative repeated doubling routine with costs (4m)M + (4m+2)S for any value a can be derived from the Modified Jacobian doubling routine. 207 | // For small values a (say 0 or -3) the costs reduce to (4m-1)M + (4m+2)S, competing nicely with the algorithm showed above. 208 | 209 | 210 | DEVICE_FUNC ec_point ECC_DOUBLE_JAC(const ec_point& pt) 211 | { 212 | if (is_zero(pt.y) || is_infinity(pt)) 213 | return point_at_infty(); 214 | else 215 | { 216 | uint256_g temp1, temp2; 217 | temp1 = MONT_MUL(BASE_FIELD_R4, pt.x); 218 | uint256_g Ysq = MONT_SQUARE(pt.y); 219 | uint256_g S = MONT_MUL(temp1, Ysq); 220 | 221 | //TODO: here we may also use BN-SPECIFIC optimizations, cause A = 0 222 | 223 | temp1 = MONT_SQUARE(pt.x); 224 | temp1 = MONT_MUL(BASE_FIELD_R3, temp1); 225 | temp2 = MONT_SQUARE(pt.z); 226 | temp2 = MONT_SQUARE(temp2); 227 | temp2 = MONT_MUL(temp2, CURVE_A_COEFF); 228 | uint256_g M = FIELD_ADD(temp1, temp2); 229 | 230 | temp1 = MONT_SQUARE(M); 231 | temp2 = MONT_MUL(BASE_FIELD_R2, S); 232 | uint256_g res_x = FIELD_SUB(temp1, temp2); 233 | 234 | temp1 = FIELD_SUB(S, res_x); 235 | temp1 = MONT_MUL(M, temp1); 236 | temp2 = MONT_SQUARE(Ysq); 237 | temp2 = MONT_MUL(BASE_FIELD_R8, temp2); 238 | uint256_g res_y = FIELD_SUB(temp1, temp2); 239 | 240 | temp1 = MONT_MUL(BASE_FIELD_R2, pt.y); 241 | uint256_g res_z = MONT_MUL(temp1, pt.z); 242 | 243 | return ec_point{res_x, res_y, res_z}; 244 | } 245 | } 246 | 247 | DEVICE_FUNC bool IS_ON_CURVE_JAC(const ec_point& pt) 248 | { 249 | //y^4 = x^3 + a x z^4 +b z^6 250 | uint256_g temp1 = MONT_SQUARE(pt.y); 251 | uint256_g lefthandside = MONT_SQUARE(temp1); 252 | 253 | uint256_g Zsq = MONT_SQUARE(pt.z); 254 | uint256_g Z4 = MONT_SQUARE(Zsq); 255 | 256 | temp1 = MONT_SQUARE(pt.x); 257 | uint256_g righthandside = MONT_MUL(temp1, pt.x); 258 | temp1 = MONT_MUL(CURVE_A_COEFF, pt.x); 259 | temp1 = MONT_MUL(temp1, Z4); 260 | righthandside = FIELD_ADD(righthandside, temp1); 261 | temp1 = MONT_MUL(CURVE_B_COEFF, Zsq); 262 | temp1 = MONT_MUL(temp1, Z4); 263 | righthandside = FIELD_ADD(righthandside, temp1); 264 | 265 | return EQUAL(lefthandside, righthandside); 266 | } 267 | 268 | DEVICE_FUNC bool EQUAL_JAC(const ec_point& pt1, const ec_point& pt2) 269 | { 270 | if (is_infinity(pt1) ^ is_infinity(pt2)) 271 | return false; 272 | if (is_infinity(pt1) & is_infinity(pt2)) 273 | return true; 274 | 275 | //now both points are not points at infinity. 276 | 277 | uint256_g Z1sq = MONT_SQUARE(pt1.z); 278 | uint256_g Z2sq = MONT_SQUARE(pt2.z); 279 | 280 | uint256_g temp1 = MONT_MUL(pt1.x, Z2sq); 281 | uint256_g temp2 = MONT_MUL(pt2.x, Z1sq); 282 | bool first_check = EQUAL(temp1, temp2); 283 | 284 | temp1 = MONT_MUL(pt1.y, Z2sq); 285 | temp1 = MONT_MUL(temp1, pt2.z); 286 | temp2 = MONT_MUL(pt2.y, Z1sq); 287 | temp2 = MONT_MUL(temp2, pt2.z); 288 | bool second_check = EQUAL(temp1, temp2); 289 | 290 | return (first_check && second_check); 291 | } 292 | 293 | DEVICE_FUNC ec_point ECC_ADD_JAC(const ec_point& left, const ec_point& right) 294 | { 295 | if (is_infinity(left)) 296 | return right; 297 | if (is_infinity(right)) 298 | return left; 299 | 300 | uint256_g U1, U2; 301 | 302 | uint256_g Z2sq = MONT_SQUARE(right.z); 303 | U1 = MONT_MUL(left.x, Z2sq); 304 | 305 | uint256_g Z1sq = MONT_SQUARE(left.z); 306 | U2 = MONT_MUL(right.x, Z1sq); 307 | 308 | uint256_g S1 = MONT_MUL(left.y, Z2sq); 309 | S1 = MONT_MUL(S1, right.z); 310 | 311 | uint256_g S2 = MONT_MUL(right.y, Z1sq); 312 | S2 = MONT_MUL(S2, left.z); 313 | 314 | if (EQUAL(U1, U2)) 315 | { 316 | if (!EQUAL(S1, S2)) 317 | return point_at_infty(); 318 | else 319 | return ECC_DOUBLE_JAC(left); 320 | } 321 | 322 | uint256_g H = FIELD_SUB(U2, U1); 323 | uint256_g R = FIELD_SUB(S2, S1); 324 | uint256_g Hsq = MONT_SQUARE(H); 325 | uint256_g Hcube = MONT_MUL(Hsq, H); 326 | uint256_g T = MONT_MUL(U1, Hsq); 327 | 328 | uint256_g res_x = MONT_SQUARE(R); 329 | res_x = FIELD_SUB(res_x, Hcube); 330 | uint256_g temp = MONT_MUL(BASE_FIELD_R2, T); 331 | res_x = FIELD_SUB(res_x, temp); 332 | 333 | uint256_g res_y = FIELD_SUB(T, res_x); 334 | res_y = MONT_MUL(R, res_y); 335 | temp = MONT_MUL(S1, Hcube); 336 | res_y = FIELD_SUB(res_y, temp); 337 | 338 | uint256_g res_z = MONT_MUL(H, left.z); 339 | res_z = MONT_MUL(res_z, right.z); 340 | 341 | return ec_point{res_x, res_y, res_z}; 342 | } 343 | 344 | DEVICE_FUNC ec_point ECC_SUB_JAC(const ec_point& left, const ec_point& right) 345 | { 346 | return ECC_ADD_JAC(left, INV(right)); 347 | } 348 | 349 | DEVICE_FUNC ec_point ECC_ADD_MIXED_JAC(const ec_point& left, const affine_point& right) 350 | { 351 | if (is_infinity(left)) 352 | return ec_point{right.x, right.y, BASE_FIELD_R}; 353 | 354 | uint256_g U2; 355 | 356 | uint256_g Z1sq = MONT_SQUARE(left.z); 357 | U2 = MONT_MUL(right.x, Z1sq); 358 | 359 | uint256_g S2 = MONT_MUL(right.y, Z1sq); 360 | S2 = MONT_MUL(S2, left.z); 361 | 362 | if (EQUAL(left.x, U2)) 363 | { 364 | if (!EQUAL(left.y, S2)) 365 | return point_at_infty(); 366 | else 367 | return ECC_DOUBLE_JAC(left); 368 | } 369 | 370 | uint256_g H = FIELD_SUB(U2, left.x); 371 | uint256_g R = FIELD_SUB(S2, left.y); 372 | uint256_g Hsq = MONT_SQUARE(H); 373 | uint256_g Hcube = MONT_MUL(Hsq, H); 374 | uint256_g T = MONT_MUL(left.x, Hsq); 375 | 376 | uint256_g res_x = MONT_SQUARE(R); 377 | res_x = FIELD_SUB(res_x, Hcube); 378 | uint256_g temp = MONT_MUL(BASE_FIELD_R2, T); 379 | res_x = FIELD_SUB(res_x, temp); 380 | 381 | uint256_g res_y = FIELD_SUB(T, res_x); 382 | res_y = MONT_MUL(R, res_y); 383 | temp = MONT_MUL(left.y, Hcube); 384 | res_y = FIELD_SUB(res_y, temp); 385 | 386 | uint256_g res_z = MONT_MUL(H, left.z); 387 | 388 | return ec_point{res_x, res_y, res_z}; 389 | } 390 | 391 | //TODO: what about repeated doubling (m-fold doubling) for Jacobian coordinates? 392 | 393 | //random number generators 394 | //--------------------------------------------------------------------------------------------------------------------------------------------------------- 395 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 396 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 397 | 398 | static DEVICE_FUNC inline uint256_g field_exp(const uint256_g& elem, const uint256_g& power) 399 | { 400 | uint256_g S = elem; 401 | uint256_g Q = BASE_FIELD_R; 402 | 403 | for (size_t i = 0; i < N_BITLEN; i++) 404 | { 405 | bool flag = get_bit(power, i); 406 | if (flag) 407 | { 408 | Q = MONT_MUL(Q, S); 409 | } 410 | 411 | S = MONT_SQUARE(S); 412 | } 413 | return Q; 414 | } 415 | 416 | //The following algorithm is taken from 1st edition of 417 | //Jeffrey Hoffstein, Jill Pipher, J.H. Silverman - An introduction to mathematical cryptography 418 | //Proposition 2.27 on page 84 419 | 420 | static DEVICE_FUNC inline optional field_square_root(const uint256_g& x) 421 | { 422 | uint256_g candidate = field_exp(x, MAGIC_CONSTANT); 423 | 424 | using X = optional; 425 | return (EQUAL(MONT_SQUARE(candidate), x) ? X(candidate) : X(NONE_OPT)); 426 | } 427 | 428 | DEVICE_FUNC void gen_random_elem(affine_point& pt, curandState& state) 429 | { 430 | //consider equation in short Weierstrass form: y^2 = x^3 + a * x + b 431 | //generate random x and compute right hand side 432 | //if this is not a square - repeat, again and again, until we are successful 433 | uint256_g x; 434 | optional y_opt; 435 | while (!y_opt) 436 | { 437 | gen_random_elem(x, state); 438 | 439 | //compute righthandside 440 | 441 | uint256_g righthandside = MONT_SQUARE(x); 442 | righthandside = MONT_MUL(righthandside, x); 443 | 444 | uint256_g temp = MONT_MUL(CURVE_A_COEFF, x); 445 | righthandside = FIELD_ADD(righthandside, temp); 446 | righthandside = FIELD_ADD(righthandside, CURVE_B_COEFF); 447 | 448 | y_opt = field_square_root(righthandside); 449 | } 450 | 451 | uint256_g y = y_opt.get_val(); 452 | 453 | if (curand(&state) % 2) 454 | y = FIELD_ADD_INV(y); 455 | 456 | pt = affine_point{x, y}; 457 | } 458 | 459 | DEVICE_FUNC void gen_random_elem(ec_point& pt, curandState& state) 460 | { 461 | affine_point temp; 462 | gen_random_elem(temp, state); 463 | pt = ec_point{temp.x, temp.y, BASE_FIELD_R}; 464 | 465 | //check if generated point is valid 466 | 467 | assert(IS_ON_CURVE(pt)); 468 | } 469 | 470 | -------------------------------------------------------------------------------- /sources/FFT.cu: -------------------------------------------------------------------------------- 1 | #include "cuda_structs.h" 2 | 3 | //FFT (we propose very naive realization) 4 | //---------------------------------------------------------------------------------------------------------------------------------------------------- 5 | //---------------------------------------------------------------------------------------------------------------------------------------------------- 6 | //---------------------------------------------------------------------------------------------------------------------------------------------------- 7 | 8 | //Sources of inspiration: 9 | //http://www.staff.science.uu.nl/~bisse101/Articles/preprint1138.pdf 10 | //https://cs.wmich.edu/gupta/teaching/cs5260/5260Sp15web/studentProjects/tiba&hussein/03278999.pdf 11 | //http://users.umiacs.umd.edu/~ramani/cmsc828e_gpusci/DeSpain_FFT_Presentation.pdf 12 | //http://www.bealto.com/gpu-fft_intro.html 13 | //https://github.com/mmajko/FFT-cuda/blob/master/src/fft-cuda.cu 14 | //Also have a loot at GPU gems 15 | 16 | //NB: arr should be a power of two 17 | 18 | 19 | //commom FFT routines 20 | //------------------------------------------------------------------------------------------------------------------------------------------------------------ 21 | //------------------------------------------------------------------------------------------------------------------------------------------------------------ 22 | //------------------------------------------------------------------------------------------------------------------------------------------------------------ 23 | 24 | struct field_pair 25 | { 26 | embedded_field a; 27 | embedded_field b; 28 | }; 29 | 30 | DEVICE_FUNC field_pair __inline__ fft_buttefly(const embedded_field& x, const embedded_field& y, const embedded_field& root_of_unity) 31 | { 32 | embedded_field temp = y * root_of_unity; 33 | return field_pair{ x + temp, x - temp}; 34 | } 35 | 36 | DEVICE_FUNC embedded_field __inline__ get_root_of_unity(uint32_t index, uint32_t omega_idx_coeff = 1, bool inverse = false) 37 | { 38 | embedded_field result(EMBEDDED_FIELD_R); 39 | uint32_t real_idx = index * omega_idx_coeff; 40 | if (inverse) 41 | real_idx = (1 << ROOTS_OF_UNTY_ARR_LEN) - real_idx; 42 | for (unsigned k = 0; k < ROOTS_OF_UNTY_ARR_LEN; k++) 43 | { 44 | if (CHECK_BIT(real_idx, k)) 45 | result *= embedded_field(EMBEDDED_FIELD_ROOTS_OF_UNITY[k]); 46 | } 47 | return result; 48 | } 49 | 50 | struct geometry 51 | { 52 | int gridSize; 53 | int blockSize; 54 | }; 55 | 56 | template 57 | geometry find_suitable_geometry(T func, uint shared_memory_used, uint32_t smCount) 58 | { 59 | int gridSize; 60 | int blockSize; 61 | int maxActiveBlocks; 62 | 63 | cudaOccupancyMaxPotentialBlockSize(&gridSize, &blockSize, func, shared_memory_used, 0); 64 | cudaOccupancyMaxActiveBlocksPerMultiprocessor(&maxActiveBlocks, func, blockSize, shared_memory_used); 65 | gridSize = maxActiveBlocks * smCount; 66 | 67 | return geometry{gridSize, blockSize}; 68 | } 69 | 70 | //Naive FFT-realization 71 | //-------------------------------------------------------------------------------------------------------------------------------------------------------- 72 | //-------------------------------------------------------------------------------------------------------------------------------------------------------- 73 | //-------------------------------------------------------------------------------------------------------------------------------------------------------- 74 | 75 | __global__ void FFT_shuffle(embedded_field* __restrict__ input_arr, embedded_field* __restrict__ output_arr, uint32_t arr_len, uint32_t log_arr_len) 76 | { 77 | uint32_t tid = threadIdx.x + blockIdx.x * blockDim.x; 78 | while (tid < arr_len) 79 | { 80 | output_arr[tid] = input_arr[__brev(tid) >> (32 - log_arr_len)]; 81 | tid += blockDim.x * gridDim.x; 82 | } 83 | } 84 | 85 | __global__ void FFT_iteration(embedded_field* __restrict__ input_arr, embedded_field* __restrict__ output_arr, 86 | uint32_t arr_len, uint32_t log_arr_len, uint32_t step) 87 | { 88 | uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; 89 | uint32_t k = (1 << step); 90 | uint32_t l = 2 * k; 91 | size_t omega_coeff = 1 << (ROOTS_OF_UNTY_ARR_LEN - log_arr_len); 92 | while (i < arr_len / 2) 93 | { 94 | uint32_t first_index = l * (i / k) + (i % k); 95 | uint32_t second_index = first_index + k; 96 | 97 | uint32_t root_of_unity_index = (1 << (log_arr_len - step - 1)) * (i % k); 98 | embedded_field omega = get_root_of_unity(root_of_unity_index, omega_coeff); 99 | 100 | field_pair ops = fft_buttefly(input_arr[first_index], input_arr[second_index], omega); 101 | 102 | output_arr[first_index] = ops.a; 103 | output_arr[second_index] = ops.b; 104 | 105 | i += blockDim.x * gridDim.x; 106 | } 107 | } 108 | 109 | #include 110 | 111 | void naive_fft_driver(embedded_field* input_arr, embedded_field* output_arr, uint32_t arr_len, bool is_inverse_FFT = false) 112 | { 113 | //first check that arr_len is a power of 2 114 | 115 | uint log_arr_len = BITS_PER_LIMB - __builtin_clz(arr_len) - 1; 116 | std::cout << "Log arr len: " << log_arr_len << std::endl; 117 | assert(arr_len = (1 << log_arr_len)); 118 | 119 | //find optimal geometry 120 | 121 | cudaDeviceProp prop; 122 | cudaGetDeviceProperties(&prop, 0); 123 | uint32_t smCount = prop.multiProcessorCount; 124 | 125 | geometry FFT_shuffle_geometry = find_suitable_geometry(FFT_shuffle, 0, smCount); 126 | geometry FFT_iter_geometry = find_suitable_geometry(FFT_iteration, 0, smCount); 127 | 128 | //allocate additional memory 129 | 130 | embedded_field* additional_device_memory = nullptr; 131 | cudaError_t cudaStatus = cudaMalloc((void **)&additional_device_memory, arr_len * sizeof(embedded_field)); 132 | 133 | //FFT shuffle; 134 | 135 | embedded_field* temp_output_arr = (log_arr_len % 2 ? additional_device_memory : output_arr); 136 | embedded_field* temp_input_arr = (log_arr_len % 2 ? output_arr : additional_device_memory); 137 | FFT_shuffle<<>>(input_arr, temp_output_arr, arr_len, log_arr_len); 138 | 139 | //FFT main cycle 140 | 141 | for (uint32_t step = 0; step < log_arr_len; step++) 142 | { 143 | //swap input and iutput arrs 144 | 145 | embedded_field* swap_arr = temp_input_arr; 146 | temp_input_arr = temp_output_arr; 147 | temp_output_arr = swap_arr; 148 | 149 | FFT_iteration<<>>(temp_input_arr, temp_output_arr, arr_len, log_arr_len, step); 150 | } 151 | 152 | //clean_up 153 | cudaFree(additional_device_memory); 154 | } 155 | 156 | 157 | //Bellman FFT-realization 158 | //-------------------------------------------------------------------------------------------------------------------------------------------------------- 159 | //-------------------------------------------------------------------------------------------------------------------------------------------------------- 160 | //-------------------------------------------------------------------------------------------------------------------------------------------------------- 161 | 162 | //TODO: make the same things using shuffle instructions and shared memory 163 | 164 | DEVICE_FUNC void _basic_serial_radix2_FFT(embedded_field* arr, size_t log_arr_len, size_t omega_idx_coeff, bool is_inverse_FFT) 165 | { 166 | size_t tid = threadIdx.x; 167 | size_t arr_len = 1 << log_arr_len; 168 | 169 | for(size_t i = tid; i < arr_len; i+= blockDim.x) 170 | { 171 | size_t rk = __brev(i) >> (32 - log_arr_len); 172 | if (i < rk) 173 | { 174 | embedded_field temp = arr[i]; 175 | arr[i] = arr[rk]; 176 | arr[rk] = temp; 177 | } 178 | } 179 | 180 | __syncthreads(); 181 | 182 | for (size_t step = 0; step < log_arr_len; ++step) 183 | { 184 | uint32_t i = tid; 185 | uint32_t k = (1 << step); 186 | uint32_t l = 2 * k; 187 | while (i < arr_len / 2) 188 | { 189 | uint32_t first_index = l * (i / k) + (i % k); 190 | uint32_t second_index = first_index + k; 191 | 192 | uint32_t omega_idx = (1 << (log_arr_len - step - 1)) * (i % k); 193 | embedded_field omega = get_root_of_unity(omega_idx, omega_idx_coeff, is_inverse_FFT); 194 | 195 | field_pair ops = fft_buttefly(arr[first_index], arr[second_index], omega); 196 | 197 | arr[first_index] = ops.a; 198 | arr[second_index] = ops.b; 199 | 200 | i += blockDim.x; 201 | } 202 | 203 | __syncthreads(); 204 | } 205 | } 206 | 207 | __global__ void _basic_parallel_radix2_FFT(const embedded_field* input_arr, embedded_field* output_arr, embedded_field* temp_arr_base, 208 | size_t log_arr_len, size_t log_num_subblocks, bool is_inverse_FFT) 209 | { 210 | assert( log_arr_len <= ROOTS_OF_UNTY_ARR_LEN && "the size of array is too large for FFT"); 211 | 212 | size_t omega_coeff = 1 << (ROOTS_OF_UNTY_ARR_LEN - log_arr_len); 213 | size_t L = 1 << (log_arr_len - log_num_subblocks); 214 | size_t NUM_SUBBLOCKS = 1 << log_num_subblocks; 215 | 216 | embedded_field* temp_arr = temp_arr_base + L * blockIdx.x; 217 | 218 | embedded_field omega_step = get_root_of_unity(blockIdx.x * L, omega_coeff, is_inverse_FFT); 219 | 220 | for (size_t i = threadIdx.x; i < L; i+= blockDim.x) 221 | { 222 | embedded_field omega_init = get_root_of_unity(blockIdx.x * i, omega_coeff, is_inverse_FFT); 223 | temp_arr[i] = embedded_field::zero(); 224 | for (size_t s = 0; s < NUM_SUBBLOCKS; ++s) 225 | { 226 | size_t idx = i + s * L; 227 | temp_arr[i] += input_arr[idx] * omega_init; 228 | omega_init *= omega_step; 229 | } 230 | } 231 | 232 | __syncthreads(); 233 | 234 | _basic_serial_radix2_FFT(temp_arr, log_arr_len - log_num_subblocks, NUM_SUBBLOCKS * omega_coeff, is_inverse_FFT); 235 | 236 | for (size_t i = threadIdx.x; i < L; i+= blockDim.x) 237 | output_arr[i * NUM_SUBBLOCKS + blockIdx.x] = temp_arr[i]; 238 | } 239 | 240 | __global__ void _radix2_one_block_FFT(const embedded_field* input_arr, embedded_field* output_arr, size_t log_arr_len, bool is_inverse_FFT) 241 | { 242 | extern __shared__ embedded_field temp_arr[]; 243 | size_t arr_len = 1 << log_arr_len; 244 | size_t omega_coeff = 1 << (ROOTS_OF_UNTY_ARR_LEN - log_arr_len); 245 | 246 | 247 | for (size_t i = threadIdx.x; i < arr_len; i+= blockDim.x) 248 | { 249 | temp_arr[i] = input_arr[i]; 250 | } 251 | 252 | _basic_serial_radix2_FFT(temp_arr, log_arr_len, omega_coeff, is_inverse_FFT); 253 | 254 | for (size_t i = threadIdx.x; i < arr_len; i+= blockDim.x) 255 | output_arr[i] = temp_arr[i]; 256 | } 257 | 258 | geometry find_geometry_for_advanced_FFT(uint arr_len) 259 | { 260 | //TODO: this particular values are customized for my architecture 261 | 262 | size_t DEFAULT_FFT_GRID_SIZE = 8; 263 | size_t DEFAULT_FFT_BLOCK_SIZE = 512; 264 | 265 | geometry res; 266 | 267 | if (arr_len < 2 * DEFAULT_FFT_BLOCK_SIZE) 268 | { 269 | res.gridSize = 1; 270 | res.blockSize = max(arr_len / 2, 1); 271 | } 272 | else 273 | { 274 | res.gridSize = min(DEFAULT_FFT_GRID_SIZE, arr_len / (2 * DEFAULT_FFT_BLOCK_SIZE)); 275 | res.blockSize = min(DEFAULT_FFT_BLOCK_SIZE, (size_t)(arr_len / (2 * res.gridSize))); 276 | } 277 | 278 | std::cout << "grid_size: " << res.gridSize << ", block size: " << res.blockSize << std::endl; 279 | return res; 280 | } 281 | 282 | void advanced_fft_driver(embedded_field* input_arr, embedded_field* output_arr, uint32_t arr_len, bool is_inverse_FFT = false) 283 | { 284 | //first check that arr_len is a power of 2 285 | 286 | uint log_arr_len = BITS_PER_LIMB - __builtin_clz(arr_len) - 1; 287 | assert(arr_len = (1 << log_arr_len)); 288 | 289 | geometry kernel_geometry = find_geometry_for_advanced_FFT(arr_len); 290 | 291 | if (kernel_geometry.gridSize == 1) 292 | { 293 | std::cout << "1block FFT - serial" << std::endl; 294 | 295 | _radix2_one_block_FFT<<<1, kernel_geometry.blockSize, kernel_geometry.blockSize * 2 * sizeof(embedded_field)>>>(input_arr, output_arr, 296 | log_arr_len, is_inverse_FFT); 297 | cudaDeviceSynchronize(); 298 | 299 | return; 300 | } 301 | 302 | size_t num_of_blocks = kernel_geometry.gridSize; 303 | uint log_num_subblocks = BITS_PER_LIMB - __builtin_clz(num_of_blocks) - 1; 304 | size_t block_size = 1 << (log_arr_len - log_num_subblocks); 305 | 306 | //allocate temporary memory 307 | embedded_field* temp_memory = nullptr; 308 | cudaError_t cudaStatus = cudaMalloc((void **)&temp_memory, num_of_blocks * block_size * sizeof(embedded_field)); 309 | 310 | _basic_parallel_radix2_FFT<<>>(input_arr, output_arr, temp_memory, 311 | log_arr_len, log_num_subblocks, is_inverse_FFT); 312 | cudaDeviceSynchronize(); 313 | 314 | cudaFree(temp_memory); 315 | } 316 | 317 | #define FFT_DRIVER(input_arr, output_arr, arr_len, is_inverse_FFT) advanced_fft_driver(input_arr, output_arr, arr_len, is_inverse_FFT) 318 | 319 | 320 | //polynomial multiplication via FFT 321 | 322 | struct polynomial 323 | { 324 | size_t deg; 325 | embedded_field* coeffs; 326 | }; 327 | 328 | size_t get_power_of_two(size_t n) 329 | { 330 | n--; 331 | n |= n >> 1; 332 | n |= n >> 2; 333 | n |= n >> 4; 334 | n |= n >> 8; 335 | n |= n >> 16; 336 | n++; 337 | 338 | return n; 339 | } 340 | 341 | __global__ void _mul_vecs(const embedded_field* a_arr, const embedded_field* b_arr, embedded_field* c_arr, size_t arr_len) 342 | { 343 | size_t tid = threadIdx.x + blockIdx.x * blockDim.x; 344 | while (tid < arr_len) 345 | { 346 | c_arr[tid] = (a_arr[tid] * b_arr[tid]); 347 | tid += blockDim.x * gridDim.x; 348 | } 349 | } 350 | 351 | void _mul_vecs_driver(const embedded_field* a_arr, const embedded_field* b_arr, embedded_field* c_arr, size_t arr_len) 352 | { 353 | int blockSize; 354 | int minGridSize; 355 | int realGridSize; 356 | int maxActiveBlocks; 357 | 358 | cudaOccupancyMaxPotentialBlockSize(&minGridSize, &blockSize, _mul_vecs, 0, 0); 359 | realGridSize = (arr_len + blockSize - 1) / blockSize; 360 | 361 | cudaDeviceProp prop; 362 | cudaGetDeviceProperties(&prop, 0); 363 | uint32_t smCount = prop.multiProcessorCount; 364 | cudaError_t error = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&maxActiveBlocks, _mul_vecs, blockSize, 0); 365 | if (error == cudaSuccess) 366 | realGridSize = maxActiveBlocks * smCount; 367 | 368 | _mul_vecs<<>>(a_arr, b_arr, c_arr, arr_len); 369 | } 370 | 371 | polynomial _polynomial_multiplication_on_fft(const polynomial& A, const polynomial& B) 372 | { 373 | size_t n = get_power_of_two(A.deg + B.deg); 374 | polynomial C; 375 | C.deg = A.deg + B.deg; 376 | 377 | embedded_field* temp_memory1 = nullptr; 378 | embedded_field* temp_memory2 = nullptr; 379 | cudaError_t cudaStatus; 380 | 381 | cudaStatus = cudaMalloc((void **)&temp_memory1, n * sizeof(embedded_field)); 382 | cudaStatus = cudaMalloc((void **)&temp_memory2, n * sizeof(embedded_field)); 383 | cudaStatus = cudaMalloc((void **)&C.coeffs, n * sizeof(embedded_field)); 384 | 385 | cudaMemcpy(temp_memory1, A.coeffs, A.deg * sizeof(embedded_field), cudaMemcpyDeviceToDevice); 386 | cudaMemset(temp_memory1 + A.deg, 0, (n - A.deg) *sizeof(embedded_field)); 387 | cudaMemcpy(temp_memory2, B.coeffs, B.deg * sizeof(embedded_field), cudaMemcpyDeviceToDevice); 388 | cudaMemset(temp_memory2 + B.deg, 0, (n - B.deg) *sizeof(embedded_field)); 389 | 390 | FFT_DRIVER(temp_memory1, temp_memory1, n, false); 391 | FFT_DRIVER(temp_memory2, temp_memory2, n, false); 392 | 393 | _mul_vecs_driver(temp_memory1, temp_memory2, C.coeffs, n); 394 | FFT_DRIVER(C.coeffs, C.coeffs, n, true); 395 | //_mul_elem_driver(C.coeffs, get_inv(n), n); 396 | 397 | cudaFree(temp_memory1); 398 | cudaFree(temp_memory2); 399 | 400 | return C; 401 | } 402 | 403 | #define POLY_MUL(X, Y) _polynomial_multiplication_on_fft(X, Y) 404 | 405 | 406 | //these drivers are used only for test purposes 407 | //------------------------------------------------------------------------------------------------------------------------------------------------------------- 408 | //------------------------------------------------------------------------------------------------------------------------------------------------------------- 409 | //------------------------------------------------------------------------------------------------------------------------------------------------------------- 410 | 411 | void naive_FFT_test_driver(uint256_g* A, uint256_g* B, uint256_g* C, size_t arr_len) 412 | { 413 | naive_fft_driver(reinterpret_cast(A), reinterpret_cast(C), arr_len); 414 | } 415 | 416 | void advanced_fft_test_driver(uint256_g* A, uint256_g* B, uint256_g* C, size_t arr_len) 417 | { 418 | advanced_fft_driver(reinterpret_cast(A), reinterpret_cast(C), arr_len); 419 | } 420 | -------------------------------------------------------------------------------- /include/cuda_structs.h: -------------------------------------------------------------------------------- 1 | #ifndef CUDA_STRUCTS_H 2 | #define CUDA_STRUCTS_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #ifdef __CUDACC__ 9 | #include 10 | #include 11 | #define DEVICE_FUNC __device__ 12 | #define HOST_DEVICE_FUNC __host__ __device__ 13 | #define DEVICE_VAR __device__ 14 | #define HOST_DEVICE_VAR __host__ __device__ 15 | #define CONST_MEMORY __constant__ 16 | #else 17 | #define DEVICE_FUNC 18 | #define HOST_DEVICE_FUNC 19 | #define DEVICE_VAR 20 | #define HOST_DEVICE_VAR 21 | #define CONST_MEMORY 22 | #endif 23 | 24 | #define HALF_N 4 25 | #define N 8 26 | #define N_DOUBLED 16 27 | #define N_BITLEN 254 28 | #define R_LOG 256 29 | #define BITS_PER_LIMB 32 30 | 31 | #define USE_PROJECTIVE_COORDINATES 32 | 33 | #define WARP_SIZE 32 34 | #define DEFAUL_NUM_OF_THREADS_PER_BLOCK 256 35 | 36 | #define CHECK_BIT(var,pos) ((var) & (1<<(pos))) 37 | #define SET_BIT(var,pos) ((var) |= (1<<(pos))) 38 | 39 | struct uint64_g 40 | { 41 | union 42 | { 43 | uint64_t as_long; 44 | struct 45 | { 46 | uint32_t low; 47 | uint32_t high; 48 | }; 49 | }; 50 | }; 51 | 52 | struct uint128_g 53 | { 54 | union 55 | { 56 | uint32_t n[4]; 57 | struct 58 | { 59 | uint64_t low; 60 | uint64_t high; 61 | }; 62 | }; 63 | }; 64 | 65 | struct uint128_with_carry_g 66 | { 67 | uint128_g val; 68 | uint32_t carry; 69 | }; 70 | 71 | //NB: may be this should somehow help? 72 | //https://stackoverflow.com/questions/10297067/in-a-cuda-kernel-how-do-i-store-an-array-in-local-thread-memory 73 | 74 | struct uint256_g 75 | { 76 | union 77 | { 78 | uint32_t n[8]; 79 | uint64_t nn[4]; 80 | struct 81 | { 82 | uint128_g low; 83 | uint128_g high; 84 | }; 85 | }; 86 | }; 87 | 88 | struct uint512_g 89 | { 90 | union 91 | { 92 | uint32_t n[16]; 93 | uint64_t nn[8]; 94 | uint256_g l[2]; 95 | }; 96 | }; 97 | 98 | struct ec_point 99 | { 100 | uint256_g x; 101 | uint256_g y; 102 | uint256_g z; 103 | }; 104 | 105 | struct affine_point 106 | { 107 | uint256_g x; 108 | uint256_g y; 109 | }; 110 | 111 | //this is a field embedded into a group of points on elliptic curve 112 | 113 | struct embedded_field 114 | { 115 | uint256_g rep_; 116 | 117 | DEVICE_FUNC explicit embedded_field(const uint256_g rep); 118 | DEVICE_FUNC embedded_field(); 119 | 120 | static DEVICE_FUNC embedded_field zero(); 121 | static DEVICE_FUNC embedded_field one(); 122 | 123 | DEVICE_FUNC bool operator==(const embedded_field& other) const; 124 | DEVICE_FUNC bool operator!=(const embedded_field& other) const; 125 | 126 | DEVICE_FUNC operator uint256_g() const; 127 | DEVICE_FUNC embedded_field operator-() const; 128 | 129 | //NB: for now we assume that highest possible limb bit is zero for the field modulus 130 | DEVICE_FUNC embedded_field& operator+=(const embedded_field& other); 131 | DEVICE_FUNC embedded_field& operator-=(const embedded_field& other); 132 | 133 | //here we mean montgomery multiplication 134 | DEVICE_FUNC embedded_field& operator*=(const embedded_field& other); 135 | 136 | friend DEVICE_FUNC embedded_field operator+(const embedded_field& left, const embedded_field& right); 137 | friend DEVICE_FUNC embedded_field operator-(const embedded_field& left, const embedded_field& right); 138 | friend DEVICE_FUNC embedded_field operator*(const embedded_field& left, const embedded_field& right); 139 | }; 140 | 141 | DEVICE_FUNC embedded_field operator+(const embedded_field& left, const embedded_field& right); 142 | DEVICE_FUNC embedded_field operator-(const embedded_field& left, const embedded_field& right); 143 | DEVICE_FUNC embedded_field operator*(const embedded_field& left, const embedded_field& right); 144 | 145 | 146 | //miscellaneous helpful staff 147 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 148 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 149 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 150 | 151 | DEVICE_FUNC inline bool get_bit(const uint256_g& x, uint32_t index) 152 | { 153 | auto num = x.n[index / 32]; 154 | auto pos = index % 32; 155 | return CHECK_BIT(num, pos); 156 | } 157 | 158 | DEVICE_FUNC inline void set_bit(uint256_g& x, uint32_t index) 159 | { 160 | auto& num = x.n[index / 32]; 161 | auto pos = index % 32; 162 | num |= (1 << pos); 163 | //SET_BIT(num, pos); 164 | } 165 | 166 | //initialization function 167 | bool CUDA_init(); 168 | void get_device_info(); 169 | 170 | 171 | #ifdef __CUDACC__ 172 | 173 | //sone global constants 174 | //-------------------------------------------------------------------------------------------------------------------------------------------------------- 175 | //-------------------------------------------------------------------------------------------------------------------------------------------------------- 176 | //-------------------------------------------------------------------------------------------------------------------------------------------------------- 177 | 178 | //TODO: it's better to embed this constants at compile time rather than taking them from constant memory 179 | //SOunds like a way for optimization! 180 | 181 | //curve order field 182 | 183 | extern DEVICE_VAR CONST_MEMORY uint256_g EMBEDDED_FIELD_P; 184 | extern DEVICE_VAR CONST_MEMORY uint256_g EMBEDDED_FIELD_R; 185 | extern DEVICE_VAR CONST_MEMORY uint256_g EMBEDDED_FIELD_R_inv; 186 | extern DEVICE_VAR CONST_MEMORY uint32_t EMBEDDED_FIELD_N; 187 | 188 | //base field 189 | 190 | extern DEVICE_VAR CONST_MEMORY uint256_g BASE_FIELD_P; 191 | extern DEVICE_VAR CONST_MEMORY uint32_t BASE_FIELD_N; 192 | extern DEVICE_VAR CONST_MEMORY uint256_g BASE_FIELD_R; 193 | extern DEVICE_VAR CONST_MEMORY uint256_g BASE_FIELD_R2; 194 | extern DEVICE_VAR CONST_MEMORY uint256_g BASE_FIELD_R3; 195 | extern DEVICE_VAR CONST_MEMORY uint256_g BASE_FIELD_R4; 196 | extern DEVICE_VAR CONST_MEMORY uint256_g BASE_FIELD_R8; 197 | 198 | //NB: MAGIC_POWER =(P+1)/4 is constant, so we are able to precompute it (it is needed for exponentiation in a finite field) 199 | //NB: Magic constant should be given in standard form (i.e. NON MONTGOMERY) 200 | 201 | extern DEVICE_VAR CONST_MEMORY uint256_g MAGIC_CONSTANT; 202 | 203 | //elliptic curve params 204 | 205 | //A = 0 206 | extern DEVICE_VAR CONST_MEMORY uint256_g CURVE_A_COEFF; 207 | //B = 3 208 | extern DEVICE_VAR CONST_MEMORY uint256_g CURVE_B_COEFF; 209 | // generator G = [1, 2, 1] 210 | extern DEVICE_VAR CONST_MEMORY ec_point CURVE_G; 211 | 212 | //this fconstant is used in Kasinski algorithm: that is fast field inversion in Montgomety form 213 | 214 | extern DEVICE_VAR CONST_MEMORY uint256_g BASE_FIELD_R_SQUARED; 215 | 216 | //this is required for experiemental version of mont mul 217 | 218 | extern DEVICE_VAR CONST_MEMORY uint256_g BASE_FIELD_N_LARGE; 219 | 220 | //this are used for FFT 221 | 222 | extern DEVICE_FUNC size_t ROOTS_OF_UNTY_ARR_LEN; 223 | extern DEVICE_FUNC CONST_MEMORY uint256_g EMBEDDED_FIELD_ROOTS_OF_UNITY[]; 224 | 225 | extern DEVICE_FUNC size_t MULT_GEN_ARR_LEN; 226 | extern DEVICE_FUNC CONST_MEMORY uint256_g EMBEDDED_FIELD_MULT_GEN_ARR[]; 227 | 228 | extern DEVICE_FUNC size_t MULT_GEN_INV_ARR_LEN; 229 | extern DEVICE_FUNC CONST_MEMORY uint256_g EMBEDDED_FIELD_MULT_GEN_INV_ARR[]; 230 | 231 | //a bunch of helpful structs 232 | //-------------------------------------------------------------------------------------------------------------------------------------------------------- 233 | //-------------------------------------------------------------------------------------------------------------------------------------------------------- 234 | //-------------------------------------------------------------------------------------------------------------------------------------------------------- 235 | 236 | //We are not able to compile with C++ 17 standard 237 | 238 | struct none_t{}; 239 | extern DEVICE_VAR CONST_MEMORY none_t NONE_OPT; 240 | 241 | template 242 | class optional 243 | { 244 | private: 245 | bool flag_; 246 | T val_; 247 | 248 | static_assert(std::is_default_constructible::value, "Inner type of optional should be constructible!"); 249 | public: 250 | DEVICE_FUNC optional(const T& val): flag_(true), val_(val) {} 251 | DEVICE_FUNC optional(const none_t& none): flag_(false) {} 252 | DEVICE_FUNC optional(): flag_(false) {} 253 | 254 | DEVICE_FUNC operator bool() const 255 | { 256 | return flag_; 257 | } 258 | 259 | DEVICE_FUNC const T& get_val() const 260 | { 261 | assert(flag_); 262 | return val_; 263 | } 264 | }; 265 | 266 | 267 | //device specific functions 268 | //-------------------------------------------------------------------------------------------------------------------------------------------------------- 269 | //-------------------------------------------------------------------------------------------------------------------------------------------------------- 270 | //-------------------------------------------------------------------------------------------------------------------------------------------------------- 271 | 272 | 273 | DEVICE_FUNC uint128_with_carry_g add_uint128_with_carry_asm(const uint128_g&, const uint128_g&); 274 | DEVICE_FUNC uint128_g sub_uint128_asm(const uint128_g&, const uint128_g&); 275 | DEVICE_FUNC uint256_g add_uint256_naive(const uint256_g&, const uint256_g&); 276 | DEVICE_FUNC uint256_g add_uint256_asm(const uint256_g&, const uint256_g&); 277 | DEVICE_FUNC uint256_g sub_uint256_naive(const uint256_g&, const uint256_g&); 278 | DEVICE_FUNC uint256_g sub_uint256_asm(const uint256_g&, const uint256_g&); 279 | DEVICE_FUNC int cmp_uint256_naive(const uint256_g&, const uint256_g&); 280 | 281 | DEVICE_FUNC void add_uint_uint256_asm(uint256_g&, uint32_t); 282 | DEVICE_FUNC void sub_uint_uint256_asm(uint256_g&, uint32_t); 283 | 284 | DEVICE_FUNC bool is_zero(const uint256_g&); 285 | DEVICE_FUNC bool is_even(const uint256_g&); 286 | 287 | DEVICE_FUNC uint256_g shift_right_asm(const uint256_g&, uint32_t); 288 | DEVICE_FUNC uint256_g shift_left_asm(const uint256_g&, uint32_t); 289 | 290 | #define CMP(a, b) cmp_uint256_naive(a, b) 291 | #define ADD(a, b) add_uint256_asm(a, b) 292 | #define SUB(a, b) sub_uint256_asm(a, b) 293 | #define SHIFT_LEFT(a, b) shift_left_asm(a, b) 294 | #define SHIFT_RIGHT(a, b) shift_right_asm(a, b) 295 | #define ADD_UINT(a, b) add_uint_uint256_asm(a, b) 296 | #define SUB_UINT(a, b) sub_uint_uint256_asm(a, b) 297 | 298 | DEVICE_FUNC inline bool EQUAL(const uint256_g& lhs, const uint256_g& rhs) 299 | { 300 | return CMP(lhs, rhs) == 0; 301 | } 302 | 303 | //helper functions for naive multiplication 304 | 305 | DEVICE_FUNC inline uint32_t device_long_mul(uint32_t x, uint32_t y, uint32_t* high_ptr) 306 | { 307 | uint32_t high = __umulhi(x, y); 308 | *high_ptr = high; 309 | return x * y; 310 | } 311 | 312 | DEVICE_FUNC inline uint32_t device_fused_add(uint32_t x, uint32_t y, uint32_t* high_ptr) 313 | { 314 | uint32_t z = x + y; 315 | if (z < x) 316 | (*high_ptr)++; 317 | return z; 318 | } 319 | 320 | DEVICE_FUNC uint256_g mul_uint128_to_256_naive(const uint128_g&, const uint128_g&); 321 | DEVICE_FUNC uint256_g mul_uint128_to_256_asm_ver1(const uint128_g&, const uint128_g&); 322 | 323 | #if (__CUDA_ARCH__ >= 500) 324 | DEVICE_FUNC uint256_g mul_uint128_to_256_asm_ver2(const uint128_g&, const uint128_g&); 325 | #endif 326 | 327 | #define MUL_SHORT(a, b) mul_uint128_to_256_asm_ver1(a, b) 328 | 329 | DEVICE_FUNC uint512_g mul_uint256_to_512_naive(const uint256_g&, const uint256_g&); 330 | DEVICE_FUNC uint512_g mul_uint256_to_512_asm(const uint256_g&, const uint256_g&); 331 | DEVICE_FUNC uint512_g mul_uint256_to_512_asm_with_allocation(const uint256_g&, const uint256_g&); 332 | DEVICE_FUNC uint512_g mul_uint256_to_512_asm_longregs(const uint256_g&, const uint256_g&); 333 | DEVICE_FUNC uint512_g mul_uint256_to_512_Karatsuba(const uint256_g&, const uint256_g&); 334 | DEVICE_FUNC uint512_g mul_uint256_to_512_asm_with_shuffle(const uint256_g&, const uint256_g&); 335 | 336 | #define MUL(a, b) mul_uint256_to_512_asm_with_allocation(a, b) 337 | 338 | DEVICE_FUNC uint512_g square_uint256_to_512_naive(const uint256_g&); 339 | DEVICE_FUNC uint512_g square_uint256_to_512_asm(const uint256_g&); 340 | 341 | DEVICE_FUNC uint256_g mont_mul_256_naive_SOS(const uint256_g&, const uint256_g&); 342 | DEVICE_FUNC uint256_g mont_mul_256_naive_CIOS(const uint256_g&, const uint256_g&); 343 | DEVICE_FUNC uint256_g mont_mul_256_asm_SOS(const uint256_g&, const uint256_g&); 344 | DEVICE_FUNC uint256_g mont_mul_256_asm_CIOS(const uint256_g&, const uint256_g&); 345 | 346 | #define MONT_SQUARE(a) mont_mul_256_asm_SOS(a, a) 347 | #define MONT_MUL(a,b) mont_mul_256_asm_CIOS(a, b) 348 | 349 | DEVICE_FUNC uint256_g FIELD_ADD(const uint256_g&, const uint256_g&); 350 | DEVICE_FUNC uint256_g FIELD_SUB(const uint256_g&, const uint256_g&); 351 | DEVICE_FUNC uint256_g FIELD_ADD_INV(const uint256_g&); 352 | DEVICE_FUNC uint256_g FIELD_MUL_INV(const uint256_g&); 353 | 354 | //Implementation of these routines doesn't depend on whether we consider prokective or jacobian coordinates 355 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 356 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 357 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 358 | 359 | DEVICE_FUNC inline bool is_infinity(const ec_point& point) 360 | { 361 | return is_zero(point.z); 362 | } 363 | 364 | DEVICE_FUNC inline ec_point point_at_infty() 365 | { 366 | ec_point pt; 367 | 368 | //TD: may be we should use asm and xor here) 369 | #pragma unroll 370 | for (int32_t i = 0 ; i < N; i++) 371 | { 372 | pt.x.n[i] = 0; 373 | } 374 | pt.y.n[0] = 1; 375 | #pragma unroll 376 | for (int32_t i= 1 ; i < N; i++) 377 | { 378 | pt.y.n[i] = 0; 379 | } 380 | #pragma unroll 381 | for (int32_t i = 0 ; i < N; i++) 382 | { 383 | pt.z.n[i] = 0; 384 | } 385 | 386 | return pt; 387 | } 388 | 389 | DEVICE_FUNC inline ec_point INV(const ec_point& pt) 390 | { 391 | return {pt.x, FIELD_ADD_INV(pt.y), pt.z}; 392 | } 393 | 394 | DEVICE_FUNC ec_point ECC_DOUBLE_PROJ(const ec_point&); 395 | DEVICE_FUNC bool IS_ON_CURVE_PROJ(const ec_point&); 396 | DEVICE_FUNC bool EQUAL_PROJ(const ec_point&, const ec_point&); 397 | DEVICE_FUNC ec_point ECC_ADD_PROJ(const ec_point&, const ec_point&); 398 | DEVICE_FUNC ec_point ECC_SUB_PROJ(const ec_point&, const ec_point&); 399 | DEVICE_FUNC ec_point ECC_ADD_MIXED_PROJ(const ec_point&, const affine_point&); 400 | 401 | DEVICE_FUNC ec_point ECC_DOUBLE_JAC(const ec_point&); 402 | DEVICE_FUNC bool IS_ON_CURVE_JAC(const ec_point&); 403 | DEVICE_FUNC bool EQUAL_JAC(const ec_point&, const ec_point&); 404 | DEVICE_FUNC ec_point ECC_ADD_JAC(const ec_point&, const ec_point&); 405 | DEVICE_FUNC ec_point ECC_SUB_JAC(const ec_point&, const ec_point&); 406 | DEVICE_FUNC ec_point ECC_ADD_MIXED_JAC(const ec_point&, const affine_point&); 407 | 408 | DEVICE_FUNC ec_point ECC_double_and_add_exp_PROJ(const ec_point&, const uint256_g&); 409 | DEVICE_FUNC ec_point ECC_ternary_expansion_exp_PROJ(const ec_point&, const uint256_g&); 410 | DEVICE_FUNC ec_point ECC_double_and_add_exp_JAC(const ec_point&, const uint256_g&); 411 | DEVICE_FUNC ec_point ECC_ternary_expansion_exp_JAC(const ec_point&, const uint256_g&); 412 | DEVICE_FUNC ec_point ECC_double_and_add_affine_exp_PROJ(const affine_point&, const uint256_g&); 413 | DEVICE_FUNC ec_point ECC_double_and_add_affine_exp_JAC(const affine_point&, const uint256_g&); 414 | DEVICE_FUNC ec_point ECC_wNAF_exp_PROJ(const ec_point&, const uint256_g&); 415 | DEVICE_FUNC ec_point ECC_wNAF_exp_JAC(const ec_point&, const uint256_g&); 416 | 417 | 418 | #ifdef USE_PROJECTIVE_COORDINATES 419 | 420 | #define ECC_ADD(a, b) ECC_ADD_PROJ(a, b) 421 | #define ECC_SUB(a, b) ECC_SUB_PROJ(a, b) 422 | #define ECC_DOUBLE(a) ECC_DOUBLE_PROJ(a) 423 | #define ECC_EXP(p, d) ECC_double_and_add_affine_exp_PROJ(p, d) 424 | #define IS_ON_CURVE(p) IS_ON_CURVE_PROJ(p) 425 | #define ECC_MIXED_ADD(a, b) ECC_ADD_MIXED_PROJ(a, b) 426 | 427 | #elif defined USE_JACOBIAN_COORDINATES 428 | 429 | #define ECC_ADD(a, b) ECC_ADD_JAC(a, b) 430 | #define ECC_SUB(a, b) ECC_SUB_JAC(a, b) 431 | #define ECC_DOUBLE(a) ECC_DOUBLE_JAC(a) 432 | #define ECC_EXP(p, d) ECC_double_and_add_affine_exp_JAC(p, d) 433 | #define IS_ON_CURVE(p) IS_ON_CURVE_JAC(p) 434 | #define ECC_MIXED_ADD(a, b) ECC_ADD_MIXED_JAC(a, b) 435 | 436 | #else 437 | #error The form of elliptic curve coordinates should be explicitely specified 438 | #endif 439 | 440 | //random elements generators 441 | 442 | DEVICE_FUNC void gen_random_elem(uint256_g&, curandState&); 443 | DEVICE_FUNC void gen_random_elem(embedded_field&, curandState&); 444 | DEVICE_FUNC void gen_random_elem(ec_point&, curandState&); 445 | DEVICE_FUNC void gen_random_elem(affine_point&, curandState&); 446 | 447 | template 448 | __global__ void gen_random_array_kernel(T* elems, size_t arr_len, curandState* state, int seed) 449 | { 450 | size_t tid = threadIdx.x + blockIdx.x * blockDim.x; 451 | /* Each thread gets same seed, a different sequence 452 | number, no offset */ 453 | curand_init(seed + tid, 0, 0, &state[tid]); 454 | 455 | curandState localState = state[tid]; 456 | 457 | while (tid < arr_len) 458 | { 459 | gen_random_elem(elems[tid], localState); 460 | tid += blockDim.x * gridDim.x; 461 | } 462 | } 463 | 464 | #endif 465 | 466 | #endif -------------------------------------------------------------------------------- /sources/Groth16_prover.cu: -------------------------------------------------------------------------------- 1 | #include "cuda_structs.h" 2 | 3 | #include 4 | #include 5 | 6 | #define EXPORT __attribute__((visibility("default"))) 7 | 8 | //---------------------------------------------------------------------------------------------------------------------------------------------- 9 | //In order to simplify testing we will have all required functions in one file (NB: it leads to awful duplicated in code, but who cares!) 10 | //---------------------------------------------------------------------------------------------------------------------------------------------- 11 | 12 | struct Geometry 13 | { 14 | int gridSize; 15 | int blockSize; 16 | }; 17 | 18 | template 19 | Geometry find_suitable_geometry(T func, uint shared_memory_used, uint32_t smCount) 20 | { 21 | int gridSize; 22 | int blockSize; 23 | int maxActiveBlocks; 24 | 25 | cudaOccupancyMaxPotentialBlockSize(&gridSize, &blockSize, func, shared_memory_used, 0); 26 | cudaOccupancyMaxActiveBlocksPerMultiprocessor(&maxActiveBlocks, func, blockSize, shared_memory_used); 27 | gridSize = maxActiveBlocks * smCount; 28 | 29 | return Geometry{gridSize, blockSize}; 30 | } 31 | 32 | static void HandleError(cudaError_t err, const char *file, int line ) 33 | { 34 | if (err != cudaSuccess) 35 | { 36 | std::cout << cudaGetErrorString( err ) << " in " << file << " at line " << line << std::endl; 37 | exit( EXIT_FAILURE ); 38 | } 39 | } 40 | 41 | #define HANDLE_ERROR( err ) (HandleError( err, __FILE__, __LINE__ )) 42 | 43 | __constant__ uint256_g elems[3]; 44 | __constant__ uint256_g tau[1]; 45 | 46 | //---------------------------------------------------------------------------------------------------------------------------------------------- 47 | //vector operations 48 | //---------------------------------------------------------------------------------------------------------------------------------------------- 49 | 50 | __global__ void field_sub_inplace_kernel(embedded_field* a_arr, const embedded_field* b_arr, size_t arr_len) 51 | { 52 | size_t tid = threadIdx.x + blockIdx.x * blockDim.x; 53 | while (tid < arr_len) 54 | { 55 | a_arr[tid] -= b_arr[tid]; 56 | tid += blockDim.x * gridDim.x; 57 | } 58 | } 59 | 60 | __global__ void field_mul_inplace_kernel(embedded_field* a_arr, const embedded_field* b_arr, size_t arr_len) 61 | { 62 | size_t tid = threadIdx.x + blockIdx.x * blockDim.x; 63 | while (tid < arr_len) 64 | { 65 | a_arr[tid] *= b_arr[tid]; 66 | tid += blockDim.x * gridDim.x; 67 | } 68 | } 69 | 70 | using field_kernel_t = void(embedded_field*, const embedded_field*, size_t); 71 | 72 | void field_func_invoke(embedded_field* a_arr, const embedded_field* b_arr, uint32_t arr_len, cudaStream_t& stream, 73 | uint32_t smCount, field_kernel_t func) 74 | { 75 | Geometry geometry = find_suitable_geometry(func, 0, smCount); 76 | 77 | (*func)<<>>(a_arr, b_arr, arr_len); 78 | } 79 | 80 | __global__ void field_fused_mul_sub_inplace_kernel(embedded_field* a_arr, const embedded_field* b_arr, 81 | const embedded_field* c_arr, size_t arr_len) 82 | { 83 | size_t tid = threadIdx.x + blockIdx.x * blockDim.x; 84 | while (tid < arr_len) 85 | { 86 | a_arr[tid] *= b_arr[tid]; 87 | a_arr[tid] -= c_arr[tid]; 88 | tid += blockDim.x * gridDim.x; 89 | } 90 | } 91 | 92 | void fused_mul_sub(embedded_field* a_arr, const embedded_field* b_arr, const embedded_field* c_arr, uint32_t arr_len, uint32_t smCount) 93 | { 94 | Geometry geometry = find_suitable_geometry(field_fused_mul_sub_inplace_kernel, 0, smCount); 95 | field_fused_mul_sub_inplace_kernel<<>>(a_arr, b_arr, c_arr, arr_len); 96 | } 97 | 98 | //---------------------------------------------------------------------------------------------------------------------------------------------- 99 | //FFT 100 | //---------------------------------------------------------------------------------------------------------------------------------------------- 101 | 102 | struct field_pair 103 | { 104 | embedded_field a; 105 | embedded_field b; 106 | }; 107 | 108 | DEVICE_FUNC field_pair __inline__ fft_buttefly(const embedded_field& x, const embedded_field& y, const embedded_field& root_of_unity) 109 | { 110 | embedded_field temp = y * root_of_unity; 111 | return field_pair{ x + temp, x - temp}; 112 | } 113 | 114 | DEVICE_FUNC embedded_field __inline__ get_root_of_unity(uint32_t index, uint32_t omega_idx_coeff = 1, bool inverse = false) 115 | { 116 | embedded_field result(EMBEDDED_FIELD_R); 117 | uint32_t real_idx = index * omega_idx_coeff; 118 | if (inverse) 119 | real_idx = (1 << (ROOTS_OF_UNTY_ARR_LEN)) - real_idx; 120 | for (unsigned k = 0; k < ROOTS_OF_UNTY_ARR_LEN; k++) 121 | { 122 | if (CHECK_BIT(real_idx, k)) 123 | result *= embedded_field(EMBEDDED_FIELD_ROOTS_OF_UNITY[k]); 124 | } 125 | return result; 126 | } 127 | 128 | __global__ void fft_shuffle(embedded_field* arr, uint32_t arr_len, uint32_t log_arr_len) 129 | { 130 | uint32_t tid = threadIdx.x + blockIdx.x * blockDim.x; 131 | while (tid < arr_len) 132 | { 133 | uint32_t first_idx = tid; 134 | uint32_t second_idx = __brev(first_idx) >> (32 - log_arr_len); 135 | if (first_idx < second_idx) 136 | { 137 | //swap values! 138 | embedded_field temp = arr[first_idx]; 139 | arr[first_idx] = arr[second_idx]; 140 | arr[second_idx] = temp; 141 | } 142 | 143 | tid += blockDim.x * gridDim.x; 144 | } 145 | } 146 | 147 | __global__ void fft_iteration(embedded_field* arr, uint32_t arr_len, uint32_t log_arr_len, uint32_t step, bool is_inverse) 148 | { 149 | uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; 150 | uint32_t k = (1 << step); 151 | uint32_t l = 2 * k; 152 | size_t omega_coeff = 1 << (ROOTS_OF_UNTY_ARR_LEN - log_arr_len); 153 | while (i < arr_len / 2) 154 | { 155 | uint32_t first_index = l * (i / k) + (i % k); 156 | uint32_t second_index = first_index + k; 157 | 158 | uint32_t root_of_unity_index = (1 << (log_arr_len - step - 1)) * (i % k); 159 | embedded_field omega = get_root_of_unity(root_of_unity_index, omega_coeff, is_inverse); 160 | 161 | field_pair ops = fft_buttefly(arr[first_index], arr[second_index], omega); 162 | 163 | arr[first_index] = ops.a; 164 | arr[second_index] = ops.b; 165 | 166 | i += blockDim.x * gridDim.x; 167 | } 168 | } 169 | 170 | void fft_impl(embedded_field* arr, uint32_t arr_len, bool is_inverse_FFT, const Geometry& geometry, cudaStream_t& stream) 171 | { 172 | uint log_arr_len = BITS_PER_LIMB - __builtin_clz(arr_len) - 1; 173 | fft_shuffle<<>>(arr, arr_len, log_arr_len); 174 | 175 | //FFT main cycle 176 | for (uint32_t step = 0; step < log_arr_len; step++) 177 | { 178 | fft_iteration<<>>(arr, arr_len, log_arr_len, step, is_inverse_FFT); 179 | } 180 | } 181 | 182 | 183 | __global__ void mul_by_const_kernel(embedded_field* arr, size_t arr_len, const uint32_t index) 184 | { 185 | const embedded_field elem = (index == 0 ? embedded_field(elems[index]) : embedded_field(tau[0])); 186 | 187 | size_t tid = threadIdx.x + blockIdx.x * blockDim.x; 188 | while (tid < arr_len) 189 | { 190 | arr[tid] *= elem; 191 | tid += blockDim.x * gridDim.x; 192 | } 193 | } 194 | 195 | void mul_by_const(embedded_field* arr, size_t arr_len, const uint256_g& elem, const Geometry& geometry, 196 | cudaStream_t& stream, uint32_t index) 197 | { 198 | mul_by_const_kernel<<>>(arr, arr_len, index); 199 | } 200 | 201 | __global__ void mont_reduce_kernel(embedded_field* arr, size_t arr_len) 202 | { 203 | const embedded_field elem = embedded_field(EMBEDDED_FIELD_R_inv); 204 | 205 | size_t tid = threadIdx.x + blockIdx.x * blockDim.x; 206 | while (tid < arr_len) 207 | { 208 | arr[tid] *= elem; 209 | tid += blockDim.x * gridDim.x; 210 | } 211 | } 212 | 213 | void mont_reduce(embedded_field* arr, size_t arr_len, const Geometry& geometry, cudaStream_t& stream) 214 | { 215 | mont_reduce_kernel<<>>(arr, arr_len); 216 | } 217 | 218 | DEVICE_FUNC embedded_field __inline__ get_gen_power(size_t index) 219 | { 220 | embedded_field result(EMBEDDED_FIELD_R); 221 | 222 | //TODO: fix - index may be longer than 32 bits 223 | for (unsigned k = 0; k < 32; k++) 224 | { 225 | if (CHECK_BIT(index, k)) 226 | result *= embedded_field(EMBEDDED_FIELD_MULT_GEN_ARR[k]); 227 | } 228 | return result; 229 | } 230 | 231 | DEVICE_FUNC embedded_field __inline__ get_gen_inv_power(size_t index) 232 | { 233 | embedded_field result(EMBEDDED_FIELD_R); 234 | 235 | for (unsigned k = 0; k < 32; k++) 236 | { 237 | if (CHECK_BIT(index, k)) 238 | result *= embedded_field(EMBEDDED_FIELD_MULT_GEN_INV_ARR[k]); 239 | } 240 | return result; 241 | } 242 | 243 | __global__ void distribute_powers_kernel(embedded_field* arr, size_t arr_len, bool is_inv) 244 | { 245 | size_t tid = threadIdx.x + blockIdx.x * blockDim.x; 246 | 247 | while (tid < arr_len) 248 | { 249 | embedded_field elem = (is_inv ? get_gen_inv_power(tid) : get_gen_power(tid)); 250 | arr[tid] *= elem; 251 | 252 | tid += blockDim.x * gridDim.x; 253 | } 254 | } 255 | 256 | void distribute_powers(embedded_field* arr, size_t arr_len, bool is_inv, const Geometry& geometry, cudaStream_t& stream) 257 | { 258 | distribute_powers_kernel<<>>(arr, arr_len, is_inv); 259 | } 260 | 261 | 262 | void FFT(embedded_field* arr, uint32_t arr_len, const Geometry& geometry, cudaStream_t& stream) 263 | { 264 | fft_impl(arr, arr_len, false, geometry, stream); 265 | } 266 | 267 | void iFFT(embedded_field* arr, uint32_t arr_len, const uint256_g& inv, const Geometry& geometry, 268 | cudaStream_t& stream, uint32_t index) 269 | { 270 | fft_impl(arr, arr_len, true, geometry, stream); 271 | mul_by_const(arr, arr_len, inv, geometry, stream, index); 272 | } 273 | 274 | void cosetFFT(embedded_field* arr, uint32_t arr_len, const Geometry& geometry, cudaStream_t& stream) 275 | { 276 | distribute_powers(arr, arr_len, false, geometry, stream); 277 | fft_impl(arr, arr_len, false, geometry, stream); 278 | } 279 | 280 | void icosetFFT(embedded_field* arr, uint32_t arr_len, const uint256_g& inv, const Geometry& geometry, 281 | cudaStream_t& stream, uint32_t index) 282 | { 283 | fft_impl(arr, arr_len, true, geometry, stream); 284 | mul_by_const(arr, arr_len, inv, geometry, stream, index); 285 | distribute_powers(arr, arr_len, true, geometry, stream); 286 | } 287 | 288 | 289 | 290 | //---------------------------------------------------------------------------------------------------------------------------------------------- 291 | //Groth 16 prover! (at least part of it) 292 | //---------------------------------------------------------------------------------------------------------------------------------------------- 293 | 294 | //NB: the lengths of all these arrays should be equal! 295 | 296 | void large_Pippenger_driver(affine_point*, uint256_g*, ec_point*, size_t); 297 | 298 | struct Groth16_prover_data 299 | { 300 | const uint8_t* a_arr; 301 | size_t a_len; 302 | 303 | const uint8_t* b_arr; 304 | size_t b_len; 305 | 306 | const uint8_t* c_arr; 307 | size_t c_len; 308 | 309 | const uint8_t* m_inv; 310 | const uint8_t* h_arr; 311 | const uint8_t* tau_inv; 312 | const uint8_t* check_arr; 313 | }; 314 | 315 | size_t calc_domain_len(size_t len) 316 | { 317 | size_t log_domain_len = BITS_PER_LIMB - __builtin_clz(len) - 1; 318 | size_t domain_len = (1 << log_domain_len); 319 | if (domain_len < len) 320 | domain_len *= 2; 321 | 322 | return domain_len; 323 | } 324 | 325 | affine_point Groth16_proof(const Groth16_prover_data* pr_data) 326 | { 327 | cudaDeviceProp prop; 328 | HANDLE_ERROR(cudaGetDeviceProperties(&prop, 0)); 329 | 330 | if (!prop.deviceOverlap){ 331 | exit( EXIT_FAILURE ); 332 | } 333 | 334 | cudaStream_t stream1, stream2, stream3; 335 | HANDLE_ERROR(cudaStreamCreate(&stream1)); 336 | HANDLE_ERROR(cudaStreamCreate(&stream2)); 337 | HANDLE_ERROR(cudaStreamCreate(&stream3)); 338 | 339 | const uint256_g* m_inv = (const uint256_g*)pr_data->m_inv; 340 | const uint256_g* tau_inv = (const uint256_g*)pr_data->tau_inv; 341 | 342 | HANDLE_ERROR(cudaMemcpyToSymbol(elems, m_inv, sizeof(uint256_g), 0, cudaMemcpyHostToDevice)); 343 | HANDLE_ERROR(cudaMemcpyToSymbol(tau, tau_inv, sizeof(uint256_g), 0, cudaMemcpyHostToDevice)); 344 | 345 | size_t a_domain_len = calc_domain_len(pr_data->a_len); 346 | size_t b_domain_len = calc_domain_len(pr_data->b_len); 347 | size_t c_domain_len = calc_domain_len(pr_data->c_len); 348 | 349 | assert(a_domain_len == b_domain_len); 350 | assert(b_domain_len == c_domain_len); 351 | 352 | size_t domain_len = a_domain_len; 353 | 354 | //lock memory and copy asynchroniously to device 355 | assert(mlock(pr_data->a_arr, pr_data->a_len * sizeof(embedded_field)) == 0); 356 | assert(mlock(pr_data->b_arr, pr_data->b_len * sizeof(embedded_field)) == 0); 357 | assert(mlock(pr_data->c_arr, pr_data->c_len * sizeof(embedded_field)) == 0); 358 | assert(mlock(pr_data->h_arr, domain_len * sizeof(affine_point)) == 0); 359 | 360 | embedded_field* dev_a = nullptr, *dev_b = nullptr, *dev_c = nullptr; 361 | 362 | HANDLE_ERROR(cudaMalloc((void**)&dev_a, domain_len * sizeof(embedded_field))); 363 | HANDLE_ERROR(cudaMalloc((void**)&dev_b, domain_len * sizeof(affine_point))); 364 | HANDLE_ERROR(cudaMalloc((void**)&dev_c, domain_len * sizeof(embedded_field))); 365 | 366 | HANDLE_ERROR(cudaMemcpyAsync(dev_a, pr_data->a_arr, pr_data->a_len * sizeof(embedded_field), cudaMemcpyHostToDevice, stream1)); 367 | HANDLE_ERROR(cudaMemcpyAsync(dev_b, pr_data->b_arr, pr_data->b_len * sizeof(embedded_field), cudaMemcpyHostToDevice, stream2)); 368 | HANDLE_ERROR(cudaMemcpyAsync(dev_c, pr_data->c_arr, pr_data->c_len * sizeof(embedded_field), cudaMemcpyHostToDevice, stream3)); 369 | 370 | HANDLE_ERROR(cudaMemsetAsync(dev_a + pr_data->a_len, 0, (domain_len - pr_data->a_len) * sizeof(embedded_field), stream1)); 371 | HANDLE_ERROR(cudaMemsetAsync(dev_b + pr_data->b_len, 0, (domain_len - pr_data->b_len) * sizeof(embedded_field), stream2)); 372 | HANDLE_ERROR(cudaMemsetAsync(dev_c + pr_data->c_len, 0, (domain_len - pr_data->a_len) * sizeof(embedded_field), stream3)); 373 | 374 | Geometry FFT_geometry = find_suitable_geometry(fft_iteration, 0, prop.multiProcessorCount); 375 | 376 | iFFT(dev_a, domain_len, *m_inv, FFT_geometry, stream1, 0); 377 | cosetFFT(dev_a, domain_len, FFT_geometry, stream1); 378 | 379 | iFFT(dev_b, domain_len, *m_inv, FFT_geometry, stream2, 0); 380 | cosetFFT(dev_b, domain_len, FFT_geometry, stream2); 381 | 382 | iFFT(dev_c, domain_len, *m_inv, FFT_geometry, stream3, 0); 383 | cosetFFT(dev_c, domain_len, FFT_geometry, stream3); 384 | 385 | HANDLE_ERROR( cudaStreamSynchronize( stream1 ) ); 386 | HANDLE_ERROR( cudaStreamSynchronize( stream2 ) ); 387 | HANDLE_ERROR( cudaStreamSynchronize( stream3 ) ); 388 | 389 | fused_mul_sub(dev_a, dev_b, dev_c, domain_len, prop.multiProcessorCount); 390 | cudaDeviceSynchronize(); 391 | 392 | mul_by_const(dev_a, domain_len, *tau_inv, FFT_geometry, stream1, 1); 393 | icosetFFT(dev_a, domain_len, *m_inv, FFT_geometry, stream1, 0); 394 | Geometry mont_reduce_geometry = find_suitable_geometry(mont_reduce_kernel, 0, prop.multiProcessorCount); 395 | mont_reduce(dev_a, domain_len - 1, mont_reduce_geometry, stream1); 396 | 397 | HANDLE_ERROR(cudaMemcpyAsync(dev_b, pr_data->h_arr, (domain_len - 1) * sizeof(affine_point), cudaMemcpyHostToDevice, stream2)); 398 | 399 | HANDLE_ERROR( cudaStreamSynchronize( stream1 ) ); 400 | HANDLE_ERROR( cudaStreamSynchronize( stream2 ) ); 401 | HANDLE_ERROR( cudaStreamSynchronize( stream3 ) ); 402 | 403 | HANDLE_ERROR( cudaStreamDestroy( stream1 ) ); 404 | HANDLE_ERROR( cudaStreamDestroy( stream2 ) ); 405 | HANDLE_ERROR( cudaStreamDestroy( stream3 ) ); 406 | 407 | munlock(pr_data->a_arr, pr_data->a_len * sizeof(embedded_field)); 408 | munlock(pr_data->b_arr, pr_data->b_len * sizeof(embedded_field)); 409 | munlock(pr_data->c_arr, pr_data->c_len * sizeof(embedded_field)); 410 | munlock(pr_data->h_arr, (domain_len - 1) * sizeof(affine_point)); 411 | 412 | large_Pippenger_driver((affine_point*)dev_b, (uint256_g*)dev_a, (ec_point*)dev_c, domain_len - 1); 413 | 414 | affine_point res; 415 | HANDLE_ERROR(cudaMemcpy(&res, dev_c, sizeof(affine_point), cudaMemcpyDeviceToHost)); 416 | 417 | HANDLE_ERROR(cudaFree(dev_a)); 418 | HANDLE_ERROR(cudaFree(dev_b)); 419 | HANDLE_ERROR(cudaFree(dev_c)); 420 | 421 | return res; 422 | } 423 | 424 | extern "C" 425 | { 426 | int EXPORT evaluate_h(size_t a_len, size_t b_len, size_t c_len, size_t h_len, const uint8_t* a_repr, const uint8_t* b_repr, 427 | const uint8_t* c_repr, const uint8_t* h_repr, const uint8_t* z_inv, const uint8_t* m_inv, uint8_t* result_ptr) 428 | { 429 | Groth16_prover_data pr_data; 430 | 431 | pr_data.a_arr = a_repr; 432 | pr_data.a_len = a_len; 433 | 434 | pr_data.b_arr = b_repr; 435 | pr_data.b_len = b_len; 436 | 437 | pr_data.c_arr = c_repr; 438 | pr_data.c_len = c_len; 439 | 440 | pr_data.m_inv = m_inv; 441 | pr_data.h_arr = h_repr; 442 | pr_data.tau_inv = z_inv; 443 | 444 | affine_point res = Groth16_proof(&pr_data); 445 | 446 | memcpy(result_ptr, &res, sizeof(affine_point)); 447 | return 0; 448 | }; 449 | 450 | //if flag in_mont_form = TRUE then tthe array of powers is in mont form and all the numbers should be converted to standard form 451 | //inside the CUDA kernel 452 | 453 | int EXPORT dense_multiexp(size_t len, const uint8_t* power_repr, const uint8_t* point_repr, bool repr_flag, uint8_t* result_ptr) 454 | { 455 | affine_point* dev_point_arr = nullptr; 456 | uint256_g* dev_power_arr = nullptr; 457 | ec_point* dev_res = nullptr; 458 | 459 | HANDLE_ERROR(cudaMalloc((void**)&dev_point_arr, len * sizeof(affine_point))); 460 | HANDLE_ERROR(cudaMalloc((void**)&dev_power_arr, len * sizeof(uint256_g))); 461 | HANDLE_ERROR(cudaMalloc((void**)&dev_res, sizeof(ec_point))); 462 | 463 | HANDLE_ERROR(cudaMemcpy(dev_point_arr, point_repr, len * sizeof(affine_point), cudaMemcpyHostToDevice)); 464 | HANDLE_ERROR(cudaMemcpy(dev_power_arr, power_repr, len * sizeof(uint256_g), cudaMemcpyHostToDevice)); 465 | 466 | if (repr_flag) 467 | { 468 | cudaDeviceProp prop; 469 | HANDLE_ERROR(cudaGetDeviceProperties(&prop, 0)); 470 | 471 | Geometry mont_reduce_geometry = find_suitable_geometry(mont_reduce_kernel, 0, prop.multiProcessorCount); 472 | cudaStream_t stream = 0; 473 | mont_reduce((embedded_field*)dev_power_arr, len, mont_reduce_geometry, stream); 474 | } 475 | 476 | large_Pippenger_driver(dev_point_arr, dev_power_arr, dev_res, len); 477 | 478 | affine_point res; 479 | HANDLE_ERROR(cudaMemcpy(&res, dev_res, sizeof(affine_point), cudaMemcpyDeviceToHost)); 480 | 481 | HANDLE_ERROR(cudaFree(dev_point_arr)); 482 | HANDLE_ERROR(cudaFree(dev_power_arr)); 483 | HANDLE_ERROR(cudaFree(dev_res)); 484 | 485 | memcpy(result_ptr, &res, sizeof(affine_point)); 486 | return 0; 487 | }; 488 | 489 | 490 | } 491 | 492 | 493 | int main(int argc, char* argv[]) 494 | { 495 | return 0; 496 | } 497 | 498 | 499 | -------------------------------------------------------------------------------- /sources/mont_mul.cu: -------------------------------------------------------------------------------- 1 | #include "cuda_structs.h" 2 | 3 | //multiplication in Montgomery form 4 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 5 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 6 | //------------------------------------------------------------------------------------------------------------------------------------------------------ 7 | 8 | DEVICE_FUNC uint256_g mont_mul_256_naive_SOS(const uint256_g& u, const uint256_g& v) 9 | { 10 | uint512_g T = MUL(u, v); 11 | uint256_g res; 12 | 13 | #pragma unroll 14 | for (uint32_t i = 0; i < N; i++) 15 | { 16 | uint32_t carry = 0; 17 | uint32_t m = T.n[i] * BASE_FIELD_N; 18 | 19 | #pragma unroll 20 | for (uint32_t j = 0; j < N; j++) 21 | { 22 | uint32_t high_word = 0; 23 | uint32_t low_word = device_long_mul(m, BASE_FIELD_P.n[j], &high_word); 24 | low_word = device_fused_add(low_word, T.n[i + j], &high_word); 25 | low_word = device_fused_add(low_word, carry, &high_word); 26 | 27 | T.n[i + j] = low_word; 28 | carry = high_word; 29 | } 30 | //continue carrying 31 | uint32_t j = N; 32 | while (carry) 33 | { 34 | uint32_t new_carry = 0; 35 | T.n[i + j] = device_fused_add(T.n[i + j], carry, &new_carry); 36 | j++; 37 | carry = new_carry; 38 | } 39 | } 40 | 41 | #pragma unroll 42 | for (uint32_t i = 0; i < N; i++) 43 | { 44 | res.n[i] = T.n[i + N]; 45 | } 46 | 47 | if (CMP(res, BASE_FIELD_P) >= 0) 48 | { 49 | //TODO: may be better change to unary version of sub? 50 | res = SUB(res, BASE_FIELD_P); 51 | } 52 | 53 | return res; 54 | } 55 | 56 | DEVICE_FUNC uint256_g mont_mul_256_naive_CIOS(const uint256_g& u, const uint256_g& v) 57 | { 58 | uint256_g T; 59 | 60 | #pragma unroll 61 | for (uint32_t j = 0; j < N; j++) 62 | T.n[j] = 0; 63 | 64 | uint32_t prefix_low = 0, prefix_high = 0, m; 65 | uint32_t high_word, low_word; 66 | 67 | #pragma unroll 68 | for (uint32_t i = 0; i < N; i++) 69 | { 70 | uint32_t carry = 0; 71 | #pragma unroll 72 | for (uint32_t j = 0; j < N; j++) 73 | { 74 | low_word = device_long_mul(u.n[j], v.n[i], &high_word); 75 | low_word = device_fused_add(low_word, T.n[j], &high_word); 76 | low_word = device_fused_add(low_word, carry, &high_word); 77 | carry = high_word; 78 | T.n[j] = low_word; 79 | } 80 | 81 | //TODO: may be we actually require less space? (only one additional limb instead of two) 82 | prefix_high = 0; 83 | prefix_low = device_fused_add(prefix_low, carry, &prefix_high); 84 | 85 | m = T.n[0] * BASE_FIELD_N; 86 | low_word = device_long_mul(BASE_FIELD_P.n[0], m, &high_word); 87 | low_word = device_fused_add(low_word, T.n[0], &high_word); 88 | carry = high_word; 89 | 90 | #pragma unroll 91 | for (uint32_t j = 1; j < N; j++) 92 | { 93 | low_word = device_long_mul(BASE_FIELD_P.n[j], m, &high_word); 94 | low_word = device_fused_add(low_word, T.n[j], &high_word); 95 | low_word = device_fused_add(low_word, carry, &high_word); 96 | T.n[j-1] = low_word; 97 | carry = high_word; 98 | } 99 | 100 | T.n[N-1] = device_fused_add(prefix_low, carry, &prefix_high); 101 | prefix_low = prefix_high; 102 | } 103 | 104 | if (CMP(T, BASE_FIELD_P) >= 0) 105 | { 106 | //TODO: may be better change to inary version of sub? 107 | T = SUB(T, BASE_FIELD_P); 108 | } 109 | 110 | return T; 111 | } 112 | 113 | DEVICE_FUNC uint256_g mont_mul_256_asm_SOS(const uint256_g& u, const uint256_g& v) 114 | { 115 | uint512_g T = MUL(u, v); 116 | uint256_g w; 117 | 118 | asm ( ".reg .u32 a0, a1, a2, a3, a4, a5, a6, a7, a8;\n\t" 119 | ".reg .u32 a9, a10, a11, a12, a13, a14, a15;\n\t" 120 | ".reg .u32 n0, n1, n2, n3, n4, n5, n6, n7;\n\t" 121 | ".reg .u32 m, q, carry;\n\t" 122 | //unpacking operands 123 | "mov.b64 {a0,a1}, %4;\n\t" 124 | "mov.b64 {a2,a3}, %5;\n\t" 125 | "mov.b64 {a4,a5}, %6;\n\t" 126 | "mov.b64 {a6,a7}, %7;\n\t" 127 | "mov.b64 {a8,a9}, %8;\n\t" 128 | "mov.b64 {a10,a11}, %9;\n\t" 129 | "mov.b64 {a12,a13}, %10;\n\t" 130 | "mov.b64 {a14,a15}, %11;\n\t" 131 | "ld.const.u32 n0, [BASE_FIELD_P];\n\t" 132 | "ld.const.u32 n1, [BASE_FIELD_P + 4];\n\t" 133 | "ld.const.u32 n2, [BASE_FIELD_P + 8];\n\t" 134 | "ld.const.u32 n3, [BASE_FIELD_P + 12];\n\t" 135 | "ld.const.u32 n4, [BASE_FIELD_P + 16];\n\t" 136 | "ld.const.u32 n5, [BASE_FIELD_P + 20];\n\t" 137 | "ld.const.u32 n6, [BASE_FIELD_P + 24];\n\t" 138 | "ld.const.u32 n7, [BASE_FIELD_P + 28];\n\t" 139 | "ld.const.u32 q, [BASE_FIELD_N];\n\t" 140 | //main routine 141 | "mul.lo.u32 m, a0, q;\n\t" 142 | "mad.lo.cc.u32 a0, m, n0, a0;\n\t" 143 | "madc.lo.cc.u32 a1, m, n1, a1;\n\t" 144 | "madc.lo.cc.u32 a2, m, n2, a2;\n\t" 145 | "madc.lo.cc.u32 a3, m, n3, a3;\n\t" 146 | "madc.lo.cc.u32 a4, m, n4, a4;\n\t" 147 | "madc.lo.cc.u32 a5, m, n5, a5;\n\t" 148 | "madc.lo.cc.u32 a6, m, n6, a6;\n\t" 149 | "madc.lo.cc.u32 a7, m, n7, a7;\n\t" 150 | "addc.cc.u32 a8, a8, 0;\n\t" 151 | "addc.u32 carry, 0, 0;\n\t" 152 | 153 | "mad.hi.cc.u32 a1, m, n0, a1;\n\t" 154 | "madc.hi.cc.u32 a2, m, n1, a2;\n\t" 155 | "madc.hi.cc.u32 a3, m, n2, a3;\n\t" 156 | "madc.hi.cc.u32 a4, m, n3, a4;\n\t" 157 | "madc.hi.cc.u32 a5, m, n4, a5;\n\t" 158 | "madc.hi.cc.u32 a6, m, n5, a6;\n\t" 159 | "madc.hi.cc.u32 a7, m, n6, a7;\n\t" 160 | "madc.hi.cc.u32 a8, m, n7, a8;\n\t" 161 | "addc.cc.u32 a9, a9, carry;\n\t" 162 | "addc.u32 carry, 0, 0;\n\t" 163 | 164 | "mul.lo.u32 m, a1, q;\n\t" 165 | "mad.lo.cc.u32 a1, m, n0, a1;\n\t" 166 | "madc.lo.cc.u32 a2, m, n1, a2;\n\t" 167 | "madc.lo.cc.u32 a3, m, n2, a3;\n\t" 168 | "madc.lo.cc.u32 a4, m, n3, a4;\n\t" 169 | "madc.lo.cc.u32 a5, m, n4, a5;\n\t" 170 | "madc.lo.cc.u32 a6, m, n5, a6;\n\t" 171 | "madc.lo.cc.u32 a7, m, n6, a7;\n\t" 172 | "madc.lo.cc.u32 a8, m, n7, a8;\n\t" 173 | "addc.cc.u32 a9, a9, 0;\n\t" 174 | "addc.u32 carry, carry, 0;\n\t" 175 | 176 | "mad.hi.cc.u32 a2, m, n0, a2;\n\t" 177 | "madc.hi.cc.u32 a3, m, n1, a3;\n\t" 178 | "madc.hi.cc.u32 a4, m, n2, a4;\n\t" 179 | "madc.hi.cc.u32 a5, m, n3, a5;\n\t" 180 | "madc.hi.cc.u32 a6, m, n4, a6;\n\t" 181 | "madc.hi.cc.u32 a7, m, n5, a7;\n\t" 182 | "madc.hi.cc.u32 a8, m, n6, a8;\n\t" 183 | "madc.hi.cc.u32 a9, m, n7, a9;\n\t" 184 | "addc.cc.u32 a10, a10, carry;\n\t" 185 | "addc.u32 carry, 0, 0;\n\t" 186 | 187 | "mul.lo.u32 m, a2, q;\n\t" 188 | "mad.lo.cc.u32 a2, m, n0, a2;\n\t" 189 | "madc.lo.cc.u32 a3, m, n1, a3;\n\t" 190 | "madc.lo.cc.u32 a4, m, n2, a4;\n\t" 191 | "madc.lo.cc.u32 a5, m, n3, a5;\n\t" 192 | "madc.lo.cc.u32 a6, m, n4, a6;\n\t" 193 | "madc.lo.cc.u32 a7, m, n5, a7;\n\t" 194 | "madc.lo.cc.u32 a8, m, n6, a8;\n\t" 195 | "madc.lo.cc.u32 a9, m, n7, a9;\n\t" 196 | "addc.cc.u32 a10, a10, 0;\n\t" 197 | "addc.u32 carry, carry, 0;\n\t" 198 | 199 | "mad.hi.cc.u32 a3, m, n0, a3;\n\t" 200 | "madc.hi.cc.u32 a4, m, n1, a4;\n\t" 201 | "madc.hi.cc.u32 a5, m, n2, a5;\n\t" 202 | "madc.hi.cc.u32 a6, m, n3, a6;\n\t" 203 | "madc.hi.cc.u32 a7, m, n4, a7;\n\t" 204 | "madc.hi.cc.u32 a8, m, n5, a8;\n\t" 205 | "madc.hi.cc.u32 a9, m, n6, a9;\n\t" 206 | "madc.hi.cc.u32 a10, m, n7, a10;\n\t" 207 | "addc.cc.u32 a11, a11, carry;\n\t" 208 | "addc.u32 carry, 0, 0;\n\t" 209 | 210 | "mul.lo.u32 m, a3, q;\n\t" 211 | "mad.lo.cc.u32 a3, m, n0, a3;\n\t" 212 | "madc.lo.cc.u32 a4, m, n1, a4;\n\t" 213 | "madc.lo.cc.u32 a5, m, n2, a5;\n\t" 214 | "madc.lo.cc.u32 a6, m, n3, a6;\n\t" 215 | "madc.lo.cc.u32 a7, m, n4, a7;\n\t" 216 | "madc.lo.cc.u32 a8, m, n5, a8;\n\t" 217 | "madc.lo.cc.u32 a9, m, n6, a9;\n\t" 218 | "madc.lo.cc.u32 a10, m, n7, a10;\n\t" 219 | "addc.cc.u32 a11, a11, 0;\n\t" 220 | "addc.u32 carry, carry, 0;\n\t" 221 | 222 | "mad.hi.cc.u32 a4, m, n0, a4;\n\t" 223 | "madc.hi.cc.u32 a5, m, n1, a5;\n\t" 224 | "madc.hi.cc.u32 a6, m, n2, a6;\n\t" 225 | "madc.hi.cc.u32 a7, m, n3, a7;\n\t" 226 | "madc.hi.cc.u32 a8, m, n4, a8;\n\t" 227 | "madc.hi.cc.u32 a9, m, n5, a9;\n\t" 228 | "madc.hi.cc.u32 a10, m, n6, a10;\n\t" 229 | "madc.hi.cc.u32 a11, m, n7, a11;\n\t" 230 | "addc.cc.u32 a12, a12, carry;\n\t" 231 | "addc.u32 carry, 0, 0;\n\t" 232 | 233 | "mul.lo.u32 m, a4, q;\n\t" 234 | "mad.lo.cc.u32 a4, m, n0, a4;\n\t" 235 | "madc.lo.cc.u32 a5, m, n1, a5;\n\t" 236 | "madc.lo.cc.u32 a6, m, n2, a6;\n\t" 237 | "madc.lo.cc.u32 a7, m, n3, a7;\n\t" 238 | "madc.lo.cc.u32 a8, m, n4, a8;\n\t" 239 | "madc.lo.cc.u32 a9, m, n5, a9;\n\t" 240 | "madc.lo.cc.u32 a10, m, n6, a10;\n\t" 241 | "madc.lo.cc.u32 a11, m, n7, a11;\n\t" 242 | "addc.cc.u32 a12, a12, 0;\n\t" 243 | "addc.u32 carry, carry, 0;\n\t" 244 | 245 | "mad.hi.cc.u32 a5, m, n0, a5;\n\t" 246 | "madc.hi.cc.u32 a6, m, n1, a6;\n\t" 247 | "madc.hi.cc.u32 a7, m, n2, a7;\n\t" 248 | "madc.hi.cc.u32 a8, m, n3, a8;\n\t" 249 | "madc.hi.cc.u32 a9, m, n4, a9;\n\t" 250 | "madc.hi.cc.u32 a10, m, n5, a10;\n\t" 251 | "madc.hi.cc.u32 a11, m, n6, a11;\n\t" 252 | "madc.hi.cc.u32 a12, m, n7, a12;\n\t" 253 | "addc.cc.u32 a13, a13, carry;\n\t" 254 | "addc.u32 carry, 0, 0;\n\t" 255 | 256 | "mul.lo.u32 m, a5, q;\n\t" 257 | "mad.lo.cc.u32 a5, m, n0, a5;\n\t" 258 | "madc.lo.cc.u32 a6, m, n1, a6;\n\t" 259 | "madc.lo.cc.u32 a7, m, n2, a7;\n\t" 260 | "madc.lo.cc.u32 a8, m, n3, a8;\n\t" 261 | "madc.lo.cc.u32 a9, m, n4, a9;\n\t" 262 | "madc.lo.cc.u32 a10, m, n5, a10;\n\t" 263 | "madc.lo.cc.u32 a11, m, n6, a11;\n\t" 264 | "madc.lo.cc.u32 a12, m, n7, a12;\n\t" 265 | "addc.cc.u32 a13, a13, 0;\n\t" 266 | "addc.u32 carry, carry, 0;\n\t" 267 | 268 | "mad.hi.cc.u32 a6, m, n0, a6;\n\t" 269 | "madc.hi.cc.u32 a7, m, n1, a7;\n\t" 270 | "madc.hi.cc.u32 a8, m, n2, a8;\n\t" 271 | "madc.hi.cc.u32 a9, m, n3, a9;\n\t" 272 | "madc.hi.cc.u32 a10, m, n4, a10;\n\t" 273 | "madc.hi.cc.u32 a11, m, n5, a11;\n\t" 274 | "madc.hi.cc.u32 a12, m, n6, a12;\n\t" 275 | "madc.hi.cc.u32 a13, m, n7, a13;\n\t" 276 | "addc.cc.u32 a14, a14, carry;\n\t" 277 | "addc.u32 a15, a15, 0;\n\t" 278 | 279 | "mul.lo.u32 m, a6, q;\n\t" 280 | "mad.lo.cc.u32 a6, m, n0, a6;\n\t" 281 | "madc.lo.cc.u32 a7, m, n1, a7;\n\t" 282 | "madc.lo.cc.u32 a8, m, n2, a8;\n\t" 283 | "madc.lo.cc.u32 a9, m, n3, a9;\n\t" 284 | "madc.lo.cc.u32 a10, m, n4, a10;\n\t" 285 | "madc.lo.cc.u32 a11, m, n5, a11;\n\t" 286 | "madc.lo.cc.u32 a12, m, n6, a12;\n\t" 287 | "madc.lo.cc.u32 a13, m, n7, a13;\n\t" 288 | "addc.cc.u32 a14, a14, 0;\n\t" 289 | "addc.u32 a15, a15, 0;\n\t" 290 | 291 | "mad.hi.cc.u32 a7, m, n0, a7;\n\t" 292 | "madc.hi.cc.u32 a8, m, n1, a8;\n\t" 293 | "madc.hi.cc.u32 a9, m, n2, a9;\n\t" 294 | "madc.hi.cc.u32 a10, m, n3, a10;\n\t" 295 | "madc.hi.cc.u32 a11, m, n4, a11;\n\t" 296 | "madc.hi.cc.u32 a12, m, n5, a12;\n\t" 297 | "madc.hi.cc.u32 a13, m, n6, a13;\n\t" 298 | "madc.hi.cc.u32 a14, m, n7, a14;\n\t" 299 | "addc.u32 a15, a15, 0;\n\t" 300 | 301 | "mul.lo.u32 m, a7, q;\n\t" 302 | "mad.lo.cc.u32 a7, m, n0, a7;\n\t" 303 | "madc.lo.cc.u32 a8, m, n1, a8;\n\t" 304 | "madc.lo.cc.u32 a9, m, n2, a9;\n\t" 305 | "madc.lo.cc.u32 a10, m, n3, a10;\n\t" 306 | "madc.lo.cc.u32 a11, m, n4, a11;\n\t" 307 | "madc.lo.cc.u32 a12, m, n5, a12;\n\t" 308 | "madc.lo.cc.u32 a13, m, n6, a13;\n\t" 309 | "madc.lo.cc.u32 a14, m, n7, a14;\n\t" 310 | "addc.u32 a15, a15, 0;\n\t" 311 | 312 | "mad.hi.cc.u32 a8, m, n0, a8;\n\t" 313 | "madc.hi.cc.u32 a9, m, n1, a9;\n\t" 314 | "madc.hi.cc.u32 a10, m, n2, a10;\n\t" 315 | "madc.hi.cc.u32 a11, m, n3, a11;\n\t" 316 | "madc.hi.cc.u32 a12, m, n4, a12;\n\t" 317 | "madc.hi.cc.u32 a13, m, n5, a13;\n\t" 318 | "madc.hi.cc.u32 a14, m, n6, a14;\n\t" 319 | "madc.hi.u32 a15, m, n7, a15;\n\t" 320 | //pack result back 321 | "mov.b64 %0, {a8,a9};\n\t" 322 | "mov.b64 %1, {a10,a11};\n\t" 323 | "mov.b64 %2, {a12,a13};\n\t" 324 | "mov.b64 %3, {a14,a15};\n\t" 325 | : "=l"(w.nn[0]), "=l"(w.nn[1]), "=l"(w.nn[2]), "=l"(w.nn[3]) 326 | : "l"(T.nn[0]), "l"(T.nn[1]), "l"(T.nn[2]), "l"(T.nn[3]), 327 | "l"(T.nn[4]), "l"(T.nn[5]), "l"(T.nn[6]), "l"(T.nn[7])); 328 | 329 | 330 | if (CMP(w, BASE_FIELD_P) >= 0) 331 | { 332 | //TODO: may be better change to inary version of sub? 333 | w = SUB(w, BASE_FIELD_P); 334 | } 335 | 336 | return w; 337 | } 338 | 339 | #define STR_VALUE(arg) #arg 340 | 341 | //This block will be repeated - again and again 342 | #define ASM_REDUCTION_BLOCK \ 343 | "mul.lo.u32 m, r0, q;\n\t" \ 344 | "mad.lo.cc.u32 r0, m, n0, r0;\n\t" \ 345 | "madc.hi.cc.u32 r1, m, n0, r1;\n\t" \ 346 | "madc.hi.cc.u32 r2, m, n1, r2;\n\t" \ 347 | "madc.hi.cc.u32 r3, m, n2, r3;\n\t" \ 348 | "madc.hi.cc.u32 r4, m, n3, r4;\n\t" \ 349 | "madc.hi.cc.u32 r5, m, n4, r5;\n\t" \ 350 | "madc.hi.cc.u32 r6, m, n5, r6;\n\t" \ 351 | "madc.hi.cc.u32 r7, m, n6, r7;\n\t" \ 352 | "madc.hi.cc.u32 prefix_low, m, n7, prefix_low;\n\t" \ 353 | "addc.u32 prefix_high, 0, 0;\n\t" \ 354 | "mad.lo.cc.u32 r0, m, n1, r1;\n\t" \ 355 | "madc.lo.cc.u32 r1, m, n2, r2;\n\t" \ 356 | "madc.lo.cc.u32 r2, m, n3, r3;\n\t" \ 357 | "madc.lo.cc.u32 r3, m, n4, r4;\n\t" \ 358 | "madc.lo.cc.u32 r4, m, n5, r5;\n\t" \ 359 | "madc.lo.cc.u32 r5, m, n6, r6;\n\t" \ 360 | "madc.lo.cc.u32 r6, m, n7, r7;\n\t" \ 361 | "addc.cc.u32 r7, prefix_low, 0;\n\t" \ 362 | "addc.u32 prefix_low, prefix_high, 0;\n\t" 363 | 364 | //This block will also be repeated - but with rising index of a: a1, a2, ..., a7 365 | #define ASM_MUL_BLOCK(idx) \ 366 | "mad.lo.cc.u32 r0, a"#idx", b0, r0;\n\t" \ 367 | "madc.lo.cc.u32 r1, a"#idx", b1, r1;\n\t" \ 368 | "madc.lo.cc.u32 r2, a"#idx", b2, r2;\n\t" \ 369 | "madc.lo.cc.u32 r3, a"#idx", b3, r3;\n\t" \ 370 | "madc.lo.cc.u32 r4, a"#idx", b4, r4;\n\t" \ 371 | "madc.lo.cc.u32 r5, a"#idx", b5, r5;\n\t" \ 372 | "madc.lo.cc.u32 r6, a"#idx", b6, r6;\n\t" \ 373 | "madc.lo.cc.u32 r7, a"#idx", b7, r7;\n\t" \ 374 | "addc.u32 prefix_low, prefix_low, 0;\n\t" \ 375 | "mad.hi.cc.u32 r1, a"#idx", b0, r1;\n\t" \ 376 | "madc.hi.cc.u32 r2, a"#idx", b1, r2;\n\t" \ 377 | "madc.hi.cc.u32 r3, a"#idx", b2, r3;\n\t" \ 378 | "madc.hi.cc.u32 r4, a"#idx", b3, r4;\n\t" \ 379 | "madc.hi.cc.u32 r5, a"#idx", b4, r5;\n\t" \ 380 | "madc.hi.cc.u32 r6, a"#idx", b5, r6;\n\t" \ 381 | "madc.hi.cc.u32 r7, a"#idx", b6, r7;\n\t" \ 382 | "madc.hi.cc.u32 prefix_low, a"#idx", b7, prefix_low;\n\t" \ 383 | "addc.u32 prefix_high, 0, 0;\n\t" 384 | 385 | //NB: look carefully on line 11 on page 31 of http://eprints.utar.edu.my/2494/1/CS-2017-1401837-1.pdf 386 | //and find an opportunity for additional speedup 387 | DEVICE_FUNC uint256_g mont_mul_256_asm_CIOS(const uint256_g& u, const uint256_g& v) 388 | { 389 | uint256_g w; 390 | 391 | asm ( ".reg .u32 a0, a1, a2, a3, a4, a5, a6, a7;\n\t" 392 | ".reg .u32 b0, b1, b2, b3, b4, b5, b6, b7;\n\t" 393 | ".reg .u32 r0, r1, r2, r3, r4, r5, r6, r7;\n\t" 394 | ".reg .u32 n0, n1, n2, n3, n4, n5, n6, n7;\n\t" 395 | ".reg .u32 m, q, prefix_low, prefix_high;\n\t" 396 | 397 | "mov.b64 {a0,a1}, %4;\n\t" 398 | "mov.b64 {a2,a3}, %5;\n\t" 399 | "mov.b64 {a4,a5}, %6;\n\t" 400 | "mov.b64 {a6,a7}, %7;\n\t" 401 | "mov.b64 {b0,b1}, %8;\n\t" 402 | "mov.b64 {b2,b3}, %9;\n\t" 403 | "mov.b64 {b4,b5}, %10;\n\t" 404 | "mov.b64 {b6,b7}, %11;\n\t" 405 | "ld.const.u32 n0, [BASE_FIELD_P];\n\t" 406 | "ld.const.u32 n1, [BASE_FIELD_P + 4];\n\t" 407 | "ld.const.u32 n2, [BASE_FIELD_P + 8];\n\t" 408 | "ld.const.u32 n3, [BASE_FIELD_P + 12];\n\t" 409 | "ld.const.u32 n4, [BASE_FIELD_P + 16];\n\t" 410 | "ld.const.u32 n5, [BASE_FIELD_P + 20];\n\t" 411 | "ld.const.u32 n6, [BASE_FIELD_P + 24];\n\t" 412 | "ld.const.u32 n7, [BASE_FIELD_P + 28];\n\t" 413 | "ld.const.u32 q, [BASE_FIELD_N];\n\t" 414 | 415 | "mul.lo.u32 r0, a0, b0;\n\t" 416 | "mul.lo.u32 r1, a0, b1;\n\t" 417 | "mul.lo.u32 r2, a0, b2;\n\t" 418 | "mul.lo.u32 r3, a0, b3;\n\t" 419 | "mul.lo.u32 r4, a0, b4;\n\t" 420 | "mul.lo.u32 r5, a0, b5;\n\t" 421 | "mul.lo.u32 r6, a0, b6;\n\t" 422 | "mul.lo.u32 r7, a0, b7;\n\t" 423 | "mad.hi.cc.u32 r1, a0, b0, r1;\n\t" 424 | "madc.hi.cc.u32 r2, a0, b1, r2;\n\t" 425 | "madc.hi.cc.u32 r3, a0, b2, r3;\n\t" 426 | "madc.hi.cc.u32 r4, a0, b3, r4;\n\t" 427 | "madc.hi.cc.u32 r5, a0, b4, r5;\n\t" 428 | "madc.hi.cc.u32 r6, a0, b5, r6;\n\t" 429 | "madc.hi.cc.u32 r7, a0, b6, r7;\n\t" 430 | "madc.hi.cc.u32 prefix_low, a0, b7, 0;\n\t" 431 | 432 | ASM_REDUCTION_BLOCK 433 | ASM_MUL_BLOCK(1) 434 | ASM_REDUCTION_BLOCK 435 | ASM_MUL_BLOCK(2) 436 | ASM_REDUCTION_BLOCK 437 | ASM_MUL_BLOCK(3) 438 | ASM_REDUCTION_BLOCK 439 | ASM_MUL_BLOCK(4) 440 | ASM_REDUCTION_BLOCK 441 | ASM_MUL_BLOCK(5) 442 | ASM_REDUCTION_BLOCK 443 | ASM_MUL_BLOCK(6) 444 | ASM_REDUCTION_BLOCK 445 | ASM_MUL_BLOCK(7) 446 | ASM_REDUCTION_BLOCK 447 | 448 | //pack result back 449 | "mov.b64 %0, {r0,r1};\n\t" 450 | "mov.b64 %1, {r2,r3};\n\t" 451 | "mov.b64 %2, {r4,r5};\n\t" 452 | "mov.b64 %3, {r6,r7};\n\t" 453 | : "=l"(w.nn[0]), "=l"(w.nn[1]), "=l"(w.nn[2]), "=l"(w.nn[3]) 454 | : "l"(u.nn[0]), "l"(u.nn[1]), "l"(u.nn[2]), "l"(u.nn[3]), 455 | "l"(v.nn[0]), "l"(v.nn[1]), "l"(v.nn[2]), "l"(v.nn[3])); 456 | 457 | //NB: we can chain several montgomety muls without the below reduction 458 | //It also results in no warp divergence! 459 | 460 | if (CMP(w, BASE_FIELD_P) >= 0) 461 | { 462 | //TODO: may be better change to inary version of sub? 463 | w = SUB(w, BASE_FIELD_P); 464 | } 465 | 466 | return w; 467 | } --------------------------------------------------------------------------------