├── .github └── workflows │ └── go.yml ├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── Makefile ├── README.md ├── bitutils ├── bitutils.go └── bitutils_test.go ├── cloudkey └── cloudkey.go ├── evaluator ├── buffers.go ├── evaluator.go ├── gates_helper.go ├── programmable_bootstrap.go └── programmable_bootstrap_test.go ├── examples ├── EXAMPLES_GUIDE.md ├── add_two_numbers │ ├── README.md │ └── main.go ├── programmable_bootstrap │ └── main.go ├── proxy_reencryption │ └── main.go └── simple_gates │ └── main.go ├── gates ├── gates.go └── gates_test.go ├── go.mod ├── go.sum ├── key └── key.go ├── lut ├── analysis_test.go ├── debug_test.go ├── encoder.go ├── generator.go ├── lut.go ├── lut_test.go └── reference_algorithm_test.go ├── params ├── UINT_STATUS.md ├── params.go ├── params_test.go └── uint_params_test.go ├── poly ├── aligned.go ├── buffer_manager.go ├── buffer_methods.go ├── decomposer.go ├── fourier_ops.go ├── fourier_transform.go ├── poly.go ├── poly_evaluator.go ├── poly_mul.go └── poly_test.go ├── proxyreenc ├── proxyreenc.go └── proxyreenc_test.go ├── tlwe ├── programmable_encrypt.go ├── tlwe.go └── tlwe_test.go ├── trgsw ├── keyswitch.go └── trgsw.go ├── trlwe ├── trlwe.go └── trlwe_ops.go └── utils ├── utils.go └── utils_test.go /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | # This workflow will build a golang project 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go 3 | 4 | name: Go 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | jobs: 13 | 14 | build: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v4 18 | 19 | - name: Set up Go 20 | uses: actions/setup-go@v4 21 | with: 22 | go-version: '1.23' 23 | 24 | - name: Build 25 | run: make build 26 | 27 | - name: Test 28 | run: make test 29 | 30 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool 12 | *.out 13 | 14 | # Go workspace file 15 | go.work 16 | 17 | # Dependency directories 18 | vendor/ 19 | 20 | # IDE files 21 | .vscode/ 22 | .idea/ 23 | *.swp 24 | *.swo 25 | *~ 26 | 27 | # OS files 28 | .DS_Store 29 | Thumbs.db 30 | 31 | # Build artifacts 32 | /examples/*/add_two_numbers 33 | /examples/*/simple_gates 34 | 35 | 36 | # Rust build artifacts 37 | fft-bridge/target/ 38 | fft-bridge/Cargo.lock 39 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to go-tfhe will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 7 | 8 | ## [0.2.2] - 2025-11-04 9 | 10 | ### Fixed 11 | - Public key encryption noise parameter handling 12 | - Reencryption key index calculation edge cases 13 | - Memory allocation patterns for large public keys 14 | 15 | ### Improved 16 | - Public key generation performance (~15% faster) 17 | - Documentation clarity in proxyreenc package 18 | - Error messages for invalid parameters 19 | - Benchmark coverage for proxy reencryption operations 20 | 21 | ### Changed 22 | - Public key default size optimized for better security/performance balance 23 | - Inline documentation expanded with additional examples 24 | 25 | ### Performance 26 | - Public key generation: ~27ms → ~23ms (~15% improvement) 27 | - Asymmetric key generation: ~4.6s → ~4.4s (~4% improvement) 28 | - Reencryption: ~3.0ms (unchanged) 29 | - Memory usage optimized for concurrent operations 30 | 31 | ### Testing 32 | - Added edge case tests for boundary conditions 33 | - Improved test coverage for error paths 34 | - All 7 proxy_reenc tests still passing 35 | 36 | ## [0.2.0] - 2025-11-03 37 | 38 | ### Added 39 | - **Proxy Reencryption Package** (`proxyreenc`) - LWE-based proxy reencryption for secure delegation 40 | - `PublicKeyLv0` - LWE public key encryption support 41 | - `ProxyReencryptionKey` - Dual-mode reencryption keys (asymmetric/symmetric) 42 | - `ReencryptTLWELv0()` - Transform ciphertexts between keys without decryption 43 | - **Asymmetric Mode** (Recommended): 44 | - Generate reencryption keys using delegatee's public key only 45 | - No secret key sharing required 46 | - True proxy reencryption with 128-bit security 47 | - **Symmetric Mode** (Trusted scenarios): 48 | - Fast key generation for single-party key rotation 49 | - ~21ms key generation vs ~4.6s asymmetric 50 | - **Example Program**: `examples/proxy_reencryption/main.go` 51 | - Demonstrates asymmetric proxy reencryption workflow 52 | - Multi-hop chain example (Alice → Bob → Carol) 53 | - Performance metrics and security notes 54 | - **Test Suite**: 7 comprehensive tests 55 | - Public key encryption/decryption 56 | - Asymmetric and symmetric modes 57 | - Multi-hop chains 58 | - Statistical accuracy testing (100% accuracy) 59 | - **Benchmarks**: 4 benchmark functions 60 | - Asymmetric key generation 61 | - Symmetric key generation 62 | - Reencryption operation 63 | - Public key generation 64 | 65 | ### Performance 66 | - **Public key generation**: ~27ms 67 | - **Asymmetric keygen**: ~4.6s (1.65x faster than Rust!) 68 | - **Symmetric keygen**: ~21ms (4.3x faster than Rust) 69 | - **Reencryption**: ~3.0ms 70 | - **Accuracy**: 100% verified over 100+ iterations 71 | - **Security**: 128-bit post-quantum resistant 72 | 73 | ### Security 74 | - Based on Learning With Errors (LWE) hardness assumption 75 | - Quantum-resistant by design 76 | - Unidirectional delegation (Alice→Bob ≠ Bob→Alice) 77 | - Proxy learns nothing about plaintext 78 | - No secret key exposure in asymmetric mode 79 | - 128-bit security level maintained 80 | 81 | ### Testing 82 | - 7 new unit tests for proxy reencryption (all passing) 83 | - Statistical accuracy testing with 100 iterations 84 | - Multi-hop chain verification (3-hop tested) 85 | - Memory safety verified 86 | - Benchmarks for all major operations 87 | 88 | ### Documentation 89 | - Package-level godoc documentation 90 | - Inline API documentation 91 | - Complete example program with explanations 92 | - Release notes (RELEASE_NOTES_v0.2.0.md) 93 | - README.md updated with new features 94 | 95 | ### Notes 96 | - **Breaking**: None - purely additive feature 97 | - **Compatibility**: Go 1.21+ required 98 | - **Dependencies**: No new dependencies (pure Go) 99 | - Port of rs-tfhe v0.2.0 proxy reencryption feature 100 | - Feature parity with zig-tfhe v0.2.0 101 | 102 | ## [0.1.0] - 2025-XX-XX 103 | 104 | ### Added 105 | - Initial release of go-tfhe 106 | - Core TFHE functionality (TLWE, TRLWE, TRGSW) 107 | - Bootstrap operations 108 | - Homomorphic logic gates (AND, OR, XOR, NAND, NOR, XNOR, NOT, MUX) 109 | - Key generation (SecretKey, CloudKey) 110 | - FFT implementation for efficient polynomial operations 111 | - Programmable bootstrapping with lookup tables 112 | - Multiple security levels (80-bit, 110-bit, 128-bit) 113 | - Specialized Uint parameters for multi-bit arithmetic 114 | - Examples: add_two_numbers, simple_gates, programmable_bootstrap 115 | - Comprehensive test suite 116 | - Pure Go implementation (no CGO) 117 | 118 | [0.2.2]: https://github.com/thedonutfactory/go-tfhe/compare/v0.2.0...v0.2.2 119 | [0.2.0]: https://github.com/thedonutfactory/go-tfhe/compare/v0.1.0...v0.2.0 120 | [0.1.0]: https://github.com/thedonutfactory/go-tfhe/releases/tag/v0.1.0 121 | 122 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 The Donut Factory 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all build test clean examples fmt vet test-quick test-gates test-nocache test-gates-nocache 2 | 3 | all: build test 4 | 5 | build: 6 | @echo "Building go-tfhe..." 7 | go build ./... 8 | 9 | test: 10 | @echo "Running tests..." 11 | go test -v ./... 12 | 13 | test-quick: 14 | @echo "Running quick tests (non-gate tests)..." 15 | go test -v ./params ./utils ./bitutils ./tlwe ./trlwe ./key ./cloudkey ./evaluator 16 | 17 | test-gates: 18 | @echo "Running gate tests (this will take several minutes)..." 19 | @echo "Each gate test takes ~400ms, batch tests take longer..." 20 | go test -v -timeout 30m ./gates 21 | 22 | test-nocache: 23 | @echo "Running tests without cache..." 24 | go test -count=1 -v ./... 25 | 26 | test-gates-nocache: 27 | @echo "Running gate tests without cache..." 28 | go test -count=1 -v -timeout 30m ./gates 29 | 30 | examples: 31 | @echo "Building examples..." 32 | cd examples/add_two_numbers && go build -o ../../bin/add_two_numbers 33 | cd examples/simple_gates && go build -o ../../bin/simple_gates 34 | cd examples/programmable_bootstrap && go build -o ../../bin/programmable_bootstrap 35 | cd examples/add_8bit_pbs && go build -o ../../bin/add_8bit_pbs 36 | 37 | run-add: 38 | @echo "Running add_two_numbers example..." 39 | cd examples/add_two_numbers && go run main.go 40 | 41 | run-gates: 42 | @echo "Running simple_gates example..." 43 | cd examples/simple_gates && go run main.go 44 | 45 | run-pbs: 46 | @echo "Running programmable_bootstrap example..." 47 | cd examples/programmable_bootstrap && go run main.go 48 | 49 | run-add-pbs: 50 | @echo "Running add_8bit_pbs example (Fast 8-bit addition with PBS)..." 51 | cd examples/add_8bit_pbs && go run main.go 52 | 53 | fmt: 54 | @echo "Formatting code..." 55 | go fmt ./... 56 | 57 | vet: 58 | @echo "Running go vet..." 59 | go vet ./... 60 | 61 | clean: 62 | @echo "Cleaning build artifacts..." 63 | go clean ./... 64 | rm -rf bin/ 65 | rm -f examples/add_two_numbers/add_two_numbers 66 | rm -f examples/simple_gates/simple_gates 67 | rm -f examples/programmable_bootstrap/programmable_bootstrap 68 | rm -f examples/add_8bit_pbs/add_8bit_pbs 69 | 70 | install-deps: 71 | @echo "Installing dependencies..." 72 | go mod download 73 | go mod verify 74 | 75 | benchmark: 76 | @echo "Running benchmarks..." 77 | go test -bench=. -benchmem ./... 78 | 79 | help: 80 | @echo "Available targets:" 81 | @echo "" 82 | @echo "Building:" 83 | @echo " all - Build and test" 84 | @echo " build - Build all packages" 85 | @echo "" 86 | @echo "Testing:" 87 | @echo " test - Run all tests" 88 | @echo " test-nocache - Run all tests without cache" 89 | @echo " test-quick - Run quick tests (no gate tests)" 90 | @echo " test-gates - Run gate tests only" 91 | @echo " test-gates-nocache - Run gate tests without cache" 92 | @echo "" 93 | @echo "Benchmarking:" 94 | @echo " benchmark - Benchmark FFT" 95 | @echo "" 96 | @echo "Examples:" 97 | @echo " examples - Build all examples" 98 | @echo " run-gates - Run simple_gates example" 99 | @echo " run-add - Run add_two_numbers example (traditional 8-bit, ~1.1s)" 100 | @echo " run-add-pbs - Run add_8bit_pbs example (PBS 8-bit, ~230ms, 4.8x faster!)" 101 | @echo " run-pbs - Run programmable_bootstrap example" 102 | @echo "" 103 | @echo "Utilities:" 104 | @echo " fmt - Format code" 105 | @echo " vet - Run go vet" 106 | @echo " clean - Remove build artifacts" 107 | @echo " install-deps - Install/verify dependencies" 108 | -------------------------------------------------------------------------------- /bitutils/bitutils.go: -------------------------------------------------------------------------------- 1 | package bitutils 2 | 3 | import ( 4 | "github.com/thedonutfactory/go-tfhe/params" 5 | "github.com/thedonutfactory/go-tfhe/tlwe" 6 | ) 7 | 8 | // Convert converts a slice of bits to a number 9 | // Bits are in little-endian order (LSB first) 10 | func ConvertU8(bits []bool) uint8 { 11 | var result uint8 12 | for i := len(bits) - 1; i >= 0; i-- { 13 | result <<= 1 14 | if bits[i] { 15 | result |= 1 16 | } 17 | } 18 | return result 19 | } 20 | 21 | func ConvertU16(bits []bool) uint16 { 22 | var result uint16 23 | for i := len(bits) - 1; i >= 0; i-- { 24 | result <<= 1 25 | if bits[i] { 26 | result |= 1 27 | } 28 | } 29 | return result 30 | } 31 | 32 | func ConvertU32(bits []bool) uint32 { 33 | var result uint32 34 | for i := len(bits) - 1; i >= 0; i-- { 35 | result <<= 1 36 | if bits[i] { 37 | result |= 1 38 | } 39 | } 40 | return result 41 | } 42 | 43 | func ConvertU64(bits []bool) uint64 { 44 | var result uint64 45 | for i := len(bits) - 1; i >= 0; i-- { 46 | result <<= 1 47 | if bits[i] { 48 | result |= 1 49 | } 50 | } 51 | return result 52 | } 53 | 54 | // ToBits converts a number to a slice of bits 55 | // Returns bits in little-endian order (LSB first) 56 | func ToBits(val uint64, size int) []bool { 57 | vec := make([]bool, size) 58 | for i := 0; i < size; i++ { 59 | vec[i] = ((val >> i) & 1) != 0 60 | } 61 | return vec 62 | } 63 | 64 | // U8ToBits converts a uint8 to a slice of bits 65 | func U8ToBits(val uint8) []bool { 66 | return ToBits(uint64(val), 8) 67 | } 68 | 69 | // U16ToBits converts a uint16 to a slice of bits 70 | func U16ToBits(val uint16) []bool { 71 | return ToBits(uint64(val), 16) 72 | } 73 | 74 | // U32ToBits converts a uint32 to a slice of bits 75 | func U32ToBits(val uint32) []bool { 76 | return ToBits(uint64(val), 32) 77 | } 78 | 79 | // U64ToBits converts a uint64 to a slice of bits 80 | func U64ToBits(val uint64) []bool { 81 | return ToBits(val, 64) 82 | } 83 | 84 | // EncryptBits encrypts a slice of bits using the given secret key 85 | func EncryptBits(bits []bool, alpha float64, key []params.Torus) []*tlwe.TLWELv0 { 86 | result := make([]*tlwe.TLWELv0, len(bits)) 87 | for i, bit := range bits { 88 | result[i] = tlwe.NewTLWELv0().EncryptBool(bit, alpha, key) 89 | } 90 | return result 91 | } 92 | 93 | // DecryptBits decrypts a slice of ciphertexts to bits 94 | func DecryptBits(ctxts []*tlwe.TLWELv0, key []params.Torus) []bool { 95 | result := make([]bool, len(ctxts)) 96 | for i, ctxt := range ctxts { 97 | result[i] = ctxt.DecryptBool(key) 98 | } 99 | return result 100 | } 101 | -------------------------------------------------------------------------------- /bitutils/bitutils_test.go: -------------------------------------------------------------------------------- 1 | package bitutils_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/thedonutfactory/go-tfhe/bitutils" 7 | ) 8 | 9 | func TestU8ToBitsAndBack(t *testing.T) { 10 | testCases := []uint8{0, 1, 5, 42, 127, 255} 11 | 12 | for _, val := range testCases { 13 | bits := bitutils.U8ToBits(val) 14 | result := bitutils.ConvertU8(bits) 15 | 16 | if result != val { 17 | t.Errorf("U8: %d -> bits -> %d", val, result) 18 | } 19 | } 20 | } 21 | 22 | func TestU16ToBitsAndBack(t *testing.T) { 23 | testCases := []uint16{0, 1, 5, 42, 127, 255, 402, 304, 706, 65535} 24 | 25 | for _, val := range testCases { 26 | bits := bitutils.U16ToBits(val) 27 | result := bitutils.ConvertU16(bits) 28 | 29 | if result != val { 30 | t.Errorf("U16: %d -> bits -> %d", val, result) 31 | } 32 | } 33 | } 34 | 35 | func TestU32ToBitsAndBack(t *testing.T) { 36 | testCases := []uint32{0, 1, 42, 1000000, 4294967295} 37 | 38 | for _, val := range testCases { 39 | bits := bitutils.U32ToBits(val) 40 | result := bitutils.ConvertU32(bits) 41 | 42 | if result != val { 43 | t.Errorf("U32: %d -> bits -> %d", val, result) 44 | } 45 | } 46 | } 47 | 48 | func TestU64ToBitsAndBack(t *testing.T) { 49 | testCases := []uint64{0, 1, 42, 1000000, 18446744073709551615} 50 | 51 | for _, val := range testCases { 52 | bits := bitutils.U64ToBits(val) 53 | result := bitutils.ConvertU64(bits) 54 | 55 | if result != val { 56 | t.Errorf("U64: %d -> bits -> %d", val, result) 57 | } 58 | } 59 | } 60 | 61 | func TestToBitsLSBFirst(t *testing.T) { 62 | // Verify that bits are in LSB-first order 63 | bits := bitutils.U8ToBits(5) // 5 = 0b00000101 64 | 65 | // LSB first: [true, false, true, false, false, false, false, false] 66 | expected := []bool{true, false, true, false, false, false, false, false} 67 | 68 | if len(bits) != len(expected) { 69 | t.Fatalf("Bit length = %d, expected %d", len(bits), len(expected)) 70 | } 71 | 72 | for i := range bits { 73 | if bits[i] != expected[i] { 74 | t.Errorf("U8ToBits(5)[%d] = %v, expected %v", i, bits[i], expected[i]) 75 | } 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /cloudkey/cloudkey.go: -------------------------------------------------------------------------------- 1 | package cloudkey 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/thedonutfactory/go-tfhe/key" 7 | "github.com/thedonutfactory/go-tfhe/params" 8 | "github.com/thedonutfactory/go-tfhe/poly" 9 | "github.com/thedonutfactory/go-tfhe/tlwe" 10 | "github.com/thedonutfactory/go-tfhe/trgsw" 11 | "github.com/thedonutfactory/go-tfhe/trlwe" 12 | "github.com/thedonutfactory/go-tfhe/utils" 13 | ) 14 | 15 | // CloudKey contains the public evaluation keys 16 | type CloudKey struct { 17 | DecompositionOffset params.Torus 18 | BlindRotateTestvec *trlwe.TRLWELv1 19 | KeySwitchingKey []*tlwe.TLWELv0 20 | BootstrappingKey []*trgsw.TRGSWLv1FFT 21 | } 22 | 23 | // NewCloudKey generates a new cloud key from a secret key 24 | func NewCloudKey(secretKey *key.SecretKey) *CloudKey { 25 | return &CloudKey{ 26 | DecompositionOffset: genDecompositionOffset(), 27 | BlindRotateTestvec: genTestvec(), 28 | KeySwitchingKey: genKeySwitchingKey(secretKey), 29 | BootstrappingKey: genBootstrappingKey(secretKey), 30 | } 31 | } 32 | 33 | // NewCloudKeyNoKSK creates a cloud key without key switching key (for testing) 34 | func NewCloudKeyNoKSK() *CloudKey { 35 | base := 1 << params.GetTRGSWLv1().BASEBIT 36 | iksT := params.GetTRGSWLv1().IKS_T 37 | n := params.GetTRGSWLv1().N 38 | lv0N := params.GetTLWELv0().N 39 | 40 | ksk := make([]*tlwe.TLWELv0, base*iksT*n) 41 | for i := range ksk { 42 | ksk[i] = tlwe.NewTLWELv0() 43 | } 44 | 45 | polyEval := poly.NewEvaluator(n) 46 | bsk := make([]*trgsw.TRGSWLv1FFT, lv0N) 47 | for i := range bsk { 48 | bsk[i] = trgsw.NewTRGSWLv1FFTDummy(polyEval) 49 | } 50 | 51 | return &CloudKey{ 52 | DecompositionOffset: genDecompositionOffset(), 53 | BlindRotateTestvec: genTestvec(), 54 | KeySwitchingKey: ksk, 55 | BootstrappingKey: bsk, 56 | } 57 | } 58 | 59 | // genDecompositionOffset generates the decomposition offset 60 | func genDecompositionOffset() params.Torus { 61 | var offset params.Torus 62 | l := params.GetTRGSWLv1().L 63 | bg := params.GetTRGSWLv1().BG 64 | bgbit := params.GetTRGSWLv1().BGBIT 65 | 66 | for i := 0; i < l; i++ { 67 | offset += params.Torus(bg/2) * params.Torus(1<<(32-((i+1)*int(bgbit)))) 68 | } 69 | 70 | return offset 71 | } 72 | 73 | // genTestvec generates the test vector for blind rotation 74 | func genTestvec() *trlwe.TRLWELv1 { 75 | n := params.GetTRGSWLv1().N 76 | testvec := trlwe.NewTRLWELv1() 77 | bTorus := utils.F64ToTorus(0.125) 78 | 79 | for i := 0; i < n; i++ { 80 | testvec.A[i] = 0 81 | testvec.B[i] = bTorus 82 | } 83 | 84 | return testvec 85 | } 86 | 87 | // genKeySwitchingKey generates the key switching key (parallelized) 88 | func genKeySwitchingKey(secretKey *key.SecretKey) []*tlwe.TLWELv0 { 89 | basebit := params.GetTRGSWLv1().BASEBIT 90 | iksT := params.GetTRGSWLv1().IKS_T 91 | base := 1 << basebit 92 | n := params.GetTRGSWLv1().N 93 | 94 | result := make([]*tlwe.TLWELv0, base*iksT*n) 95 | for i := range result { 96 | result[i] = tlwe.NewTLWELv0() 97 | } 98 | 99 | var wg sync.WaitGroup 100 | for i := 0; i < n; i++ { 101 | wg.Add(1) 102 | go func(iIdx int) { 103 | defer wg.Done() 104 | for j := 0; j < iksT; j++ { 105 | for k := 0; k < base; k++ { 106 | if k == 0 { 107 | continue 108 | } 109 | shift := uint((j + 1) * basebit) 110 | p := (float64(k) * float64(secretKey.KeyLv1[iIdx])) / float64(uint64(1)<> (32 - nBit - 1)) 117 | poly.PolyMulWithXKInPlace(testvec.A, bTilda, e.Buffers.BlindRotation.Accumulator1.A) 118 | poly.PolyMulWithXKInPlace(testvec.B, bTilda, e.Buffers.BlindRotation.Accumulator1.B) 119 | 120 | // Iterate through LWE coefficients 121 | for i := 0; i < tlweLv0N; i++ { 122 | aTilda := int((ctIn.P[i] + (1 << (31 - nBit - 1))) >> (32 - nBit - 1)) 123 | 124 | // Rotate into buffer.ctAcc2 125 | poly.PolyMulWithXKInPlace(e.Buffers.BlindRotation.Accumulator1.A, aTilda, e.Buffers.BlindRotation.Accumulator2.A) 126 | poly.PolyMulWithXKInPlace(e.Buffers.BlindRotation.Accumulator1.B, aTilda, e.Buffers.BlindRotation.Accumulator2.B) 127 | 128 | // CMux: ctAcc1 = ctAcc1 + bsk[i] * (ctAcc2 - ctAcc1) 129 | e.CMuxAssign(bsk[i], e.Buffers.BlindRotation.Accumulator1, e.Buffers.BlindRotation.Accumulator2, decompositionOffset, e.Buffers.BlindRotation.Accumulator1) 130 | } 131 | 132 | // Copy result to output 133 | copy(ctOut.A, e.Buffers.BlindRotation.Accumulator1.A) 134 | copy(ctOut.B, e.Buffers.BlindRotation.Accumulator1.B) 135 | } 136 | 137 | // BootstrapAssign performs full bootstrapping (blind rotate + key switch) 138 | // Zero-allocation version - writes to ctOut 139 | func (e *Evaluator) BootstrapAssign(ctIn *tlwe.TLWELv0, testvec *trlwe.TRLWELv1, bsk []*trgsw.TRGSWLv1FFT, ksk []*tlwe.TLWELv0, decompositionOffset params.Torus, ctOut *tlwe.TLWELv0) { 140 | // Blind rotate 141 | e.BlindRotateAssign(ctIn, testvec, bsk, decompositionOffset, e.Buffers.BlindRotation.Rotated) 142 | 143 | // Sample extract 144 | trlwe.SampleExtractIndexAssign(e.Buffers.BlindRotation.Rotated, 0, e.Buffers.Bootstrap.ExtractedLWE) 145 | 146 | // Key switch - writes directly to ctOut (zero-allocation!) 147 | trgsw.IdentityKeySwitchingAssign(e.Buffers.Bootstrap.ExtractedLWE, ksk, ctOut) 148 | } 149 | 150 | // Bootstrap performs full bootstrapping and returns result using buffer pool 151 | // Returns pointer to buffer pool - valid until 4 more bootstrap calls 152 | func (e *Evaluator) Bootstrap(ctIn *tlwe.TLWELv0, testvec *trlwe.TRLWELv1, bsk []*trgsw.TRGSWLv1FFT, ksk []*tlwe.TLWELv0, decompositionOffset params.Torus) *tlwe.TLWELv0 { 153 | // Get result buffer from pool (round-robin) 154 | result := e.Buffers.GetNextResult() 155 | e.BootstrapAssign(ctIn, testvec, bsk, ksk, decompositionOffset, result) 156 | return result 157 | } 158 | 159 | // ResetBuffers resets all buffer pool indices 160 | func (e *Evaluator) ResetBuffers() { 161 | e.Buffers.Reset() 162 | } 163 | -------------------------------------------------------------------------------- /evaluator/gates_helper.go: -------------------------------------------------------------------------------- 1 | package evaluator 2 | 3 | import ( 4 | "github.com/thedonutfactory/go-tfhe/params" 5 | "github.com/thedonutfactory/go-tfhe/tlwe" 6 | "github.com/thedonutfactory/go-tfhe/utils" 7 | ) 8 | 9 | // PrepareNAND prepares a NAND input for bootstrapping (zero-allocation) 10 | func (e *Evaluator) PrepareNAND(a, b *tlwe.TLWELv0) *tlwe.TLWELv0 { 11 | n := params.GetTLWELv0().N 12 | result := tlwe.NewTLWELv0() 13 | 14 | // NAND: -(a + b) + 1/8 15 | for i := 0; i < n; i++ { 16 | result.P[i] = -(a.P[i] + b.P[i]) 17 | } 18 | result.P[n] = -(a.P[n] + b.P[n]) + utils.F64ToTorus(0.125) 19 | 20 | return result 21 | } 22 | 23 | // PrepareAND prepares an AND input for bootstrapping 24 | func (e *Evaluator) PrepareAND(a, b *tlwe.TLWELv0) *tlwe.TLWELv0 { 25 | n := params.GetTLWELv0().N 26 | result := tlwe.NewTLWELv0() 27 | 28 | // AND: (a + b) - 1/8 29 | for i := 0; i < n; i++ { 30 | result.P[i] = a.P[i] + b.P[i] 31 | } 32 | result.P[n] = a.P[n] + b.P[n] + utils.F64ToTorus(-0.125) 33 | 34 | return result 35 | } 36 | 37 | // PrepareOR prepares an OR input for bootstrapping 38 | func (e *Evaluator) PrepareOR(a, b *tlwe.TLWELv0) *tlwe.TLWELv0 { 39 | n := params.GetTLWELv0().N 40 | result := tlwe.NewTLWELv0() 41 | 42 | // OR: (a + b) + 1/8 43 | for i := 0; i < n; i++ { 44 | result.P[i] = a.P[i] + b.P[i] 45 | } 46 | result.P[n] = a.P[n] + b.P[n] + utils.F64ToTorus(0.125) 47 | 48 | return result 49 | } 50 | 51 | // PrepareXOR prepares an XOR input for bootstrapping 52 | func (e *Evaluator) PrepareXOR(a, b *tlwe.TLWELv0) *tlwe.TLWELv0 { 53 | n := params.GetTLWELv0().N 54 | result := tlwe.NewTLWELv0() 55 | 56 | // XOR: (a + 2*b) + 1/4 57 | for i := 0; i < n; i++ { 58 | result.P[i] = a.P[i] + 2*b.P[i] 59 | } 60 | result.P[n] = a.P[n] + 2*b.P[n] + utils.F64ToTorus(0.25) 61 | 62 | return result 63 | } 64 | -------------------------------------------------------------------------------- /evaluator/programmable_bootstrap.go: -------------------------------------------------------------------------------- 1 | package evaluator 2 | 3 | import ( 4 | "github.com/thedonutfactory/go-tfhe/lut" 5 | "github.com/thedonutfactory/go-tfhe/params" 6 | "github.com/thedonutfactory/go-tfhe/tlwe" 7 | "github.com/thedonutfactory/go-tfhe/trgsw" 8 | "github.com/thedonutfactory/go-tfhe/trlwe" 9 | ) 10 | 11 | // BootstrapFunc performs programmable bootstrapping with a function 12 | // The function f operates on the message space [0, messageModulus) and 13 | // is evaluated homomorphically on the encrypted data during bootstrapping. 14 | // 15 | // This combines noise refreshing with arbitrary function evaluation. 16 | func (e *Evaluator) BootstrapFunc( 17 | ctIn *tlwe.TLWELv0, 18 | f func(int) int, 19 | messageModulus int, 20 | bsk []*trgsw.TRGSWLv1FFT, 21 | ksk []*tlwe.TLWELv0, 22 | decompositionOffset params.Torus, 23 | ) *tlwe.TLWELv0 { 24 | // Generate lookup table from function 25 | generator := lut.NewGenerator(messageModulus) 26 | lookupTable := generator.GenLookUpTable(f) 27 | 28 | // Perform LUT-based bootstrapping 29 | return e.BootstrapLUT(ctIn, lookupTable, bsk, ksk, decompositionOffset) 30 | } 31 | 32 | // BootstrapFuncAssign performs programmable bootstrapping with a function (zero-allocation) 33 | func (e *Evaluator) BootstrapFuncAssign( 34 | ctIn *tlwe.TLWELv0, 35 | f func(int) int, 36 | messageModulus int, 37 | bsk []*trgsw.TRGSWLv1FFT, 38 | ksk []*tlwe.TLWELv0, 39 | decompositionOffset params.Torus, 40 | ctOut *tlwe.TLWELv0, 41 | ) { 42 | // Generate lookup table from function 43 | generator := lut.NewGenerator(messageModulus) 44 | lookupTable := generator.GenLookUpTable(f) 45 | 46 | // Perform LUT-based bootstrapping 47 | e.BootstrapLUTAssign(ctIn, lookupTable, bsk, ksk, decompositionOffset, ctOut) 48 | } 49 | 50 | // BootstrapLUT performs programmable bootstrapping with a pre-computed lookup table 51 | // The lookup table encodes the function to be evaluated during bootstrapping. 52 | // 53 | // This is more efficient than BootstrapFunc when the same function is used multiple times. 54 | func (e *Evaluator) BootstrapLUT( 55 | ctIn *tlwe.TLWELv0, 56 | lut *lut.LookUpTable, 57 | bsk []*trgsw.TRGSWLv1FFT, 58 | ksk []*tlwe.TLWELv0, 59 | decompositionOffset params.Torus, 60 | ) *tlwe.TLWELv0 { 61 | result := e.Buffers.GetNextResult() 62 | e.BootstrapLUTAssign(ctIn, lut, bsk, ksk, decompositionOffset, result) 63 | 64 | copiedResult := tlwe.NewTLWELv0() 65 | copy(copiedResult.P, result.P) 66 | copiedResult.SetB(result.B()) 67 | 68 | return copiedResult 69 | } 70 | 71 | func (e *Evaluator) BootstrapLUTTemp( 72 | ctIn *tlwe.TLWELv0, 73 | lut *lut.LookUpTable, 74 | bsk []*trgsw.TRGSWLv1FFT, 75 | ksk []*tlwe.TLWELv0, 76 | decompositionOffset params.Torus, 77 | ) *tlwe.TLWELv0 { 78 | result := e.Buffers.GetNextResult() 79 | e.BootstrapLUTAssign(ctIn, lut, bsk, ksk, decompositionOffset, result) 80 | return result 81 | } 82 | 83 | // BootstrapLUTAssign performs programmable bootstrapping with a lookup table (zero-allocation) 84 | // This is the core implementation of programmable bootstrapping. 85 | // 86 | // Algorithm: 87 | // 1. Blind rotate the lookup table based on the encrypted value 88 | // 2. Sample extract to get an LWE ciphertext 89 | // 3. Key switch to convert back to the original key 90 | // 91 | // The key insight is that we can reuse the existing BlindRotateAssign function 92 | // by converting the LUT into a TRLWE ciphertext (test vector). 93 | func (e *Evaluator) BootstrapLUTAssign( 94 | ctIn *tlwe.TLWELv0, 95 | lut *lut.LookUpTable, 96 | bsk []*trgsw.TRGSWLv1FFT, 97 | ksk []*tlwe.TLWELv0, 98 | decompositionOffset params.Torus, 99 | ctOut *tlwe.TLWELv0, 100 | ) { 101 | // Convert LUT to TRLWE format (test vector) 102 | // The LUT is already a TRLWE with the function encoded in the B polynomial 103 | testvec := lut.Poly 104 | 105 | // Perform blind rotation using the LUT as the test vector 106 | // This rotates the LUT based on the encrypted value, effectively evaluating the function 107 | e.BlindRotateAssign(ctIn, testvec, bsk, decompositionOffset, e.Buffers.BlindRotation.Rotated) 108 | 109 | // Extract the constant term as an LWE ciphertext 110 | // This gives us the function evaluation encrypted under the TRLWE key 111 | trlwe.SampleExtractIndexAssign(e.Buffers.BlindRotation.Rotated, 0, e.Buffers.Bootstrap.ExtractedLWE) 112 | 113 | // Key switch to convert back to the original LWE key 114 | trgsw.IdentityKeySwitchingAssign(e.Buffers.Bootstrap.ExtractedLWE, ksk, ctOut) 115 | } 116 | -------------------------------------------------------------------------------- /evaluator/programmable_bootstrap_test.go: -------------------------------------------------------------------------------- 1 | package evaluator 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/thedonutfactory/go-tfhe/cloudkey" 7 | "github.com/thedonutfactory/go-tfhe/key" 8 | "github.com/thedonutfactory/go-tfhe/lut" 9 | "github.com/thedonutfactory/go-tfhe/params" 10 | "github.com/thedonutfactory/go-tfhe/tlwe" 11 | ) 12 | 13 | func TestProgrammableBootstrapIdentity(t *testing.T) { 14 | // Use 80-bit security for faster testing 15 | oldSecurityLevel := params.CurrentSecurityLevel 16 | params.CurrentSecurityLevel = params.Security80Bit 17 | defer func() { params.CurrentSecurityLevel = oldSecurityLevel }() 18 | 19 | // Generate keys 20 | secretKey := key.NewSecretKey() 21 | cloudKey := cloudkey.NewCloudKey(secretKey) 22 | 23 | // Create evaluator 24 | eval := NewEvaluator(params.GetTRGSWLv1().N) 25 | 26 | // Test identity function: f(x) = x 27 | identity := func(x int) int { return x } 28 | 29 | // Test with both 0 and 1 30 | testCases := []struct { 31 | name string 32 | input int 33 | want int 34 | }{ 35 | {"identity(0)", 0, 0}, 36 | {"identity(1)", 1, 1}, 37 | } 38 | 39 | for _, tc := range testCases { 40 | t.Run(tc.name, func(t *testing.T) { 41 | // Encrypt input using LWE message encoding (not binary encoding!) 42 | ct := tlwe.NewTLWELv0() 43 | ct.EncryptLWEMessage(tc.input, 2, params.GetTLWELv0().ALPHA, secretKey.KeyLv0) 44 | 45 | // Apply programmable bootstrap 46 | result := eval.BootstrapFunc( 47 | ct, 48 | identity, 49 | 2, // binary message modulus 50 | cloudKey.BootstrappingKey, 51 | cloudKey.KeySwitchingKey, 52 | cloudKey.DecompositionOffset, 53 | ) 54 | 55 | // Decrypt and verify using LWE message decoding 56 | decrypted := result.DecryptLWEMessage(2, secretKey.KeyLv0) 57 | if decrypted != tc.want { 58 | t.Errorf("identity(%d) = %d, want %d", tc.input, decrypted, tc.want) 59 | } 60 | }) 61 | } 62 | } 63 | 64 | func TestProgrammableBootstrapNOT(t *testing.T) { 65 | oldSecurityLevel := params.CurrentSecurityLevel 66 | params.CurrentSecurityLevel = params.Security80Bit 67 | defer func() { params.CurrentSecurityLevel = oldSecurityLevel }() 68 | 69 | secretKey := key.NewSecretKey() 70 | cloudKey := cloudkey.NewCloudKey(secretKey) 71 | eval := NewEvaluator(params.GetTRGSWLv1().N) 72 | 73 | // Test NOT function: f(x) = 1 - x 74 | notFunc := func(x int) int { return 1 - x } 75 | 76 | testCases := []struct { 77 | name string 78 | input int 79 | want int 80 | }{ 81 | {"NOT(0)", 0, 1}, 82 | {"NOT(1)", 1, 0}, 83 | } 84 | 85 | for _, tc := range testCases { 86 | t.Run(tc.name, func(t *testing.T) { 87 | ct := tlwe.NewTLWELv0() 88 | ct.EncryptLWEMessage(tc.input, 2, params.GetTLWELv0().ALPHA, secretKey.KeyLv0) 89 | 90 | result := eval.BootstrapFunc( 91 | ct, 92 | notFunc, 93 | 2, 94 | cloudKey.BootstrappingKey, 95 | cloudKey.KeySwitchingKey, 96 | cloudKey.DecompositionOffset, 97 | ) 98 | 99 | decrypted := result.DecryptLWEMessage(2, secretKey.KeyLv0) 100 | if decrypted != tc.want { 101 | t.Errorf("NOT(%d) = %d, want %d", tc.input, decrypted, tc.want) 102 | } 103 | }) 104 | } 105 | } 106 | 107 | func TestProgrammableBootstrapConstant(t *testing.T) { 108 | oldSecurityLevel := params.CurrentSecurityLevel 109 | params.CurrentSecurityLevel = params.Security80Bit 110 | defer func() { params.CurrentSecurityLevel = oldSecurityLevel }() 111 | 112 | secretKey := key.NewSecretKey() 113 | cloudKey := cloudkey.NewCloudKey(secretKey) 114 | eval := NewEvaluator(params.GetTRGSWLv1().N) 115 | 116 | // Test constant function: f(x) = 1 (always returns 1) 117 | constantOne := func(x int) int { return 1 } 118 | 119 | testCases := []struct { 120 | name string 121 | input int 122 | }{ 123 | {"constant(0)", 0}, 124 | {"constant(1)", 1}, 125 | } 126 | 127 | for _, tc := range testCases { 128 | t.Run(tc.name, func(t *testing.T) { 129 | ct := tlwe.NewTLWELv0() 130 | ct.EncryptLWEMessage(tc.input, 2, params.GetTLWELv0().ALPHA, secretKey.KeyLv0) 131 | 132 | result := eval.BootstrapFunc( 133 | ct, 134 | constantOne, 135 | 2, 136 | cloudKey.BootstrappingKey, 137 | cloudKey.KeySwitchingKey, 138 | cloudKey.DecompositionOffset, 139 | ) 140 | 141 | // Should always decrypt to 1 142 | decrypted := result.DecryptLWEMessage(2, secretKey.KeyLv0) 143 | if decrypted != 1 { 144 | t.Errorf("constant(%d) = %d, want 1", tc.input, decrypted) 145 | } 146 | }) 147 | } 148 | } 149 | 150 | func TestBootstrapLUTReuse(t *testing.T) { 151 | // Test that we can reuse a lookup table for multiple encryptions 152 | oldSecurityLevel := params.CurrentSecurityLevel 153 | params.CurrentSecurityLevel = params.Security80Bit 154 | defer func() { params.CurrentSecurityLevel = oldSecurityLevel }() 155 | 156 | secretKey := key.NewSecretKey() 157 | cloudKey := cloudkey.NewCloudKey(secretKey) 158 | eval := NewEvaluator(params.GetTRGSWLv1().N) 159 | gen := lut.NewGenerator(2) 160 | 161 | // Pre-compute lookup table for NOT function 162 | notFunc := func(x int) int { return 1 - x } 163 | lookupTable := gen.GenLookUpTable(notFunc) 164 | 165 | // Apply to multiple inputs using the same LUT 166 | inputs := []int{0, 1, 0, 1, 0} 167 | 168 | for i, input := range inputs { 169 | ct := tlwe.NewTLWELv0() 170 | ct.EncryptLWEMessage(input, 2, params.GetTLWELv0().ALPHA, secretKey.KeyLv0) 171 | 172 | // Use pre-computed LUT 173 | result := eval.BootstrapLUT( 174 | ct, 175 | lookupTable, 176 | cloudKey.BootstrappingKey, 177 | cloudKey.KeySwitchingKey, 178 | cloudKey.DecompositionOffset, 179 | ) 180 | 181 | decrypted := result.DecryptLWEMessage(2, secretKey.KeyLv0) 182 | expected := 1 - input 183 | 184 | if decrypted != expected { 185 | t.Errorf("test %d: NOT(%d) = %d, want %d", i, input, decrypted, expected) 186 | } 187 | } 188 | } 189 | 190 | func TestModSwitch(t *testing.T) { 191 | gen := lut.NewGenerator(2) 192 | n := params.GetTRGSWLv1().N 193 | 194 | // Test that ModSwitch returns values in valid range 195 | tests := []params.Torus{ 196 | 0, 197 | 1 << 30, 198 | 1 << 31, 199 | 3 << 30, 200 | params.Torus(^uint32(0)), 201 | } 202 | 203 | for _, val := range tests { 204 | result := gen.ModSwitch(val) 205 | if result < 0 || result >= n { 206 | t.Errorf("ModSwitch(%d) = %d, out of range [0, %d)", val, result, n) 207 | } 208 | } 209 | } 210 | 211 | // Benchmark programmable bootstrapping performance 212 | func BenchmarkProgrammableBootstrap(b *testing.B) { 213 | params.CurrentSecurityLevel = params.Security80Bit 214 | 215 | secretKey := key.NewSecretKey() 216 | cloudKey := cloudkey.NewCloudKey(secretKey) 217 | eval := NewEvaluator(params.GetTRGSWLv1().N) 218 | 219 | // Create input ciphertext 220 | ct := tlwe.NewTLWELv0() 221 | ct.EncryptLWEMessage(1, 2, params.GetTLWELv0().ALPHA, secretKey.KeyLv0) 222 | 223 | // Identity function 224 | identity := func(x int) int { return x } 225 | 226 | b.ResetTimer() 227 | for i := 0; i < b.N; i++ { 228 | _ = eval.BootstrapFunc( 229 | ct, 230 | identity, 231 | 2, 232 | cloudKey.BootstrappingKey, 233 | cloudKey.KeySwitchingKey, 234 | cloudKey.DecompositionOffset, 235 | ) 236 | } 237 | } 238 | 239 | // Benchmark LUT reuse 240 | func BenchmarkBootstrapLUT(b *testing.B) { 241 | params.CurrentSecurityLevel = params.Security80Bit 242 | 243 | secretKey := key.NewSecretKey() 244 | cloudKey := cloudkey.NewCloudKey(secretKey) 245 | eval := NewEvaluator(params.GetTRGSWLv1().N) 246 | gen := lut.NewGenerator(2) 247 | 248 | // Pre-compute LUT 249 | identity := func(x int) int { return x } 250 | lookupTable := gen.GenLookUpTable(identity) 251 | 252 | // Create input ciphertext 253 | ct := tlwe.NewTLWELv0() 254 | ct.EncryptLWEMessage(1, 2, params.GetTLWELv0().ALPHA, secretKey.KeyLv0) 255 | 256 | b.ResetTimer() 257 | for i := 0; i < b.N; i++ { 258 | _ = eval.BootstrapLUT( 259 | ct, 260 | lookupTable, 261 | cloudKey.BootstrappingKey, 262 | cloudKey.KeySwitchingKey, 263 | cloudKey.DecompositionOffset, 264 | ) 265 | } 266 | } 267 | -------------------------------------------------------------------------------- /examples/EXAMPLES_GUIDE.md: -------------------------------------------------------------------------------- 1 | # Examples Guide 2 | 3 | This directory contains demonstrations of go-tfhe capabilities, from basic gates to advanced programmable bootstrapping. 4 | 5 | ## Available Examples 6 | 7 | ### 1. `simple_gates/` - Boolean Gate Demonstrations 8 | 9 | **What it does**: Tests all basic homomorphic boolean gates 10 | 11 | **Gates demonstrated**: 12 | - AND, OR, NAND, NOR 13 | - XOR, XNOR 14 | - NOT 15 | 16 | **Run**: `make run-gates` 17 | 18 | **Time**: ~10 seconds (tests all gates on all input combinations) 19 | 20 | **Best for**: Learning basic TFHE gates 21 | 22 | --- 23 | 24 | ### 2. `add_two_numbers/` - Traditional 8-bit Addition 25 | 26 | **What it does**: 8-bit addition using traditional ripple-carry adder 27 | 28 | **Method**: Bit-by-bit with boolean gates (XOR, AND, OR) 29 | 30 | **Operations**: 40 gates (5 per bit × 8 bits) 31 | 32 | **Run**: `make run-add` 33 | 34 | **Time**: ~1.1 seconds 35 | 36 | **Best for**: Understanding traditional TFHE approach and why PBS is revolutionary 37 | 38 | **Example output**: 39 | ``` 40 | Computing: 42 + 137 = 179 41 | Operations: 40 boolean gates 42 | Time: ~1.1s 43 | ✅ SUCCESS! 44 | ``` 45 | 46 | --- 47 | 48 | ### 3. `add_8bit_pbs/` - Fast 8-bit Addition with PBS ⭐ 49 | 50 | **What it does**: 8-bit addition using Programmable Bootstrapping (PBS) 51 | 52 | **Method**: Nibble-based (processes 4 bits at once) 53 | 54 | **Operations**: 3 programmable bootstraps 55 | 56 | **Parameters**: `SecurityUint5` (messageModulus=32, N=2048) 57 | 58 | **Run**: `make run-add-pbs` 59 | 60 | **Time**: ~230ms 61 | 62 | **Best for**: Seeing the dramatic PBS performance advantage 63 | 64 | **Example output**: 65 | ``` 66 | Computing: 42 + 137 = 179 67 | Input A: 42 = 0b0010_1010 (nibbles: high=2, low=10) 68 | Input B: 137 = 0b1000_1001 (nibbles: high=8, low=9) 69 | 70 | Steps: 71 | 1. Encrypt nibbles (4 nibbles) 72 | 2. Add low nibbles (homomorphic, no bootstrap) 73 | 3. Bootstrap 1: Extract low sum (mod 16) 74 | 4. Bootstrap 2: Extract carry bit 75 | 5. Add high nibbles + carry (homomorphic) 76 | 6. Bootstrap 3: Extract high sum (mod 16) 77 | 7. Combine nibbles 78 | 79 | Result: 179 = 0b1011_0011 80 | Time: ~230ms (3 bootstraps) 81 | ✅ SUCCESS! 82 | 83 | Speedup: 4.8x faster than traditional method! 🚀 84 | ``` 85 | 86 | --- 87 | 88 | ### 4. `programmable_bootstrap/` - PBS Feature Demonstrations 89 | 90 | **What it does**: Comprehensive PBS feature demonstrations 91 | 92 | **Features shown**: 93 | - Identity function (noise refresh) 94 | - NOT function (bit flip) 95 | - Constant functions 96 | - LUT reuse 97 | - Multi-bit messages (messageModulus=4) 98 | 99 | **Run**: `make run-pbs` 100 | 101 | **Time**: ~2-3 seconds (multiple demonstrations) 102 | 103 | **Best for**: Learning PBS concepts and usage patterns 104 | 105 | --- 106 | 107 | ## Comparison Table 108 | 109 | | Example | Method | Operations | Time | Speedup | Use Case | 110 | |---------|--------|-----------|------|---------|----------| 111 | | `simple_gates` | Boolean gates | Varies | ~10s | Baseline | Learn gates | 112 | | `add_two_numbers` | Ripple-carry | 40 gates | ~1.1s | 1x | Traditional method | 113 | | **`add_8bit_pbs`** | **PBS nibbles** | **3 PBS** | **~230ms** | **4.8x** ⭐ | **Fast arithmetic** | 114 | | `programmable_bootstrap` | PBS | Varies | ~2-3s | N/A | Learn PBS | 115 | 116 | ## Quick Start 117 | 118 | ```bash 119 | # 1. Start simple - learn the gates 120 | make run-gates 121 | 122 | # 2. See traditional approach 123 | make run-add 124 | 125 | # 3. See the PBS revolution! 126 | make run-add-pbs 127 | 128 | # 4. Explore PBS features 129 | make run-pbs 130 | ``` 131 | 132 | ## Understanding the Speedup 133 | 134 | ### Traditional Method (`add_two_numbers`) 135 | ``` 136 | Process: Bit-by-bit ripple carry 137 | - For each bit (0-7): 138 | - XOR(a, b) → 1 bootstrap 139 | - XOR(ab, c) → 1 bootstrap 140 | - AND(a, b) → 1 bootstrap 141 | - AND(c, ab) → 1 bootstrap 142 | - OR(...) → 1 bootstrap 143 | Total: 5 gates × 8 bits = 40 bootstraps 144 | ``` 145 | 146 | ### PBS Method (`add_8bit_pbs`) 147 | ``` 148 | Process: Nibble-based with programmable bootstrapping 149 | - Split into 4-bit chunks (nibbles) 150 | - Add low nibbles → PBS extract sum & carry (2 bootstraps) 151 | - Add high nibbles → PBS extract sum (1 bootstrap) 152 | Total: 3 bootstraps 153 | 154 | Why faster? 155 | - Processes 4 bits at once instead of 1 bit 156 | - LUTs encode multiple operations in single bootstrap 157 | - 13x fewer bootstraps = massive speedup! 158 | ``` 159 | 160 | ## Recommended Learning Path 161 | 162 | 1. **Start**: `simple_gates` - Understand basic operations 163 | 2. **Traditional**: `add_two_numbers` - See how addition works bit-by-bit 164 | 3. **Modern**: `add_8bit_pbs` - See the PBS advantage 165 | 4. **Deep Dive**: `programmable_bootstrap` - Explore PBS features 166 | 167 | ## Extending These Examples 168 | 169 | ### Build Your Own Operations 170 | 171 | Using `add_8bit_pbs` as a template, you can create: 172 | 173 | **8-bit Subtraction:** 174 | ```go 175 | // Use LUT: f(x) = x % 16 for difference 176 | // Handle borrow instead of carry 177 | ``` 178 | 179 | **8-bit Multiplication:** 180 | ```go 181 | // Decompose into nibbles 182 | // Use shift-and-add algorithm 183 | // ~12-16 bootstraps 184 | ``` 185 | 186 | **8-bit Comparison:** 187 | ```go 188 | // LUT: f(x) = x >= threshold ? 1 : 0 189 | // Single bootstrap per nibble 190 | ``` 191 | 192 | ## Parameter Selection for Examples 193 | 194 | | Example | Parameter Used | Reason | 195 | |---------|---------------|---------| 196 | | `simple_gates` | `Security128Bit` | Binary operations | 197 | | `add_two_numbers` | `Security128Bit` | Binary operations | 198 | | `add_8bit_pbs` | **`SecurityUint5`** | Needs messageModulus=32 | 199 | | `programmable_bootstrap` | `Security80Bit` | Faster demo | 200 | 201 | ## Performance Notes 202 | 203 | All times are approximate and depend on hardware: 204 | - Measured on: Modern CPU (2020+) 205 | - Key generation: One-time cost 206 | - Bootstrap times: Consistent per operation 207 | - Can be parallelized for multiple operations 208 | 209 | ## Next Steps 210 | 211 | After running the examples: 212 | 1. Read `PARAMETER_GUIDE.md` for parameter selection 213 | 2. Check `README.md` for API documentation 214 | 3. See `FINAL_STATUS.md` for complete library status 215 | 4. Build your own homomorphic applications! 216 | 217 | --- 218 | 219 | **Start exploring: `make run-gates`** 🚀 220 | 221 | -------------------------------------------------------------------------------- /examples/add_two_numbers/README.md: -------------------------------------------------------------------------------- 1 | # Fast 8-bit Addition with Programmable Bootstrapping 2 | 3 | This example demonstrates **high-performance 8-bit homomorphic addition** using Programmable Bootstrapping (PBS) with the nibble-based method. 4 | 5 | ## What This Example Does 6 | 7 | Computes `42 + 137 = 179` using only **3-4 programmable bootstraps**: 8 | 9 | 1. Splits each 8-bit number into two 4-bit nibbles (low and high) 10 | 2. Encrypts nibbles with `messageModulus=32` (Uint5 parameters) 11 | 3. Adds low nibbles and extracts carry using PBS 12 | 4. Adds high nibbles with carry using PBS 13 | 5. Combines results into final 8-bit sum 14 | 15 | ## Algorithm: Nibble-Based Addition 16 | 17 | ``` 18 | Input: a, b (8-bit unsigned integers) 19 | 20 | Step 1: Split into nibbles 21 | a_low = a & 0x0F (bits 0-3) 22 | a_high = (a >> 4) & 0x0F (bits 4-7) 23 | b_low = b & 0x0F 24 | b_high = (b >> 4) & 0x0F 25 | 26 | Step 2: Add low nibbles (Bootstrap 1 & 2) 27 | temp_low = a_low + b_low (homomorphic addition, no bootstrap) 28 | sum_low = PBS(temp_low, LUT: x % 16) // Bootstrap 1 29 | carry = PBS(temp_low, LUT: x >= 16 ? 1 : 0) // Bootstrap 2 30 | 31 | Step 3: Add high nibbles with carry (Bootstrap 3) 32 | temp_high = a_high + b_high + carry 33 | sum_high = PBS(temp_high, LUT: x % 16) // Bootstrap 3 34 | 35 | Step 4: Combine nibbles 36 | result = sum_low | (sum_high << 4) 37 | 38 | Total: 3 programmable bootstraps! 39 | ``` 40 | 41 | ## Parameters Used 42 | 43 | - **Security Level**: `SecurityUint5` 44 | - **messageModulus**: 32 (supports values 0-31) 45 | - **Polynomial Degree**: N=2048 46 | - **LWE Dimension**: 1071 47 | - **Noise Level**: 7.09e-08 (~700x lower than standard) 48 | 49 | ## Performance 50 | 51 | ### This Example (PBS Method) 52 | - **Bootstraps**: 3-4 (depends on overflow tracking) 53 | - **Time**: ~230ms for 8-bit addition 54 | - **Method**: Nibble-based (4 bits at a time) 55 | 56 | ### Comparison with Traditional Method 57 | - **Traditional** (`examples/add_two_numbers`): 40 gates, ~1.1s 58 | - **PBS Method** (this example): 3 bootstraps, ~230ms 59 | - **Speedup**: **~4.8x faster!** 🚀 60 | 61 | ## Running the Example 62 | 63 | ```bash 64 | cd examples/add_8bit_pbs 65 | go run main.go 66 | ``` 67 | 68 | Or using Makefile: 69 | ```bash 70 | make run-add-pbs 71 | ``` 72 | 73 | ## Expected Output 74 | 75 | ``` 76 | ╔════════════════════════════════════════════════════════════════╗ 77 | ║ Fast 8-bit Addition Using Programmable Bootstrapping ║ 78 | ║ Nibble-Based Method (4 Bootstraps) ║ 79 | ╚════════════════════════════════════════════════════════════════╝ 80 | 81 | Security Level: Uint5 parameters (5-bit messages, messageModulus=32, N=2048) 82 | 83 | ⏱️ Generating keys... 84 | Key generation completed in 5.2s 85 | 86 | 📋 Generating lookup tables... 87 | LUT generation completed in 15µs 88 | 89 | Test Case 1: 42 + 137 = 179 90 | ───────────────────────────────────────────────────────── 91 | a: 42 = 0010_1010 (nibbles: 2, 10) 92 | b: 137 = 1000_1001 (nibbles: 8, 9) 93 | 94 | Encryption: 85µs (4 nibbles) 95 | Bootstrap 1 (low sum): 58ms 96 | Bootstrap 2 (low carry): 57ms 97 | Bootstrap 3 (high sum): 59ms 98 | Decryption: 4µs 99 | 100 | Result: 179 = 1011_0011 (nibbles: 11, 3) 101 | Total time: 174ms (3 bootstraps) 102 | ✅ CORRECT! 103 | 104 | Test Case 2: 0 + 0 = 0 105 | ───────────────────────────────────────────────────────── 106 | Result: 0 107 | ✅ CORRECT! 108 | 109 | Test Case 3: 255 + 1 = 0 110 | ───────────────────────────────────────────────────────── 111 | Result: 0 112 | ✅ CORRECT! (overflow handled correctly) 113 | 114 | Test Case 4: 128 + 127 = 255 115 | ───────────────────────────────────────────────────────── 116 | Result: 255 117 | ✅ CORRECT! 118 | 119 | Test Case 5: 15 + 15 = 30 120 | ───────────────────────────────────────────────────────── 121 | Result: 30 122 | ✅ CORRECT! 123 | 124 | ═══════════════════════════════════════════════════════════════ 125 | PERFORMANCE COMPARISON 126 | ═══════════════════════════════════════════════════════════════ 127 | 128 | Traditional Bit-by-Bit (examples/add_two_numbers): 129 | • Method: 40 boolean gates (XOR, AND, OR) 130 | • Bootstraps: ~40 (1 per gate) 131 | • Time: ~1.1 seconds 132 | 133 | PBS Nibble-Based (this example): 134 | • Method: 3-4 programmable bootstraps 135 | • Bootstraps: 3-4 (processes 4 bits at once) 136 | • Time: ~230ms 137 | 138 | 🚀 Speedup: ~4.8x faster with PBS! 139 | 140 | 💡 KEY INSIGHT: 141 | PBS processes multiple bits simultaneously using lookup tables, 142 | dramatically reducing the number of operations needed. 143 | 144 | Traditional: 1 bit per operation (40 ops for 8-bit) 145 | PBS Method: 4 bits per operation (3-4 ops for 8-bit) 146 | 147 | ✨ This is the power of programmable bootstrapping! 148 | ``` 149 | 150 | ## How It Works 151 | 152 | ### Nibble Decomposition 153 | 154 | An 8-bit number is split into two 4-bit nibbles: 155 | ``` 156 | Value: 179 = 10110011 157 | ↓ 158 | High: 1011 (11) 159 | Low: 0011 (3) 160 | ``` 161 | 162 | ### Why messageModulus=32? 163 | 164 | - Each nibble is 4 bits: values 0-15 165 | - Addition can produce 0-30: (15 + 15 = 30) 166 | - Need messageModulus ≥ 31 167 | - Uint5 provides messageModulus=32 ✅ 168 | 169 | ### Lookup Table Functions 170 | 171 | **Sum Modulo 16**: `f(x) = x % 16` 172 | - Input: 0-30 (sum of two nibbles) 173 | - Output: 0-15 (result mod 16) 174 | 175 | **Carry Detection**: `f(x) = x >= 16 ? 1 : 0` 176 | - Input: 0-30 177 | - Output: 0 or 1 (carry bit) 178 | 179 | ## Key Advantages 180 | 181 | 1. **10x Fewer Operations** - 3-4 vs 40 operations 182 | 2. **4.8x Faster** - ~230ms vs ~1.1s 183 | 3. **Scalable** - Same technique works for 16-bit, 32-bit, etc. 184 | 4. **Flexible** - Can implement any arithmetic function 185 | 186 | ## Extending to Larger Integers 187 | 188 | ### 16-bit Addition 189 | ```go 190 | // Split into 4 nibbles 191 | // Need ~6-8 bootstraps 192 | // Still much faster than 80-gate traditional method 193 | ``` 194 | 195 | ### 32-bit Addition 196 | ```go 197 | // Split into 8 nibbles 198 | // Need ~14-16 bootstraps 199 | // vs ~160 gates traditionally! 200 | ``` 201 | 202 | ## Technical Details 203 | 204 | **Homomorphic Addition (No Bootstrap):** 205 | ```go 206 | // Adding ciphertexts is just adding their components 207 | for i := 0; i < n+1; i++ { 208 | ctSum.P[i] = ctA.P[i] + ctB.P[i] 209 | } 210 | ``` 211 | 212 | **Programmable Bootstrap:** 213 | ```go 214 | // Refresh noise AND apply function 215 | result := eval.BootstrapLUT(ct, lut, 216 | cloudKey.BootstrappingKey, 217 | cloudKey.KeySwitchingKey, 218 | cloudKey.DecompositionOffset) 219 | ``` 220 | 221 | ## Comparison with Reference 222 | 223 | This implementation matches the algorithm in: 224 | - `tfhe-go/examples/adder_8bit_fast.go` 225 | 226 | Both achieve 4-bootstrap 8-bit addition using Uint5 parameters! 227 | 228 | ## Next Steps 229 | 230 | Try modifying this example to: 231 | - Add 16-bit numbers (use more nibbles) 232 | - Implement subtraction 233 | - Create multiplication using repeated addition 234 | - Build a simple calculator 235 | 236 | The PBS framework makes all of this possible with excellent performance! 237 | 238 | -------------------------------------------------------------------------------- /examples/add_two_numbers/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/thedonutfactory/go-tfhe/cloudkey" 8 | "github.com/thedonutfactory/go-tfhe/evaluator" 9 | "github.com/thedonutfactory/go-tfhe/key" 10 | "github.com/thedonutfactory/go-tfhe/lut" 11 | "github.com/thedonutfactory/go-tfhe/params" 12 | "github.com/thedonutfactory/go-tfhe/tlwe" 13 | ) 14 | 15 | func main() { 16 | fmt.Println("╔════════════════════════════════════════════════════════════════╗") 17 | fmt.Println("║ Fast 8-bit Addition Using Programmable Bootstrapping ║") 18 | fmt.Println("╚════════════════════════════════════════════════════════════════╝") 19 | fmt.Println() 20 | 21 | // Use Uint5 parameters for messageModulus=32 22 | params.CurrentSecurityLevel = params.SecurityUint5 23 | fmt.Printf("Security Level: %s\n", params.SecurityInfo()) 24 | fmt.Println() 25 | 26 | // Generate keys 27 | fmt.Println("⏱️ Generating keys...") 28 | keyStart := time.Now() 29 | secretKey := key.NewSecretKey() 30 | cloudKey := cloudkey.NewCloudKey(secretKey) 31 | eval := evaluator.NewEvaluator(params.GetTRGSWLv1().N) 32 | keyDuration := time.Since(keyStart) 33 | fmt.Printf(" Key generation completed in %v\n", keyDuration) 34 | fmt.Println() 35 | 36 | // Inputs 37 | a := uint8(42) 38 | b := uint8(137) 39 | expected := uint8(179) 40 | 41 | fmt.Printf("Computing: %d + %d = %d (encrypted)\n", a, b, expected) 42 | fmt.Println() 43 | 44 | // Step 1: Split into nibbles (4-bit chunks) 45 | aLow := int(a & 0x0F) // Low nibble of a (bits 0-3) 46 | aHigh := int((a >> 4) & 0x0F) // High nibble of a (bits 4-7) 47 | bLow := int(b & 0x0F) // Low nibble of b 48 | bHigh := int((b >> 4) & 0x0F) // High nibble of b 49 | 50 | fmt.Printf("Input A: %3d = 0b%04b_%04b (nibbles: high=%d, low=%d)\n", a, aHigh, aLow, aHigh, aLow) 51 | fmt.Printf("Input B: %3d = 0b%04b_%04b (nibbles: high=%d, low=%d)\n", b, bHigh, bLow, bHigh, bLow) 52 | fmt.Println() 53 | 54 | // Step 2: Generate lookup tables 55 | fmt.Println("📋 Generating lookup tables...") 56 | lutStart := time.Now() 57 | gen := lut.NewGenerator(32) 58 | 59 | lutSumLow := gen.GenLookUpTable(func(x int) int { 60 | return x % 16 // Extract lower 4 bits 61 | }) 62 | 63 | lutCarryLow := gen.GenLookUpTable(func(x int) int { 64 | if x >= 16 { 65 | return 1 // Carry out 66 | } 67 | return 0 68 | }) 69 | 70 | lutSumHigh := gen.GenLookUpTable(func(x int) int { 71 | return x % 16 // Extract lower 4 bits 72 | }) 73 | 74 | lutDuration := time.Since(lutStart) 75 | fmt.Printf(" LUT generation: %v\n", lutDuration) 76 | fmt.Println() 77 | 78 | // Step 3: Encrypt nibbles 79 | fmt.Println("🔒 Encrypting nibbles...") 80 | encStart := time.Now() 81 | 82 | ctALow := tlwe.NewTLWELv0() 83 | ctALow.EncryptLWEMessage(aLow, 32, params.GetTLWELv0().ALPHA, secretKey.KeyLv0) 84 | 85 | ctAHigh := tlwe.NewTLWELv0() 86 | ctAHigh.EncryptLWEMessage(aHigh, 32, params.GetTLWELv0().ALPHA, secretKey.KeyLv0) 87 | 88 | ctBLow := tlwe.NewTLWELv0() 89 | ctBLow.EncryptLWEMessage(bLow, 32, params.GetTLWELv0().ALPHA, secretKey.KeyLv0) 90 | 91 | ctBHigh := tlwe.NewTLWELv0() 92 | ctBHigh.EncryptLWEMessage(bHigh, 32, params.GetTLWELv0().ALPHA, secretKey.KeyLv0) 93 | 94 | encDuration := time.Since(encStart) 95 | fmt.Printf(" Encrypted 4 nibbles in %v\n", encDuration) 96 | fmt.Println() 97 | 98 | // Step 4: Homomorphic addition of low nibbles (no bootstrap needed!) 99 | fmt.Println("➕ Computing encrypted addition...") 100 | addStart := time.Now() 101 | 102 | n := params.GetTLWELv0().N 103 | ctTempLow := tlwe.NewTLWELv0() 104 | for j := 0; j < n+1; j++ { 105 | ctTempLow.P[j] = ctALow.P[j] + ctBLow.P[j] 106 | } 107 | fmt.Println(" Step 1: Low nibbles added (homomorphic add, no bootstrap)") 108 | 109 | // Step 5: Bootstrap 1 - Extract low sum (mod 16) 110 | pbs1Start := time.Now() 111 | ctSumLow := eval.BootstrapLUT(ctTempLow, lutSumLow, 112 | cloudKey.BootstrappingKey, cloudKey.KeySwitchingKey, cloudKey.DecompositionOffset) 113 | pbs1Duration := time.Since(pbs1Start) 114 | fmt.Printf(" Bootstrap 1: Extract low sum (mod 16) - %v\n", pbs1Duration) 115 | 116 | // Step 6: Bootstrap 2 - Extract carry from low nibbles 117 | pbs2Start := time.Now() 118 | ctCarry := eval.BootstrapLUT(ctTempLow, lutCarryLow, 119 | cloudKey.BootstrappingKey, cloudKey.KeySwitchingKey, cloudKey.DecompositionOffset) 120 | pbs2Duration := time.Since(pbs2Start) 121 | fmt.Printf(" Bootstrap 2: Extract carry bit - %v\n", pbs2Duration) 122 | 123 | // Step 7: Add high nibbles + carry (homomorphic) 124 | ctTempHigh := tlwe.NewTLWELv0() 125 | for j := 0; j < n+1; j++ { 126 | ctTempHigh.P[j] = ctAHigh.P[j] + ctBHigh.P[j] + ctCarry.P[j] 127 | } 128 | fmt.Println(" Step 2: High nibbles + carry added (homomorphic add, no bootstrap)") 129 | 130 | // Step 8: Bootstrap 3 - Extract high sum (mod 16) 131 | pbs3Start := time.Now() 132 | ctSumHigh := eval.BootstrapLUT(ctTempHigh, lutSumHigh, 133 | cloudKey.BootstrappingKey, cloudKey.KeySwitchingKey, cloudKey.DecompositionOffset) 134 | pbs3Duration := time.Since(pbs3Start) 135 | fmt.Printf(" Bootstrap 3: Extract high sum (mod 16) - %v\n", pbs3Duration) 136 | 137 | addDuration := time.Since(addStart) 138 | fmt.Println() 139 | 140 | // Step 9: Decrypt results 141 | fmt.Println("🔓 Decrypting result...") 142 | decStart := time.Now() 143 | 144 | sumLow := ctSumLow.DecryptLWEMessage(32, secretKey.KeyLv0) 145 | sumHigh := ctSumHigh.DecryptLWEMessage(32, secretKey.KeyLv0) 146 | 147 | decDuration := time.Since(decStart) 148 | fmt.Printf(" Decrypted nibbles in %v\n", decDuration) 149 | fmt.Println() 150 | 151 | // Step 10: Combine nibbles into final result 152 | result := uint8(sumLow | (sumHigh << 4)) 153 | 154 | fmt.Println("═══════════════════════════════════════════════════════════════") 155 | fmt.Println("RESULTS") 156 | fmt.Println("═══════════════════════════════════════════════════════════════") 157 | fmt.Printf("Input A: %3d = 0b%04b_%04b\n", a, aHigh, aLow) 158 | fmt.Printf("Input B: %3d = 0b%04b_%04b\n", b, bHigh, bLow) 159 | fmt.Printf("Result: %3d = 0b%04b_%04b (nibbles: high=%d, low=%d)\n", 160 | result, sumHigh, sumLow, sumHigh, sumLow) 161 | fmt.Printf("Expected: %3d\n", expected) 162 | fmt.Println() 163 | 164 | if result == expected { 165 | fmt.Println("✅ SUCCESS! Result is correct!") 166 | } else { 167 | fmt.Printf("❌ FAILURE! Expected %d, got %d\n", expected, result) 168 | } 169 | 170 | fmt.Println() 171 | fmt.Println("═══════════════════════════════════════════════════════════════") 172 | fmt.Println("PERFORMANCE SUMMARY") 173 | fmt.Println("═══════════════════════════════════════════════════════════════") 174 | fmt.Printf("Key Generation: %v\n", keyDuration) 175 | fmt.Printf("LUT Generation: %v\n", lutDuration) 176 | fmt.Printf("Encryption: %v (4 nibbles)\n", encDuration) 177 | fmt.Printf("Addition: %v (3 bootstraps)\n", addDuration) 178 | fmt.Printf(" - Bootstrap 1: %v (low sum)\n", pbs1Duration) 179 | fmt.Printf(" - Bootstrap 2: %v (carry)\n", pbs2Duration) 180 | fmt.Printf(" - Bootstrap 3: %v (high sum)\n", pbs3Duration) 181 | fmt.Printf("Decryption: %v\n", decDuration) 182 | fmt.Println() 183 | 184 | } 185 | -------------------------------------------------------------------------------- /examples/programmable_bootstrap/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/thedonutfactory/go-tfhe/cloudkey" 8 | "github.com/thedonutfactory/go-tfhe/evaluator" 9 | "github.com/thedonutfactory/go-tfhe/key" 10 | "github.com/thedonutfactory/go-tfhe/lut" 11 | "github.com/thedonutfactory/go-tfhe/params" 12 | "github.com/thedonutfactory/go-tfhe/tlwe" 13 | ) 14 | 15 | func main() { 16 | fmt.Println("=== Programmable Bootstrapping Demo ===") 17 | fmt.Println() 18 | 19 | // Use 80-bit security for faster demo 20 | params.CurrentSecurityLevel = params.Security80Bit 21 | fmt.Printf("Security Level: %s\n", params.SecurityInfo()) 22 | fmt.Println() 23 | 24 | // Generate keys 25 | fmt.Println("Generating keys...") 26 | startKey := time.Now() 27 | secretKey := key.NewSecretKey() 28 | cloudKey := cloudkey.NewCloudKey(secretKey) 29 | fmt.Printf("Key generation took: %v\n", time.Since(startKey)) 30 | fmt.Println() 31 | 32 | // Create evaluator 33 | eval := evaluator.NewEvaluator(params.GetTRGSWLv1().N) 34 | 35 | // Example 1: Identity function 36 | fmt.Println("Example 1: Identity Function (f(x) = x)") 37 | fmt.Println("This refreshes noise while preserving the value") 38 | identity := func(x int) int { return x } 39 | demoFunction(eval, secretKey, cloudKey, identity, "identity", 0, 1) 40 | 41 | // Example 2: NOT function 42 | fmt.Println("\nExample 2: NOT Function (f(x) = 1 - x)") 43 | fmt.Println("This flips the bit during bootstrapping") 44 | notFunc := func(x int) int { return 1 - x } 45 | demoFunction(eval, secretKey, cloudKey, notFunc, "NOT", 0, 1) 46 | 47 | // Example 3: Constant function 48 | fmt.Println("\nExample 3: Constant Function (f(x) = 1)") 49 | fmt.Println("This always returns 1, regardless of input") 50 | constantOne := func(x int) int { return 1 } 51 | demoFunction(eval, secretKey, cloudKey, constantOne, "constant(1)", 0, 1) 52 | 53 | // Example 4: AND with constant (simulation) 54 | fmt.Println("\nExample 4: Constant Function (f(x) = 0)") 55 | fmt.Println("This always returns 0") 56 | constantZero := func(x int) int { return 0 } 57 | demoFunction(eval, secretKey, cloudKey, constantZero, "constant(0)", 0, 1) 58 | 59 | // Example 5: LUT reuse demonstration 60 | fmt.Println("\nExample 5: Lookup Table Reuse") 61 | fmt.Println("Pre-compute LUT once, use multiple times for efficiency") 62 | demoLUTReuse(eval, secretKey, cloudKey) 63 | 64 | // Example 6: Multi-bit messages (4 values) 65 | fmt.Println("\nExample 6: Multi-bit Messages (2-bit values)") 66 | demoMultiBit(eval, secretKey, cloudKey) 67 | 68 | fmt.Println("\n=== Demo Complete ===") 69 | fmt.Println("\nNote: Programmable bootstrapping uses general LWE message encoding") 70 | fmt.Println("(message * scale), not binary boolean encoding (±1/8).") 71 | fmt.Println("Use EncryptLWEMessage() for encryption and DecryptLWEMessage() for decryption.") 72 | } 73 | 74 | func demoFunction(eval *evaluator.Evaluator, secretKey *key.SecretKey, cloudKey *cloudkey.CloudKey, 75 | f func(int) int, name string, inputs ...int) { 76 | 77 | for i, input := range inputs { 78 | // Encrypt input using LWE message encoding 79 | ct := tlwe.NewTLWELv0() 80 | ct.EncryptLWEMessage(input, 2, params.GetTLWELv0().ALPHA, secretKey.KeyLv0) 81 | 82 | // Apply programmable bootstrap 83 | start := time.Now() 84 | result := eval.BootstrapFunc( 85 | ct, 86 | f, 87 | 2, // binary (message modulus = 2) 88 | cloudKey.BootstrappingKey, 89 | cloudKey.KeySwitchingKey, 90 | cloudKey.DecompositionOffset, 91 | ) 92 | elapsed := time.Since(start) 93 | 94 | // Decrypt using LWE message decoding 95 | output := result.DecryptLWEMessage(2, secretKey.KeyLv0) 96 | 97 | fmt.Printf(" Input %d: %d → %s(%d) = %d (took %v)\n", 98 | i+1, input, name, input, output, elapsed) 99 | } 100 | } 101 | 102 | func demoLUTReuse(eval *evaluator.Evaluator, secretKey *key.SecretKey, cloudKey *cloudkey.CloudKey) { 103 | // Pre-compute lookup table for NOT function 104 | gen := lut.NewGenerator(2) 105 | notFunc := func(x int) int { return 1 - x } 106 | 107 | fmt.Println(" Pre-computing NOT lookup table...") 108 | start := time.Now() 109 | lookupTable := gen.GenLookUpTable(notFunc) 110 | lutTime := time.Since(start) 111 | fmt.Printf(" LUT generation took: %v\n", lutTime) 112 | 113 | // Apply to multiple inputs using the same LUT 114 | inputs := []int{0, 1, 0, 1, 0} 115 | 116 | var totalBootstrapTime time.Duration 117 | for i, input := range inputs { 118 | ct := tlwe.NewTLWELv0() 119 | ct.EncryptLWEMessage(input, 2, params.GetTLWELv0().ALPHA, secretKey.KeyLv0) 120 | 121 | start := time.Now() 122 | result := eval.BootstrapLUT( 123 | ct, 124 | lookupTable, 125 | cloudKey.BootstrappingKey, 126 | cloudKey.KeySwitchingKey, 127 | cloudKey.DecompositionOffset, 128 | ) 129 | elapsed := time.Since(start) 130 | totalBootstrapTime += elapsed 131 | 132 | output := result.DecryptLWEMessage(2, secretKey.KeyLv0) 133 | fmt.Printf(" Input %d: %d → NOT(%d) = %d (took %v)\n", 134 | i+1, input, input, output, elapsed) 135 | } 136 | 137 | avgTime := totalBootstrapTime / time.Duration(len(inputs)) 138 | fmt.Printf(" Average bootstrap time: %v\n", avgTime) 139 | fmt.Println(" ✓ LUT reuse avoids recomputing the lookup table!") 140 | } 141 | 142 | func demoMultiBit(eval *evaluator.Evaluator, secretKey *key.SecretKey, cloudKey *cloudkey.CloudKey) { 143 | // Use 2-bit messages (values 0, 1, 2, 3) 144 | messageModulus := 4 145 | 146 | // Function that increments by 1 (mod 4) 147 | increment := func(x int) int { return (x + 1) % 4 } 148 | 149 | fmt.Println(" Testing increment function: f(x) = (x + 1) mod 4") 150 | 151 | // Test a few values 152 | testInputs := []int{0, 1, 2, 3} 153 | 154 | for _, input := range testInputs { 155 | ct := tlwe.NewTLWELv0() 156 | ct.EncryptLWEMessage(input, messageModulus, params.GetTLWELv0().ALPHA, secretKey.KeyLv0) 157 | 158 | start := time.Now() 159 | result := eval.BootstrapFunc( 160 | ct, 161 | increment, 162 | messageModulus, 163 | cloudKey.BootstrappingKey, 164 | cloudKey.KeySwitchingKey, 165 | cloudKey.DecompositionOffset, 166 | ) 167 | elapsed := time.Since(start) 168 | 169 | output := result.DecryptLWEMessage(messageModulus, secretKey.KeyLv0) 170 | expected := increment(input) 171 | 172 | status := "✓" 173 | if output != expected { 174 | status = "✗" 175 | } 176 | 177 | fmt.Printf(" increment(%d) = %d (expected %d) %s (took %v)\n", 178 | input, output, expected, status, elapsed) 179 | } 180 | 181 | fmt.Println(" ✓ Framework supports arbitrary message moduli!") 182 | } 183 | -------------------------------------------------------------------------------- /examples/proxy_reencryption/main.go: -------------------------------------------------------------------------------- 1 | // Proxy Reencryption Example 2 | // 3 | // This example demonstrates how to use LWE proxy reencryption to securely 4 | // delegate access to encrypted data without decryption. 5 | // 6 | // Run with: 7 | // go run examples/proxy_reencryption/main.go 8 | 9 | package main 10 | 11 | import ( 12 | "fmt" 13 | "time" 14 | 15 | "github.com/thedonutfactory/go-tfhe/key" 16 | "github.com/thedonutfactory/go-tfhe/params" 17 | "github.com/thedonutfactory/go-tfhe/proxyreenc" 18 | "github.com/thedonutfactory/go-tfhe/tlwe" 19 | ) 20 | 21 | func main() { 22 | fmt.Println("=== LWE Proxy Reencryption Demo ===") 23 | fmt.Println() 24 | 25 | // Scenario: Alice wants to share encrypted data with Bob 26 | // without decrypting it, using a semi-trusted proxy 27 | 28 | fmt.Println("1. Setting up keys for Alice and Bob...") 29 | aliceKey := key.NewSecretKey() 30 | bobKey := key.NewSecretKey() 31 | fmt.Println(" ✓ Alice's secret key generated") 32 | 33 | // Bob publishes his public key 34 | start := time.Now() 35 | bobPublicKey := proxyreenc.NewPublicKeyLv0(bobKey.KeyLv0) 36 | pubkeyTime := time.Since(start) 37 | fmt.Printf(" ✓ Bob's public key generated in %.2fms\n", float64(pubkeyTime.Microseconds())/1000.0) 38 | fmt.Println(" ✓ Bob shares his public key (safe to publish)") 39 | fmt.Println() 40 | 41 | // Alice encrypts some data 42 | fmt.Println("2. Alice encrypts her data...") 43 | messages := []bool{true, false, true, true, false} 44 | aliceCiphertexts := make([]*tlwe.TLWELv0, len(messages)) 45 | 46 | for i, msg := range messages { 47 | ct := tlwe.NewTLWELv0() 48 | ct.EncryptBool(msg, params.GetTLWELv0().ALPHA, aliceKey.KeyLv0) 49 | aliceCiphertexts[i] = ct 50 | } 51 | 52 | fmt.Println(" Messages encrypted by Alice:") 53 | for i, msg := range messages { 54 | fmt.Printf(" - Message %d: %v\n", i+1, msg) 55 | } 56 | fmt.Println() 57 | 58 | // Alice generates a proxy reencryption key using Bob's PUBLIC key 59 | fmt.Println("3. Alice generates a proxy reencryption key (Alice -> Bob)...") 60 | fmt.Println(" Using ASYMMETRIC mode - Bob's secret key is NOT needed!") 61 | start = time.Now() 62 | reencKey := proxyreenc.NewProxyReencryptionKeyAsymmetric(aliceKey.KeyLv0, bobPublicKey) 63 | keygenTime := time.Since(start) 64 | fmt.Printf(" ✓ Reencryption key generated in %.2fms\n", float64(keygenTime.Microseconds())/1000.0) 65 | fmt.Println(" ✓ Alice shares this key with the proxy") 66 | fmt.Println() 67 | 68 | // Proxy reencrypts the data (without learning the plaintext) 69 | fmt.Println("4. Proxy converts Alice's ciphertexts to Bob's ciphertexts...") 70 | start = time.Now() 71 | bobCiphertexts := make([]*tlwe.TLWELv0, len(aliceCiphertexts)) 72 | for i, ct := range aliceCiphertexts { 73 | bobCiphertexts[i] = proxyreenc.ReencryptTLWELv0(ct, reencKey) 74 | } 75 | reencTime := time.Since(start) 76 | fmt.Printf(" ✓ %d ciphertexts reencrypted in %.2fms\n", len(bobCiphertexts), float64(reencTime.Microseconds())/1000.0) 77 | fmt.Printf(" ✓ Average time per reencryption: %.2fms\n\n", float64(reencTime.Microseconds())/float64(len(bobCiphertexts))/1000.0) 78 | 79 | // Bob decrypts the reencrypted data 80 | fmt.Println("5. Bob decrypts the reencrypted data...") 81 | correct := 0 82 | decryptedMessages := make([]bool, len(bobCiphertexts)) 83 | 84 | for i, ct := range bobCiphertexts { 85 | decryptedMessages[i] = ct.DecryptBool(bobKey.KeyLv0) 86 | } 87 | 88 | fmt.Println(" Decrypted messages:") 89 | for i, original := range messages { 90 | decrypted := decryptedMessages[i] 91 | status := "✗" 92 | if original == decrypted { 93 | correct++ 94 | status = "✓" 95 | } 96 | fmt.Printf(" %s Message %d: %v (original: %v)\n", status, i+1, decrypted, original) 97 | } 98 | fmt.Println() 99 | 100 | fmt.Println("=== Results ===") 101 | accuracy := float64(correct) / float64(len(messages)) * 100.0 102 | fmt.Printf("Accuracy: %d/%d (%.1f%%)\n", correct, len(messages), accuracy) 103 | fmt.Println() 104 | 105 | // Demonstrate multi-hop reencryption: Alice -> Bob -> Carol 106 | fmt.Println() 107 | fmt.Println("=== Multi-Hop Reencryption Demo (Asymmetric) ===") 108 | fmt.Println() 109 | fmt.Println("Demonstrating a chain: Alice -> Bob -> Carol") 110 | fmt.Println("Each party only needs the next party's PUBLIC key") 111 | fmt.Println() 112 | 113 | carolKey := key.NewSecretKey() 114 | carolPublicKey := proxyreenc.NewPublicKeyLv0(carolKey.KeyLv0) 115 | fmt.Println("1. Carol's keys generated and public key published") 116 | 117 | reencKeyBC := proxyreenc.NewProxyReencryptionKeyAsymmetric(bobKey.KeyLv0, carolPublicKey) 118 | fmt.Println("2. Generated reencryption key (Bob -> Carol) using Carol's PUBLIC key") 119 | 120 | testMessage := true 121 | aliceCt := tlwe.NewTLWELv0() 122 | aliceCt.EncryptBool(testMessage, params.GetTLWELv0().ALPHA, aliceKey.KeyLv0) 123 | fmt.Printf("3. Alice encrypts message: %v\n", testMessage) 124 | 125 | bobCt := proxyreenc.ReencryptTLWELv0(aliceCt, reencKey) 126 | fmt.Println("4. Proxy reencrypts Alice -> Bob") 127 | bobDecrypted := bobCt.DecryptBool(bobKey.KeyLv0) 128 | bobStatus := "✗" 129 | if bobDecrypted == testMessage { 130 | bobStatus = "✓" 131 | } 132 | fmt.Printf(" Bob decrypts: %v %s\n", bobDecrypted, bobStatus) 133 | 134 | carolCt := proxyreenc.ReencryptTLWELv0(bobCt, reencKeyBC) 135 | fmt.Println("5. Proxy reencrypts Bob -> Carol") 136 | carolDecrypted := carolCt.DecryptBool(carolKey.KeyLv0) 137 | carolStatus := "✗" 138 | if carolDecrypted == testMessage { 139 | carolStatus = "✓" 140 | } 141 | fmt.Printf(" Carol decrypts: %v %s\n", carolDecrypted, carolStatus) 142 | 143 | fmt.Println() 144 | fmt.Println("=== Security Notes ===") 145 | fmt.Println("• The proxy never learns the plaintext") 146 | fmt.Println("• Bob's secret key is NEVER shared - only his public key is used") 147 | fmt.Println("• The reencryption key only works in one direction") 148 | fmt.Println("• Each reencryption adds a small amount of noise") 149 | fmt.Println("• The scheme is unidirectional (Alice->Bob key ≠ Bob->Alice key)") 150 | fmt.Println("• True asymmetric proxy reencryption with LWE-based public keys") 151 | 152 | fmt.Println("\n=== Performance Summary ===") 153 | fmt.Printf("Bob's public key generation: %.2fms\n", float64(pubkeyTime.Microseconds())/1000.0) 154 | fmt.Printf("Reencryption key generation: %.2fms\n", float64(keygenTime.Microseconds())/1000.0) 155 | fmt.Printf("Average reencryption time: %.2fms\n", float64(reencTime.Microseconds())/float64(len(bobCiphertexts))/1000.0) 156 | } 157 | 158 | -------------------------------------------------------------------------------- /examples/simple_gates/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/thedonutfactory/go-tfhe/cloudkey" 8 | "github.com/thedonutfactory/go-tfhe/gates" 9 | "github.com/thedonutfactory/go-tfhe/key" 10 | "github.com/thedonutfactory/go-tfhe/params" 11 | "github.com/thedonutfactory/go-tfhe/tlwe" 12 | ) 13 | 14 | func encrypt(x bool, secretKey *key.SecretKey) *gates.Ciphertext { 15 | return tlwe.NewTLWELv0().EncryptBool(x, params.GetTLWELv0().ALPHA, secretKey.KeyLv0) 16 | } 17 | 18 | func decrypt(x *gates.Ciphertext, secretKey *key.SecretKey) bool { 19 | return x.DecryptBool(secretKey.KeyLv0) 20 | } 21 | 22 | func main() { 23 | fmt.Println("╔══════════════════════════════════════════════════════════════╗") 24 | fmt.Println("║ Go-TFHE: Homomorphic Gates Example ║") 25 | fmt.Println("╚══════════════════════════════════════════════════════════════╝") 26 | fmt.Println() 27 | 28 | secretKey := key.NewSecretKey() 29 | ck := cloudkey.NewCloudKey(secretKey) 30 | 31 | // Test inputs 32 | testCases := []struct { 33 | a, b bool 34 | }{ 35 | {false, false}, 36 | {false, true}, 37 | {true, false}, 38 | {true, true}, 39 | } 40 | 41 | for _, tc := range testCases { 42 | fmt.Printf("Testing inputs: A=%v, B=%v\n", tc.a, tc.b) 43 | fmt.Println(strings.Repeat("-", 60)) 44 | 45 | // Encrypt inputs 46 | ctA := encrypt(tc.a, secretKey) 47 | ctB := encrypt(tc.b, secretKey) 48 | 49 | // Test each gate 50 | testGate := func(name string, gateFunc func(*gates.Ciphertext, *gates.Ciphertext, *cloudkey.CloudKey) *gates.Ciphertext, expected bool) { 51 | result := gateFunc(ctA, ctB, ck) 52 | decrypted := decrypt(result, secretKey) 53 | status := "✅" 54 | if decrypted != expected { 55 | status = "❌" 56 | } 57 | fmt.Printf(" %s %s: %v (expected %v)\n", status, name, decrypted, expected) 58 | } 59 | 60 | testGate("AND ", gates.AND, tc.a && tc.b) 61 | testGate("OR ", gates.OR, tc.a || tc.b) 62 | testGate("NAND", gates.NAND, !(tc.a && tc.b)) 63 | testGate("NOR ", gates.NOR, !(tc.a || tc.b)) 64 | testGate("XOR ", gates.XOR, tc.a != tc.b) 65 | testGate("XNOR", gates.XNOR, tc.a == tc.b) 66 | 67 | fmt.Println() 68 | } 69 | 70 | // Test NOT gate 71 | fmt.Println("Testing NOT gate:") 72 | fmt.Println(strings.Repeat("-", 60)) 73 | for _, val := range []bool{false, true} { 74 | ct := encrypt(val, secretKey) 75 | notCt := gates.NOT(ct) 76 | decrypted := decrypt(notCt, secretKey) 77 | expected := !val 78 | status := "✅" 79 | if decrypted != expected { 80 | status = "❌" 81 | } 82 | fmt.Printf(" %s NOT(%v) = %v (expected %v)\n", status, val, decrypted, expected) 83 | } 84 | 85 | fmt.Println() 86 | fmt.Println("✅ All gate tests complete!") 87 | } 88 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/thedonutfactory/go-tfhe 2 | 3 | go 1.21 4 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thedonutfactory/go-tfhe/9c41b7a9be525d7dd51ddfd62093e4fa18293709/go.sum -------------------------------------------------------------------------------- /key/key.go: -------------------------------------------------------------------------------- 1 | package key 2 | 3 | import ( 4 | "math/rand" 5 | 6 | "github.com/thedonutfactory/go-tfhe/params" 7 | ) 8 | 9 | // SecretKey contains the secret keys for both levels 10 | type SecretKey struct { 11 | KeyLv0 []params.Torus 12 | KeyLv1 []params.Torus 13 | } 14 | 15 | // NewSecretKey generates a new secret key 16 | func NewSecretKey() *SecretKey { 17 | rng := rand.New(rand.NewSource(rand.Int63())) 18 | 19 | lv0N := params.GetTLWELv0().N 20 | lv1N := params.GetTLWELv1().N 21 | 22 | keyLv0 := make([]params.Torus, lv0N) 23 | keyLv1 := make([]params.Torus, lv1N) 24 | 25 | for i := 0; i < lv0N; i++ { 26 | if rng.Intn(2) == 1 { 27 | keyLv0[i] = 1 28 | } else { 29 | keyLv0[i] = 0 30 | } 31 | } 32 | 33 | for i := 0; i < lv1N; i++ { 34 | if rng.Intn(2) == 1 { 35 | keyLv1[i] = 1 36 | } else { 37 | keyLv1[i] = 0 38 | } 39 | } 40 | 41 | return &SecretKey{ 42 | KeyLv0: keyLv0, 43 | KeyLv1: keyLv1, 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /lut/analysis_test.go: -------------------------------------------------------------------------------- 1 | package lut 2 | 3 | import ( 4 | "math" 5 | "testing" 6 | 7 | "github.com/thedonutfactory/go-tfhe/utils" 8 | ) 9 | 10 | // TestAnalyzeLUTLayout analyzes the LUT layout for different functions 11 | func TestAnalyzeLUTLayout(t *testing.T) { 12 | gen := NewGenerator(2) 13 | n := gen.PolyDegree 14 | 15 | t.Log("=== Analyzing LUT Layouts ===\n") 16 | 17 | // Analyze what positions correspond to which inputs 18 | t.Log("Step 1: Understanding input encoding and ModSwitch mapping") 19 | 20 | // For binary TFHE: 21 | // - Input 0 (false) encodes to -1/8 = 7/8 = 0.875 22 | // - Input 1 (true) encodes to 1/8 = 0.125 23 | 24 | falseEncoded := utils.F64ToTorus(-0.125) // = 0.875 in unsigned 25 | trueEncoded := utils.F64ToTorus(0.125) 26 | 27 | t.Logf("Encoded values:") 28 | t.Logf(" false: %d (%.6f)", falseEncoded, utils.TorusToF64(falseEncoded)) 29 | t.Logf(" true: %d (%.6f)", trueEncoded, utils.TorusToF64(trueEncoded)) 30 | 31 | // What do these map to via ModSwitch? 32 | falseModSwitch := gen.ModSwitch(falseEncoded) 33 | trueModSwitch := gen.ModSwitch(trueEncoded) 34 | 35 | t.Logf("\nModSwitch values (out of [0, %d)):", 2*n) 36 | t.Logf(" ModSwitch(false) = %d", falseModSwitch) 37 | t.Logf(" ModSwitch(true) = %d", trueModSwitch) 38 | 39 | t.Logf("\nAfter blind rotation by -ModSwitch, coefficient 0 comes from:") 40 | t.Logf(" For false input: LUT[%d %% %d] = LUT[%d]", falseModSwitch, n, falseModSwitch%n) 41 | t.Logf(" For true input: LUT[%d %% %d] = LUT[%d]", trueModSwitch, n, trueModSwitch%n) 42 | 43 | // Analyze different functions 44 | functions := []struct { 45 | name string 46 | f func(int) int 47 | }{ 48 | {"identity", func(x int) int { return x }}, 49 | {"NOT", func(x int) int { return 1 - x }}, 50 | {"constant_0", func(x int) int { return 0 }}, 51 | {"constant_1", func(x int) int { return 1 }}, 52 | } 53 | 54 | for _, fn := range functions { 55 | t.Logf("\n--- Function: %s ---", fn.name) 56 | lut := gen.GenLookUpTable(fn.f) 57 | 58 | // Check what values are at the key positions 59 | falseLUTIdx := falseModSwitch % n 60 | trueLUTIdx := trueModSwitch % n 61 | 62 | falseLUTVal := lut.Poly.B[falseLUTIdx] 63 | trueLUTVal := lut.Poly.B[trueLUTIdx] 64 | 65 | t.Logf("LUT[%d] = %d (%.6f) - will be extracted for false input", 66 | falseLUTIdx, falseLUTVal, utils.TorusToF64(falseLUTVal)) 67 | t.Logf("LUT[%d] = %d (%.6f) - will be extracted for true input", 68 | trueLUTIdx, trueLUTVal, utils.TorusToF64(trueLUTVal)) 69 | 70 | // What should these be? 71 | expectedForFalse := fn.f(0) 72 | expectedForTrue := fn.f(1) 73 | 74 | var expectedFalseVal, expectedTrueVal float64 75 | if expectedForFalse == 0 { 76 | expectedFalseVal = 0.875 // -1/8 77 | } else { 78 | expectedFalseVal = 0.125 // 1/8 79 | } 80 | if expectedForTrue == 0 { 81 | expectedTrueVal = 0.875 82 | } else { 83 | expectedTrueVal = 0.125 84 | } 85 | 86 | t.Logf("\nExpected:") 87 | t.Logf(" %s(false) = %d → should encode to %.6f", fn.name, expectedForFalse, expectedFalseVal) 88 | t.Logf(" %s(true) = %d → should encode to %.6f", fn.name, expectedForTrue, expectedTrueVal) 89 | 90 | actualFalseVal := utils.TorusToF64(falseLUTVal) 91 | actualTrueVal := utils.TorusToF64(trueLUTVal) 92 | 93 | falseMatch := (actualFalseVal-expectedFalseVal < 0.01) || (actualFalseVal-expectedFalseVal > 0.99) 94 | trueMatch := (actualTrueVal-expectedTrueVal < 0.01) || (actualTrueVal-expectedTrueVal > 0.99) 95 | 96 | t.Logf("\nMatches:") 97 | t.Logf(" False input: %v (actual=%.6f, expected=%.6f)", falseMatch, actualFalseVal, expectedFalseVal) 98 | t.Logf(" True input: %v (actual=%.6f, expected=%.6f)", trueMatch, actualTrueVal, expectedTrueVal) 99 | } 100 | } 101 | 102 | // TestLUTRegionMapping tests which regions of the LUT correspond to which inputs 103 | func TestLUTRegionMapping(t *testing.T) { 104 | gen := NewGenerator(2) 105 | n := gen.PolyDegree 106 | 107 | t.Log("=== LUT Region Mapping Analysis ===\n") 108 | 109 | // Create a simple test: assign different values to different regions 110 | // and see what we get for different inputs 111 | 112 | t.Log("Creating test LUT with distinct regions:") 113 | testLUT := NewLookUpTable() 114 | 115 | // Fill first quarter with value A 116 | valA := utils.F64ToTorus(0.1) 117 | for i := 0; i < n/4; i++ { 118 | testLUT.Poly.B[i] = valA 119 | testLUT.Poly.A[i] = 0 120 | } 121 | 122 | // Fill second quarter with value B 123 | valB := utils.F64ToTorus(0.3) 124 | for i := n / 4; i < n/2; i++ { 125 | testLUT.Poly.B[i] = valB 126 | testLUT.Poly.A[i] = 0 127 | } 128 | 129 | // Fill third quarter with value C 130 | valC := utils.F64ToTorus(0.5) 131 | for i := n / 2; i < 3*n/4; i++ { 132 | testLUT.Poly.B[i] = valC 133 | testLUT.Poly.A[i] = 0 134 | } 135 | 136 | // Fill fourth quarter with value D 137 | valD := utils.F64ToTorus(0.7) 138 | for i := 3 * n / 4; i < n; i++ { 139 | testLUT.Poly.B[i] = valD 140 | testLUT.Poly.A[i] = 0 141 | } 142 | 143 | t.Logf("Region mapping:") 144 | t.Logf(" [0, %d): value A = %.3f", n/4, 0.1) 145 | t.Logf(" [%d, %d): value B = %.3f", n/4, n/2, 0.3) 146 | t.Logf(" [%d, %d): value C = %.3f", n/2, 3*n/4, 0.5) 147 | t.Logf(" [%d, %d): value D = %.3f", 3*n/4, n, 0.7) 148 | 149 | // Now check where false and true map to 150 | falseEncoded := utils.F64ToTorus(-0.125) 151 | trueEncoded := utils.F64ToTorus(0.125) 152 | 153 | falseModSwitch := gen.ModSwitch(falseEncoded) 154 | trueModSwitch := gen.ModSwitch(trueEncoded) 155 | 156 | falseLUTIdx := falseModSwitch % n 157 | trueLUTIdx := trueModSwitch % n 158 | 159 | t.Logf("\nInput mappings:") 160 | t.Logf(" false (0.875) → ModSwitch=%d → LUT[%d]", falseModSwitch, falseLUTIdx) 161 | t.Logf(" true (0.125) → ModSwitch=%d → LUT[%d]", trueModSwitch, trueLUTIdx) 162 | 163 | t.Logf(" false maps to region: %s", getRegion(falseLUTIdx, n)) 164 | t.Logf(" true maps to region: %s", getRegion(trueLUTIdx, n)) 165 | } 166 | 167 | func getRegion(idx, n int) string { 168 | if idx < n/4 { 169 | return "A (first quarter)" 170 | } else if idx < n/2 { 171 | return "B (second quarter)" 172 | } else if idx < 3*n/4 { 173 | return "C (third quarter)" 174 | } else { 175 | return "D (fourth quarter)" 176 | } 177 | } 178 | 179 | // TestCompareWithReferenceEncoding compares our encoding with reference 180 | func TestCompareWithReferenceEncoding(t *testing.T) { 181 | t.Log("=== Comparing Encoding Schemes ===\n") 182 | 183 | gen := NewGenerator(2) 184 | n := gen.PolyDegree 185 | 186 | // Reference TFHE test vector for identity is constant 0.125 187 | // This means: no matter what rotation, we always get 0.125 188 | // But that can't give us different outputs for different inputs! 189 | // 190 | // The key insight: the INPUT ciphertext already encodes the value. 191 | // The test vector for GATES doesn't evaluate a function - it refreshes noise. 192 | // 193 | // For programmable bootstrap, we WANT different outputs for different inputs. 194 | 195 | t.Log("Key insight:") 196 | t.Log(" Standard bootstrap (for gates): input is PRE-PROCESSED, test vector is constant") 197 | t.Log(" Programmable bootstrap: test vector encodes the function") 198 | 199 | t.Log("\nFor NOT function:") 200 | t.Log(" We want: NOT(false=0) = true=1, NOT(true=1) = false=0") 201 | t.Log(" So LUT should have:") 202 | 203 | falseEncoded := utils.F64ToTorus(-0.125) // 0.875 204 | trueEncoded := utils.F64ToTorus(0.125) 205 | 206 | falseMS := gen.ModSwitch(falseEncoded) 207 | trueMS := gen.ModSwitch(trueEncoded) 208 | 209 | t.Logf(" Position %d (for false input): value for NOT(false)=true = 0.125", falseMS%n) 210 | t.Logf(" Position %d (for true input): value for NOT(true)=false = 0.875", trueMS%n) 211 | 212 | // Generate NOT LUT and check 213 | notFunc := func(x int) int { return 1 - x } 214 | notLUT := gen.GenLookUpTable(notFunc) 215 | 216 | actualFalsePos := notLUT.Poly.B[falseMS%n] 217 | actualTruePos := notLUT.Poly.B[trueMS%n] 218 | 219 | t.Logf("\nActual NOT LUT:") 220 | t.Logf(" Position %d: %.6f (expected 0.125 for true)", falseMS%n, utils.TorusToF64(actualFalsePos)) 221 | t.Logf(" Position %d: %.6f (expected 0.875 for false)", trueMS%n, utils.TorusToF64(actualTruePos)) 222 | 223 | // Check if they match 224 | falseOK := math.Abs(utils.TorusToF64(actualFalsePos)-0.125) < 0.01 225 | trueOK := math.Abs(utils.TorusToF64(actualTruePos)-0.875) < 0.01 226 | 227 | if !falseOK || !trueOK { 228 | t.Logf("\n⚠️ Mismatch detected!") 229 | t.Logf(" Position for false input: %v", falseOK) 230 | t.Logf(" Position for true input: %v", trueOK) 231 | } else { 232 | t.Logf("\n✓ LUT correctly encoded!") 233 | } 234 | } 235 | -------------------------------------------------------------------------------- /lut/debug_test.go: -------------------------------------------------------------------------------- 1 | package lut 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/thedonutfactory/go-tfhe/params" 7 | "github.com/thedonutfactory/go-tfhe/utils" 8 | ) 9 | 10 | // TestEncoderDetailed provides detailed tracing of encoder behavior 11 | func TestEncoderDetailed(t *testing.T) { 12 | t.Log("=== Testing Encoder with Binary Messages ===") 13 | enc := NewEncoder(2) 14 | 15 | t.Logf("MessageModulus: %d", enc.MessageModulus) 16 | t.Logf("Scale: %f", enc.Scale) 17 | 18 | // Test encoding 0 19 | val0 := enc.Encode(0) 20 | f0 := utils.TorusToF64(val0) 21 | t.Logf("Encode(0) = %d (%.6f in [0,1))", val0, f0) 22 | 23 | // Test encoding 1 24 | val1 := enc.Encode(1) 25 | f1 := utils.TorusToF64(val1) 26 | t.Logf("Encode(1) = %d (%.6f in [0,1))", val1, f1) 27 | 28 | // Test decoding 29 | dec0 := enc.Decode(val0) 30 | dec1 := enc.Decode(val1) 31 | t.Logf("Decode(Encode(0)) = %d", dec0) 32 | t.Logf("Decode(Encode(1)) = %d", dec1) 33 | 34 | if dec0 != 0 { 35 | t.Errorf("Decode(Encode(0)) = %d, want 0", dec0) 36 | } 37 | if dec1 != 1 { 38 | t.Errorf("Decode(Encode(1)) = %d, want 1", dec1) 39 | } 40 | } 41 | 42 | // TestLUTGenerationDetailed provides detailed tracing of LUT generation 43 | func TestLUTGenerationDetailed(t *testing.T) { 44 | t.Log("=== Testing LUT Generation for Identity Function ===") 45 | 46 | gen := NewGenerator(2) 47 | t.Logf("PolyDegree: %d", gen.PolyDegree) 48 | t.Logf("LookUpTableSize: %d", gen.LookUpTableSize) 49 | t.Logf("MessageModulus: %d", gen.Encoder.MessageModulus) 50 | t.Logf("Scale: %f", gen.Encoder.Scale) 51 | 52 | identity := func(x int) int { return x } 53 | 54 | t.Log("\n--- Step 1: Generate LUT ---") 55 | lut := gen.GenLookUpTable(identity) 56 | 57 | t.Log("\n--- Step 2: Examine LUT Contents ---") 58 | t.Log("First 20 B coefficients:") 59 | for i := 0; i < 20 && i < gen.PolyDegree; i++ { 60 | val := lut.Poly.B[i] 61 | fval := utils.TorusToF64(val) 62 | t.Logf(" B[%d] = %10d (%.6f)", i, val, fval) 63 | } 64 | 65 | t.Log("\nLast 20 B coefficients:") 66 | for i := gen.PolyDegree - 20; i < gen.PolyDegree; i++ { 67 | val := lut.Poly.B[i] 68 | fval := utils.TorusToF64(val) 69 | t.Logf(" B[%d] = %10d (%.6f)", i, val, fval) 70 | } 71 | 72 | t.Log("\n--- Step 3: Check A coefficients (should be zero) ---") 73 | nonZeroA := 0 74 | for i := 0; i < gen.PolyDegree; i++ { 75 | if lut.Poly.A[i] != 0 { 76 | nonZeroA++ 77 | } 78 | } 79 | t.Logf("Non-zero A coefficients: %d (should be 0)", nonZeroA) 80 | 81 | if nonZeroA > 0 { 82 | t.Errorf("Expected all A coefficients to be zero, found %d non-zero", nonZeroA) 83 | } 84 | } 85 | 86 | // TestLUTGenerationStepByStep traces the algorithm step by step 87 | func TestLUTGenerationStepByStep(t *testing.T) { 88 | t.Log("=== Step-by-Step LUT Generation for Identity ===") 89 | 90 | gen := NewGenerator(2) 91 | messageModulus := gen.Encoder.MessageModulus 92 | 93 | t.Logf("Parameters:") 94 | t.Logf(" MessageModulus: %d", messageModulus) 95 | t.Logf(" PolyDegree (N): %d", gen.PolyDegree) 96 | t.Logf(" LookUpTableSize (2N): %d", gen.LookUpTableSize) 97 | 98 | // Manually trace through the algorithm 99 | identity := func(x int) int { return x } 100 | 101 | t.Log("\n--- Step 1: Create raw LUT ---") 102 | lutRaw := make([]params.Torus, gen.LookUpTableSize) 103 | 104 | for x := 0; x < messageModulus; x++ { 105 | start := divRound(x*gen.LookUpTableSize, messageModulus) 106 | end := divRound((x+1)*gen.LookUpTableSize, messageModulus) 107 | y := gen.Encoder.Encode(identity(x)) 108 | 109 | t.Logf("Message %d:", x) 110 | t.Logf(" f(%d) = %d", x, identity(x)) 111 | t.Logf(" Encoded: %d (%.6f)", y, utils.TorusToF64(y)) 112 | t.Logf(" Range in LUT: [%d, %d)", start, end) 113 | 114 | for i := start; i < end; i++ { 115 | lutRaw[i] = y 116 | } 117 | } 118 | 119 | t.Log("\n--- Step 2: Apply offset rotation ---") 120 | offset := divRound(gen.LookUpTableSize, 2*messageModulus) 121 | t.Logf("Offset: %d", offset) 122 | 123 | rotated := make([]params.Torus, gen.LookUpTableSize) 124 | for i := 0; i < gen.LookUpTableSize; i++ { 125 | srcIdx := (i + offset) % gen.LookUpTableSize 126 | rotated[i] = lutRaw[srcIdx] 127 | } 128 | 129 | t.Log("First 10 values after rotation:") 130 | for i := 0; i < 10; i++ { 131 | t.Logf(" rotated[%d] = %d (%.6f)", i, rotated[i], utils.TorusToF64(rotated[i])) 132 | } 133 | 134 | t.Log("\n--- Step 3: Apply negacyclic property ---") 135 | t.Logf("Storing first N=%d values directly", gen.PolyDegree) 136 | t.Logf("Subtracting second N values (due to X^N = -1)") 137 | 138 | result := make([]params.Torus, gen.PolyDegree) 139 | for i := 0; i < gen.PolyDegree; i++ { 140 | result[i] = rotated[i] 141 | } 142 | for i := gen.PolyDegree; i < gen.LookUpTableSize; i++ { 143 | result[i-gen.PolyDegree] -= rotated[i] 144 | } 145 | 146 | t.Log("\nFinal B coefficients (first 10):") 147 | for i := 0; i < 10; i++ { 148 | t.Logf(" B[%d] = %d (%.6f)", i, result[i], utils.TorusToF64(result[i])) 149 | } 150 | 151 | // Compare with actual generation 152 | t.Log("\n--- Comparing with actual GenLookUpTable ---") 153 | actualLUT := gen.GenLookUpTable(identity) 154 | 155 | matches := 0 156 | for i := 0; i < gen.PolyDegree; i++ { 157 | if result[i] == actualLUT.Poly.B[i] { 158 | matches++ 159 | } 160 | } 161 | 162 | t.Logf("Matching coefficients: %d / %d", matches, gen.PolyDegree) 163 | 164 | if matches != gen.PolyDegree { 165 | t.Log("\nFirst 10 differences:") 166 | count := 0 167 | for i := 0; i < gen.PolyDegree && count < 10; i++ { 168 | if result[i] != actualLUT.Poly.B[i] { 169 | t.Logf(" B[%d]: manual=%d, actual=%d", i, result[i], actualLUT.Poly.B[i]) 170 | count++ 171 | } 172 | } 173 | } 174 | } 175 | 176 | // TestModSwitchDetailed traces ModSwitch behavior 177 | func TestModSwitchDetailed(t *testing.T) { 178 | t.Log("=== Testing ModSwitch ===") 179 | 180 | gen := NewGenerator(2) 181 | n := gen.PolyDegree 182 | lookUpTableSize := gen.LookUpTableSize 183 | 184 | t.Logf("PolyDegree (N): %d", n) 185 | t.Logf("LookUpTableSize (2N): %d", lookUpTableSize) 186 | 187 | testCases := []struct { 188 | name string 189 | value params.Torus 190 | desc string 191 | }{ 192 | {"zero", 0, "0"}, 193 | {"quarter", params.Torus(1 << 30), "1/4 of torus"}, 194 | {"half", params.Torus(1 << 31), "1/2 of torus"}, 195 | {"three-quarter", params.Torus(3 << 30), "3/4 of torus"}, 196 | {"max", params.Torus(^uint32(0)), "max value"}, 197 | } 198 | 199 | for _, tc := range testCases { 200 | t.Run(tc.name, func(t *testing.T) { 201 | result := gen.ModSwitch(tc.value) 202 | 203 | // Calculate what it should be 204 | fVal := utils.TorusToF64(tc.value) 205 | expectedFloat := fVal * float64(lookUpTableSize) 206 | 207 | t.Logf("Input: %s (%d)", tc.desc, tc.value) 208 | t.Logf(" As float in [0,1): %.6f", fVal) 209 | t.Logf(" Scaled to [0, 2N): %.2f", expectedFloat) 210 | t.Logf(" ModSwitch result: %d", result) 211 | t.Logf(" In range [0, %d): %v", lookUpTableSize, result >= 0 && result < lookUpTableSize) 212 | 213 | if result < 0 || result >= lookUpTableSize { 214 | t.Errorf("ModSwitch result %d out of range [0, %d)", result, lookUpTableSize) 215 | } 216 | }) 217 | } 218 | } 219 | 220 | // TestCompareWithReferenceTestVector compares our LUT with what a test vector should look like 221 | func TestCompareWithReferenceTestVector(t *testing.T) { 222 | t.Log("=== Comparing LUT with Reference Test Vector ===") 223 | 224 | // A reference test vector for binary has constant 1/8 in all positions 225 | // This represents the identity function in TFHE 226 | referenceValue := utils.F64ToTorus(0.125) 227 | 228 | gen := NewGenerator(2) 229 | identity := func(x int) int { return x } 230 | lut := gen.GenLookUpTable(identity) 231 | 232 | t.Logf("Reference value (constant 1/8): %d (%.6f)", referenceValue, utils.TorusToF64(referenceValue)) 233 | 234 | t.Log("\nComparing first 20 B coefficients:") 235 | matches := 0 236 | for i := 0; i < 20 && i < gen.PolyDegree; i++ { 237 | actual := lut.Poly.B[i] 238 | actualF := utils.TorusToF64(actual) 239 | refF := utils.TorusToF64(referenceValue) 240 | 241 | match := "" 242 | if actual == referenceValue { 243 | matches++ 244 | match = "✓" 245 | } else { 246 | match = "✗" 247 | } 248 | 249 | t.Logf(" B[%d]: actual=%.6f, reference=%.6f %s", i, actualF, refF, match) 250 | } 251 | 252 | t.Logf("\nMatches: %d / %d", matches, 20) 253 | } 254 | -------------------------------------------------------------------------------- /lut/encoder.go: -------------------------------------------------------------------------------- 1 | package lut 2 | 3 | import ( 4 | "github.com/thedonutfactory/go-tfhe/params" 5 | "github.com/thedonutfactory/go-tfhe/utils" 6 | ) 7 | 8 | // Encoder provides encoding and decoding functions for different message spaces 9 | type Encoder struct { 10 | MessageModulus int // Number of possible messages (e.g., 2 for binary, 4 for 2-bit) 11 | Scale float64 // Scaling factor for encoding 12 | } 13 | 14 | // NewEncoder creates a new encoder with the given message modulus 15 | // For binary (boolean) operations, use messageModulus=2 16 | // The default encoding uses 1/(2*messageModulus) to place messages in the torus 17 | func NewEncoder(messageModulus int) *Encoder { 18 | // For TFHE, binary messages are encoded as ±1/8 19 | // Message 0 (false) -> -1/8 = 7/8 in unsigned representation 20 | // Message 1 (true) -> +1/8 21 | // 22 | // For general case with messageModulus m, we use ±1/(2m) 23 | // This gives us 1/4 for binary (m=2) 24 | scale := 1.0 / float64(2*messageModulus) 25 | return &Encoder{ 26 | MessageModulus: messageModulus, 27 | Scale: scale, 28 | } 29 | } 30 | 31 | // NewEncoderWithScale creates a new encoder with custom message modulus and scale 32 | func NewEncoderWithScale(messageModulus int, scale float64) *Encoder { 33 | return &Encoder{ 34 | MessageModulus: messageModulus, 35 | Scale: scale, 36 | } 37 | } 38 | 39 | // Encode encodes an integer message into a torus value 40 | // message should be in range [0, MessageModulus) 41 | // 42 | // For TFHE bootstrapping, the encoding is: 43 | // 44 | // message i -> (i + 0.5) * scale 45 | // 46 | // This centers each message in its quantization region 47 | func (e *Encoder) Encode(message int) params.Torus { 48 | // Normalize message to [0, MessageModulus) 49 | message = message % e.MessageModulus 50 | if message < 0 { 51 | message += e.MessageModulus 52 | } 53 | 54 | // Encode as (message + 0.5) * scale 55 | // For binary: 0 -> 0.5 * 0.25 = 0.125, 1 -> 1.5 * 0.25 = 0.375 56 | // But we want: 0 -> -0.125 (= 0.875), 1 -> 0.125 57 | // 58 | // Actually for TFHE bootstrapping, messages map to: (2i+1-m)/(2m) 59 | // For m=2: i=0 -> -1/4 = 3/4, i=1 -> 1/4 60 | // 61 | // Hmm, let me reconsider. The standard TFHE encoding is: 62 | // For boolean: false=-1/8, true=1/8 63 | // In unsigned: false=7/8, true=1/8 64 | // 65 | // For m values: message i maps to (2i+1-m) / (2m) 66 | // m=2: i=0 -> (0+1-2)/(4) = -1/4 = 3/4 67 | // i=1 -> (2+1-2)/(4) = 1/4 68 | // 69 | // But for bootstrapping, we actually want something different... 70 | // Let me use the simpler formula: message i -> i * scale 71 | // with offset handling 72 | 73 | value := float64(message) * e.Scale 74 | return utils.F64ToTorus(value) 75 | } 76 | 77 | // EncodeWithCustomScale encodes with a custom scale factor 78 | func (e *Encoder) EncodeWithCustomScale(message int, scale float64) params.Torus { 79 | message = message % e.MessageModulus 80 | if message < 0 { 81 | message += e.MessageModulus 82 | } 83 | value := float64(message) * scale 84 | return utils.F64ToTorus(value) 85 | } 86 | 87 | // Decode decodes a torus value back to an integer message 88 | func (e *Encoder) Decode(value params.Torus) int { 89 | // Convert torus to float 90 | f := utils.TorusToF64(value) 91 | 92 | // Round to nearest message 93 | message := int(f/e.Scale + 0.5) 94 | 95 | // Normalize to [0, MessageModulus) 96 | message = message % e.MessageModulus 97 | if message < 0 { 98 | message += e.MessageModulus 99 | } 100 | 101 | return message 102 | } 103 | 104 | // DecodeBool decodes a torus value to a boolean (for binary messages) 105 | func (e *Encoder) DecodeBool(value params.Torus) bool { 106 | return e.Decode(value) != 0 107 | } 108 | -------------------------------------------------------------------------------- /lut/generator.go: -------------------------------------------------------------------------------- 1 | package lut 2 | 3 | import ( 4 | "math" 5 | 6 | "github.com/thedonutfactory/go-tfhe/params" 7 | ) 8 | 9 | // Generator creates lookup tables from functions for programmable bootstrapping 10 | type Generator struct { 11 | Encoder *Encoder 12 | PolyDegree int 13 | LookUpTableSize int // For binary: equals PolyDegree (not 2*PolyDegree!) 14 | } 15 | 16 | // NewGenerator creates a new LUT generator 17 | func NewGenerator(messageModulus int) *Generator { 18 | polyDegree := params.GetTRGSWLv1().N 19 | // CRITICAL: For standard TFHE, lookUpTableSize = polyDegree (polyExtendFactor = 1) 20 | // Only for extended configurations is lookUpTableSize > polyDegree 21 | lookUpTableSize := polyDegree 22 | 23 | return &Generator{ 24 | Encoder: NewEncoder(messageModulus), 25 | PolyDegree: polyDegree, 26 | LookUpTableSize: lookUpTableSize, 27 | } 28 | } 29 | 30 | // NewGeneratorWithScale creates a new LUT generator with custom scale 31 | func NewGeneratorWithScale(messageModulus int, scale float64) *Generator { 32 | polyDegree := params.GetTRGSWLv1().N 33 | return &Generator{ 34 | Encoder: NewEncoderWithScale(messageModulus, scale), 35 | PolyDegree: polyDegree, 36 | LookUpTableSize: polyDegree, // Standard: lookUpTableSize = polyDegree 37 | } 38 | } 39 | 40 | // GenLookUpTable generates a lookup table from a function f: int -> int 41 | func (g *Generator) GenLookUpTable(f func(int) int) *LookUpTable { 42 | lut := NewLookUpTable() 43 | g.GenLookUpTableAssign(f, lut) 44 | return lut 45 | } 46 | 47 | // GenLookUpTableAssign generates a lookup table and writes to lutOut 48 | // 49 | // Algorithm from tfhe-go reference implementation (bootstrap_lut.go:111-132) 50 | // For standard TFHE with polyExtendFactor=1 (lookUpTableSize = polyDegree): 51 | // 1. Create lutRaw[lookUpTableSize] 52 | // 2. For each message x, fill range with encoded f(x) 53 | // 3. Rotate by offset 54 | // 4. Negate tail 55 | // 5. Store in polynomial 56 | func (g *Generator) GenLookUpTableAssign(f func(int) int, lutOut *LookUpTable) { 57 | messageModulus := g.Encoder.MessageModulus 58 | 59 | // Create raw LUT buffer (size = lookUpTableSize, which equals N for standard TFHE) 60 | lutRaw := make([]params.Torus, g.LookUpTableSize) 61 | 62 | // Fill each message's range with encoded output 63 | for x := 0; x < messageModulus; x++ { 64 | start := divRound(x*g.LookUpTableSize, messageModulus) 65 | end := divRound((x+1)*g.LookUpTableSize, messageModulus) 66 | 67 | // Apply function to message index 68 | y := f(x) 69 | 70 | // Encode the output: message * scale 71 | encodedY := g.Encoder.Encode(y) 72 | 73 | // Fill range 74 | for xx := start; xx < end; xx++ { 75 | lutRaw[xx] = encodedY 76 | } 77 | } 78 | 79 | // Rotate by offset 80 | offset := divRound(g.LookUpTableSize, 2*messageModulus) 81 | 82 | // Apply rotation 83 | rotated := make([]params.Torus, g.LookUpTableSize) 84 | for i := 0; i < g.LookUpTableSize; i++ { 85 | srcIdx := (i + offset) % g.LookUpTableSize 86 | rotated[i] = lutRaw[srcIdx] 87 | } 88 | 89 | // Negate tail portion 90 | for i := g.LookUpTableSize - offset; i < g.LookUpTableSize; i++ { 91 | rotated[i] = -rotated[i] 92 | } 93 | 94 | // Store in polynomial 95 | // For polyExtendFactor=1: just copy all lookUpTableSize coefficients 96 | for i := 0; i < g.LookUpTableSize; i++ { 97 | lutOut.Poly.B[i] = rotated[i] 98 | lutOut.Poly.A[i] = 0 99 | } 100 | } 101 | 102 | // GenLookUpTableFull generates a lookup table from a function f: int -> Torus 103 | func (g *Generator) GenLookUpTableFull(f func(int) params.Torus) *LookUpTable { 104 | lut := NewLookUpTable() 105 | g.GenLookUpTableFullAssign(f, lut) 106 | return lut 107 | } 108 | 109 | // GenLookUpTableFullAssign generates a lookup table with full control 110 | func (g *Generator) GenLookUpTableFullAssign(f func(int) params.Torus, lutOut *LookUpTable) { 111 | messageModulus := g.Encoder.MessageModulus 112 | 113 | lutRaw := make([]params.Torus, g.LookUpTableSize) 114 | 115 | for x := 0; x < messageModulus; x++ { 116 | start := divRound(x*g.LookUpTableSize, messageModulus) 117 | end := divRound((x+1)*g.LookUpTableSize, messageModulus) 118 | 119 | y := f(x) 120 | 121 | for i := start; i < end; i++ { 122 | lutRaw[i] = y 123 | } 124 | } 125 | 126 | offset := divRound(g.LookUpTableSize, 2*messageModulus) 127 | rotated := make([]params.Torus, g.LookUpTableSize) 128 | for i := 0; i < g.LookUpTableSize; i++ { 129 | srcIdx := (i + offset) % g.LookUpTableSize 130 | rotated[i] = lutRaw[srcIdx] 131 | } 132 | 133 | for i := g.LookUpTableSize - offset; i < g.LookUpTableSize; i++ { 134 | rotated[i] = -rotated[i] 135 | } 136 | 137 | for i := 0; i < g.LookUpTableSize; i++ { 138 | lutOut.Poly.B[i] = rotated[i] 139 | lutOut.Poly.A[i] = 0 140 | } 141 | } 142 | 143 | // GenLookUpTableCustom generates a lookup table with custom message modulus and scale 144 | func (g *Generator) GenLookUpTableCustom(f func(int) int, messageModulus int, scale float64) *LookUpTable { 145 | lut := NewLookUpTable() 146 | 147 | oldEncoder := g.Encoder 148 | g.Encoder = NewEncoderWithScale(messageModulus, scale) 149 | 150 | g.GenLookUpTableAssign(f, lut) 151 | 152 | g.Encoder = oldEncoder 153 | 154 | return lut 155 | } 156 | 157 | // ModSwitch switches the modulus of x from Torus (2^32) to lookUpTableSize 158 | // For standard TFHE with lookUpTableSize=N: result in [0, N) 159 | func (g *Generator) ModSwitch(x params.Torus) int { 160 | scaled := float64(x) / float64(uint64(1)<<32) * float64(g.LookUpTableSize) 161 | result := int(math.Round(scaled)) % g.LookUpTableSize 162 | 163 | if result < 0 { 164 | result += g.LookUpTableSize 165 | } 166 | 167 | return result 168 | } 169 | 170 | // divRound performs integer division with rounding 171 | func divRound(a, b int) int { 172 | return (a + b/2) / b 173 | } 174 | -------------------------------------------------------------------------------- /lut/lut.go: -------------------------------------------------------------------------------- 1 | // Package lut provides LookUpTable support for programmable bootstrapping. 2 | // This enables evaluating arbitrary functions on encrypted data during bootstrapping. 3 | package lut 4 | 5 | import ( 6 | "github.com/thedonutfactory/go-tfhe/params" 7 | "github.com/thedonutfactory/go-tfhe/trlwe" 8 | ) 9 | 10 | // LookUpTable is a TRLWE ciphertext that encodes a function 11 | // for programmable bootstrapping. 12 | // During blind rotation, the LUT is rotated based on the encrypted value, 13 | // effectively evaluating the function on the encrypted data. 14 | type LookUpTable struct { 15 | // Polynomial encoding the function values 16 | Poly *trlwe.TRLWELv1 17 | } 18 | 19 | // NewLookUpTable creates a new lookup table 20 | func NewLookUpTable() *LookUpTable { 21 | return &LookUpTable{ 22 | Poly: trlwe.NewTRLWELv1(), 23 | } 24 | } 25 | 26 | // Copy returns a deep copy of the lookup table 27 | func (lut *LookUpTable) Copy() *LookUpTable { 28 | result := NewLookUpTable() 29 | copy(result.Poly.A, lut.Poly.A) 30 | copy(result.Poly.B, lut.Poly.B) 31 | return result 32 | } 33 | 34 | // CopyFrom copies values from another lookup table 35 | func (lut *LookUpTable) CopyFrom(other *LookUpTable) { 36 | copy(lut.Poly.A, other.Poly.A) 37 | copy(lut.Poly.B, other.Poly.B) 38 | } 39 | 40 | // Clear clears the lookup table (sets all coefficients to 0) 41 | func (lut *LookUpTable) Clear() { 42 | n := params.GetTRGSWLv1().N 43 | for i := 0; i < n; i++ { 44 | lut.Poly.A[i] = 0 45 | lut.Poly.B[i] = 0 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /lut/lut_test.go: -------------------------------------------------------------------------------- 1 | package lut 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/thedonutfactory/go-tfhe/params" 7 | "github.com/thedonutfactory/go-tfhe/utils" 8 | ) 9 | 10 | func TestLookUpTableBasic(t *testing.T) { 11 | // Test creation and basic operations 12 | lut := NewLookUpTable() 13 | 14 | if lut == nil { 15 | t.Fatal("NewLookUpTable returned nil") 16 | } 17 | 18 | if lut.Poly == nil { 19 | t.Fatal("LookUpTable polynomial is nil") 20 | } 21 | 22 | // Test clear 23 | lut.Poly.B[0] = 123 24 | lut.Clear() 25 | if lut.Poly.B[0] != 0 { 26 | t.Error("Clear did not clear the polynomial") 27 | } 28 | } 29 | 30 | func TestLookUpTableCopy(t *testing.T) { 31 | lut1 := NewLookUpTable() 32 | lut1.Poly.B[0] = 42 33 | lut1.Poly.A[0] = 17 34 | 35 | // Test Copy 36 | lut2 := lut1.Copy() 37 | if lut2.Poly.B[0] != 42 || lut2.Poly.A[0] != 17 { 38 | t.Error("Copy did not copy values correctly") 39 | } 40 | 41 | // Modify original and ensure copy is unchanged 42 | lut1.Poly.B[0] = 99 43 | if lut2.Poly.B[0] != 42 { 44 | t.Error("Copy is not independent of original") 45 | } 46 | 47 | // Test CopyFrom 48 | lut3 := NewLookUpTable() 49 | lut3.CopyFrom(lut1) 50 | if lut3.Poly.B[0] != 99 { 51 | t.Error("CopyFrom did not copy values correctly") 52 | } 53 | } 54 | 55 | func TestEncoder(t *testing.T) { 56 | // Test binary encoder (message modulus = 2) 57 | enc := NewEncoder(2) 58 | 59 | // Test encoding 60 | val0 := enc.Encode(0) 61 | val1 := enc.Encode(1) 62 | 63 | // Values should be different 64 | if val0 == val1 { 65 | t.Error("Encoded values for 0 and 1 should be different") 66 | } 67 | 68 | // Test decoding 69 | if enc.Decode(val0) != 0 { 70 | t.Errorf("Decode(Encode(0)) = %d, want 0", enc.Decode(val0)) 71 | } 72 | if enc.Decode(val1) != 1 { 73 | t.Errorf("Decode(Encode(1)) = %d, want 1", enc.Decode(val1)) 74 | } 75 | 76 | // Test DecodeBool 77 | if enc.DecodeBool(val0) != false { 78 | t.Error("DecodeBool(Encode(0)) should be false") 79 | } 80 | if enc.DecodeBool(val1) != true { 81 | t.Error("DecodeBool(Encode(1)) should be true") 82 | } 83 | } 84 | 85 | func TestEncoderModular(t *testing.T) { 86 | // Test with message modulus = 4 87 | enc := NewEncoder(4) 88 | 89 | for i := 0; i < 4; i++ { 90 | encoded := enc.Encode(i) 91 | decoded := enc.Decode(encoded) 92 | if decoded != i { 93 | t.Errorf("Encode/Decode(%d) = %d, want %d", i, decoded, i) 94 | } 95 | } 96 | 97 | // Test negative wrapping 98 | if enc.Encode(-1) != enc.Encode(3) { 99 | t.Error("Negative values should wrap modulo MessageModulus") 100 | } 101 | 102 | // Test overflow wrapping 103 | if enc.Encode(4) != enc.Encode(0) { 104 | t.Error("Values >= MessageModulus should wrap") 105 | } 106 | } 107 | 108 | func TestGeneratorIdentity(t *testing.T) { 109 | // Test identity function (f(x) = x) 110 | gen := NewGenerator(4) 111 | 112 | identity := func(x int) int { return x } 113 | lut := gen.GenLookUpTable(identity) 114 | 115 | if lut == nil { 116 | t.Fatal("GenLookUpTable returned nil") 117 | } 118 | 119 | // Lookup table should be created without error 120 | // Detailed functional testing requires full TFHE stack 121 | } 122 | 123 | func TestGeneratorConstant(t *testing.T) { 124 | // Test constant function (f(x) = c) 125 | gen := NewGenerator(2) 126 | 127 | constantOne := func(x int) int { return 1 } 128 | lut := gen.GenLookUpTable(constantOne) 129 | 130 | if lut == nil { 131 | t.Fatal("GenLookUpTable returned nil") 132 | } 133 | 134 | // All values should encode to the same constant 135 | // Detailed verification requires full TFHE stack 136 | } 137 | 138 | func TestGeneratorNOT(t *testing.T) { 139 | // Test NOT function for binary (f(x) = 1 - x) 140 | gen := NewGenerator(2) 141 | 142 | notFunc := func(x int) int { return 1 - x } 143 | lut := gen.GenLookUpTable(notFunc) 144 | 145 | if lut == nil { 146 | t.Fatal("GenLookUpTable returned nil") 147 | } 148 | } 149 | 150 | func TestGeneratorCustomModulus(t *testing.T) { 151 | // Test with custom message modulus 152 | gen := NewGenerator(8) 153 | 154 | // Function that doubles the input mod 8 155 | doubleFunc := func(x int) int { return (2 * x) % 8 } 156 | lut := gen.GenLookUpTableCustom(doubleFunc, 8, 1.0/16.0) 157 | 158 | if lut == nil { 159 | t.Fatal("GenLookUpTableCustom returned nil") 160 | } 161 | } 162 | 163 | func TestModSwitch(t *testing.T) { 164 | gen := NewGenerator(2) 165 | n := params.GetTRGSWLv1().N 166 | 167 | // Test modulus switching at key points 168 | tests := []struct { 169 | name string 170 | input params.Torus 171 | }{ 172 | {"zero", 0}, 173 | {"quarter", params.Torus(1 << 30)}, 174 | {"half", params.Torus(1 << 31)}, 175 | {"three-quarters", params.Torus(3 << 30)}, 176 | {"max", params.Torus(^uint32(0))}, 177 | } 178 | 179 | for _, tt := range tests { 180 | t.Run(tt.name, func(t *testing.T) { 181 | result := gen.ModSwitch(tt.input) 182 | 183 | // Result should be in valid range 184 | if result < 0 || result >= 2*n { 185 | t.Errorf("ModSwitch(%d) = %d, out of range [0, %d)", tt.input, result, 2*n) 186 | } 187 | }) 188 | } 189 | } 190 | 191 | func TestGeneratorFullControl(t *testing.T) { 192 | // Test GenLookUpTableFull for fine-grained control 193 | gen := NewGenerator(2) 194 | 195 | // Function that returns exact torus values 196 | fullFunc := func(x int) params.Torus { 197 | if x == 0 { 198 | return utils.F64ToTorus(0.0) 199 | } 200 | return utils.F64ToTorus(0.25) 201 | } 202 | 203 | lut := gen.GenLookUpTableFull(fullFunc) 204 | 205 | if lut == nil { 206 | t.Fatal("GenLookUpTableFull returned nil") 207 | } 208 | } 209 | 210 | func BenchmarkLookUpTableCreation(b *testing.B) { 211 | gen := NewGenerator(2) 212 | identity := func(x int) int { return x } 213 | 214 | b.ResetTimer() 215 | for i := 0; i < b.N; i++ { 216 | _ = gen.GenLookUpTable(identity) 217 | } 218 | } 219 | 220 | func BenchmarkModSwitch(b *testing.B) { 221 | gen := NewGenerator(2) 222 | testVal := params.Torus(12345678) 223 | 224 | b.ResetTimer() 225 | for i := 0; i < b.N; i++ { 226 | _ = gen.ModSwitch(testVal) 227 | } 228 | } 229 | 230 | func BenchmarkEncode(b *testing.B) { 231 | enc := NewEncoder(2) 232 | 233 | b.ResetTimer() 234 | for i := 0; i < b.N; i++ { 235 | _ = enc.Encode(i % 2) 236 | } 237 | } 238 | 239 | func BenchmarkDecode(b *testing.B) { 240 | enc := NewEncoder(2) 241 | testVal := enc.Encode(1) 242 | 243 | b.ResetTimer() 244 | for i := 0; i < b.N; i++ { 245 | _ = enc.Decode(testVal) 246 | } 247 | } 248 | -------------------------------------------------------------------------------- /lut/reference_algorithm_test.go: -------------------------------------------------------------------------------- 1 | package lut 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/thedonutfactory/go-tfhe/params" 7 | "github.com/thedonutfactory/go-tfhe/utils" 8 | ) 9 | 10 | // TestReferenceAlgorithmStepByStep traces the reference algorithm step by step 11 | func TestReferenceAlgorithmStepByStep(t *testing.T) { 12 | messageModulus := 2 13 | polyDegree := params.GetTRGSWLv1().N // 1024 14 | lookUpTableSize := 2 * polyDegree // 2048 15 | 16 | t.Log("=== Reference Algorithm for NOT Function ===\n") 17 | t.Logf("Parameters: messageModulus=%d, N=%d, LUTSize=%d\n", messageModulus, polyDegree, lookUpTableSize) 18 | 19 | notFunc := func(x int) int { return 1 - x } 20 | 21 | // Step 1: Create raw LUT 22 | t.Log("Step 1: Fill raw LUT") 23 | lutRaw := make([]params.Torus, lookUpTableSize) 24 | 25 | for x := 0; x < messageModulus; x++ { 26 | start := divRound(x*lookUpTableSize, messageModulus) 27 | end := divRound((x+1)*lookUpTableSize, messageModulus) 28 | 29 | output := notFunc(x) 30 | var encodedOutput params.Torus 31 | if output == 0 { 32 | encodedOutput = utils.F64ToTorus(-0.125) // 0.875 33 | } else { 34 | encodedOutput = utils.F64ToTorus(0.125) 35 | } 36 | 37 | t.Logf(" Message %d: NOT(%d)=%d → encode to %.3f", x, x, output, utils.TorusToF64(encodedOutput)) 38 | t.Logf(" Fill indices [%d, %d)", start, end) 39 | 40 | for i := start; i < end; i++ { 41 | lutRaw[i] = encodedOutput 42 | } 43 | } 44 | 45 | t.Log("\n Check key positions in raw LUT:") 46 | checkPos := []int{0, 256, 512, 768, 1024, 1280, 1536, 1792} 47 | for _, pos := range checkPos { 48 | t.Logf(" lutRaw[%4d] = %.3f", pos, utils.TorusToF64(lutRaw[pos])) 49 | } 50 | 51 | // Step 2: Rotate by offset 52 | offset := divRound(lookUpTableSize, 2*messageModulus) 53 | t.Logf("\nStep 2: Rotate by offset=%d", offset) 54 | 55 | rotated := make([]params.Torus, lookUpTableSize) 56 | for i := 0; i < lookUpTableSize; i++ { 57 | srcIdx := (i + offset) % lookUpTableSize 58 | rotated[i] = lutRaw[srcIdx] 59 | } 60 | 61 | t.Log(" Check key positions after rotation:") 62 | for _, pos := range checkPos { 63 | srcPos := (pos + offset) % lookUpTableSize 64 | t.Logf(" rotated[%4d] = lutRaw[%4d] = %.3f", pos, srcPos, utils.TorusToF64(rotated[pos])) 65 | } 66 | 67 | // Step 3: Negate tail 68 | negateStart := lookUpTableSize - offset 69 | t.Logf("\nStep 3: Negate indices [%d, %d)", negateStart, lookUpTableSize) 70 | 71 | for i := negateStart; i < lookUpTableSize; i++ { 72 | rotated[i] = -rotated[i] 73 | } 74 | 75 | t.Log(" Check key positions after negation:") 76 | for _, pos := range checkPos { 77 | neg := "" 78 | if pos >= negateStart { 79 | neg = " (negated)" 80 | } 81 | t.Logf(" rotated[%4d] = %.3f%s", pos, utils.TorusToF64(rotated[pos]), neg) 82 | } 83 | 84 | // Step 4: Store first N coefficients 85 | t.Logf("\nStep 4: Store first N=%d coefficients in polynomial", polyDegree) 86 | 87 | result := NewLookUpTable() 88 | for i := 0; i < polyDegree; i++ { 89 | result.Poly.B[i] = rotated[i] 90 | result.Poly.A[i] = 0 91 | } 92 | 93 | t.Log("\n Final LUT key positions:") 94 | checkPosFinal := []int{0, 256, 512, 768} 95 | for _, pos := range checkPosFinal { 96 | t.Logf(" LUT.Poly.B[%4d] = %.3f", pos, utils.TorusToF64(result.Poly.B[pos])) 97 | } 98 | 99 | // Compare with actual generator 100 | t.Log("\n Comparing with GenLookUpTable:") 101 | gen := NewGenerator(2) 102 | actualLUT := gen.GenLookUpTable(notFunc) 103 | 104 | matches := 0 105 | for i := 0; i < polyDegree; i++ { 106 | if result.Poly.B[i] == actualLUT.Poly.B[i] { 107 | matches++ 108 | } 109 | } 110 | t.Logf(" Matching coefficients: %d / %d", matches, polyDegree) 111 | 112 | if matches != polyDegree { 113 | t.Log("\n First 10 mismatches:") 114 | count := 0 115 | for i := 0; i < polyDegree && count < 10; i++ { 116 | if result.Poly.B[i] != actualLUT.Poly.B[i] { 117 | t.Logf(" [%d]: manual=%.3f, actual=%.3f", 118 | i, utils.TorusToF64(result.Poly.B[i]), utils.TorusToF64(actualLUT.Poly.B[i])) 119 | count++ 120 | } 121 | } 122 | } 123 | 124 | // Now verify this gives correct results for ideal inputs 125 | t.Log("\n Verification with ideal encoded inputs:") 126 | 127 | falseIdeal := utils.F64ToTorus(-0.125) // 0.875 128 | trueIdeal := utils.F64ToTorus(0.125) 129 | 130 | falseMS := gen.ModSwitch(falseIdeal) 131 | trueMS := gen.ModSwitch(trueIdeal) 132 | 133 | t.Logf(" false (0.875) → ModSwitch=%d → extract from LUT[%d]", falseMS, falseMS%polyDegree) 134 | t.Logf(" Value: %.3f, Expected: %.3f (NOT(false)=true)", 135 | utils.TorusToF64(result.Poly.B[falseMS%polyDegree]), 0.125) 136 | 137 | t.Logf(" true (0.125) → ModSwitch=%d → extract from LUT[%d]", trueMS, trueMS%polyDegree) 138 | t.Logf(" Value: %.3f, Expected: %.3f (NOT(true)=false)", 139 | utils.TorusToF64(result.Poly.B[trueMS%polyDegree]), 0.875) 140 | } 141 | -------------------------------------------------------------------------------- /params/UINT_STATUS.md: -------------------------------------------------------------------------------- 1 | # Uint Parameter Sets Status 2 | 3 | ## Production Ready ✅ 4 | 5 | | Parameter | messageModulus | Poly Degree | Status | Test Results | 6 | |-----------|----------------|-------------|--------|--------------| 7 | | **Uint2** | 4 | 512 | ✅ **READY** | 100% pass (Identity, Complement, Modulo) | 8 | | **Uint3** | 8 | 1024 | ✅ **READY** | 100% pass (Identity, Complement, Modulo) | 9 | | **Uint4** | 16 | 2048 | ✅ **READY** | 100% pass (Identity, Complement, Modulo) | 10 | | **Uint5** | 32 | 2048 | ✅ **READY** | 100% pass (Identity, Complement, Modulo) | 11 | 12 | ## Experimental ⚠️ 13 | 14 | | Parameter | messageModulus | Poly Degree | LUTSize | Status | Test Results | 15 | |-----------|----------------|-------------|---------|--------|--------------| 16 | | Uint6 | 64 | 2048 | 4096 | ⚠️ **EXPERIMENTAL** | Identity ✅, Complement ❌, Modulo ❌ | 17 | | Uint7 | 128 | 2048 | 8192 | ⚠️ **EXPERIMENTAL** | Partial failures | 18 | | Uint8 | 256 | 2048 | 18432 | ⚠️ **EXPERIMENTAL** | Partial failures | 19 | 20 | ## Why Uint6-8 Are Experimental 21 | 22 | Uint6-8 use **extended lookup tables** where `LookUpTableSize > PolyDegree`: 23 | - Uint6: LookUpTableSize = 4096 = 2 × PolyDegree (polyExtendFactor = 2) 24 | - Uint7: LookUpTableSize = 8192 = 4 × PolyDegree (polyExtendFactor = 4) 25 | - Uint8: LookUpTableSize = 18432 = 9 × PolyDegree (polyExtendFactor = 9) 26 | 27 | Our current LUT generation assumes `LookUpTableSize = PolyDegree`. Supporting extended LUTs requires: 28 | 1. Modified LUT generation algorithm with polyExtendFactor 29 | 2. Special blind rotation handling for extended LUTs 30 | 3. Additional testing and validation 31 | 32 | ## Recommendation 33 | 34 | **For Production Use:** 35 | - Use **Uint2-5** which are fully tested and reliable 36 | - Uint5 supports messageModulus=32 which is sufficient for most applications 37 | - For 8-bit values, use nibble-based decomposition with Uint5 38 | 39 | **For Research/Development:** 40 | - Uint6-8 can be explored for specific use cases 41 | - Identity function works, suggesting basic PBS is functional 42 | - More complex functions need additional work 43 | 44 | ## Workaround for Larger Values 45 | 46 | Instead of Uint8 (0-255 direct), use Uint5 with byte decomposition: 47 | ```go 48 | // Split 8-bit value into two 4-bit nibbles 49 | low := value & 0x0F 50 | high := (value >> 4) & 0x0F 51 | 52 | // Encrypt with Uint5 (messageModulus=32) 53 | // Process nibbles separately 54 | // Combine with only 4 bootstraps! 55 | ``` 56 | 57 | This is actually **faster and more reliable** than direct Uint8! 58 | 59 | ## Future Work 60 | 61 | To make Uint6-8 production-ready: 62 | 1. Implement extended LUT generation (polyExtendFactor > 1) 63 | 2. Update LUT generator to handle larger table sizes 64 | 3. Comprehensive testing of extended PBS 65 | 4. Performance optimization 66 | 67 | For now, **Uint2-5 provide excellent coverage** for practical homomorphic arithmetic! 68 | -------------------------------------------------------------------------------- /params/params_test.go: -------------------------------------------------------------------------------- 1 | package params_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/thedonutfactory/go-tfhe/params" 7 | ) 8 | 9 | func TestSecurityLevelSwitching(t *testing.T) { 10 | // Test 128-bit (default) 11 | params.CurrentSecurityLevel = params.Security128Bit 12 | tlwe0 := params.GetTLWELv0() 13 | if tlwe0.N != 700 { 14 | t.Errorf("128-bit TLWE Lv0 N = %d, expected 700", tlwe0.N) 15 | } 16 | 17 | // Test 110-bit 18 | params.CurrentSecurityLevel = params.Security110Bit 19 | tlwe0 = params.GetTLWELv0() 20 | if tlwe0.N != 630 { 21 | t.Errorf("110-bit TLWE Lv0 N = %d, expected 630", tlwe0.N) 22 | } 23 | 24 | // Test 80-bit 25 | params.CurrentSecurityLevel = params.Security80Bit 26 | tlwe0 = params.GetTLWELv0() 27 | if tlwe0.N != 550 { 28 | t.Errorf("80-bit TLWE Lv0 N = %d, expected 550", tlwe0.N) 29 | } 30 | 31 | // Reset to default 32 | params.CurrentSecurityLevel = params.Security128Bit 33 | } 34 | 35 | func TestParameterConsistency(t *testing.T) { 36 | params.CurrentSecurityLevel = params.Security128Bit 37 | 38 | tlwe0 := params.GetTLWELv0() 39 | tlwe1 := params.GetTLWELv1() 40 | trlwe1 := params.GetTRLWELv1() 41 | trgsw1 := params.GetTRGSWLv1() 42 | 43 | // Verify basic constraints 44 | if tlwe0.N <= 0 { 45 | t.Errorf("TLWE Lv0 N must be positive, got %d", tlwe0.N) 46 | } 47 | if tlwe1.N <= 0 { 48 | t.Errorf("TLWE Lv1 N must be positive, got %d", tlwe1.N) 49 | } 50 | if tlwe0.ALPHA <= 0 { 51 | t.Errorf("TLWE Lv0 ALPHA must be positive, got %f", tlwe0.ALPHA) 52 | } 53 | if tlwe1.ALPHA <= 0 { 54 | t.Errorf("TLWE Lv1 ALPHA must be positive, got %f", tlwe1.ALPHA) 55 | } 56 | 57 | // Verify TRLWE and TLWE Lv1 have same N 58 | if trlwe1.N != tlwe1.N { 59 | t.Errorf("TRLWE Lv1 N (%d) should equal TLWE Lv1 N (%d)", trlwe1.N, tlwe1.N) 60 | } 61 | 62 | // Verify TRGSW BG matches BGBIT 63 | expectedBG := uint32(1) << trgsw1.BGBIT 64 | if trgsw1.BG != expectedBG { 65 | t.Errorf("TRGSW BG = %d, expected %d (1 << %d)", trgsw1.BG, expectedBG, trgsw1.BGBIT) 66 | } 67 | 68 | // Verify TRGSW N matches TLWE Lv1 N 69 | if trgsw1.N != tlwe1.N { 70 | t.Errorf("TRGSW N (%d) should equal TLWE Lv1 N (%d)", trgsw1.N, tlwe1.N) 71 | } 72 | } 73 | 74 | func TestSecurityInfo(t *testing.T) { 75 | info := params.SecurityInfo() 76 | if info == "" { 77 | t.Error("SecurityInfo returned empty string") 78 | } 79 | t.Logf("Security info: %s", info) 80 | } 81 | 82 | func TestKSKAndBSKAlpha(t *testing.T) { 83 | params.CurrentSecurityLevel = params.Security128Bit 84 | 85 | kskAlpha := params.KSKAlpha() 86 | bskAlpha := params.BSKAlpha() 87 | 88 | if kskAlpha <= 0 { 89 | t.Errorf("KSKAlpha must be positive, got %f", kskAlpha) 90 | } 91 | if bskAlpha <= 0 { 92 | t.Errorf("BSKAlpha must be positive, got %f", bskAlpha) 93 | } 94 | 95 | // KSK uses Lv0 alpha, BSK uses Lv1 alpha 96 | if kskAlpha != params.GetTLWELv0().ALPHA { 97 | t.Errorf("KSKAlpha (%f) should equal TLWE Lv0 ALPHA (%f)", kskAlpha, params.GetTLWELv0().ALPHA) 98 | } 99 | if bskAlpha != params.GetTLWELv1().ALPHA { 100 | t.Errorf("BSKAlpha (%f) should equal TLWE Lv1 ALPHA (%f)", bskAlpha, params.GetTLWELv1().ALPHA) 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /params/uint_params_test.go: -------------------------------------------------------------------------------- 1 | package params_test 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | "time" 7 | 8 | "github.com/thedonutfactory/go-tfhe/cloudkey" 9 | "github.com/thedonutfactory/go-tfhe/evaluator" 10 | "github.com/thedonutfactory/go-tfhe/key" 11 | "github.com/thedonutfactory/go-tfhe/lut" 12 | "github.com/thedonutfactory/go-tfhe/params" 13 | "github.com/thedonutfactory/go-tfhe/tlwe" 14 | ) 15 | 16 | // TestAllUintParameters tests all Uint parameter sets with programmable bootstrapping 17 | func TestAllUintParameters(t *testing.T) { 18 | testCases := []struct { 19 | name string 20 | secLevel params.SecurityLevel 21 | messageModulus int 22 | skipReason string // If non-empty, test will be skipped 23 | }{ 24 | {"Uint1", params.SecurityUint1, 2, ""}, 25 | {"Uint2", params.SecurityUint2, 4, ""}, 26 | {"Uint3", params.SecurityUint3, 8, ""}, 27 | {"Uint4", params.SecurityUint4, 16, ""}, 28 | {"Uint5", params.SecurityUint5, 32, ""}, 29 | {"Uint6", params.SecurityUint6, 64, "Extended LUT (polyExtendFactor=2) not fully implemented"}, 30 | {"Uint7", params.SecurityUint7, 128, "Extended LUT (polyExtendFactor=4) not fully implemented"}, 31 | {"Uint8", params.SecurityUint8, 256, "Extended LUT (polyExtendFactor=9) not fully implemented"}, 32 | } 33 | 34 | for _, tc := range testCases { 35 | t.Run(tc.name, func(t *testing.T) { 36 | if tc.skipReason != "" { 37 | t.Skipf("Skipping %s: %s", tc.name, tc.skipReason) 38 | return 39 | } 40 | testUintParameterSet(t, tc.secLevel, tc.name, tc.messageModulus) 41 | }) 42 | } 43 | } 44 | 45 | func testUintParameterSet(t *testing.T, secLevel params.SecurityLevel, name string, messageModulus int) { 46 | params.CurrentSecurityLevel = secLevel 47 | 48 | t.Logf("Testing %s with messageModulus=%d, N=%d", name, messageModulus, params.GetTRGSWLv1().N) 49 | 50 | // Generate keys 51 | keyStart := time.Now() 52 | secretKey := key.NewSecretKey() 53 | cloudKey := cloudkey.NewCloudKey(secretKey) 54 | eval := evaluator.NewEvaluator(params.GetTRGSWLv1().N) 55 | keyDuration := time.Since(keyStart) 56 | t.Logf("Key generation: %v", keyDuration) 57 | 58 | gen := lut.NewGenerator(messageModulus) 59 | 60 | // Test identity function on a subset of values 61 | t.Run("Identity", func(t *testing.T) { 62 | lutId := gen.GenLookUpTable(func(x int) int { return x }) 63 | 64 | // Test first few and last few values 65 | testValues := getTestValues(messageModulus) 66 | 67 | for _, x := range testValues { 68 | ct := tlwe.NewTLWELv0() 69 | ct.EncryptLWEMessage(x, messageModulus, params.GetTLWELv0().ALPHA, secretKey.KeyLv0) 70 | 71 | ctResult := eval.BootstrapLUT(ct, lutId, cloudKey.BootstrappingKey, cloudKey.KeySwitchingKey, cloudKey.DecompositionOffset) 72 | 73 | result := ctResult.DecryptLWEMessage(messageModulus, secretKey.KeyLv0) 74 | 75 | if result != x { 76 | t.Errorf("identity(%d) = %d, want %d", x, result, x) 77 | } 78 | } 79 | }) 80 | 81 | // Test NOT-like function (complement) 82 | t.Run("Complement", func(t *testing.T) { 83 | lutComplement := gen.GenLookUpTable(func(x int) int { 84 | return (messageModulus - 1) - x 85 | }) 86 | 87 | testValues := getTestValues(messageModulus) 88 | 89 | for _, x := range testValues { 90 | ct := tlwe.NewTLWELv0() 91 | ct.EncryptLWEMessage(x, messageModulus, params.GetTLWELv0().ALPHA, secretKey.KeyLv0) 92 | 93 | ctResult := eval.BootstrapLUT(ct, lutComplement, cloudKey.BootstrappingKey, cloudKey.KeySwitchingKey, cloudKey.DecompositionOffset) 94 | 95 | result := ctResult.DecryptLWEMessage(messageModulus, secretKey.KeyLv0) 96 | expected := (messageModulus - 1) - x 97 | 98 | if result != expected { 99 | t.Errorf("complement(%d) = %d, want %d", x, result, expected) 100 | } 101 | } 102 | }) 103 | 104 | // Test modulo function 105 | t.Run("Modulo", func(t *testing.T) { 106 | modValue := messageModulus / 2 107 | lutMod := gen.GenLookUpTable(func(x int) int { 108 | return x % modValue 109 | }) 110 | 111 | testValues := getTestValues(messageModulus) 112 | 113 | for _, x := range testValues { 114 | ct := tlwe.NewTLWELv0() 115 | ct.EncryptLWEMessage(x, messageModulus, params.GetTLWELv0().ALPHA, secretKey.KeyLv0) 116 | 117 | ctResult := eval.BootstrapLUT(ct, lutMod, cloudKey.BootstrappingKey, cloudKey.KeySwitchingKey, cloudKey.DecompositionOffset) 118 | 119 | result := ctResult.DecryptLWEMessage(messageModulus, secretKey.KeyLv0) 120 | expected := x % modValue 121 | 122 | if result != expected { 123 | t.Errorf("(%d %% %d) = %d, want %d", x, modValue, result, expected) 124 | } 125 | } 126 | }) 127 | } 128 | 129 | // getTestValues returns a subset of test values to keep tests fast 130 | // Tests first 3, middle value, and last 3 values 131 | func getTestValues(max int) []int { 132 | if max <= 8 { 133 | // Small modulus: test all values 134 | result := make([]int, max) 135 | for i := 0; i < max; i++ { 136 | result[i] = i 137 | } 138 | return result 139 | } 140 | 141 | // Large modulus: test subset 142 | return []int{ 143 | 0, 1, 2, // First few 144 | max / 2, // Middle 145 | max - 3, max - 2, max - 1, // Last few 146 | } 147 | } 148 | 149 | // BenchmarkUintParameters benchmarks key generation for all Uint parameter sets 150 | func BenchmarkUintParameters(b *testing.B) { 151 | paramSets := []struct { 152 | name string 153 | secLevel params.SecurityLevel 154 | }{ 155 | {"Uint1", params.SecurityUint1}, 156 | {"Uint2", params.SecurityUint2}, 157 | {"Uint3", params.SecurityUint3}, 158 | {"Uint4", params.SecurityUint4}, 159 | {"Uint5", params.SecurityUint5}, 160 | {"Uint6", params.SecurityUint6}, 161 | {"Uint7", params.SecurityUint7}, 162 | {"Uint8", params.SecurityUint8}, 163 | } 164 | 165 | for _, ps := range paramSets { 166 | b.Run(fmt.Sprintf("KeyGen/%s", ps.name), func(b *testing.B) { 167 | params.CurrentSecurityLevel = ps.secLevel 168 | b.ResetTimer() 169 | 170 | for i := 0; i < b.N; i++ { 171 | secretKey := key.NewSecretKey() 172 | _ = cloudkey.NewCloudKey(secretKey) 173 | } 174 | }) 175 | } 176 | } 177 | 178 | // BenchmarkPBS benchmarks programmable bootstrapping for each Uint parameter set 179 | func BenchmarkPBS(b *testing.B) { 180 | paramSets := []struct { 181 | name string 182 | secLevel params.SecurityLevel 183 | messageModulus int 184 | }{ 185 | {"Uint1", params.SecurityUint1, 2}, 186 | {"Uint2", params.SecurityUint2, 4}, 187 | {"Uint3", params.SecurityUint3, 8}, 188 | {"Uint4", params.SecurityUint4, 16}, 189 | {"Uint5", params.SecurityUint5, 32}, 190 | {"Uint6", params.SecurityUint6, 64}, 191 | {"Uint7", params.SecurityUint7, 128}, 192 | {"Uint8", params.SecurityUint8, 256}, 193 | } 194 | 195 | for _, ps := range paramSets { 196 | b.Run(ps.name, func(b *testing.B) { 197 | params.CurrentSecurityLevel = ps.secLevel 198 | 199 | secretKey := key.NewSecretKey() 200 | cloudKey := cloudkey.NewCloudKey(secretKey) 201 | eval := evaluator.NewEvaluator(params.GetTRGSWLv1().N) 202 | 203 | gen := lut.NewGenerator(ps.messageModulus) 204 | lutId := gen.GenLookUpTable(func(x int) int { return x }) 205 | 206 | ct := tlwe.NewTLWELv0() 207 | ct.EncryptLWEMessage(1, ps.messageModulus, params.GetTLWELv0().ALPHA, secretKey.KeyLv0) 208 | 209 | b.ResetTimer() 210 | 211 | for i := 0; i < b.N; i++ { 212 | _ = eval.BootstrapLUT(ct, lutId, cloudKey.BootstrappingKey, cloudKey.KeySwitchingKey, cloudKey.DecompositionOffset) 213 | } 214 | }) 215 | } 216 | } 217 | 218 | // TestUintParameterProperties verifies parameter properties 219 | func TestUintParameterProperties(t *testing.T) { 220 | testCases := []struct { 221 | name string 222 | secLevel params.SecurityLevel 223 | expectedN int 224 | expectedLweN int 225 | messageModulus int 226 | }{ 227 | {"Uint1", params.SecurityUint1, 1024, 700, 2}, 228 | {"Uint2", params.SecurityUint2, 512, 687, 4}, 229 | {"Uint3", params.SecurityUint3, 1024, 820, 8}, 230 | {"Uint4", params.SecurityUint4, 2048, 820, 16}, 231 | {"Uint5", params.SecurityUint5, 2048, 1071, 32}, 232 | {"Uint6", params.SecurityUint6, 2048, 1071, 64}, 233 | {"Uint7", params.SecurityUint7, 2048, 1160, 128}, 234 | {"Uint8", params.SecurityUint8, 2048, 1160, 256}, 235 | } 236 | 237 | for _, tc := range testCases { 238 | t.Run(tc.name, func(t *testing.T) { 239 | params.CurrentSecurityLevel = tc.secLevel 240 | 241 | n := params.GetTRGSWLv1().N 242 | lweN := params.GetTLWELv0().N 243 | 244 | if n != tc.expectedN { 245 | t.Errorf("Polynomial degree: got %d, want %d", n, tc.expectedN) 246 | } 247 | 248 | if lweN != tc.expectedLweN { 249 | t.Errorf("LWE dimension: got %d, want %d", lweN, tc.expectedLweN) 250 | } 251 | 252 | // Verify other parameters are set 253 | if params.GetTLWELv0().ALPHA == 0 { 254 | t.Error("LWE noise not set") 255 | } 256 | 257 | if params.GetTRGSWLv1().BG == 0 { 258 | t.Error("TRGSW base not set") 259 | } 260 | 261 | t.Logf("%s: N=%d, LWE_N=%d, messageModulus=%d", tc.name, n, lweN, tc.messageModulus) 262 | }) 263 | } 264 | } 265 | -------------------------------------------------------------------------------- /poly/aligned.go: -------------------------------------------------------------------------------- 1 | package poly 2 | 3 | import "github.com/thedonutfactory/go-tfhe/params" 4 | 5 | // Memory alignment utilities for better cache performance 6 | 7 | // NewPolyAligned creates a polynomial with cache-line aligned memory 8 | // This helps with SIMD operations and cache efficiency 9 | func NewPolyAligned(N int) Poly { 10 | if !isPowerOfTwo(N) { 11 | panic("degree not power of two") 12 | } 13 | if N < MinDegree { 14 | panic("degree smaller than MinDegree") 15 | } 16 | 17 | // Allocate with extra space for alignment 18 | // Cache lines are typically 64 bytes = 16 x uint32 19 | const cacheLineSize = 16 20 | coeffs := make([]params.Torus, N+cacheLineSize) 21 | 22 | // Find aligned offset 23 | offset := 0 24 | addr := uintptr(0) 25 | if len(coeffs) > 0 { 26 | addr = uintptr(len(coeffs)) % cacheLineSize 27 | if addr != 0 { 28 | offset = int(cacheLineSize - addr) 29 | } 30 | } 31 | 32 | // Return slice starting at aligned offset 33 | return Poly{Coeffs: coeffs[offset : offset+N]} 34 | } 35 | 36 | // NewFourierPolyAligned creates a fourier polynomial with cache-line aligned memory 37 | func NewFourierPolyAligned(N int) FourierPoly { 38 | if !isPowerOfTwo(N) { 39 | panic("degree not power of two") 40 | } 41 | if N < MinDegree { 42 | panic("degree smaller than MinDegree") 43 | } 44 | 45 | // Allocate with extra space for alignment 46 | // Cache lines are typically 64 bytes = 8 x float64 47 | const cacheLineSize = 8 48 | coeffs := make([]float64, N+cacheLineSize) 49 | 50 | // Find aligned offset 51 | offset := 0 52 | addr := uintptr(0) 53 | if len(coeffs) > 0 { 54 | addr = uintptr(len(coeffs)) % cacheLineSize 55 | if addr != 0 { 56 | offset = int(cacheLineSize - addr) 57 | } 58 | } 59 | 60 | // Return slice starting at aligned offset 61 | return FourierPoly{Coeffs: coeffs[offset : offset+N]} 62 | } 63 | -------------------------------------------------------------------------------- /poly/buffer_manager.go: -------------------------------------------------------------------------------- 1 | package poly 2 | 3 | import "github.com/thedonutfactory/go-tfhe/params" 4 | 5 | // BufferManager centralizes all polynomial operation buffers 6 | // This provides a single, well-documented place to manage FFT, decomposition, and rotation buffers 7 | type BufferManager struct { 8 | // Polynomial degree 9 | n int 10 | 11 | // === FFT Buffers === 12 | 13 | // Forward/Inverse FFT working buffers 14 | FFT struct { 15 | Poly Poly // Time domain working buffer 16 | Fourier FourierPoly // Frequency domain working buffer 17 | } 18 | 19 | // === Decomposition Buffers === 20 | 21 | Decomposition struct { 22 | // Decomposed polynomials in time domain [level] 23 | Poly []Poly 24 | // Decomposed polynomials in Fourier domain [level] 25 | Fourier []FourierPoly 26 | } 27 | 28 | // === Multiplication Buffers === 29 | 30 | Multiplication struct { 31 | // Result accumulators in Fourier domain 32 | AccA FourierPoly 33 | AccB FourierPoly 34 | // Temporary buffer for operations 35 | Temp FourierPoly 36 | } 37 | 38 | // === Rotation Buffers === 39 | 40 | Rotation struct { 41 | // Pool of polynomials for X^k multiplication 42 | Pool []Poly 43 | InUse int // Number currently in use 44 | 45 | // TRLWE rotation buffers 46 | TRLWEPool []*struct { 47 | A []params.Torus 48 | B []params.Torus 49 | } 50 | TRLWEInUse int 51 | } 52 | 53 | // === Temporary Buffers === 54 | 55 | // General-purpose temporary buffers 56 | Temp struct { 57 | Poly1 Poly 58 | Poly2 Poly 59 | Poly3 Poly 60 | } 61 | } 62 | 63 | // NewBufferManager creates a new centralized buffer manager 64 | func NewBufferManager(n int) *BufferManager { 65 | l := params.GetTRGSWLv1().L 66 | 67 | bm := &BufferManager{n: n} 68 | 69 | // Initialize FFT buffers 70 | bm.FFT.Poly = NewPoly(n) 71 | bm.FFT.Fourier = NewFourierPoly(n) 72 | 73 | // Initialize decomposition buffers for 2*L levels (A and B components) 74 | bm.Decomposition.Poly = make([]Poly, l*2) 75 | bm.Decomposition.Fourier = make([]FourierPoly, l*2) 76 | for i := 0; i < l*2; i++ { 77 | bm.Decomposition.Poly[i] = NewPoly(n) 78 | bm.Decomposition.Fourier[i] = NewFourierPoly(n) 79 | } 80 | 81 | // Initialize multiplication buffers 82 | bm.Multiplication.AccA = NewFourierPoly(n) 83 | bm.Multiplication.AccB = NewFourierPoly(n) 84 | bm.Multiplication.Temp = NewFourierPoly(n) 85 | 86 | // Initialize rotation pool (4 polynomials should be enough for most operations) 87 | bm.Rotation.Pool = make([]Poly, 4) 88 | for i := 0; i < 4; i++ { 89 | bm.Rotation.Pool[i] = NewPoly(n) 90 | } 91 | bm.Rotation.InUse = 0 92 | 93 | // Initialize TRLWE rotation pool 94 | bm.Rotation.TRLWEPool = make([]*struct { 95 | A []params.Torus 96 | B []params.Torus 97 | }, 4) 98 | for i := 0; i < 4; i++ { 99 | bm.Rotation.TRLWEPool[i] = &struct { 100 | A []params.Torus 101 | B []params.Torus 102 | }{ 103 | A: make([]params.Torus, n), 104 | B: make([]params.Torus, n), 105 | } 106 | } 107 | bm.Rotation.TRLWEInUse = 0 108 | 109 | // Initialize temporary buffers 110 | bm.Temp.Poly1 = NewPoly(n) 111 | bm.Temp.Poly2 = NewPoly(n) 112 | bm.Temp.Poly3 = NewPoly(n) 113 | 114 | return bm 115 | } 116 | 117 | // GetRotationBuffer returns a polynomial buffer for rotation operations 118 | func (bm *BufferManager) GetRotationBuffer() Poly { 119 | if bm.Rotation.InUse >= len(bm.Rotation.Pool) { 120 | // Wrap around if we run out (should rarely happen) 121 | bm.Rotation.InUse = 0 122 | } 123 | buffer := bm.Rotation.Pool[bm.Rotation.InUse] 124 | bm.Rotation.InUse++ 125 | return buffer 126 | } 127 | 128 | // GetTRLWEBuffer returns a TRLWE buffer (A, B components) 129 | func (bm *BufferManager) GetTRLWEBuffer() ([]params.Torus, []params.Torus) { 130 | if bm.Rotation.TRLWEInUse >= len(bm.Rotation.TRLWEPool) { 131 | bm.Rotation.TRLWEInUse = 0 132 | } 133 | buffer := bm.Rotation.TRLWEPool[bm.Rotation.TRLWEInUse] 134 | bm.Rotation.TRLWEInUse++ 135 | return buffer.A, buffer.B 136 | } 137 | 138 | // Reset resets all buffer indices 139 | func (bm *BufferManager) Reset() { 140 | bm.Rotation.InUse = 0 141 | bm.Rotation.TRLWEInUse = 0 142 | } 143 | 144 | // MemoryUsage returns approximate memory usage in bytes 145 | func (bm *BufferManager) MemoryUsage() int { 146 | n := bm.n 147 | l := params.GetTRGSWLv1().L 148 | 149 | // Poly: N * 4 bytes, FourierPoly: N * 8 * 2 bytes (complex) 150 | polySize := n * 4 151 | fourierSize := n * 8 * 2 152 | 153 | mem := 0 154 | 155 | // FFT buffers 156 | mem += polySize + fourierSize 157 | 158 | // Decomposition buffers (2*L levels) 159 | mem += (polySize + fourierSize) * l * 2 160 | 161 | // Multiplication buffers 162 | mem += fourierSize * 3 163 | 164 | // Rotation pool 165 | mem += polySize * len(bm.Rotation.Pool) 166 | mem += polySize * 2 * len(bm.Rotation.TRLWEPool) // A and B 167 | 168 | // Temp buffers 169 | mem += polySize * 3 170 | 171 | return mem 172 | } 173 | -------------------------------------------------------------------------------- /poly/buffer_methods.go: -------------------------------------------------------------------------------- 1 | package poly 2 | 3 | import "github.com/thedonutfactory/go-tfhe/params" 4 | 5 | // ============================================================================ 6 | // UNIFIED BUFFER METHODS 7 | // ============================================================================ 8 | // All buffer pool operations consolidated in one place for clarity. 9 | // These methods operate on poly.Evaluator.buffer (evaluationBuffer struct). 10 | 11 | // ============================================================================ 12 | // FOURIER BUFFER OPERATIONS 13 | // ============================================================================ 14 | 15 | // ClearBuffer clears a named Fourier buffer (sets all coefficients to zero) 16 | func (e *Evaluator) ClearBuffer(name string) { 17 | switch name { 18 | case "fpAcc": 19 | e.buffer.fpAcc.Clear() 20 | case "fpBcc": 21 | e.buffer.fpBcc.Clear() 22 | case "fpDiff": 23 | e.buffer.fpDiff.Clear() 24 | case "fpMul1": 25 | e.buffer.fpMul1.Clear() 26 | case "fpMul2": 27 | e.buffer.fpMul2.Clear() 28 | default: 29 | panic("unknown buffer name: " + name) 30 | } 31 | } 32 | 33 | // MulAddFourierPolyAssignBuffered performs fpOut += decompFFT[idx] * fp 34 | // using the pre-allocated decomposition buffer 35 | func (e *Evaluator) MulAddFourierPolyAssignBuffered(idx int, fp FourierPoly, bufferName string) { 36 | var fpOut *FourierPoly 37 | switch bufferName { 38 | case "fpAcc": 39 | fpOut = &e.buffer.fpAcc 40 | case "fpBcc": 41 | fpOut = &e.buffer.fpBcc 42 | default: 43 | panic("unknown buffer name: " + bufferName) 44 | } 45 | 46 | // Use the pre-computed FFT from decomposition buffer 47 | e.MulAddFourierPolyAssign(e.buffer.decompFFT[idx], fp, *fpOut) 48 | } 49 | 50 | // BufferToPolyAssign converts a buffer from frequency domain to time domain 51 | // and writes directly to the output slice (zero-allocation) 52 | func (e *Evaluator) BufferToPolyAssign(bufferName string, out []params.Torus) { 53 | var fp *FourierPoly 54 | switch bufferName { 55 | case "fpAcc": 56 | fp = &e.buffer.fpAcc 57 | case "fpBcc": 58 | fp = &e.buffer.fpBcc 59 | case "fpDiff": 60 | fp = &e.buffer.fpDiff 61 | default: 62 | panic("unknown buffer name: " + bufferName) 63 | } 64 | 65 | // Use unsafe conversion to avoid allocation 66 | pOut := Poly{Coeffs: out} 67 | e.ToPolyAssignUnsafe(*fp, pOut) 68 | } 69 | 70 | // ============================================================================ 71 | // DECOMPOSITION BUFFER OPERATIONS 72 | // ============================================================================ 73 | 74 | // GetDecompBuffer returns the i-th decomposition buffer for direct write 75 | func (e *Evaluator) GetDecompBuffer(i int) *Poly { 76 | if i >= len(e.buffer.decompBuffer) { 77 | panic("decomposition buffer index out of range") 78 | } 79 | return &e.buffer.decompBuffer[i] 80 | } 81 | 82 | // GetDecompFFTBuffer returns the i-th decomposition FFT buffer 83 | func (e *Evaluator) GetDecompFFTBuffer(i int) *FourierPoly { 84 | if i >= len(e.buffer.decompFFT) { 85 | panic("decomposition FFT buffer index out of range") 86 | } 87 | return &e.buffer.decompFFT[i] 88 | } 89 | 90 | // ToFourierPolyInBuffer transforms a poly to fourier and stores in buffer 91 | func (e *Evaluator) ToFourierPolyInBuffer(p Poly, bufferIdx int) { 92 | if bufferIdx >= len(e.buffer.decompFFT) { 93 | panic("buffer index out of range") 94 | } 95 | e.ToFourierPolyAssign(p, e.buffer.decompFFT[bufferIdx]) 96 | } 97 | 98 | // CopyToDecompBuffer copies a polynomial into the decomposition buffer 99 | func (e *Evaluator) CopyToDecompBuffer(src []params.Torus, bufferIdx int) { 100 | if bufferIdx >= len(e.buffer.decompBuffer) { 101 | panic("buffer index out of range") 102 | } 103 | copy(e.buffer.decompBuffer[bufferIdx].Coeffs, src) 104 | } 105 | 106 | // ============================================================================ 107 | // ROTATION POOL OPERATIONS 108 | // ============================================================================ 109 | 110 | // GetRotationBuffer returns a rotation buffer from the pool 111 | // Uses round-robin allocation to avoid conflicts 112 | func (e *Evaluator) GetRotationBuffer() []params.Torus { 113 | buf := e.buffer.rotationPool[e.buffer.rotationIdx].Coeffs 114 | e.buffer.rotationIdx = (e.buffer.rotationIdx + 1) % len(e.buffer.rotationPool) 115 | return buf 116 | } 117 | 118 | // ResetRotationPool resets the rotation buffer pool index 119 | // Call this at the start of a new operation to ensure clean state 120 | func (e *Evaluator) ResetRotationPool() { 121 | e.buffer.rotationIdx = 0 122 | } 123 | 124 | // PolyMulWithXK multiplies a polynomial by X^k using a pooled buffer (zero-allocation) 125 | func (e *Evaluator) PolyMulWithXK(a []params.Torus, k int) []params.Torus { 126 | result := e.GetRotationBuffer() 127 | PolyMulWithXKInPlace(a, k, result) 128 | return result 129 | } 130 | 131 | // PolyMulWithXKInPlace multiplies polynomial by X^k in the ring Z[X]/(X^N+1) 132 | // This is the core rotation operation used throughout TFHE 133 | func PolyMulWithXKInPlace(a []params.Torus, k int, result []params.Torus) { 134 | n := len(a) 135 | k = k % (2 * n) // Normalize k to [0, 2N) 136 | 137 | if k == 0 { 138 | copy(result, a) 139 | return 140 | } 141 | 142 | if k < 0 { 143 | k += 2 * n 144 | } 145 | 146 | if k < n { 147 | // Positive rotation: coefficients shift right, wrap with negation 148 | for i := 0; i < n-k; i++ { 149 | result[i+k] = a[i] 150 | } 151 | for i := n - k; i < n; i++ { 152 | result[i+k-n] = ^params.Torus(0) - a[i] 153 | } 154 | } else { 155 | // Rotation >= n: all coefficients get negated 156 | k -= n 157 | for i := 0; i < n-k; i++ { 158 | result[i+k] = ^params.Torus(0) - a[i] 159 | } 160 | for i := n - k; i < n; i++ { 161 | result[i+k-n] = a[i] 162 | } 163 | } 164 | } 165 | 166 | // PolyMulWithXKDirect multiplies by X^k and writes to provided buffer (zero-allocation) 167 | func (e *Evaluator) PolyMulWithXKDirect(a []params.Torus, k int, result []params.Torus) { 168 | PolyMulWithXKInPlace(a, k, result) 169 | } 170 | 171 | // ============================================================================ 172 | // TRLWE POOL OPERATIONS 173 | // ============================================================================ 174 | 175 | // GetTRLWEBuffer returns a TRLWE buffer from the pool 176 | // Returns (A, B) slices that can be used to construct a TRLWE 177 | func (e *Evaluator) GetTRLWEBuffer() ([]params.Torus, []params.Torus) { 178 | buf := &e.buffer.trlwePool[e.buffer.trlweIdx] 179 | e.buffer.trlweIdx = (e.buffer.trlweIdx + 1) % len(e.buffer.trlwePool) 180 | return buf.A, buf.B 181 | } 182 | 183 | // ResetTRLWEPool resets the TRLWE pool index 184 | func (e *Evaluator) ResetTRLWEPool() { 185 | e.buffer.trlweIdx = 0 186 | } 187 | 188 | // ClearTRLWEBuffer clears a TRLWE buffer 189 | func (e *Evaluator) ClearTRLWEBuffer(a, b []params.Torus) { 190 | for i := range a { 191 | a[i] = 0 192 | b[i] = 0 193 | } 194 | } 195 | -------------------------------------------------------------------------------- /poly/decomposer.go: -------------------------------------------------------------------------------- 1 | package poly 2 | 3 | import "github.com/thedonutfactory/go-tfhe/params" 4 | 5 | // Decomposer performs gadget decomposition with pre-allocated buffers 6 | // This achieves zero-allocation decomposition operations 7 | type Decomposer struct { 8 | buffer decompositionBuffer 9 | } 10 | 11 | // decompositionBuffer contains pre-allocated buffers for decomposition 12 | type decompositionBuffer struct { 13 | // polyDecomposed is the pre-allocated buffer for polynomial decomposition 14 | polyDecomposed []Poly 15 | // polyFourierDecomposed is the pre-allocated buffer for Fourier-domain decomposition 16 | polyFourierDecomposed []FourierPoly 17 | } 18 | 19 | // NewDecomposer creates a new Decomposer with buffers for up to maxLevel decomposition levels 20 | func NewDecomposer(N int, maxLevel int) *Decomposer { 21 | polyDecomposed := make([]Poly, maxLevel) 22 | polyFourierDecomposed := make([]FourierPoly, maxLevel) 23 | 24 | for i := 0; i < maxLevel; i++ { 25 | polyDecomposed[i] = NewPoly(N) 26 | polyFourierDecomposed[i] = NewFourierPoly(N) 27 | } 28 | 29 | return &Decomposer{ 30 | buffer: decompositionBuffer{ 31 | polyDecomposed: polyDecomposed, 32 | polyFourierDecomposed: polyFourierDecomposed, 33 | }, 34 | } 35 | } 36 | 37 | // GetPolyDecomposedBuffer returns the decomposition buffer for polynomial 38 | func (d *Decomposer) GetPolyDecomposedBuffer(level int) []Poly { 39 | if level > len(d.buffer.polyDecomposed) { 40 | panic("decomposition level exceeds buffer size") 41 | } 42 | return d.buffer.polyDecomposed[:level] 43 | } 44 | 45 | // GetPolyFourierDecomposedBuffer returns the Fourier decomposition buffer 46 | func (d *Decomposer) GetPolyFourierDecomposedBuffer(level int) []FourierPoly { 47 | if level > len(d.buffer.polyFourierDecomposed) { 48 | panic("decomposition level exceeds buffer size") 49 | } 50 | return d.buffer.polyFourierDecomposed[:level] 51 | } 52 | 53 | // DecomposePolyAssign decomposes polynomial p into decomposedOut using gadget decomposition 54 | // This writes directly to the provided buffer (zero-allocation) 55 | func DecomposePolyAssign(p []params.Torus, bgbit, level int, offset params.Torus, decomposedOut []Poly) { 56 | n := len(p) 57 | mask := params.Torus((1 << bgbit) - 1) 58 | halfBG := params.Torus(1 << (bgbit - 1)) 59 | 60 | for j := 0; j < n; j++ { 61 | tmp := p[j] + offset 62 | for i := 0; i < level; i++ { 63 | decomposedOut[i].Coeffs[j] = ((tmp >> (32 - (uint32(i)+1)*uint32(bgbit))) & mask) - halfBG 64 | } 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /poly/fourier_ops.go: -------------------------------------------------------------------------------- 1 | package poly 2 | 3 | import "unsafe" 4 | 5 | // AddFourierPoly returns fp0 + fp1. 6 | func (e *Evaluator) AddFourierPoly(fp0, fp1 FourierPoly) FourierPoly { 7 | fpOut := e.NewFourierPoly() 8 | e.AddFourierPolyAssign(fp0, fp1, fpOut) 9 | return fpOut 10 | } 11 | 12 | // AddFourierPolyAssign computes fpOut = fp0 + fp1. 13 | func (e *Evaluator) AddFourierPolyAssign(fp0, fp1, fpOut FourierPoly) { 14 | addCmplxAssign(fp0.Coeffs, fp1.Coeffs, fpOut.Coeffs) 15 | } 16 | 17 | // SubFourierPoly returns fp0 - fp1. 18 | func (e *Evaluator) SubFourierPoly(fp0, fp1 FourierPoly) FourierPoly { 19 | fpOut := e.NewFourierPoly() 20 | e.SubFourierPolyAssign(fp0, fp1, fpOut) 21 | return fpOut 22 | } 23 | 24 | // SubFourierPolyAssign computes fpOut = fp0 - fp1. 25 | func (e *Evaluator) SubFourierPolyAssign(fp0, fp1, fpOut FourierPoly) { 26 | subCmplxAssign(fp0.Coeffs, fp1.Coeffs, fpOut.Coeffs) 27 | } 28 | 29 | // MulFourierPoly returns fp0 * fp1. 30 | func (e *Evaluator) MulFourierPoly(fp0, fp1 FourierPoly) FourierPoly { 31 | fpOut := e.NewFourierPoly() 32 | e.MulFourierPolyAssign(fp0, fp1, fpOut) 33 | return fpOut 34 | } 35 | 36 | // MulFourierPolyAssign computes fpOut = fp0 * fp1. 37 | // This is element-wise complex multiplication in the frequency domain. 38 | func (e *Evaluator) MulFourierPolyAssign(fp0, fp1, fpOut FourierPoly) { 39 | elementWiseMulCmplxAssign(fp0.Coeffs, fp1.Coeffs, fpOut.Coeffs) 40 | } 41 | 42 | // MulAddFourierPolyAssign computes fpOut += fp0 * fp1. 43 | func (e *Evaluator) MulAddFourierPolyAssign(fp0, fp1, fpOut FourierPoly) { 44 | elementWiseMulAddCmplxAssign(fp0.Coeffs, fp1.Coeffs, fpOut.Coeffs) 45 | } 46 | 47 | // MulSubFourierPolyAssign computes fpOut -= fp0 * fp1. 48 | func (e *Evaluator) MulSubFourierPolyAssign(fp0, fp1, fpOut FourierPoly) { 49 | elementWiseMulSubCmplxAssign(fp0.Coeffs, fp1.Coeffs, fpOut.Coeffs) 50 | } 51 | 52 | // FloatMulFourierPolyAssign computes fpOut = c * fp0. 53 | func (e *Evaluator) FloatMulFourierPolyAssign(fp0 FourierPoly, c float64, fpOut FourierPoly) { 54 | floatMulCmplxAssign(fp0.Coeffs, c, fpOut.Coeffs) 55 | } 56 | 57 | // FloatMulAddFourierPolyAssign computes fpOut += c * fp0. 58 | func (e *Evaluator) FloatMulAddFourierPolyAssign(fp0 FourierPoly, c float64, fpOut FourierPoly) { 59 | floatMulAddCmplxAssign(fp0.Coeffs, c, fpOut.Coeffs) 60 | } 61 | 62 | // addCmplxAssign computes vOut = v0 + v1. 63 | func addCmplxAssign(v0, v1, vOut []float64) { 64 | for i := 0; i < len(vOut); i += 8 { 65 | w0 := (*[8]float64)(unsafe.Pointer(&v0[i])) 66 | w1 := (*[8]float64)(unsafe.Pointer(&v1[i])) 67 | wOut := (*[8]float64)(unsafe.Pointer(&vOut[i])) 68 | 69 | wOut[0] = w0[0] + w1[0] 70 | wOut[1] = w0[1] + w1[1] 71 | wOut[2] = w0[2] + w1[2] 72 | wOut[3] = w0[3] + w1[3] 73 | 74 | wOut[4] = w0[4] + w1[4] 75 | wOut[5] = w0[5] + w1[5] 76 | wOut[6] = w0[6] + w1[6] 77 | wOut[7] = w0[7] + w1[7] 78 | } 79 | } 80 | 81 | // subCmplxAssign computes vOut = v0 - v1. 82 | func subCmplxAssign(v0, v1, vOut []float64) { 83 | for i := 0; i < len(vOut); i += 8 { 84 | w0 := (*[8]float64)(unsafe.Pointer(&v0[i])) 85 | w1 := (*[8]float64)(unsafe.Pointer(&v1[i])) 86 | wOut := (*[8]float64)(unsafe.Pointer(&vOut[i])) 87 | 88 | wOut[0] = w0[0] - w1[0] 89 | wOut[1] = w0[1] - w1[1] 90 | wOut[2] = w0[2] - w1[2] 91 | wOut[3] = w0[3] - w1[3] 92 | 93 | wOut[4] = w0[4] - w1[4] 94 | wOut[5] = w0[5] - w1[5] 95 | wOut[6] = w0[6] - w1[6] 96 | wOut[7] = w0[7] - w1[7] 97 | } 98 | } 99 | 100 | // floatMulCmplxAssign computes vOut = c * v0. 101 | func floatMulCmplxAssign(v0 []float64, c float64, vOut []float64) { 102 | for i := 0; i < len(vOut); i += 8 { 103 | w0 := (*[8]float64)(unsafe.Pointer(&v0[i])) 104 | wOut := (*[8]float64)(unsafe.Pointer(&vOut[i])) 105 | 106 | wOut[0] = c * w0[0] 107 | wOut[1] = c * w0[1] 108 | wOut[2] = c * w0[2] 109 | wOut[3] = c * w0[3] 110 | 111 | wOut[4] = c * w0[4] 112 | wOut[5] = c * w0[5] 113 | wOut[6] = c * w0[6] 114 | wOut[7] = c * w0[7] 115 | } 116 | } 117 | 118 | // floatMulAddCmplxAssign computes vOut += c * v0. 119 | func floatMulAddCmplxAssign(v0 []float64, c float64, vOut []float64) { 120 | for i := 0; i < len(vOut); i += 8 { 121 | w0 := (*[8]float64)(unsafe.Pointer(&v0[i])) 122 | wOut := (*[8]float64)(unsafe.Pointer(&vOut[i])) 123 | 124 | wOut[0] += c * w0[0] 125 | wOut[1] += c * w0[1] 126 | wOut[2] += c * w0[2] 127 | wOut[3] += c * w0[3] 128 | 129 | wOut[4] += c * w0[4] 130 | wOut[5] += c * w0[5] 131 | wOut[6] += c * w0[6] 132 | wOut[7] += c * w0[7] 133 | } 134 | } 135 | 136 | // elementWiseMulCmplxAssign computes vOut = v0 * v1 (element-wise complex multiplication). 137 | // This is the key operation for polynomial multiplication in the frequency domain. 138 | func elementWiseMulCmplxAssign(v0, v1, vOut []float64) { 139 | var vOutR, vOutI float64 140 | 141 | for i := 0; i < len(vOut); i += 8 { 142 | w0 := (*[8]float64)(unsafe.Pointer(&v0[i])) 143 | w1 := (*[8]float64)(unsafe.Pointer(&v1[i])) 144 | wOut := (*[8]float64)(unsafe.Pointer(&vOut[i])) 145 | 146 | // Complex multiplication: (a + bi)(c + di) = (ac - bd) + (ad + bc)i 147 | // Real part stored in first 4 floats, imaginary in last 4 148 | vOutR = w0[0]*w1[0] - w0[4]*w1[4] 149 | vOutI = w0[0]*w1[4] + w0[4]*w1[0] 150 | wOut[0], wOut[4] = vOutR, vOutI 151 | 152 | vOutR = w0[1]*w1[1] - w0[5]*w1[5] 153 | vOutI = w0[1]*w1[5] + w0[5]*w1[1] 154 | wOut[1], wOut[5] = vOutR, vOutI 155 | 156 | vOutR = w0[2]*w1[2] - w0[6]*w1[6] 157 | vOutI = w0[2]*w1[6] + w0[6]*w1[2] 158 | wOut[2], wOut[6] = vOutR, vOutI 159 | 160 | vOutR = w0[3]*w1[3] - w0[7]*w1[7] 161 | vOutI = w0[3]*w1[7] + w0[7]*w1[3] 162 | wOut[3], wOut[7] = vOutR, vOutI 163 | } 164 | } 165 | 166 | // elementWiseMulAddCmplxAssign computes vOut += v0 * v1. 167 | func elementWiseMulAddCmplxAssign(v0, v1, vOut []float64) { 168 | var vOutR, vOutI float64 169 | 170 | for i := 0; i < len(vOut); i += 8 { 171 | w0 := (*[8]float64)(unsafe.Pointer(&v0[i])) 172 | w1 := (*[8]float64)(unsafe.Pointer(&v1[i])) 173 | wOut := (*[8]float64)(unsafe.Pointer(&vOut[i])) 174 | 175 | vOutR = wOut[0] + (w0[0]*w1[0] - w0[4]*w1[4]) 176 | vOutI = wOut[4] + (w0[0]*w1[4] + w0[4]*w1[0]) 177 | wOut[0], wOut[4] = vOutR, vOutI 178 | 179 | vOutR = wOut[1] + (w0[1]*w1[1] - w0[5]*w1[5]) 180 | vOutI = wOut[5] + (w0[1]*w1[5] + w0[5]*w1[1]) 181 | wOut[1], wOut[5] = vOutR, vOutI 182 | 183 | vOutR = wOut[2] + (w0[2]*w1[2] - w0[6]*w1[6]) 184 | vOutI = wOut[6] + (w0[2]*w1[6] + w0[6]*w1[2]) 185 | wOut[2], wOut[6] = vOutR, vOutI 186 | 187 | vOutR = wOut[3] + (w0[3]*w1[3] - w0[7]*w1[7]) 188 | vOutI = wOut[7] + (w0[3]*w1[7] + w0[7]*w1[3]) 189 | wOut[3], wOut[7] = vOutR, vOutI 190 | } 191 | } 192 | 193 | // elementWiseMulSubCmplxAssign computes vOut -= v0 * v1. 194 | func elementWiseMulSubCmplxAssign(v0, v1, vOut []float64) { 195 | var vOutR, vOutI float64 196 | 197 | for i := 0; i < len(vOut); i += 8 { 198 | w0 := (*[8]float64)(unsafe.Pointer(&v0[i])) 199 | w1 := (*[8]float64)(unsafe.Pointer(&v1[i])) 200 | wOut := (*[8]float64)(unsafe.Pointer(&vOut[i])) 201 | 202 | vOutR = wOut[0] - (w0[0]*w1[0] - w0[4]*w1[4]) 203 | vOutI = wOut[4] - (w0[0]*w1[4] + w0[4]*w1[0]) 204 | wOut[0], wOut[4] = vOutR, vOutI 205 | 206 | vOutR = wOut[1] - (w0[1]*w1[1] - w0[5]*w1[5]) 207 | vOutI = wOut[5] - (w0[1]*w1[5] + w0[5]*w1[1]) 208 | wOut[1], wOut[5] = vOutR, vOutI 209 | 210 | vOutR = wOut[2] - (w0[2]*w1[2] - w0[6]*w1[6]) 211 | vOutI = wOut[6] - (w0[2]*w1[6] + w0[6]*w1[2]) 212 | wOut[2], wOut[6] = vOutR, vOutI 213 | 214 | vOutR = wOut[3] - (w0[3]*w1[3] - w0[7]*w1[7]) 215 | vOutI = wOut[7] - (w0[3]*w1[7] + w0[7]*w1[3]) 216 | wOut[3], wOut[7] = vOutR, vOutI 217 | } 218 | } 219 | -------------------------------------------------------------------------------- /poly/poly.go: -------------------------------------------------------------------------------- 1 | // Package poly implements optimized polynomial operations for TFHE. 2 | // Based on the high-performance implementation from tfhe-go. 3 | package poly 4 | 5 | import ( 6 | "github.com/thedonutfactory/go-tfhe/params" 7 | ) 8 | 9 | const ( 10 | // MinDegree is the minimum degree of polynomial that Evaluator can handle. 11 | // Set to 2^4 because SIMD operations handle 4 values at a time. 12 | MinDegree = 1 << 4 13 | 14 | // splitLogBound denotes the maximum bits for polynomial multiplication. 15 | // This ensures failure rate less than 2^-284. 16 | splitLogBound = 48 17 | ) 18 | 19 | // Poly is a polynomial over Z_Q[X]/(X^N + 1). 20 | type Poly struct { 21 | Coeffs []params.Torus 22 | } 23 | 24 | // NewPoly creates a polynomial with degree N. 25 | func NewPoly(N int) Poly { 26 | if !isPowerOfTwo(N) { 27 | panic("degree not power of two") 28 | } 29 | if N < MinDegree { 30 | panic("degree smaller than MinDegree") 31 | } 32 | return Poly{Coeffs: make([]params.Torus, N)} 33 | } 34 | 35 | // Degree returns the degree of the polynomial. 36 | func (p Poly) Degree() int { 37 | return len(p.Coeffs) 38 | } 39 | 40 | // Copy returns a copy of the polynomial. 41 | func (p Poly) Copy() Poly { 42 | coeffsCopy := make([]params.Torus, len(p.Coeffs)) 43 | copy(coeffsCopy, p.Coeffs) 44 | return Poly{Coeffs: coeffsCopy} 45 | } 46 | 47 | // Clear clears all coefficients to zero. 48 | func (p Poly) Clear() { 49 | for i := range p.Coeffs { 50 | p.Coeffs[i] = 0 51 | } 52 | } 53 | 54 | // FourierPoly is a fourier transformed polynomial over C[X]/(X^N/2 + 1). 55 | // This corresponds to a polynomial over Z_Q[X]/(X^N + 1). 56 | // 57 | // Coeffs are represented as float-4 complex vector for efficient computation: 58 | // [(r0, r1, r2, r3), (i0, i1, i2, i3), ...] 59 | // instead of standard [(r0, i0), (r1, i1), (r2, i2), (r3, i3), ...] 60 | type FourierPoly struct { 61 | Coeffs []float64 62 | } 63 | 64 | // NewFourierPoly creates a fourier polynomial with degree N. 65 | func NewFourierPoly(N int) FourierPoly { 66 | if !isPowerOfTwo(N) { 67 | panic("degree not power of two") 68 | } 69 | if N < MinDegree { 70 | panic("degree smaller than MinDegree") 71 | } 72 | return FourierPoly{Coeffs: make([]float64, N)} 73 | } 74 | 75 | // Degree returns the degree of the polynomial. 76 | func (p FourierPoly) Degree() int { 77 | return len(p.Coeffs) 78 | } 79 | 80 | // Copy returns a copy of the polynomial. 81 | func (p FourierPoly) Copy() FourierPoly { 82 | coeffsCopy := make([]float64, len(p.Coeffs)) 83 | copy(coeffsCopy, p.Coeffs) 84 | return FourierPoly{Coeffs: coeffsCopy} 85 | } 86 | 87 | // CopyFrom copies p0 to p. 88 | func (p *FourierPoly) CopyFrom(p0 FourierPoly) { 89 | copy(p.Coeffs, p0.Coeffs) 90 | } 91 | 92 | // Clear clears all coefficients to zero. 93 | func (p FourierPoly) Clear() { 94 | for i := range p.Coeffs { 95 | p.Coeffs[i] = 0 96 | } 97 | } 98 | 99 | // isPowerOfTwo checks if n is a power of two. 100 | func isPowerOfTwo(n int) bool { 101 | return n > 0 && (n&(n-1)) == 0 102 | } 103 | 104 | // log2 returns the base-2 logarithm of n. 105 | func log2(n int) int { 106 | if n <= 0 { 107 | panic("log2 of non-positive number") 108 | } 109 | log := 0 110 | for n > 1 { 111 | n >>= 1 112 | log++ 113 | } 114 | return log 115 | } 116 | -------------------------------------------------------------------------------- /poly/poly_evaluator.go: -------------------------------------------------------------------------------- 1 | package poly 2 | 3 | import ( 4 | "math" 5 | "math/cmplx" 6 | 7 | "github.com/thedonutfactory/go-tfhe/params" 8 | ) 9 | 10 | // Evaluator computes polynomial operations over the N-th cyclotomic ring. 11 | // This is optimized for TFHE operations with precomputed twiddle factors. 12 | type Evaluator struct { 13 | // degree is the degree of polynomial that this evaluator can handle. 14 | degree int 15 | // q is a float64 value of the modulus (2^32 for Torus). 16 | q float64 17 | 18 | // tw is the twiddle factors for fourier transform. 19 | tw []complex128 20 | // twInv is the twiddle factors for inverse fourier transform. 21 | twInv []complex128 22 | // twMono is the twiddle factors for monomial fourier transform. 23 | twMono []complex128 24 | // twMonoIdx is the precomputed bit-reversed index for monomial fourier transform. 25 | twMonoIdx []int 26 | 27 | buffer evaluationBuffer 28 | } 29 | 30 | // evaluationBuffer is a buffer for Evaluator. 31 | // These buffers are pre-allocated and reused across operations to achieve zero-allocation performance. 32 | // 33 | // This is the UNIFIED buffer system that consolidates all buffer management: 34 | // - FFT/IFFT working buffers 35 | // - Decomposition buffers (time and Fourier domain) 36 | // - Multiplication accumulators 37 | // - Rotation pools 38 | // - TRLWE pools 39 | // - Temporary buffers 40 | type evaluationBuffer struct { 41 | // === Core FFT Buffers === 42 | fp FourierPoly // Intermediate FFT buffer 43 | fpInv FourierPoly // Intermediate inverse FFT buffer 44 | pSplit Poly // Buffer for split operations 45 | 46 | // === External Product / Multiplication Buffers === 47 | fpMul1, fpMul2 FourierPoly // For multiplication operands 48 | fpAcc, fpBcc FourierPoly // For accumulation (A and B components) 49 | 50 | // === Decomposition Buffers === 51 | decompBuffer []Poly // Pool of decomposition results (time domain) 52 | decompFFT []FourierPoly // FFT'd decomposition results (Fourier domain) 53 | 54 | // === CMUX Buffers === 55 | fpDiff FourierPoly // For CMUX difference computation 56 | 57 | // === Temporary Buffers === 58 | pTemp Poly // General purpose temporary polynomial 59 | pRotA, pRotB Poly // Rotation results 60 | 61 | // === Rotation Pool === 62 | // Pool for polyMulWithXK operations (zero-allocation rotation) 63 | rotationPool [4]Poly // Pool of 4 rotation buffers 64 | rotationIdx int // Current rotation buffer index 65 | 66 | // === TRLWE Pool === 67 | // Pool for intermediate TRLWE results 68 | trlwePool [4]struct { 69 | A []params.Torus 70 | B []params.Torus 71 | } 72 | trlweIdx int // Current TRLWE pool index 73 | } 74 | 75 | // NewEvaluator creates a new Evaluator with degree N. 76 | func NewEvaluator(N int) *Evaluator { 77 | if !isPowerOfTwo(N) { 78 | panic("degree not power of two") 79 | } 80 | if N < MinDegree { 81 | panic("degree smaller than MinDegree") 82 | } 83 | 84 | // Q = 2^32 for Torus (uint32) 85 | Q := math.Exp2(32) 86 | 87 | tw, twInv := genTwiddleFactors(N / 2) 88 | 89 | twMono := make([]complex128, 2*N) 90 | for i := 0; i < 2*N; i++ { 91 | e := -math.Pi * float64(i) / float64(N) 92 | twMono[i] = cmplx.Exp(complex(0, e)) 93 | } 94 | 95 | twMonoIdx := make([]int, N/2) 96 | twMonoIdx[0] = 2*N - 1 97 | for i := 1; i < N/2; i++ { 98 | twMonoIdx[i] = 4*i - 1 99 | } 100 | bitReverseInPlace(twMonoIdx) 101 | 102 | return &Evaluator{ 103 | degree: N, 104 | q: Q, 105 | tw: tw, 106 | twInv: twInv, 107 | twMono: twMono, 108 | twMonoIdx: twMonoIdx, 109 | buffer: newEvaluationBuffer(N), 110 | } 111 | } 112 | 113 | // genTwiddleFactors generates twiddle factors for FFT. 114 | func genTwiddleFactors(N int) (tw, twInv []complex128) { 115 | twFFT := make([]complex128, N/2) 116 | twInvFFT := make([]complex128, N/2) 117 | for i := 0; i < N/2; i++ { 118 | e := -2 * math.Pi * float64(i) / float64(N) 119 | twFFT[i] = cmplx.Exp(complex(0, e)) 120 | twInvFFT[i] = cmplx.Exp(-complex(0, e)) 121 | } 122 | bitReverseInPlace(twFFT) 123 | bitReverseInPlace(twInvFFT) 124 | 125 | tw = make([]complex128, 0, N-1) 126 | twInv = make([]complex128, 0, N-1) 127 | 128 | for m, t := 1, N/2; m <= N/2; m, t = m<<1, t>>1 { 129 | twFold := cmplx.Exp(complex(0, 2*math.Pi*float64(t)/float64(4*N))) 130 | for i := 0; i < m; i++ { 131 | tw = append(tw, twFFT[i]*twFold) 132 | } 133 | } 134 | 135 | for m, t := N/2, 1; m >= 1; m, t = m>>1, t<<1 { 136 | twInvFold := cmplx.Exp(complex(0, -2*math.Pi*float64(t)/float64(4*N))) 137 | for i := 0; i < m; i++ { 138 | twInv = append(twInv, twInvFFT[i]*twInvFold) 139 | } 140 | } 141 | 142 | return tw, twInv 143 | } 144 | 145 | // bitReverseInPlace performs bit reversal permutation in place. 146 | func bitReverseInPlace[T any](data []T) { 147 | n := len(data) 148 | if n <= 1 { 149 | return 150 | } 151 | 152 | j := 0 153 | for i := 0; i < n; i++ { 154 | if i < j { 155 | data[i], data[j] = data[j], data[i] 156 | } 157 | // Bit reversal 158 | m := n >> 1 159 | for m > 0 && j >= m { 160 | j -= m 161 | m >>= 1 162 | } 163 | j += m 164 | } 165 | } 166 | 167 | // newEvaluationBuffer creates a new evaluationBuffer. 168 | func newEvaluationBuffer(N int) evaluationBuffer { 169 | // Pre-allocate decomposition buffers for typical TFHE parameters 170 | // L=3, so we need 3*2=6 decomposition levels 171 | const maxDecompLevels = 8 // Slightly more for safety 172 | 173 | decompBuffer := make([]Poly, maxDecompLevels) 174 | decompFFT := make([]FourierPoly, maxDecompLevels) 175 | for i := 0; i < maxDecompLevels; i++ { 176 | decompBuffer[i] = NewPoly(N) 177 | decompFFT[i] = NewFourierPoly(N) 178 | } 179 | 180 | // Initialize rotation pool 181 | var rotationPool [4]Poly 182 | for i := 0; i < 4; i++ { 183 | rotationPool[i] = NewPoly(N) 184 | } 185 | 186 | // Initialize TRLWE pool 187 | var trlwePool [4]struct { 188 | A []params.Torus 189 | B []params.Torus 190 | } 191 | for i := 0; i < 4; i++ { 192 | trlwePool[i].A = make([]params.Torus, N) 193 | trlwePool[i].B = make([]params.Torus, N) 194 | } 195 | 196 | return evaluationBuffer{ 197 | fp: NewFourierPoly(N), 198 | fpInv: NewFourierPoly(N), 199 | pSplit: NewPoly(N), 200 | 201 | // External product buffers 202 | fpMul1: NewFourierPoly(N), 203 | fpMul2: NewFourierPoly(N), 204 | fpAcc: NewFourierPoly(N), 205 | fpBcc: NewFourierPoly(N), 206 | 207 | // Decomposition buffers 208 | decompBuffer: decompBuffer, 209 | decompFFT: decompFFT, 210 | 211 | // CMUX buffers 212 | fpDiff: NewFourierPoly(N), 213 | pTemp: NewPoly(N), 214 | 215 | // Blind rotation buffers 216 | pRotA: NewPoly(N), 217 | pRotB: NewPoly(N), 218 | rotationPool: rotationPool, 219 | rotationIdx: 0, 220 | trlwePool: trlwePool, 221 | trlweIdx: 0, 222 | } 223 | } 224 | 225 | // Degree returns the degree of polynomial that the evaluator can handle. 226 | func (e *Evaluator) Degree() int { 227 | return e.degree 228 | } 229 | 230 | // NewPoly creates a new polynomial with the same degree as the evaluator. 231 | func (e *Evaluator) NewPoly() Poly { 232 | return NewPoly(e.degree) 233 | } 234 | 235 | // NewFourierPoly creates a new fourier polynomial with the same degree as the evaluator. 236 | func (e *Evaluator) NewFourierPoly() FourierPoly { 237 | return NewFourierPoly(e.degree) 238 | } 239 | 240 | // ShallowCopy returns a shallow copy of this Evaluator. 241 | // Returned Evaluator is safe for concurrent use. 242 | func (e *Evaluator) ShallowCopy() *Evaluator { 243 | return &Evaluator{ 244 | degree: e.degree, 245 | q: e.q, 246 | tw: e.tw, 247 | twInv: e.twInv, 248 | twMono: e.twMono, 249 | twMonoIdx: e.twMonoIdx, 250 | buffer: newEvaluationBuffer(e.degree), 251 | } 252 | } 253 | -------------------------------------------------------------------------------- /poly/poly_mul.go: -------------------------------------------------------------------------------- 1 | package poly 2 | 3 | // MulPoly returns p0 * p1. 4 | func (e *Evaluator) MulPoly(p0, p1 Poly) Poly { 5 | pOut := e.NewPoly() 6 | e.MulPolyAssign(p0, p1, pOut) 7 | return pOut 8 | } 9 | 10 | // MulPolyAssign computes pOut = p0 * p1. 11 | // This uses FFT-based multiplication for efficiency. 12 | func (e *Evaluator) MulPolyAssign(p0, p1, pOut Poly) { 13 | // Transform both polynomials to frequency domain 14 | fp0 := e.ToFourierPoly(p0) 15 | fp1 := e.ToFourierPoly(p1) 16 | 17 | // Multiply in frequency domain (element-wise complex multiplication) 18 | e.MulFourierPolyAssign(fp0, fp1, fp0) 19 | 20 | // Transform back to time domain 21 | e.ToPolyAssignUnsafe(fp0, pOut) 22 | } 23 | 24 | // MulAddPolyAssign computes pOut += p0 * p1. 25 | func (e *Evaluator) MulAddPolyAssign(p0, p1, pOut Poly) { 26 | fp0 := e.ToFourierPoly(p0) 27 | fp1 := e.ToFourierPoly(p1) 28 | e.MulFourierPolyAssign(fp0, fp1, fp0) 29 | e.ToPolyAddAssignUnsafe(fp0, pOut) 30 | } 31 | 32 | // MulSubPolyAssign computes pOut -= p0 * p1. 33 | func (e *Evaluator) MulSubPolyAssign(p0, p1, pOut Poly) { 34 | fp0 := e.ToFourierPoly(p0) 35 | fp1 := e.ToFourierPoly(p1) 36 | e.MulFourierPolyAssign(fp0, fp1, fp0) 37 | e.ToPolySubAssignUnsafe(fp0, pOut) 38 | } 39 | -------------------------------------------------------------------------------- /poly/poly_test.go: -------------------------------------------------------------------------------- 1 | package poly 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/thedonutfactory/go-tfhe/params" 7 | ) 8 | 9 | // TestFFTRoundTrip tests that FFT -> IFFT gives back the original polynomial 10 | func TestFFTRoundTrip(t *testing.T) { 11 | eval := NewEvaluator(1024) 12 | 13 | // Create a test polynomial 14 | p := eval.NewPoly() 15 | for i := range p.Coeffs { 16 | p.Coeffs[i] = params.Torus(i * 12345) 17 | } 18 | 19 | // Transform to frequency domain and back 20 | fp := eval.ToFourierPoly(p) 21 | pOut := eval.ToPoly(fp) 22 | 23 | // Check if we got the original back (with some tolerance for floating point errors) 24 | for i := range p.Coeffs { 25 | diff := int64(pOut.Coeffs[i]) - int64(p.Coeffs[i]) 26 | if diff < 0 { 27 | diff = -diff 28 | } 29 | if diff > 10 { // Allow small error due to floating point rounding 30 | t.Errorf("Coefficient %d: got %d, want %d (diff %d)", i, pOut.Coeffs[i], p.Coeffs[i], diff) 31 | } 32 | } 33 | } 34 | 35 | // TestPolyMul tests polynomial multiplication 36 | func TestPolyMul(t *testing.T) { 37 | eval := NewEvaluator(1024) 38 | 39 | // Create two simple test polynomials 40 | p1 := eval.NewPoly() 41 | p2 := eval.NewPoly() 42 | 43 | p1.Coeffs[0] = 100 44 | p1.Coeffs[1] = 200 45 | 46 | p2.Coeffs[0] = 10 47 | p2.Coeffs[1] = 20 48 | 49 | // Multiply 50 | pOut := eval.MulPoly(p1, p2) 51 | 52 | // Expected result for first few coefficients: 53 | // (100 + 200*X) * (10 + 20*X) = 1000 + 2000*X + 2000*X + 4000*X^2 54 | // = 1000 + 4000*X + 4000*X^2 55 | 56 | // Due to negacyclic ring, we need to check this works correctly 57 | // For now, just verify the function runs without panic 58 | if pOut.Coeffs == nil { 59 | t.Error("MulPoly returned nil coefficients") 60 | } 61 | } 62 | 63 | // BenchmarkFFT benchmarks the FFT operation 64 | func BenchmarkFFT(b *testing.B) { 65 | eval := NewEvaluator(1024) 66 | p := eval.NewPoly() 67 | for i := range p.Coeffs { 68 | p.Coeffs[i] = params.Torus(i) 69 | } 70 | 71 | b.ResetTimer() 72 | for i := 0; i < b.N; i++ { 73 | _ = eval.ToFourierPoly(p) 74 | } 75 | } 76 | 77 | // BenchmarkIFFT benchmarks the inverse FFT operation 78 | func BenchmarkIFFT(b *testing.B) { 79 | eval := NewEvaluator(1024) 80 | p := eval.NewPoly() 81 | for i := range p.Coeffs { 82 | p.Coeffs[i] = params.Torus(i) 83 | } 84 | fp := eval.ToFourierPoly(p) 85 | 86 | b.ResetTimer() 87 | for i := 0; i < b.N; i++ { 88 | _ = eval.ToPoly(fp) 89 | } 90 | } 91 | 92 | // BenchmarkPolyMul benchmarks polynomial multiplication 93 | func BenchmarkPolyMul(b *testing.B) { 94 | eval := NewEvaluator(1024) 95 | p1 := eval.NewPoly() 96 | p2 := eval.NewPoly() 97 | for i := range p1.Coeffs { 98 | p1.Coeffs[i] = params.Torus(i) 99 | p2.Coeffs[i] = params.Torus(i * 2) 100 | } 101 | 102 | b.ResetTimer() 103 | for i := 0; i < b.N; i++ { 104 | _ = eval.MulPoly(p1, p2) 105 | } 106 | } 107 | 108 | // BenchmarkElementWiseMul benchmarks element-wise multiplication in frequency domain 109 | func BenchmarkElementWiseMul(b *testing.B) { 110 | eval := NewEvaluator(1024) 111 | p1 := eval.NewPoly() 112 | p2 := eval.NewPoly() 113 | for i := range p1.Coeffs { 114 | p1.Coeffs[i] = params.Torus(i) 115 | p2.Coeffs[i] = params.Torus(i * 2) 116 | } 117 | fp1 := eval.ToFourierPoly(p1) 118 | fp2 := eval.ToFourierPoly(p2) 119 | 120 | b.ResetTimer() 121 | for i := 0; i < b.N; i++ { 122 | eval.MulFourierPolyAssign(fp1, fp2, fp1) 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /proxyreenc/proxyreenc_test.go: -------------------------------------------------------------------------------- 1 | package proxyreenc 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/thedonutfactory/go-tfhe/key" 7 | "github.com/thedonutfactory/go-tfhe/params" 8 | "github.com/thedonutfactory/go-tfhe/tlwe" 9 | ) 10 | 11 | func TestPublicKeyEncryption(t *testing.T) { 12 | secretKey := key.NewSecretKey() 13 | publicKey := NewPublicKeyLv0(secretKey.KeyLv0) 14 | 15 | // Test encrypting with public key and decrypting with secret key 16 | messages := []bool{true, false} 17 | for _, message := range messages { 18 | ct := publicKey.EncryptBool(message, params.GetTLWELv0().ALPHA) 19 | decrypted := ct.DecryptBool(secretKey.KeyLv0) 20 | 21 | if decrypted != message { 22 | t.Errorf("Public key encryption failed: got %v, want %v", decrypted, message) 23 | } 24 | } 25 | } 26 | 27 | func TestPublicKeyEncryptionMultiple(t *testing.T) { 28 | secretKey := key.NewSecretKey() 29 | publicKey := NewPublicKeyLv0(secretKey.KeyLv0) 30 | 31 | correct := 0 32 | iterations := 100 33 | 34 | for i := 0; i < iterations; i++ { 35 | message := (i % 2) == 0 36 | ct := publicKey.EncryptBool(message, params.GetTLWELv0().ALPHA) 37 | if ct.DecryptBool(secretKey.KeyLv0) == message { 38 | correct++ 39 | } 40 | } 41 | 42 | accuracy := float64(correct) / float64(iterations) 43 | if accuracy < 0.95 { 44 | t.Errorf("Public key encryption accuracy too low: %.2f%%", accuracy*100) 45 | } 46 | } 47 | 48 | func TestProxyReencryptionAsymmetric(t *testing.T) { 49 | aliceKey := key.NewSecretKey() 50 | bobKey := key.NewSecretKey() 51 | 52 | // Bob publishes his public key 53 | bobPublicKey := NewPublicKeyLv0(bobKey.KeyLv0) 54 | 55 | // Alice generates reencryption key using Bob's PUBLIC key 56 | reencKey := NewProxyReencryptionKeyAsymmetric(aliceKey.KeyLv0, bobPublicKey) 57 | 58 | // Test both true and false 59 | messages := []bool{true, false} 60 | for _, message := range messages { 61 | aliceCt := tlwe.NewTLWELv0() 62 | aliceCt.EncryptBool(message, params.GetTLWELv0().ALPHA, aliceKey.KeyLv0) 63 | 64 | // Verify Alice can decrypt 65 | if aliceCt.DecryptBool(aliceKey.KeyLv0) != message { 66 | t.Errorf("Alice encryption failed for message %v", message) 67 | } 68 | 69 | // Reencrypt to Bob's key 70 | bobCt := ReencryptTLWELv0(aliceCt, reencKey) 71 | 72 | // Verify Bob can decrypt 73 | decrypted := bobCt.DecryptBool(bobKey.KeyLv0) 74 | if decrypted != message { 75 | t.Errorf("Asymmetric proxy reencryption failed: got %v, want %v", decrypted, message) 76 | } 77 | } 78 | } 79 | 80 | func TestProxyReencryptionSymmetric(t *testing.T) { 81 | aliceKey := key.NewSecretKey() 82 | bobKey := key.NewSecretKey() 83 | 84 | // Symmetric mode - requires both secret keys 85 | reencKey := NewProxyReencryptionKeySymmetric(aliceKey.KeyLv0, bobKey.KeyLv0) 86 | 87 | // Test both true and false 88 | messages := []bool{true, false} 89 | for _, message := range messages { 90 | aliceCt := tlwe.NewTLWELv0() 91 | aliceCt.EncryptBool(message, params.GetTLWELv0().ALPHA, aliceKey.KeyLv0) 92 | 93 | // Verify Alice can decrypt 94 | if aliceCt.DecryptBool(aliceKey.KeyLv0) != message { 95 | t.Errorf("Alice encryption failed for message %v", message) 96 | } 97 | 98 | // Reencrypt to Bob's key 99 | bobCt := ReencryptTLWELv0(aliceCt, reencKey) 100 | 101 | // Verify Bob can decrypt 102 | decrypted := bobCt.DecryptBool(bobKey.KeyLv0) 103 | if decrypted != message { 104 | t.Errorf("Symmetric proxy reencryption failed: got %v, want %v", decrypted, message) 105 | } 106 | } 107 | } 108 | 109 | func TestProxyReencryptionAsymmetricMultiple(t *testing.T) { 110 | aliceKey := key.NewSecretKey() 111 | bobKey := key.NewSecretKey() 112 | bobPublicKey := NewPublicKeyLv0(bobKey.KeyLv0) 113 | 114 | reencKey := NewProxyReencryptionKeyAsymmetric(aliceKey.KeyLv0, bobPublicKey) 115 | 116 | correct := 0 117 | iterations := 100 118 | 119 | for i := 0; i < iterations; i++ { 120 | message := (i % 2) == 0 121 | 122 | aliceCt := tlwe.NewTLWELv0() 123 | aliceCt.EncryptBool(message, params.GetTLWELv0().ALPHA, aliceKey.KeyLv0) 124 | 125 | bobCt := ReencryptTLWELv0(aliceCt, reencKey) 126 | 127 | if bobCt.DecryptBool(bobKey.KeyLv0) == message { 128 | correct++ 129 | } 130 | } 131 | 132 | accuracy := float64(correct) / float64(iterations) 133 | if accuracy < 0.90 { 134 | t.Errorf("Asymmetric proxy reencryption accuracy too low: %.2f%%", accuracy*100) 135 | } 136 | 137 | t.Logf("Asymmetric proxy reencryption accuracy: %d/%d (%.1f%%)", correct, iterations, accuracy*100) 138 | } 139 | 140 | func TestProxyReencryptionChainAsymmetric(t *testing.T) { 141 | aliceKey := key.NewSecretKey() 142 | bobKey := key.NewSecretKey() 143 | carolKey := key.NewSecretKey() 144 | 145 | bobPublic := NewPublicKeyLv0(bobKey.KeyLv0) 146 | carolPublic := NewPublicKeyLv0(carolKey.KeyLv0) 147 | 148 | reencKeyAB := NewProxyReencryptionKeyAsymmetric(aliceKey.KeyLv0, bobPublic) 149 | reencKeyBC := NewProxyReencryptionKeyAsymmetric(bobKey.KeyLv0, carolPublic) 150 | 151 | message := true 152 | 153 | aliceCt := tlwe.NewTLWELv0() 154 | aliceCt.EncryptBool(message, params.GetTLWELv0().ALPHA, aliceKey.KeyLv0) 155 | 156 | // Alice -> Bob 157 | bobCt := ReencryptTLWELv0(aliceCt, reencKeyAB) 158 | if bobCt.DecryptBool(bobKey.KeyLv0) != message { 159 | t.Errorf("Alice -> Bob reencryption failed") 160 | } 161 | 162 | // Bob -> Carol 163 | carolCt := ReencryptTLWELv0(bobCt, reencKeyBC) 164 | if carolCt.DecryptBool(carolKey.KeyLv0) != message { 165 | t.Errorf("Bob -> Carol reencryption failed") 166 | } 167 | } 168 | 169 | func TestProxyReencryptionKeyGeneration(t *testing.T) { 170 | aliceKey := key.NewSecretKey() 171 | bobKey := key.NewSecretKey() 172 | 173 | reencKey := NewProxyReencryptionKeySymmetric(aliceKey.KeyLv0, bobKey.KeyLv0) 174 | 175 | // Verify the key has the right size 176 | expectedSize := reencKey.Base * reencKey.T * params.GetTLWELv0().N 177 | if len(reencKey.KeyEncryptions) != expectedSize { 178 | t.Errorf("Key size mismatch: got %d, want %d", len(reencKey.KeyEncryptions), expectedSize) 179 | } 180 | 181 | // Verify structure 182 | expectedBase := 1 << params.GetTRGSWLv1().BASEBIT 183 | if reencKey.Base != expectedBase { 184 | t.Errorf("Base mismatch: got %d, want %d", reencKey.Base, expectedBase) 185 | } 186 | 187 | if reencKey.T != params.GetTRGSWLv1().IKS_T { 188 | t.Errorf("T mismatch: got %d, want %d", reencKey.T, params.GetTRGSWLv1().IKS_T) 189 | } 190 | } 191 | 192 | // Benchmark asymmetric key generation 193 | func BenchmarkAsymmetricKeyGeneration(b *testing.B) { 194 | aliceKey := key.NewSecretKey() 195 | bobKey := key.NewSecretKey() 196 | bobPublicKey := NewPublicKeyLv0(bobKey.KeyLv0) 197 | 198 | b.ResetTimer() 199 | for i := 0; i < b.N; i++ { 200 | _ = NewProxyReencryptionKeyAsymmetric(aliceKey.KeyLv0, bobPublicKey) 201 | } 202 | } 203 | 204 | // Benchmark symmetric key generation 205 | func BenchmarkSymmetricKeyGeneration(b *testing.B) { 206 | aliceKey := key.NewSecretKey() 207 | bobKey := key.NewSecretKey() 208 | 209 | b.ResetTimer() 210 | for i := 0; i < b.N; i++ { 211 | _ = NewProxyReencryptionKeySymmetric(aliceKey.KeyLv0, bobKey.KeyLv0) 212 | } 213 | } 214 | 215 | // Benchmark reencryption operation 216 | func BenchmarkReencryption(b *testing.B) { 217 | aliceKey := key.NewSecretKey() 218 | bobKey := key.NewSecretKey() 219 | 220 | reencKey := NewProxyReencryptionKeySymmetric(aliceKey.KeyLv0, bobKey.KeyLv0) 221 | 222 | aliceCt := tlwe.NewTLWELv0() 223 | aliceCt.EncryptBool(true, params.GetTLWELv0().ALPHA, aliceKey.KeyLv0) 224 | 225 | b.ResetTimer() 226 | for i := 0; i < b.N; i++ { 227 | _ = ReencryptTLWELv0(aliceCt, reencKey) 228 | } 229 | } 230 | 231 | // Benchmark public key generation 232 | func BenchmarkPublicKeyGeneration(b *testing.B) { 233 | secretKey := key.NewSecretKey() 234 | 235 | b.ResetTimer() 236 | for i := 0; i < b.N; i++ { 237 | _ = NewPublicKeyLv0(secretKey.KeyLv0) 238 | } 239 | } 240 | 241 | -------------------------------------------------------------------------------- /tlwe/programmable_encrypt.go: -------------------------------------------------------------------------------- 1 | package tlwe 2 | 3 | import ( 4 | "github.com/thedonutfactory/go-tfhe/params" 5 | ) 6 | 7 | // EncryptLWEMessage encrypts an integer message using general message encoding 8 | // This is different from EncryptBool which uses ±1/8 binary encoding. 9 | // 10 | // For programmable bootstrapping, use this function to match the LUT encoding. 11 | // Encoding: message → message * scale, where scale = 2^31 / messageModulus 12 | func (t *TLWELv0) EncryptLWEMessage(message int, messageModulus int, alpha float64, key []params.Torus) *TLWELv0 { 13 | // Calculate scale: 2^31 / messageModulus 14 | scale := float64(uint64(1)<<31) / float64(messageModulus) 15 | 16 | // Normalize message 17 | message = message % messageModulus 18 | if message < 0 { 19 | message += messageModulus 20 | } 21 | 22 | // Encode: message * scale / 2^32 to get value in [0, 1) 23 | encodedMessage := float64(message) * scale / float64(uint64(1)<<32) 24 | 25 | return t.EncryptF64(encodedMessage, alpha, key) 26 | } 27 | 28 | // DecryptLWEMessage decrypts an integer message using general message encoding 29 | // 30 | // Following the reference implementation: num.DivRound(phase, scale) % messageModulus 31 | // DivRound(a, b) rounds a/b to nearest integer 32 | func (t *TLWELv0) DecryptLWEMessage(messageModulus int, key []params.Torus) int { 33 | // Calculate scale: 2^31 / messageModulus 34 | scale := params.Torus(uint64(1)<<31) / params.Torus(messageModulus) 35 | 36 | // Get phase (decrypted value with noise) 37 | n := params.GetTLWELv0().N 38 | var innerProduct params.Torus 39 | for i := 0; i < n; i++ { 40 | innerProduct += t.P[i] * key[i] 41 | } 42 | phase := t.P[n] - innerProduct 43 | 44 | // DivRound: (a + b/2) / b 45 | // For unsigned: (phase + scale/2) / scale 46 | decoded := int((phase + scale/2) / scale) 47 | 48 | message := decoded % messageModulus 49 | if message < 0 { 50 | message += messageModulus 51 | } 52 | 53 | return message 54 | } 55 | -------------------------------------------------------------------------------- /tlwe/tlwe.go: -------------------------------------------------------------------------------- 1 | package tlwe 2 | 3 | import ( 4 | "math/rand" 5 | 6 | "github.com/thedonutfactory/go-tfhe/params" 7 | "github.com/thedonutfactory/go-tfhe/utils" 8 | ) 9 | 10 | // TLWELv0 represents a Level 0 TLWE ciphertext 11 | type TLWELv0 struct { 12 | P []params.Torus // Length is N+1, where last element is b 13 | } 14 | 15 | // NewTLWELv0 creates a new TLWE Level 0 ciphertext 16 | func NewTLWELv0() *TLWELv0 { 17 | n := params.GetTLWELv0().N 18 | return &TLWELv0{ 19 | P: make([]params.Torus, n+1), 20 | } 21 | } 22 | 23 | // B returns the b component of the TLWE ciphertext 24 | func (t *TLWELv0) B() params.Torus { 25 | n := params.GetTLWELv0().N 26 | return t.P[n] 27 | } 28 | 29 | // SetB sets the b component of the TLWE ciphertext 30 | func (t *TLWELv0) SetB(val params.Torus) { 31 | n := params.GetTLWELv0().N 32 | t.P[n] = val 33 | } 34 | 35 | // EncryptF64 encrypts a float64 value with TLWE Level 0 36 | func (t *TLWELv0) EncryptF64(p float64, alpha float64, key []params.Torus) *TLWELv0 { 37 | rng := rand.New(rand.NewSource(rand.Int63())) 38 | n := params.GetTLWELv0().N 39 | 40 | var innerProduct params.Torus 41 | for i := 0; i < n; i++ { 42 | randU32 := params.Torus(rng.Uint32()) 43 | innerProduct += key[i] * randU32 44 | t.P[i] = randU32 45 | } 46 | 47 | b := utils.GaussianF64(p, alpha, rng) 48 | t.SetB(innerProduct + b) 49 | return t 50 | } 51 | 52 | // EncryptBool encrypts a boolean value with TLWE Level 0 53 | func (t *TLWELv0) EncryptBool(pBool bool, alpha float64, key []params.Torus) *TLWELv0 { 54 | var p float64 55 | if pBool { 56 | p = 0.125 57 | } else { 58 | p = -0.125 59 | } 60 | return t.EncryptF64(p, alpha, key) 61 | } 62 | 63 | // DecryptBool decrypts a TLWE Level 0 ciphertext to a boolean 64 | func (t *TLWELv0) DecryptBool(key []params.Torus) bool { 65 | n := params.GetTLWELv0().N 66 | var innerProduct params.Torus 67 | for i := 0; i < n; i++ { 68 | innerProduct += t.P[i] * key[i] 69 | } 70 | 71 | resTorus := int32(t.P[n] - innerProduct) 72 | return resTorus >= 0 73 | } 74 | 75 | // Add adds two TLWE Level 0 ciphertexts 76 | func (t *TLWELv0) Add(other *TLWELv0) *TLWELv0 { 77 | result := NewTLWELv0() 78 | for i := range result.P { 79 | result.P[i] = t.P[i] + other.P[i] 80 | } 81 | return result 82 | } 83 | 84 | // AddAssign adds two TLWE Level 0 ciphertexts and writes to output (zero-allocation) 85 | func (t *TLWELv0) AddAssign(other *TLWELv0, output *TLWELv0) { 86 | for i := range output.P { 87 | output.P[i] = t.P[i] + other.P[i] 88 | } 89 | } 90 | 91 | // Sub subtracts two TLWE Level 0 ciphertexts 92 | func (t *TLWELv0) Sub(other *TLWELv0) *TLWELv0 { 93 | result := NewTLWELv0() 94 | for i := range result.P { 95 | result.P[i] = t.P[i] - other.P[i] 96 | } 97 | return result 98 | } 99 | 100 | // Neg negates a TLWE Level 0 ciphertext 101 | func (t *TLWELv0) Neg() *TLWELv0 { 102 | result := NewTLWELv0() 103 | for i := range result.P { 104 | result.P[i] = 0 - t.P[i] 105 | } 106 | return result 107 | } 108 | 109 | // Mul multiplies two TLWE Level 0 ciphertexts (element-wise) 110 | func (t *TLWELv0) Mul(other *TLWELv0) *TLWELv0 { 111 | result := NewTLWELv0() 112 | for i := range result.P { 113 | result.P[i] = t.P[i] * other.P[i] 114 | } 115 | return result 116 | } 117 | 118 | // AddMul adds a TLWE ciphertext multiplied by a constant 119 | func (t *TLWELv0) AddMul(other *TLWELv0, multiplier params.Torus) *TLWELv0 { 120 | result := NewTLWELv0() 121 | for i := range result.P { 122 | result.P[i] = t.P[i] + (other.P[i] * multiplier) 123 | } 124 | return result 125 | } 126 | 127 | // SubMul subtracts a TLWE ciphertext multiplied by a constant 128 | func (t *TLWELv0) SubMul(other *TLWELv0, multiplier params.Torus) *TLWELv0 { 129 | result := NewTLWELv0() 130 | for i := range result.P { 131 | result.P[i] = t.P[i] - (other.P[i] * multiplier) 132 | } 133 | return result 134 | } 135 | 136 | // TLWELv1 represents a Level 1 TLWE ciphertext 137 | type TLWELv1 struct { 138 | P []params.Torus // Length is N+1, where last element is b 139 | } 140 | 141 | // NewTLWELv1 creates a new TLWE Level 1 ciphertext 142 | func NewTLWELv1() *TLWELv1 { 143 | n := params.GetTLWELv1().N 144 | return &TLWELv1{ 145 | P: make([]params.Torus, n+1), 146 | } 147 | } 148 | 149 | // SetB sets the b component of the TLWE Level 1 ciphertext 150 | func (t *TLWELv1) SetB(val params.Torus) { 151 | n := params.GetTLWELv1().N 152 | t.P[n] = val 153 | } 154 | 155 | // EncryptF64 encrypts a float64 value with TLWE Level 1 156 | func (t *TLWELv1) EncryptF64(p float64, alpha float64, key []params.Torus) *TLWELv1 { 157 | rng := rand.New(rand.NewSource(rand.Int63())) 158 | n := params.GetTLWELv1().N 159 | 160 | var innerProduct params.Torus 161 | for i := 0; i < n; i++ { 162 | randU32 := params.Torus(rng.Uint32()) 163 | innerProduct += key[i] * randU32 164 | t.P[i] = randU32 165 | } 166 | 167 | b := utils.GaussianF64(p, alpha, rng) 168 | t.SetB(innerProduct + b) 169 | return t 170 | } 171 | 172 | // EncryptBool encrypts a boolean value with TLWE Level 1 173 | func (t *TLWELv1) EncryptBool(pBool bool, alpha float64, key []params.Torus) *TLWELv1 { 174 | var p float64 175 | if pBool { 176 | p = 0.125 177 | } else { 178 | p = -0.125 179 | } 180 | return t.EncryptF64(p, alpha, key) 181 | } 182 | 183 | // DecryptBool decrypts a TLWE Level 1 ciphertext to a boolean 184 | func (t *TLWELv1) DecryptBool(key []params.Torus) bool { 185 | n := params.GetTLWELv1().N 186 | var innerProduct params.Torus 187 | for i := 0; i < n; i++ { 188 | innerProduct += t.P[i] * key[i] 189 | } 190 | 191 | resTorus := int32(t.P[len(key)] - innerProduct) 192 | return resTorus >= 0 193 | } 194 | -------------------------------------------------------------------------------- /tlwe/tlwe_test.go: -------------------------------------------------------------------------------- 1 | package tlwe_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/thedonutfactory/go-tfhe/key" 7 | "github.com/thedonutfactory/go-tfhe/params" 8 | "github.com/thedonutfactory/go-tfhe/tlwe" 9 | ) 10 | 11 | func TestTLWELv0EncryptDecrypt(t *testing.T) { 12 | sk := key.NewSecretKey() 13 | 14 | testCases := []bool{true, false} 15 | 16 | for _, val := range testCases { 17 | ct := tlwe.NewTLWELv0().EncryptBool(val, params.GetTLWELv0().ALPHA, sk.KeyLv0) 18 | dec := ct.DecryptBool(sk.KeyLv0) 19 | 20 | if dec != val { 21 | t.Errorf("Encrypt/Decrypt(%v) = %v", val, dec) 22 | } 23 | } 24 | } 25 | 26 | func TestTLWELv0EncryptDecryptMultiple(t *testing.T) { 27 | sk := key.NewSecretKey() 28 | trials := 100 29 | correct := 0 30 | 31 | for i := 0; i < trials; i++ { 32 | val := i%2 == 0 33 | ct := tlwe.NewTLWELv0().EncryptBool(val, params.GetTLWELv0().ALPHA, sk.KeyLv0) 34 | dec := ct.DecryptBool(sk.KeyLv0) 35 | 36 | if dec == val { 37 | correct++ 38 | } 39 | } 40 | 41 | if correct != trials { 42 | t.Errorf("Correctness: %d/%d (%.1f%%)", correct, trials, float64(correct)/float64(trials)*100) 43 | } 44 | } 45 | 46 | func TestTLWELv0Add(t *testing.T) { 47 | sk := key.NewSecretKey() 48 | 49 | ct1 := tlwe.NewTLWELv0().EncryptBool(true, params.GetTLWELv0().ALPHA, sk.KeyLv0) 50 | ct2 := tlwe.NewTLWELv0().EncryptBool(false, params.GetTLWELv0().ALPHA, sk.KeyLv0) 51 | 52 | sum := ct1.Add(ct2) 53 | 54 | // Addition of true + false should still be decryptable 55 | // (though the semantic meaning depends on the circuit) 56 | _ = sum.DecryptBool(sk.KeyLv0) 57 | } 58 | 59 | func TestTLWELv0Neg(t *testing.T) { 60 | sk := key.NewSecretKey() 61 | 62 | ct := tlwe.NewTLWELv0().EncryptBool(true, params.GetTLWELv0().ALPHA, sk.KeyLv0) 63 | negCt := ct.Neg() 64 | dec := negCt.DecryptBool(sk.KeyLv0) 65 | 66 | // Negation of true (0.125) should give false (-0.125) 67 | if dec != false { 68 | t.Errorf("Neg(true) = %v, expected false", dec) 69 | } 70 | } 71 | 72 | func TestTLWELv1EncryptDecrypt(t *testing.T) { 73 | sk := key.NewSecretKey() 74 | 75 | testCases := []bool{true, false} 76 | 77 | for _, val := range testCases { 78 | ct := tlwe.NewTLWELv1().EncryptBool(val, params.GetTLWELv1().ALPHA, sk.KeyLv1) 79 | dec := ct.DecryptBool(sk.KeyLv1) 80 | 81 | if dec != val { 82 | t.Errorf("TLWELv1 Encrypt/Decrypt(%v) = %v", val, dec) 83 | } 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /trgsw/keyswitch.go: -------------------------------------------------------------------------------- 1 | package trgsw 2 | 3 | import ( 4 | "github.com/thedonutfactory/go-tfhe/params" 5 | "github.com/thedonutfactory/go-tfhe/tlwe" 6 | ) 7 | 8 | // IdentityKeySwitchingAssign performs identity key switching and writes to output 9 | // Zero-allocation version 10 | func IdentityKeySwitchingAssign(src *tlwe.TLWELv1, keySwitchingKey []*tlwe.TLWELv0, output *tlwe.TLWELv0) { 11 | n := params.GetTRGSWLv1().N 12 | basebit := params.GetTRGSWLv1().BASEBIT 13 | base := 1 << basebit 14 | iksT := params.GetTRGSWLv1().IKS_T 15 | tlweLv0N := params.GetTLWELv0().N 16 | 17 | // Clear output 18 | for i := 0; i < len(output.P); i++ { 19 | output.P[i] = 0 20 | } 21 | output.P[tlweLv0N] = src.P[len(src.P)-1] 22 | 23 | precOffset := params.Torus(1 << (32 - (1 + basebit*iksT))) 24 | 25 | for i := 0; i < n; i++ { 26 | aBar := src.P[i] + precOffset 27 | for j := 0; j < iksT; j++ { 28 | k := (aBar >> (32 - (j+1)*basebit)) & params.Torus((1<> (32 - (uint32(i)+1)*bgbit)) & mask) - halfBG 159 | } 160 | for i := 0; i < l; i++ { 161 | polyEval.GetDecompBuffer(i + l).Coeffs[j] = ((tmp1 >> (32 - (uint32(i)+1)*bgbit)) & mask) - halfBG 162 | } 163 | } 164 | 165 | // Transform all decomposition levels to frequency domain 166 | for i := 0; i < l*2; i++ { 167 | polyEval.ToFourierPolyInBuffer(*polyEval.GetDecompBuffer(i), i) 168 | } 169 | } 170 | 171 | // CMUX performs controlled MUX operation (zero-allocation version using TRLWE pool) 172 | // if cond == 0 then in1 else in2 173 | func CMUX(in1, in2 *trlwe.TRLWELv1, cond *TRGSWLv1FFT, decompositionOffset params.Torus, polyEval *poly.Evaluator) *trlwe.TRLWELv1 { 174 | n := params.GetTRGSWLv1().N 175 | 176 | // Get TRLWE buffer from pool for difference computation 177 | tmpA, tmpB := polyEval.GetTRLWEBuffer() 178 | for i := 0; i < n; i++ { 179 | tmpA[i] = in2.A[i] - in1.A[i] 180 | tmpB[i] = in2.B[i] - in1.B[i] 181 | } 182 | tmp := &trlwe.TRLWELv1{A: tmpA, B: tmpB} 183 | 184 | // External product (uses internal buffers for zero-alloc in hot path) 185 | tmp2 := ExternalProductWithFFT(cond, tmp, decompositionOffset, polyEval) 186 | 187 | // Add in1 to result (reuse tmp2) 188 | for i := 0; i < n; i++ { 189 | tmp2.A[i] += in1.A[i] 190 | tmp2.B[i] += in1.B[i] 191 | } 192 | 193 | return tmp2 194 | } 195 | 196 | // BlindRotate performs blind rotation for bootstrapping (optimized with buffer pool) 197 | func BlindRotate(src *tlwe.TLWELv0, blindRotateTestvec *trlwe.TRLWELv1, bootstrappingKey []*TRGSWLv1FFT, decompositionOffset params.Torus, polyEval *poly.Evaluator) *trlwe.TRLWELv1 { 198 | n := params.GetTRGSWLv1().N 199 | nBit := params.GetTRGSWLv1().NBIT 200 | 201 | // Reset rotation pool for this operation 202 | polyEval.ResetRotationPool() 203 | 204 | bTilda := 2*n - ((int(src.B()) + (1 << (31 - nBit - 1))) >> (32 - nBit - 1)) 205 | 206 | // Initial rotation using buffer pool 207 | resultA := polyEval.PolyMulWithXK(blindRotateTestvec.A, bTilda) 208 | resultB := polyEval.PolyMulWithXK(blindRotateTestvec.B, bTilda) 209 | result := &trlwe.TRLWELv1{A: resultA, B: resultB} 210 | 211 | tlweLv0N := params.GetTLWELv0().N 212 | for i := 0; i < tlweLv0N; i++ { 213 | aTilda := int((src.P[i] + (1 << (31 - nBit - 1))) >> (32 - nBit - 1)) 214 | 215 | // Use buffer pool for rotation 216 | res2A := polyEval.PolyMulWithXK(result.A, aTilda) 217 | res2B := polyEval.PolyMulWithXK(result.B, aTilda) 218 | res2 := &trlwe.TRLWELv1{A: res2A, B: res2B} 219 | 220 | result = CMUX(result, res2, bootstrappingKey[i], decompositionOffset, polyEval) 221 | } 222 | 223 | return result 224 | } 225 | 226 | // evaluatorPool is a pool of evaluators for parallel operations 227 | var evaluatorPool = sync.Pool{ 228 | New: func() interface{} { 229 | return poly.NewEvaluator(params.GetTRGSWLv1().N) 230 | }, 231 | } 232 | 233 | // BatchBlindRotate performs multiple blind rotations in parallel (zero-allocation) 234 | func BatchBlindRotate(srcs []*tlwe.TLWELv0, blindRotateTestvec *trlwe.TRLWELv1, bootstrappingKey []*TRGSWLv1FFT, decompositionOffset params.Torus) []*trlwe.TRLWELv1 { 235 | results := make([]*trlwe.TRLWELv1, len(srcs)) 236 | var wg sync.WaitGroup 237 | 238 | for i, src := range srcs { 239 | wg.Add(1) 240 | go func(idx int, s *tlwe.TLWELv0) { 241 | defer wg.Done() 242 | // Get evaluator from pool (reuse instead of allocate) 243 | polyEval := evaluatorPool.Get().(*poly.Evaluator) 244 | defer evaluatorPool.Put(polyEval) 245 | 246 | results[idx] = BlindRotate(s, blindRotateTestvec, bootstrappingKey, decompositionOffset, polyEval) 247 | }(i, src) 248 | } 249 | 250 | wg.Wait() 251 | return results 252 | } 253 | 254 | // polyMulWithXKInPlace multiplies a polynomial by X^k in-place (zero-allocation) 255 | func polyMulWithXKInPlace(a []params.Torus, k int, result []params.Torus) { 256 | n := len(a) 257 | k = k % (2 * n) // Normalize k to [0, 2N) 258 | 259 | if k == 0 { 260 | copy(result, a) 261 | return 262 | } 263 | 264 | if k < n { 265 | // Positive rotation: coefficients shift right, wrap with negation 266 | for i := 0; i < n-k; i++ { 267 | result[i+k] = a[i] 268 | } 269 | for i := n - k; i < n; i++ { 270 | result[i+k-n] = ^params.Torus(0) - a[i] 271 | } 272 | } else { 273 | // Rotation >= n: all coefficients get negated 274 | k -= n 275 | for i := 0; i < n-k; i++ { 276 | result[i+k] = ^params.Torus(0) - a[i] 277 | } 278 | for i := n - k; i < n; i++ { 279 | result[i+k-n] = a[i] 280 | } 281 | } 282 | } 283 | 284 | // IdentityKeySwitching performs identity key switching 285 | func IdentityKeySwitching(src *tlwe.TLWELv1, keySwitchingKey []*tlwe.TLWELv0) *tlwe.TLWELv0 { 286 | n := params.GetTRGSWLv1().N 287 | basebit := params.GetTRGSWLv1().BASEBIT 288 | base := 1 << basebit 289 | iksT := params.GetTRGSWLv1().IKS_T 290 | 291 | result := tlwe.NewTLWELv0() 292 | tlweLv0N := params.GetTLWELv0().N 293 | result.P[tlweLv0N] = src.P[len(src.P)-1] 294 | 295 | precOffset := params.Torus(1 << (32 - (1 + basebit*iksT))) 296 | 297 | for i := 0; i < n; i++ { 298 | aBar := src.P[i] + precOffset 299 | for j := 0; j < iksT; j++ { 300 | k := (aBar >> (32 - (j+1)*basebit)) & params.Torus((1<= 0 78 | } 79 | 80 | return result 81 | } 82 | 83 | // TRLWELv1FFT represents a TRLWE Level 1 ciphertext in FFT form 84 | type TRLWELv1FFT struct { 85 | A []float64 86 | B []float64 87 | } 88 | 89 | // NewTRLWELv1FFT creates a new TRLWE Level 1 FFT ciphertext from a regular TRLWE 90 | func NewTRLWELv1FFT(trlwe *TRLWELv1, polyEval *poly.Evaluator) *TRLWELv1FFT { 91 | // Convert to Fourier domain using poly evaluator 92 | polyA := poly.Poly{Coeffs: trlwe.A} 93 | polyB := poly.Poly{Coeffs: trlwe.B} 94 | 95 | fpA := polyEval.ToFourierPoly(polyA) 96 | fpB := polyEval.ToFourierPoly(polyB) 97 | 98 | return &TRLWELv1FFT{ 99 | A: fpA.Coeffs, 100 | B: fpB.Coeffs, 101 | } 102 | } 103 | 104 | // NewTRLWELv1FFTDummy creates a dummy TRLWE Level 1 FFT ciphertext 105 | func NewTRLWELv1FFTDummy() *TRLWELv1FFT { 106 | // FourierPoly needs 2*N for interleaved real/imaginary layout 107 | return &TRLWELv1FFT{ 108 | A: make([]float64, 2*params.GetTRLWELv1().N), 109 | B: make([]float64, 2*params.GetTRLWELv1().N), 110 | } 111 | } 112 | 113 | // SampleExtractIndex extracts a TLWE sample from a TRLWE at index k 114 | func SampleExtractIndex(trlwe *TRLWELv1, k int) *tlwe.TLWELv1 { 115 | n := params.GetTRLWELv1().N 116 | result := tlwe.NewTLWELv1() 117 | 118 | for i := 0; i < n; i++ { 119 | if i <= k { 120 | result.P[i] = trlwe.A[k-i] 121 | } else { 122 | result.P[i] = ^params.Torus(0) - trlwe.A[n+k-i] 123 | } 124 | } 125 | result.SetB(trlwe.B[k]) 126 | 127 | return result 128 | } 129 | 130 | // SampleExtractIndex2 extracts a TLWE Lv0 sample from a TRLWE at index k 131 | // NOTE: This should NOT be used when TRLWE.N != TLWELv0.N 132 | // For Uint5 params, use proper key switching from TLWELv1 instead 133 | func SampleExtractIndex2(trlwe *TRLWELv1, k int) *tlwe.TLWELv0 { 134 | n := params.GetTLWELv0().N 135 | trlweN := len(trlwe.A) 136 | result := tlwe.NewTLWELv0() 137 | 138 | // If sizes don't match, we can't directly extract 139 | // This function is only correct when trlweN == n 140 | if trlweN != n { 141 | panic("SampleExtractIndex2: TRLWE dimension mismatch - use proper key switching") 142 | } 143 | 144 | for i := 0; i < n; i++ { 145 | if i <= k { 146 | result.P[i] = trlwe.A[k-i] 147 | } else { 148 | result.P[i] = ^params.Torus(0) - trlwe.A[n+k-i] 149 | } 150 | } 151 | result.SetB(trlwe.B[k]) 152 | 153 | return result 154 | } 155 | -------------------------------------------------------------------------------- /trlwe/trlwe_ops.go: -------------------------------------------------------------------------------- 1 | package trlwe 2 | 3 | import ( 4 | "github.com/thedonutfactory/go-tfhe/params" 5 | "github.com/thedonutfactory/go-tfhe/tlwe" 6 | ) 7 | 8 | // SampleExtractIndexAssign extracts a TLWE sample from TRLWE at index k and writes to output 9 | // Zero-allocation version 10 | func SampleExtractIndexAssign(trlwe *TRLWELv1, k int, output *tlwe.TLWELv1) { 11 | n := params.GetTRLWELv1().N 12 | 13 | for i := 0; i < n; i++ { 14 | if i <= k { 15 | output.P[i] = trlwe.A[k-i] 16 | } else { 17 | output.P[i] = ^params.Torus(0) - trlwe.A[n+k-i] 18 | } 19 | } 20 | output.SetB(trlwe.B[k]) 21 | } 22 | -------------------------------------------------------------------------------- /utils/utils.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "math" 5 | "math/rand" 6 | 7 | "github.com/thedonutfactory/go-tfhe/params" 8 | ) 9 | 10 | // F64ToTorus converts a float64 to a Torus value 11 | func F64ToTorus(d float64) params.Torus { 12 | torus := math.Mod(d, 1.0) * float64(uint64(1)<<32) 13 | return params.Torus(int64(torus)) 14 | } 15 | 16 | // TorusToF64 converts a Torus value to a float64 in range [0, 1) 17 | func TorusToF64(t params.Torus) float64 { 18 | return float64(t) / float64(uint64(1)<<32) 19 | } 20 | 21 | // F64ToTorusVec converts a slice of float64 to a slice of Torus values 22 | func F64ToTorusVec(d []float64) []params.Torus { 23 | result := make([]params.Torus, len(d)) 24 | for i, val := range d { 25 | result[i] = F64ToTorus(val) 26 | } 27 | return result 28 | } 29 | 30 | // GaussianTorus samples from a Gaussian distribution and adds to mu 31 | func GaussianTorus(mu params.Torus, stddev float64, rng *rand.Rand) params.Torus { 32 | sample := rng.NormFloat64() * stddev 33 | return mu + F64ToTorus(sample) 34 | } 35 | 36 | // GaussianF64 samples from a Gaussian distribution with mean mu 37 | func GaussianF64(mu float64, stddev float64, rng *rand.Rand) params.Torus { 38 | muTorus := F64ToTorus(mu) 39 | return GaussianTorus(muTorus, stddev, rng) 40 | } 41 | 42 | // GaussianF64Vec samples a vector from a Gaussian distribution 43 | func GaussianF64Vec(mu []float64, stddev float64, rng *rand.Rand) []params.Torus { 44 | result := make([]params.Torus, len(mu)) 45 | for i, m := range mu { 46 | result[i] = GaussianTorus(F64ToTorus(m), stddev, rng) 47 | } 48 | return result 49 | } 50 | -------------------------------------------------------------------------------- /utils/utils_test.go: -------------------------------------------------------------------------------- 1 | package utils_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/thedonutfactory/go-tfhe/params" 7 | "github.com/thedonutfactory/go-tfhe/utils" 8 | ) 9 | 10 | func TestF64ToTorus(t *testing.T) { 11 | testCases := []struct { 12 | input float64 13 | expected params.Torus 14 | }{ 15 | {0.0, 0}, 16 | {0.125, 536870912}, // 2^29 17 | {-0.125, 3758096384}, // 2^32 - 2^29 18 | {0.25, 1073741824}, // 2^30 19 | {0.5, 2147483648}, // 2^31 20 | } 21 | 22 | for _, tc := range testCases { 23 | result := utils.F64ToTorus(tc.input) 24 | if result != tc.expected { 25 | t.Errorf("F64ToTorus(%f) = %d (0x%08x), expected %d (0x%08x)", 26 | tc.input, result, result, tc.expected, tc.expected) 27 | } 28 | } 29 | } 30 | 31 | func TestF64ToTorusVec(t *testing.T) { 32 | input := []float64{0.0, 0.125, 0.25} 33 | expected := []params.Torus{0, 536870912, 1073741824} 34 | 35 | result := utils.F64ToTorusVec(input) 36 | 37 | if len(result) != len(expected) { 38 | t.Fatalf("F64ToTorusVec length = %d, expected %d", len(result), len(expected)) 39 | } 40 | 41 | for i := range result { 42 | if result[i] != expected[i] { 43 | t.Errorf("F64ToTorusVec[%d] = %d, expected %d", i, result[i], expected[i]) 44 | } 45 | } 46 | } 47 | --------------------------------------------------------------------------------