├── demo ├── .gitignore ├── Dockerfile.keygen ├── Dockerfile.wg-server ├── test-client.js ├── Dockerfile.node-server ├── setup.sh ├── node-server-1.js ├── node-server-2.js ├── docker-compose.yml └── keygen.sh ├── example-wg0.conf ├── version.go ├── go.mod ├── .gitignore ├── example-usage.sh ├── LICENSE ├── demo-policy-routing.sh ├── go.sum ├── ipc.go ├── logger.go ├── .github └── workflows │ ├── test.yml │ └── release.yml ├── socks.go ├── Makefile ├── routing_cli_test.go ├── version_test.go ├── forwarder.go ├── POLICY_ROUTING.md ├── routing_test.go ├── README.md ├── routing.go ├── main.go ├── socks_test.go ├── config.go ├── tunnel.go ├── lib └── intercept.c ├── forwarder_test.go ├── logger_test.go ├── ipc_test.go ├── main_test.go ├── tunnel_test.go └── config_test.go /demo/.gitignore: -------------------------------------------------------------------------------- 1 | # WireGuard keys and configurations 2 | keys/ 3 | configs/ 4 | 5 | # Docker volumes 6 | volumes/ 7 | 8 | # Logs 9 | *.log 10 | 11 | # Build artifacts 12 | wrapguard-src/ -------------------------------------------------------------------------------- /demo/Dockerfile.keygen: -------------------------------------------------------------------------------- 1 | FROM alpine:latest 2 | 3 | RUN apk add --no-cache \ 4 | wireguard-tools \ 5 | bash 6 | 7 | WORKDIR /workspace 8 | 9 | COPY keygen.sh /keygen.sh 10 | RUN chmod +x /keygen.sh 11 | 12 | CMD ["/keygen.sh"] -------------------------------------------------------------------------------- /example-wg0.conf: -------------------------------------------------------------------------------- 1 | [Interface] 2 | PrivateKey = eDsYEfddDm8jE8sUBnfG9GZm0mqTYGJhxbsOjzKvBUo= 3 | Address = 10.150.0.2/24 4 | 5 | [Peer] 6 | PublicKey = sJwKzKorIGo/ZHAPDnmM5dk0ZmQlkf4aNtRVK6eYInU= 7 | PresharedKey = ve5d5GUSnojL/5mrn7srhnRjhyrVWsTBSPwfpEIT4DA= 8 | Endpoint = 127.0.0.1:51820 9 | AllowedIPs = 10.150.0.0/24 10 | PersistentKeepalive = 25 11 | -------------------------------------------------------------------------------- /demo/Dockerfile.wg-server: -------------------------------------------------------------------------------- 1 | FROM alpine:latest 2 | 3 | RUN apk add --no-cache \ 4 | wireguard-tools \ 5 | iptables \ 6 | ip6tables 7 | 8 | # Copy the WireGuard configuration 9 | COPY configs/wg-server.conf /etc/wireguard/wg0.conf 10 | 11 | # Expose WireGuard port 12 | EXPOSE 51820/udp 13 | 14 | # Start WireGuard 15 | CMD ["sh", "-c", "wg-quick up wg0 && sleep infinity"] -------------------------------------------------------------------------------- /version.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // Version information 4 | const ( 5 | Version = "v1.0.0-dev" 6 | AppName = "WrapGuard" 7 | ) 8 | 9 | // GetVersion returns the version string 10 | func GetVersion() string { 11 | return Version 12 | } 13 | 14 | // GetFullVersion returns the full version string with app name 15 | func GetFullVersion() string { 16 | return AppName + " " + Version 17 | } 18 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/puzed/wrapguard 2 | 3 | go 1.24.0 4 | 5 | toolchain go1.24.2 6 | 7 | require ( 8 | github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 9 | golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb 10 | ) 11 | 12 | require ( 13 | golang.org/x/crypto v0.46.0 // indirect 14 | golang.org/x/net v0.48.0 // indirect 15 | golang.org/x/sys v0.39.0 // indirect 16 | golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect 17 | ) 18 | -------------------------------------------------------------------------------- /demo/test-client.js: -------------------------------------------------------------------------------- 1 | const net = require('net'); 2 | 3 | console.log('Test client starting...'); 4 | 5 | const client = new net.Socket(); 6 | 7 | client.on('connect', () => { 8 | console.log('Connected!'); 9 | client.write('GET / HTTP/1.1\r\nHost: 10.150.0.3:8002\r\n\r\n'); 10 | }); 11 | 12 | client.on('data', (data) => { 13 | console.log('Received:', data.toString()); 14 | client.destroy(); 15 | }); 16 | 17 | client.on('error', (err) => { 18 | console.log('Error:', err); 19 | }); 20 | 21 | client.on('close', () => { 22 | console.log('Connection closed'); 23 | }); 24 | 25 | console.log('Connecting to 10.150.0.3:8002...'); 26 | client.connect(8002, '10.150.0.3'); -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries 2 | wrapguard 3 | libwrapguard.so 4 | *.so 5 | *.dylib 6 | *.dll 7 | 8 | # Go build artifacts 9 | *.exe 10 | *.exe~ 11 | *.test 12 | *.out 13 | 14 | # Go workspace files 15 | go.work 16 | go.work.sum 17 | 18 | # Dependency directories 19 | vendor/ 20 | 21 | # IDE and editor files 22 | .vscode/ 23 | .idea/ 24 | *.swp 25 | *.swo 26 | *~ 27 | .DS_Store 28 | 29 | # Test coverage 30 | *.prof 31 | coverage.txt 32 | coverage.html 33 | 34 | # Debug files 35 | *.log 36 | debug 37 | 38 | # Temporary files 39 | *.tmp 40 | *.temp 41 | /tmp/ 42 | 43 | # WireGuard config files with real keys 44 | *.conf 45 | !example-wg0.conf 46 | 47 | # Build directories 48 | /dist/ 49 | /build/ 50 | /release/ 51 | 52 | # Local environment files 53 | .env 54 | .env.local 55 | 56 | # OS specific files 57 | Thumbs.db 58 | .Spotlight-V100 59 | .Trashes 60 | 61 | # Claude AI files (keep these) 62 | # .claude/ 63 | -------------------------------------------------------------------------------- /example-usage.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Example usage of WrapGuard with routing options 4 | 5 | echo "Example 1: Using exit node to route all traffic through a specific peer" 6 | echo "wrapguard --config=wg0.conf --exit-node=10.150.0.3 -- curl https://icanhazip.com" 7 | echo "" 8 | 9 | echo "Example 2: Routing specific subnets through different peers" 10 | echo "wrapguard --config=wg0.conf \\" 11 | echo " --route=192.168.0.0/16:10.150.0.3 \\" 12 | echo " --route=172.16.0.0/12:10.150.0.4 \\" 13 | echo " -- ssh internal.corp.com" 14 | echo "" 15 | 16 | echo "Example 3: Combining exit node with specific routes" 17 | echo "wrapguard --config=wg0.conf \\" 18 | echo " --exit-node=10.150.0.5 \\" 19 | echo " --route=10.0.0.0/8:10.150.0.3 \\" 20 | echo " -- curl https://example.com" 21 | echo "" 22 | 23 | echo "Note: The peer IPs (like 10.150.0.3) must be within the AllowedIPs range of the corresponding peer in your config." -------------------------------------------------------------------------------- /demo/Dockerfile.node-server: -------------------------------------------------------------------------------- 1 | FROM golang:1.24-alpine AS builder 2 | 3 | # Install build dependencies 4 | RUN apk add --no-cache \ 5 | gcc \ 6 | musl-dev \ 7 | make 8 | 9 | # Copy wrapguard source 10 | WORKDIR /build 11 | COPY wrapguard-src/ ./ 12 | 13 | # Build wrapguard 14 | RUN make build 15 | 16 | FROM node:20-alpine 17 | 18 | # Install runtime dependencies 19 | RUN apk add --no-cache \ 20 | libc6-compat \ 21 | bash \ 22 | curl 23 | 24 | # Create app directory 25 | WORKDIR /app 26 | 27 | # Copy wrapguard binary and library from builder 28 | COPY --from=builder /build/wrapguard /usr/local/bin/wrapguard 29 | COPY --from=builder /build/libwrapguard.so /usr/local/bin/libwrapguard.so 30 | 31 | # Make wrapguard executable 32 | RUN chmod +x /usr/local/bin/wrapguard 33 | 34 | # Copy the Node.js application 35 | COPY node-server-*.js ./ 36 | 37 | # Expose the application ports 38 | EXPOSE 8001 8002 39 | 40 | # The CMD will be overridden in docker-compose.yml for each service -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Mark Wylde 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /demo-policy-routing.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Demo script for testing policy-based routing 4 | 5 | echo "=== WrapGuard Policy-Based Routing Demo ===" 6 | echo "" 7 | echo "This demo shows how WrapGuard routes traffic through different peers" 8 | echo "based on destination IP, protocol, and port." 9 | echo "" 10 | 11 | # Check if wrapguard is built 12 | if [ ! -f "./wrapguard" ]; then 13 | echo "Building wrapguard..." 14 | make build 15 | fi 16 | 17 | # Enable debug logging 18 | export WRAPGUARD_DEBUG=1 19 | 20 | echo "1. Testing general traffic routing (should go through peer 1):" 21 | echo " Command: wrapguard --config=example-policy-routing.conf --log-level=debug -- curl -s https://icanhazip.com" 22 | echo "" 23 | 24 | echo "2. Testing port 8080 routing (should go through peer 2):" 25 | echo " Command: wrapguard --config=example-policy-routing.conf --log-level=debug -- curl -s http://example.com:8080" 26 | echo "" 27 | 28 | echo "3. Testing development network routing (should go through peer 3):" 29 | echo " Command: wrapguard --config=example-policy-routing.conf --log-level=debug -- ping -c 1 10.1.2.3" 30 | echo "" 31 | 32 | echo "Note: The actual commands won't work without real WireGuard peers configured," 33 | echo "but the debug logs will show which peer would be selected for each connection." -------------------------------------------------------------------------------- /demo/setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | rm -rf keys/ configs/ wrapguard-src 4 | 5 | set -e 6 | 7 | echo "🔐 WrapGuard Demo Setup" 8 | echo "======================" 9 | echo "" 10 | echo "This script prepares everything needed for the WrapGuard demo." 11 | echo "No WireGuard installation required on your host!" 12 | echo "" 13 | 14 | # Step 1: Prepare wrapguard source for Docker build 15 | echo "📦 Preparing wrapguard source for Docker build..." 16 | mkdir -p wrapguard-src 17 | 18 | # Copy necessary files from parent directory 19 | cp ../*.go wrapguard-src/ 2>/dev/null || true 20 | cp ../go.mod wrapguard-src/ 21 | cp ../go.sum wrapguard-src/ 22 | cp ../Makefile wrapguard-src/ 23 | cp -r ../lib wrapguard-src/ 24 | 25 | echo "✅ Source files copied to wrapguard-src/" 26 | echo "" 27 | 28 | # Step 2: Build and run the key generation container 29 | echo "🐳 Building key generation container..." 30 | docker build -f Dockerfile.keygen -t wrapguard-keygen . 31 | 32 | echo "" 33 | echo "🔑 Running key generation..." 34 | docker run --rm -v "$(pwd):/workspace" wrapguard-keygen 35 | 36 | echo "" 37 | echo "✅ Setup complete! Everything is ready." 38 | echo "" 39 | echo "📁 Generated files:" 40 | echo " - wrapguard-src/ (Source code for building)" 41 | echo " - keys/ (WireGuard private/public keys)" 42 | echo " - configs/ (WireGuard configuration files)" 43 | echo "" 44 | echo "🚀 Next step: docker compose up --build" -------------------------------------------------------------------------------- /demo/node-server-1.js: -------------------------------------------------------------------------------- 1 | const http = require('http'); 2 | const os = require('os'); 3 | 4 | console.log('Node.js version:', process.version); 5 | console.log('Process PID:', process.pid); 6 | 7 | const PORT = 8001; 8 | const SERVER_NAME = 'Node Server 1'; 9 | const MY_IP = '10.150.0.2'; 10 | const OTHER_SERVER = 'http://10.150.0.3:8002'; 11 | 12 | // Create HTTP server 13 | const server = http.createServer(async (req, res) => { 14 | res.end('i am 1') 15 | }); 16 | 17 | // Start server 18 | console.log('About to call server.listen...'); 19 | server.listen(PORT, '0.0.0.0', () => { 20 | console.log(`🚀 ${SERVER_NAME} listening on port ${PORT}`); 21 | console.log(`📍 WireGuard IP: ${MY_IP}`); 22 | console.log(`🔗 Other server: ${OTHER_SERVER}`); 23 | 24 | console.log('🎯 Server is listening and ready to accept connections'); 25 | 26 | setInterval(() => { 27 | console.log('server1: attemping connection') 28 | 29 | const req = http.request(OTHER_SERVER, (res) => { 30 | let data = ''; 31 | res.on('data', (chunk) => { 32 | data += chunk; 33 | }); 34 | res.on('end', () => { 35 | console.log('Response:', data); 36 | console.log('✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅'); 37 | }); 38 | }); 39 | 40 | req.on('error', (error) => { 41 | console.log('Error:', error.message); 42 | }); 43 | 44 | req.end(); 45 | }, 3000) 46 | }); 47 | 48 | console.log('Server setup complete, waiting for listen callback...'); 49 | -------------------------------------------------------------------------------- /demo/node-server-2.js: -------------------------------------------------------------------------------- 1 | const http = require('http'); 2 | const os = require('os'); 3 | 4 | console.log('Node.js version:', process.version); 5 | console.log('Process PID:', process.pid); 6 | 7 | const PORT = 8002; 8 | const SERVER_NAME = 'Node Server 2'; 9 | const MY_IP = '10.150.0.3'; 10 | const OTHER_SERVER = 'http://10.150.0.2:8001'; 11 | 12 | // Create HTTP server 13 | const server = http.createServer(async (req, res) => { 14 | res.end('i am 2') 15 | }); 16 | 17 | 18 | // Start server 19 | console.log('About to call server.listen...'); 20 | server.listen(PORT, '0.0.0.0', () => { 21 | console.log(`🚀 ${SERVER_NAME} listening on port ${PORT}`); 22 | console.log(`📍 WireGuard IP: ${MY_IP}`); 23 | console.log(`🔗 Other server: ${OTHER_SERVER}`); 24 | 25 | console.log('🎯 Server is listening and ready to accept connections'); 26 | 27 | setInterval(() => { 28 | console.log('server2: attemping connection') 29 | 30 | const req = http.request(OTHER_SERVER, (res) => { 31 | let data = ''; 32 | res.on('data', (chunk) => { 33 | data += chunk; 34 | }); 35 | res.on('end', () => { 36 | console.log('Response:', data); 37 | console.log('✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅'); 38 | }); 39 | }); 40 | 41 | req.on('error', (error) => { 42 | console.log('Error:', error.message); 43 | }); 44 | 45 | req.end(); 46 | }, 10000) 47 | }); 48 | 49 | 50 | console.log('Server setup complete, waiting for listen callback...'); -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= 2 | github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= 3 | github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= 4 | github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= 5 | golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= 6 | golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= 7 | golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= 8 | golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= 9 | golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= 10 | golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= 11 | golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= 12 | golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= 13 | golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= 14 | golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= 15 | golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A= 16 | golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw= 17 | gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c h1:m/r7OM+Y2Ty1sgBQ7Qb27VgIMBW8ZZhT4gLnUyDIhzI= 18 | gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g= 19 | -------------------------------------------------------------------------------- /demo/docker-compose.yml: -------------------------------------------------------------------------------- 1 | 2 | services: 3 | wg-server: 4 | build: 5 | context: . 6 | dockerfile: Dockerfile.wg-server 7 | container_name: wg-server 8 | cap_add: 9 | - NET_ADMIN 10 | - SYS_MODULE 11 | sysctls: 12 | - net.ipv4.ip_forward=1 13 | volumes: 14 | - ./configs/wg-server.conf:/etc/wireguard/wg0.conf:ro 15 | networks: 16 | - wrapguard 17 | ports: 18 | - "51820:51820/udp" 19 | 20 | node-server-1: 21 | build: 22 | context: . 23 | dockerfile: Dockerfile.node-server 24 | container_name: node-server-1 25 | networks: 26 | - wrapguard 27 | volumes: 28 | - ./configs/node-server-1.conf:/etc/wireguard/wg0.conf:ro 29 | - ./node-server-1.js:/app/server.js:ro 30 | environment: 31 | - NODE_ENV=production 32 | - SERVER_ID=1 33 | command: > 34 | sh -c " 35 | echo 'Starting Node Server 1 with wrapguard...'; 36 | wrapguard --config=/etc/wireguard/wg0.conf --log-level=debug -- node /app/server.js 37 | " 38 | depends_on: 39 | - wg-server 40 | 41 | node-server-2: 42 | build: 43 | context: . 44 | dockerfile: Dockerfile.node-server 45 | container_name: node-server-2 46 | networks: 47 | - wrapguard 48 | volumes: 49 | - ./configs/node-server-2.conf:/etc/wireguard/wg0.conf:ro 50 | - ./node-server-2.js:/app/server.js:ro 51 | environment: 52 | - NODE_ENV=production 53 | - SERVER_ID=2 54 | command: > 55 | sh -c " 56 | echo 'Starting Node Server 2 with wrapguard...'; 57 | wrapguard --config=/etc/wireguard/wg0.conf --log-level=debug -- node /app/server.js 58 | " 59 | depends_on: 60 | - wg-server 61 | 62 | networks: 63 | wrapguard: 64 | driver: bridge -------------------------------------------------------------------------------- /ipc.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bufio" 5 | "encoding/json" 6 | "fmt" 7 | "net" 8 | "os" 9 | "path/filepath" 10 | ) 11 | 12 | type IPCMessage struct { 13 | Type string `json:"type"` // "CONNECT" or "BIND" 14 | FD int `json:"fd"` 15 | Port int `json:"port"` 16 | Addr string `json:"addr"` 17 | } 18 | 19 | type IPCServer struct { 20 | listener net.Listener 21 | socketPath string 22 | msgChan chan IPCMessage 23 | } 24 | 25 | func NewIPCServer() (*IPCServer, error) { 26 | // Create socket path in temp directory 27 | socketPath := filepath.Join(os.TempDir(), fmt.Sprintf("wrapguard-%d.sock", os.Getpid())) 28 | 29 | // Remove existing socket if it exists 30 | os.Remove(socketPath) 31 | 32 | listener, err := net.Listen("unix", socketPath) 33 | if err != nil { 34 | return nil, fmt.Errorf("failed to create IPC socket: %w", err) 35 | } 36 | 37 | server := &IPCServer{ 38 | listener: listener, 39 | socketPath: socketPath, 40 | msgChan: make(chan IPCMessage, 100), 41 | } 42 | 43 | // Start accepting connections 44 | go server.acceptConnections() 45 | 46 | return server, nil 47 | } 48 | 49 | func (s *IPCServer) acceptConnections() { 50 | for { 51 | conn, err := s.listener.Accept() 52 | if err != nil { 53 | // Server is shutting down 54 | break 55 | } 56 | 57 | // Handle connection in background 58 | go s.handleConnection(conn) 59 | } 60 | } 61 | 62 | func (s *IPCServer) handleConnection(conn net.Conn) { 63 | defer conn.Close() 64 | 65 | scanner := bufio.NewScanner(conn) 66 | for scanner.Scan() { 67 | line := scanner.Text() 68 | 69 | var msg IPCMessage 70 | if err := json.Unmarshal([]byte(line), &msg); err != nil { 71 | fmt.Printf("IPC: Failed to parse message: %v\n", err) 72 | continue 73 | } 74 | 75 | // Send message to channel (non-blocking) 76 | select { 77 | case s.msgChan <- msg: 78 | default: 79 | fmt.Printf("IPC: Message channel full, dropping message\n") 80 | } 81 | } 82 | } 83 | 84 | func (s *IPCServer) SocketPath() string { 85 | return s.socketPath 86 | } 87 | 88 | func (s *IPCServer) MessageChan() <-chan IPCMessage { 89 | return s.msgChan 90 | } 91 | 92 | func (s *IPCServer) Close() error { 93 | if s.listener != nil { 94 | s.listener.Close() 95 | } 96 | 97 | // Clean up socket file 98 | if s.socketPath != "" { 99 | os.Remove(s.socketPath) 100 | } 101 | 102 | return nil 103 | } 104 | -------------------------------------------------------------------------------- /logger.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io" 7 | "os" 8 | "strings" 9 | "sync" 10 | "time" 11 | ) 12 | 13 | type LogLevel int 14 | 15 | const ( 16 | LogLevelError LogLevel = iota 17 | LogLevelWarn 18 | LogLevelInfo 19 | LogLevelDebug 20 | ) 21 | 22 | func (l LogLevel) String() string { 23 | switch l { 24 | case LogLevelError: 25 | return "error" 26 | case LogLevelWarn: 27 | return "warn" 28 | case LogLevelInfo: 29 | return "info" 30 | case LogLevelDebug: 31 | return "debug" 32 | default: 33 | return "unknown" 34 | } 35 | } 36 | 37 | func ParseLogLevel(s string) (LogLevel, error) { 38 | switch strings.ToLower(s) { 39 | case "error": 40 | return LogLevelError, nil 41 | case "warn", "warning": 42 | return LogLevelWarn, nil 43 | case "info": 44 | return LogLevelInfo, nil 45 | case "debug": 46 | return LogLevelDebug, nil 47 | default: 48 | return LogLevelInfo, fmt.Errorf("invalid log level: %s", s) 49 | } 50 | } 51 | 52 | type Logger struct { 53 | level LogLevel 54 | output io.Writer 55 | mu sync.Mutex 56 | } 57 | 58 | type LogEntry struct { 59 | Timestamp string `json:"timestamp"` 60 | Level string `json:"level"` 61 | Message string `json:"message"` 62 | } 63 | 64 | func NewLogger(level LogLevel, output io.Writer) *Logger { 65 | return &Logger{ 66 | level: level, 67 | output: output, 68 | } 69 | } 70 | 71 | func (l *Logger) log(level LogLevel, format string, args ...interface{}) { 72 | if level > l.level { 73 | return 74 | } 75 | 76 | entry := LogEntry{ 77 | Timestamp: time.Now().UTC().Format(time.RFC3339), 78 | Level: level.String(), 79 | Message: fmt.Sprintf(format, args...), 80 | } 81 | 82 | data, _ := json.Marshal(entry) 83 | 84 | l.mu.Lock() 85 | fmt.Fprintf(l.output, "%s\n", data) 86 | l.mu.Unlock() 87 | } 88 | 89 | func (l *Logger) Errorf(format string, args ...interface{}) { 90 | l.log(LogLevelError, format, args...) 91 | } 92 | 93 | func (l *Logger) Warnf(format string, args ...interface{}) { 94 | l.log(LogLevelWarn, format, args...) 95 | } 96 | 97 | func (l *Logger) Infof(format string, args ...interface{}) { 98 | l.log(LogLevelInfo, format, args...) 99 | } 100 | 101 | func (l *Logger) Debugf(format string, args ...interface{}) { 102 | l.log(LogLevelDebug, format, args...) 103 | } 104 | 105 | // Global logger instance 106 | var logger *Logger 107 | 108 | func init() { 109 | // Default logger to stderr with info level 110 | logger = NewLogger(LogLevelInfo, os.Stderr) 111 | } 112 | 113 | func SetGlobalLogger(l *Logger) { 114 | logger = l 115 | } 116 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | test: 11 | name: Test 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - name: Check out code 16 | uses: actions/checkout@v4 17 | 18 | - name: Set up Go 19 | uses: actions/setup-go@v5 20 | with: 21 | go-version: '1.23' 22 | 23 | - name: Cache Go modules 24 | uses: actions/cache@v4 25 | with: 26 | path: | 27 | ~/.cache/go-build 28 | ~/go/pkg/mod 29 | key: ${{ runner.os }}-go-1.23-${{ hashFiles('**/go.sum') }} 30 | restore-keys: | 31 | ${{ runner.os }}-go-1.23- 32 | 33 | - name: Download dependencies 34 | run: go mod download 35 | 36 | - name: Verify dependencies 37 | run: go mod verify 38 | 39 | - name: Run tests 40 | run: go test -v -race -coverprofile=coverage.out ./... 41 | 42 | - name: Run tests with coverage 43 | run: go test -cover ./... 44 | 45 | - name: Upload coverage reports to Codecov 46 | uses: codecov/codecov-action@v4 47 | with: 48 | file: ./coverage.out 49 | flags: unittests 50 | name: codecov-umbrella 51 | fail_ci_if_error: false 52 | 53 | lint: 54 | name: Lint 55 | runs-on: ubuntu-latest 56 | 57 | steps: 58 | - name: Check out code 59 | uses: actions/checkout@v4 60 | 61 | - name: Set up Go 62 | uses: actions/setup-go@v5 63 | with: 64 | go-version: '1.23' 65 | 66 | - name: Run go vet 67 | run: go vet ./... 68 | 69 | - name: Check formatting 70 | run: | 71 | if [ "$(gofmt -s -l . | wc -l)" -gt 0 ]; then 72 | echo "Go files are not formatted:" 73 | gofmt -d . 74 | exit 1 75 | fi 76 | 77 | build: 78 | name: Build 79 | runs-on: ubuntu-latest 80 | 81 | steps: 82 | - name: Check out code 83 | uses: actions/checkout@v4 84 | 85 | - name: Set up Go 86 | uses: actions/setup-go@v5 87 | with: 88 | go-version: '1.23' 89 | 90 | - name: Install build dependencies 91 | run: sudo apt-get update && sudo apt-get install -y gcc 92 | 93 | - name: Build binary 94 | run: make build 95 | 96 | - name: Verify binary exists 97 | run: | 98 | ls -la wrapguard 99 | ls -la libwrapguard.so 100 | file wrapguard 101 | file libwrapguard.so 102 | 103 | - name: Test binary runs 104 | run: | 105 | ./wrapguard --version 106 | ./wrapguard --help 107 | -------------------------------------------------------------------------------- /socks.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net" 7 | "strconv" 8 | 9 | "github.com/armon/go-socks5" 10 | ) 11 | 12 | type SOCKS5Server struct { 13 | server *socks5.Server 14 | listener net.Listener 15 | port int 16 | tunnel *Tunnel 17 | } 18 | 19 | func NewSOCKS5Server(tunnel *Tunnel) (*SOCKS5Server, error) { 20 | // Create SOCKS5 server with custom dialer that routes WireGuard IPs through the tunnel 21 | socksConfig := &socks5.Config{ 22 | Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { 23 | logger.Debugf("SOCKS5 dial request: %s %s", network, addr) 24 | 25 | // Parse the address to check if it's a WireGuard IP 26 | host, port, err := net.SplitHostPort(addr) 27 | if err != nil { 28 | return nil, fmt.Errorf("invalid address format: %w", err) 29 | } 30 | 31 | // Check if this is a WireGuard IP that should be routed through the tunnel 32 | ip := net.ParseIP(host) 33 | if ip != nil { 34 | // Use routing engine to find appropriate peer 35 | portNum, _ := strconv.Atoi(port) 36 | peer, peerIdx := tunnel.router.FindPeerForDestination(ip, portNum, "tcp") 37 | if peer != nil { 38 | logger.Debugf("Routing %s through WireGuard tunnel via peer %d (endpoint: %s)", addr, peerIdx, peer.Endpoint) 39 | return tunnel.DialWireGuard(ctx, network, host, port) 40 | } 41 | } 42 | 43 | // For non-WireGuard IPs, use normal dialing 44 | logger.Debugf("Using normal dial for %s", addr) 45 | dialer := &net.Dialer{} 46 | conn, err := dialer.DialContext(ctx, network, addr) 47 | if err != nil { 48 | logger.Debugf("SOCKS5 dial failed for %s: %v", addr, err) 49 | } else { 50 | logger.Debugf("SOCKS5 dial succeeded for %s", addr) 51 | } 52 | return conn, err 53 | }, 54 | } 55 | 56 | server, err := socks5.New(socksConfig) 57 | if err != nil { 58 | return nil, fmt.Errorf("failed to create SOCKS5 server: %w", err) 59 | } 60 | 61 | // Listen on localhost for SOCKS5 connections 62 | listener, err := net.Listen("tcp", "127.0.0.1:0") 63 | if err != nil { 64 | return nil, fmt.Errorf("failed to listen for SOCKS5 connections: %w", err) 65 | } 66 | 67 | port := listener.Addr().(*net.TCPAddr).Port 68 | 69 | s := &SOCKS5Server{ 70 | server: server, 71 | listener: listener, 72 | port: port, 73 | tunnel: tunnel, 74 | } 75 | 76 | // Start serving in background 77 | go func() { 78 | if err := server.Serve(listener); err != nil { 79 | // Log error but don't crash - server might be shutting down 80 | logger.Debugf("SOCKS5 server stopped: %v", err) 81 | } 82 | }() 83 | 84 | return s, nil 85 | } 86 | 87 | func (s *SOCKS5Server) Port() int { 88 | return s.port 89 | } 90 | 91 | func (s *SOCKS5Server) Close() error { 92 | if s.listener != nil { 93 | return s.listener.Close() 94 | } 95 | return nil 96 | } 97 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all build clean test 2 | 3 | # Build variables 4 | GO_MODULE = github.com/puzed/wrapguard 5 | BINARY_NAME = wrapguard 6 | LIBRARY_NAME = libwrapguard.so 7 | VERSION = 1.0.0-dev 8 | 9 | # Build flags 10 | GO_BUILD_FLAGS = -ldflags="-s -w -X main.version=$(VERSION)" 11 | C_BUILD_FLAGS = -shared -fPIC -ldl 12 | 13 | # Default target 14 | all: build 15 | 16 | # Build both Go binary and C library 17 | build: $(BINARY_NAME) $(LIBRARY_NAME) 18 | 19 | # Build Go binary 20 | $(BINARY_NAME): *.go go.mod go.sum 21 | @echo "Building Go binary..." 22 | go mod tidy 23 | go build $(GO_BUILD_FLAGS) -o $(BINARY_NAME) . 24 | 25 | # Build C library 26 | $(LIBRARY_NAME): lib/intercept.c 27 | @echo "Building C library..." 28 | gcc $(C_BUILD_FLAGS) -o $(LIBRARY_NAME) lib/intercept.c 29 | 30 | # Clean build artifacts 31 | clean: 32 | @echo "Cleaning build artifacts..." 33 | rm -f $(BINARY_NAME) $(LIBRARY_NAME) 34 | go clean 35 | 36 | # Run tests 37 | test: 38 | @echo "Running tests..." 39 | go test -v ./... 40 | 41 | # Run tests with coverage 42 | test-coverage: 43 | @echo "Running tests with coverage..." 44 | go test -cover ./... 45 | 46 | # Build debug version 47 | debug: GO_BUILD_FLAGS = -ldflags="-X main.version=$(VERSION)-debug" 48 | debug: C_BUILD_FLAGS += -g -O0 49 | debug: build 50 | 51 | # Install dependencies 52 | deps: 53 | @echo "Installing dependencies..." 54 | go mod download 55 | 56 | # Format code 57 | fmt: 58 | @echo "Formatting Go code..." 59 | go fmt ./... 60 | 61 | # Run linter 62 | lint: 63 | @echo "Running linter..." 64 | go vet ./... 65 | 66 | # Build for multiple platforms 67 | build-all: build-linux build-darwin 68 | 69 | build-linux: 70 | @echo "Building for Linux..." 71 | GOOS=linux GOARCH=amd64 go build $(GO_BUILD_FLAGS) -o $(BINARY_NAME)-linux-amd64 . 72 | gcc $(C_BUILD_FLAGS) -o libwrapguard-linux-amd64.so lib/intercept.c 73 | 74 | build-darwin: 75 | @echo "Building for macOS..." 76 | GOOS=darwin GOARCH=amd64 go build $(GO_BUILD_FLAGS) -o $(BINARY_NAME)-darwin-amd64 . 77 | GOOS=darwin GOARCH=arm64 go build $(GO_BUILD_FLAGS) -o $(BINARY_NAME)-darwin-arm64 . 78 | gcc $(C_BUILD_FLAGS) -o libwrapguard-darwin.dylib lib/intercept.c 79 | 80 | # Run demo 81 | demo: build 82 | @echo "Running demo..." 83 | cd demo && ./setup.sh && docker-compose up 84 | 85 | # Help 86 | help: 87 | @echo "Available targets:" 88 | @echo " all - Build both binary and library (default)" 89 | @echo " build - Build both binary and library" 90 | @echo " clean - Clean build artifacts" 91 | @echo " test - Run tests" 92 | @echo " test-coverage- Run tests with coverage" 93 | @echo " debug - Build debug version" 94 | @echo " deps - Install dependencies" 95 | @echo " fmt - Format Go code" 96 | @echo " lint - Run linter" 97 | @echo " build-all - Build for multiple platforms" 98 | @echo " demo - Run demo" 99 | @echo " help - Show this help" -------------------------------------------------------------------------------- /routing_cli_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestApplyCLIRoutes(t *testing.T) { 8 | // Create a test configuration 9 | config := &WireGuardConfig{ 10 | Interface: InterfaceConfig{ 11 | PrivateKey: "test-private-key", 12 | Address: "10.0.0.2/24", 13 | }, 14 | Peers: []PeerConfig{ 15 | { 16 | PublicKey: "peer1-public-key", 17 | Endpoint: "192.168.1.100:51820", 18 | AllowedIPs: []string{"10.0.0.0/24"}, 19 | }, 20 | { 21 | PublicKey: "peer2-public-key", 22 | Endpoint: "192.168.1.101:51820", 23 | AllowedIPs: []string{"10.1.0.0/24"}, 24 | }, 25 | }, 26 | } 27 | 28 | // Test exit node 29 | err := ApplyCLIRoutes(config, "10.0.0.3", nil) 30 | if err != nil { 31 | t.Fatalf("Failed to apply exit node: %v", err) 32 | } 33 | 34 | // Check that the routing policy was added to the correct peer 35 | peer1 := &config.Peers[0] 36 | if len(peer1.RoutingPolicies) != 1 { 37 | t.Fatalf("Expected 1 routing policy, got %d", len(peer1.RoutingPolicies)) 38 | } 39 | 40 | policy := peer1.RoutingPolicies[0] 41 | if policy.DestinationCIDR != "0.0.0.0/0" { 42 | t.Errorf("Expected destination CIDR 0.0.0.0/0, got %s", policy.DestinationCIDR) 43 | } 44 | 45 | // Test specific routes 46 | routes := []string{"192.168.1.0/24:10.0.0.4", "172.16.0.0/12:10.1.0.5"} 47 | err = ApplyCLIRoutes(config, "", routes) 48 | if err != nil { 49 | t.Fatalf("Failed to apply routes: %v", err) 50 | } 51 | 52 | // Check that routes were added to correct peers 53 | if len(peer1.RoutingPolicies) != 2 { 54 | t.Fatalf("Expected 2 routing policies on peer1, got %d", len(peer1.RoutingPolicies)) 55 | } 56 | 57 | peer2 := &config.Peers[1] 58 | if len(peer2.RoutingPolicies) != 1 { 59 | t.Fatalf("Expected 1 routing policy on peer2, got %d", len(peer2.RoutingPolicies)) 60 | } 61 | 62 | // Verify the specific route on peer2 63 | if peer2.RoutingPolicies[0].DestinationCIDR != "172.16.0.0/12" { 64 | t.Errorf("Expected destination CIDR 172.16.0.0/12, got %s", peer2.RoutingPolicies[0].DestinationCIDR) 65 | } 66 | } 67 | 68 | func TestApplyCLIRoutesErrors(t *testing.T) { 69 | config := &WireGuardConfig{ 70 | Interface: InterfaceConfig{ 71 | PrivateKey: "test-private-key", 72 | Address: "10.0.0.2/24", 73 | }, 74 | Peers: []PeerConfig{ 75 | { 76 | PublicKey: "peer1-public-key", 77 | Endpoint: "192.168.1.100:51820", 78 | AllowedIPs: []string{"10.0.0.0/24"}, 79 | }, 80 | }, 81 | } 82 | 83 | // Test invalid route format 84 | err := ApplyCLIRoutes(config, "", []string{"invalid-route"}) 85 | if err == nil { 86 | t.Error("Expected error for invalid route format") 87 | } 88 | 89 | // Test invalid CIDR 90 | err = ApplyCLIRoutes(config, "", []string{"invalid-cidr:10.0.0.3"}) 91 | if err == nil { 92 | t.Error("Expected error for invalid CIDR") 93 | } 94 | 95 | // Test peer IP not found 96 | err = ApplyCLIRoutes(config, "", []string{"192.168.1.0/24:192.168.1.1"}) 97 | if err == nil { 98 | t.Error("Expected error for peer IP not in any AllowedIPs") 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /demo/keygen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | echo "🔑 Generating WireGuard keys..." 5 | 6 | # Create keys directory 7 | mkdir -p /workspace/keys 8 | 9 | # Generate keys for each peer 10 | for peer in wg-server node-server-1 node-server-2; do 11 | echo "Generating keys for $peer..." 12 | wg genkey > /workspace/keys/${peer}_private.key 13 | wg pubkey < /workspace/keys/${peer}_private.key > /workspace/keys/${peer}_public.key 14 | echo "✅ Keys generated for $peer" 15 | done 16 | 17 | # Generate preshared keys for each client-server pair 18 | echo "Generating preshared keys..." 19 | wg genpsk > /workspace/keys/server_node1_preshared.key 20 | wg genpsk > /workspace/keys/server_node2_preshared.key 21 | 22 | # Set proper permissions 23 | chmod 600 /workspace/keys/*_private.key 24 | chmod 644 /workspace/keys/*_public.key 25 | chmod 600 /workspace/keys/*_preshared.key 26 | 27 | echo "" 28 | echo "📋 Generated keys:" 29 | echo "==================" 30 | for peer in wg-server node-server-1 node-server-2; do 31 | echo "$peer:" 32 | echo " Public: $(cat /workspace/keys/${peer}_public.key)" 33 | echo "" 34 | done 35 | 36 | # Generate configuration files 37 | echo "📝 Generating WireGuard configurations..." 38 | 39 | # Read keys 40 | WG_SERVER_PRIVATE=$(cat /workspace/keys/wg-server_private.key) 41 | WG_SERVER_PUBLIC=$(cat /workspace/keys/wg-server_public.key) 42 | NODE1_PRIVATE=$(cat /workspace/keys/node-server-1_private.key) 43 | NODE1_PUBLIC=$(cat /workspace/keys/node-server-1_public.key) 44 | NODE2_PRIVATE=$(cat /workspace/keys/node-server-2_private.key) 45 | NODE2_PUBLIC=$(cat /workspace/keys/node-server-2_public.key) 46 | PSK_1=$(cat /workspace/keys/server_node1_preshared.key) 47 | PSK_2=$(cat /workspace/keys/server_node2_preshared.key) 48 | 49 | # Create configs directory 50 | mkdir -p /workspace/configs 51 | 52 | # Generate wg-server config 53 | cat > /workspace/configs/wg-server.conf < /workspace/configs/node-server-1.conf < /workspace/configs/node-server-2.conf < 16 | Route = :: 17 | ``` 18 | 19 | ### Route Format 20 | 21 | - ``: Destination network in CIDR notation (e.g., `192.168.1.0/24`, `0.0.0.0/0`) 22 | - ``: `tcp`, `udp`, or `any` (optional, defaults to `any`) 23 | - ``: Port or port range (optional, defaults to all ports) 24 | - Single port: `80` 25 | - Port range: `8080-9000` 26 | - Multiple ports: `80,443` (comma-separated) 27 | 28 | ## Examples 29 | 30 | ### Basic Routing by Destination Network 31 | 32 | ```ini 33 | [Peer] 34 | PublicKey = peer1_public_key 35 | Endpoint = vpn1.example.com:51820 36 | AllowedIPs = 0.0.0.0/0 37 | # Route all traffic through this peer by default 38 | Route = 0.0.0.0/0 39 | 40 | [Peer] 41 | PublicKey = peer2_public_key 42 | Endpoint = vpn2.example.com:51820 43 | AllowedIPs = 192.168.0.0/16 44 | # Route specific subnet through this peer 45 | Route = 192.168.1.0/24 46 | ``` 47 | 48 | ### Protocol and Port-Based Routing 49 | 50 | ```ini 51 | [Peer] 52 | PublicKey = web_peer_public_key 53 | Endpoint = web-vpn.example.com:51820 54 | AllowedIPs = 0.0.0.0/0 55 | # Route web traffic through this peer 56 | Route = 0.0.0.0/0:tcp:80,443 57 | 58 | [Peer] 59 | PublicKey = dev_peer_public_key 60 | Endpoint = dev-vpn.example.com:51820 61 | AllowedIPs = 0.0.0.0/0 62 | # Route development services through this peer 63 | Route = 0.0.0.0/0:tcp:3000-4000 64 | Route = 0.0.0.0/0:tcp:8080-9000 65 | ``` 66 | 67 | ### Complex Multi-Peer Setup 68 | 69 | ```ini 70 | [Interface] 71 | PrivateKey = your_private_key 72 | Address = 10.150.0.2/24 73 | 74 | # Peer 1: General purpose VPN 75 | [Peer] 76 | PublicKey = general_vpn_public_key 77 | Endpoint = general-vpn.example.com:51820 78 | AllowedIPs = 0.0.0.0/0 79 | Route = 0.0.0.0/0 # Default route for all traffic 80 | 81 | # Peer 2: Corporate network access 82 | [Peer] 83 | PublicKey = corp_vpn_public_key 84 | Endpoint = corp-vpn.example.com:51820 85 | AllowedIPs = 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16 86 | # Route corporate networks 87 | Route = 10.0.0.0/8 88 | Route = 172.16.0.0/12 89 | Route = 192.168.0.0/16 90 | # Route specific services 91 | Route = 0.0.0.0/0:tcp:22 # SSH through corporate VPN 92 | Route = 0.0.0.0/0:tcp:3389 # RDP through corporate VPN 93 | 94 | # Peer 3: Streaming and gaming 95 | [Peer] 96 | PublicKey = gaming_vpn_public_key 97 | Endpoint = gaming-vpn.example.com:51820 98 | AllowedIPs = 0.0.0.0/0 99 | # Route gaming and streaming ports 100 | Route = 0.0.0.0/0:udp:5000-6000 # Gaming ports 101 | Route = 0.0.0.0/0:tcp:1935 # RTMP streaming 102 | ``` 103 | 104 | ## Routing Priority 105 | 106 | 1. **Most specific CIDR wins**: `/32` routes take precedence over `/24`, which take precedence over `/0` 107 | 2. **Order matters**: For same CIDR specificity, routes listed first have higher priority 108 | 3. **Protocol matching**: Protocol-specific routes only match their protocol 109 | 4. **Port matching**: Port-specific routes only match connections to those ports 110 | 111 | ## How It Works 112 | 113 | 1. When a connection is initiated, WrapGuard checks the destination IP, protocol, and port 114 | 2. It searches through all configured routing policies to find the best match 115 | 3. Traffic is routed through the WireGuard peer with the matching policy 116 | 4. If no policy matches, it falls back to checking AllowedIPs 117 | 5. If still no match, the default peer (first one with `0.0.0.0/0`) is used 118 | 119 | ## Testing Your Configuration 120 | 121 | To test your routing configuration: 122 | 123 | ```bash 124 | # Check which peer would handle specific traffic 125 | wrapguard --config=policy-routing.conf -- curl https://example.com 126 | wrapguard --config=policy-routing.conf -- ssh user@192.168.1.100 127 | wrapguard --config=policy-routing.conf -- nc -v 10.0.0.5 3000 128 | ``` 129 | 130 | Enable debug logging to see routing decisions: 131 | 132 | ```bash 133 | wrapguard --config=policy-routing.conf --log-level=debug -- your_command 134 | ``` 135 | 136 | ## Limitations 137 | 138 | - Currently only supports IPv4 routing 139 | - Maximum of one route per line (no comma-separated CIDRs) 140 | - Port ranges in route specifications don't support comma-separated values -------------------------------------------------------------------------------- /routing_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "net" 5 | "testing" 6 | ) 7 | 8 | func TestParsePortRange(t *testing.T) { 9 | tests := []struct { 10 | input string 11 | expected PortRange 12 | hasError bool 13 | }{ 14 | {"80", PortRange{Start: 80, End: 80}, false}, 15 | {"8080-9000", PortRange{Start: 8080, End: 9000}, false}, 16 | {"any", PortRange{Start: 1, End: 65535}, false}, 17 | {"", PortRange{Start: 1, End: 65535}, false}, 18 | {"invalid", PortRange{}, true}, 19 | {"80-70", PortRange{}, true}, 20 | {"0-100", PortRange{}, true}, 21 | {"100-70000", PortRange{}, true}, 22 | } 23 | 24 | for _, test := range tests { 25 | result, err := ParsePortRange(test.input) 26 | if test.hasError { 27 | if err == nil { 28 | t.Errorf("Expected error for input %s, but got none", test.input) 29 | } 30 | } else { 31 | if err != nil { 32 | t.Errorf("Unexpected error for input %s: %v", test.input, err) 33 | } 34 | if result != test.expected { 35 | t.Errorf("For input %s, expected %v but got %v", test.input, test.expected, result) 36 | } 37 | } 38 | } 39 | } 40 | 41 | func TestParseRoutingPolicy(t *testing.T) { 42 | tests := []struct { 43 | input string 44 | priority int 45 | expected RoutingPolicy 46 | hasError bool 47 | }{ 48 | { 49 | "192.168.1.0/24", 50 | 0, 51 | RoutingPolicy{ 52 | DestinationCIDR: "192.168.1.0/24", 53 | Protocol: "any", 54 | PortRange: PortRange{Start: 1, End: 65535}, 55 | Priority: 0, 56 | }, 57 | false, 58 | }, 59 | { 60 | "0.0.0.0/0:tcp:80", 61 | 1, 62 | RoutingPolicy{ 63 | DestinationCIDR: "0.0.0.0/0", 64 | Protocol: "tcp", 65 | PortRange: PortRange{Start: 80, End: 80}, 66 | Priority: 1, 67 | }, 68 | false, 69 | }, 70 | { 71 | "10.0.0.0/8:udp:5000-6000", 72 | 2, 73 | RoutingPolicy{ 74 | DestinationCIDR: "10.0.0.0/8", 75 | Protocol: "udp", 76 | PortRange: PortRange{Start: 5000, End: 6000}, 77 | Priority: 2, 78 | }, 79 | false, 80 | }, 81 | { 82 | "invalid-cidr", 83 | 0, 84 | RoutingPolicy{}, 85 | true, 86 | }, 87 | { 88 | "192.168.1.0/24:invalid-protocol:80", 89 | 0, 90 | RoutingPolicy{}, 91 | true, 92 | }, 93 | } 94 | 95 | for _, test := range tests { 96 | result, err := ParseRoutingPolicy(test.input, test.priority) 97 | if test.hasError { 98 | if err == nil { 99 | t.Errorf("Expected error for input %s, but got none", test.input) 100 | } 101 | } else { 102 | if err != nil { 103 | t.Errorf("Unexpected error for input %s: %v", test.input, err) 104 | } 105 | if result == nil { 106 | t.Errorf("Expected non-nil result for input %s", test.input) 107 | } else if *result != test.expected { 108 | t.Errorf("For input %s, expected %+v but got %+v", test.input, test.expected, *result) 109 | } 110 | } 111 | } 112 | } 113 | 114 | func TestRoutingEngine(t *testing.T) { 115 | // Create a test configuration 116 | config := &WireGuardConfig{ 117 | Interface: InterfaceConfig{ 118 | Address: "10.150.0.2/24", 119 | }, 120 | Peers: []PeerConfig{ 121 | { 122 | PublicKey: "peer1", 123 | Endpoint: "vpn1.example.com:51820", 124 | AllowedIPs: []string{"0.0.0.0/0"}, 125 | RoutingPolicies: []RoutingPolicy{ 126 | { 127 | DestinationCIDR: "0.0.0.0/0", 128 | Protocol: "any", 129 | PortRange: PortRange{Start: 1, End: 65535}, 130 | Priority: 0, 131 | }, 132 | }, 133 | }, 134 | { 135 | PublicKey: "peer2", 136 | Endpoint: "vpn2.example.com:51820", 137 | AllowedIPs: []string{"192.168.0.0/16", "172.16.0.0/12"}, 138 | RoutingPolicies: []RoutingPolicy{ 139 | { 140 | DestinationCIDR: "192.168.1.0/24", 141 | Protocol: "tcp", 142 | PortRange: PortRange{Start: 80, End: 443}, 143 | Priority: 1, 144 | }, 145 | { 146 | DestinationCIDR: "0.0.0.0/0", 147 | Protocol: "tcp", 148 | PortRange: PortRange{Start: 8080, End: 9000}, 149 | Priority: 2, 150 | }, 151 | }, 152 | }, 153 | { 154 | PublicKey: "peer3", 155 | Endpoint: "dev-vpn.example.com:51820", 156 | AllowedIPs: []string{"10.0.0.0/8"}, 157 | RoutingPolicies: []RoutingPolicy{ 158 | { 159 | DestinationCIDR: "10.0.0.0/8", 160 | Protocol: "any", 161 | PortRange: PortRange{Start: 1, End: 65535}, 162 | Priority: 0, 163 | }, 164 | }, 165 | }, 166 | }, 167 | } 168 | 169 | engine := NewRoutingEngine(config) 170 | 171 | tests := []struct { 172 | name string 173 | dstIP string 174 | dstPort int 175 | protocol string 176 | expectedPeer int 177 | }{ 178 | {"General traffic", "8.8.8.8", 53, "udp", 0}, 179 | {"HTTP to 192.168.1.x", "192.168.1.100", 80, "tcp", 1}, 180 | {"HTTPS to 192.168.1.x", "192.168.1.100", 443, "tcp", 1}, 181 | {"Port 8080 to any IP", "1.2.3.4", 8080, "tcp", 1}, 182 | {"Development network", "10.1.2.3", 3000, "tcp", 2}, 183 | {"SSH to 192.168.1.x (no specific rule)", "192.168.1.100", 22, "tcp", 0}, 184 | {"UDP to port 8080 (TCP-only rule)", "1.2.3.4", 8080, "udp", 0}, 185 | } 186 | 187 | for _, test := range tests { 188 | t.Run(test.name, func(t *testing.T) { 189 | ip := net.ParseIP(test.dstIP) 190 | if ip == nil { 191 | t.Fatalf("Failed to parse IP: %s", test.dstIP) 192 | } 193 | 194 | peer, peerIdx := engine.FindPeerForDestination(ip, test.dstPort, test.protocol) 195 | if peerIdx != test.expectedPeer { 196 | t.Errorf("Expected peer %d, but got peer %d", test.expectedPeer, peerIdx) 197 | } 198 | if test.expectedPeer >= 0 && peer == nil { 199 | t.Errorf("Expected non-nil peer for peer index %d", test.expectedPeer) 200 | } 201 | }) 202 | } 203 | } 204 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WrapGuard - Userspace WireGuard Proxy 2 | 3 | WrapGuard enables any application to transparently route ALL network traffic through a WireGuard VPN without requiring container privileges or kernel modules. 4 | 5 | ## Features 6 | 7 | - **Pure Userspace**: No TUN interface creation, no NET_ADMIN capability needed 8 | - **Transparent Interception**: Uses LD_PRELOAD to intercept all network calls 9 | - **Bidirectional Support**: Both incoming and outgoing connections work 10 | - **Standard Config**: Uses standard WireGuard configuration files 11 | 12 | ## Installation 13 | 14 | ### Pre-compiled Binaries 15 | 16 | Download pre-compiled binaries for Linux and macOS from the [releases page](https://github.com/puzed/wrapguard/releases). 17 | 18 | **No additional dependencies required** - WrapGuard is a single binary that includes everything needed to create WireGuard connections. You don't need WireGuard installed on your host machine, kernel modules, or any other VPN software. 19 | 20 | ### Building from Source 21 | 22 | ```bash 23 | make build 24 | ``` 25 | 26 | This will create: 27 | - `wrapguard` - The main executable (single binary with embedded library) 28 | - `libwrapguard.so` - The LD_PRELOAD library 29 | 30 | ## Usage 31 | 32 | ```bash 33 | # Route incoming connections through WireGuard 34 | wrapguard --config=~/wg0.conf -- node -e 'http.createServer((_, res) => res.end("hello")).listen(8080)' 35 | 36 | # Route outgoing connections through WireGuard 37 | wrapguard --config=~/wg1.conf -- curl http://10.0.0.3:8080 38 | 39 | # Use an exit node (route all traffic through a specific peer) 40 | wrapguard --config=~/wg0.conf --exit-node=10.150.0.3 -- curl https://icanhazip.com 41 | 42 | # Route specific subnets through different peers 43 | wrapguard --config=~/wg0.conf \ 44 | --route=192.168.0.0/16:10.150.0.3 \ 45 | --route=172.16.0.0/12:10.150.0.4 \ 46 | -- curl https://internal.corp.com 47 | 48 | # With debug logging to console 49 | wrapguard --config=~/wg0.conf --log-level=debug -- curl https://icanhazip.com 50 | 51 | # With logging to file 52 | wrapguard --config=~/wg0.conf --log-level=info --log-file=/tmp/wrapguard.log -- curl https://icanhazip.com 53 | ``` 54 | 55 | ## Routing 56 | 57 | WrapGuard supports policy-based routing to direct traffic through specific WireGuard peers. 58 | 59 | ### Exit Node 60 | 61 | Use the `--exit-node` option to route all traffic through a specific peer (like a traditional VPN): 62 | 63 | ```bash 64 | wrapguard --config=~/wg0.conf --exit-node=10.150.0.3 -- curl https://example.com 65 | ``` 66 | 67 | ### Policy-Based Routing 68 | 69 | Use the `--route` option to route specific subnets through different peers: 70 | 71 | ```bash 72 | # Route corporate traffic through one peer, internet through another 73 | wrapguard --config=~/wg0.conf \ 74 | --route=192.168.0.0/16:10.150.0.3 \ 75 | --route=0.0.0.0/0:10.150.0.4 \ 76 | -- ssh internal.corp.com 77 | ``` 78 | 79 | ### Configuration File Routing 80 | 81 | You can also define routes in your WireGuard configuration file: 82 | 83 | ```ini 84 | [Peer] 85 | PublicKey = ... 86 | AllowedIPs = 10.150.0.0/24 87 | # Route all traffic through this peer 88 | Route = 0.0.0.0/0 89 | # Or route specific subnets 90 | Route = 192.168.0.0/16 91 | Route = 172.16.0.0/12:tcp:443 92 | ``` 93 | 94 | ## Logging 95 | 96 | WrapGuard provides structured JSON logging with configurable levels and output destinations. 97 | 98 | ### Logging Options 99 | 100 | - `--log-level=` - Set logging level (error, warn, info, debug). Default: info 101 | - `--log-file=` - Write logs to file instead of terminal 102 | 103 | ### Log Levels 104 | 105 | - `error` - Only critical errors 106 | - `warn` - Warnings and errors 107 | - `info` - General information, warnings, and errors (default) 108 | - `debug` - Detailed debugging information 109 | 110 | ### Log Format 111 | 112 | All logs are output in structured JSON format with timestamps: 113 | 114 | ```json 115 | {"timestamp":"2025-05-26T10:00:00Z","level":"info","message":"WrapGuard v1.0.0-dev initialized"} 116 | {"timestamp":"2025-05-26T10:00:00Z","level":"info","message":"Config: example-wg0.conf"} 117 | {"timestamp":"2025-05-26T10:00:00Z","level":"info","message":"Interface: 10.2.0.2/32"} 118 | {"timestamp":"2025-05-26T10:00:00Z","level":"info","message":"Peer endpoint: 192.168.1.8:51820"} 119 | {"timestamp":"2025-05-26T10:00:00Z","level":"info","message":"Launching: curl https://icanhazip.com"} 120 | ``` 121 | 122 | When `--log-file` is specified, all logs are written to the file and nothing appears on the terminal. 123 | 124 | ## Configuration 125 | 126 | WrapGuard uses standard WireGuard configuration files: 127 | 128 | ```ini 129 | [Interface] 130 | PrivateKey = 131 | Address = 10.0.0.2/24 132 | 133 | [Peer] 134 | PublicKey = 135 | Endpoint = server.example.com:51820 136 | AllowedIPs = 0.0.0.0/0 137 | PersistentKeepalive = 25 138 | ``` 139 | 140 | ## How It Works 141 | 142 | 1. **Main Process**: Parses config, initializes WireGuard userspace implementation 143 | 2. **LD_PRELOAD Library**: Intercepts network system calls (socket, connect, send, recv, etc.) 144 | 3. **Virtual Network Stack**: Routes packets between intercepted connections and WireGuard tunnel 145 | 4. **Memory-based TUN**: No kernel interface needed, packets processed entirely in memory 146 | 147 | ## Limitations 148 | 149 | - Linux and macOS only (Windows is not supported) 150 | - TCP and UDP protocols only 151 | - Performance overhead due to userspace packet processing 152 | 153 | ## Development 154 | 155 | ### Running Tests 156 | 157 | WrapGuard includes comprehensive unit tests for all core functionality: 158 | 159 | ```bash 160 | # Run all tests 161 | go test -v ./... 162 | 163 | # Run tests with coverage 164 | go test -cover ./... 165 | ``` 166 | 167 | ### Building 168 | 169 | ```bash 170 | # Build the main binary 171 | make build 172 | 173 | # Build with debug information 174 | make debug 175 | 176 | # Clean build artifacts 177 | make clean 178 | ``` 179 | 180 | ## Demo 181 | 182 | WrapGuard includes a comprehensive Docker-based demo that shows Node.js applications communicating through a WireGuard tunnel without requiring root privileges or kernel modules. 183 | 184 | The demo consists of: 185 | - A WireGuard server container 186 | - Two Node.js HTTP servers wrapped with WrapGuard 187 | - Cross-server communication through the WireGuard tunnel 188 | 189 | ### Running the Demo 190 | 191 | ```bash 192 | cd demo 193 | ./setup.sh 194 | docker compose up --build 195 | ``` 196 | 197 | ## Testing 198 | 199 | ```bash 200 | # Test outgoing connection 201 | wrapguard --config=example-wg0.conf -- curl https://example.com 202 | 203 | # Test incoming connection 204 | wrapguard --config=example-wg0.conf -- python3 -m http.server 8080 205 | ``` 206 | -------------------------------------------------------------------------------- /routing.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "net/netip" 7 | "strconv" 8 | "strings" 9 | ) 10 | 11 | // RoutingPolicy defines a policy for routing traffic through a specific peer 12 | type RoutingPolicy struct { 13 | DestinationCIDR string // e.g., "192.168.1.0/24" or "0.0.0.0/0" 14 | Protocol string // "tcp", "udp", or "any" 15 | PortRange PortRange // Port range for the policy 16 | Priority int // Higher priority policies are evaluated first 17 | } 18 | 19 | // PortRange represents a range of ports 20 | type PortRange struct { 21 | Start int 22 | End int 23 | } 24 | 25 | // RoutingEngine manages routing decisions for WireGuard peers 26 | type RoutingEngine struct { 27 | peers []PeerConfig 28 | routeTable map[string][]int // CIDR -> peer indices 29 | allowedIPs map[int][]netip.Prefix // peer index -> allowed IP prefixes 30 | } 31 | 32 | // NewRoutingEngine creates a new routing engine from the WireGuard configuration 33 | func NewRoutingEngine(config *WireGuardConfig) *RoutingEngine { 34 | engine := &RoutingEngine{ 35 | peers: config.Peers, 36 | routeTable: make(map[string][]int), 37 | allowedIPs: make(map[int][]netip.Prefix), 38 | } 39 | 40 | // Build routing table from AllowedIPs 41 | for peerIdx, peer := range config.Peers { 42 | for _, allowedIP := range peer.AllowedIPs { 43 | prefix, err := netip.ParsePrefix(allowedIP) 44 | if err != nil { 45 | if logger != nil { 46 | logger.Warnf("Invalid AllowedIP %s for peer %d: %v", allowedIP, peerIdx, err) 47 | } 48 | continue 49 | } 50 | engine.allowedIPs[peerIdx] = append(engine.allowedIPs[peerIdx], prefix) 51 | } 52 | 53 | // Process routing policies 54 | for _, policy := range peer.RoutingPolicies { 55 | if existingPeers, exists := engine.routeTable[policy.DestinationCIDR]; exists { 56 | engine.routeTable[policy.DestinationCIDR] = append(existingPeers, peerIdx) 57 | } else { 58 | engine.routeTable[policy.DestinationCIDR] = []int{peerIdx} 59 | } 60 | } 61 | } 62 | 63 | return engine 64 | } 65 | 66 | // FindPeerForDestination finds the appropriate peer for routing to a destination 67 | func (r *RoutingEngine) FindPeerForDestination(dstIP net.IP, dstPort int, protocol string) (*PeerConfig, int) { 68 | // Convert to netip.Addr for easier comparison 69 | var addr netip.Addr 70 | if dstIP.To4() != nil { 71 | // Ensure we use IPv4 representation 72 | addr, _ = netip.AddrFromSlice(dstIP.To4()) 73 | } else { 74 | addr, _ = netip.AddrFromSlice(dstIP) 75 | } 76 | if !addr.IsValid() { 77 | return nil, -1 78 | } 79 | 80 | // First, check routing policies 81 | bestPeer := -1 82 | bestPriority := -1 83 | bestSpecificity := -1 84 | 85 | for cidr, peerIndices := range r.routeTable { 86 | prefix, err := netip.ParsePrefix(cidr) 87 | if err != nil { 88 | continue 89 | } 90 | 91 | if prefix.Contains(addr) { 92 | specificity := prefix.Bits() 93 | 94 | for _, peerIdx := range peerIndices { 95 | if peerIdx >= len(r.peers) { 96 | continue 97 | } 98 | 99 | peer := &r.peers[peerIdx] 100 | 101 | // Check if this peer has a matching routing policy 102 | for _, policy := range peer.RoutingPolicies { 103 | if policy.DestinationCIDR != cidr { 104 | continue 105 | } 106 | 107 | // Check protocol match 108 | if policy.Protocol != "any" && policy.Protocol != protocol { 109 | continue 110 | } 111 | 112 | // Check port range 113 | if dstPort > 0 && (dstPort < policy.PortRange.Start || dstPort > policy.PortRange.End) { 114 | continue 115 | } 116 | 117 | // This policy matches, check if it's better than current best 118 | if specificity > bestSpecificity || 119 | (specificity == bestSpecificity && policy.Priority > bestPriority) { 120 | bestPeer = peerIdx 121 | bestPriority = policy.Priority 122 | bestSpecificity = specificity 123 | } 124 | } 125 | } 126 | } 127 | } 128 | 129 | if bestPeer >= 0 { 130 | return &r.peers[bestPeer], bestPeer 131 | } 132 | 133 | // If no routing policy matched, fall back to AllowedIPs 134 | for peerIdx, prefixes := range r.allowedIPs { 135 | for _, prefix := range prefixes { 136 | if prefix.Contains(addr) { 137 | return &r.peers[peerIdx], peerIdx 138 | } 139 | } 140 | } 141 | 142 | return nil, -1 143 | } 144 | 145 | // ParsePortRange parses a port range string like "80", "8080-9000", or "any" 146 | func ParsePortRange(portStr string) (PortRange, error) { 147 | if portStr == "" || portStr == "any" { 148 | return PortRange{Start: 1, End: 65535}, nil 149 | } 150 | 151 | if strings.Contains(portStr, "-") { 152 | parts := strings.Split(portStr, "-") 153 | if len(parts) != 2 { 154 | return PortRange{}, fmt.Errorf("invalid port range format: %s", portStr) 155 | } 156 | 157 | start, err := strconv.Atoi(strings.TrimSpace(parts[0])) 158 | if err != nil { 159 | return PortRange{}, fmt.Errorf("invalid start port: %s", parts[0]) 160 | } 161 | 162 | end, err := strconv.Atoi(strings.TrimSpace(parts[1])) 163 | if err != nil { 164 | return PortRange{}, fmt.Errorf("invalid end port: %s", parts[1]) 165 | } 166 | 167 | if start > end || start < 1 || end > 65535 { 168 | return PortRange{}, fmt.Errorf("invalid port range: %d-%d", start, end) 169 | } 170 | 171 | return PortRange{Start: start, End: end}, nil 172 | } 173 | 174 | // Single port 175 | port, err := strconv.Atoi(strings.TrimSpace(portStr)) 176 | if err != nil { 177 | return PortRange{}, fmt.Errorf("invalid port: %s", portStr) 178 | } 179 | 180 | if port < 1 || port > 65535 { 181 | return PortRange{}, fmt.Errorf("port out of range: %d", port) 182 | } 183 | 184 | return PortRange{Start: port, End: port}, nil 185 | } 186 | 187 | // ParseRoutingPolicy parses a routing policy string 188 | // Format: "CIDR" or "CIDR:protocol:ports" 189 | // Examples: "192.168.1.0/24", "0.0.0.0/0:tcp:80,443", "10.0.0.0/8:any:8080-9000" 190 | func ParseRoutingPolicy(policyStr string, priority int) (*RoutingPolicy, error) { 191 | parts := strings.Split(policyStr, ":") 192 | 193 | if len(parts) == 0 || parts[0] == "" { 194 | return nil, fmt.Errorf("empty routing policy") 195 | } 196 | 197 | policy := &RoutingPolicy{ 198 | DestinationCIDR: parts[0], 199 | Protocol: "any", 200 | PortRange: PortRange{Start: 1, End: 65535}, 201 | Priority: priority, 202 | } 203 | 204 | // Validate CIDR 205 | if _, err := netip.ParsePrefix(policy.DestinationCIDR); err != nil { 206 | return nil, fmt.Errorf("invalid CIDR: %s", policy.DestinationCIDR) 207 | } 208 | 209 | if len(parts) > 1 { 210 | // Protocol specified 211 | protocol := strings.ToLower(parts[1]) 212 | if protocol != "tcp" && protocol != "udp" && protocol != "any" { 213 | return nil, fmt.Errorf("invalid protocol: %s", protocol) 214 | } 215 | policy.Protocol = protocol 216 | } 217 | 218 | if len(parts) > 2 { 219 | // Port range specified 220 | portRange, err := ParsePortRange(parts[2]) 221 | if err != nil { 222 | return nil, err 223 | } 224 | policy.PortRange = portRange 225 | } 226 | 227 | return policy, nil 228 | } 229 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "flag" 6 | "fmt" 7 | "io" 8 | "os" 9 | "os/exec" 10 | "os/signal" 11 | "path/filepath" 12 | "strings" 13 | "syscall" 14 | "time" 15 | ) 16 | 17 | var version = "1.0.0-dev" 18 | 19 | func printUsage() { 20 | help := fmt.Sprintf(` 21 | ╦ ╦┬─┐┌─┐┌─┐╔═╗┬ ┬┌─┐┬─┐┌┬┐ 22 | ║║║├┬┘├─┤├─┘║ ╦│ │├─┤├┬┘ ││ 23 | ╚╩╝┴└─┴ ┴┴ ╚═╝└─┘┴ ┴┴└──┴┘ %s 24 | 25 | 🔒 Userspace WireGuard proxy for transparent network tunneling 26 | 27 | `, version) 28 | 29 | help += "\033[33mUSAGE:\033[0m\n" 30 | help += " wrapguard --config= -- [args...]\n\n" 31 | 32 | help += "\033[33mEXAMPLES:\033[0m\n" 33 | help += " \033[36m# Check your tunneled IP address\033[0m\n" 34 | help += " wrapguard --config=wg0.conf -- curl https://icanhazip.com\n\n" 35 | 36 | help += " \033[36m# Run a web server accessible through WireGuard\033[0m\n" 37 | help += " wrapguard --config=wg0.conf -- python3 -m http.server 8080\n\n" 38 | 39 | help += " \033[36m# Tunnel Node.js applications\033[0m\n" 40 | help += " wrapguard --config=wg0.conf -- node app.js\n\n" 41 | 42 | help += " \033[36m# Interactive shell with tunneled network\033[0m\n" 43 | help += " wrapguard --config=wg0.conf -- bash\n\n" 44 | 45 | help += "\033[33mOPTIONS:\033[0m\n" 46 | help += " --config= Path to WireGuard configuration file\n" 47 | help += " --exit-node= Route all traffic through specified peer IP\n" 48 | help += " --route= Add routing policy (CIDR:peerIP)\n" 49 | help += " --log-level= Set log level (error, warn, info, debug)\n" 50 | help += " --log-file= Set file to write logs to (default: terminal)\n" 51 | help += " --help Show this help message\n" 52 | help += " --version Show version information\n\n" 53 | 54 | help += "\033[33mFEATURES:\033[0m\n" 55 | help += " ✓ No root/sudo required\n" 56 | help += " ✓ No kernel modules needed\n" 57 | help += " ✓ Works in containers\n" 58 | help += " ✓ Transparent to applications\n" 59 | help += " ✓ Standard WireGuard configs\n\n" 60 | 61 | help += "\033[33mCONFIG EXAMPLE:\033[0m\n" 62 | help += " [Interface]\n" 63 | help += " PrivateKey = \n" 64 | help += " Address = 10.0.0.2/24\n\n" 65 | 66 | help += " [Peer]\n" 67 | help += " PublicKey = \n" 68 | help += " Endpoint = vpn.example.com:51820\n" 69 | help += " AllowedIPs = 0.0.0.0/0\n\n" 70 | 71 | help += "\033[90mMore info: https://github.com/puzed/wrapguard\033[0m\n\n" 72 | 73 | os.Stderr.WriteString(help) 74 | } 75 | 76 | func main() { 77 | var configPath string 78 | var showHelp bool 79 | var showVersion bool 80 | var logLevelStr string 81 | var logFile string 82 | var exitNode string 83 | var routes []string 84 | flag.StringVar(&configPath, "config", "", "Path to WireGuard configuration file") 85 | flag.BoolVar(&showHelp, "help", false, "Show help message") 86 | flag.BoolVar(&showVersion, "version", false, "Show version information") 87 | flag.StringVar(&logLevelStr, "log-level", "info", "Set log level (error, warn, info, debug)") 88 | flag.StringVar(&logFile, "log-file", "", "Set file to write logs to (default: terminal)") 89 | flag.StringVar(&exitNode, "exit-node", "", "Route all traffic through specified peer IP (e.g., 10.0.0.3)") 90 | flag.Func("route", "Add routing policy (format: CIDR:peerIP, e.g., 192.168.1.0/24:10.0.0.3)", func(value string) error { 91 | routes = append(routes, value) 92 | return nil 93 | }) 94 | flag.Usage = printUsage 95 | flag.Parse() 96 | 97 | if showVersion { 98 | fmt.Printf("wrapguard version %s\n", version) 99 | os.Exit(0) 100 | } 101 | 102 | if showHelp { 103 | printUsage() 104 | os.Exit(0) 105 | } 106 | 107 | if configPath == "" { 108 | printUsage() 109 | os.Exit(1) 110 | } 111 | 112 | // Parse log level 113 | logLevel, err := ParseLogLevel(logLevelStr) 114 | if err != nil { 115 | fmt.Fprintf(os.Stderr, "\n\033[31m✗ Error:\033[0m Invalid log level: %v\n", err) 116 | os.Exit(1) 117 | } 118 | 119 | // Setup logger output 120 | var logOutput io.Writer = os.Stderr 121 | if logFile != "" { 122 | file, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) 123 | if err != nil { 124 | fmt.Fprintf(os.Stderr, "\n\033[31m✗ Error:\033[0m Failed to open log file: %v\n", err) 125 | os.Exit(1) 126 | } 127 | defer file.Close() 128 | logOutput = file 129 | } 130 | 131 | // Create logger 132 | logger := NewLogger(logLevel, logOutput) 133 | SetGlobalLogger(logger) 134 | 135 | args := flag.Args() 136 | if len(args) == 0 { 137 | fmt.Fprintf(os.Stderr, "\n\033[31m✗ Error:\033[0m No command specified\n") 138 | printUsage() 139 | os.Exit(1) 140 | } 141 | 142 | // Parse WireGuard configuration 143 | config, err := ParseConfig(configPath) 144 | if err != nil { 145 | logger.Errorf("Failed to parse WireGuard config: %v", err) 146 | os.Exit(1) 147 | } 148 | 149 | // Apply CLI routing options 150 | if exitNode != "" || len(routes) > 0 { 151 | if err := ApplyCLIRoutes(config, exitNode, routes); err != nil { 152 | logger.Errorf("Failed to apply routing options: %v", err) 153 | os.Exit(1) 154 | } 155 | } 156 | 157 | // Create IPC server for communication with LD_PRELOAD library 158 | ipcServer, err := NewIPCServer() 159 | if err != nil { 160 | logger.Errorf("Failed to start IPC server: %v", err) 161 | os.Exit(1) 162 | } 163 | defer ipcServer.Close() 164 | 165 | // Create context for cancellation 166 | ctx, cancel := context.WithCancel(context.Background()) 167 | defer cancel() 168 | 169 | // Start WireGuard tunnel 170 | logger.Infof("Creating WireGuard tunnel...") 171 | tunnel, err := NewTunnel(ctx, config) 172 | if err != nil { 173 | logger.Errorf("Failed to create tunnel: %v", err) 174 | os.Exit(1) 175 | } 176 | defer tunnel.Close() 177 | logger.Infof("WireGuard tunnel created successfully") 178 | 179 | // Start SOCKS5 server that routes through WireGuard tunnel 180 | logger.Infof("Starting SOCKS5 server...") 181 | socksServer, err := NewSOCKS5Server(tunnel) 182 | if err != nil { 183 | logger.Errorf("Failed to start SOCKS5 server: %v", err) 184 | os.Exit(1) 185 | } 186 | defer socksServer.Close() 187 | logger.Infof("SOCKS5 server started on port %d", socksServer.Port()) 188 | 189 | // Start port forwarder for incoming connections 190 | forwarder := NewPortForwarder(tunnel, ipcServer.MessageChan()) 191 | go forwarder.Run(ctx) 192 | 193 | // Show startup messages using structured logging 194 | logger.Infof("WrapGuard v%s initialized", version) 195 | logger.Infof("Config: %s", configPath) 196 | logger.Infof("Interface: %s", config.Interface.Address) 197 | if len(config.Peers) > 0 { 198 | logger.Infof("Peer endpoint: %s", config.Peers[0].Endpoint) 199 | } 200 | logger.Infof("Launching: [%s]", strings.Join(args, " ")) 201 | 202 | // Get path to our LD_PRELOAD library 203 | execPath, err := os.Executable() 204 | if err != nil { 205 | logger.Errorf("Failed to get executable path: %v", err) 206 | os.Exit(1) 207 | } 208 | libPath := filepath.Join(filepath.Dir(execPath), "libwrapguard.so") 209 | 210 | // Prepare child process 211 | cmd := exec.Command(args[0], args[1:]...) 212 | cmd.Stdin = os.Stdin 213 | cmd.Stdout = os.Stdout 214 | cmd.Stderr = os.Stderr 215 | 216 | // Set LD_PRELOAD and IPC socket path 217 | cmd.Env = append(os.Environ(), 218 | fmt.Sprintf("LD_PRELOAD=%s", libPath), 219 | fmt.Sprintf("WRAPGUARD_IPC_PATH=%s", ipcServer.SocketPath()), 220 | fmt.Sprintf("WRAPGUARD_SOCKS_PORT=%d", socksServer.Port()), 221 | ) 222 | 223 | // Start the child process 224 | if err := cmd.Start(); err != nil { 225 | logger.Errorf("Failed to start child process: %v", err) 226 | os.Exit(1) 227 | } 228 | 229 | // Handle signals 230 | sigChan := make(chan os.Signal, 1) 231 | signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) 232 | 233 | // Wait for child process or signal 234 | done := make(chan error, 1) 235 | go func() { 236 | done <- cmd.Wait() 237 | }() 238 | 239 | select { 240 | case err := <-done: 241 | if err != nil { 242 | if exitErr, ok := err.(*exec.ExitError); ok { 243 | os.Exit(exitErr.ExitCode()) 244 | } 245 | logger.Errorf("Child process error: %v", err) 246 | os.Exit(1) 247 | } 248 | // Exit cleanly when child process completes successfully 249 | os.Exit(0) 250 | case sig := <-sigChan: 251 | logger.Infof("Received signal %v, shutting down...", sig) 252 | // Forward signal to child process 253 | if cmd.Process != nil { 254 | cmd.Process.Signal(sig) 255 | } 256 | // Wait for child to exit 257 | select { 258 | case <-done: 259 | case <-time.After(5 * time.Second): 260 | logger.Warnf("Child process did not exit gracefully, killing...") 261 | cmd.Process.Kill() 262 | } 263 | os.Exit(1) 264 | } 265 | } 266 | -------------------------------------------------------------------------------- /socks_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "net" 5 | "net/netip" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | func TestNewSOCKS5Server(t *testing.T) { 11 | // Create a mock tunnel 12 | tunnel := &Tunnel{ 13 | ourIP: mustParseIPAddr("10.150.0.2"), 14 | } 15 | 16 | server, err := NewSOCKS5Server(tunnel) 17 | if err != nil { 18 | t.Fatalf("NewSOCKS5Server failed: %v", err) 19 | } 20 | defer server.Close() 21 | 22 | if server == nil { 23 | t.Fatal("NewSOCKS5Server returned nil") 24 | } 25 | 26 | if server.server == nil { 27 | t.Error("SOCKS5 server is nil") 28 | } 29 | 30 | if server.listener == nil { 31 | t.Error("listener is nil") 32 | } 33 | 34 | if server.tunnel != tunnel { 35 | t.Error("tunnel reference not set correctly") 36 | } 37 | 38 | if server.port == 0 { 39 | t.Error("port should be set to non-zero value") 40 | } 41 | } 42 | 43 | func TestSOCKS5Server_Port(t *testing.T) { 44 | tunnel := &Tunnel{ 45 | ourIP: mustParseIPAddr("10.150.0.2"), 46 | } 47 | 48 | server, err := NewSOCKS5Server(tunnel) 49 | if err != nil { 50 | t.Fatalf("NewSOCKS5Server failed: %v", err) 51 | } 52 | defer server.Close() 53 | 54 | port := server.Port() 55 | if port == 0 { 56 | t.Error("Port() returned 0") 57 | } 58 | 59 | // Port should be in valid range 60 | if port < 1024 || port > 65535 { 61 | t.Errorf("Port() returned invalid port %d", port) 62 | } 63 | } 64 | 65 | func TestSOCKS5Server_Close(t *testing.T) { 66 | tunnel := &Tunnel{ 67 | ourIP: mustParseIPAddr("10.150.0.2"), 68 | } 69 | 70 | server, err := NewSOCKS5Server(tunnel) 71 | if err != nil { 72 | t.Fatalf("NewSOCKS5Server failed: %v", err) 73 | } 74 | 75 | // Test that Close doesn't return error 76 | err = server.Close() 77 | if err != nil { 78 | t.Errorf("Close() returned error: %v", err) 79 | } 80 | 81 | // Test multiple closes don't panic (may return error, that's OK) 82 | server.Close() // Don't check error for second close as it may fail 83 | } 84 | 85 | func TestSOCKS5Server_Integration(t *testing.T) { 86 | // This is an integration test that may not work in all test environments 87 | // but tests the SOCKS5 server functionality 88 | 89 | tunnel := &Tunnel{ 90 | ourIP: mustParseIPAddr("10.150.0.2"), 91 | } 92 | 93 | server, err := NewSOCKS5Server(tunnel) 94 | if err != nil { 95 | t.Fatalf("NewSOCKS5Server failed: %v", err) 96 | } 97 | defer server.Close() 98 | 99 | // Give the server a moment to start 100 | time.Sleep(50 * time.Millisecond) 101 | 102 | // Test that we can connect to the SOCKS5 server 103 | conn, err := net.DialTimeout("tcp", "127.0.0.1:"+itoa(server.Port()), 1*time.Second) 104 | if err != nil { 105 | t.Logf("Could not connect to SOCKS5 server (may be expected in test env): %v", err) 106 | return 107 | } 108 | defer conn.Close() 109 | 110 | // Basic connectivity test - just ensure we can connect 111 | // Full SOCKS5 protocol testing would require more complex setup 112 | } 113 | 114 | func TestSOCKS5Server_CustomDialer(t *testing.T) { 115 | // Test the custom dialer logic in the SOCKS5 server 116 | // This tests the logic but not the actual network connections 117 | 118 | tunnel := &Tunnel{ 119 | ourIP: mustParseIPAddr("10.150.0.2"), 120 | } 121 | 122 | // Since we can't easily override the method, we'll test the server creation 123 | // The actual dialer testing would require more complex mocking 124 | server, err := NewSOCKS5Server(tunnel) 125 | if err != nil { 126 | t.Fatalf("NewSOCKS5Server failed: %v", err) 127 | } 128 | defer server.Close() 129 | 130 | // Verify the server was created successfully 131 | if server.server == nil { 132 | t.Error("SOCKS5 server not created") 133 | } 134 | } 135 | 136 | func TestSOCKS5Server_ListenerAddress(t *testing.T) { 137 | tunnel := &Tunnel{ 138 | ourIP: mustParseIPAddr("10.150.0.2"), 139 | } 140 | 141 | server, err := NewSOCKS5Server(tunnel) 142 | if err != nil { 143 | t.Fatalf("NewSOCKS5Server failed: %v", err) 144 | } 145 | defer server.Close() 146 | 147 | // Check that the server listens on localhost 148 | addr := server.listener.Addr().(*net.TCPAddr) 149 | if !addr.IP.IsLoopback() { 150 | t.Errorf("server should listen on loopback interface, got %v", addr.IP) 151 | } 152 | 153 | if addr.Port == 0 { 154 | t.Error("server should have a non-zero port") 155 | } 156 | 157 | // Port should match what Port() returns 158 | if addr.Port != server.Port() { 159 | t.Errorf("listener port %d doesn't match Port() %d", addr.Port, server.Port()) 160 | } 161 | } 162 | 163 | func TestSOCKS5Server_NilTunnel(t *testing.T) { 164 | // Test behavior with nil tunnel (should not panic but may fail) 165 | _, err := NewSOCKS5Server(nil) 166 | 167 | // This will likely panic or fail, which is acceptable behavior 168 | // We just want to ensure it doesn't crash the test suite 169 | if err != nil { 170 | t.Logf("NewSOCKS5Server with nil tunnel failed as expected: %v", err) 171 | } 172 | } 173 | 174 | func TestSOCKS5Server_PortRange(t *testing.T) { 175 | tunnel := &Tunnel{ 176 | ourIP: mustParseIPAddr("10.150.0.2"), 177 | } 178 | 179 | // Create multiple servers to test port allocation 180 | servers := make([]*SOCKS5Server, 5) 181 | defer func() { 182 | for _, server := range servers { 183 | if server != nil { 184 | server.Close() 185 | } 186 | } 187 | }() 188 | 189 | ports := make(map[int]bool) 190 | 191 | for i := 0; i < 5; i++ { 192 | server, err := NewSOCKS5Server(tunnel) 193 | if err != nil { 194 | t.Fatalf("NewSOCKS5Server %d failed: %v", i, err) 195 | } 196 | servers[i] = server 197 | 198 | port := server.Port() 199 | if ports[port] { 200 | t.Errorf("duplicate port %d allocated", port) 201 | } 202 | ports[port] = true 203 | 204 | // Each server should get a different port 205 | if port < 1024 || port > 65535 { 206 | t.Errorf("invalid port %d allocated", port) 207 | } 208 | } 209 | } 210 | 211 | func TestSOCKS5Server_ServerRunning(t *testing.T) { 212 | tunnel := &Tunnel{ 213 | ourIP: mustParseIPAddr("10.150.0.2"), 214 | } 215 | 216 | server, err := NewSOCKS5Server(tunnel) 217 | if err != nil { 218 | t.Fatalf("NewSOCKS5Server failed: %v", err) 219 | } 220 | defer server.Close() 221 | 222 | // Give server time to start 223 | time.Sleep(10 * time.Millisecond) 224 | 225 | // Test that the server is actually listening 226 | listener := server.listener 227 | if listener == nil { 228 | t.Fatal("listener is nil") 229 | } 230 | 231 | addr := listener.Addr() 232 | if addr == nil { 233 | t.Fatal("listener address is nil") 234 | } 235 | 236 | // The server should be running (we can't easily test the Serve goroutine 237 | // without complex setup, but we can verify the listener is active) 238 | tcpAddr, ok := addr.(*net.TCPAddr) 239 | if !ok { 240 | t.Fatalf("listener address is not TCP: %T", addr) 241 | } 242 | 243 | if tcpAddr.Port == 0 { 244 | t.Error("listener has no port assigned") 245 | } 246 | } 247 | 248 | // Helper function to parse IP addresses for testing 249 | func mustParseIPAddr(s string) netip.Addr { 250 | ip, err := netip.ParseAddr(s) 251 | if err != nil { 252 | panic("invalid IP: " + s + " - " + err.Error()) 253 | } 254 | return ip 255 | } 256 | 257 | // Helper function to convert int to string (simple implementation) 258 | func itoa(i int) string { 259 | if i == 0 { 260 | return "0" 261 | } 262 | 263 | negative := false 264 | if i < 0 { 265 | negative = true 266 | i = -i 267 | } 268 | 269 | var digits []byte 270 | for i > 0 { 271 | digits = append([]byte{byte('0' + i%10)}, digits...) 272 | i /= 10 273 | } 274 | 275 | if negative { 276 | digits = append([]byte{'-'}, digits...) 277 | } 278 | 279 | return string(digits) 280 | } 281 | 282 | // Test that tests the tunnel's IsWireGuardIP method with SOCKS5 context 283 | func TestSOCKS5_WireGuardIPDetection(t *testing.T) { 284 | tunnel := &Tunnel{ 285 | ourIP: mustParseIPAddr("10.150.0.2"), 286 | } 287 | 288 | tests := []struct { 289 | name string 290 | ip string 291 | want bool 292 | }{ 293 | {"WireGuard IP", "10.150.0.5", true}, 294 | {"Non-WireGuard IP", "8.8.8.8", false}, 295 | {"Localhost", "127.0.0.1", false}, 296 | } 297 | 298 | for _, tt := range tests { 299 | t.Run(tt.name, func(t *testing.T) { 300 | ip := net.ParseIP(tt.ip) 301 | if ip == nil { 302 | t.Fatalf("invalid IP: %s", tt.ip) 303 | } 304 | 305 | result := tunnel.IsWireGuardIP(ip) 306 | if result != tt.want { 307 | t.Errorf("IsWireGuardIP(%s) = %v, want %v", tt.ip, result, tt.want) 308 | } 309 | }) 310 | } 311 | } 312 | 313 | // Benchmark test for SOCKS5 server creation 314 | func BenchmarkNewSOCKS5Server(b *testing.B) { 315 | tunnel := &Tunnel{ 316 | ourIP: mustParseIPAddr("10.150.0.2"), 317 | } 318 | 319 | b.ResetTimer() 320 | for i := 0; i < b.N; i++ { 321 | server, err := NewSOCKS5Server(tunnel) 322 | if err != nil { 323 | b.Fatalf("NewSOCKS5Server failed: %v", err) 324 | } 325 | server.Close() 326 | } 327 | } 328 | -------------------------------------------------------------------------------- /config.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bufio" 5 | "encoding/base64" 6 | "encoding/hex" 7 | "fmt" 8 | "net" 9 | "net/netip" 10 | "os" 11 | "strconv" 12 | "strings" 13 | ) 14 | 15 | type InterfaceConfig struct { 16 | PrivateKey string 17 | Address string 18 | DNS []string 19 | ListenPort int 20 | } 21 | 22 | type PeerConfig struct { 23 | PublicKey string 24 | PresharedKey string 25 | Endpoint string 26 | AllowedIPs []string 27 | PersistentKeepalive int 28 | RoutingPolicies []RoutingPolicy // New field for policy-based routing 29 | } 30 | 31 | type WireGuardConfig struct { 32 | Interface InterfaceConfig 33 | Peers []PeerConfig 34 | } 35 | 36 | func ParseConfig(filename string) (*WireGuardConfig, error) { 37 | file, err := os.Open(filename) 38 | if err != nil { 39 | return nil, fmt.Errorf("failed to open config file: %w", err) 40 | } 41 | defer file.Close() 42 | 43 | config := &WireGuardConfig{} 44 | scanner := bufio.NewScanner(file) 45 | var currentSection string 46 | var currentPeer *PeerConfig 47 | 48 | for scanner.Scan() { 49 | line := strings.TrimSpace(scanner.Text()) 50 | 51 | // Skip empty lines and comments 52 | if line == "" || strings.HasPrefix(line, "#") { 53 | continue 54 | } 55 | 56 | // Check for section headers 57 | if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") { 58 | currentSection = strings.ToLower(line[1 : len(line)-1]) 59 | if currentSection == "peer" { 60 | if currentPeer != nil { 61 | config.Peers = append(config.Peers, *currentPeer) 62 | } 63 | currentPeer = &PeerConfig{} 64 | } 65 | continue 66 | } 67 | 68 | // Parse key-value pairs 69 | parts := strings.SplitN(line, "=", 2) 70 | if len(parts) != 2 { 71 | continue 72 | } 73 | 74 | key := strings.TrimSpace(parts[0]) 75 | value := strings.TrimSpace(parts[1]) 76 | 77 | switch currentSection { 78 | case "interface": 79 | if err := parseInterfaceField(&config.Interface, key, value); err != nil { 80 | return nil, fmt.Errorf("error parsing interface field %s: %w", key, err) 81 | } 82 | case "peer": 83 | if currentPeer != nil { 84 | if err := parsePeerField(currentPeer, key, value); err != nil { 85 | return nil, fmt.Errorf("error parsing peer field %s: %w", key, err) 86 | } 87 | } 88 | } 89 | } 90 | 91 | // Add the last peer if exists 92 | if currentPeer != nil { 93 | config.Peers = append(config.Peers, *currentPeer) 94 | } 95 | 96 | if err := scanner.Err(); err != nil { 97 | return nil, fmt.Errorf("error reading config file: %w", err) 98 | } 99 | 100 | if err := validateConfig(config); err != nil { 101 | return nil, fmt.Errorf("config validation failed: %w", err) 102 | } 103 | 104 | return config, nil 105 | } 106 | 107 | func parseInterfaceField(iface *InterfaceConfig, key, value string) error { 108 | switch strings.ToLower(key) { 109 | case "privatekey": 110 | // Convert base64 private key to hex for wireguard-go IPC 111 | hexKey, err := base64ToHex(value) 112 | if err != nil { 113 | return fmt.Errorf("invalid private key format: %w", err) 114 | } 115 | iface.PrivateKey = hexKey 116 | case "address": 117 | iface.Address = value 118 | case "dns": 119 | // Parse comma-separated DNS servers 120 | dns := strings.Split(value, ",") 121 | for i, d := range dns { 122 | dns[i] = strings.TrimSpace(d) 123 | } 124 | iface.DNS = dns 125 | case "listenport": 126 | port, err := strconv.Atoi(value) 127 | if err != nil { 128 | return fmt.Errorf("invalid listen port: %w", err) 129 | } 130 | iface.ListenPort = port 131 | } 132 | return nil 133 | } 134 | 135 | func parsePeerField(peer *PeerConfig, key, value string) error { 136 | switch strings.ToLower(key) { 137 | case "publickey": 138 | // Convert base64 public key to hex for wireguard-go IPC 139 | hexKey, err := base64ToHex(value) 140 | if err != nil { 141 | return fmt.Errorf("invalid public key format: %w", err) 142 | } 143 | peer.PublicKey = hexKey 144 | case "presharedkey": 145 | // Convert base64 preshared key to hex for wireguard-go IPC 146 | hexKey, err := base64ToHex(value) 147 | if err != nil { 148 | return fmt.Errorf("invalid preshared key format: %w", err) 149 | } 150 | peer.PresharedKey = hexKey 151 | case "endpoint": 152 | // Resolve hostname in endpoint to IP address 153 | resolvedEndpoint, err := resolveEndpoint(value) 154 | if err != nil { 155 | return fmt.Errorf("failed to resolve endpoint %s: %w", value, err) 156 | } 157 | peer.Endpoint = resolvedEndpoint 158 | case "allowedips": 159 | // Parse comma-separated allowed IPs 160 | ips := strings.Split(value, ",") 161 | for i, ip := range ips { 162 | ips[i] = strings.TrimSpace(ip) 163 | } 164 | peer.AllowedIPs = ips 165 | case "persistentkeepalive": 166 | keepalive, err := strconv.Atoi(value) 167 | if err != nil { 168 | return fmt.Errorf("invalid persistent keepalive: %w", err) 169 | } 170 | peer.PersistentKeepalive = keepalive 171 | case "route": 172 | // Parse routing policy with auto-incrementing priority 173 | priority := len(peer.RoutingPolicies) 174 | policy, err := ParseRoutingPolicy(value, priority) 175 | if err != nil { 176 | return fmt.Errorf("invalid routing policy: %w", err) 177 | } 178 | peer.RoutingPolicies = append(peer.RoutingPolicies, *policy) 179 | } 180 | return nil 181 | } 182 | 183 | func validateConfig(config *WireGuardConfig) error { 184 | // Validate interface 185 | if config.Interface.PrivateKey == "" { 186 | return fmt.Errorf("interface private key is required") 187 | } 188 | 189 | if config.Interface.Address == "" { 190 | return fmt.Errorf("interface address is required") 191 | } 192 | 193 | // Validate address format 194 | if _, err := netip.ParsePrefix(config.Interface.Address); err != nil { 195 | return fmt.Errorf("invalid interface address format: %w", err) 196 | } 197 | 198 | // Validate at least one peer 199 | if len(config.Peers) == 0 { 200 | return fmt.Errorf("at least one peer is required") 201 | } 202 | 203 | // Validate peers 204 | for i, peer := range config.Peers { 205 | if peer.PublicKey == "" { 206 | return fmt.Errorf("peer %d: public key is required", i) 207 | } 208 | 209 | if len(peer.AllowedIPs) == 0 { 210 | return fmt.Errorf("peer %d: at least one allowed IP is required", i) 211 | } 212 | 213 | // Validate allowed IPs format 214 | for _, allowedIP := range peer.AllowedIPs { 215 | if _, err := netip.ParsePrefix(allowedIP); err != nil { 216 | return fmt.Errorf("peer %d: invalid allowed IP format %s: %w", i, allowedIP, err) 217 | } 218 | } 219 | } 220 | 221 | return nil 222 | } 223 | 224 | // GetInterfaceIP extracts the IP address from the interface address (without CIDR) 225 | func (c *WireGuardConfig) GetInterfaceIP() (netip.Addr, error) { 226 | prefix, err := netip.ParsePrefix(c.Interface.Address) 227 | if err != nil { 228 | return netip.Addr{}, err 229 | } 230 | return prefix.Addr(), nil 231 | } 232 | 233 | // GetInterfacePrefix returns the interface address as a prefix 234 | func (c *WireGuardConfig) GetInterfacePrefix() (netip.Prefix, error) { 235 | return netip.ParsePrefix(c.Interface.Address) 236 | } 237 | 238 | // base64ToHex converts a base64-encoded WireGuard key to lowercase hex format 239 | // required by wireguard-go IPC protocol 240 | func base64ToHex(base64Key string) (string, error) { 241 | // Decode base64 key 242 | keyBytes, err := base64.StdEncoding.DecodeString(base64Key) 243 | if err != nil { 244 | return "", fmt.Errorf("failed to decode base64 key: %w", err) 245 | } 246 | 247 | // WireGuard keys should be exactly 32 bytes 248 | if len(keyBytes) != 32 { 249 | return "", fmt.Errorf("key must be 32 bytes, got %d", len(keyBytes)) 250 | } 251 | 252 | // Convert to lowercase hex 253 | return hex.EncodeToString(keyBytes), nil 254 | } 255 | 256 | // resolveEndpoint resolves a hostname:port endpoint to IP:port format 257 | // required by wireguard-go which expects IP addresses, not hostnames 258 | func resolveEndpoint(endpoint string) (string, error) { 259 | host, port, err := net.SplitHostPort(endpoint) 260 | if err != nil { 261 | return "", fmt.Errorf("invalid endpoint format: %w", err) 262 | } 263 | 264 | // Check if host is already an IP address 265 | if ip := net.ParseIP(host); ip != nil { 266 | return endpoint, nil // Already an IP, return as-is 267 | } 268 | 269 | // Resolve hostname to IP 270 | ips, err := net.LookupIP(host) 271 | if err != nil { 272 | return "", fmt.Errorf("failed to resolve hostname %s: %w", host, err) 273 | } 274 | 275 | if len(ips) == 0 { 276 | return "", fmt.Errorf("no IP addresses found for hostname %s", host) 277 | } 278 | 279 | // Use the first IP address (prefer IPv4) 280 | var resolvedIP net.IP 281 | for _, ip := range ips { 282 | if ip.To4() != nil { 283 | resolvedIP = ip 284 | break 285 | } 286 | } 287 | 288 | // If no IPv4 found, use the first IP 289 | if resolvedIP == nil { 290 | resolvedIP = ips[0] 291 | } 292 | 293 | return net.JoinHostPort(resolvedIP.String(), port), nil 294 | } 295 | 296 | // ApplyCLIRoutes applies routing policies from CLI arguments to the configuration 297 | func ApplyCLIRoutes(config *WireGuardConfig, exitNode string, routes []string) error { 298 | // Handle exit node (shorthand for routing all traffic through a peer) 299 | if exitNode != "" { 300 | routes = append([]string{fmt.Sprintf("0.0.0.0/0:%s", exitNode)}, routes...) 301 | } 302 | 303 | // Process each route 304 | for _, route := range routes { 305 | parts := strings.Split(route, ":") 306 | if len(parts) != 2 { 307 | return fmt.Errorf("invalid route format '%s', expected CIDR:peerIP", route) 308 | } 309 | 310 | cidr := strings.TrimSpace(parts[0]) 311 | peerIP := strings.TrimSpace(parts[1]) 312 | 313 | // Validate CIDR 314 | if _, err := netip.ParsePrefix(cidr); err != nil { 315 | return fmt.Errorf("invalid CIDR in route '%s': %w", route, err) 316 | } 317 | 318 | // Find the peer with the matching IP 319 | peerFound := false 320 | for i := range config.Peers { 321 | peer := &config.Peers[i] 322 | 323 | // Check if this peer can route to the specified IP 324 | for _, allowedIP := range peer.AllowedIPs { 325 | prefix, err := netip.ParsePrefix(allowedIP) 326 | if err != nil { 327 | continue 328 | } 329 | 330 | // Check if the peer IP is within this peer's allowed IPs 331 | addr, err := netip.ParseAddr(peerIP) 332 | if err != nil { 333 | continue 334 | } 335 | 336 | if prefix.Contains(addr) { 337 | // Add routing policy to this peer 338 | priority := len(peer.RoutingPolicies) 339 | policy := RoutingPolicy{ 340 | DestinationCIDR: cidr, 341 | Protocol: "any", 342 | PortRange: PortRange{Start: 1, End: 65535}, 343 | Priority: priority, 344 | } 345 | peer.RoutingPolicies = append(peer.RoutingPolicies, policy) 346 | peerFound = true 347 | 348 | if logger != nil { 349 | logger.Infof("Added route %s via peer %s", cidr, peerIP) 350 | } 351 | break 352 | } 353 | } 354 | 355 | if peerFound { 356 | break 357 | } 358 | } 359 | 360 | if !peerFound { 361 | return fmt.Errorf("no peer found that can route to %s", peerIP) 362 | } 363 | } 364 | 365 | return nil 366 | } 367 | -------------------------------------------------------------------------------- /tunnel.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "encoding/binary" 6 | "fmt" 7 | "net" 8 | "net/netip" 9 | "os" 10 | "strconv" 11 | "sync" 12 | "time" 13 | 14 | "golang.zx2c4.com/wireguard/conn" 15 | "golang.zx2c4.com/wireguard/device" 16 | "golang.zx2c4.com/wireguard/tun" 17 | ) 18 | 19 | type Tunnel struct { 20 | device *device.Device 21 | tun *MemoryTUN 22 | ourIP netip.Addr 23 | connMap map[string]*TunnelConn 24 | mutex sync.RWMutex 25 | router *RoutingEngine // Add routing engine 26 | config *WireGuardConfig // Keep config reference 27 | } 28 | 29 | type TunnelConn struct { 30 | localAddr net.Addr 31 | remoteAddr net.Addr 32 | readChan chan []byte 33 | writeChan chan []byte 34 | closed bool 35 | mutex sync.RWMutex 36 | } 37 | 38 | // MemoryTUN implements tun.Device for userspace packet handling 39 | type MemoryTUN struct { 40 | inbound chan []byte 41 | outbound chan []byte 42 | mtu int 43 | name string 44 | events chan tun.Event 45 | closed bool 46 | mutex sync.RWMutex 47 | tunnel *Tunnel 48 | } 49 | 50 | func NewMemoryTUN(name string, mtu int) *MemoryTUN { 51 | return &MemoryTUN{ 52 | inbound: make(chan []byte, 100), 53 | outbound: make(chan []byte, 100), 54 | mtu: mtu, 55 | name: name, 56 | events: make(chan tun.Event, 10), 57 | } 58 | } 59 | 60 | func (m *MemoryTUN) File() *os.File { return nil } 61 | 62 | func (m *MemoryTUN) Read(bufs [][]byte, sizes []int, offset int) (int, error) { 63 | packet, ok := <-m.inbound 64 | if !ok { 65 | return 0, fmt.Errorf("TUN closed") 66 | } 67 | 68 | // Read a single packet into the first buffer 69 | if len(bufs) > 0 { 70 | n := copy(bufs[0][offset:], packet) 71 | sizes[0] = n 72 | return 1, nil 73 | } 74 | return 0, nil 75 | } 76 | 77 | func (m *MemoryTUN) Write(bufs [][]byte, offset int) (int, error) { 78 | m.mutex.RLock() 79 | if m.closed { 80 | m.mutex.RUnlock() 81 | return 0, fmt.Errorf("TUN closed") 82 | } 83 | m.mutex.RUnlock() 84 | 85 | // Write all packets in the batch 86 | written := 0 87 | for _, buf := range bufs { 88 | packet := make([]byte, len(buf)-offset) 89 | copy(packet, buf[offset:]) 90 | 91 | // Handle incoming packets from WireGuard 92 | if m.tunnel != nil { 93 | go m.tunnel.handleIncomingPacket(packet) 94 | } 95 | 96 | select { 97 | case m.outbound <- packet: 98 | written++ 99 | default: 100 | // Drop if full 101 | break 102 | } 103 | } 104 | 105 | return written, nil 106 | } 107 | 108 | func (m *MemoryTUN) Flush() error { return nil } 109 | func (m *MemoryTUN) MTU() (int, error) { return m.mtu, nil } 110 | func (m *MemoryTUN) Name() (string, error) { return m.name, nil } 111 | func (m *MemoryTUN) Events() <-chan tun.Event { return m.events } 112 | func (m *MemoryTUN) BatchSize() int { return 1 } 113 | 114 | func (m *MemoryTUN) Close() error { 115 | m.mutex.Lock() 116 | defer m.mutex.Unlock() 117 | 118 | if !m.closed { 119 | m.closed = true 120 | close(m.inbound) 121 | close(m.outbound) 122 | close(m.events) 123 | } 124 | return nil 125 | } 126 | 127 | func NewTunnel(ctx context.Context, config *WireGuardConfig) (*Tunnel, error) { 128 | // Get our WireGuard IP 129 | ourIP, err := config.GetInterfaceIP() 130 | if err != nil { 131 | return nil, fmt.Errorf("failed to parse interface IP: %w", err) 132 | } 133 | 134 | // Create memory TUN 135 | memTun := NewMemoryTUN("wg0", 1420) 136 | 137 | tunnel := &Tunnel{ 138 | tun: memTun, 139 | ourIP: ourIP, 140 | connMap: make(map[string]*TunnelConn), 141 | config: config, 142 | router: NewRoutingEngine(config), 143 | } 144 | 145 | // Set tunnel reference in TUN for packet handling 146 | memTun.tunnel = tunnel 147 | 148 | // Create WireGuard device 149 | logger := device.NewLogger( 150 | device.LogLevelSilent, 151 | fmt.Sprintf("[%s] ", "wg"), 152 | ) 153 | 154 | dev := device.NewDevice(memTun, conn.NewDefaultBind(), logger) 155 | 156 | // Configure device 157 | if err := configureDevice(dev, config); err != nil { 158 | dev.Close() 159 | return nil, fmt.Errorf("failed to configure device: %w", err) 160 | } 161 | 162 | // Bring device up 163 | if err := dev.Up(); err != nil { 164 | dev.Close() 165 | return nil, fmt.Errorf("failed to bring device up: %w", err) 166 | } 167 | 168 | tunnel.device = dev 169 | return tunnel, nil 170 | } 171 | 172 | func configureDevice(dev *device.Device, config *WireGuardConfig) error { 173 | ipcConfig := fmt.Sprintf("private_key=%s\n", config.Interface.PrivateKey) 174 | 175 | if config.Interface.ListenPort > 0 { 176 | ipcConfig += fmt.Sprintf("listen_port=%d\n", config.Interface.ListenPort) 177 | } 178 | 179 | for _, peer := range config.Peers { 180 | ipcConfig += fmt.Sprintf("public_key=%s\n", peer.PublicKey) 181 | 182 | if peer.PresharedKey != "" { 183 | ipcConfig += fmt.Sprintf("preshared_key=%s\n", peer.PresharedKey) 184 | } 185 | 186 | if peer.Endpoint != "" { 187 | ipcConfig += fmt.Sprintf("endpoint=%s\n", peer.Endpoint) 188 | } 189 | 190 | if peer.PersistentKeepalive > 0 { 191 | ipcConfig += fmt.Sprintf("persistent_keepalive_interval=%d\n", peer.PersistentKeepalive) 192 | } 193 | 194 | for _, allowedIP := range peer.AllowedIPs { 195 | ipcConfig += fmt.Sprintf("allowed_ip=%s\n", allowedIP) 196 | } 197 | } 198 | 199 | return dev.IpcSet(ipcConfig) 200 | } 201 | 202 | func (t *Tunnel) handleIncomingPacket(packet []byte) { 203 | if len(packet) < 20 { 204 | return // Too short for IP header 205 | } 206 | 207 | // Parse IP header to extract source/dest 208 | version := packet[0] >> 4 209 | if version != 4 { 210 | return // Only IPv4 for now 211 | } 212 | 213 | protocol := packet[9] 214 | if protocol != 6 { 215 | return // Only TCP for now 216 | } 217 | 218 | srcIP := net.IP(packet[12:16]) 219 | dstIP := net.IP(packet[16:20]) 220 | 221 | // Extract TCP ports 222 | if len(packet) < 24 { 223 | return 224 | } 225 | 226 | srcPort := binary.BigEndian.Uint16(packet[20:22]) 227 | dstPort := binary.BigEndian.Uint16(packet[22:24]) 228 | 229 | connKey := fmt.Sprintf("%s:%d->%s:%d", srcIP, srcPort, dstIP, dstPort) 230 | 231 | t.mutex.RLock() 232 | conn, exists := t.connMap[connKey] 233 | t.mutex.RUnlock() 234 | 235 | if exists { 236 | // Deliver to existing connection 237 | select { 238 | case conn.readChan <- packet[20:]: // TCP payload 239 | default: 240 | // Drop if full 241 | } 242 | } 243 | } 244 | 245 | // DialContext creates a connection through WireGuard 246 | func (t *Tunnel) DialContext(ctx context.Context, network, address string) (net.Conn, error) { 247 | // For now, return an error since we need the WireGuard interface to be configured 248 | // In a full implementation, this would send packets through the WireGuard tunnel 249 | return nil, fmt.Errorf("WireGuard tunnel dial not implemented - requires system WireGuard interface or full TCP/IP stack") 250 | } 251 | 252 | func (t *Tunnel) createTCPSyn(dstIP net.IP, dstPort int) []byte { 253 | // Create a minimal TCP SYN packet 254 | // This is very simplified - a real implementation would need proper TCP handling 255 | packet := make([]byte, 40) // IP header (20) + TCP header (20) 256 | 257 | // IP header 258 | packet[0] = 0x45 // Version 4, header length 5 259 | packet[1] = 0x00 // DSCP/ECN 260 | binary.BigEndian.PutUint16(packet[2:4], 40) // Total length 261 | binary.BigEndian.PutUint16(packet[4:6], 0x1234) // ID 262 | binary.BigEndian.PutUint16(packet[6:8], 0x4000) // Flags 263 | packet[8] = 64 // TTL 264 | packet[9] = 6 // Protocol (TCP) 265 | copy(packet[12:16], t.ourIP.AsSlice()) // Source IP 266 | copy(packet[16:20], dstIP.To4()) // Dest IP 267 | 268 | // TCP header 269 | binary.BigEndian.PutUint16(packet[20:22], 12345) // Source port 270 | binary.BigEndian.PutUint16(packet[22:24], uint16(dstPort)) // Dest port 271 | binary.BigEndian.PutUint32(packet[24:28], 0x12345678) // Seq number 272 | binary.BigEndian.PutUint32(packet[28:32], 0) // Ack number 273 | packet[32] = 0x50 // Header length 274 | packet[33] = 0x02 // SYN flag 275 | binary.BigEndian.PutUint16(packet[34:36], 8192) // Window 276 | 277 | return packet 278 | } 279 | 280 | func (t *Tunnel) Listen(network, address string) (net.Listener, error) { 281 | // For incoming connections, we need to listen on our WireGuard IP 282 | // This is a placeholder - real implementation would handle TCP listening 283 | return net.Listen("tcp", fmt.Sprintf("%s%s", t.ourIP.String(), address)) 284 | } 285 | 286 | // IsWireGuardIP checks if an IP is in the WireGuard network 287 | func (t *Tunnel) IsWireGuardIP(ip net.IP) bool { 288 | // Check if the IP is in the 10.150.0.0/24 range (our WireGuard network) 289 | _, wgNet, err := net.ParseCIDR("10.150.0.0/24") 290 | if err != nil { 291 | return false 292 | } 293 | return wgNet.Contains(ip) 294 | } 295 | 296 | // DialWireGuard creates a connection to a WireGuard IP through the tunnel 297 | func (t *Tunnel) DialWireGuard(ctx context.Context, network, host, port string) (net.Conn, error) { 298 | // Parse destination IP and port 299 | ip := net.ParseIP(host) 300 | if ip == nil { 301 | return nil, fmt.Errorf("invalid IP address: %s", host) 302 | } 303 | 304 | portNum, err := strconv.Atoi(port) 305 | if err != nil { 306 | return nil, fmt.Errorf("invalid port: %s", port) 307 | } 308 | 309 | // Find the appropriate peer using routing engine 310 | peer, peerIdx := t.router.FindPeerForDestination(ip, portNum, network) 311 | if peer == nil { 312 | return nil, fmt.Errorf("no route to %s:%s", host, port) 313 | } 314 | 315 | logger.Debugf("WireGuard tunnel: routing %s:%s through peer %d (endpoint: %s)", host, port, peerIdx, peer.Endpoint) 316 | 317 | // For now, fall back to hostname translation for testing 318 | // In a production system, this would send packets through the WireGuard tunnel 319 | // to the selected peer 320 | var realHost string 321 | switch host { 322 | case "10.150.0.2": 323 | realHost = "node-server-1" 324 | case "10.150.0.3": 325 | realHost = "node-server-2" 326 | default: 327 | // In a real implementation, we would encapsulate and send through the tunnel 328 | // For now, try direct connection as fallback 329 | logger.Warnf("No hostname mapping for %s, attempting direct connection", host) 330 | realHost = host 331 | } 332 | 333 | dialer := &net.Dialer{} 334 | return dialer.DialContext(ctx, network, realHost+":"+port) 335 | } 336 | 337 | func (t *Tunnel) Close() error { 338 | if t.device != nil { 339 | t.device.Close() 340 | } 341 | if t.tun != nil { 342 | t.tun.Close() 343 | } 344 | return nil 345 | } 346 | 347 | // TunnelConn implements net.Conn 348 | func (tc *TunnelConn) Read(b []byte) (int, error) { 349 | data, ok := <-tc.readChan 350 | if !ok { 351 | return 0, fmt.Errorf("connection closed") 352 | } 353 | copy(b, data) 354 | return len(data), nil 355 | } 356 | 357 | func (tc *TunnelConn) Write(b []byte) (int, error) { 358 | select { 359 | case tc.writeChan <- b: 360 | return len(b), nil 361 | default: 362 | return 0, fmt.Errorf("write buffer full") 363 | } 364 | } 365 | 366 | func (tc *TunnelConn) Close() error { 367 | tc.mutex.Lock() 368 | defer tc.mutex.Unlock() 369 | 370 | if !tc.closed { 371 | tc.closed = true 372 | close(tc.readChan) 373 | close(tc.writeChan) 374 | } 375 | return nil 376 | } 377 | 378 | func (tc *TunnelConn) LocalAddr() net.Addr { return tc.localAddr } 379 | func (tc *TunnelConn) RemoteAddr() net.Addr { return tc.remoteAddr } 380 | func (tc *TunnelConn) SetDeadline(t time.Time) error { return nil } 381 | func (tc *TunnelConn) SetReadDeadline(t time.Time) error { return nil } 382 | func (tc *TunnelConn) SetWriteDeadline(t time.Time) error { return nil } 383 | 384 | func mustParsePort(s string) int { 385 | p, _ := strconv.Atoi(s) 386 | return p 387 | } 388 | -------------------------------------------------------------------------------- /lib/intercept.c: -------------------------------------------------------------------------------- 1 | #define _GNU_SOURCE 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | // Function pointers for original functions 17 | static int (*real_connect)(int sockfd, const struct sockaddr *addr, socklen_t addrlen) = NULL; 18 | static int (*real_bind)(int sockfd, const struct sockaddr *addr, socklen_t addrlen) = NULL; 19 | 20 | // Global variables for configuration 21 | static char *ipc_path = NULL; 22 | static int socks_port = 0; 23 | static int initialized = 0; 24 | 25 | // Initialize the library 26 | static void init_library() { 27 | if (initialized) return; 28 | initialized = 1; 29 | 30 | // Load original functions 31 | real_connect = dlsym(RTLD_NEXT, "connect"); 32 | real_bind = dlsym(RTLD_NEXT, "bind"); 33 | 34 | // Get configuration from environment 35 | ipc_path = getenv("WRAPGUARD_IPC_PATH"); 36 | char *socks_port_str = getenv("WRAPGUARD_SOCKS_PORT"); 37 | if (socks_port_str) { 38 | socks_port = atoi(socks_port_str); 39 | } 40 | 41 | // Debug output (only in debug mode) 42 | char *debug_mode = getenv("WRAPGUARD_DEBUG"); 43 | if (debug_mode && strcmp(debug_mode, "1") == 0) { 44 | fprintf(stderr, "WrapGuard LD_PRELOAD: Initialized\n"); 45 | fprintf(stderr, "WrapGuard LD_PRELOAD: IPC path: %s\n", ipc_path ? ipc_path : "NULL"); 46 | fprintf(stderr, "WrapGuard LD_PRELOAD: SOCKS port: %d\n", socks_port); 47 | } 48 | 49 | if (!ipc_path || socks_port == 0) { 50 | fprintf(stderr, "WrapGuard: Missing environment variables\n"); 51 | } 52 | } 53 | 54 | // Check if an address should be intercepted 55 | static int should_intercept_connect(const struct sockaddr *addr) { 56 | if (addr->sa_family != AF_INET && addr->sa_family != AF_INET6) { 57 | return 0; // Only intercept IP connections 58 | } 59 | 60 | if (addr->sa_family == AF_INET) { 61 | struct sockaddr_in *in_addr = (struct sockaddr_in *)addr; 62 | 63 | // Don't intercept localhost connections (except when connecting to our SOCKS proxy) 64 | uint32_t ip = ntohl(in_addr->sin_addr.s_addr); 65 | if ((ip & 0xFF000000) == 0x7F000000) { // 127.x.x.x 66 | int port = ntohs(in_addr->sin_port); 67 | if (port == socks_port) { 68 | return 0; // Don't intercept connections to our own SOCKS proxy 69 | } 70 | } 71 | 72 | return 1; // Intercept all other connections 73 | } 74 | 75 | // TODO: Handle IPv6 if needed 76 | return 0; 77 | } 78 | 79 | // Send IPC message 80 | static void send_ipc_message(const char *type, int fd, int port, const char *addr) { 81 | if (!ipc_path) return; 82 | 83 | int sock = socket(AF_UNIX, SOCK_STREAM, 0); 84 | if (sock < 0) return; 85 | 86 | struct sockaddr_un sun; 87 | memset(&sun, 0, sizeof(sun)); 88 | sun.sun_family = AF_UNIX; 89 | strncpy(sun.sun_path, ipc_path, sizeof(sun.sun_path) - 1); 90 | 91 | if (connect(sock, (struct sockaddr *)&sun, sizeof(sun)) == 0) { 92 | char message[512]; 93 | snprintf(message, sizeof(message), 94 | "{\"type\":\"%s\",\"fd\":%d,\"port\":%d,\"addr\":\"%s\"}\n", 95 | type, fd, port, addr ? addr : ""); 96 | 97 | write(sock, message, strlen(message)); 98 | } 99 | 100 | close(sock); 101 | } 102 | 103 | // SOCKS5 connection helper 104 | static int socks5_connect(int sockfd, const struct sockaddr *addr, socklen_t addrlen) { 105 | char *debug_mode = getenv("WRAPGUARD_DEBUG"); 106 | 107 | if (addr->sa_family != AF_INET) { 108 | errno = EAFNOSUPPORT; 109 | return -1; 110 | } 111 | 112 | struct sockaddr_in *target = (struct sockaddr_in *)addr; 113 | struct sockaddr_in socks_addr; 114 | memset(&socks_addr, 0, sizeof(socks_addr)); 115 | socks_addr.sin_family = AF_INET; 116 | socks_addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); 117 | socks_addr.sin_port = htons(socks_port); 118 | 119 | // Connect to SOCKS5 proxy 120 | if (debug_mode && strcmp(debug_mode, "1") == 0) { 121 | fprintf(stderr, "WrapGuard LD_PRELOAD: Connecting to SOCKS5 proxy at 127.0.0.1:%d\n", socks_port); 122 | } 123 | int connect_result = real_connect(sockfd, (struct sockaddr *)&socks_addr, sizeof(socks_addr)); 124 | if (connect_result != 0 && errno != EINPROGRESS) { 125 | fprintf(stderr, "WrapGuard LD_PRELOAD: Failed to connect to SOCKS5 proxy: %s\n", strerror(errno)); 126 | return -1; 127 | } 128 | 129 | // For non-blocking sockets, we need to wait for connection to complete 130 | if (errno == EINPROGRESS) { 131 | if (debug_mode && strcmp(debug_mode, "1") == 0) { 132 | fprintf(stderr, "WrapGuard LD_PRELOAD: Non-blocking connect in progress, waiting...\n"); 133 | } 134 | fd_set write_fds; 135 | FD_ZERO(&write_fds); 136 | FD_SET(sockfd, &write_fds); 137 | 138 | struct timeval timeout = {5, 0}; // 5 second timeout 139 | int select_result = select(sockfd + 1, NULL, &write_fds, NULL, &timeout); 140 | if (select_result <= 0) { 141 | fprintf(stderr, "WrapGuard LD_PRELOAD: Timeout waiting for SOCKS5 connection\n"); 142 | return -1; 143 | } 144 | 145 | // Check if connection actually succeeded 146 | int so_error; 147 | socklen_t len = sizeof(so_error); 148 | if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, &so_error, &len) != 0 || so_error != 0) { 149 | fprintf(stderr, "WrapGuard LD_PRELOAD: SOCKS5 connection failed: %s\n", strerror(so_error)); 150 | return -1; 151 | } 152 | } 153 | 154 | if (debug_mode && strcmp(debug_mode, "1") == 0) { 155 | fprintf(stderr, "WrapGuard LD_PRELOAD: Connected to SOCKS5 proxy, starting handshake\n"); 156 | } 157 | 158 | // SOCKS5 handshake 159 | unsigned char handshake[] = {0x05, 0x01, 0x00}; // Version 5, 1 method, no auth 160 | if (debug_mode && strcmp(debug_mode, "1") == 0) { 161 | fprintf(stderr, "WrapGuard LD_PRELOAD: Sending SOCKS5 handshake\n"); 162 | } 163 | if (send(sockfd, handshake, 3, 0) != 3) { 164 | fprintf(stderr, "WrapGuard LD_PRELOAD: Failed to send SOCKS5 handshake\n"); 165 | return -1; 166 | } 167 | 168 | unsigned char response[2]; 169 | if (debug_mode && strcmp(debug_mode, "1") == 0) { 170 | fprintf(stderr, "WrapGuard LD_PRELOAD: Waiting for SOCKS5 handshake response\n"); 171 | } 172 | 173 | // Wait for response with timeout (non-blocking socket issue) 174 | fd_set read_fds; 175 | FD_ZERO(&read_fds); 176 | FD_SET(sockfd, &read_fds); 177 | 178 | struct timeval timeout = {5, 0}; // 5 second timeout 179 | int select_result = select(sockfd + 1, &read_fds, NULL, NULL, &timeout); 180 | if (select_result <= 0) { 181 | fprintf(stderr, "WrapGuard LD_PRELOAD: Timeout waiting for SOCKS5 handshake response\n"); 182 | return -1; 183 | } 184 | 185 | int recv_bytes = recv(sockfd, response, 2, 0); 186 | if (recv_bytes != 2) { 187 | fprintf(stderr, "WrapGuard LD_PRELOAD: SOCKS5 handshake response failed, got %d bytes, errno: %s\n", recv_bytes, strerror(errno)); 188 | return -1; 189 | } 190 | if (response[0] != 0x05 || response[1] != 0x00) { 191 | fprintf(stderr, "WrapGuard LD_PRELOAD: Invalid SOCKS5 handshake response: %02x %02x\n", response[0], response[1]); 192 | return -1; 193 | } 194 | if (debug_mode && strcmp(debug_mode, "1") == 0) { 195 | fprintf(stderr, "WrapGuard LD_PRELOAD: SOCKS5 handshake successful\n"); 196 | } 197 | 198 | // SOCKS5 connect request 199 | unsigned char connect_req[10]; 200 | connect_req[0] = 0x05; // Version 201 | connect_req[1] = 0x01; // Connect command 202 | connect_req[2] = 0x00; // Reserved 203 | connect_req[3] = 0x01; // IPv4 address type 204 | memcpy(&connect_req[4], &target->sin_addr, 4); // IP address 205 | memcpy(&connect_req[8], &target->sin_port, 2); // Port 206 | 207 | if (send(sockfd, connect_req, 10, 0) != 10) { 208 | return -1; 209 | } 210 | 211 | // Read SOCKS5 response with timeout 212 | unsigned char connect_resp[10]; 213 | 214 | FD_ZERO(&read_fds); 215 | FD_SET(sockfd, &read_fds); 216 | timeout.tv_sec = 5; 217 | timeout.tv_usec = 0; 218 | 219 | select_result = select(sockfd + 1, &read_fds, NULL, NULL, &timeout); 220 | if (select_result <= 0) { 221 | fprintf(stderr, "WrapGuard LD_PRELOAD: Timeout waiting for SOCKS5 connect response\n"); 222 | return -1; 223 | } 224 | 225 | if (recv(sockfd, connect_resp, 10, 0) != 10 || connect_resp[0] != 0x05 || connect_resp[1] != 0x00) { 226 | fprintf(stderr, "WrapGuard LD_PRELOAD: SOCKS5 connect failed\n"); 227 | errno = ECONNREFUSED; 228 | return -1; 229 | } 230 | 231 | return 0; // Success 232 | } 233 | 234 | // Intercepted connect function 235 | int connect(int sockfd, const struct sockaddr *addr, socklen_t addrlen) { 236 | init_library(); 237 | 238 | // Convert address to string for logging 239 | char addr_str[INET_ADDRSTRLEN + 16]; 240 | if (addr->sa_family == AF_INET) { 241 | struct sockaddr_in *in_addr = (struct sockaddr_in *)addr; 242 | char ip_str[INET_ADDRSTRLEN]; 243 | inet_ntop(AF_INET, &in_addr->sin_addr, ip_str, INET_ADDRSTRLEN); 244 | snprintf(addr_str, sizeof(addr_str), "%s:%d", ip_str, ntohs(in_addr->sin_port)); 245 | } else { 246 | strcpy(addr_str, "unknown"); 247 | } 248 | 249 | char *debug_mode = getenv("WRAPGUARD_DEBUG"); 250 | if (debug_mode && strcmp(debug_mode, "1") == 0) { 251 | fprintf(stderr, "WrapGuard LD_PRELOAD: connect() called for %s\n", addr_str); 252 | } 253 | 254 | if (!should_intercept_connect(addr)) { 255 | if (debug_mode && strcmp(debug_mode, "1") == 0) { 256 | fprintf(stderr, "WrapGuard LD_PRELOAD: NOT intercepting %s\n", addr_str); 257 | } 258 | return real_connect(sockfd, addr, addrlen); 259 | } 260 | 261 | if (debug_mode && strcmp(debug_mode, "1") == 0) { 262 | fprintf(stderr, "WrapGuard LD_PRELOAD: INTERCEPTING %s, routing through SOCKS5\n", addr_str); 263 | } 264 | 265 | // Send IPC message 266 | send_ipc_message("CONNECT", sockfd, 0, addr_str); 267 | 268 | // Route through SOCKS5 269 | return socks5_connect(sockfd, addr, addrlen); 270 | } 271 | 272 | // Intercepted bind function 273 | int bind(int sockfd, const struct sockaddr *addr, socklen_t addrlen) { 274 | init_library(); 275 | 276 | // Call original bind first 277 | int result = real_bind(sockfd, addr, addrlen); 278 | 279 | // If bind succeeded and it's a TCP socket, notify the main process 280 | if (result == 0 && addr->sa_family == AF_INET) { 281 | struct sockaddr_in *in_addr = (struct sockaddr_in *)addr; 282 | int port = ntohs(in_addr->sin_port); 283 | 284 | // Get the actual port if it was auto-assigned (port 0) 285 | if (port == 0) { 286 | struct sockaddr_in actual_addr; 287 | socklen_t actual_len = sizeof(actual_addr); 288 | if (getsockname(sockfd, (struct sockaddr *)&actual_addr, &actual_len) == 0) { 289 | port = ntohs(actual_addr.sin_port); 290 | } 291 | } 292 | 293 | // Check if it's a TCP socket 294 | int sock_type; 295 | socklen_t opt_len = sizeof(sock_type); 296 | if (getsockopt(sockfd, SOL_SOCKET, SO_TYPE, &sock_type, &opt_len) == 0 && sock_type == SOCK_STREAM) { 297 | // Send IPC message to set up port forwarding 298 | send_ipc_message("BIND", sockfd, port, NULL); 299 | } 300 | } 301 | 302 | return result; 303 | } -------------------------------------------------------------------------------- /forwarder_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "net" 6 | "net/netip" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | func TestNewPortForwarder(t *testing.T) { 12 | tunnel := &Tunnel{ 13 | ourIP: netip.MustParseAddr("10.150.0.2"), 14 | } 15 | 16 | msgChan := make(chan IPCMessage, 10) 17 | forwarder := NewPortForwarder(tunnel, msgChan) 18 | 19 | if forwarder == nil { 20 | t.Fatal("NewPortForwarder returned nil") 21 | } 22 | 23 | if forwarder.tunnel != tunnel { 24 | t.Error("tunnel not set correctly") 25 | } 26 | 27 | if forwarder.msgChan != msgChan { 28 | t.Error("message channel not set correctly") 29 | } 30 | 31 | if forwarder.listeners == nil { 32 | t.Error("listeners map not initialized") 33 | } 34 | 35 | if len(forwarder.listeners) != 0 { 36 | t.Error("listeners map should be empty initially") 37 | } 38 | } 39 | 40 | func TestPortForwarder_HandleBind(t *testing.T) { 41 | tunnel := &Tunnel{ 42 | ourIP: netip.MustParseAddr("10.150.0.2"), 43 | } 44 | 45 | msgChan := make(chan IPCMessage, 10) 46 | forwarder := NewPortForwarder(tunnel, msgChan) 47 | 48 | // Test binding to a port 49 | port := 8080 50 | err := forwarder.handleBind(port) 51 | 52 | // In test environment, this might fail to bind to the WireGuard IP 53 | // but should fall back to localhost 54 | if err != nil { 55 | t.Logf("handleBind failed (expected in test env): %v", err) 56 | return 57 | } 58 | 59 | // Check that listener was created 60 | if _, exists := forwarder.listeners[port]; !exists { 61 | t.Error("listener not created for port") 62 | } 63 | 64 | // Clean up 65 | forwarder.closeAllListeners() 66 | } 67 | 68 | func TestPortForwarder_HandleBindDuplicate(t *testing.T) { 69 | tunnel := &Tunnel{ 70 | ourIP: netip.MustParseAddr("10.150.0.2"), 71 | } 72 | 73 | msgChan := make(chan IPCMessage, 10) 74 | forwarder := NewPortForwarder(tunnel, msgChan) 75 | defer forwarder.closeAllListeners() 76 | 77 | port := 8081 78 | 79 | // First bind should succeed or fail gracefully 80 | err1 := forwarder.handleBind(port) 81 | 82 | // Second bind to same port should not create duplicate listener 83 | err2 := forwarder.handleBind(port) 84 | 85 | // Both should either succeed or fail gracefully 86 | if err1 != nil && err2 != nil { 87 | t.Logf("Both bind attempts failed (expected in test env): %v, %v", err1, err2) 88 | return 89 | } 90 | 91 | // Should only have one listener for the port 92 | count := 0 93 | for p := range forwarder.listeners { 94 | if p == port { 95 | count++ 96 | } 97 | } 98 | 99 | if count > 1 { 100 | t.Errorf("found %d listeners for port %d, want at most 1", count, port) 101 | } 102 | } 103 | 104 | func TestPortForwarder_Run(t *testing.T) { 105 | tunnel := &Tunnel{ 106 | ourIP: netip.MustParseAddr("10.150.0.2"), 107 | } 108 | 109 | msgChan := make(chan IPCMessage, 10) 110 | forwarder := NewPortForwarder(tunnel, msgChan) 111 | 112 | ctx, cancel := context.WithCancel(context.Background()) 113 | defer cancel() 114 | 115 | // Start the forwarder in a goroutine 116 | done := make(chan bool) 117 | go func() { 118 | forwarder.Run(ctx) 119 | done <- true 120 | }() 121 | 122 | // Send a BIND message 123 | bindMsg := IPCMessage{ 124 | Type: "BIND", 125 | Port: 8082, 126 | } 127 | 128 | msgChan <- bindMsg 129 | 130 | // Give some time for message processing 131 | time.Sleep(50 * time.Millisecond) 132 | 133 | // Cancel context to stop the forwarder 134 | cancel() 135 | 136 | // Wait for forwarder to stop 137 | select { 138 | case <-done: 139 | // Good, forwarder stopped 140 | case <-time.After(1 * time.Second): 141 | t.Error("forwarder did not stop within timeout") 142 | } 143 | 144 | // All listeners should be closed 145 | if len(forwarder.listeners) != 0 { 146 | t.Errorf("expected 0 listeners after close, got %d", len(forwarder.listeners)) 147 | } 148 | } 149 | 150 | func TestPortForwarder_RunWithNonBindMessage(t *testing.T) { 151 | tunnel := &Tunnel{ 152 | ourIP: netip.MustParseAddr("10.150.0.2"), 153 | } 154 | 155 | msgChan := make(chan IPCMessage, 10) 156 | forwarder := NewPortForwarder(tunnel, msgChan) 157 | 158 | ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) 159 | defer cancel() 160 | 161 | // Start the forwarder 162 | done := make(chan bool) 163 | go func() { 164 | forwarder.Run(ctx) 165 | done <- true 166 | }() 167 | 168 | // Send a non-BIND message 169 | connectMsg := IPCMessage{ 170 | Type: "CONNECT", 171 | Port: 8083, 172 | } 173 | 174 | msgChan <- connectMsg 175 | 176 | // Wait for context timeout 177 | <-done 178 | 179 | // Should not have created any listeners 180 | if len(forwarder.listeners) != 0 { 181 | t.Errorf("expected 0 listeners for CONNECT message, got %d", len(forwarder.listeners)) 182 | } 183 | } 184 | 185 | func TestPortForwarder_CloseAllListeners(t *testing.T) { 186 | tunnel := &Tunnel{ 187 | ourIP: netip.MustParseAddr("10.150.0.2"), 188 | } 189 | 190 | msgChan := make(chan IPCMessage, 10) 191 | forwarder := NewPortForwarder(tunnel, msgChan) 192 | 193 | // Create mock listeners (using real listeners would require available ports) 194 | listener1, err := net.Listen("tcp", "127.0.0.1:0") 195 | if err != nil { 196 | t.Fatalf("failed to create test listener 1: %v", err) 197 | } 198 | 199 | listener2, err := net.Listen("tcp", "127.0.0.1:0") 200 | if err != nil { 201 | listener1.Close() 202 | t.Fatalf("failed to create test listener 2: %v", err) 203 | } 204 | 205 | port1 := listener1.Addr().(*net.TCPAddr).Port 206 | port2 := listener2.Addr().(*net.TCPAddr).Port 207 | 208 | forwarder.listeners[port1] = listener1 209 | forwarder.listeners[port2] = listener2 210 | 211 | // Close all listeners 212 | forwarder.closeAllListeners() 213 | 214 | // Listeners map should be empty 215 | if len(forwarder.listeners) != 0 { 216 | t.Errorf("expected 0 listeners after closeAll, got %d", len(forwarder.listeners)) 217 | } 218 | 219 | // Listeners should be closed (attempting to accept should fail) 220 | _, err = listener1.Accept() 221 | if err == nil { 222 | t.Error("listener1 should be closed") 223 | } 224 | 225 | _, err = listener2.Accept() 226 | if err == nil { 227 | t.Error("listener2 should be closed") 228 | } 229 | } 230 | 231 | func TestPortForwarder_AcceptConnections(t *testing.T) { 232 | // This test is complex to implement without real network setup 233 | // We'll test the basic structure and error handling 234 | 235 | tunnel := &Tunnel{ 236 | ourIP: netip.MustParseAddr("10.150.0.2"), 237 | } 238 | 239 | msgChan := make(chan IPCMessage, 10) 240 | forwarder := NewPortForwarder(tunnel, msgChan) 241 | 242 | // Create a listener on an available port 243 | listener, err := net.Listen("tcp", "127.0.0.1:0") 244 | if err != nil { 245 | t.Fatalf("failed to create test listener: %v", err) 246 | } 247 | defer listener.Close() 248 | 249 | port := listener.Addr().(*net.TCPAddr).Port 250 | 251 | // Start accepting connections in a goroutine 252 | done := make(chan bool) 253 | go func() { 254 | forwarder.acceptConnections(listener, port) 255 | done <- true 256 | }() 257 | 258 | // Close the listener to stop accepting 259 | listener.Close() 260 | 261 | // Wait for acceptConnections to exit 262 | select { 263 | case <-done: 264 | // Good, acceptConnections stopped 265 | case <-time.After(1 * time.Second): 266 | t.Error("acceptConnections did not stop within timeout") 267 | } 268 | } 269 | 270 | func TestPortForwarder_HandleConnection(t *testing.T) { 271 | // This test requires a more complex setup with actual network connections 272 | // For now, we'll test the basic structure 273 | 274 | tunnel := &Tunnel{ 275 | ourIP: netip.MustParseAddr("10.150.0.2"), 276 | } 277 | 278 | msgChan := make(chan IPCMessage, 10) 279 | forwarder := NewPortForwarder(tunnel, msgChan) 280 | 281 | // Create a mock connection pair 282 | server, client := net.Pipe() 283 | defer server.Close() 284 | defer client.Close() 285 | 286 | // Test that handleConnection doesn't panic 287 | // In a real scenario, this would connect to localhost:port 288 | // but that requires a server running on that port 289 | 290 | done := make(chan bool) 291 | go func() { 292 | defer func() { 293 | if r := recover(); r != nil { 294 | t.Errorf("handleConnection panicked: %v", r) 295 | } 296 | done <- true 297 | }() 298 | forwarder.handleConnection(server, 8080) 299 | }() 300 | 301 | // Close connections to trigger exit 302 | server.Close() 303 | client.Close() 304 | 305 | // Wait for completion 306 | select { 307 | case <-done: 308 | // Good, no panic 309 | case <-time.After(1 * time.Second): 310 | t.Error("handleConnection did not complete within timeout") 311 | } 312 | } 313 | 314 | func TestPortForwarder_ConcurrentAccess(t *testing.T) { 315 | tunnel := &Tunnel{ 316 | ourIP: netip.MustParseAddr("10.150.0.2"), 317 | } 318 | 319 | msgChan := make(chan IPCMessage, 100) 320 | forwarder := NewPortForwarder(tunnel, msgChan) 321 | 322 | // Test concurrent access to the listeners map 323 | done := make(chan bool, 10) 324 | 325 | // Start multiple goroutines trying to bind to different ports 326 | for i := 0; i < 10; i++ { 327 | go func(port int) { 328 | defer func() { 329 | done <- true 330 | }() 331 | // This will likely fail in test environment, but tests concurrency 332 | forwarder.handleBind(8000 + port) 333 | }(i) 334 | } 335 | 336 | // Wait for all goroutines to complete 337 | for i := 0; i < 10; i++ { 338 | select { 339 | case <-done: 340 | // Good 341 | case <-time.After(2 * time.Second): 342 | t.Error("goroutine did not complete within timeout") 343 | return 344 | } 345 | } 346 | 347 | // Clean up 348 | forwarder.closeAllListeners() 349 | } 350 | 351 | func TestPortForwarder_MessageChannelClosed(t *testing.T) { 352 | tunnel := &Tunnel{ 353 | ourIP: netip.MustParseAddr("10.150.0.2"), 354 | } 355 | 356 | msgChan := make(chan IPCMessage, 10) 357 | forwarder := NewPortForwarder(tunnel, msgChan) 358 | 359 | ctx, cancel := context.WithCancel(context.Background()) 360 | defer cancel() 361 | 362 | // Start the forwarder 363 | done := make(chan bool) 364 | go func() { 365 | forwarder.Run(ctx) 366 | done <- true 367 | }() 368 | 369 | // Close the message channel 370 | close(msgChan) 371 | 372 | // Give some time for the forwarder to handle the closed channel 373 | time.Sleep(50 * time.Millisecond) 374 | 375 | // Cancel context 376 | cancel() 377 | 378 | // Wait for forwarder to stop 379 | select { 380 | case <-done: 381 | // Good, forwarder handled closed channel gracefully 382 | case <-time.After(1 * time.Second): 383 | t.Error("forwarder did not stop after channel close") 384 | } 385 | } 386 | 387 | // Test IP address validation 388 | func TestPortForwarder_IPValidation(t *testing.T) { 389 | tests := []struct { 390 | name string 391 | ip string 392 | }{ 393 | {"IPv4", "10.150.0.2"}, 394 | {"IPv6", "::1"}, 395 | {"nil", ""}, 396 | } 397 | 398 | for _, tt := range tests { 399 | t.Run(tt.name, func(t *testing.T) { 400 | var ip netip.Addr 401 | if tt.ip != "" { 402 | var err error 403 | ip, err = netip.ParseAddr(tt.ip) 404 | if err != nil { 405 | t.Fatalf("invalid test IP: %v", err) 406 | } 407 | } 408 | 409 | tunnel := &Tunnel{ 410 | ourIP: ip, 411 | } 412 | 413 | msgChan := make(chan IPCMessage, 10) 414 | forwarder := NewPortForwarder(tunnel, msgChan) 415 | 416 | // Should not panic with any IP configuration 417 | if forwarder == nil { 418 | t.Error("NewPortForwarder returned nil") 419 | } 420 | }) 421 | } 422 | } 423 | 424 | // Benchmark test for port forwarder creation 425 | func BenchmarkNewPortForwarder(b *testing.B) { 426 | tunnel := &Tunnel{ 427 | ourIP: netip.MustParseAddr("10.150.0.2"), 428 | } 429 | 430 | msgChan := make(chan IPCMessage, 10) 431 | 432 | b.ResetTimer() 433 | for i := 0; i < b.N; i++ { 434 | forwarder := NewPortForwarder(tunnel, msgChan) 435 | _ = forwarder 436 | } 437 | } 438 | 439 | // Benchmark test for bind handling 440 | func BenchmarkPortForwarder_HandleBind(b *testing.B) { 441 | tunnel := &Tunnel{ 442 | ourIP: netip.MustParseAddr("10.150.0.2"), 443 | } 444 | 445 | msgChan := make(chan IPCMessage, 10) 446 | forwarder := NewPortForwarder(tunnel, msgChan) 447 | defer forwarder.closeAllListeners() 448 | 449 | b.ResetTimer() 450 | for i := 0; i < b.N; i++ { 451 | // Use different ports to avoid conflicts 452 | port := 8000 + (i % 1000) 453 | forwarder.handleBind(port) 454 | } 455 | } 456 | -------------------------------------------------------------------------------- /logger_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "os" 7 | "strings" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | func TestLogLevel_String(t *testing.T) { 13 | tests := []struct { 14 | level LogLevel 15 | expected string 16 | }{ 17 | {LogLevelError, "error"}, 18 | {LogLevelWarn, "warn"}, 19 | {LogLevelInfo, "info"}, 20 | {LogLevelDebug, "debug"}, 21 | {LogLevel(999), "unknown"}, 22 | } 23 | 24 | for _, tt := range tests { 25 | t.Run(tt.expected, func(t *testing.T) { 26 | if got := tt.level.String(); got != tt.expected { 27 | t.Errorf("LogLevel.String() = %v, want %v", got, tt.expected) 28 | } 29 | }) 30 | } 31 | } 32 | 33 | func TestParseLogLevel(t *testing.T) { 34 | tests := []struct { 35 | input string 36 | expected LogLevel 37 | expectError bool 38 | }{ 39 | {"error", LogLevelError, false}, 40 | {"warn", LogLevelWarn, false}, 41 | {"warning", LogLevelWarn, false}, 42 | {"info", LogLevelInfo, false}, 43 | {"debug", LogLevelDebug, false}, 44 | {"ERROR", LogLevelError, false}, // Test case insensitive 45 | {"INFO", LogLevelInfo, false}, 46 | {"invalid", LogLevelInfo, true}, 47 | {"", LogLevelInfo, true}, 48 | {"trace", LogLevelInfo, true}, 49 | } 50 | 51 | for _, tt := range tests { 52 | t.Run(tt.input, func(t *testing.T) { 53 | got, err := ParseLogLevel(tt.input) 54 | 55 | if tt.expectError { 56 | if err == nil { 57 | t.Error("expected error but got none") 58 | } 59 | return 60 | } 61 | 62 | if err != nil { 63 | t.Errorf("unexpected error: %v", err) 64 | return 65 | } 66 | 67 | if got != tt.expected { 68 | t.Errorf("ParseLogLevel(%q) = %v, want %v", tt.input, got, tt.expected) 69 | } 70 | }) 71 | } 72 | } 73 | 74 | func TestNewLogger(t *testing.T) { 75 | var buf bytes.Buffer 76 | logger := NewLogger(LogLevelInfo, &buf) 77 | 78 | if logger == nil { 79 | t.Error("NewLogger returned nil") 80 | } 81 | 82 | if logger.level != LogLevelInfo { 83 | t.Errorf("expected level %v, got %v", LogLevelInfo, logger.level) 84 | } 85 | 86 | if logger.output != &buf { 87 | t.Error("output not set correctly") 88 | } 89 | } 90 | 91 | func TestLogger_Log(t *testing.T) { 92 | tests := []struct { 93 | name string 94 | loggerLevel LogLevel 95 | logLevel LogLevel 96 | message string 97 | shouldOutput bool 98 | }{ 99 | {"error at error level", LogLevelError, LogLevelError, "error message", true}, 100 | {"warn at error level", LogLevelError, LogLevelWarn, "warn message", false}, 101 | {"info at error level", LogLevelError, LogLevelInfo, "info message", false}, 102 | {"debug at error level", LogLevelError, LogLevelDebug, "debug message", false}, 103 | 104 | {"error at warn level", LogLevelWarn, LogLevelError, "error message", true}, 105 | {"warn at warn level", LogLevelWarn, LogLevelWarn, "warn message", true}, 106 | {"info at warn level", LogLevelWarn, LogLevelInfo, "info message", false}, 107 | {"debug at warn level", LogLevelWarn, LogLevelDebug, "debug message", false}, 108 | 109 | {"error at info level", LogLevelInfo, LogLevelError, "error message", true}, 110 | {"warn at info level", LogLevelInfo, LogLevelWarn, "warn message", true}, 111 | {"info at info level", LogLevelInfo, LogLevelInfo, "info message", true}, 112 | {"debug at info level", LogLevelInfo, LogLevelDebug, "debug message", false}, 113 | 114 | {"error at debug level", LogLevelDebug, LogLevelError, "error message", true}, 115 | {"warn at debug level", LogLevelDebug, LogLevelWarn, "warn message", true}, 116 | {"info at debug level", LogLevelDebug, LogLevelInfo, "info message", true}, 117 | {"debug at debug level", LogLevelDebug, LogLevelDebug, "debug message", true}, 118 | } 119 | 120 | for _, tt := range tests { 121 | t.Run(tt.name, func(t *testing.T) { 122 | var buf bytes.Buffer 123 | logger := NewLogger(tt.loggerLevel, &buf) 124 | 125 | logger.log(tt.logLevel, "%s", tt.message) 126 | 127 | output := buf.String() 128 | if tt.shouldOutput { 129 | if output == "" { 130 | t.Error("expected output but got none") 131 | } 132 | 133 | // Verify JSON format 134 | var entry LogEntry 135 | if err := json.Unmarshal([]byte(strings.TrimSpace(output)), &entry); err != nil { 136 | t.Errorf("failed to parse JSON output: %v", err) 137 | } 138 | 139 | if entry.Level != tt.logLevel.String() { 140 | t.Errorf("expected level %s, got %s", tt.logLevel.String(), entry.Level) 141 | } 142 | 143 | if entry.Message != tt.message { 144 | t.Errorf("expected message %q, got %q", tt.message, entry.Message) 145 | } 146 | 147 | if entry.Timestamp == "" { 148 | t.Error("timestamp is empty") 149 | } 150 | 151 | // Verify timestamp format 152 | if _, err := time.Parse(time.RFC3339, entry.Timestamp); err != nil { 153 | t.Errorf("invalid timestamp format: %v", err) 154 | } 155 | } else { 156 | if output != "" { 157 | t.Errorf("expected no output but got: %s", output) 158 | } 159 | } 160 | }) 161 | } 162 | } 163 | 164 | func TestLogger_LogMethods(t *testing.T) { 165 | var buf bytes.Buffer 166 | logger := NewLogger(LogLevelDebug, &buf) 167 | 168 | tests := []struct { 169 | name string 170 | logFunc func(string, ...interface{}) 171 | level string 172 | message string 173 | }{ 174 | {"Errorf", logger.Errorf, "error", "error message"}, 175 | {"Warnf", logger.Warnf, "warn", "warning message"}, 176 | {"Infof", logger.Infof, "info", "info message"}, 177 | {"Debugf", logger.Debugf, "debug", "debug message"}, 178 | } 179 | 180 | for _, tt := range tests { 181 | t.Run(tt.name, func(t *testing.T) { 182 | buf.Reset() 183 | tt.logFunc(tt.message) 184 | 185 | output := buf.String() 186 | if output == "" { 187 | t.Error("expected output but got none") 188 | } 189 | 190 | var entry LogEntry 191 | if err := json.Unmarshal([]byte(strings.TrimSpace(output)), &entry); err != nil { 192 | t.Errorf("failed to parse JSON output: %v", err) 193 | } 194 | 195 | if entry.Level != tt.level { 196 | t.Errorf("expected level %s, got %s", tt.level, entry.Level) 197 | } 198 | 199 | if entry.Message != tt.message { 200 | t.Errorf("expected message %q, got %q", tt.message, entry.Message) 201 | } 202 | }) 203 | } 204 | } 205 | 206 | func TestLogger_LogWithFormatting(t *testing.T) { 207 | var buf bytes.Buffer 208 | logger := NewLogger(LogLevelInfo, &buf) 209 | 210 | logger.Infof("test message with %s and %d", "string", 42) 211 | 212 | output := buf.String() 213 | var entry LogEntry 214 | if err := json.Unmarshal([]byte(strings.TrimSpace(output)), &entry); err != nil { 215 | t.Errorf("failed to parse JSON output: %v", err) 216 | } 217 | 218 | expected := "test message with string and 42" 219 | if entry.Message != expected { 220 | t.Errorf("expected message %q, got %q", expected, entry.Message) 221 | } 222 | } 223 | 224 | func TestLogger_JSONMarshaling(t *testing.T) { 225 | var buf bytes.Buffer 226 | logger := NewLogger(LogLevelInfo, &buf) 227 | 228 | // Test with special characters that need JSON escaping 229 | message := `test "quoted" message with \backslash and 230 | newline` 231 | logger.Infof("%s", message) 232 | 233 | output := buf.String() 234 | var entry LogEntry 235 | if err := json.Unmarshal([]byte(strings.TrimSpace(output)), &entry); err != nil { 236 | t.Errorf("failed to parse JSON output: %v", err) 237 | } 238 | 239 | if entry.Message != message { 240 | t.Errorf("message not preserved correctly through JSON marshaling") 241 | } 242 | } 243 | 244 | func TestLogger_ConcurrentAccess(t *testing.T) { 245 | // Test that concurrent logging doesn't panic or cause data races 246 | // We'll use a simpler approach that just verifies the logger doesn't crash 247 | 248 | var buf bytes.Buffer 249 | logger := NewLogger(LogLevelDebug, &buf) 250 | 251 | // Test concurrent logging with fewer goroutines and messages 252 | done := make(chan bool, 2) 253 | 254 | for i := 0; i < 2; i++ { 255 | go func(id int) { 256 | // Just log a few messages to test thread safety 257 | for j := 0; j < 3; j++ { 258 | logger.Infof("goroutine %d message %d", id, j) 259 | time.Sleep(1 * time.Millisecond) // Small delay to reduce race conditions 260 | } 261 | done <- true 262 | }(i) 263 | } 264 | 265 | // Wait for all goroutines to complete 266 | for i := 0; i < 2; i++ { 267 | select { 268 | case <-done: 269 | case <-time.After(2 * time.Second): 270 | t.Error("goroutine did not complete within timeout") 271 | return 272 | } 273 | } 274 | 275 | // Give time for all writes to complete 276 | time.Sleep(50 * time.Millisecond) 277 | 278 | output := buf.String() 279 | 280 | // Just verify we got some output and it's not corrupted 281 | if len(output) == 0 { 282 | t.Error("expected some log output from concurrent access") 283 | } 284 | 285 | // Verify that we have at least some valid JSON lines 286 | lines := strings.Split(strings.TrimSpace(output), "\n") 287 | validLines := 0 288 | for _, line := range lines { 289 | line = strings.TrimSpace(line) 290 | if line == "" { 291 | continue 292 | } 293 | var entry LogEntry 294 | if err := json.Unmarshal([]byte(line), &entry); err == nil { 295 | validLines++ 296 | } 297 | } 298 | 299 | // We should have at least a few valid log entries 300 | if validLines < 2 { 301 | t.Errorf("expected at least 2 valid log entries from concurrent access, got %d", validLines) 302 | } 303 | } 304 | 305 | func TestSetGlobalLogger(t *testing.T) { 306 | // Save original logger 307 | originalLogger := logger 308 | 309 | // Create a new logger 310 | var buf bytes.Buffer 311 | testLogger := NewLogger(LogLevelError, &buf) 312 | 313 | // Set as global logger 314 | SetGlobalLogger(testLogger) 315 | 316 | // Verify it was set 317 | if logger != testLogger { 318 | t.Error("global logger not set correctly") 319 | } 320 | 321 | // Restore original logger 322 | SetGlobalLogger(originalLogger) 323 | } 324 | 325 | func TestGlobalLoggerInitialization(t *testing.T) { 326 | // The global logger should be initialized in init() 327 | if logger == nil { 328 | t.Error("global logger not initialized") 329 | } 330 | 331 | if logger.level != LogLevelInfo { 332 | t.Errorf("expected default log level %v, got %v", LogLevelInfo, logger.level) 333 | } 334 | 335 | if logger.output != os.Stderr { 336 | t.Error("expected default output to be os.Stderr") 337 | } 338 | } 339 | 340 | func TestLogEntry_JSONTags(t *testing.T) { 341 | entry := LogEntry{ 342 | Timestamp: "2023-01-01T00:00:00Z", 343 | Level: "info", 344 | Message: "test message", 345 | } 346 | 347 | data, err := json.Marshal(entry) 348 | if err != nil { 349 | t.Fatalf("failed to marshal LogEntry: %v", err) 350 | } 351 | 352 | expected := `{"timestamp":"2023-01-01T00:00:00Z","level":"info","message":"test message"}` 353 | if string(data) != expected { 354 | t.Errorf("JSON output mismatch:\nexpected: %s\ngot: %s", expected, string(data)) 355 | } 356 | } 357 | 358 | func TestLogger_EmptyMessage(t *testing.T) { 359 | var buf bytes.Buffer 360 | logger := NewLogger(LogLevelInfo, &buf) 361 | 362 | logger.Infof("") 363 | 364 | output := buf.String() 365 | var entry LogEntry 366 | if err := json.Unmarshal([]byte(strings.TrimSpace(output)), &entry); err != nil { 367 | t.Errorf("failed to parse JSON output: %v", err) 368 | } 369 | 370 | if entry.Message != "" { 371 | t.Errorf("expected empty message, got %q", entry.Message) 372 | } 373 | } 374 | 375 | func TestLogger_LongMessage(t *testing.T) { 376 | var buf bytes.Buffer 377 | logger := NewLogger(LogLevelInfo, &buf) 378 | 379 | // Create a very long message 380 | longMessage := strings.Repeat("a", 10000) 381 | logger.Infof("%s", longMessage) 382 | 383 | output := buf.String() 384 | var entry LogEntry 385 | if err := json.Unmarshal([]byte(strings.TrimSpace(output)), &entry); err != nil { 386 | t.Errorf("failed to parse JSON output: %v", err) 387 | } 388 | 389 | if entry.Message != longMessage { 390 | t.Error("long message not preserved correctly") 391 | } 392 | } 393 | 394 | // Benchmark tests for performance 395 | func BenchmarkLogger_Info(b *testing.B) { 396 | var buf bytes.Buffer 397 | logger := NewLogger(LogLevelInfo, &buf) 398 | 399 | b.ResetTimer() 400 | for i := 0; i < b.N; i++ { 401 | logger.Infof("benchmark message %d", i) 402 | } 403 | } 404 | 405 | func BenchmarkLogger_InfoFiltered(b *testing.B) { 406 | var buf bytes.Buffer 407 | logger := NewLogger(LogLevelError, &buf) // Debug messages will be filtered 408 | 409 | b.ResetTimer() 410 | for i := 0; i < b.N; i++ { 411 | logger.Debugf("benchmark message %d", i) 412 | } 413 | } 414 | 415 | func BenchmarkParseLogLevel(b *testing.B) { 416 | levels := []string{"error", "warn", "info", "debug"} 417 | 418 | b.ResetTimer() 419 | for i := 0; i < b.N; i++ { 420 | level := levels[i%len(levels)] 421 | _, _ = ParseLogLevel(level) 422 | } 423 | } 424 | -------------------------------------------------------------------------------- /ipc_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "net" 6 | "os" 7 | "path/filepath" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | func TestNewIPCServer(t *testing.T) { 13 | server, err := NewIPCServer() 14 | if err != nil { 15 | t.Fatalf("NewIPCServer failed: %v", err) 16 | } 17 | defer server.Close() 18 | 19 | if server == nil { 20 | t.Fatal("NewIPCServer returned nil") 21 | } 22 | 23 | if server.listener == nil { 24 | t.Error("listener is nil") 25 | } 26 | 27 | if server.socketPath == "" { 28 | t.Error("socket path is empty") 29 | } 30 | 31 | if server.msgChan == nil { 32 | t.Error("message channel is nil") 33 | } 34 | 35 | // Check that socket path is in temp directory 36 | expectedDir := os.TempDir() 37 | actualDir := filepath.Dir(server.socketPath) 38 | // Clean the paths to handle trailing slashes consistently 39 | expectedDir = filepath.Clean(expectedDir) 40 | actualDir = filepath.Clean(actualDir) 41 | if actualDir != expectedDir { 42 | t.Errorf("socket path not in temp dir: expected %s, got %s", expectedDir, actualDir) 43 | } 44 | 45 | // Check that socket file contains PID 46 | if !containsPID(server.socketPath) { 47 | t.Error("socket path should contain PID") 48 | } 49 | } 50 | 51 | func TestIPCServer_SocketPath(t *testing.T) { 52 | server, err := NewIPCServer() 53 | if err != nil { 54 | t.Fatalf("NewIPCServer failed: %v", err) 55 | } 56 | defer server.Close() 57 | 58 | path := server.SocketPath() 59 | if path == "" { 60 | t.Error("SocketPath returned empty string") 61 | } 62 | 63 | if path != server.socketPath { 64 | t.Errorf("SocketPath() = %q, want %q", path, server.socketPath) 65 | } 66 | } 67 | 68 | func TestIPCServer_MessageChan(t *testing.T) { 69 | server, err := NewIPCServer() 70 | if err != nil { 71 | t.Fatalf("NewIPCServer failed: %v", err) 72 | } 73 | defer server.Close() 74 | 75 | msgChan := server.MessageChan() 76 | if msgChan == nil { 77 | t.Error("MessageChan returned nil") 78 | } 79 | 80 | // Test that it's the same channel 81 | if msgChan != server.msgChan { 82 | t.Error("MessageChan returned different channel") 83 | } 84 | 85 | // Test that it's read-only 86 | select { 87 | case <-msgChan: 88 | // This is fine, channel is empty 89 | default: 90 | // This is expected 91 | } 92 | } 93 | 94 | func TestIPCServer_Close(t *testing.T) { 95 | server, err := NewIPCServer() 96 | if err != nil { 97 | t.Fatalf("NewIPCServer failed: %v", err) 98 | } 99 | 100 | socketPath := server.socketPath 101 | 102 | // Socket file should exist 103 | if _, err := os.Stat(socketPath); os.IsNotExist(err) { 104 | t.Error("socket file should exist before close") 105 | } 106 | 107 | // Close the server 108 | err = server.Close() 109 | if err != nil { 110 | t.Errorf("Close() returned error: %v", err) 111 | } 112 | 113 | // Socket file should be removed 114 | if _, err := os.Stat(socketPath); !os.IsNotExist(err) { 115 | t.Error("socket file should be removed after close") 116 | } 117 | 118 | // Multiple closes should not panic 119 | err = server.Close() 120 | if err != nil { 121 | t.Errorf("second Close() returned error: %v", err) 122 | } 123 | } 124 | 125 | func TestIPCServer_MessageHandling(t *testing.T) { 126 | server, err := NewIPCServer() 127 | if err != nil { 128 | t.Fatalf("NewIPCServer failed: %v", err) 129 | } 130 | defer server.Close() 131 | 132 | // Give server time to start accepting connections 133 | time.Sleep(10 * time.Millisecond) 134 | 135 | // Connect to the IPC server 136 | conn, err := net.Dial("unix", server.socketPath) 137 | if err != nil { 138 | t.Fatalf("failed to connect to IPC server: %v", err) 139 | } 140 | defer conn.Close() 141 | 142 | // Test message 143 | msg := IPCMessage{ 144 | Type: "CONNECT", 145 | FD: 42, 146 | Port: 8080, 147 | Addr: "127.0.0.1:8080", 148 | } 149 | 150 | // Send message 151 | msgBytes, err := json.Marshal(msg) 152 | if err != nil { 153 | t.Fatalf("failed to marshal message: %v", err) 154 | } 155 | 156 | _, err = conn.Write(append(msgBytes, '\n')) 157 | if err != nil { 158 | t.Fatalf("failed to write message: %v", err) 159 | } 160 | 161 | // Receive message from channel 162 | select { 163 | case receivedMsg := <-server.msgChan: 164 | if receivedMsg.Type != msg.Type { 165 | t.Errorf("received Type = %q, want %q", receivedMsg.Type, msg.Type) 166 | } 167 | if receivedMsg.FD != msg.FD { 168 | t.Errorf("received FD = %d, want %d", receivedMsg.FD, msg.FD) 169 | } 170 | if receivedMsg.Port != msg.Port { 171 | t.Errorf("received Port = %d, want %d", receivedMsg.Port, msg.Port) 172 | } 173 | if receivedMsg.Addr != msg.Addr { 174 | t.Errorf("received Addr = %q, want %q", receivedMsg.Addr, msg.Addr) 175 | } 176 | case <-time.After(1 * time.Second): 177 | t.Error("timeout waiting for message") 178 | } 179 | } 180 | 181 | func TestIPCServer_InvalidMessage(t *testing.T) { 182 | server, err := NewIPCServer() 183 | if err != nil { 184 | t.Fatalf("NewIPCServer failed: %v", err) 185 | } 186 | defer server.Close() 187 | 188 | // Give server time to start 189 | time.Sleep(10 * time.Millisecond) 190 | 191 | // Connect to the IPC server 192 | conn, err := net.Dial("unix", server.socketPath) 193 | if err != nil { 194 | t.Fatalf("failed to connect to IPC server: %v", err) 195 | } 196 | defer conn.Close() 197 | 198 | // Send invalid JSON 199 | _, err = conn.Write([]byte("invalid json\n")) 200 | if err != nil { 201 | t.Fatalf("failed to write invalid message: %v", err) 202 | } 203 | 204 | // Should not receive anything on message channel 205 | select { 206 | case msg := <-server.msgChan: 207 | t.Errorf("received unexpected message: %+v", msg) 208 | case <-time.After(100 * time.Millisecond): 209 | // This is expected - invalid messages should be dropped 210 | } 211 | } 212 | 213 | func TestIPCServer_MultipleConnections(t *testing.T) { 214 | server, err := NewIPCServer() 215 | if err != nil { 216 | t.Fatalf("NewIPCServer failed: %v", err) 217 | } 218 | defer server.Close() 219 | 220 | // Give server time to start 221 | time.Sleep(10 * time.Millisecond) 222 | 223 | // Create multiple connections 224 | conns := make([]net.Conn, 3) 225 | defer func() { 226 | for _, conn := range conns { 227 | if conn != nil { 228 | conn.Close() 229 | } 230 | } 231 | }() 232 | 233 | for i := 0; i < 3; i++ { 234 | conn, err := net.Dial("unix", server.socketPath) 235 | if err != nil { 236 | t.Fatalf("failed to connect %d to IPC server: %v", i, err) 237 | } 238 | conns[i] = conn 239 | } 240 | 241 | // Send messages from all connections 242 | messages := []IPCMessage{ 243 | {Type: "CONNECT", FD: 1, Port: 8080, Addr: "127.0.0.1:8080"}, 244 | {Type: "BIND", FD: 2, Port: 8081, Addr: "127.0.0.1:8081"}, 245 | {Type: "CONNECT", FD: 3, Port: 8082, Addr: "127.0.0.1:8082"}, 246 | } 247 | 248 | for i, msg := range messages { 249 | msgBytes, err := json.Marshal(msg) 250 | if err != nil { 251 | t.Fatalf("failed to marshal message %d: %v", i, err) 252 | } 253 | 254 | _, err = conns[i].Write(append(msgBytes, '\n')) 255 | if err != nil { 256 | t.Fatalf("failed to write message %d: %v", i, err) 257 | } 258 | } 259 | 260 | // Receive all messages 261 | received := make(map[int]IPCMessage) 262 | for i := 0; i < 3; i++ { 263 | select { 264 | case msg := <-server.msgChan: 265 | received[msg.FD] = msg 266 | case <-time.After(1 * time.Second): 267 | t.Errorf("timeout waiting for message %d", i) 268 | } 269 | } 270 | 271 | // Verify all messages were received 272 | for i, originalMsg := range messages { 273 | receivedMsg, ok := received[originalMsg.FD] 274 | if !ok { 275 | t.Errorf("message %d not received", i) 276 | continue 277 | } 278 | 279 | if receivedMsg.Type != originalMsg.Type { 280 | t.Errorf("message %d: Type = %q, want %q", i, receivedMsg.Type, originalMsg.Type) 281 | } 282 | if receivedMsg.Port != originalMsg.Port { 283 | t.Errorf("message %d: Port = %d, want %d", i, receivedMsg.Port, originalMsg.Port) 284 | } 285 | if receivedMsg.Addr != originalMsg.Addr { 286 | t.Errorf("message %d: Addr = %q, want %q", i, receivedMsg.Addr, originalMsg.Addr) 287 | } 288 | } 289 | } 290 | 291 | func TestIPCServer_ChannelBuffering(t *testing.T) { 292 | server, err := NewIPCServer() 293 | if err != nil { 294 | t.Fatalf("NewIPCServer failed: %v", err) 295 | } 296 | defer server.Close() 297 | 298 | // Give server time to start 299 | time.Sleep(10 * time.Millisecond) 300 | 301 | // Connect to server 302 | conn, err := net.Dial("unix", server.socketPath) 303 | if err != nil { 304 | t.Fatalf("failed to connect to IPC server: %v", err) 305 | } 306 | defer conn.Close() 307 | 308 | // Send many messages without reading from channel 309 | // This tests the channel buffering (should be 100) 310 | for i := 0; i < 50; i++ { 311 | msg := IPCMessage{ 312 | Type: "CONNECT", 313 | FD: i, 314 | Port: 8080 + i, 315 | Addr: "127.0.0.1:8080", 316 | } 317 | 318 | msgBytes, err := json.Marshal(msg) 319 | if err != nil { 320 | t.Fatalf("failed to marshal message %d: %v", i, err) 321 | } 322 | 323 | _, err = conn.Write(append(msgBytes, '\n')) 324 | if err != nil { 325 | t.Fatalf("failed to write message %d: %v", i, err) 326 | } 327 | } 328 | 329 | // Give time for messages to be processed 330 | time.Sleep(100 * time.Millisecond) 331 | 332 | // Now read messages from channel 333 | count := 0 334 | for { 335 | select { 336 | case <-server.msgChan: 337 | count++ 338 | case <-time.After(100 * time.Millisecond): 339 | // No more messages 340 | goto done 341 | } 342 | } 343 | 344 | done: 345 | if count != 50 { 346 | t.Errorf("received %d messages, want 50", count) 347 | } 348 | } 349 | 350 | func TestIPCMessage_JSONMarshaling(t *testing.T) { 351 | msg := IPCMessage{ 352 | Type: "BIND", 353 | FD: 42, 354 | Port: 8080, 355 | Addr: "192.168.1.1:8080", 356 | } 357 | 358 | // Marshal to JSON 359 | data, err := json.Marshal(msg) 360 | if err != nil { 361 | t.Fatalf("failed to marshal IPCMessage: %v", err) 362 | } 363 | 364 | // Unmarshal from JSON 365 | var unmarshaled IPCMessage 366 | err = json.Unmarshal(data, &unmarshaled) 367 | if err != nil { 368 | t.Fatalf("failed to unmarshal IPCMessage: %v", err) 369 | } 370 | 371 | // Compare 372 | if unmarshaled.Type != msg.Type { 373 | t.Errorf("Type = %q, want %q", unmarshaled.Type, msg.Type) 374 | } 375 | if unmarshaled.FD != msg.FD { 376 | t.Errorf("FD = %d, want %d", unmarshaled.FD, msg.FD) 377 | } 378 | if unmarshaled.Port != msg.Port { 379 | t.Errorf("Port = %d, want %d", unmarshaled.Port, msg.Port) 380 | } 381 | if unmarshaled.Addr != msg.Addr { 382 | t.Errorf("Addr = %q, want %q", unmarshaled.Addr, msg.Addr) 383 | } 384 | } 385 | 386 | func TestIPCServer_ConnectionClosed(t *testing.T) { 387 | server, err := NewIPCServer() 388 | if err != nil { 389 | t.Fatalf("NewIPCServer failed: %v", err) 390 | } 391 | defer server.Close() 392 | 393 | // Give server time to start 394 | time.Sleep(10 * time.Millisecond) 395 | 396 | // Connect and immediately close 397 | conn, err := net.Dial("unix", server.socketPath) 398 | if err != nil { 399 | t.Fatalf("failed to connect to IPC server: %v", err) 400 | } 401 | 402 | // Send a message and then close 403 | msg := IPCMessage{Type: "CONNECT", FD: 1, Port: 8080, Addr: "127.0.0.1:8080"} 404 | msgBytes, _ := json.Marshal(msg) 405 | conn.Write(append(msgBytes, '\n')) 406 | conn.Close() 407 | 408 | // Should receive the message 409 | select { 410 | case receivedMsg := <-server.msgChan: 411 | if receivedMsg.Type != msg.Type { 412 | t.Errorf("received wrong message type: %s", receivedMsg.Type) 413 | } 414 | case <-time.After(1 * time.Second): 415 | t.Error("timeout waiting for message") 416 | } 417 | 418 | // Server should handle the closed connection gracefully 419 | // (no panic or error) 420 | } 421 | 422 | func TestIPCServer_SocketPermissions(t *testing.T) { 423 | server, err := NewIPCServer() 424 | if err != nil { 425 | t.Fatalf("NewIPCServer failed: %v", err) 426 | } 427 | defer server.Close() 428 | 429 | // Check that socket file exists and has appropriate permissions 430 | info, err := os.Stat(server.socketPath) 431 | if err != nil { 432 | t.Fatalf("failed to stat socket file: %v", err) 433 | } 434 | 435 | // Should be a socket 436 | if info.Mode()&os.ModeSocket == 0 { 437 | t.Error("socket file is not a socket") 438 | } 439 | } 440 | 441 | // Helper function to check if path contains PID 442 | func containsPID(path string) bool { 443 | filename := filepath.Base(path) 444 | return len(filename) > len("wrapguard-.sock") 445 | } 446 | 447 | // Benchmark test for IPC server creation 448 | func BenchmarkNewIPCServer(b *testing.B) { 449 | for i := 0; i < b.N; i++ { 450 | server, err := NewIPCServer() 451 | if err != nil { 452 | b.Fatalf("NewIPCServer failed: %v", err) 453 | } 454 | server.Close() 455 | } 456 | } 457 | 458 | // Benchmark test for message handling 459 | func BenchmarkIPCServer_MessageHandling(b *testing.B) { 460 | server, err := NewIPCServer() 461 | if err != nil { 462 | b.Fatalf("NewIPCServer failed: %v", err) 463 | } 464 | defer server.Close() 465 | 466 | // Give server time to start 467 | time.Sleep(10 * time.Millisecond) 468 | 469 | conn, err := net.Dial("unix", server.socketPath) 470 | if err != nil { 471 | b.Fatalf("failed to connect to IPC server: %v", err) 472 | } 473 | defer conn.Close() 474 | 475 | msg := IPCMessage{ 476 | Type: "CONNECT", 477 | FD: 42, 478 | Port: 8080, 479 | Addr: "127.0.0.1:8080", 480 | } 481 | 482 | msgBytes, _ := json.Marshal(msg) 483 | msgLine := append(msgBytes, '\n') 484 | 485 | // Drain the channel in a goroutine 486 | go func() { 487 | for { 488 | select { 489 | case <-server.msgChan: 490 | case <-time.After(1 * time.Second): 491 | return 492 | } 493 | } 494 | }() 495 | 496 | b.ResetTimer() 497 | for i := 0; i < b.N; i++ { 498 | conn.Write(msgLine) 499 | } 500 | } 501 | -------------------------------------------------------------------------------- /main_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "flag" 6 | "os" 7 | "os/exec" 8 | "path/filepath" 9 | "strings" 10 | "testing" 11 | "time" 12 | ) 13 | 14 | func TestPrintUsage(t *testing.T) { 15 | // Capture stderr output 16 | oldStderr := os.Stderr 17 | r, w, _ := os.Pipe() 18 | os.Stderr = w 19 | 20 | printUsage() 21 | 22 | w.Close() 23 | os.Stderr = oldStderr 24 | 25 | // Read captured output 26 | buf := make([]byte, 4096) 27 | n, err := r.Read(buf) 28 | if err != nil && n == 0 { 29 | t.Fatal("failed to read usage output") 30 | } 31 | 32 | output := string(buf[:n]) 33 | 34 | // Check that usage contains expected elements 35 | expectedParts := []string{ 36 | "wrapguard", // Changed to lowercase to match actual output 37 | "USAGE:", 38 | "--config", 39 | "EXAMPLES:", 40 | "curl", 41 | "OPTIONS:", 42 | "--log-level", 43 | "--help", 44 | } 45 | 46 | for _, part := range expectedParts { 47 | if !strings.Contains(output, part) { 48 | t.Errorf("usage output missing expected part: %s", part) 49 | } 50 | } 51 | } 52 | 53 | func TestMainWithHelp(t *testing.T) { 54 | // Test the --help flag 55 | // We need to test this by running the program as a subprocess 56 | // since main() calls os.Exit() 57 | 58 | if os.Getenv("TEST_MAIN_HELP") == "1" { 59 | // We're in the subprocess 60 | os.Args = []string{"wrapguard", "--help"} 61 | main() 62 | return 63 | } 64 | 65 | // Run subprocess 66 | cmd := exec.Command(os.Args[0], "-test.run=TestMainWithHelp") 67 | cmd.Env = append(os.Environ(), "TEST_MAIN_HELP=1") 68 | 69 | output, err := cmd.CombinedOutput() 70 | if err != nil { 71 | // Exit code 0 is expected for --help 72 | if exitErr, ok := err.(*exec.ExitError); ok { 73 | if exitErr.ExitCode() != 0 { 74 | t.Errorf("expected exit code 0 for --help, got %d", exitErr.ExitCode()) 75 | } 76 | } 77 | } 78 | 79 | outputStr := string(output) 80 | if !strings.Contains(strings.ToLower(outputStr), "wrapguard") { 81 | t.Error("help output should contain 'wrapguard'") 82 | } 83 | } 84 | 85 | func TestMainWithVersion(t *testing.T) { 86 | if os.Getenv("TEST_MAIN_VERSION") == "1" { 87 | // We're in the subprocess 88 | os.Args = []string{"wrapguard", "--version"} 89 | main() 90 | return 91 | } 92 | 93 | // Run subprocess 94 | cmd := exec.Command(os.Args[0], "-test.run=TestMainWithVersion") 95 | cmd.Env = append(os.Environ(), "TEST_MAIN_VERSION=1") 96 | 97 | output, err := cmd.CombinedOutput() 98 | if err != nil { 99 | if exitErr, ok := err.(*exec.ExitError); ok { 100 | if exitErr.ExitCode() != 0 { 101 | t.Errorf("expected exit code 0 for --version, got %d", exitErr.ExitCode()) 102 | } 103 | } 104 | } 105 | 106 | outputStr := string(output) 107 | if !strings.Contains(outputStr, "wrapguard version") { 108 | t.Error("version output should contain 'wrapguard version'") 109 | } 110 | } 111 | 112 | func TestMainWithNoConfig(t *testing.T) { 113 | if os.Getenv("TEST_MAIN_NO_CONFIG") == "1" { 114 | // We're in the subprocess 115 | os.Args = []string{"wrapguard", "echo", "hello"} 116 | main() 117 | return 118 | } 119 | 120 | // Run subprocess 121 | cmd := exec.Command(os.Args[0], "-test.run=TestMainWithNoConfig") 122 | cmd.Env = append(os.Environ(), "TEST_MAIN_NO_CONFIG=1") 123 | 124 | output, err := cmd.CombinedOutput() 125 | if err != nil { 126 | if exitErr, ok := err.(*exec.ExitError); ok { 127 | if exitErr.ExitCode() != 1 { 128 | t.Errorf("expected exit code 1 for no config, got %d", exitErr.ExitCode()) 129 | } 130 | } 131 | } 132 | 133 | outputStr := string(output) 134 | if !strings.Contains(outputStr, "USAGE:") { 135 | t.Error("no config should show usage") 136 | } 137 | } 138 | 139 | func TestMainWithInvalidLogLevel(t *testing.T) { 140 | if os.Getenv("TEST_MAIN_INVALID_LOG") == "1" { 141 | // We're in the subprocess 142 | tempConfig := createTempConfig(t) 143 | defer os.Remove(tempConfig) 144 | 145 | os.Args = []string{"wrapguard", "--config=" + tempConfig, "--log-level=invalid", "echo", "hello"} 146 | main() 147 | return 148 | } 149 | 150 | // Run subprocess 151 | cmd := exec.Command(os.Args[0], "-test.run=TestMainWithInvalidLogLevel") 152 | cmd.Env = append(os.Environ(), "TEST_MAIN_INVALID_LOG=1") 153 | 154 | output, err := cmd.CombinedOutput() 155 | if err != nil { 156 | if exitErr, ok := err.(*exec.ExitError); ok { 157 | if exitErr.ExitCode() != 1 { 158 | t.Errorf("expected exit code 1 for invalid log level, got %d", exitErr.ExitCode()) 159 | } 160 | } 161 | } 162 | 163 | outputStr := string(output) 164 | if !strings.Contains(outputStr, "Invalid log level") { 165 | t.Error("should show invalid log level error") 166 | } 167 | } 168 | 169 | func TestMainWithInvalidConfig(t *testing.T) { 170 | if os.Getenv("TEST_MAIN_INVALID_CONFIG") == "1" { 171 | // We're in the subprocess 172 | os.Args = []string{"wrapguard", "--config=/nonexistent/config.conf", "echo", "hello"} 173 | main() 174 | return 175 | } 176 | 177 | // Run subprocess 178 | cmd := exec.Command(os.Args[0], "-test.run=TestMainWithInvalidConfig") 179 | cmd.Env = append(os.Environ(), "TEST_MAIN_INVALID_CONFIG=1") 180 | 181 | _, err := cmd.CombinedOutput() 182 | if err != nil { 183 | if exitErr, ok := err.(*exec.ExitError); ok { 184 | if exitErr.ExitCode() != 1 { 185 | t.Errorf("expected exit code 1 for invalid config, got %d", exitErr.ExitCode()) 186 | } 187 | } 188 | } 189 | } 190 | 191 | func TestMainWithNoCommand(t *testing.T) { 192 | if os.Getenv("TEST_MAIN_NO_COMMAND") == "1" { 193 | // We're in the subprocess 194 | tempConfig := createTempConfig(t) 195 | defer os.Remove(tempConfig) 196 | 197 | os.Args = []string{"wrapguard", "--config=" + tempConfig} 198 | main() 199 | return 200 | } 201 | 202 | // Run subprocess 203 | cmd := exec.Command(os.Args[0], "-test.run=TestMainWithNoCommand") 204 | cmd.Env = append(os.Environ(), "TEST_MAIN_NO_COMMAND=1") 205 | 206 | output, err := cmd.CombinedOutput() 207 | if err != nil { 208 | if exitErr, ok := err.(*exec.ExitError); ok { 209 | if exitErr.ExitCode() != 1 { 210 | t.Errorf("expected exit code 1 for no command, got %d", exitErr.ExitCode()) 211 | } 212 | } 213 | } 214 | 215 | outputStr := string(output) 216 | if !strings.Contains(outputStr, "No command specified") { 217 | t.Error("should show no command error") 218 | } 219 | } 220 | 221 | func TestMainWithLogFile(t *testing.T) { 222 | if os.Getenv("TEST_MAIN_LOG_FILE") == "1" { 223 | // We're in the subprocess 224 | tempConfig := createTempConfig(t) 225 | defer os.Remove(tempConfig) 226 | 227 | tempLog := filepath.Join(os.TempDir(), "wrapguard-test.log") 228 | defer os.Remove(tempLog) 229 | 230 | os.Args = []string{"wrapguard", "--config=" + tempConfig, "--log-file=" + tempLog, "echo", "hello"} 231 | main() 232 | return 233 | } 234 | 235 | // Run subprocess 236 | cmd := exec.Command(os.Args[0], "-test.run=TestMainWithLogFile") 237 | cmd.Env = append(os.Environ(), "TEST_MAIN_LOG_FILE=1") 238 | 239 | // This will likely fail due to missing WireGuard setup, but we can test 240 | // that it attempts to create the log file 241 | cmd.Run() 242 | 243 | // The test mainly ensures no panic occurs with log file option 244 | } 245 | 246 | func TestFlagParsing(t *testing.T) { 247 | // Test flag parsing logic separately 248 | tests := []struct { 249 | name string 250 | args []string 251 | expected struct { 252 | config string 253 | help bool 254 | version bool 255 | logLevel string 256 | logFile string 257 | } 258 | }{ 259 | { 260 | name: "basic config", 261 | args: []string{"--config=test.conf", "echo", "hello"}, 262 | expected: struct { 263 | config string 264 | help bool 265 | version bool 266 | logLevel string 267 | logFile string 268 | }{ 269 | config: "test.conf", 270 | help: false, 271 | version: false, 272 | logLevel: "info", 273 | logFile: "", 274 | }, 275 | }, 276 | { 277 | name: "help flag", 278 | args: []string{"--help"}, 279 | expected: struct { 280 | config string 281 | help bool 282 | version bool 283 | logLevel string 284 | logFile string 285 | }{ 286 | config: "", 287 | help: true, 288 | version: false, 289 | logLevel: "info", 290 | logFile: "", 291 | }, 292 | }, 293 | { 294 | name: "version flag", 295 | args: []string{"--version"}, 296 | expected: struct { 297 | config string 298 | help bool 299 | version bool 300 | logLevel string 301 | logFile string 302 | }{ 303 | config: "", 304 | help: false, 305 | version: true, 306 | logLevel: "info", 307 | logFile: "", 308 | }, 309 | }, 310 | { 311 | name: "all flags", 312 | args: []string{"--config=test.conf", "--log-level=debug", "--log-file=test.log", "echo", "hello"}, 313 | expected: struct { 314 | config string 315 | help bool 316 | version bool 317 | logLevel string 318 | logFile string 319 | }{ 320 | config: "test.conf", 321 | help: false, 322 | version: false, 323 | logLevel: "debug", 324 | logFile: "test.log", 325 | }, 326 | }, 327 | } 328 | 329 | for _, tt := range tests { 330 | t.Run(tt.name, func(t *testing.T) { 331 | // Reset flag package for each test 332 | flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) 333 | 334 | var configPath, logLevelStr, logFile string 335 | var showHelp, showVersion bool 336 | 337 | flag.StringVar(&configPath, "config", "", "Path to WireGuard configuration file") 338 | flag.BoolVar(&showHelp, "help", false, "Show help message") 339 | flag.BoolVar(&showVersion, "version", false, "Show version information") 340 | flag.StringVar(&logLevelStr, "log-level", "info", "Set log level") 341 | flag.StringVar(&logFile, "log-file", "", "Set file to write logs to") 342 | 343 | // Parse the test arguments 344 | err := flag.CommandLine.Parse(tt.args) 345 | if err != nil { 346 | t.Fatalf("flag parsing failed: %v", err) 347 | } 348 | 349 | // Check results 350 | if configPath != tt.expected.config { 351 | t.Errorf("config = %q, want %q", configPath, tt.expected.config) 352 | } 353 | if showHelp != tt.expected.help { 354 | t.Errorf("help = %v, want %v", showHelp, tt.expected.help) 355 | } 356 | if showVersion != tt.expected.version { 357 | t.Errorf("version = %v, want %v", showVersion, tt.expected.version) 358 | } 359 | if logLevelStr != tt.expected.logLevel { 360 | t.Errorf("logLevel = %q, want %q", logLevelStr, tt.expected.logLevel) 361 | } 362 | if logFile != tt.expected.logFile { 363 | t.Errorf("logFile = %q, want %q", logFile, tt.expected.logFile) 364 | } 365 | }) 366 | } 367 | } 368 | 369 | func TestMainIntegration(t *testing.T) { 370 | // This is a more comprehensive integration test 371 | // It will likely fail in test environment due to missing WireGuard setup 372 | // but tests the full initialization flow 373 | 374 | if os.Getenv("TEST_MAIN_INTEGRATION") == "1" { 375 | // We're in the subprocess 376 | tempConfig := createTempConfig(t) 377 | defer os.Remove(tempConfig) 378 | 379 | tempLog := filepath.Join(os.TempDir(), "wrapguard-integration.log") 380 | defer os.Remove(tempLog) 381 | 382 | os.Args = []string{"wrapguard", 383 | "--config=" + tempConfig, 384 | "--log-level=debug", 385 | "--log-file=" + tempLog, 386 | "echo", "integration test"} 387 | main() 388 | return 389 | } 390 | 391 | // Run subprocess with timeout 392 | cmd := exec.Command(os.Args[0], "-test.run=TestMainIntegration") 393 | cmd.Env = append(os.Environ(), "TEST_MAIN_INTEGRATION=1") 394 | 395 | // Use a timeout to prevent hanging 396 | done := make(chan error, 1) 397 | go func() { 398 | done <- cmd.Run() 399 | }() 400 | 401 | select { 402 | case err := <-done: 403 | // Test completed (likely with error due to WireGuard setup) 404 | if err != nil { 405 | t.Logf("Integration test failed as expected (no WireGuard): %v", err) 406 | } 407 | case <-time.After(10 * time.Second): 408 | cmd.Process.Kill() 409 | t.Error("Integration test timed out") 410 | } 411 | } 412 | 413 | // Helper function to create a temporary valid config file 414 | func createTempConfig(t *testing.T) string { 415 | tempFile, err := os.CreateTemp("", "wrapguard-test-*.conf") 416 | if err != nil { 417 | t.Fatalf("failed to create temp config: %v", err) 418 | } 419 | 420 | config := `[Interface] 421 | PrivateKey = cGluZy1wcml2YXRlLWtleS0xMjM0NTY3ODkwMTIzNDU2Nzg5MDEyMzQ1Njc4OTA= 422 | Address = 10.150.0.2/24 423 | 424 | [Peer] 425 | PublicKey = cGluZy1wdWJsaWMta2V5LTEyMzQ1Njc4OTAxMjM0NTY3ODkwMTIzNDU2Nzg5MDEy 426 | Endpoint = 127.0.0.1:51820 427 | AllowedIPs = 0.0.0.0/0` 428 | 429 | if _, err := tempFile.WriteString(config); err != nil { 430 | tempFile.Close() 431 | os.Remove(tempFile.Name()) 432 | t.Fatalf("failed to write temp config: %v", err) 433 | } 434 | 435 | tempFile.Close() 436 | return tempFile.Name() 437 | } 438 | 439 | // Test global logger setup in main 440 | func TestMainLoggerSetup(t *testing.T) { 441 | // Test that the logger is set up correctly in main 442 | // We can't easily test this without running main, but we can test 443 | // the logger creation logic 444 | 445 | tests := []struct { 446 | name string 447 | logLevel string 448 | logFile string 449 | wantErr bool 450 | }{ 451 | {"valid info level", "info", "", false}, 452 | {"valid debug level", "debug", "", false}, 453 | {"valid error level", "error", "", false}, 454 | {"valid warn level", "warn", "", false}, 455 | {"invalid level", "invalid", "", true}, 456 | {"valid with file", "info", "/tmp/test.log", false}, 457 | } 458 | 459 | for _, tt := range tests { 460 | t.Run(tt.name, func(t *testing.T) { 461 | logLevel, err := ParseLogLevel(tt.logLevel) 462 | 463 | if tt.wantErr { 464 | if err == nil { 465 | t.Error("expected error for invalid log level") 466 | } 467 | return 468 | } 469 | 470 | if err != nil { 471 | t.Errorf("unexpected error: %v", err) 472 | return 473 | } 474 | 475 | // Test logger creation 476 | var output bytes.Buffer 477 | logger := NewLogger(logLevel, &output) 478 | 479 | if logger == nil { 480 | t.Error("NewLogger returned nil") 481 | } 482 | 483 | // Test logging 484 | logger.Infof("test message") 485 | 486 | // Only expect output for levels that should produce output 487 | if output.Len() == 0 && logLevel >= LogLevelInfo { 488 | t.Error("expected log output") 489 | } 490 | }) 491 | } 492 | } 493 | 494 | // Test version constant consistency 495 | func TestVersionConsistency(t *testing.T) { 496 | // The version in main.go should be consistent 497 | mainVersion := version // from main.go 498 | moduleVersion := Version // from version.go 499 | 500 | // They might be different (main.go has its own constant) 501 | // but we test that they're both non-empty 502 | if mainVersion == "" { 503 | t.Error("main.go version constant is empty") 504 | } 505 | 506 | if moduleVersion == "" { 507 | t.Error("version.go Version constant is empty") 508 | } 509 | 510 | // Both should contain version-like strings 511 | if !strings.Contains(mainVersion, ".") { 512 | t.Error("main version should contain version number") 513 | } 514 | } 515 | 516 | // Benchmark test for flag parsing 517 | func BenchmarkFlagParsing(b *testing.B) { 518 | args := []string{"--config=test.conf", "--log-level=info", "echo", "hello"} 519 | 520 | b.ResetTimer() 521 | for i := 0; i < b.N; i++ { 522 | // Reset flag package 523 | flag.CommandLine = flag.NewFlagSet("test", flag.ContinueOnError) 524 | 525 | var configPath, logLevelStr, logFile string 526 | var showHelp, showVersion bool 527 | 528 | flag.StringVar(&configPath, "config", "", "Path to WireGuard configuration file") 529 | flag.BoolVar(&showHelp, "help", false, "Show help message") 530 | flag.BoolVar(&showVersion, "version", false, "Show version information") 531 | flag.StringVar(&logLevelStr, "log-level", "info", "Set log level") 532 | flag.StringVar(&logFile, "log-file", "", "Set file to write logs to") 533 | 534 | flag.CommandLine.Parse(args) 535 | } 536 | } 537 | -------------------------------------------------------------------------------- /tunnel_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "net" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | func TestNewMemoryTUN(t *testing.T) { 11 | tun := NewMemoryTUN("test-tun", 1420) 12 | 13 | if tun == nil { 14 | t.Fatal("NewMemoryTUN returned nil") 15 | } 16 | 17 | if tun.mtu != 1420 { 18 | t.Errorf("expected MTU 1420, got %d", tun.mtu) 19 | } 20 | 21 | if tun.name != "test-tun" { 22 | t.Errorf("expected name 'test-tun', got %q", tun.name) 23 | } 24 | 25 | if tun.closed { 26 | t.Error("TUN should not be closed initially") 27 | } 28 | 29 | if tun.inbound == nil { 30 | t.Error("inbound channel not initialized") 31 | } 32 | 33 | if tun.outbound == nil { 34 | t.Error("outbound channel not initialized") 35 | } 36 | 37 | if tun.events == nil { 38 | t.Error("events channel not initialized") 39 | } 40 | 41 | tun.Close() 42 | } 43 | 44 | func TestMemoryTUN_File(t *testing.T) { 45 | tun := NewMemoryTUN("test", 1420) 46 | defer tun.Close() 47 | 48 | if file := tun.File(); file != nil { 49 | t.Error("File() should return nil for memory TUN") 50 | } 51 | } 52 | 53 | func TestMemoryTUN_MTU(t *testing.T) { 54 | tun := NewMemoryTUN("test", 1500) 55 | defer tun.Close() 56 | 57 | mtu, err := tun.MTU() 58 | if err != nil { 59 | t.Errorf("MTU() returned error: %v", err) 60 | } 61 | 62 | if mtu != 1500 { 63 | t.Errorf("expected MTU 1500, got %d", mtu) 64 | } 65 | } 66 | 67 | func TestMemoryTUN_Name(t *testing.T) { 68 | tun := NewMemoryTUN("test-interface", 1420) 69 | defer tun.Close() 70 | 71 | name, err := tun.Name() 72 | if err != nil { 73 | t.Errorf("Name() returned error: %v", err) 74 | } 75 | 76 | if name != "test-interface" { 77 | t.Errorf("expected name 'test-interface', got %q", name) 78 | } 79 | } 80 | 81 | func TestMemoryTUN_Events(t *testing.T) { 82 | tun := NewMemoryTUN("test", 1420) 83 | defer tun.Close() 84 | 85 | events := tun.Events() 86 | if events == nil { 87 | t.Error("Events() returned nil channel") 88 | } 89 | } 90 | 91 | func TestMemoryTUN_ReadWrite(t *testing.T) { 92 | tun := NewMemoryTUN("test", 1420) 93 | defer tun.Close() 94 | 95 | // Test data 96 | testData := []byte("test packet data") 97 | 98 | // Write data 99 | go func() { 100 | time.Sleep(10 * time.Millisecond) // Small delay to ensure Read is waiting 101 | tun.inbound <- testData 102 | }() 103 | 104 | // Read data 105 | buf := make([]byte, 1500) 106 | bufs := [][]byte{buf} 107 | sizes := make([]int, 1) 108 | n, err := tun.Read(bufs, sizes, 0) 109 | if err != nil { 110 | t.Errorf("Read() returned error: %v", err) 111 | } 112 | 113 | if n != 1 { 114 | t.Errorf("expected to read 1 packet, got %d", n) 115 | } 116 | 117 | if sizes[0] != len(testData) { 118 | t.Errorf("expected packet size %d bytes, got %d", len(testData), sizes[0]) 119 | } 120 | 121 | if string(buf[:sizes[0]]) != string(testData) { 122 | t.Errorf("expected data %q, got %q", string(testData), string(buf[:sizes[0]])) 123 | } 124 | } 125 | 126 | func TestMemoryTUN_WriteToOutbound(t *testing.T) { 127 | tun := NewMemoryTUN("test", 1420) 128 | defer tun.Close() 129 | 130 | testData := []byte("outbound packet data") 131 | 132 | // Write to TUN (simulating WireGuard writing) 133 | bufs := [][]byte{testData} 134 | n, err := tun.Write(bufs, 0) 135 | if err != nil { 136 | t.Errorf("Write() returned error: %v", err) 137 | } 138 | 139 | if n != 1 { 140 | t.Errorf("expected to write 1 packet, got %d", n) 141 | } 142 | 143 | // Check if data appeared in outbound channel 144 | select { 145 | case data := <-tun.outbound: 146 | if string(data) != string(testData) { 147 | t.Errorf("expected outbound data %q, got %q", string(testData), string(data)) 148 | } 149 | case <-time.After(100 * time.Millisecond): 150 | t.Error("no data received on outbound channel") 151 | } 152 | } 153 | 154 | func TestMemoryTUN_Close(t *testing.T) { 155 | tun := NewMemoryTUN("test", 1420) 156 | 157 | // Close the TUN 158 | err := tun.Close() 159 | if err != nil { 160 | t.Errorf("Close() returned error: %v", err) 161 | } 162 | 163 | if !tun.closed { 164 | t.Error("TUN should be marked as closed") 165 | } 166 | 167 | // Test that Read returns error after close 168 | buf := make([]byte, 100) 169 | bufs := [][]byte{buf} 170 | sizes := make([]int, 1) 171 | _, err = tun.Read(bufs, sizes, 0) 172 | if err == nil { 173 | t.Error("Read() should return error after close") 174 | } 175 | 176 | // Test that Write returns error after close 177 | _, err = tun.Write([][]byte{[]byte("test")}, 0) 178 | if err == nil { 179 | t.Error("Write() should return error after close") 180 | } 181 | 182 | // Multiple closes should not panic 183 | err = tun.Close() 184 | if err != nil { 185 | t.Errorf("Second Close() returned error: %v", err) 186 | } 187 | } 188 | 189 | func TestMemoryTUN_Flush(t *testing.T) { 190 | tun := NewMemoryTUN("test", 1420) 191 | defer tun.Close() 192 | 193 | // Flush should not return error 194 | err := tun.Flush() 195 | if err != nil { 196 | t.Errorf("Flush() returned error: %v", err) 197 | } 198 | } 199 | 200 | func TestTunnel_IsWireGuardIP(t *testing.T) { 201 | config := &WireGuardConfig{ 202 | Interface: InterfaceConfig{ 203 | Address: "10.150.0.2/24", 204 | }, 205 | } 206 | 207 | ourIP, _ := config.GetInterfaceIP() 208 | tunnel := &Tunnel{ 209 | ourIP: ourIP, 210 | } 211 | 212 | tests := []struct { 213 | name string 214 | ip string 215 | expected bool 216 | }{ 217 | {"WireGuard network IP", "10.150.0.5", true}, 218 | {"Our IP", "10.150.0.2", true}, 219 | {"Network address", "10.150.0.0", true}, 220 | {"Broadcast address", "10.150.0.255", true}, 221 | {"Outside network", "10.151.0.5", false}, 222 | {"Different network", "192.168.1.1", false}, 223 | {"Public IP", "8.8.8.8", false}, 224 | {"Invalid IP", "", false}, 225 | } 226 | 227 | for _, tt := range tests { 228 | t.Run(tt.name, func(t *testing.T) { 229 | ip := net.ParseIP(tt.ip) 230 | result := tunnel.IsWireGuardIP(ip) 231 | 232 | if result != tt.expected { 233 | t.Errorf("IsWireGuardIP(%q) = %v, want %v", tt.ip, result, tt.expected) 234 | } 235 | }) 236 | } 237 | } 238 | 239 | func TestTunnel_DialWireGuard(t *testing.T) { 240 | config := &WireGuardConfig{ 241 | Interface: InterfaceConfig{ 242 | Address: "10.150.0.2/24", 243 | }, 244 | Peers: []PeerConfig{ 245 | { 246 | PublicKey: "test-peer", 247 | Endpoint: "test.example.com:51820", 248 | AllowedIPs: []string{"0.0.0.0/0"}, 249 | }, 250 | }, 251 | } 252 | 253 | ourIP, _ := config.GetInterfaceIP() 254 | tunnel := &Tunnel{ 255 | ourIP: ourIP, 256 | config: config, 257 | router: NewRoutingEngine(config), 258 | } 259 | 260 | // Use a timeout context to prevent hanging on connection attempts 261 | ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 262 | defer cancel() 263 | 264 | // Test dialing known WireGuard IPs (fallback mode) 265 | tests := []struct { 266 | name string 267 | host string 268 | port string 269 | expectError bool 270 | }{ 271 | {"node-server-1", "10.150.0.2", "8080", false}, 272 | {"node-server-2", "10.150.0.3", "8080", false}, 273 | {"unknown WireGuard IP", "10.150.0.99", "8080", true}, 274 | } 275 | 276 | for _, tt := range tests { 277 | t.Run(tt.name, func(t *testing.T) { 278 | conn, err := tunnel.DialWireGuard(ctx, "tcp", tt.host, tt.port) 279 | 280 | if tt.expectError { 281 | if err == nil { 282 | t.Error("expected error but got none") 283 | if conn != nil { 284 | conn.Close() 285 | } 286 | } 287 | } else { 288 | // Note: This will likely fail in test environment since 289 | // node-server-1 and node-server-2 don't exist, but we test 290 | // that the function doesn't panic and handles the mapping 291 | if err != nil { 292 | // Expected in test environment 293 | t.Logf("DialWireGuard failed as expected in test environment: %v", err) 294 | } else if conn != nil { 295 | conn.Close() 296 | } 297 | } 298 | }) 299 | } 300 | } 301 | 302 | func TestCreateTCPSyn(t *testing.T) { 303 | config := &WireGuardConfig{ 304 | Interface: InterfaceConfig{ 305 | Address: "10.150.0.2/24", 306 | }, 307 | } 308 | 309 | ourIP, _ := config.GetInterfaceIP() 310 | tunnel := &Tunnel{ 311 | ourIP: ourIP, 312 | } 313 | 314 | dstIP := net.ParseIP("10.150.0.3") 315 | dstPort := 80 316 | 317 | packet := tunnel.createTCPSyn(dstIP, dstPort) 318 | 319 | if len(packet) != 40 { 320 | t.Errorf("expected packet length 40, got %d", len(packet)) 321 | } 322 | 323 | // Check IP version 324 | version := packet[0] >> 4 325 | if version != 4 { 326 | t.Errorf("expected IP version 4, got %d", version) 327 | } 328 | 329 | // Check protocol (should be TCP = 6) 330 | protocol := packet[9] 331 | if protocol != 6 { 332 | t.Errorf("expected protocol 6 (TCP), got %d", protocol) 333 | } 334 | 335 | // Check source IP 336 | srcIP := net.IP(packet[12:16]) 337 | if !srcIP.Equal(ourIP.AsSlice()) { 338 | t.Errorf("expected source IP %v, got %v", ourIP, srcIP) 339 | } 340 | 341 | // Check destination IP 342 | dstIPFromPacket := net.IP(packet[16:20]) 343 | if !dstIPFromPacket.Equal(dstIP) { 344 | t.Errorf("expected destination IP %v, got %v", dstIP, dstIPFromPacket) 345 | } 346 | } 347 | 348 | func TestTunnel_HandleIncomingPacket(t *testing.T) { 349 | config := &WireGuardConfig{ 350 | Interface: InterfaceConfig{ 351 | Address: "10.150.0.2/24", 352 | }, 353 | } 354 | 355 | ourIP, _ := config.GetInterfaceIP() 356 | tunnel := &Tunnel{ 357 | ourIP: ourIP, 358 | connMap: make(map[string]*TunnelConn), 359 | } 360 | 361 | // Test with short packet 362 | tunnel.handleIncomingPacket([]byte("short")) 363 | // Should not panic 364 | 365 | // Test with non-IPv4 packet 366 | packet := make([]byte, 40) 367 | packet[0] = 0x60 // IPv6 368 | tunnel.handleIncomingPacket(packet) 369 | // Should not panic 370 | 371 | // Test with non-TCP packet 372 | packet[0] = 0x45 // IPv4 373 | packet[9] = 17 // UDP 374 | tunnel.handleIncomingPacket(packet) 375 | // Should not panic 376 | 377 | // Test with too short for TCP 378 | packet[9] = 6 // TCP 379 | shortPacket := packet[:23] 380 | tunnel.handleIncomingPacket(shortPacket) 381 | // Should not panic 382 | } 383 | 384 | func TestTunnelConn_Implementation(t *testing.T) { 385 | readChan := make(chan []byte, 10) 386 | writeChan := make(chan []byte, 10) 387 | 388 | localAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:8080") 389 | remoteAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:9090") 390 | 391 | conn := &TunnelConn{ 392 | localAddr: localAddr, 393 | remoteAddr: remoteAddr, 394 | readChan: readChan, 395 | writeChan: writeChan, 396 | } 397 | 398 | // Test addresses 399 | if conn.LocalAddr() != localAddr { 400 | t.Errorf("LocalAddr() = %v, want %v", conn.LocalAddr(), localAddr) 401 | } 402 | 403 | if conn.RemoteAddr() != remoteAddr { 404 | t.Errorf("RemoteAddr() = %v, want %v", conn.RemoteAddr(), remoteAddr) 405 | } 406 | 407 | // Test Write 408 | testData := []byte("test data") 409 | n, err := conn.Write(testData) 410 | if err != nil { 411 | t.Errorf("Write() returned error: %v", err) 412 | } 413 | if n != len(testData) { 414 | t.Errorf("Write() returned %d, want %d", n, len(testData)) 415 | } 416 | 417 | // Check data was written to channel 418 | select { 419 | case data := <-writeChan: 420 | if string(data) != string(testData) { 421 | t.Errorf("written data = %q, want %q", string(data), string(testData)) 422 | } 423 | case <-time.After(100 * time.Millisecond): 424 | t.Error("no data written to channel") 425 | } 426 | 427 | // Test Read 428 | readData := []byte("read test data") 429 | readChan <- readData 430 | 431 | buf := make([]byte, 100) 432 | n, err = conn.Read(buf) 433 | if err != nil { 434 | t.Errorf("Read() returned error: %v", err) 435 | } 436 | if n != len(readData) { 437 | t.Errorf("Read() returned %d bytes, want %d", n, len(readData)) 438 | } 439 | if string(buf[:n]) != string(readData) { 440 | t.Errorf("read data = %q, want %q", string(buf[:n]), string(readData)) 441 | } 442 | 443 | // Test deadline methods (should not return error) 444 | if err := conn.SetDeadline(time.Now()); err != nil { 445 | t.Errorf("SetDeadline() returned error: %v", err) 446 | } 447 | if err := conn.SetReadDeadline(time.Now()); err != nil { 448 | t.Errorf("SetReadDeadline() returned error: %v", err) 449 | } 450 | if err := conn.SetWriteDeadline(time.Now()); err != nil { 451 | t.Errorf("SetWriteDeadline() returned error: %v", err) 452 | } 453 | 454 | // Test Close 455 | err = conn.Close() 456 | if err != nil { 457 | t.Errorf("Close() returned error: %v", err) 458 | } 459 | 460 | if !conn.closed { 461 | t.Error("connection should be marked as closed") 462 | } 463 | 464 | // Test Read after close 465 | _, err = conn.Read(buf) 466 | if err == nil { 467 | t.Error("Read() should return error after close") 468 | } 469 | 470 | // Multiple closes should not panic 471 | err = conn.Close() 472 | if err != nil { 473 | t.Errorf("second Close() returned error: %v", err) 474 | } 475 | } 476 | 477 | func TestTunnelConn_WriteBufferFull(t *testing.T) { 478 | // Create connection with small buffer 479 | writeChan := make(chan []byte, 1) 480 | 481 | conn := &TunnelConn{ 482 | writeChan: writeChan, 483 | } 484 | 485 | // Fill the buffer 486 | _, err := conn.Write([]byte("first")) 487 | if err != nil { 488 | t.Errorf("first Write() returned error: %v", err) 489 | } 490 | 491 | // Second write should fail due to full buffer 492 | _, err = conn.Write([]byte("second")) 493 | if err == nil { 494 | t.Error("Write() should return error when buffer is full") 495 | } 496 | } 497 | 498 | func TestMustParsePort(t *testing.T) { 499 | tests := []struct { 500 | input string 501 | expected int 502 | }{ 503 | {"80", 80}, 504 | {"8080", 8080}, 505 | {"443", 443}, 506 | {"0", 0}, 507 | {"invalid", 0}, // strconv.Atoi returns 0 for invalid input 508 | } 509 | 510 | for _, tt := range tests { 511 | t.Run(tt.input, func(t *testing.T) { 512 | result := mustParsePort(tt.input) 513 | if result != tt.expected { 514 | t.Errorf("mustParsePort(%q) = %d, want %d", tt.input, result, tt.expected) 515 | } 516 | }) 517 | } 518 | } 519 | 520 | // Integration test for tunnel creation (may fail due to WireGuard dependencies) 521 | func TestNewTunnel_Integration(t *testing.T) { 522 | // This test may fail in CI/test environments without proper WireGuard setup 523 | // but tests the tunnel creation logic 524 | 525 | config := &WireGuardConfig{ 526 | Interface: InterfaceConfig{ 527 | PrivateKey: "cGluZy1wcml2YXRlLWtleS0xMjM0NTY3ODkwMTIzNDU2Nzg5MDEyMzQ1Njc4OTA=", // base64 encoded 32 bytes 528 | Address: "10.150.0.2/24", 529 | }, 530 | Peers: []PeerConfig{ 531 | { 532 | PublicKey: "cGluZy1wdWJsaWMta2V5LTEyMzQ1Njc4OTAxMjM0NTY3ODkwMTIzNDU2Nzg5MDEy", // base64 encoded 32 bytes 533 | Endpoint: "127.0.0.1:51820", 534 | AllowedIPs: []string{"0.0.0.0/0"}, 535 | }, 536 | }, 537 | } 538 | 539 | ctx := context.Background() 540 | tunnel, err := NewTunnel(ctx, config) 541 | 542 | // In test environment, this will likely fail due to missing WireGuard setup 543 | // but we test that it doesn't panic and handles errors gracefully 544 | if err != nil { 545 | t.Logf("NewTunnel failed as expected in test environment: %v", err) 546 | return 547 | } 548 | 549 | if tunnel == nil { 550 | t.Error("NewTunnel returned nil tunnel without error") 551 | return 552 | } 553 | 554 | // Test tunnel properties 555 | expectedIP, _ := config.GetInterfaceIP() 556 | if tunnel.ourIP != expectedIP { 557 | t.Errorf("tunnel.ourIP = %v, want %v", tunnel.ourIP, expectedIP) 558 | } 559 | 560 | if tunnel.device == nil { 561 | t.Error("tunnel.device is nil") 562 | } 563 | 564 | if tunnel.tun == nil { 565 | t.Error("tunnel.tun is nil") 566 | } 567 | 568 | if tunnel.connMap == nil { 569 | t.Error("tunnel.connMap is nil") 570 | } 571 | 572 | // Clean up 573 | tunnel.Close() 574 | } 575 | 576 | // Test tunnel close 577 | func TestTunnel_Close(t *testing.T) { 578 | tun := NewMemoryTUN("test", 1420) 579 | tunnel := &Tunnel{ 580 | tun: tun, 581 | // device: nil, // Don't create actual WireGuard device in test 582 | } 583 | 584 | err := tunnel.Close() 585 | if err != nil { 586 | t.Errorf("Close() returned error: %v", err) 587 | } 588 | 589 | // TUN should be closed 590 | if !tun.closed { 591 | t.Error("TUN should be closed after tunnel close") 592 | } 593 | } 594 | -------------------------------------------------------------------------------- /config_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/base64" 5 | "net/netip" 6 | "os" 7 | "testing" 8 | ) 9 | 10 | func TestParseConfig(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | config string 14 | expectError bool 15 | validate func(*WireGuardConfig) error 16 | }{ 17 | { 18 | name: "valid basic config", 19 | config: `[Interface] 20 | PrivateKey = ` + generateTestKey() + ` 21 | Address = 10.0.0.2/24 22 | 23 | [Peer] 24 | PublicKey = ` + generateTestKey() + ` 25 | Endpoint = 192.168.1.1:51820 26 | AllowedIPs = 0.0.0.0/0`, 27 | expectError: false, 28 | validate: func(c *WireGuardConfig) error { 29 | if c.Interface.Address != "10.0.0.2/24" { 30 | t.Errorf("expected address 10.0.0.2/24, got %s", c.Interface.Address) 31 | } 32 | if len(c.Peers) != 1 { 33 | t.Errorf("expected 1 peer, got %d", len(c.Peers)) 34 | } 35 | if c.Peers[0].Endpoint != "192.168.1.1:51820" { 36 | t.Errorf("expected endpoint 192.168.1.1:51820, got %s", c.Peers[0].Endpoint) 37 | } 38 | return nil 39 | }, 40 | }, 41 | { 42 | name: "config with DNS", 43 | config: `[Interface] 44 | PrivateKey = ` + generateTestKey() + ` 45 | Address = 10.0.0.2/24 46 | DNS = 1.1.1.1, 8.8.8.8 47 | 48 | [Peer] 49 | PublicKey = ` + generateTestKey() + ` 50 | Endpoint = 127.0.0.1:51820 51 | AllowedIPs = 0.0.0.0/0`, 52 | expectError: false, 53 | validate: func(c *WireGuardConfig) error { 54 | if len(c.Interface.DNS) != 2 { 55 | t.Errorf("expected 2 DNS servers, got %d", len(c.Interface.DNS)) 56 | } 57 | if c.Interface.DNS[0] != "1.1.1.1" || c.Interface.DNS[1] != "8.8.8.8" { 58 | t.Errorf("unexpected DNS servers: %v", c.Interface.DNS) 59 | } 60 | return nil 61 | }, 62 | }, 63 | { 64 | name: "config with listen port", 65 | config: `[Interface] 66 | PrivateKey = ` + generateTestKey() + ` 67 | Address = 10.0.0.2/24 68 | ListenPort = 51820 69 | 70 | [Peer] 71 | PublicKey = ` + generateTestKey() + ` 72 | Endpoint = 192.168.1.1:51820 73 | AllowedIPs = 0.0.0.0/0`, 74 | expectError: false, 75 | validate: func(c *WireGuardConfig) error { 76 | if c.Interface.ListenPort != 51820 { 77 | t.Errorf("expected listen port 51820, got %d", c.Interface.ListenPort) 78 | } 79 | return nil 80 | }, 81 | }, 82 | { 83 | name: "config with preshared key", 84 | config: `[Interface] 85 | PrivateKey = ` + generateTestKey() + ` 86 | Address = 10.0.0.2/24 87 | 88 | [Peer] 89 | PublicKey = ` + generateTestKey() + ` 90 | PresharedKey = ` + generateTestKey() + ` 91 | Endpoint = 192.168.1.1:51820 92 | AllowedIPs = 0.0.0.0/0`, 93 | expectError: false, 94 | validate: func(c *WireGuardConfig) error { 95 | if c.Peers[0].PresharedKey == "" { 96 | t.Error("expected preshared key to be set") 97 | } 98 | return nil 99 | }, 100 | }, 101 | { 102 | name: "config with keepalive", 103 | config: `[Interface] 104 | PrivateKey = ` + generateTestKey() + ` 105 | Address = 10.0.0.2/24 106 | 107 | [Peer] 108 | PublicKey = ` + generateTestKey() + ` 109 | Endpoint = 192.168.1.1:51820 110 | AllowedIPs = 0.0.0.0/0 111 | PersistentKeepalive = 25`, 112 | expectError: false, 113 | validate: func(c *WireGuardConfig) error { 114 | if c.Peers[0].PersistentKeepalive != 25 { 115 | t.Errorf("expected keepalive 25, got %d", c.Peers[0].PersistentKeepalive) 116 | } 117 | return nil 118 | }, 119 | }, 120 | { 121 | name: "config with multiple peers", 122 | config: `[Interface] 123 | PrivateKey = ` + generateTestKey() + ` 124 | Address = 10.0.0.2/24 125 | 126 | [Peer] 127 | PublicKey = ` + generateTestKey() + ` 128 | Endpoint = 192.168.1.1:51820 129 | AllowedIPs = 10.0.0.0/24 130 | 131 | [Peer] 132 | PublicKey = ` + generateTestKey() + ` 133 | Endpoint = 192.168.1.2:51820 134 | AllowedIPs = 10.1.0.0/24`, 135 | expectError: false, 136 | validate: func(c *WireGuardConfig) error { 137 | if len(c.Peers) != 2 { 138 | t.Errorf("expected 2 peers, got %d", len(c.Peers)) 139 | } 140 | return nil 141 | }, 142 | }, 143 | { 144 | name: "config with comments and empty lines", 145 | config: `# This is a comment 146 | [Interface] 147 | # Interface configuration 148 | PrivateKey = ` + generateTestKey() + ` 149 | Address = 10.0.0.2/24 150 | 151 | # Peer configuration 152 | [Peer] 153 | PublicKey = ` + generateTestKey() + ` 154 | Endpoint = 192.168.1.1:51820 155 | AllowedIPs = 0.0.0.0/0 156 | # End of config`, 157 | expectError: false, 158 | validate: func(c *WireGuardConfig) error { 159 | if c.Interface.Address != "10.0.0.2/24" { 160 | t.Errorf("expected address 10.0.0.2/24, got %s", c.Interface.Address) 161 | } 162 | return nil 163 | }, 164 | }, 165 | { 166 | name: "missing private key", 167 | config: `[Interface] 168 | Address = 10.0.0.2/24 169 | 170 | [Peer] 171 | PublicKey = ` + generateTestKey() + ` 172 | Endpoint = 192.168.1.1:51820 173 | AllowedIPs = 0.0.0.0/0`, 174 | expectError: true, 175 | }, 176 | { 177 | name: "missing address", 178 | config: `[Interface] 179 | PrivateKey = ` + generateTestKey() + ` 180 | 181 | [Peer] 182 | PublicKey = ` + generateTestKey() + ` 183 | Endpoint = 192.168.1.1:51820 184 | AllowedIPs = 0.0.0.0/0`, 185 | expectError: true, 186 | }, 187 | { 188 | name: "missing peer", 189 | config: `[Interface] 190 | PrivateKey = ` + generateTestKey() + ` 191 | Address = 10.0.0.2/24`, 192 | expectError: true, 193 | }, 194 | { 195 | name: "missing peer public key", 196 | config: `[Interface] 197 | PrivateKey = ` + generateTestKey() + ` 198 | Address = 10.0.0.2/24 199 | 200 | [Peer] 201 | Endpoint = 192.168.1.1:51820 202 | AllowedIPs = 0.0.0.0/0`, 203 | expectError: true, 204 | }, 205 | { 206 | name: "missing peer allowed IPs", 207 | config: `[Interface] 208 | PrivateKey = ` + generateTestKey() + ` 209 | Address = 10.0.0.2/24 210 | 211 | [Peer] 212 | PublicKey = ` + generateTestKey() + ` 213 | Endpoint = 192.168.1.1:51820`, 214 | expectError: true, 215 | }, 216 | { 217 | name: "invalid private key", 218 | config: `[Interface] 219 | PrivateKey = invalid-key 220 | Address = 10.0.0.2/24 221 | 222 | [Peer] 223 | PublicKey = ` + generateTestKey() + ` 224 | Endpoint = 192.168.1.1:51820 225 | AllowedIPs = 0.0.0.0/0`, 226 | expectError: true, 227 | }, 228 | { 229 | name: "invalid address format", 230 | config: `[Interface] 231 | PrivateKey = ` + generateTestKey() + ` 232 | Address = invalid-address 233 | 234 | [Peer] 235 | PublicKey = ` + generateTestKey() + ` 236 | Endpoint = 192.168.1.1:51820 237 | AllowedIPs = 0.0.0.0/0`, 238 | expectError: true, 239 | }, 240 | { 241 | name: "invalid allowed IP format", 242 | config: `[Interface] 243 | PrivateKey = ` + generateTestKey() + ` 244 | Address = 10.0.0.2/24 245 | 246 | [Peer] 247 | PublicKey = ` + generateTestKey() + ` 248 | Endpoint = 192.168.1.1:51820 249 | AllowedIPs = invalid-ip`, 250 | expectError: true, 251 | }, 252 | { 253 | name: "invalid listen port", 254 | config: `[Interface] 255 | PrivateKey = ` + generateTestKey() + ` 256 | Address = 10.0.0.2/24 257 | ListenPort = invalid-port 258 | 259 | [Peer] 260 | PublicKey = ` + generateTestKey() + ` 261 | Endpoint = 192.168.1.1:51820 262 | AllowedIPs = 0.0.0.0/0`, 263 | expectError: true, 264 | }, 265 | { 266 | name: "invalid keepalive", 267 | config: `[Interface] 268 | PrivateKey = ` + generateTestKey() + ` 269 | Address = 10.0.0.2/24 270 | 271 | [Peer] 272 | PublicKey = ` + generateTestKey() + ` 273 | Endpoint = 192.168.1.1:51820 274 | AllowedIPs = 0.0.0.0/0 275 | PersistentKeepalive = invalid-keepalive`, 276 | expectError: true, 277 | }, 278 | } 279 | 280 | for _, tt := range tests { 281 | t.Run(tt.name, func(t *testing.T) { 282 | // Create temporary config file 283 | tempFile, err := os.CreateTemp("", "wg-test-*.conf") 284 | if err != nil { 285 | t.Fatalf("failed to create temp file: %v", err) 286 | } 287 | defer os.Remove(tempFile.Name()) 288 | 289 | if _, err := tempFile.WriteString(tt.config); err != nil { 290 | t.Fatalf("failed to write config: %v", err) 291 | } 292 | tempFile.Close() 293 | 294 | // Parse config 295 | config, err := ParseConfig(tempFile.Name()) 296 | 297 | if tt.expectError { 298 | if err == nil { 299 | t.Errorf("expected error but got none") 300 | } 301 | return 302 | } 303 | 304 | if err != nil { 305 | t.Errorf("unexpected error: %v", err) 306 | return 307 | } 308 | 309 | if config == nil { 310 | t.Error("config is nil") 311 | return 312 | } 313 | 314 | // Run validation if provided 315 | if tt.validate != nil { 316 | if err := tt.validate(config); err != nil { 317 | t.Errorf("validation failed: %v", err) 318 | } 319 | } 320 | }) 321 | } 322 | } 323 | 324 | func TestParseConfigFileNotFound(t *testing.T) { 325 | _, err := ParseConfig("/nonexistent/file.conf") 326 | if err == nil { 327 | t.Error("expected error for nonexistent file") 328 | } 329 | } 330 | 331 | func TestGetInterfaceIP(t *testing.T) { 332 | config := &WireGuardConfig{ 333 | Interface: InterfaceConfig{ 334 | Address: "10.0.0.2/24", 335 | }, 336 | } 337 | 338 | ip, err := config.GetInterfaceIP() 339 | if err != nil { 340 | t.Errorf("unexpected error: %v", err) 341 | } 342 | 343 | expected, _ := netip.ParseAddr("10.0.0.2") 344 | if ip != expected { 345 | t.Errorf("expected IP %v, got %v", expected, ip) 346 | } 347 | } 348 | 349 | func TestGetInterfaceIPInvalid(t *testing.T) { 350 | config := &WireGuardConfig{ 351 | Interface: InterfaceConfig{ 352 | Address: "invalid-address", 353 | }, 354 | } 355 | 356 | _, err := config.GetInterfaceIP() 357 | if err == nil { 358 | t.Error("expected error for invalid address") 359 | } 360 | } 361 | 362 | func TestGetInterfacePrefix(t *testing.T) { 363 | config := &WireGuardConfig{ 364 | Interface: InterfaceConfig{ 365 | Address: "10.0.0.2/24", 366 | }, 367 | } 368 | 369 | prefix, err := config.GetInterfacePrefix() 370 | if err != nil { 371 | t.Errorf("unexpected error: %v", err) 372 | } 373 | 374 | expected, _ := netip.ParsePrefix("10.0.0.2/24") 375 | if prefix != expected { 376 | t.Errorf("expected prefix %v, got %v", expected, prefix) 377 | } 378 | } 379 | 380 | func TestBase64ToHex(t *testing.T) { 381 | tests := []struct { 382 | name string 383 | input string 384 | expectError bool 385 | expectedLen int 386 | }{ 387 | { 388 | name: "valid key", 389 | input: generateTestKey(), 390 | expectError: false, 391 | expectedLen: 64, // 32 bytes = 64 hex chars 392 | }, 393 | { 394 | name: "invalid base64", 395 | input: "invalid-base64!@#", 396 | expectError: true, 397 | }, 398 | { 399 | name: "wrong length", 400 | input: base64.StdEncoding.EncodeToString([]byte("short")), 401 | expectError: true, 402 | }, 403 | { 404 | name: "empty key", 405 | input: "", 406 | expectError: true, 407 | }, 408 | } 409 | 410 | for _, tt := range tests { 411 | t.Run(tt.name, func(t *testing.T) { 412 | result, err := base64ToHex(tt.input) 413 | 414 | if tt.expectError { 415 | if err == nil { 416 | t.Error("expected error but got none") 417 | } 418 | return 419 | } 420 | 421 | if err != nil { 422 | t.Errorf("unexpected error: %v", err) 423 | return 424 | } 425 | 426 | if len(result) != tt.expectedLen { 427 | t.Errorf("expected length %d, got %d", tt.expectedLen, len(result)) 428 | } 429 | }) 430 | } 431 | } 432 | 433 | func TestResolveEndpoint(t *testing.T) { 434 | tests := []struct { 435 | name string 436 | endpoint string 437 | expectError bool 438 | }{ 439 | { 440 | name: "IP endpoint", 441 | endpoint: "192.168.1.1:51820", 442 | expectError: false, 443 | }, 444 | { 445 | name: "localhost endpoint", 446 | endpoint: "localhost:51820", 447 | expectError: false, 448 | }, 449 | { 450 | name: "invalid format", 451 | endpoint: "invalid-endpoint", 452 | expectError: true, 453 | }, 454 | { 455 | name: "nonexistent hostname", 456 | endpoint: "nonexistent-hostname-12345.invalid:51820", 457 | expectError: true, 458 | }, 459 | } 460 | 461 | for _, tt := range tests { 462 | t.Run(tt.name, func(t *testing.T) { 463 | result, err := resolveEndpoint(tt.endpoint) 464 | 465 | if tt.expectError { 466 | if err == nil { 467 | t.Error("expected error but got none") 468 | } 469 | return 470 | } 471 | 472 | if err != nil { 473 | t.Errorf("unexpected error: %v", err) 474 | return 475 | } 476 | 477 | if result == "" { 478 | t.Error("expected non-empty result") 479 | } 480 | }) 481 | } 482 | } 483 | 484 | func TestParseInterfaceField(t *testing.T) { 485 | tests := []struct { 486 | name string 487 | key string 488 | value string 489 | expectError bool 490 | validate func(*InterfaceConfig) error 491 | }{ 492 | { 493 | name: "private key", 494 | key: "PrivateKey", 495 | value: generateTestKey(), 496 | expectError: false, 497 | validate: func(iface *InterfaceConfig) error { 498 | if iface.PrivateKey == "" { 499 | t.Error("private key not set") 500 | } 501 | return nil 502 | }, 503 | }, 504 | { 505 | name: "address", 506 | key: "Address", 507 | value: "10.0.0.2/24", 508 | expectError: false, 509 | validate: func(iface *InterfaceConfig) error { 510 | if iface.Address != "10.0.0.2/24" { 511 | t.Errorf("expected address 10.0.0.2/24, got %s", iface.Address) 512 | } 513 | return nil 514 | }, 515 | }, 516 | { 517 | name: "DNS", 518 | key: "DNS", 519 | value: "1.1.1.1, 8.8.8.8", 520 | expectError: false, 521 | validate: func(iface *InterfaceConfig) error { 522 | if len(iface.DNS) != 2 { 523 | t.Errorf("expected 2 DNS servers, got %d", len(iface.DNS)) 524 | } 525 | return nil 526 | }, 527 | }, 528 | { 529 | name: "listen port", 530 | key: "ListenPort", 531 | value: "51820", 532 | expectError: false, 533 | validate: func(iface *InterfaceConfig) error { 534 | if iface.ListenPort != 51820 { 535 | t.Errorf("expected listen port 51820, got %d", iface.ListenPort) 536 | } 537 | return nil 538 | }, 539 | }, 540 | { 541 | name: "invalid private key", 542 | key: "PrivateKey", 543 | value: "invalid-key", 544 | expectError: true, 545 | }, 546 | { 547 | name: "invalid listen port", 548 | key: "ListenPort", 549 | value: "invalid-port", 550 | expectError: true, 551 | }, 552 | } 553 | 554 | for _, tt := range tests { 555 | t.Run(tt.name, func(t *testing.T) { 556 | iface := &InterfaceConfig{} 557 | err := parseInterfaceField(iface, tt.key, tt.value) 558 | 559 | if tt.expectError { 560 | if err == nil { 561 | t.Error("expected error but got none") 562 | } 563 | return 564 | } 565 | 566 | if err != nil { 567 | t.Errorf("unexpected error: %v", err) 568 | return 569 | } 570 | 571 | if tt.validate != nil { 572 | if err := tt.validate(iface); err != nil { 573 | t.Errorf("validation failed: %v", err) 574 | } 575 | } 576 | }) 577 | } 578 | } 579 | 580 | func TestParsePeerField(t *testing.T) { 581 | tests := []struct { 582 | name string 583 | key string 584 | value string 585 | expectError bool 586 | validate func(*PeerConfig) error 587 | }{ 588 | { 589 | name: "public key", 590 | key: "PublicKey", 591 | value: generateTestKey(), 592 | expectError: false, 593 | validate: func(peer *PeerConfig) error { 594 | if peer.PublicKey == "" { 595 | t.Error("public key not set") 596 | } 597 | return nil 598 | }, 599 | }, 600 | { 601 | name: "preshared key", 602 | key: "PresharedKey", 603 | value: generateTestKey(), 604 | expectError: false, 605 | validate: func(peer *PeerConfig) error { 606 | if peer.PresharedKey == "" { 607 | t.Error("preshared key not set") 608 | } 609 | return nil 610 | }, 611 | }, 612 | { 613 | name: "endpoint", 614 | key: "Endpoint", 615 | value: "192.168.1.1:51820", 616 | expectError: false, 617 | validate: func(peer *PeerConfig) error { 618 | if peer.Endpoint != "192.168.1.1:51820" { 619 | t.Errorf("expected endpoint 192.168.1.1:51820, got %s", peer.Endpoint) 620 | } 621 | return nil 622 | }, 623 | }, 624 | { 625 | name: "allowed IPs", 626 | key: "AllowedIPs", 627 | value: "0.0.0.0/0, 10.0.0.0/24", 628 | expectError: false, 629 | validate: func(peer *PeerConfig) error { 630 | if len(peer.AllowedIPs) != 2 { 631 | t.Errorf("expected 2 allowed IPs, got %d", len(peer.AllowedIPs)) 632 | } 633 | return nil 634 | }, 635 | }, 636 | { 637 | name: "persistent keepalive", 638 | key: "PersistentKeepalive", 639 | value: "25", 640 | expectError: false, 641 | validate: func(peer *PeerConfig) error { 642 | if peer.PersistentKeepalive != 25 { 643 | t.Errorf("expected keepalive 25, got %d", peer.PersistentKeepalive) 644 | } 645 | return nil 646 | }, 647 | }, 648 | { 649 | name: "invalid public key", 650 | key: "PublicKey", 651 | value: "invalid-key", 652 | expectError: true, 653 | }, 654 | { 655 | name: "invalid endpoint", 656 | key: "Endpoint", 657 | value: "invalid-endpoint", 658 | expectError: true, 659 | }, 660 | { 661 | name: "invalid keepalive", 662 | key: "PersistentKeepalive", 663 | value: "invalid-keepalive", 664 | expectError: true, 665 | }, 666 | } 667 | 668 | for _, tt := range tests { 669 | t.Run(tt.name, func(t *testing.T) { 670 | peer := &PeerConfig{} 671 | err := parsePeerField(peer, tt.key, tt.value) 672 | 673 | if tt.expectError { 674 | if err == nil { 675 | t.Error("expected error but got none") 676 | } 677 | return 678 | } 679 | 680 | if err != nil { 681 | t.Errorf("unexpected error: %v", err) 682 | return 683 | } 684 | 685 | if tt.validate != nil { 686 | if err := tt.validate(peer); err != nil { 687 | t.Errorf("validation failed: %v", err) 688 | } 689 | } 690 | }) 691 | } 692 | } 693 | 694 | func TestValidateConfig(t *testing.T) { 695 | tests := []struct { 696 | name string 697 | config *WireGuardConfig 698 | expectError bool 699 | }{ 700 | { 701 | name: "valid config", 702 | config: &WireGuardConfig{ 703 | Interface: InterfaceConfig{ 704 | PrivateKey: "test-key", 705 | Address: "10.0.0.2/24", 706 | }, 707 | Peers: []PeerConfig{ 708 | { 709 | PublicKey: "test-public-key", 710 | AllowedIPs: []string{"0.0.0.0/0"}, 711 | }, 712 | }, 713 | }, 714 | expectError: false, 715 | }, 716 | { 717 | name: "missing private key", 718 | config: &WireGuardConfig{ 719 | Interface: InterfaceConfig{ 720 | Address: "10.0.0.2/24", 721 | }, 722 | Peers: []PeerConfig{ 723 | { 724 | PublicKey: "test-public-key", 725 | AllowedIPs: []string{"0.0.0.0/0"}, 726 | }, 727 | }, 728 | }, 729 | expectError: true, 730 | }, 731 | { 732 | name: "missing address", 733 | config: &WireGuardConfig{ 734 | Interface: InterfaceConfig{ 735 | PrivateKey: "test-key", 736 | }, 737 | Peers: []PeerConfig{ 738 | { 739 | PublicKey: "test-public-key", 740 | AllowedIPs: []string{"0.0.0.0/0"}, 741 | }, 742 | }, 743 | }, 744 | expectError: true, 745 | }, 746 | { 747 | name: "invalid address format", 748 | config: &WireGuardConfig{ 749 | Interface: InterfaceConfig{ 750 | PrivateKey: "test-key", 751 | Address: "invalid-address", 752 | }, 753 | Peers: []PeerConfig{ 754 | { 755 | PublicKey: "test-public-key", 756 | AllowedIPs: []string{"0.0.0.0/0"}, 757 | }, 758 | }, 759 | }, 760 | expectError: true, 761 | }, 762 | { 763 | name: "no peers", 764 | config: &WireGuardConfig{ 765 | Interface: InterfaceConfig{ 766 | PrivateKey: "test-key", 767 | Address: "10.0.0.2/24", 768 | }, 769 | Peers: []PeerConfig{}, 770 | }, 771 | expectError: true, 772 | }, 773 | { 774 | name: "peer missing public key", 775 | config: &WireGuardConfig{ 776 | Interface: InterfaceConfig{ 777 | PrivateKey: "test-key", 778 | Address: "10.0.0.2/24", 779 | }, 780 | Peers: []PeerConfig{ 781 | { 782 | AllowedIPs: []string{"0.0.0.0/0"}, 783 | }, 784 | }, 785 | }, 786 | expectError: true, 787 | }, 788 | { 789 | name: "peer missing allowed IPs", 790 | config: &WireGuardConfig{ 791 | Interface: InterfaceConfig{ 792 | PrivateKey: "test-key", 793 | Address: "10.0.0.2/24", 794 | }, 795 | Peers: []PeerConfig{ 796 | { 797 | PublicKey: "test-public-key", 798 | }, 799 | }, 800 | }, 801 | expectError: true, 802 | }, 803 | { 804 | name: "peer invalid allowed IP", 805 | config: &WireGuardConfig{ 806 | Interface: InterfaceConfig{ 807 | PrivateKey: "test-key", 808 | Address: "10.0.0.2/24", 809 | }, 810 | Peers: []PeerConfig{ 811 | { 812 | PublicKey: "test-public-key", 813 | AllowedIPs: []string{"invalid-ip"}, 814 | }, 815 | }, 816 | }, 817 | expectError: true, 818 | }, 819 | } 820 | 821 | for _, tt := range tests { 822 | t.Run(tt.name, func(t *testing.T) { 823 | err := validateConfig(tt.config) 824 | 825 | if tt.expectError { 826 | if err == nil { 827 | t.Error("expected error but got none") 828 | } 829 | } else { 830 | if err != nil { 831 | t.Errorf("unexpected error: %v", err) 832 | } 833 | } 834 | }) 835 | } 836 | } 837 | 838 | // Helper function to generate a test WireGuard key 839 | func generateTestKey() string { 840 | // Generate 32 random bytes and encode as base64 841 | key := make([]byte, 32) 842 | for i := range key { 843 | key[i] = byte(i) 844 | } 845 | return base64.StdEncoding.EncodeToString(key) 846 | } 847 | 848 | // Benchmark tests for performance 849 | func BenchmarkParseConfig(b *testing.B) { 850 | config := `[Interface] 851 | PrivateKey = ` + generateTestKey() + ` 852 | Address = 10.0.0.2/24 853 | 854 | [Peer] 855 | PublicKey = ` + generateTestKey() + ` 856 | Endpoint = 192.168.1.1:51820 857 | AllowedIPs = 0.0.0.0/0` 858 | 859 | tempFile, err := os.CreateTemp("", "wg-bench-*.conf") 860 | if err != nil { 861 | b.Fatalf("failed to create temp file: %v", err) 862 | } 863 | defer os.Remove(tempFile.Name()) 864 | 865 | if _, err := tempFile.WriteString(config); err != nil { 866 | b.Fatalf("failed to write config: %v", err) 867 | } 868 | tempFile.Close() 869 | 870 | b.ResetTimer() 871 | for i := 0; i < b.N; i++ { 872 | _, err := ParseConfig(tempFile.Name()) 873 | if err != nil { 874 | b.Fatalf("parse error: %v", err) 875 | } 876 | } 877 | } 878 | 879 | func BenchmarkBase64ToHex(b *testing.B) { 880 | key := generateTestKey() 881 | b.ResetTimer() 882 | for i := 0; i < b.N; i++ { 883 | _, err := base64ToHex(key) 884 | if err != nil { 885 | b.Fatalf("conversion error: %v", err) 886 | } 887 | } 888 | } 889 | --------------------------------------------------------------------------------