├── .claude └── settings.local.json ├── .github └── workflows │ ├── release.yml │ └── test.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── config.go ├── config_ipv6_test.go ├── config_test.go ├── demo.sh ├── example-wg0.conf ├── go.mod ├── go.sum ├── ipc.go ├── ipc_test.go ├── lib └── intercept.c ├── logging.go ├── logging_test.go ├── main.go ├── memorytun.go ├── network.go ├── network_test.go ├── test.sh ├── version.go ├── version_test.go └── wireguard.go /.claude/settings.local.json: -------------------------------------------------------------------------------- 1 | { 2 | "permissions": { 3 | "allow": [ 4 | "Bash(go mod:*)", 5 | "Bash(make:*)", 6 | "Bash(mkdir:*)", 7 | "Bash(mv:*)", 8 | "Bash(go build:*)", 9 | "Bash(go doc:*)", 10 | "Bash(true)", 11 | "Bash(go test:*)", 12 | "Bash(rm:*)", 13 | "Bash(gcc:*)", 14 | "Bash(chmod:*)", 15 | "Bash(ls:*)", 16 | "Bash(./wrapguard:*)", 17 | "Bash(git init:*)", 18 | "Bash(git add:*)", 19 | "Bash(gofmt:*)", 20 | "Bash(go vet:*)", 21 | "Bash(go get:*)", 22 | "Bash(go list:*)", 23 | "Bash(go clean:*)", 24 | "Bash(rg:*)", 25 | "Bash(git checkout:*)", 26 | "Bash(git commit:*)", 27 | "Bash(git push:*)", 28 | "Bash(touch:*)" 29 | ], 30 | "deny": [] 31 | }, 32 | "enableAllProjectMcpServers": false 33 | } -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | permissions: 8 | contents: write 9 | 10 | jobs: 11 | build-linux: 12 | name: Build Linux binaries 13 | runs-on: ubuntu-latest 14 | strategy: 15 | matrix: 16 | arch: [amd64, arm64] 17 | steps: 18 | - name: Checkout code 19 | uses: actions/checkout@v4 20 | 21 | - name: Set up Go 22 | uses: actions/setup-go@v5 23 | with: 24 | go-version: '1.21' 25 | 26 | - name: Install cross-compilation tools 27 | if: matrix.arch == 'arm64' 28 | run: | 29 | sudo apt-get update 30 | sudo apt-get install -y gcc-aarch64-linux-gnu 31 | 32 | - name: Build binaries 33 | run: | 34 | # Build Go binary 35 | GOOS=linux GOARCH=${{ matrix.arch }} CGO_ENABLED=0 \ 36 | go build -ldflags="-s -w -X main.Version=${{ github.event.release.tag_name }}" \ 37 | -o wrapguard . 38 | 39 | # Build C library 40 | if [ "${{ matrix.arch }}" = "arm64" ]; then 41 | aarch64-linux-gnu-gcc -fPIC -shared -Wall -O2 \ 42 | -o libwrapguard.so lib/intercept.c -ldl -lpthread 43 | else 44 | gcc -fPIC -shared -Wall -O2 \ 45 | -o libwrapguard.so lib/intercept.c -ldl -lpthread 46 | fi 47 | 48 | - name: Create release archive 49 | run: | 50 | chmod +x wrapguard 51 | tar -czf wrapguard-${{ github.event.release.tag_name }}-linux-${{ matrix.arch }}.tar.gz \ 52 | wrapguard libwrapguard.so README.md example-wg0.conf 53 | 54 | - name: Upload Release Asset 55 | uses: actions/upload-release-asset@v1 56 | env: 57 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 58 | with: 59 | upload_url: ${{ github.event.release.upload_url }} 60 | asset_path: ./wrapguard-${{ github.event.release.tag_name }}-linux-${{ matrix.arch }}.tar.gz 61 | asset_name: wrapguard-${{ github.event.release.tag_name }}-linux-${{ matrix.arch }}.tar.gz 62 | asset_content_type: application/gzip 63 | 64 | build-macos: 65 | name: Build macOS binaries 66 | runs-on: macos-latest 67 | strategy: 68 | matrix: 69 | arch: [amd64, arm64] 70 | steps: 71 | - name: Checkout code 72 | uses: actions/checkout@v4 73 | 74 | - name: Set up Go 75 | uses: actions/setup-go@v5 76 | with: 77 | go-version: '1.21' 78 | 79 | - name: Build binaries 80 | run: | 81 | # Build Go binary 82 | GOOS=darwin GOARCH=${{ matrix.arch }} CGO_ENABLED=0 \ 83 | go build -ldflags="-s -w -X main.Version=${{ github.event.release.tag_name }}" \ 84 | -o wrapguard . 85 | 86 | # Build C library (dylib for macOS) 87 | clang -fPIC -shared -Wall -O2 \ 88 | -o libwrapguard.dylib lib/intercept.c -ldl -lpthread 89 | 90 | - name: Create release archive 91 | run: | 92 | chmod +x wrapguard 93 | tar -czf wrapguard-${{ github.event.release.tag_name }}-darwin-${{ matrix.arch }}.tar.gz \ 94 | wrapguard libwrapguard.dylib README.md example-wg0.conf 95 | 96 | - name: Upload Release Asset 97 | uses: actions/upload-release-asset@v1 98 | env: 99 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 100 | with: 101 | upload_url: ${{ github.event.release.upload_url }} 102 | asset_path: ./wrapguard-${{ github.event.release.tag_name }}-darwin-${{ matrix.arch }}.tar.gz 103 | asset_name: wrapguard-${{ github.event.release.tag_name }}-darwin-${{ matrix.arch }}.tar.gz 104 | asset_content_type: application/gzip 105 | 106 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | test: 11 | name: Test 12 | runs-on: ubuntu-latest 13 | 14 | strategy: 15 | matrix: 16 | go-version: ['1.21', '1.22', '1.23'] 17 | 18 | steps: 19 | - name: Check out code 20 | uses: actions/checkout@v4 21 | 22 | - name: Set up Go 23 | uses: actions/setup-go@v5 24 | with: 25 | go-version: ${{ matrix.go-version }} 26 | 27 | - name: Cache Go modules 28 | uses: actions/cache@v4 29 | with: 30 | path: | 31 | ~/.cache/go-build 32 | ~/go/pkg/mod 33 | key: ${{ runner.os }}-go-${{ matrix.go-version }}-${{ hashFiles('**/go.sum') }} 34 | restore-keys: | 35 | ${{ runner.os }}-go-${{ matrix.go-version }}- 36 | 37 | - name: Download dependencies 38 | run: go mod download 39 | 40 | - name: Verify dependencies 41 | run: go mod verify 42 | 43 | - name: Run tests 44 | run: go test -v -race -coverprofile=coverage.out ./... 45 | 46 | - name: Run tests with coverage 47 | run: go test -cover ./... 48 | 49 | - name: Upload coverage reports to Codecov 50 | if: matrix.go-version == '1.23' 51 | uses: codecov/codecov-action@v4 52 | with: 53 | file: ./coverage.out 54 | flags: unittests 55 | name: codecov-umbrella 56 | fail_ci_if_error: false 57 | 58 | lint: 59 | name: Lint 60 | runs-on: ubuntu-latest 61 | 62 | steps: 63 | - name: Check out code 64 | uses: actions/checkout@v4 65 | 66 | - name: Set up Go 67 | uses: actions/setup-go@v5 68 | with: 69 | go-version: '1.23' 70 | 71 | - name: Run go vet 72 | run: go vet ./... 73 | 74 | - name: Check formatting 75 | run: | 76 | if [ "$(gofmt -s -l . | wc -l)" -gt 0 ]; then 77 | echo "Go files are not formatted:" 78 | gofmt -d . 79 | exit 1 80 | fi 81 | 82 | build: 83 | name: Build 84 | runs-on: ubuntu-latest 85 | 86 | steps: 87 | - name: Check out code 88 | uses: actions/checkout@v4 89 | 90 | - name: Set up Go 91 | uses: actions/setup-go@v5 92 | with: 93 | go-version: '1.23' 94 | 95 | - name: Install build dependencies 96 | run: sudo apt-get update && sudo apt-get install -y gcc 97 | 98 | - name: Build binary 99 | run: make build 100 | 101 | - name: Verify binary exists 102 | run: | 103 | ls -la wrapguard 104 | ls -la libwrapguard.so 105 | file wrapguard 106 | file libwrapguard.so 107 | 108 | - name: Test binary runs 109 | run: | 110 | ./wrapguard --version 111 | ./wrapguard --help 112 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries 2 | wrapguard 3 | libwrapguard.so 4 | *.so 5 | *.dylib 6 | *.dll 7 | 8 | # Go build artifacts 9 | *.exe 10 | *.exe~ 11 | *.test 12 | *.out 13 | 14 | # Go workspace files 15 | go.work 16 | go.work.sum 17 | 18 | # Dependency directories 19 | vendor/ 20 | 21 | # IDE and editor files 22 | .vscode/ 23 | .idea/ 24 | *.swp 25 | *.swo 26 | *~ 27 | .DS_Store 28 | 29 | # Test coverage 30 | *.prof 31 | coverage.txt 32 | coverage.html 33 | 34 | # Debug files 35 | *.log 36 | debug 37 | 38 | # Temporary files 39 | *.tmp 40 | *.temp 41 | /tmp/ 42 | 43 | # WireGuard config files with real keys 44 | *.conf 45 | !example-wg0.conf 46 | 47 | # Build directories 48 | /dist/ 49 | /build/ 50 | /release/ 51 | 52 | # Local environment files 53 | .env 54 | .env.local 55 | 56 | # OS specific files 57 | Thumbs.db 58 | .Spotlight-V100 59 | .Trashes 60 | 61 | # Claude AI files (keep these) 62 | # .claude/ 63 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Mark Wylde 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all clean build test 2 | 3 | # Output binary names 4 | BINARY = wrapguard 5 | UNAME_S := $(shell uname -s) 6 | ifeq ($(UNAME_S),Darwin) 7 | LIBRARY = libwrapguard.dylib 8 | else 9 | LIBRARY = libwrapguard.so 10 | endif 11 | 12 | # Go build flags 13 | GO_BUILD_FLAGS = -ldflags="-s -w" 14 | 15 | # C compiler flags 16 | CC = gcc 17 | CFLAGS = -fPIC -shared -Wall -O2 18 | 19 | all: build 20 | 21 | build: $(BINARY) $(LIBRARY) 22 | 23 | $(BINARY): *.go go.mod 24 | go build $(GO_BUILD_FLAGS) -o $(BINARY) . 25 | 26 | $(LIBRARY): lib/intercept.c 27 | $(CC) $(CFLAGS) -o $(LIBRARY) lib/intercept.c -ldl -lpthread 28 | 29 | test: build 30 | # Test basic connectivity 31 | ./test.sh 32 | 33 | clean: 34 | rm -f $(BINARY) $(LIBRARY) 35 | go clean -cache 36 | 37 | install: build 38 | install -m 755 $(BINARY) /usr/local/bin/ 39 | install -m 755 $(LIBRARY) /usr/local/lib/ 40 | 41 | .PHONY: deps 42 | deps: 43 | go mod download 44 | go mod tidy 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WrapGuard - Userspace WireGuard Proxy 2 | 3 | WrapGuard enables any application to transparently route ALL network traffic through a WireGuard VPN without requiring container privileges or kernel modules. 4 | 5 | ## Features 6 | 7 | - **Pure Userspace**: No TUN interface creation, no NET_ADMIN capability needed 8 | - **Transparent Interception**: Uses LD_PRELOAD to intercept all network calls 9 | - **Bidirectional Support**: Both incoming and outgoing connections work 10 | - **Standard Config**: Uses standard WireGuard configuration files 11 | 12 | ## Installation 13 | 14 | ### Pre-compiled Binaries 15 | 16 | Download pre-compiled binaries for Linux and macOS from the [releases page](https://github.com/puzed/wrapguard/releases). 17 | 18 | **No additional dependencies required** - WrapGuard is a single binary that includes everything needed to create WireGuard connections. You don't need WireGuard installed on your host machine, kernel modules, or any other VPN software. 19 | 20 | ### Building from Source 21 | 22 | ```bash 23 | make build 24 | ``` 25 | 26 | This will create: 27 | - `wrapguard` - The main executable (single binary with embedded library) 28 | - `libwrapguard.so` - The LD_PRELOAD library 29 | 30 | ## Usage 31 | 32 | ```bash 33 | # Route outgoing connections through WireGuard 34 | wrapguard --config=~/wg0.conf -- curl https://icanhazip.com 35 | 36 | # Route incoming connections through WireGuard 37 | wrapguard --config=~/wg0.conf -- node -e 'http.createServer().listen(8080)' 38 | 39 | # With debug logging to console 40 | wrapguard --config=~/wg0.conf --log-level=debug -- curl https://icanhazip.com 41 | 42 | # With logging to file 43 | wrapguard --config=~/wg0.conf --log-level=info --log-file=/tmp/wrapguard.log -- curl https://icanhazip.com 44 | ``` 45 | 46 | ## Logging 47 | 48 | WrapGuard provides structured JSON logging with configurable levels and output destinations. 49 | 50 | ### Logging Options 51 | 52 | - `--log-level=` - Set logging level (error, warn, info, debug). Default: info 53 | - `--log-file=` - Write logs to file instead of terminal 54 | 55 | ### Log Levels 56 | 57 | - `error` - Only critical errors 58 | - `warn` - Warnings and errors 59 | - `info` - General information, warnings, and errors (default) 60 | - `debug` - Detailed debugging information 61 | 62 | ### Log Format 63 | 64 | All logs are output in structured JSON format with timestamps: 65 | 66 | ```json 67 | {"timestamp":"2025-05-26T10:00:00Z","level":"info","message":"WrapGuard v1.0.0-dev initialized"} 68 | {"timestamp":"2025-05-26T10:00:00Z","level":"info","message":"Config: example-wg0.conf"} 69 | {"timestamp":"2025-05-26T10:00:00Z","level":"info","message":"Interface: 10.2.0.2/32"} 70 | {"timestamp":"2025-05-26T10:00:00Z","level":"info","message":"Peer endpoint: 192.168.1.8:51820"} 71 | {"timestamp":"2025-05-26T10:00:00Z","level":"info","message":"Launching: curl https://icanhazip.com"} 72 | ``` 73 | 74 | When `--log-file` is specified, all logs are written to the file and nothing appears on the terminal. 75 | 76 | ## Configuration 77 | 78 | WrapGuard uses standard WireGuard configuration files: 79 | 80 | ```ini 81 | [Interface] 82 | PrivateKey = 83 | Address = 10.0.0.2/24 84 | 85 | [Peer] 86 | PublicKey = 87 | Endpoint = server.example.com:51820 88 | AllowedIPs = 0.0.0.0/0 89 | PersistentKeepalive = 25 90 | ``` 91 | 92 | ## How It Works 93 | 94 | 1. **Main Process**: Parses config, initializes WireGuard userspace implementation 95 | 2. **LD_PRELOAD Library**: Intercepts network system calls (socket, connect, send, recv, etc.) 96 | 3. **Virtual Network Stack**: Routes packets between intercepted connections and WireGuard tunnel 97 | 4. **Memory-based TUN**: No kernel interface needed, packets processed entirely in memory 98 | 99 | ## Limitations 100 | 101 | - Linux and macOS only (Windows is not supported) 102 | - TCP and UDP protocols only 103 | - Performance overhead due to userspace packet processing 104 | 105 | ## Development 106 | 107 | ### Running Tests 108 | 109 | WrapGuard includes comprehensive unit tests for all core functionality: 110 | 111 | ```bash 112 | # Run all tests 113 | go test -v ./... 114 | 115 | # Run tests with coverage 116 | go test -cover ./... 117 | 118 | # Run specific test file 119 | go test -v ./config_test.go 120 | ``` 121 | 122 | ### Test Coverage 123 | 124 | The test suite covers: 125 | - Configuration parsing and validation (`config_test.go`) 126 | - Structured JSON logging (`logging_test.go`) 127 | - Virtual network stack operations (`network_test.go`) 128 | - IPC communication protocols (`ipc_test.go`) 129 | - Version information (`version_test.go`) 130 | 131 | ### Building 132 | 133 | ```bash 134 | # Build the main binary 135 | make build 136 | 137 | # Build with debug information 138 | make debug 139 | 140 | # Clean build artifacts 141 | make clean 142 | ``` 143 | 144 | ## Testing 145 | 146 | ```bash 147 | # Test outgoing connection 148 | wrapguard --config=example-wg0.conf -- curl https://example.com 149 | 150 | # Test incoming connection 151 | wrapguard --config=example-wg0.conf -- python3 -m http.server 8080 152 | ``` 153 | -------------------------------------------------------------------------------- /config.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bufio" 5 | "encoding/base64" 6 | "fmt" 7 | "net" 8 | "os" 9 | "strconv" 10 | "strings" 11 | ) 12 | 13 | type WireGuardConfig struct { 14 | Interface InterfaceConfig 15 | Peers []PeerConfig 16 | } 17 | 18 | type InterfaceConfig struct { 19 | PrivateKey string 20 | Address *net.IPNet 21 | ListenPort int 22 | MTU int 23 | } 24 | 25 | type PeerConfig struct { 26 | PublicKey string 27 | Endpoint *net.UDPAddr 28 | AllowedIPs []*net.IPNet 29 | PersistentKeepalive int 30 | PresharedKey string 31 | } 32 | 33 | func ParseWireGuardConfig(path string) (*WireGuardConfig, error) { 34 | file, err := os.Open(path) 35 | if err != nil { 36 | return nil, fmt.Errorf("failed to open config file: %w", err) 37 | } 38 | defer file.Close() 39 | 40 | config := &WireGuardConfig{ 41 | Interface: InterfaceConfig{ 42 | MTU: 1420, // Default WireGuard MTU 43 | }, 44 | } 45 | 46 | scanner := bufio.NewScanner(file) 47 | var currentSection string 48 | var currentPeer *PeerConfig 49 | 50 | for scanner.Scan() { 51 | line := strings.TrimSpace(scanner.Text()) 52 | 53 | // Skip empty lines and comments 54 | if line == "" || strings.HasPrefix(line, "#") { 55 | continue 56 | } 57 | 58 | // Check for section headers 59 | if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") { 60 | section := strings.ToLower(strings.TrimSpace(line[1 : len(line)-1])) 61 | currentSection = section 62 | 63 | if section == "peer" { 64 | if currentPeer != nil { 65 | config.Peers = append(config.Peers, *currentPeer) 66 | } 67 | currentPeer = &PeerConfig{} 68 | } 69 | continue 70 | } 71 | 72 | // Parse key-value pairs 73 | parts := strings.SplitN(line, "=", 2) 74 | if len(parts) != 2 { 75 | continue 76 | } 77 | 78 | key := strings.TrimSpace(strings.ToLower(parts[0])) 79 | value := strings.TrimSpace(parts[1]) 80 | 81 | switch currentSection { 82 | case "interface": 83 | if err := parseInterfaceConfig(&config.Interface, key, value); err != nil { 84 | return nil, fmt.Errorf("invalid interface config: %w", err) 85 | } 86 | case "peer": 87 | if currentPeer == nil { 88 | currentPeer = &PeerConfig{} 89 | } 90 | if err := parsePeerConfig(currentPeer, key, value); err != nil { 91 | return nil, fmt.Errorf("invalid peer config: %w", err) 92 | } 93 | } 94 | } 95 | 96 | // Add the last peer if exists 97 | if currentPeer != nil { 98 | config.Peers = append(config.Peers, *currentPeer) 99 | } 100 | 101 | if err := scanner.Err(); err != nil { 102 | return nil, fmt.Errorf("error reading config: %w", err) 103 | } 104 | 105 | // Validate configuration 106 | if config.Interface.PrivateKey == "" { 107 | return nil, fmt.Errorf("missing private key in interface section") 108 | } 109 | if config.Interface.Address == nil { 110 | return nil, fmt.Errorf("missing address in interface section") 111 | } 112 | if len(config.Peers) == 0 { 113 | return nil, fmt.Errorf("no peers configured") 114 | } 115 | 116 | return config, nil 117 | } 118 | 119 | func parseInterfaceConfig(config *InterfaceConfig, key, value string) error { 120 | switch key { 121 | case "privatekey": 122 | // Validate base64 123 | if _, err := base64.StdEncoding.DecodeString(value); err != nil { 124 | return fmt.Errorf("invalid private key: %w", err) 125 | } 126 | config.PrivateKey = value 127 | case "address": 128 | // Handle comma-separated addresses for dual-stack IPv4/IPv6 129 | addresses := strings.Split(value, ",") 130 | for _, addr := range addresses { 131 | addr = strings.TrimSpace(addr) 132 | ip, ipnet, err := net.ParseCIDR(addr) 133 | if err != nil { 134 | return fmt.Errorf("invalid address: %w", err) 135 | } 136 | ipnet.IP = ip 137 | // For now, use the first address as the primary 138 | if config.Address == nil { 139 | config.Address = ipnet 140 | } 141 | } 142 | case "listenport": 143 | port, err := strconv.Atoi(value) 144 | if err != nil { 145 | return fmt.Errorf("invalid listen port: %w", err) 146 | } 147 | config.ListenPort = port 148 | case "mtu": 149 | mtu, err := strconv.Atoi(value) 150 | if err != nil { 151 | return fmt.Errorf("invalid MTU: %w", err) 152 | } 153 | config.MTU = mtu 154 | } 155 | return nil 156 | } 157 | 158 | func parsePeerConfig(config *PeerConfig, key, value string) error { 159 | switch key { 160 | case "publickey": 161 | // Validate base64 162 | if _, err := base64.StdEncoding.DecodeString(value); err != nil { 163 | return fmt.Errorf("invalid public key: %w", err) 164 | } 165 | config.PublicKey = value 166 | case "endpoint": 167 | // Handle IPv6 endpoints with brackets 168 | addr, err := net.ResolveUDPAddr("udp", value) 169 | if err != nil { 170 | return fmt.Errorf("invalid endpoint: %w", err) 171 | } 172 | config.Endpoint = addr 173 | case "allowedips": 174 | ips := strings.Split(value, ",") 175 | for _, ipStr := range ips { 176 | ipStr = strings.TrimSpace(ipStr) 177 | _, ipnet, err := net.ParseCIDR(ipStr) 178 | if err != nil { 179 | // Try parsing as single IP 180 | ip := net.ParseIP(ipStr) 181 | if ip == nil { 182 | return fmt.Errorf("invalid allowed IP: %s", ipStr) 183 | } 184 | if ip.To4() != nil { 185 | ipnet = &net.IPNet{IP: ip, Mask: net.CIDRMask(32, 32)} 186 | } else { 187 | ipnet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)} 188 | } 189 | } 190 | config.AllowedIPs = append(config.AllowedIPs, ipnet) 191 | } 192 | case "persistentkeepalive": 193 | keepalive, err := strconv.Atoi(value) 194 | if err != nil { 195 | return fmt.Errorf("invalid persistent keepalive: %w", err) 196 | } 197 | config.PersistentKeepalive = keepalive 198 | case "presharedkey": 199 | // Validate base64 200 | if _, err := base64.StdEncoding.DecodeString(value); err != nil { 201 | return fmt.Errorf("invalid preshared key: %w", err) 202 | } 203 | config.PresharedKey = value 204 | } 205 | return nil 206 | } 207 | -------------------------------------------------------------------------------- /config_ipv6_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "net" 5 | "os" 6 | "testing" 7 | ) 8 | 9 | func TestParseWireGuardConfigIPv6(t *testing.T) { 10 | // Create a temporary config file with IPv6 addresses 11 | content := `[Interface] 12 | PrivateKey = MFUi5Ifm+AoCB83PFlv4MIbT+gUUOCATAR1o+qJnuVc= 13 | Address = 2001:db8::1/64 14 | ListenPort = 51820 15 | MTU = 1420 16 | 17 | [Peer] 18 | PublicKey = lBKGHDRS3JrAJCFHJLe4cqhMnaaymBpKAhTxOFb8gT8= 19 | AllowedIPs = ::/0 20 | Endpoint = [2001:db8::2]:51820 21 | PersistentKeepalive = 25 22 | ` 23 | 24 | tmpFile, err := os.CreateTemp("", "test-wg-ipv6-*.conf") 25 | if err != nil { 26 | t.Fatalf("Failed to create temp file: %v", err) 27 | } 28 | defer os.Remove(tmpFile.Name()) 29 | 30 | if _, err := tmpFile.WriteString(content); err != nil { 31 | t.Fatalf("Failed to write temp file: %v", err) 32 | } 33 | tmpFile.Close() 34 | 35 | config, err := ParseWireGuardConfig(tmpFile.Name()) 36 | if err != nil { 37 | t.Fatalf("Failed to parse config: %v", err) 38 | } 39 | 40 | // Check interface configuration 41 | expectedIP := net.ParseIP("2001:db8::1") 42 | if !config.Interface.Address.IP.Equal(expectedIP) { 43 | t.Errorf("Expected interface IP %v, got %v", expectedIP, config.Interface.Address.IP) 44 | } 45 | 46 | expectedMask := net.CIDRMask(64, 128) 47 | if config.Interface.Address.Mask.String() != expectedMask.String() { 48 | t.Errorf("Expected mask %v, got %v", expectedMask, config.Interface.Address.Mask) 49 | } 50 | 51 | if config.Interface.ListenPort != 51820 { 52 | t.Errorf("Expected listen port 51820, got %d", config.Interface.ListenPort) 53 | } 54 | 55 | // Check peer configuration 56 | if len(config.Peers) != 1 { 57 | t.Fatalf("Expected 1 peer, got %d", len(config.Peers)) 58 | } 59 | 60 | peer := config.Peers[0] 61 | if peer.PublicKey != "lBKGHDRS3JrAJCFHJLe4cqhMnaaymBpKAhTxOFb8gT8=" { 62 | t.Errorf("Unexpected public key: %s", peer.PublicKey) 63 | } 64 | 65 | expectedEndpoint := &net.UDPAddr{ 66 | IP: net.ParseIP("2001:db8::2"), 67 | Port: 51820, 68 | } 69 | if peer.Endpoint.String() != expectedEndpoint.String() { 70 | t.Errorf("Expected endpoint %v, got %v", expectedEndpoint, peer.Endpoint) 71 | } 72 | 73 | // Check allowed IPs 74 | if len(peer.AllowedIPs) != 1 { 75 | t.Fatalf("Expected 1 allowed IP, got %d", len(peer.AllowedIPs)) 76 | } 77 | 78 | expectedAllowedIP := &net.IPNet{ 79 | IP: net.IPv6zero, 80 | Mask: net.CIDRMask(0, 128), 81 | } 82 | if peer.AllowedIPs[0].String() != expectedAllowedIP.String() { 83 | t.Errorf("Expected allowed IP %v, got %v", expectedAllowedIP, peer.AllowedIPs[0]) 84 | } 85 | } 86 | 87 | func TestParseWireGuardConfigDualStack(t *testing.T) { 88 | // Create a temporary config file with both IPv4 and IPv6 addresses 89 | content := `[Interface] 90 | PrivateKey = MFUi5Ifm+AoCB83PFlv4MIbT+gUUOCATAR1o+qJnuVc= 91 | Address = 10.2.0.2/32, 2001:db8::1/64 92 | ListenPort = 51820 93 | 94 | [Peer] 95 | PublicKey = lBKGHDRS3JrAJCFHJLe4cqhMnaaymBpKAhTxOFb8gT8= 96 | AllowedIPs = 0.0.0.0/0, ::/0 97 | Endpoint = 192.168.64.6:51820 98 | PersistentKeepalive = 25 99 | ` 100 | 101 | tmpFile, err := os.CreateTemp("", "test-wg-dual-*.conf") 102 | if err != nil { 103 | t.Fatalf("Failed to create temp file: %v", err) 104 | } 105 | defer os.Remove(tmpFile.Name()) 106 | 107 | if _, err := tmpFile.WriteString(content); err != nil { 108 | t.Fatalf("Failed to write temp file: %v", err) 109 | } 110 | tmpFile.Close() 111 | 112 | config, err := ParseWireGuardConfig(tmpFile.Name()) 113 | if err != nil { 114 | t.Fatalf("Failed to parse config: %v", err) 115 | } 116 | 117 | // Check interface configuration (should use the first address) 118 | expectedIP := net.ParseIP("10.2.0.2") 119 | if !config.Interface.Address.IP.Equal(expectedIP) { 120 | t.Errorf("Expected interface IP %v, got %v", expectedIP, config.Interface.Address.IP) 121 | } 122 | 123 | // Check peer configuration 124 | if len(config.Peers) != 1 { 125 | t.Fatalf("Expected 1 peer, got %d", len(config.Peers)) 126 | } 127 | 128 | peer := config.Peers[0] 129 | 130 | // Check allowed IPs (should have both IPv4 and IPv6) 131 | if len(peer.AllowedIPs) != 2 { 132 | t.Fatalf("Expected 2 allowed IPs, got %d", len(peer.AllowedIPs)) 133 | } 134 | 135 | // Check IPv4 allowed IP 136 | expectedIPv4 := &net.IPNet{ 137 | IP: net.IPv4zero, 138 | Mask: net.CIDRMask(0, 32), 139 | } 140 | found := false 141 | for _, allowedIP := range peer.AllowedIPs { 142 | if allowedIP.String() == expectedIPv4.String() { 143 | found = true 144 | break 145 | } 146 | } 147 | if !found { 148 | t.Errorf("Expected IPv4 allowed IP %v not found", expectedIPv4) 149 | } 150 | 151 | // Check IPv6 allowed IP 152 | expectedIPv6 := &net.IPNet{ 153 | IP: net.IPv6zero, 154 | Mask: net.CIDRMask(0, 128), 155 | } 156 | found = false 157 | for _, allowedIP := range peer.AllowedIPs { 158 | if allowedIP.String() == expectedIPv6.String() { 159 | found = true 160 | break 161 | } 162 | } 163 | if !found { 164 | t.Errorf("Expected IPv6 allowed IP %v not found", expectedIPv6) 165 | } 166 | } 167 | 168 | func TestParseIPv6EndpointFormats(t *testing.T) { 169 | tests := []struct { 170 | name string 171 | endpoint string 172 | expected string 173 | }{ 174 | { 175 | name: "IPv6 with brackets", 176 | endpoint: "[2001:db8::1]:51820", 177 | expected: "[2001:db8::1]:51820", 178 | }, 179 | { 180 | name: "IPv4 endpoint", 181 | endpoint: "192.168.1.1:51820", 182 | expected: "192.168.1.1:51820", 183 | }, 184 | { 185 | name: "IPv6 localhost with brackets", 186 | endpoint: "[::1]:51820", 187 | expected: "[::1]:51820", 188 | }, 189 | } 190 | 191 | for _, tt := range tests { 192 | t.Run(tt.name, func(t *testing.T) { 193 | content := `[Interface] 194 | PrivateKey = MFUi5Ifm+AoCB83PFlv4MIbT+gUUOCATAR1o+qJnuVc= 195 | Address = 10.2.0.2/32 196 | 197 | [Peer] 198 | PublicKey = lBKGHDRS3JrAJCFHJLe4cqhMnaaymBpKAhTxOFb8gT8= 199 | AllowedIPs = 0.0.0.0/0 200 | Endpoint = ` + tt.endpoint + ` 201 | ` 202 | 203 | tmpFile, err := os.CreateTemp("", "test-wg-endpoint-*.conf") 204 | if err != nil { 205 | t.Fatalf("Failed to create temp file: %v", err) 206 | } 207 | defer os.Remove(tmpFile.Name()) 208 | 209 | if _, err := tmpFile.WriteString(content); err != nil { 210 | t.Fatalf("Failed to write temp file: %v", err) 211 | } 212 | tmpFile.Close() 213 | 214 | config, err := ParseWireGuardConfig(tmpFile.Name()) 215 | if err != nil { 216 | t.Fatalf("Failed to parse config: %v", err) 217 | } 218 | 219 | if len(config.Peers) != 1 { 220 | t.Fatalf("Expected 1 peer, got %d", len(config.Peers)) 221 | } 222 | 223 | peer := config.Peers[0] 224 | if peer.Endpoint.String() != tt.expected { 225 | t.Errorf("Expected endpoint %s, got %s", tt.expected, peer.Endpoint.String()) 226 | } 227 | }) 228 | } 229 | } 230 | -------------------------------------------------------------------------------- /config_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | ) 7 | 8 | func TestParseWireGuardConfig(t *testing.T) { 9 | // Create a temporary config file 10 | configContent := `[Interface] 11 | PrivateKey = YWJjZGVmZ2hpamtsb21ub3Bxcnp0dXZ3eHl6MTIzNDU2Nzg5MA== 12 | Address = 10.0.0.2/24 13 | ListenPort = 51820 14 | MTU = 1420 15 | 16 | [Peer] 17 | PublicKey = dGVzdGtleWZvcnRlc3RpbmcxMjM0NTY3ODlhYmNkZWZnaGlqaw== 18 | Endpoint = 192.168.1.1:51820 19 | AllowedIPs = 0.0.0.0/0 20 | PersistentKeepalive = 25 21 | ` 22 | 23 | tmpfile, err := os.CreateTemp("", "wg-config-*.conf") 24 | if err != nil { 25 | t.Fatal(err) 26 | } 27 | defer os.Remove(tmpfile.Name()) 28 | 29 | if _, err := tmpfile.Write([]byte(configContent)); err != nil { 30 | t.Fatal(err) 31 | } 32 | tmpfile.Close() 33 | 34 | // Test parsing 35 | config, err := ParseWireGuardConfig(tmpfile.Name()) 36 | if err != nil { 37 | t.Fatalf("Failed to parse config: %v", err) 38 | } 39 | 40 | // Validate interface config 41 | if config.Interface.PrivateKey != "YWJjZGVmZ2hpamtsb21ub3Bxcnp0dXZ3eHl6MTIzNDU2Nzg5MA==" { 42 | t.Errorf("Expected private key to match") 43 | } 44 | 45 | if config.Interface.Address.String() != "10.0.0.2/24" { 46 | t.Errorf("Expected address 10.0.0.2/24, got %s", config.Interface.Address.String()) 47 | } 48 | 49 | if config.Interface.ListenPort != 51820 { 50 | t.Errorf("Expected listen port 51820, got %d", config.Interface.ListenPort) 51 | } 52 | 53 | if config.Interface.MTU != 1420 { 54 | t.Errorf("Expected MTU 1420, got %d", config.Interface.MTU) 55 | } 56 | 57 | // Validate peer config 58 | if len(config.Peers) != 1 { 59 | t.Fatalf("Expected 1 peer, got %d", len(config.Peers)) 60 | } 61 | 62 | peer := config.Peers[0] 63 | if peer.PublicKey != "dGVzdGtleWZvcnRlc3RpbmcxMjM0NTY3ODlhYmNkZWZnaGlqaw==" { 64 | t.Errorf("Expected peer public key to match") 65 | } 66 | 67 | if peer.Endpoint.String() != "192.168.1.1:51820" { 68 | t.Errorf("Expected endpoint 192.168.1.1:51820, got %s", peer.Endpoint.String()) 69 | } 70 | 71 | if len(peer.AllowedIPs) != 1 { 72 | t.Fatalf("Expected 1 allowed IP, got %d", len(peer.AllowedIPs)) 73 | } 74 | 75 | if peer.AllowedIPs[0].String() != "0.0.0.0/0" { 76 | t.Errorf("Expected allowed IP 0.0.0.0/0, got %s", peer.AllowedIPs[0].String()) 77 | } 78 | 79 | if peer.PersistentKeepalive != 25 { 80 | t.Errorf("Expected persistent keepalive 25, got %d", peer.PersistentKeepalive) 81 | } 82 | } 83 | 84 | func TestParseWireGuardConfigMissingFields(t *testing.T) { 85 | tests := []struct { 86 | name string 87 | config string 88 | wantErr string 89 | }{ 90 | { 91 | name: "missing private key", 92 | config: "[Interface]\nAddress = 10.0.0.2/24\n", 93 | wantErr: "missing private key", 94 | }, 95 | { 96 | name: "missing address", 97 | config: "[Interface]\nPrivateKey = YWJjZGVmZ2hpamtsb21ub3Bxcnp0dXZ3eHl6MTIzNDU2Nzg5MA==\n", 98 | wantErr: "missing address", 99 | }, 100 | { 101 | name: "no peers", 102 | config: "[Interface]\nPrivateKey = YWJjZGVmZ2hpamtsb21ub3Bxcnp0dXZ3eHl6MTIzNDU2Nzg5MA==\nAddress = 10.0.0.2/24\n", 103 | wantErr: "no peers configured", 104 | }, 105 | { 106 | name: "invalid private key", 107 | config: "[Interface]\nPrivateKey = invalid-key\nAddress = 10.0.0.2/24\n[Peer]\nPublicKey = dGVzdGtleWZvcnRlc3RpbmcxMjM0NTY3ODlhYmNkZWZnaGlqaw==\n", 108 | wantErr: "invalid private key", 109 | }, 110 | } 111 | 112 | for _, tt := range tests { 113 | t.Run(tt.name, func(t *testing.T) { 114 | tmpfile, err := os.CreateTemp("", "wg-config-*.conf") 115 | if err != nil { 116 | t.Fatal(err) 117 | } 118 | defer os.Remove(tmpfile.Name()) 119 | 120 | if _, err := tmpfile.Write([]byte(tt.config)); err != nil { 121 | t.Fatal(err) 122 | } 123 | tmpfile.Close() 124 | 125 | _, err = ParseWireGuardConfig(tmpfile.Name()) 126 | if err == nil { 127 | t.Errorf("Expected error containing '%s', but got no error", tt.wantErr) 128 | } else if err.Error() == "" || len(err.Error()) < 3 { 129 | t.Errorf("Expected meaningful error message, got: %v", err) 130 | } 131 | }) 132 | } 133 | } 134 | 135 | func TestParseInterfaceConfig(t *testing.T) { 136 | config := &InterfaceConfig{MTU: 1420} 137 | 138 | tests := []struct { 139 | key string 140 | value string 141 | check func(*testing.T, *InterfaceConfig) 142 | }{ 143 | { 144 | key: "privatekey", 145 | value: "YWJjZGVmZ2hpamtsb21ub3Bxcnp0dXZ3eHl6MTIzNDU2Nzg5MA==", 146 | check: func(t *testing.T, c *InterfaceConfig) { 147 | if c.PrivateKey != "YWJjZGVmZ2hpamtsb21ub3Bxcnp0dXZ3eHl6MTIzNDU2Nzg5MA==" { 148 | t.Errorf("Private key not set correctly") 149 | } 150 | }, 151 | }, 152 | { 153 | key: "address", 154 | value: "10.0.0.2/24", 155 | check: func(t *testing.T, c *InterfaceConfig) { 156 | if c.Address == nil || c.Address.String() != "10.0.0.2/24" { 157 | t.Errorf("Address not set correctly") 158 | } 159 | }, 160 | }, 161 | { 162 | key: "listenport", 163 | value: "51820", 164 | check: func(t *testing.T, c *InterfaceConfig) { 165 | if c.ListenPort != 51820 { 166 | t.Errorf("Listen port not set correctly") 167 | } 168 | }, 169 | }, 170 | { 171 | key: "mtu", 172 | value: "1500", 173 | check: func(t *testing.T, c *InterfaceConfig) { 174 | if c.MTU != 1500 { 175 | t.Errorf("MTU not set correctly") 176 | } 177 | }, 178 | }, 179 | } 180 | 181 | for _, tt := range tests { 182 | t.Run(tt.key, func(t *testing.T) { 183 | err := parseInterfaceConfig(config, tt.key, tt.value) 184 | if err != nil { 185 | t.Errorf("Unexpected error: %v", err) 186 | } 187 | tt.check(t, config) 188 | }) 189 | } 190 | } 191 | 192 | func TestParsePeerConfig(t *testing.T) { 193 | config := &PeerConfig{} 194 | 195 | // Test public key 196 | err := parsePeerConfig(config, "publickey", "dGVzdGtleWZvcnRlc3RpbmcxMjM0NTY3ODlhYmNkZWZnaGlqaw==") 197 | if err != nil { 198 | t.Errorf("Unexpected error: %v", err) 199 | } 200 | if config.PublicKey != "dGVzdGtleWZvcnRlc3RpbmcxMjM0NTY3ODlhYmNkZWZnaGlqaw==" { 201 | t.Errorf("Public key not set correctly") 202 | } 203 | 204 | // Test endpoint 205 | err = parsePeerConfig(config, "endpoint", "192.168.1.1:51820") 206 | if err != nil { 207 | t.Errorf("Unexpected error: %v", err) 208 | } 209 | if config.Endpoint == nil || config.Endpoint.String() != "192.168.1.1:51820" { 210 | t.Errorf("Endpoint not set correctly") 211 | } 212 | 213 | // Test allowed IPs 214 | err = parsePeerConfig(config, "allowedips", "0.0.0.0/0, 192.168.1.0/24") 215 | if err != nil { 216 | t.Errorf("Unexpected error: %v", err) 217 | } 218 | if len(config.AllowedIPs) != 2 { 219 | t.Errorf("Expected 2 allowed IPs, got %d", len(config.AllowedIPs)) 220 | } 221 | 222 | // Test persistent keepalive 223 | err = parsePeerConfig(config, "persistentkeepalive", "25") 224 | if err != nil { 225 | t.Errorf("Unexpected error: %v", err) 226 | } 227 | if config.PersistentKeepalive != 25 { 228 | t.Errorf("Persistent keepalive not set correctly") 229 | } 230 | } 231 | 232 | func TestParseAllowedIPs(t *testing.T) { 233 | config := &PeerConfig{} 234 | 235 | // Test various IP formats 236 | tests := []struct { 237 | input string 238 | expected int 239 | }{ 240 | {"0.0.0.0/0", 1}, 241 | {"192.168.1.0/24", 1}, 242 | {"10.0.0.1", 1}, // Single IP should become /32 243 | {"192.168.1.0/24, 10.0.0.0/8", 2}, 244 | {"8.8.8.8, 1.1.1.1", 2}, 245 | } 246 | 247 | for _, tt := range tests { 248 | t.Run(tt.input, func(t *testing.T) { 249 | config.AllowedIPs = nil // Reset 250 | err := parsePeerConfig(config, "allowedips", tt.input) 251 | if err != nil { 252 | t.Errorf("Unexpected error for input '%s': %v", tt.input, err) 253 | } 254 | if len(config.AllowedIPs) != tt.expected { 255 | t.Errorf("Expected %d allowed IPs for input '%s', got %d", tt.expected, tt.input, len(config.AllowedIPs)) 256 | } 257 | }) 258 | } 259 | } 260 | 261 | func TestConfigFileNotExists(t *testing.T) { 262 | _, err := ParseWireGuardConfig("/nonexistent/file.conf") 263 | if err == nil { 264 | t.Error("Expected error for non-existent file") 265 | } 266 | } 267 | -------------------------------------------------------------------------------- /demo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Enable color output 4 | export TERM=xterm-256color 5 | 6 | echo "=== WrapGuard Demo ===" 7 | echo 8 | 9 | echo "1. Showing help screen:" 10 | echo "----------------------" 11 | ./wrapguard --help 12 | 13 | echo 14 | echo 15 | echo "2. Showing version:" 16 | echo "------------------" 17 | ./wrapguard --version 18 | 19 | echo 20 | echo 21 | echo "3. Running without config (shows help):" 22 | echo "--------------------------------------" 23 | ./wrapguard 2>&1 || true 24 | 25 | echo 26 | echo 27 | echo "4. Example with config file:" 28 | echo "---------------------------" 29 | echo "./wrapguard --config=example-wg0.conf -- echo 'Hello from WireGuard tunnel!'" -------------------------------------------------------------------------------- /example-wg0.conf: -------------------------------------------------------------------------------- 1 | [Interface] 2 | PrivateKey = MFUi5Ifm+AoCB83PFlv4MIbT+gUUOCATAR1o+qJnuVc= 3 | Address = 10.2.0.2/32 4 | DNS = 8.8.8.8 5 | 6 | [Peer] 7 | PublicKey = lBKGHDRS3JrAJCFHJLe4cqhMnaaymBpKAhTxOFb8gT8= 8 | AllowedIPs = 0.0.0.0/0 9 | Endpoint = 192.168.64.6:51820 10 | PersistentKeepalive = 25 11 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/puzed/wrapguard 2 | 3 | go 1.23.1 4 | 5 | toolchain go1.24.2 6 | 7 | require ( 8 | golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb 9 | golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 10 | ) 11 | 12 | require ( 13 | golang.org/x/crypto v0.38.0 // indirect 14 | golang.org/x/net v0.40.0 // indirect 15 | golang.org/x/sys v0.33.0 // indirect 16 | golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect 17 | ) 18 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= 2 | github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= 3 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= 4 | github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 5 | golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= 6 | golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= 7 | golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY= 8 | golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds= 9 | golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= 10 | golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= 11 | golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= 12 | golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= 13 | golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= 14 | golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= 15 | golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A= 16 | golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw= 17 | golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU= 18 | golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ= 19 | gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c h1:m/r7OM+Y2Ty1sgBQ7Qb27VgIMBW8ZZhT4gLnUyDIhzI= 20 | gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g= 21 | -------------------------------------------------------------------------------- /ipc.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/binary" 5 | "encoding/json" 6 | "fmt" 7 | "net" 8 | "os" 9 | "path/filepath" 10 | "sync" 11 | "time" 12 | ) 13 | 14 | // IPCServer handles communication with the LD_PRELOAD library 15 | type IPCServer struct { 16 | socketPath string 17 | listener net.Listener 18 | netStack *VirtualNetworkStack 19 | wgProxy *WireGuardProxy 20 | stopChan chan struct{} 21 | wg sync.WaitGroup 22 | } 23 | 24 | // IPCMessage represents a message between the main process and LD_PRELOAD library 25 | type IPCMessage struct { 26 | Type string `json:"type"` 27 | ConnID uint32 `json:"conn_id,omitempty"` 28 | SocketFD int `json:"socket_fd,omitempty"` 29 | Domain int `json:"domain,omitempty"` 30 | SockType int `json:"sock_type,omitempty"` 31 | Protocol int `json:"protocol,omitempty"` 32 | Address string `json:"address,omitempty"` 33 | Port int `json:"port,omitempty"` 34 | Data []byte `json:"data,omitempty"` 35 | Error string `json:"error,omitempty"` 36 | } 37 | 38 | // IPCResponse represents a response to an IPC message 39 | type IPCResponse struct { 40 | Success bool `json:"success"` 41 | ConnID uint32 `json:"conn_id,omitempty"` 42 | Data []byte `json:"data,omitempty"` 43 | Error string `json:"error,omitempty"` 44 | } 45 | 46 | // NewIPCServer creates a new IPC server 47 | func NewIPCServer(netStack *VirtualNetworkStack, wgProxy *WireGuardProxy) (*IPCServer, error) { 48 | // Create socket path in temp directory 49 | tmpDir := os.TempDir() 50 | socketPath := filepath.Join(tmpDir, fmt.Sprintf("wrapguard_%d.sock", os.Getpid())) 51 | 52 | return &IPCServer{ 53 | socketPath: socketPath, 54 | netStack: netStack, 55 | wgProxy: wgProxy, 56 | stopChan: make(chan struct{}), 57 | }, nil 58 | } 59 | 60 | // SocketPath returns the Unix socket path 61 | func (s *IPCServer) SocketPath() string { 62 | return s.socketPath 63 | } 64 | 65 | // Start starts the IPC server 66 | func (s *IPCServer) Start() error { 67 | // Remove any existing socket 68 | os.Remove(s.socketPath) 69 | 70 | // Create Unix domain socket 71 | listener, err := net.Listen("unix", s.socketPath) 72 | if err != nil { 73 | return fmt.Errorf("failed to create Unix socket: %w", err) 74 | } 75 | s.listener = listener 76 | 77 | // Start accepting connections 78 | s.wg.Add(1) 79 | go s.acceptConnections() 80 | 81 | return nil 82 | } 83 | 84 | // Stop stops the IPC server 85 | func (s *IPCServer) Stop() error { 86 | close(s.stopChan) 87 | if s.listener != nil { 88 | s.listener.Close() 89 | } 90 | s.wg.Wait() 91 | os.Remove(s.socketPath) 92 | return nil 93 | } 94 | 95 | // acceptConnections accepts incoming IPC connections 96 | func (s *IPCServer) acceptConnections() { 97 | defer s.wg.Done() 98 | 99 | for { 100 | conn, err := s.listener.Accept() 101 | if err != nil { 102 | select { 103 | case <-s.stopChan: 104 | return 105 | default: 106 | // Log error and continue 107 | continue 108 | } 109 | } 110 | 111 | s.wg.Add(1) 112 | go s.handleConnection(conn) 113 | } 114 | } 115 | 116 | // handleConnection handles a single IPC connection 117 | func (s *IPCServer) handleConnection(conn net.Conn) { 118 | defer s.wg.Done() 119 | defer conn.Close() 120 | 121 | decoder := json.NewDecoder(conn) 122 | encoder := json.NewEncoder(conn) 123 | 124 | for { 125 | var msg IPCMessage 126 | if err := decoder.Decode(&msg); err != nil { 127 | return 128 | } 129 | 130 | response := s.handleMessage(&msg) 131 | if err := encoder.Encode(response); err != nil { 132 | return 133 | } 134 | } 135 | } 136 | 137 | // handleMessage processes an IPC message and returns a response 138 | func (s *IPCServer) handleMessage(msg *IPCMessage) *IPCResponse { 139 | switch msg.Type { 140 | case "socket": 141 | return s.handleSocket(msg) 142 | case "bind": 143 | return s.handleBind(msg) 144 | case "listen": 145 | return s.handleListen(msg) 146 | case "accept": 147 | return s.handleAccept(msg) 148 | case "connect": 149 | return s.handleConnect(msg) 150 | case "send": 151 | return s.handleSend(msg) 152 | case "recv": 153 | return s.handleRecv(msg) 154 | case "close": 155 | return s.handleClose(msg) 156 | default: 157 | return &IPCResponse{ 158 | Success: false, 159 | Error: fmt.Sprintf("unknown message type: %s", msg.Type), 160 | } 161 | } 162 | } 163 | 164 | // handleSocket handles socket creation 165 | func (s *IPCServer) handleSocket(msg *IPCMessage) *IPCResponse { 166 | // Only support AF_INET (2) and SOCK_STREAM (1) or SOCK_DGRAM (2) 167 | if msg.Domain != 2 { 168 | return &IPCResponse{ 169 | Success: false, 170 | Error: "only AF_INET supported", 171 | } 172 | } 173 | 174 | var connType string 175 | switch msg.SockType { 176 | case 1: // SOCK_STREAM 177 | connType = "tcp" 178 | case 2: // SOCK_DGRAM 179 | connType = "udp" 180 | default: 181 | return &IPCResponse{ 182 | Success: false, 183 | Error: "unsupported socket type", 184 | } 185 | } 186 | 187 | conn, err := s.netStack.CreateConnection(connType) 188 | if err != nil { 189 | return &IPCResponse{ 190 | Success: false, 191 | Error: err.Error(), 192 | } 193 | } 194 | 195 | return &IPCResponse{ 196 | Success: true, 197 | ConnID: conn.ID, 198 | } 199 | } 200 | 201 | // handleBind handles bind requests 202 | func (s *IPCServer) handleBind(msg *IPCMessage) *IPCResponse { 203 | var addr net.Addr 204 | if msg.Address == "" || msg.Address == "0.0.0.0" { 205 | // Use WireGuard interface IP 206 | if s.wgProxy.config.Interface.Address != nil { 207 | msg.Address = s.wgProxy.config.Interface.Address.IP.String() 208 | } 209 | } 210 | 211 | conn, _ := s.getConnection(msg.ConnID) 212 | if conn == nil { 213 | return &IPCResponse{ 214 | Success: false, 215 | Error: "connection not found", 216 | } 217 | } 218 | 219 | if conn.Type == "tcp" { 220 | addr = &net.TCPAddr{ 221 | IP: net.ParseIP(msg.Address), 222 | Port: msg.Port, 223 | } 224 | } else { 225 | addr = &net.UDPAddr{ 226 | IP: net.ParseIP(msg.Address), 227 | Port: msg.Port, 228 | } 229 | } 230 | 231 | if err := s.netStack.BindConnection(msg.ConnID, addr); err != nil { 232 | return &IPCResponse{ 233 | Success: false, 234 | Error: err.Error(), 235 | } 236 | } 237 | 238 | return &IPCResponse{Success: true} 239 | } 240 | 241 | // handleListen handles listen requests 242 | func (s *IPCServer) handleListen(msg *IPCMessage) *IPCResponse { 243 | if err := s.netStack.ListenConnection(msg.ConnID); err != nil { 244 | return &IPCResponse{ 245 | Success: false, 246 | Error: err.Error(), 247 | } 248 | } 249 | 250 | return &IPCResponse{Success: true} 251 | } 252 | 253 | // handleAccept handles accept requests 254 | func (s *IPCServer) handleAccept(msg *IPCMessage) *IPCResponse { 255 | conn, _ := s.getConnection(msg.ConnID) 256 | if conn == nil { 257 | return &IPCResponse{ 258 | Success: false, 259 | Error: "connection not found", 260 | } 261 | } 262 | 263 | // Try to accept with timeout 264 | for i := 0; i < 100; i++ { // 10 second timeout 265 | newConn, err := s.netStack.AcceptConnection(conn.LocalAddr) 266 | if err == nil { 267 | return &IPCResponse{ 268 | Success: true, 269 | ConnID: newConn.ID, 270 | } 271 | } 272 | time.Sleep(100 * time.Millisecond) 273 | } 274 | 275 | return &IPCResponse{ 276 | Success: false, 277 | Error: "accept timeout", 278 | } 279 | } 280 | 281 | // handleConnect handles connect requests 282 | func (s *IPCServer) handleConnect(msg *IPCMessage) *IPCResponse { 283 | conn, _ := s.getConnection(msg.ConnID) 284 | if conn == nil { 285 | return &IPCResponse{ 286 | Success: false, 287 | Error: "connection not found", 288 | } 289 | } 290 | 291 | var remoteAddr net.Addr 292 | if conn.Type == "tcp" { 293 | remoteAddr = &net.TCPAddr{ 294 | IP: net.ParseIP(msg.Address), 295 | Port: msg.Port, 296 | } 297 | } else { 298 | remoteAddr = &net.UDPAddr{ 299 | IP: net.ParseIP(msg.Address), 300 | Port: msg.Port, 301 | } 302 | } 303 | 304 | if err := s.netStack.ConnectConnection(msg.ConnID, remoteAddr); err != nil { 305 | return &IPCResponse{ 306 | Success: false, 307 | Error: err.Error(), 308 | } 309 | } 310 | 311 | // Set local address in network stack 312 | if s.wgProxy.config.Interface.Address != nil { 313 | s.netStack.SetLocalAddress(s.wgProxy.config.Interface.Address) 314 | } 315 | 316 | return &IPCResponse{Success: true} 317 | } 318 | 319 | // handleSend handles send requests 320 | func (s *IPCServer) handleSend(msg *IPCMessage) *IPCResponse { 321 | if err := s.netStack.SendData(msg.ConnID, msg.Data); err != nil { 322 | return &IPCResponse{ 323 | Success: false, 324 | Error: err.Error(), 325 | } 326 | } 327 | 328 | return &IPCResponse{ 329 | Success: true, 330 | Data: []byte{}, // Return number of bytes sent 331 | } 332 | } 333 | 334 | // handleRecv handles receive requests 335 | func (s *IPCServer) handleRecv(msg *IPCMessage) *IPCResponse { 336 | data, err := s.netStack.ReceiveData(msg.ConnID) 337 | if err != nil { 338 | // Try waiting a bit for data 339 | for i := 0; i < 10; i++ { 340 | time.Sleep(100 * time.Millisecond) 341 | data, err = s.netStack.ReceiveData(msg.ConnID) 342 | if err == nil { 343 | break 344 | } 345 | } 346 | 347 | if err != nil { 348 | return &IPCResponse{ 349 | Success: false, 350 | Error: err.Error(), 351 | } 352 | } 353 | } 354 | 355 | return &IPCResponse{ 356 | Success: true, 357 | Data: data, 358 | } 359 | } 360 | 361 | // handleClose handles close requests 362 | func (s *IPCServer) handleClose(msg *IPCMessage) *IPCResponse { 363 | if err := s.netStack.CloseConnection(msg.ConnID); err != nil { 364 | return &IPCResponse{ 365 | Success: false, 366 | Error: err.Error(), 367 | } 368 | } 369 | 370 | return &IPCResponse{Success: true} 371 | } 372 | 373 | // getConnection retrieves a connection by ID 374 | func (s *IPCServer) getConnection(connID uint32) (*VirtualConnection, bool) { 375 | s.netStack.mu.RLock() 376 | defer s.netStack.mu.RUnlock() 377 | conn, exists := s.netStack.connections[connID] 378 | return conn, exists 379 | } 380 | 381 | // Helper to read length-prefixed messages 382 | func readMessage(conn net.Conn) ([]byte, error) { 383 | // Read 4-byte length prefix 384 | lengthBuf := make([]byte, 4) 385 | if _, err := conn.Read(lengthBuf); err != nil { 386 | return nil, err 387 | } 388 | 389 | length := binary.BigEndian.Uint32(lengthBuf) 390 | if length > 1024*1024 { // 1MB max message size 391 | return nil, fmt.Errorf("message too large: %d bytes", length) 392 | } 393 | 394 | // Read message data 395 | data := make([]byte, length) 396 | if _, err := conn.Read(data); err != nil { 397 | return nil, err 398 | } 399 | 400 | return data, nil 401 | } 402 | 403 | // Helper to write length-prefixed messages 404 | func writeMessage(conn net.Conn, data []byte) error { 405 | // Write 4-byte length prefix 406 | lengthBuf := make([]byte, 4) 407 | binary.BigEndian.PutUint32(lengthBuf, uint32(len(data))) 408 | 409 | if _, err := conn.Write(lengthBuf); err != nil { 410 | return err 411 | } 412 | 413 | // Write message data 414 | if _, err := conn.Write(data); err != nil { 415 | return err 416 | } 417 | 418 | return nil 419 | } 420 | -------------------------------------------------------------------------------- /ipc_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "net" 6 | "path/filepath" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | func TestNewIPCServer(t *testing.T) { 12 | netStack, _ := NewVirtualNetworkStack() 13 | 14 | // Create a minimal WireGuard config for testing 15 | config := &WireGuardConfig{ 16 | Interface: InterfaceConfig{ 17 | PrivateKey: "YWJjZGVmZ2hpamtsb21ub3Bxcnp0dXZ3eHl6MTIzNDU2Nzg5MA==", 18 | }, 19 | } 20 | _, ipnet, _ := net.ParseCIDR("10.0.0.2/24") 21 | config.Interface.Address = ipnet 22 | 23 | // Create a mock WireGuard proxy for testing 24 | wgProxy := &WireGuardProxy{ 25 | config: config, 26 | } 27 | 28 | server, err := NewIPCServer(netStack, wgProxy) 29 | if err != nil { 30 | t.Fatalf("Failed to create IPC server: %v", err) 31 | } 32 | 33 | if server == nil { 34 | t.Fatal("IPC server should not be nil") 35 | } 36 | 37 | if server.socketPath == "" { 38 | t.Error("Socket path should not be empty") 39 | } 40 | 41 | if server.netStack != netStack { 42 | t.Error("Network stack should be set correctly") 43 | } 44 | 45 | if server.wgProxy != wgProxy { 46 | t.Error("WireGuard proxy should be set correctly") 47 | } 48 | } 49 | 50 | func TestSocketPath(t *testing.T) { 51 | netStack, _ := NewVirtualNetworkStack() 52 | config := &WireGuardConfig{ 53 | Interface: InterfaceConfig{ 54 | PrivateKey: "YWJjZGVmZ2hpamtsb21ub3Bxcnp0dXZ3eHl6MTIzNDU2Nzg5MA==", 55 | }, 56 | } 57 | _, ipnet, _ := net.ParseCIDR("10.0.0.2/24") 58 | config.Interface.Address = ipnet 59 | 60 | wgProxy := &WireGuardProxy{config: config} 61 | server, _ := NewIPCServer(netStack, wgProxy) 62 | 63 | socketPath := server.SocketPath() 64 | if socketPath == "" { 65 | t.Error("Socket path should not be empty") 66 | } 67 | 68 | // Check that socket path is absolute 69 | if !filepath.IsAbs(socketPath) { 70 | t.Error("Socket path should be absolute") 71 | } 72 | } 73 | 74 | func TestHandleSocketMessage(t *testing.T) { 75 | netStack, _ := NewVirtualNetworkStack() 76 | config := &WireGuardConfig{ 77 | Interface: InterfaceConfig{ 78 | PrivateKey: "YWJjZGVmZ2hpamtsb21ub3Bxcnp0dXZ3eHl6MTIzNDU2Nzg5MA==", 79 | }, 80 | } 81 | _, ipnet, _ := net.ParseCIDR("10.0.0.2/24") 82 | config.Interface.Address = ipnet 83 | 84 | wgProxy := &WireGuardProxy{config: config} 85 | server, _ := NewIPCServer(netStack, wgProxy) 86 | 87 | // Test TCP socket creation 88 | msg := &IPCMessage{ 89 | Type: "socket", 90 | Domain: 2, // AF_INET 91 | SockType: 1, // SOCK_STREAM 92 | Protocol: 0, 93 | } 94 | 95 | response := server.handleSocket(msg) 96 | if !response.Success { 97 | t.Errorf("Expected success, got error: %s", response.Error) 98 | } 99 | 100 | if response.ConnID == 0 { 101 | t.Error("Connection ID should be non-zero") 102 | } 103 | 104 | // Test UDP socket creation 105 | msg.SockType = 2 // SOCK_DGRAM 106 | response = server.handleSocket(msg) 107 | if !response.Success { 108 | t.Errorf("Expected success, got error: %s", response.Error) 109 | } 110 | 111 | // Test unsupported domain 112 | msg.Domain = 10 // AF_INET6 (unsupported) 113 | response = server.handleSocket(msg) 114 | if response.Success { 115 | t.Error("Expected failure for unsupported domain") 116 | } 117 | 118 | // Test unsupported socket type 119 | msg.Domain = 2 120 | msg.SockType = 3 // SOCK_RAW (unsupported) 121 | response = server.handleSocket(msg) 122 | if response.Success { 123 | t.Error("Expected failure for unsupported socket type") 124 | } 125 | } 126 | 127 | func TestHandleBindMessage(t *testing.T) { 128 | netStack, _ := NewVirtualNetworkStack() 129 | config := &WireGuardConfig{ 130 | Interface: InterfaceConfig{ 131 | PrivateKey: "YWJjZGVmZ2hpamtsb21ub3Bxcnp0dXZ3eHl6MTIzNDU2Nzg5MA==", 132 | }, 133 | } 134 | _, ipnet, _ := net.ParseCIDR("10.0.0.2/24") 135 | config.Interface.Address = ipnet 136 | 137 | wgProxy := &WireGuardProxy{config: config} 138 | server, _ := NewIPCServer(netStack, wgProxy) 139 | 140 | // Create a connection first 141 | conn, _ := netStack.CreateConnection("tcp") 142 | 143 | msg := &IPCMessage{ 144 | Type: "bind", 145 | ConnID: conn.ID, 146 | Address: "10.0.0.2", 147 | Port: 8080, 148 | } 149 | 150 | response := server.handleBind(msg) 151 | if !response.Success { 152 | t.Errorf("Expected success, got error: %s", response.Error) 153 | } 154 | 155 | // Test bind with non-existent connection 156 | msg.ConnID = 999 157 | response = server.handleBind(msg) 158 | if response.Success { 159 | t.Error("Expected failure for non-existent connection") 160 | } 161 | } 162 | 163 | func TestHandleConnectMessage(t *testing.T) { 164 | netStack, _ := NewVirtualNetworkStack() 165 | config := &WireGuardConfig{ 166 | Interface: InterfaceConfig{ 167 | PrivateKey: "YWJjZGVmZ2hpamtsb21ub3Bxcnp0dXZ3eHl6MTIzNDU2Nzg5MA==", 168 | }, 169 | } 170 | _, ipnet, _ := net.ParseCIDR("10.0.0.2/24") 171 | config.Interface.Address = ipnet 172 | 173 | wgProxy := &WireGuardProxy{config: config} 174 | server, _ := NewIPCServer(netStack, wgProxy) 175 | 176 | // Create a connection first 177 | conn, _ := netStack.CreateConnection("tcp") 178 | 179 | msg := &IPCMessage{ 180 | Type: "connect", 181 | ConnID: conn.ID, 182 | Address: "192.168.1.1", 183 | Port: 80, 184 | } 185 | 186 | response := server.handleConnect(msg) 187 | if !response.Success { 188 | t.Errorf("Expected success, got error: %s", response.Error) 189 | } 190 | 191 | // Test connect with non-existent connection 192 | msg.ConnID = 999 193 | response = server.handleConnect(msg) 194 | if response.Success { 195 | t.Error("Expected failure for non-existent connection") 196 | } 197 | } 198 | 199 | func TestHandleSendMessage(t *testing.T) { 200 | netStack, _ := NewVirtualNetworkStack() 201 | config := &WireGuardConfig{ 202 | Interface: InterfaceConfig{ 203 | PrivateKey: "YWJjZGVmZ2hpamtsb21ub3Bxcnp0dXZ3eHl6MTIzNDU2Nzg5MA==", 204 | }, 205 | } 206 | _, ipnet, _ := net.ParseCIDR("10.0.0.2/24") 207 | config.Interface.Address = ipnet 208 | 209 | wgProxy := &WireGuardProxy{config: config} 210 | server, _ := NewIPCServer(netStack, wgProxy) 211 | 212 | // Create and connect a connection 213 | conn, _ := netStack.CreateConnection("tcp") 214 | remoteAddr := &net.TCPAddr{IP: net.ParseIP("192.168.1.1"), Port: 80} 215 | netStack.ConnectConnection(conn.ID, remoteAddr) 216 | 217 | testData := []byte("hello world") 218 | msg := &IPCMessage{ 219 | Type: "send", 220 | ConnID: conn.ID, 221 | Data: testData, 222 | } 223 | 224 | response := server.handleSend(msg) 225 | if !response.Success { 226 | t.Errorf("Expected success, got error: %s", response.Error) 227 | } 228 | } 229 | 230 | func TestHandleRecvMessage(t *testing.T) { 231 | netStack, _ := NewVirtualNetworkStack() 232 | config := &WireGuardConfig{ 233 | Interface: InterfaceConfig{ 234 | PrivateKey: "YWJjZGVmZ2hpamtsb21ub3Bxcnp0dXZ3eHl6MTIzNDU2Nzg5MA==", 235 | }, 236 | } 237 | _, ipnet, _ := net.ParseCIDR("10.0.0.2/24") 238 | config.Interface.Address = ipnet 239 | 240 | wgProxy := &WireGuardProxy{config: config} 241 | server, _ := NewIPCServer(netStack, wgProxy) 242 | 243 | // Create a connection 244 | conn, _ := netStack.CreateConnection("tcp") 245 | 246 | // Put some data in the incoming channel 247 | testData := []byte("hello world") 248 | go func() { 249 | time.Sleep(10 * time.Millisecond) 250 | conn.IncomingData <- testData 251 | }() 252 | 253 | msg := &IPCMessage{ 254 | Type: "recv", 255 | ConnID: conn.ID, 256 | } 257 | 258 | response := server.handleRecv(msg) 259 | if !response.Success { 260 | t.Errorf("Expected success, got error: %s", response.Error) 261 | } 262 | 263 | if string(response.Data) != string(testData) { 264 | t.Errorf("Expected data %s, got %s", string(testData), string(response.Data)) 265 | } 266 | } 267 | 268 | func TestHandleCloseMessage(t *testing.T) { 269 | netStack, _ := NewVirtualNetworkStack() 270 | config := &WireGuardConfig{ 271 | Interface: InterfaceConfig{ 272 | PrivateKey: "YWJjZGVmZ2hpamtsb21ub3Bxcnp0dXZ3eHl6MTIzNDU2Nzg5MA==", 273 | }, 274 | } 275 | _, ipnet, _ := net.ParseCIDR("10.0.0.2/24") 276 | config.Interface.Address = ipnet 277 | 278 | wgProxy := &WireGuardProxy{config: config} 279 | server, _ := NewIPCServer(netStack, wgProxy) 280 | 281 | // Create a connection 282 | conn, _ := netStack.CreateConnection("tcp") 283 | 284 | msg := &IPCMessage{ 285 | Type: "close", 286 | ConnID: conn.ID, 287 | } 288 | 289 | response := server.handleClose(msg) 290 | if !response.Success { 291 | t.Errorf("Expected success, got error: %s", response.Error) 292 | } 293 | 294 | // Verify connection was actually closed 295 | netStack.mu.RLock() 296 | _, exists := netStack.connections[conn.ID] 297 | netStack.mu.RUnlock() 298 | 299 | if exists { 300 | t.Error("Connection should have been removed") 301 | } 302 | } 303 | 304 | func TestHandleUnknownMessage(t *testing.T) { 305 | netStack, _ := NewVirtualNetworkStack() 306 | config := &WireGuardConfig{ 307 | Interface: InterfaceConfig{ 308 | PrivateKey: "YWJjZGVmZ2hpamtsb21ub3Bxcnp0dXZ3eHl6MTIzNDU2Nzg5MA==", 309 | }, 310 | } 311 | _, ipnet, _ := net.ParseCIDR("10.0.0.2/24") 312 | config.Interface.Address = ipnet 313 | 314 | wgProxy := &WireGuardProxy{config: config} 315 | server, _ := NewIPCServer(netStack, wgProxy) 316 | 317 | msg := &IPCMessage{ 318 | Type: "unknown", 319 | } 320 | 321 | response := server.handleMessage(msg) 322 | if response.Success { 323 | t.Error("Expected failure for unknown message type") 324 | } 325 | 326 | if response.Error == "" { 327 | t.Error("Expected error message for unknown type") 328 | } 329 | } 330 | 331 | func TestIPCMessageSerialization(t *testing.T) { 332 | msg := &IPCMessage{ 333 | Type: "socket", 334 | ConnID: 123, 335 | SocketFD: 456, 336 | Domain: 2, 337 | SockType: 1, 338 | Protocol: 0, 339 | Address: "10.0.0.2", 340 | Port: 8080, 341 | Data: []byte("test data"), 342 | Error: "", 343 | } 344 | 345 | // Test JSON marshaling 346 | data, err := json.Marshal(msg) 347 | if err != nil { 348 | t.Fatalf("Failed to marshal IPC message: %v", err) 349 | } 350 | 351 | // Test JSON unmarshaling 352 | var unmarshaled IPCMessage 353 | err = json.Unmarshal(data, &unmarshaled) 354 | if err != nil { 355 | t.Fatalf("Failed to unmarshal IPC message: %v", err) 356 | } 357 | 358 | // Verify fields 359 | if unmarshaled.Type != msg.Type { 360 | t.Errorf("Type mismatch: expected %s, got %s", msg.Type, unmarshaled.Type) 361 | } 362 | if unmarshaled.ConnID != msg.ConnID { 363 | t.Errorf("ConnID mismatch: expected %d, got %d", msg.ConnID, unmarshaled.ConnID) 364 | } 365 | if unmarshaled.Address != msg.Address { 366 | t.Errorf("Address mismatch: expected %s, got %s", msg.Address, unmarshaled.Address) 367 | } 368 | if string(unmarshaled.Data) != string(msg.Data) { 369 | t.Errorf("Data mismatch: expected %s, got %s", string(msg.Data), string(unmarshaled.Data)) 370 | } 371 | } 372 | 373 | func TestIPCResponseSerialization(t *testing.T) { 374 | response := &IPCResponse{ 375 | Success: true, 376 | ConnID: 123, 377 | Data: []byte("response data"), 378 | Error: "", 379 | } 380 | 381 | // Test JSON marshaling 382 | data, err := json.Marshal(response) 383 | if err != nil { 384 | t.Fatalf("Failed to marshal IPC response: %v", err) 385 | } 386 | 387 | // Test JSON unmarshaling 388 | var unmarshaled IPCResponse 389 | err = json.Unmarshal(data, &unmarshaled) 390 | if err != nil { 391 | t.Fatalf("Failed to unmarshal IPC response: %v", err) 392 | } 393 | 394 | // Verify fields 395 | if unmarshaled.Success != response.Success { 396 | t.Errorf("Success mismatch: expected %t, got %t", response.Success, unmarshaled.Success) 397 | } 398 | if unmarshaled.ConnID != response.ConnID { 399 | t.Errorf("ConnID mismatch: expected %d, got %d", response.ConnID, unmarshaled.ConnID) 400 | } 401 | if string(unmarshaled.Data) != string(response.Data) { 402 | t.Errorf("Data mismatch: expected %s, got %s", string(response.Data), string(unmarshaled.Data)) 403 | } 404 | } 405 | -------------------------------------------------------------------------------- /lib/intercept.c: -------------------------------------------------------------------------------- 1 | #define _GNU_SOURCE 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | // JSON parsing would normally use a library, but for simplicity, we'll use a basic approach 16 | typedef struct { 17 | char type[32]; 18 | uint32_t conn_id; 19 | int socket_fd; 20 | int domain; 21 | int sock_type; 22 | int protocol; 23 | char address[64]; 24 | int port; 25 | size_t data_len; 26 | char *data; 27 | char error[256]; 28 | } IPCMessage; 29 | 30 | typedef struct { 31 | int success; 32 | uint32_t conn_id; 33 | size_t data_len; 34 | char *data; 35 | char error[256]; 36 | } IPCResponse; 37 | 38 | // Connection mapping 39 | typedef struct { 40 | int fd; 41 | uint32_t conn_id; 42 | } FDMapping; 43 | 44 | static FDMapping fd_mappings[1024]; 45 | static int next_fake_fd = 1000; 46 | static pthread_mutex_t fd_mutex = PTHREAD_MUTEX_INITIALIZER; 47 | static int ipc_socket = -1; 48 | static char ipc_path[256]; 49 | 50 | // Function pointers to real functions 51 | static int (*real_socket)(int domain, int type, int protocol); 52 | static int (*real_bind)(int sockfd, const struct sockaddr *addr, socklen_t addrlen); 53 | static int (*real_listen)(int sockfd, int backlog); 54 | static int (*real_accept)(int sockfd, struct sockaddr *addr, socklen_t *addrlen); 55 | static int (*real_connect)(int sockfd, const struct sockaddr *addr, socklen_t addrlen); 56 | static ssize_t (*real_send)(int sockfd, const void *buf, size_t len, int flags); 57 | static ssize_t (*real_recv)(int sockfd, void *buf, size_t len, int flags); 58 | static ssize_t (*real_sendto)(int sockfd, const void *buf, size_t len, int flags, 59 | const struct sockaddr *dest_addr, socklen_t addrlen); 60 | static ssize_t (*real_recvfrom)(int sockfd, void *buf, size_t len, int flags, 61 | struct sockaddr *src_addr, socklen_t *addrlen); 62 | static int (*real_close)(int fd); 63 | 64 | // Initialize the library 65 | __attribute__((constructor)) 66 | void init_intercept() { 67 | // Get real function pointers 68 | real_socket = dlsym(RTLD_NEXT, "socket"); 69 | real_bind = dlsym(RTLD_NEXT, "bind"); 70 | real_listen = dlsym(RTLD_NEXT, "listen"); 71 | real_accept = dlsym(RTLD_NEXT, "accept"); 72 | real_connect = dlsym(RTLD_NEXT, "connect"); 73 | real_send = dlsym(RTLD_NEXT, "send"); 74 | real_recv = dlsym(RTLD_NEXT, "recv"); 75 | real_sendto = dlsym(RTLD_NEXT, "sendto"); 76 | real_recvfrom = dlsym(RTLD_NEXT, "recvfrom"); 77 | real_close = dlsym(RTLD_NEXT, "close"); 78 | 79 | // Get IPC socket path from environment 80 | const char *path = getenv("WRAPGUARD_IPC_PATH"); 81 | if (path) { 82 | strncpy(ipc_path, path, sizeof(ipc_path) - 1); 83 | ipc_path[sizeof(ipc_path) - 1] = '\0'; 84 | } 85 | 86 | // Initialize FD mappings 87 | memset(fd_mappings, 0, sizeof(fd_mappings)); 88 | } 89 | 90 | // Connect to IPC server 91 | static int connect_ipc() { 92 | if (ipc_socket >= 0) { 93 | return ipc_socket; 94 | } 95 | 96 | ipc_socket = real_socket(AF_UNIX, SOCK_STREAM, 0); 97 | if (ipc_socket < 0) { 98 | return -1; 99 | } 100 | 101 | struct sockaddr_un addr; 102 | memset(&addr, 0, sizeof(addr)); 103 | addr.sun_family = AF_UNIX; 104 | strncpy(addr.sun_path, ipc_path, sizeof(addr.sun_path) - 1); 105 | 106 | if (real_connect(ipc_socket, (struct sockaddr *)&addr, sizeof(addr)) < 0) { 107 | real_close(ipc_socket); 108 | ipc_socket = -1; 109 | return -1; 110 | } 111 | 112 | return ipc_socket; 113 | } 114 | 115 | // Simple JSON serialization helpers 116 | static void write_json_string(char *buf, size_t *pos, const char *key, const char *value) { 117 | *pos += snprintf(buf + *pos, 4096 - *pos, "\"%s\":\"%s\",", key, value); 118 | } 119 | 120 | static void write_json_int(char *buf, size_t *pos, const char *key, int value) { 121 | *pos += snprintf(buf + *pos, 4096 - *pos, "\"%s\":%d,", key, value); 122 | } 123 | 124 | static void write_json_uint32(char *buf, size_t *pos, const char *key, uint32_t value) { 125 | *pos += snprintf(buf + *pos, 4096 - *pos, "\"%s\":%u,", key, value); 126 | } 127 | 128 | // Send IPC message and get response 129 | static int send_ipc_message(IPCMessage *msg, IPCResponse *resp) { 130 | int sock = connect_ipc(); 131 | if (sock < 0) { 132 | return -1; 133 | } 134 | 135 | // Build JSON message 136 | char json_buf[4096]; 137 | size_t pos = 0; 138 | 139 | json_buf[pos++] = '{'; 140 | write_json_string(json_buf, &pos, "type", msg->type); 141 | 142 | if (msg->conn_id > 0) { 143 | write_json_uint32(json_buf, &pos, "conn_id", msg->conn_id); 144 | } 145 | if (msg->socket_fd > 0) { 146 | write_json_int(json_buf, &pos, "socket_fd", msg->socket_fd); 147 | } 148 | if (msg->domain > 0) { 149 | write_json_int(json_buf, &pos, "domain", msg->domain); 150 | } 151 | if (msg->sock_type > 0) { 152 | write_json_int(json_buf, &pos, "sock_type", msg->sock_type); 153 | } 154 | if (msg->protocol >= 0) { 155 | write_json_int(json_buf, &pos, "protocol", msg->protocol); 156 | } 157 | if (strlen(msg->address) > 0) { 158 | write_json_string(json_buf, &pos, "address", msg->address); 159 | } 160 | if (msg->port > 0) { 161 | write_json_int(json_buf, &pos, "port", msg->port); 162 | } 163 | 164 | // Remove trailing comma and close JSON 165 | if (json_buf[pos-1] == ',') pos--; 166 | json_buf[pos++] = '}'; 167 | json_buf[pos++] = '\n'; 168 | json_buf[pos] = '\0'; 169 | 170 | // Send message 171 | if (write(sock, json_buf, pos) < 0) { 172 | return -1; 173 | } 174 | 175 | // Read response (simplified parsing) 176 | char resp_buf[4096]; 177 | ssize_t n = read(sock, resp_buf, sizeof(resp_buf) - 1); 178 | if (n <= 0) { 179 | return -1; 180 | } 181 | resp_buf[n] = '\0'; 182 | 183 | // Parse response (very basic) 184 | resp->success = (strstr(resp_buf, "\"success\":true") != NULL); 185 | 186 | char *conn_id_str = strstr(resp_buf, "\"conn_id\":"); 187 | if (conn_id_str) { 188 | resp->conn_id = atoi(conn_id_str + 10); 189 | } 190 | 191 | char *error_str = strstr(resp_buf, "\"error\":\""); 192 | if (error_str) { 193 | error_str += 9; 194 | char *end = strchr(error_str, '"'); 195 | if (end) { 196 | size_t len = end - error_str; 197 | if (len > sizeof(resp->error) - 1) { 198 | len = sizeof(resp->error) - 1; 199 | } 200 | strncpy(resp->error, error_str, len); 201 | resp->error[len] = '\0'; 202 | } 203 | } 204 | 205 | return 0; 206 | } 207 | 208 | // Map connection ID to fake FD 209 | static int map_conn_to_fd(uint32_t conn_id) { 210 | pthread_mutex_lock(&fd_mutex); 211 | 212 | int fd = next_fake_fd++; 213 | if (fd < 1024) { 214 | fd_mappings[fd].fd = fd; 215 | fd_mappings[fd].conn_id = conn_id; 216 | } 217 | 218 | pthread_mutex_unlock(&fd_mutex); 219 | return fd; 220 | } 221 | 222 | // Get connection ID from FD 223 | static uint32_t get_conn_id(int fd) { 224 | if (fd < 1000 || fd >= 1024) { 225 | return 0; 226 | } 227 | 228 | pthread_mutex_lock(&fd_mutex); 229 | uint32_t conn_id = fd_mappings[fd].conn_id; 230 | pthread_mutex_unlock(&fd_mutex); 231 | 232 | return conn_id; 233 | } 234 | 235 | // Intercepted functions 236 | int socket(int domain, int type, int protocol) { 237 | // Only intercept AF_INET sockets 238 | if (domain != AF_INET) { 239 | return real_socket(domain, type, protocol); 240 | } 241 | 242 | IPCMessage msg = {0}; 243 | IPCResponse resp = {0}; 244 | 245 | strcpy(msg.type, "socket"); 246 | msg.domain = domain; 247 | msg.sock_type = type; 248 | msg.protocol = protocol; 249 | 250 | if (send_ipc_message(&msg, &resp) < 0 || !resp.success) { 251 | errno = ENOTSUP; 252 | return -1; 253 | } 254 | 255 | return map_conn_to_fd(resp.conn_id); 256 | } 257 | 258 | int bind(int sockfd, const struct sockaddr *addr, socklen_t addrlen) { 259 | uint32_t conn_id = get_conn_id(sockfd); 260 | if (conn_id == 0) { 261 | return real_bind(sockfd, addr, addrlen); 262 | } 263 | 264 | if (addr->sa_family != AF_INET) { 265 | errno = EAFNOSUPPORT; 266 | return -1; 267 | } 268 | 269 | struct sockaddr_in *sin = (struct sockaddr_in *)addr; 270 | 271 | IPCMessage msg = {0}; 272 | IPCResponse resp = {0}; 273 | 274 | strcpy(msg.type, "bind"); 275 | msg.conn_id = conn_id; 276 | inet_ntop(AF_INET, &sin->sin_addr, msg.address, sizeof(msg.address)); 277 | msg.port = ntohs(sin->sin_port); 278 | 279 | if (send_ipc_message(&msg, &resp) < 0 || !resp.success) { 280 | errno = EADDRINUSE; 281 | return -1; 282 | } 283 | 284 | return 0; 285 | } 286 | 287 | int listen(int sockfd, int backlog) { 288 | uint32_t conn_id = get_conn_id(sockfd); 289 | if (conn_id == 0) { 290 | return real_listen(sockfd, backlog); 291 | } 292 | 293 | IPCMessage msg = {0}; 294 | IPCResponse resp = {0}; 295 | 296 | strcpy(msg.type, "listen"); 297 | msg.conn_id = conn_id; 298 | 299 | if (send_ipc_message(&msg, &resp) < 0 || !resp.success) { 300 | errno = EOPNOTSUPP; 301 | return -1; 302 | } 303 | 304 | return 0; 305 | } 306 | 307 | int accept(int sockfd, struct sockaddr *addr, socklen_t *addrlen) { 308 | uint32_t conn_id = get_conn_id(sockfd); 309 | if (conn_id == 0) { 310 | return real_accept(sockfd, addr, addrlen); 311 | } 312 | 313 | IPCMessage msg = {0}; 314 | IPCResponse resp = {0}; 315 | 316 | strcpy(msg.type, "accept"); 317 | msg.conn_id = conn_id; 318 | 319 | if (send_ipc_message(&msg, &resp) < 0 || !resp.success) { 320 | errno = EAGAIN; 321 | return -1; 322 | } 323 | 324 | // TODO: Fill in addr if provided 325 | 326 | return map_conn_to_fd(resp.conn_id); 327 | } 328 | 329 | int connect(int sockfd, const struct sockaddr *addr, socklen_t addrlen) { 330 | uint32_t conn_id = get_conn_id(sockfd); 331 | if (conn_id == 0) { 332 | return real_connect(sockfd, addr, addrlen); 333 | } 334 | 335 | if (addr->sa_family != AF_INET) { 336 | errno = EAFNOSUPPORT; 337 | return -1; 338 | } 339 | 340 | struct sockaddr_in *sin = (struct sockaddr_in *)addr; 341 | 342 | IPCMessage msg = {0}; 343 | IPCResponse resp = {0}; 344 | 345 | strcpy(msg.type, "connect"); 346 | msg.conn_id = conn_id; 347 | inet_ntop(AF_INET, &sin->sin_addr, msg.address, sizeof(msg.address)); 348 | msg.port = ntohs(sin->sin_port); 349 | 350 | if (send_ipc_message(&msg, &resp) < 0 || !resp.success) { 351 | errno = ECONNREFUSED; 352 | return -1; 353 | } 354 | 355 | return 0; 356 | } 357 | 358 | ssize_t send(int sockfd, const void *buf, size_t len, int flags) { 359 | uint32_t conn_id = get_conn_id(sockfd); 360 | if (conn_id == 0) { 361 | return real_send(sockfd, buf, len, flags); 362 | } 363 | 364 | IPCMessage msg = {0}; 365 | IPCResponse resp = {0}; 366 | 367 | strcpy(msg.type, "send"); 368 | msg.conn_id = conn_id; 369 | msg.data = (char *)buf; 370 | msg.data_len = len; 371 | 372 | if (send_ipc_message(&msg, &resp) < 0 || !resp.success) { 373 | errno = EPIPE; 374 | return -1; 375 | } 376 | 377 | return len; 378 | } 379 | 380 | ssize_t recv(int sockfd, void *buf, size_t len, int flags) { 381 | uint32_t conn_id = get_conn_id(sockfd); 382 | if (conn_id == 0) { 383 | return real_recv(sockfd, buf, len, flags); 384 | } 385 | 386 | IPCMessage msg = {0}; 387 | IPCResponse resp = {0}; 388 | 389 | strcpy(msg.type, "recv"); 390 | msg.conn_id = conn_id; 391 | 392 | if (send_ipc_message(&msg, &resp) < 0 || !resp.success) { 393 | if (flags & MSG_DONTWAIT) { 394 | errno = EAGAIN; 395 | } else { 396 | errno = ECONNRESET; 397 | } 398 | return -1; 399 | } 400 | 401 | // Copy received data 402 | size_t copy_len = resp.data_len < len ? resp.data_len : len; 403 | if (resp.data && copy_len > 0) { 404 | memcpy(buf, resp.data, copy_len); 405 | } 406 | 407 | return copy_len; 408 | } 409 | 410 | int close(int fd) { 411 | uint32_t conn_id = get_conn_id(fd); 412 | if (conn_id == 0) { 413 | return real_close(fd); 414 | } 415 | 416 | IPCMessage msg = {0}; 417 | IPCResponse resp = {0}; 418 | 419 | strcpy(msg.type, "close"); 420 | msg.conn_id = conn_id; 421 | 422 | send_ipc_message(&msg, &resp); 423 | 424 | // Remove mapping 425 | pthread_mutex_lock(&fd_mutex); 426 | if (fd < 1024) { 427 | fd_mappings[fd].conn_id = 0; 428 | } 429 | pthread_mutex_unlock(&fd_mutex); 430 | 431 | return 0; 432 | } 433 | 434 | // Also intercept sendto and recvfrom for UDP 435 | ssize_t sendto(int sockfd, const void *buf, size_t len, int flags, 436 | const struct sockaddr *dest_addr, socklen_t addrlen) { 437 | uint32_t conn_id = get_conn_id(sockfd); 438 | if (conn_id == 0) { 439 | return real_sendto(sockfd, buf, len, flags, dest_addr, addrlen); 440 | } 441 | 442 | // For UDP, we might need to handle destination address 443 | return send(sockfd, buf, len, flags); 444 | } 445 | 446 | ssize_t recvfrom(int sockfd, void *buf, size_t len, int flags, 447 | struct sockaddr *src_addr, socklen_t *addrlen) { 448 | uint32_t conn_id = get_conn_id(sockfd); 449 | if (conn_id == 0) { 450 | return real_recvfrom(sockfd, buf, len, flags, src_addr, addrlen); 451 | } 452 | 453 | // For UDP, we might need to fill in source address 454 | return recv(sockfd, buf, len, flags); 455 | } 456 | -------------------------------------------------------------------------------- /logging.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io" 7 | "log" 8 | "strings" 9 | "time" 10 | ) 11 | 12 | // LogLevel represents the severity of a log message 13 | type LogLevel int 14 | 15 | const ( 16 | LogLevelError LogLevel = iota 17 | LogLevelWarn 18 | LogLevelInfo 19 | LogLevelDebug 20 | ) 21 | 22 | // String returns the string representation of the log level 23 | func (l LogLevel) String() string { 24 | switch l { 25 | case LogLevelError: 26 | return "error" 27 | case LogLevelWarn: 28 | return "warn" 29 | case LogLevelInfo: 30 | return "info" 31 | case LogLevelDebug: 32 | return "debug" 33 | default: 34 | return "unknown" 35 | } 36 | } 37 | 38 | // ParseLogLevel parses a string into a LogLevel 39 | func ParseLogLevel(s string) (LogLevel, error) { 40 | switch strings.ToLower(s) { 41 | case "error": 42 | return LogLevelError, nil 43 | case "warn", "warning": 44 | return LogLevelWarn, nil 45 | case "info": 46 | return LogLevelInfo, nil 47 | case "debug": 48 | return LogLevelDebug, nil 49 | default: 50 | return LogLevelInfo, fmt.Errorf("unknown log level: %s", s) 51 | } 52 | } 53 | 54 | // LogEntry represents a structured log entry 55 | type LogEntry struct { 56 | Timestamp string `json:"timestamp"` 57 | Level string `json:"level"` 58 | Message string `json:"message"` 59 | Component string `json:"component,omitempty"` 60 | } 61 | 62 | // Logger provides structured JSON logging 63 | type Logger struct { 64 | level LogLevel 65 | output io.Writer 66 | } 67 | 68 | // NewLogger creates a new logger with the specified level and output 69 | func NewLogger(level LogLevel, output io.Writer) *Logger { 70 | return &Logger{ 71 | level: level, 72 | output: output, 73 | } 74 | } 75 | 76 | // shouldLog checks if a message at the given level should be logged 77 | func (l *Logger) shouldLog(level LogLevel) bool { 78 | return level <= l.level 79 | } 80 | 81 | // log writes a log entry to the output 82 | func (l *Logger) log(level LogLevel, component, message string) { 83 | if !l.shouldLog(level) { 84 | return 85 | } 86 | 87 | entry := LogEntry{ 88 | Timestamp: time.Now().UTC().Format(time.RFC3339), 89 | Level: level.String(), 90 | Message: message, 91 | Component: component, 92 | } 93 | 94 | data, err := json.Marshal(entry) 95 | if err != nil { 96 | // Fallback to simple format if JSON marshaling fails 97 | fmt.Fprintf(l.output, "LOG_ERROR: failed to marshal log entry: %v\n", err) 98 | return 99 | } 100 | 101 | fmt.Fprintln(l.output, string(data)) 102 | } 103 | 104 | // Error logs an error message 105 | func (l *Logger) Error(message string) { 106 | l.log(LogLevelError, "", message) 107 | } 108 | 109 | // Errorf logs a formatted error message 110 | func (l *Logger) Errorf(format string, args ...interface{}) { 111 | l.log(LogLevelError, "", fmt.Sprintf(format, args...)) 112 | } 113 | 114 | // ErrorWithComponent logs an error message with a component 115 | func (l *Logger) ErrorWithComponent(component, message string) { 116 | l.log(LogLevelError, component, message) 117 | } 118 | 119 | // Warn logs a warning message 120 | func (l *Logger) Warn(message string) { 121 | l.log(LogLevelWarn, "", message) 122 | } 123 | 124 | // Warnf logs a formatted warning message 125 | func (l *Logger) Warnf(format string, args ...interface{}) { 126 | l.log(LogLevelWarn, "", fmt.Sprintf(format, args...)) 127 | } 128 | 129 | // Info logs an info message 130 | func (l *Logger) Info(message string) { 131 | l.log(LogLevelInfo, "", message) 132 | } 133 | 134 | // Infof logs a formatted info message 135 | func (l *Logger) Infof(format string, args ...interface{}) { 136 | l.log(LogLevelInfo, "", fmt.Sprintf(format, args...)) 137 | } 138 | 139 | // InfoWithComponent logs an info message with a component 140 | func (l *Logger) InfoWithComponent(component, message string) { 141 | l.log(LogLevelInfo, component, message) 142 | } 143 | 144 | // Debug logs a debug message 145 | func (l *Logger) Debug(message string) { 146 | l.log(LogLevelDebug, "", message) 147 | } 148 | 149 | // Debugf logs a formatted debug message 150 | func (l *Logger) Debugf(format string, args ...interface{}) { 151 | l.log(LogLevelDebug, "", fmt.Sprintf(format, args...)) 152 | } 153 | 154 | // DebugWithComponent logs a debug message with a component 155 | func (l *Logger) DebugWithComponent(component, message string) { 156 | l.log(LogLevelDebug, component, message) 157 | } 158 | 159 | // WireGuardLogger creates a logger compatible with WireGuard device logger 160 | func (l *Logger) WireGuardLogger() *log.Logger { 161 | return log.New(&wireGuardLogWriter{logger: l}, "", 0) 162 | } 163 | 164 | // wireGuardLogWriter adapts our Logger to work with standard log.Logger 165 | type wireGuardLogWriter struct { 166 | logger *Logger 167 | } 168 | 169 | func (w *wireGuardLogWriter) Write(p []byte) (n int, err error) { 170 | message := strings.TrimSpace(string(p)) 171 | if message != "" { 172 | w.logger.DebugWithComponent("wireguard", message) 173 | } 174 | return len(p), nil 175 | } 176 | -------------------------------------------------------------------------------- /logging_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "strings" 7 | "testing" 8 | ) 9 | 10 | func TestParseLogLevel(t *testing.T) { 11 | tests := []struct { 12 | input string 13 | expected LogLevel 14 | wantErr bool 15 | }{ 16 | {"error", LogLevelError, false}, 17 | {"warn", LogLevelWarn, false}, 18 | {"warning", LogLevelWarn, false}, 19 | {"info", LogLevelInfo, false}, 20 | {"debug", LogLevelDebug, false}, 21 | {"ERROR", LogLevelError, false}, // Case insensitive 22 | {"INFO", LogLevelInfo, false}, 23 | {"invalid", LogLevelInfo, true}, 24 | {"", LogLevelInfo, true}, 25 | } 26 | 27 | for _, tt := range tests { 28 | t.Run(tt.input, func(t *testing.T) { 29 | level, err := ParseLogLevel(tt.input) 30 | if (err != nil) != tt.wantErr { 31 | t.Errorf("ParseLogLevel(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) 32 | return 33 | } 34 | if !tt.wantErr && level != tt.expected { 35 | t.Errorf("ParseLogLevel(%q) = %v, want %v", tt.input, level, tt.expected) 36 | } 37 | }) 38 | } 39 | } 40 | 41 | func TestLogLevelString(t *testing.T) { 42 | tests := []struct { 43 | level LogLevel 44 | expected string 45 | }{ 46 | {LogLevelError, "error"}, 47 | {LogLevelWarn, "warn"}, 48 | {LogLevelInfo, "info"}, 49 | {LogLevelDebug, "debug"}, 50 | {LogLevel(999), "unknown"}, // Invalid level 51 | } 52 | 53 | for _, tt := range tests { 54 | t.Run(tt.expected, func(t *testing.T) { 55 | result := tt.level.String() 56 | if result != tt.expected { 57 | t.Errorf("LogLevel(%d).String() = %q, want %q", tt.level, result, tt.expected) 58 | } 59 | }) 60 | } 61 | } 62 | 63 | func TestLoggerOutput(t *testing.T) { 64 | var buf bytes.Buffer 65 | logger := NewLogger(LogLevelDebug, &buf) 66 | 67 | logger.Error("test error message") 68 | 69 | // Parse the JSON output 70 | var entry LogEntry 71 | line := strings.TrimSpace(buf.String()) 72 | if err := json.Unmarshal([]byte(line), &entry); err != nil { 73 | t.Fatalf("Failed to parse log output as JSON: %v", err) 74 | } 75 | 76 | if entry.Level != "error" { 77 | t.Errorf("Expected log level 'error', got %q", entry.Level) 78 | } 79 | if entry.Message != "test error message" { 80 | t.Errorf("Expected message 'test error message', got %q", entry.Message) 81 | } 82 | if entry.Timestamp == "" { 83 | t.Error("Expected timestamp to be set") 84 | } 85 | } 86 | 87 | func TestLoggerLevels(t *testing.T) { 88 | tests := []struct { 89 | loggerLevel LogLevel 90 | logLevel LogLevel 91 | shouldLog bool 92 | }{ 93 | {LogLevelError, LogLevelError, true}, 94 | {LogLevelError, LogLevelWarn, false}, 95 | {LogLevelWarn, LogLevelError, true}, 96 | {LogLevelWarn, LogLevelWarn, true}, 97 | {LogLevelWarn, LogLevelInfo, false}, 98 | {LogLevelInfo, LogLevelError, true}, 99 | {LogLevelInfo, LogLevelWarn, true}, 100 | {LogLevelInfo, LogLevelInfo, true}, 101 | {LogLevelInfo, LogLevelDebug, false}, 102 | {LogLevelDebug, LogLevelError, true}, 103 | {LogLevelDebug, LogLevelWarn, true}, 104 | {LogLevelDebug, LogLevelInfo, true}, 105 | {LogLevelDebug, LogLevelDebug, true}, 106 | } 107 | 108 | for _, tt := range tests { 109 | t.Run("", func(t *testing.T) { 110 | var buf bytes.Buffer 111 | logger := NewLogger(tt.loggerLevel, &buf) 112 | 113 | // Test the shouldLog method 114 | result := logger.shouldLog(tt.logLevel) 115 | if result != tt.shouldLog { 116 | t.Errorf("shouldLog(%v) with logger level %v = %v, want %v", 117 | tt.logLevel, tt.loggerLevel, result, tt.shouldLog) 118 | } 119 | 120 | // Test actual logging 121 | buf.Reset() 122 | logger.log(tt.logLevel, "", "test message") 123 | 124 | hasOutput := buf.Len() > 0 125 | if hasOutput != tt.shouldLog { 126 | t.Errorf("Expected output: %v, got output: %v", tt.shouldLog, hasOutput) 127 | } 128 | }) 129 | } 130 | } 131 | 132 | func TestLoggerMethods(t *testing.T) { 133 | var buf bytes.Buffer 134 | logger := NewLogger(LogLevelDebug, &buf) 135 | 136 | tests := []struct { 137 | name string 138 | logFunc func() 139 | expected string 140 | }{ 141 | { 142 | name: "Error", 143 | logFunc: func() { logger.Error("error message") }, 144 | expected: "error", 145 | }, 146 | { 147 | name: "Errorf", 148 | logFunc: func() { logger.Errorf("error %s", "formatted") }, 149 | expected: "error", 150 | }, 151 | { 152 | name: "Warn", 153 | logFunc: func() { logger.Warn("warn message") }, 154 | expected: "warn", 155 | }, 156 | { 157 | name: "Warnf", 158 | logFunc: func() { logger.Warnf("warn %d", 123) }, 159 | expected: "warn", 160 | }, 161 | { 162 | name: "Info", 163 | logFunc: func() { logger.Info("info message") }, 164 | expected: "info", 165 | }, 166 | { 167 | name: "Infof", 168 | logFunc: func() { logger.Infof("info %v", true) }, 169 | expected: "info", 170 | }, 171 | { 172 | name: "Debug", 173 | logFunc: func() { logger.Debug("debug message") }, 174 | expected: "debug", 175 | }, 176 | { 177 | name: "Debugf", 178 | logFunc: func() { logger.Debugf("debug %f", 3.14) }, 179 | expected: "debug", 180 | }, 181 | } 182 | 183 | for _, tt := range tests { 184 | t.Run(tt.name, func(t *testing.T) { 185 | buf.Reset() 186 | tt.logFunc() 187 | 188 | var entry LogEntry 189 | line := strings.TrimSpace(buf.String()) 190 | if err := json.Unmarshal([]byte(line), &entry); err != nil { 191 | t.Fatalf("Failed to parse log output as JSON: %v", err) 192 | } 193 | 194 | if entry.Level != tt.expected { 195 | t.Errorf("Expected log level %q, got %q", tt.expected, entry.Level) 196 | } 197 | }) 198 | } 199 | } 200 | 201 | func TestLoggerWithComponent(t *testing.T) { 202 | var buf bytes.Buffer 203 | logger := NewLogger(LogLevelDebug, &buf) 204 | 205 | logger.ErrorWithComponent("test-component", "error message") 206 | 207 | var entry LogEntry 208 | line := strings.TrimSpace(buf.String()) 209 | if err := json.Unmarshal([]byte(line), &entry); err != nil { 210 | t.Fatalf("Failed to parse log output as JSON: %v", err) 211 | } 212 | 213 | if entry.Component != "test-component" { 214 | t.Errorf("Expected component 'test-component', got %q", entry.Component) 215 | } 216 | if entry.Level != "error" { 217 | t.Errorf("Expected log level 'error', got %q", entry.Level) 218 | } 219 | if entry.Message != "error message" { 220 | t.Errorf("Expected message 'error message', got %q", entry.Message) 221 | } 222 | } 223 | 224 | func TestWireGuardLogger(t *testing.T) { 225 | var buf bytes.Buffer 226 | logger := NewLogger(LogLevelDebug, &buf) 227 | 228 | wgLogger := logger.WireGuardLogger() 229 | wgLogger.Println("test wireguard message") 230 | 231 | var entry LogEntry 232 | line := strings.TrimSpace(buf.String()) 233 | if err := json.Unmarshal([]byte(line), &entry); err != nil { 234 | t.Fatalf("Failed to parse log output as JSON: %v", err) 235 | } 236 | 237 | if entry.Component != "wireguard" { 238 | t.Errorf("Expected component 'wireguard', got %q", entry.Component) 239 | } 240 | if entry.Level != "debug" { 241 | t.Errorf("Expected log level 'debug', got %q", entry.Level) 242 | } 243 | if !strings.Contains(entry.Message, "test wireguard message") { 244 | t.Errorf("Expected message to contain 'test wireguard message', got %q", entry.Message) 245 | } 246 | } 247 | 248 | func TestLoggerJSONFormat(t *testing.T) { 249 | var buf bytes.Buffer 250 | logger := NewLogger(LogLevelInfo, &buf) 251 | 252 | logger.Info("test message") 253 | 254 | var entry LogEntry 255 | line := strings.TrimSpace(buf.String()) 256 | if err := json.Unmarshal([]byte(line), &entry); err != nil { 257 | t.Fatalf("Failed to parse log output as JSON: %v", err) 258 | } 259 | 260 | // Validate all required fields are present 261 | if entry.Timestamp == "" { 262 | t.Error("Timestamp should not be empty") 263 | } 264 | if entry.Level == "" { 265 | t.Error("Level should not be empty") 266 | } 267 | if entry.Message == "" { 268 | t.Error("Message should not be empty") 269 | } 270 | 271 | // Validate timestamp format (RFC3339) 272 | if !strings.Contains(entry.Timestamp, "T") || !strings.Contains(entry.Timestamp, "Z") { 273 | t.Errorf("Timestamp should be in RFC3339 format, got %q", entry.Timestamp) 274 | } 275 | } 276 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "io" 7 | "os" 8 | "os/exec" 9 | "os/signal" 10 | "path/filepath" 11 | "strings" 12 | "syscall" 13 | ) 14 | 15 | func printUsage() { 16 | help := fmt.Sprintf(` 17 | ╦ ╦┬─┐┌─┐┌─┐╔═╗┬ ┬┌─┐┬─┐┌┬┐ 18 | ║║║├┬┘├─┤├─┘║ ╦│ │├─┤├┬┘ ││ 19 | ╚╩╝┴└─┴ ┴┴ ╚═╝└─┘┴ ┴┴└──┴┘ %s 20 | 21 | 🔒 Userspace WireGuard proxy for transparent network tunneling 22 | 23 | `, Version) 24 | 25 | help += "\033[33mUSAGE:\033[0m\n" 26 | help += " wrapguard --config= -- [args...]\n\n" 27 | 28 | help += "\033[33mEXAMPLES:\033[0m\n" 29 | help += " \033[36m# Check your tunneled IP address\033[0m\n" 30 | help += " wrapguard --config=wg0.conf -- curl https://icanhazip.com\n\n" 31 | 32 | help += " \033[36m# Run a web server accessible through WireGuard\033[0m\n" 33 | help += " wrapguard --config=wg0.conf -- python3 -m http.server 8080\n\n" 34 | 35 | help += " \033[36m# Tunnel Node.js applications\033[0m\n" 36 | help += " wrapguard --config=wg0.conf -- node app.js\n\n" 37 | 38 | help += " \033[36m# Interactive shell with tunneled network\033[0m\n" 39 | help += " wrapguard --config=wg0.conf -- bash\n\n" 40 | 41 | help += "\033[33mOPTIONS:\033[0m\n" 42 | help += " --config= Path to WireGuard configuration file\n" 43 | help += " --log-level= Set log level (error, warn, info, debug)\n" 44 | help += " --log-file= Set file to write logs to (default: terminal)\n" 45 | help += " --help Show this help message\n" 46 | help += " --version Show version information\n\n" 47 | 48 | help += "\033[33mFEATURES:\033[0m\n" 49 | help += " ✓ No root/sudo required\n" 50 | help += " ✓ No kernel modules needed\n" 51 | help += " ✓ Works in containers\n" 52 | help += " ✓ Transparent to applications\n" 53 | help += " ✓ Standard WireGuard configs\n\n" 54 | 55 | help += "\033[33mCONFIG EXAMPLE:\033[0m\n" 56 | help += " [Interface]\n" 57 | help += " PrivateKey = \n" 58 | help += " Address = 10.0.0.2/24\n\n" 59 | 60 | help += " [Peer]\n" 61 | help += " PublicKey = \n" 62 | help += " Endpoint = vpn.example.com:51820\n" 63 | help += " AllowedIPs = 0.0.0.0/0\n\n" 64 | 65 | help += "\033[90mMore info: https://github.com/puzed/wrapguard\033[0m\n\n" 66 | 67 | os.Stderr.WriteString(help) 68 | } 69 | 70 | func main() { 71 | var configPath string 72 | var showHelp bool 73 | var showVersion bool 74 | var logLevelStr string 75 | var logFile string 76 | flag.StringVar(&configPath, "config", "", "Path to WireGuard configuration file") 77 | flag.BoolVar(&showHelp, "help", false, "Show help message") 78 | flag.BoolVar(&showVersion, "version", false, "Show version information") 79 | flag.StringVar(&logLevelStr, "log-level", "info", "Set log level (error, warn, info, debug)") 80 | flag.StringVar(&logFile, "log-file", "", "Set file to write logs to (default: terminal)") 81 | flag.Usage = printUsage 82 | flag.Parse() 83 | 84 | if showVersion { 85 | fmt.Printf("wrapguard version %s\n", Version) 86 | os.Exit(0) 87 | } 88 | 89 | if showHelp { 90 | printUsage() 91 | os.Exit(0) 92 | } 93 | 94 | if configPath == "" { 95 | printUsage() 96 | os.Exit(1) 97 | } 98 | 99 | // Parse log level 100 | logLevel, err := ParseLogLevel(logLevelStr) 101 | if err != nil { 102 | fmt.Fprintf(os.Stderr, "\n\033[31m✗ Error:\033[0m Invalid log level: %v\n", err) 103 | os.Exit(1) 104 | } 105 | 106 | // Setup logger output 107 | var logOutput io.Writer = os.Stdout 108 | if logFile != "" { 109 | file, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) 110 | if err != nil { 111 | fmt.Fprintf(os.Stderr, "\n\033[31m✗ Error:\033[0m Failed to open log file: %v\n", err) 112 | os.Exit(1) 113 | } 114 | defer file.Close() 115 | logOutput = file 116 | } 117 | 118 | // Create logger 119 | logger := NewLogger(logLevel, logOutput) 120 | 121 | args := flag.Args() 122 | if len(args) == 0 { 123 | fmt.Fprintf(os.Stderr, "\n\033[31m✗ Error:\033[0m No command specified\n") 124 | printUsage() 125 | os.Exit(1) 126 | } 127 | 128 | // Parse WireGuard configuration 129 | config, err := ParseWireGuardConfig(configPath) 130 | if err != nil { 131 | logger.Errorf("Failed to parse WireGuard config: %v", err) 132 | os.Exit(1) 133 | } 134 | 135 | // Initialize the virtual network stack 136 | netStack, err := NewVirtualNetworkStack() 137 | if err != nil { 138 | logger.Errorf("Failed to create virtual network stack: %v", err) 139 | os.Exit(1) 140 | } 141 | 142 | // Initialize WireGuard with memory-based TUN 143 | wg, err := NewWireGuardProxy(config, netStack, logger) 144 | if err != nil { 145 | logger.Errorf("Failed to initialize WireGuard: %v", err) 146 | os.Exit(1) 147 | } 148 | 149 | // Start the WireGuard proxy 150 | if err := wg.Start(); err != nil { 151 | logger.Errorf("Failed to start WireGuard: %v", err) 152 | os.Exit(1) 153 | } 154 | defer wg.Stop() 155 | 156 | // Start IPC server for LD_PRELOAD library communication 157 | ipcServer, err := NewIPCServer(netStack, wg) 158 | if err != nil { 159 | logger.Errorf("Failed to create IPC server: %v", err) 160 | os.Exit(1) 161 | } 162 | 163 | if err := ipcServer.Start(); err != nil { 164 | logger.Errorf("Failed to start IPC server: %v", err) 165 | os.Exit(1) 166 | } 167 | defer ipcServer.Stop() 168 | 169 | // Show startup messages using structured logging 170 | logger.Infof("WrapGuard %s initialized", Version) 171 | logger.Infof("Config: %s", configPath) 172 | logger.Infof("Interface: %s", config.Interface.Address.String()) 173 | logger.Infof("Peer endpoint: %s", config.Peers[0].Endpoint.String()) 174 | logger.Infof("Launching: %s", strings.Join(args, " ")) 175 | 176 | // Get path to our LD_PRELOAD library 177 | execPath, err := os.Executable() 178 | if err != nil { 179 | logger.Errorf("Failed to get executable path: %v", err) 180 | os.Exit(1) 181 | } 182 | libPath := filepath.Join(filepath.Dir(execPath), "libwrapguard.so") 183 | 184 | // Prepare child process 185 | cmd := exec.Command(args[0], args[1:]...) 186 | cmd.Stdin = os.Stdin 187 | cmd.Stdout = os.Stdout 188 | cmd.Stderr = os.Stderr 189 | 190 | // Set LD_PRELOAD and IPC socket path 191 | cmd.Env = append(os.Environ(), 192 | fmt.Sprintf("LD_PRELOAD=%s", libPath), 193 | fmt.Sprintf("WRAPGUARD_IPC_PATH=%s", ipcServer.SocketPath()), 194 | ) 195 | 196 | // Start the child process 197 | if err := cmd.Start(); err != nil { 198 | logger.Errorf("Failed to start child process: %v", err) 199 | os.Exit(1) 200 | } 201 | 202 | // Handle signals 203 | sigChan := make(chan os.Signal, 1) 204 | signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) 205 | 206 | // Wait for child process or signal 207 | done := make(chan error, 1) 208 | go func() { 209 | done <- cmd.Wait() 210 | }() 211 | 212 | select { 213 | case err := <-done: 214 | if err != nil { 215 | if exitErr, ok := err.(*exec.ExitError); ok { 216 | os.Exit(exitErr.ExitCode()) 217 | } 218 | logger.Errorf("Child process error: %v", err) 219 | os.Exit(1) 220 | } 221 | // Exit cleanly when child process completes successfully 222 | os.Exit(0) 223 | case sig := <-sigChan: 224 | // Forward signal to child process 225 | if cmd.Process != nil { 226 | cmd.Process.Signal(sig) 227 | } 228 | // Wait for child to exit 229 | <-done 230 | os.Exit(1) 231 | } 232 | } 233 | -------------------------------------------------------------------------------- /memorytun.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "errors" 5 | "io" 6 | "os" 7 | "sync" 8 | "time" 9 | 10 | "golang.zx2c4.com/wireguard/tun" 11 | ) 12 | 13 | // MemoryTUN implements a memory-based TUN device that doesn't require kernel interfaces 14 | type MemoryTUN struct { 15 | name string 16 | mtu int 17 | closed chan struct{} 18 | events chan tun.Event 19 | inbound chan []byte // Packets from WireGuard to applications 20 | outbound chan []byte // Packets from applications to WireGuard 21 | closeOnce sync.Once 22 | mu sync.Mutex 23 | } 24 | 25 | // NewMemoryTUN creates a new memory-based TUN device 26 | func NewMemoryTUN(name string, mtu int) *MemoryTUN { 27 | return &MemoryTUN{ 28 | name: name, 29 | mtu: mtu, 30 | closed: make(chan struct{}), 31 | events: make(chan tun.Event, 10), 32 | inbound: make(chan []byte, 1000), 33 | outbound: make(chan []byte, 1000), 34 | } 35 | } 36 | 37 | // Name returns the name of the TUN device 38 | func (t *MemoryTUN) Name() (string, error) { 39 | return t.name, nil 40 | } 41 | 42 | // File returns a nil file descriptor as we don't use real files 43 | func (t *MemoryTUN) File() *os.File { 44 | return nil 45 | } 46 | 47 | // Events returns the event channel 48 | func (t *MemoryTUN) Events() <-chan tun.Event { 49 | return t.events 50 | } 51 | 52 | // Read reads one or more packets from the TUN device (packets coming from applications) 53 | // On a successful read it returns the number of packets read, and sets 54 | // packet lengths within the sizes slice. 55 | func (t *MemoryTUN) Read(bufs [][]byte, sizes []int, offset int) (int, error) { 56 | if len(bufs) == 0 || len(sizes) < len(bufs) { 57 | return 0, errors.New("invalid buffer or sizes slice") 58 | } 59 | 60 | packetsRead := 0 61 | for i := range bufs { 62 | select { 63 | case <-t.closed: 64 | if packetsRead == 0 { 65 | return 0, io.EOF 66 | } 67 | return packetsRead, nil 68 | case packet := <-t.outbound: 69 | if len(packet) > len(bufs[i])-offset { 70 | return packetsRead, errors.New("packet too large for buffer") 71 | } 72 | copy(bufs[i][offset:], packet) 73 | sizes[i] = len(packet) 74 | packetsRead++ 75 | default: 76 | // No more packets available 77 | if packetsRead == 0 { 78 | // Block for at least one packet 79 | select { 80 | case <-t.closed: 81 | return 0, io.EOF 82 | case packet := <-t.outbound: 83 | if len(packet) > len(bufs[i])-offset { 84 | return 0, errors.New("packet too large for buffer") 85 | } 86 | copy(bufs[i][offset:], packet) 87 | sizes[i] = len(packet) 88 | return 1, nil 89 | } 90 | } 91 | return packetsRead, nil 92 | } 93 | } 94 | return packetsRead, nil 95 | } 96 | 97 | // Write writes one or more packets to the TUN device (packets going to applications) 98 | // On a successful write it returns the number of packets written. 99 | func (t *MemoryTUN) Write(bufs [][]byte, offset int) (int, error) { 100 | if len(bufs) == 0 { 101 | return 0, nil 102 | } 103 | 104 | packetsWritten := 0 105 | for _, buf := range bufs { 106 | if offset >= len(buf) { 107 | continue 108 | } 109 | 110 | packet := make([]byte, len(buf)-offset) 111 | copy(packet, buf[offset:]) 112 | 113 | select { 114 | case <-t.closed: 115 | if packetsWritten == 0 { 116 | return 0, io.EOF 117 | } 118 | return packetsWritten, nil 119 | case t.inbound <- packet: 120 | packetsWritten++ 121 | default: 122 | // Drop packet if buffer is full but count as written 123 | packetsWritten++ 124 | } 125 | } 126 | return packetsWritten, nil 127 | } 128 | 129 | // MTU returns the MTU of the TUN device 130 | func (t *MemoryTUN) MTU() (int, error) { 131 | return t.mtu, nil 132 | } 133 | 134 | // Close closes the TUN device 135 | func (t *MemoryTUN) Close() error { 136 | t.closeOnce.Do(func() { 137 | close(t.closed) 138 | close(t.events) 139 | }) 140 | return nil 141 | } 142 | 143 | // BatchSize returns the preferred/max number of packets that can be read or 144 | // written in a single read/write call. 145 | func (t *MemoryTUN) BatchSize() int { 146 | return 128 // Allow batching up to 128 packets 147 | } 148 | 149 | // InjectInbound injects a packet as if it came from the network (for sending to WireGuard) 150 | func (t *MemoryTUN) InjectInbound(packet []byte) error { 151 | select { 152 | case <-t.closed: 153 | return io.EOF 154 | case t.outbound <- packet: 155 | return nil 156 | case <-time.After(100 * time.Millisecond): 157 | return errors.New("timeout injecting packet") 158 | } 159 | } 160 | 161 | // ReadOutbound reads a packet that WireGuard wants to send to the network 162 | func (t *MemoryTUN) ReadOutbound() ([]byte, error) { 163 | select { 164 | case <-t.closed: 165 | return nil, io.EOF 166 | case packet := <-t.inbound: 167 | return packet, nil 168 | } 169 | } 170 | 171 | // SendUp sends the TUN up event 172 | func (t *MemoryTUN) SendUp() { 173 | select { 174 | case t.events <- tun.EventUp: 175 | default: 176 | } 177 | } 178 | 179 | // SendDown sends the TUN down event 180 | func (t *MemoryTUN) SendDown() { 181 | select { 182 | case t.events <- tun.EventDown: 183 | default: 184 | } 185 | } 186 | -------------------------------------------------------------------------------- /network.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/binary" 5 | "fmt" 6 | "net" 7 | "sync" 8 | "sync/atomic" 9 | ) 10 | 11 | // VirtualNetworkStack manages virtual connections and packet routing 12 | type VirtualNetworkStack struct { 13 | mu sync.RWMutex 14 | connections map[uint32]*VirtualConnection 15 | listeningSockets map[string]*VirtualListener 16 | outgoingPackets chan []byte 17 | nextConnID uint32 18 | localIP net.IP 19 | localNet *net.IPNet 20 | } 21 | 22 | // VirtualConnection represents a virtual network connection 23 | type VirtualConnection struct { 24 | ID uint32 25 | LocalAddr net.Addr 26 | RemoteAddr net.Addr 27 | Type string // "tcp" or "udp" 28 | State string // "connected", "listening", etc. 29 | IncomingData chan []byte 30 | OutgoingData chan []byte 31 | } 32 | 33 | // VirtualListener represents a listening socket 34 | type VirtualListener struct { 35 | Addr net.Addr 36 | Type string // "tcp" or "udp" 37 | AcceptQueue chan *VirtualConnection 38 | } 39 | 40 | // NewVirtualNetworkStack creates a new virtual network stack 41 | func NewVirtualNetworkStack() (*VirtualNetworkStack, error) { 42 | return &VirtualNetworkStack{ 43 | connections: make(map[uint32]*VirtualConnection), 44 | listeningSockets: make(map[string]*VirtualListener), 45 | outgoingPackets: make(chan []byte, 1000), 46 | }, nil 47 | } 48 | 49 | // SetLocalAddress sets the local WireGuard IP address 50 | func (s *VirtualNetworkStack) SetLocalAddress(addr *net.IPNet) { 51 | s.mu.Lock() 52 | defer s.mu.Unlock() 53 | s.localIP = addr.IP 54 | s.localNet = addr 55 | } 56 | 57 | // CreateConnection creates a new virtual connection 58 | func (s *VirtualNetworkStack) CreateConnection(connType string) (*VirtualConnection, error) { 59 | connID := atomic.AddUint32(&s.nextConnID, 1) 60 | 61 | conn := &VirtualConnection{ 62 | ID: connID, 63 | Type: connType, 64 | State: "created", 65 | IncomingData: make(chan []byte, 100), 66 | OutgoingData: make(chan []byte, 100), 67 | } 68 | 69 | s.mu.Lock() 70 | s.connections[connID] = conn 71 | s.mu.Unlock() 72 | 73 | // Start packet handler for this connection 74 | go s.handleConnectionPackets(conn) 75 | 76 | return conn, nil 77 | } 78 | 79 | // BindConnection binds a connection to a local address 80 | func (s *VirtualNetworkStack) BindConnection(connID uint32, addr net.Addr) error { 81 | s.mu.Lock() 82 | defer s.mu.Unlock() 83 | 84 | conn, exists := s.connections[connID] 85 | if !exists { 86 | return fmt.Errorf("connection %d not found", connID) 87 | } 88 | 89 | conn.LocalAddr = addr 90 | return nil 91 | } 92 | 93 | // ListenConnection puts a connection in listening state 94 | func (s *VirtualNetworkStack) ListenConnection(connID uint32) error { 95 | s.mu.Lock() 96 | defer s.mu.Unlock() 97 | 98 | conn, exists := s.connections[connID] 99 | if !exists { 100 | return fmt.Errorf("connection %d not found", connID) 101 | } 102 | 103 | if conn.LocalAddr == nil { 104 | return fmt.Errorf("connection must be bound before listening") 105 | } 106 | 107 | listener := &VirtualListener{ 108 | Addr: conn.LocalAddr, 109 | Type: conn.Type, 110 | AcceptQueue: make(chan *VirtualConnection, 10), 111 | } 112 | 113 | s.listeningSockets[conn.LocalAddr.String()] = listener 114 | conn.State = "listening" 115 | 116 | return nil 117 | } 118 | 119 | // ConnectConnection connects to a remote address 120 | func (s *VirtualNetworkStack) ConnectConnection(connID uint32, remoteAddr net.Addr) error { 121 | s.mu.Lock() 122 | conn, exists := s.connections[connID] 123 | s.mu.Unlock() 124 | 125 | if !exists { 126 | return fmt.Errorf("connection %d not found", connID) 127 | } 128 | 129 | // Assign local address if not bound 130 | if conn.LocalAddr == nil { 131 | // Auto-assign ephemeral port 132 | localPort := 30000 + (connID % 30000) 133 | if conn.Type == "tcp" { 134 | conn.LocalAddr = &net.TCPAddr{IP: s.localIP, Port: int(localPort)} 135 | } else { 136 | conn.LocalAddr = &net.UDPAddr{IP: s.localIP, Port: int(localPort)} 137 | } 138 | } 139 | 140 | conn.RemoteAddr = remoteAddr 141 | conn.State = "connected" 142 | 143 | // Send SYN packet for TCP 144 | if conn.Type == "tcp" { 145 | synPacket := s.createTCPPacket(conn, nil, true, false, false) 146 | s.outgoingPackets <- synPacket 147 | } 148 | 149 | return nil 150 | } 151 | 152 | // SendData sends data on a connection 153 | func (s *VirtualNetworkStack) SendData(connID uint32, data []byte) error { 154 | s.mu.RLock() 155 | conn, exists := s.connections[connID] 156 | s.mu.RUnlock() 157 | 158 | if !exists { 159 | return fmt.Errorf("connection %d not found", connID) 160 | } 161 | 162 | if conn.State != "connected" { 163 | return fmt.Errorf("connection not in connected state") 164 | } 165 | 166 | // Queue data for sending 167 | select { 168 | case conn.OutgoingData <- data: 169 | return nil 170 | default: 171 | return fmt.Errorf("outgoing buffer full") 172 | } 173 | } 174 | 175 | // ReceiveData receives data from a connection 176 | func (s *VirtualNetworkStack) ReceiveData(connID uint32) ([]byte, error) { 177 | s.mu.RLock() 178 | conn, exists := s.connections[connID] 179 | s.mu.RUnlock() 180 | 181 | if !exists { 182 | return nil, fmt.Errorf("connection %d not found", connID) 183 | } 184 | 185 | select { 186 | case data := <-conn.IncomingData: 187 | return data, nil 188 | default: 189 | return nil, fmt.Errorf("no data available") 190 | } 191 | } 192 | 193 | // AcceptConnection accepts a new connection on a listening socket 194 | func (s *VirtualNetworkStack) AcceptConnection(listenAddr net.Addr) (*VirtualConnection, error) { 195 | s.mu.RLock() 196 | listener, exists := s.listeningSockets[listenAddr.String()] 197 | s.mu.RUnlock() 198 | 199 | if !exists { 200 | return nil, fmt.Errorf("no listener on %s", listenAddr.String()) 201 | } 202 | 203 | select { 204 | case conn := <-listener.AcceptQueue: 205 | return conn, nil 206 | default: 207 | return nil, fmt.Errorf("no pending connections") 208 | } 209 | } 210 | 211 | // CloseConnection closes a virtual connection 212 | func (s *VirtualNetworkStack) CloseConnection(connID uint32) error { 213 | s.mu.Lock() 214 | defer s.mu.Unlock() 215 | 216 | conn, exists := s.connections[connID] 217 | if !exists { 218 | return fmt.Errorf("connection %d not found", connID) 219 | } 220 | 221 | // Send FIN packet for TCP 222 | if conn.Type == "tcp" && conn.State == "connected" { 223 | finPacket := s.createTCPPacket(conn, nil, false, false, true) 224 | s.outgoingPackets <- finPacket 225 | } 226 | 227 | close(conn.IncomingData) 228 | close(conn.OutgoingData) 229 | delete(s.connections, connID) 230 | 231 | // Remove from listening sockets if it was listening 232 | if conn.State == "listening" && conn.LocalAddr != nil { 233 | delete(s.listeningSockets, conn.LocalAddr.String()) 234 | } 235 | 236 | return nil 237 | } 238 | 239 | // OutgoingPackets returns the channel for outgoing packets 240 | func (s *VirtualNetworkStack) OutgoingPackets() <-chan []byte { 241 | return s.outgoingPackets 242 | } 243 | 244 | // DeliverIncomingPacket processes an incoming packet from WireGuard 245 | func (s *VirtualNetworkStack) DeliverIncomingPacket(packet []byte) error { 246 | if len(packet) < 20 { 247 | return fmt.Errorf("packet too short") 248 | } 249 | 250 | // Parse IP header 251 | version := packet[0] >> 4 252 | switch version { 253 | case 4: 254 | return s.handleIPv4Packet(packet) 255 | case 6: 256 | return s.handleIPv6Packet(packet) 257 | default: 258 | return fmt.Errorf("unsupported IP version: %d", version) 259 | } 260 | } 261 | 262 | // handleIPv4Packet processes IPv4 packets 263 | func (s *VirtualNetworkStack) handleIPv4Packet(packet []byte) error { 264 | if len(packet) < 20 { 265 | return fmt.Errorf("IPv4 packet too short") 266 | } 267 | 268 | protocol := packet[9] 269 | srcIP := net.IP(packet[12:16]) 270 | dstIP := net.IP(packet[16:20]) 271 | 272 | headerLen := int(packet[0]&0x0f) * 4 273 | if len(packet) < headerLen { 274 | return fmt.Errorf("invalid IPv4 header length") 275 | } 276 | 277 | payload := packet[headerLen:] 278 | 279 | switch protocol { 280 | case 6: // TCP 281 | return s.handleIncomingTCP(srcIP, dstIP, payload) 282 | case 17: // UDP 283 | return s.handleIncomingUDP(srcIP, dstIP, payload) 284 | default: 285 | return fmt.Errorf("unsupported protocol: %d", protocol) 286 | } 287 | } 288 | 289 | // handleIPv6Packet processes IPv6 packets 290 | func (s *VirtualNetworkStack) handleIPv6Packet(packet []byte) error { 291 | if len(packet) < 40 { 292 | return fmt.Errorf("IPv6 packet too short") 293 | } 294 | 295 | nextHeader := packet[6] 296 | srcIP := net.IP(packet[8:24]) 297 | dstIP := net.IP(packet[24:40]) 298 | 299 | payload := packet[40:] 300 | 301 | switch nextHeader { 302 | case 6: // TCP 303 | return s.handleIncomingTCP(srcIP, dstIP, payload) 304 | case 17: // UDP 305 | return s.handleIncomingUDP(srcIP, dstIP, payload) 306 | default: 307 | return fmt.Errorf("unsupported protocol: %d", nextHeader) 308 | } 309 | } 310 | 311 | // handleConnectionPackets handles outgoing packets for a connection 312 | func (s *VirtualNetworkStack) handleConnectionPackets(conn *VirtualConnection) { 313 | for data := range conn.OutgoingData { 314 | var packet []byte 315 | if conn.Type == "tcp" { 316 | packet = s.createTCPPacket(conn, data, false, true, false) 317 | } else { 318 | packet = s.createUDPPacket(conn, data) 319 | } 320 | s.outgoingPackets <- packet 321 | } 322 | } 323 | 324 | // createTCPPacket creates a TCP/IP packet (IPv4 or IPv6) 325 | func (s *VirtualNetworkStack) createTCPPacket(conn *VirtualConnection, data []byte, syn, ack, fin bool) []byte { 326 | // This is a simplified implementation 327 | // In production, you'd need proper TCP sequence numbers, checksums, etc. 328 | 329 | tcpAddr, _ := conn.LocalAddr.(*net.TCPAddr) 330 | remoteTCPAddr, _ := conn.RemoteAddr.(*net.TCPAddr) 331 | 332 | var ipHeader []byte 333 | 334 | // Determine if we're dealing with IPv4 or IPv6 335 | if tcpAddr.IP.To4() != nil && remoteTCPAddr.IP.To4() != nil { 336 | // IPv4 header (20 bytes) 337 | ipHeader = make([]byte, 20) 338 | ipHeader[0] = 0x45 // Version 4, header length 5 (20 bytes) 339 | ipHeader[1] = 0 // TOS 340 | binary.BigEndian.PutUint16(ipHeader[2:4], uint16(20+20+len(data))) // Total length 341 | binary.BigEndian.PutUint16(ipHeader[4:6], 0) // ID 342 | ipHeader[6] = 0x40 // Flags (Don't Fragment) 343 | ipHeader[8] = 64 // TTL 344 | ipHeader[9] = 6 // Protocol (TCP) 345 | // Checksum would go in bytes 10-11 346 | copy(ipHeader[12:16], tcpAddr.IP.To4()) 347 | copy(ipHeader[16:20], remoteTCPAddr.IP.To4()) 348 | } else { 349 | // IPv6 header (40 bytes) 350 | ipHeader = make([]byte, 40) 351 | ipHeader[0] = 0x60 // Version 6 352 | // Traffic class and flow label in bytes 1-3 353 | binary.BigEndian.PutUint16(ipHeader[4:6], uint16(20+len(data))) // Payload length 354 | ipHeader[6] = 6 // Next header (TCP) 355 | ipHeader[7] = 64 // Hop limit 356 | copy(ipHeader[8:24], tcpAddr.IP.To16()) 357 | copy(ipHeader[24:40], remoteTCPAddr.IP.To16()) 358 | } 359 | 360 | // TCP header (20 bytes minimum) 361 | tcpHeader := make([]byte, 20) 362 | binary.BigEndian.PutUint16(tcpHeader[0:2], uint16(tcpAddr.Port)) 363 | binary.BigEndian.PutUint16(tcpHeader[2:4], uint16(remoteTCPAddr.Port)) 364 | // Sequence number, ACK number would go here 365 | tcpHeader[12] = 0x50 // Header length (5 * 4 = 20 bytes) 366 | 367 | // Flags 368 | flags := byte(0) 369 | if syn { 370 | flags |= 0x02 371 | } 372 | if ack { 373 | flags |= 0x10 374 | } 375 | if fin { 376 | flags |= 0x01 377 | } 378 | tcpHeader[13] = flags 379 | 380 | binary.BigEndian.PutUint16(tcpHeader[14:16], 65535) // Window size 381 | // Checksum would go in bytes 16-18 382 | 383 | // Combine all parts 384 | packet := make([]byte, 0, len(ipHeader)+20+len(data)) 385 | packet = append(packet, ipHeader...) 386 | packet = append(packet, tcpHeader...) 387 | packet = append(packet, data...) 388 | 389 | return packet 390 | } 391 | 392 | // createUDPPacket creates a UDP/IP packet (IPv4 or IPv6) 393 | func (s *VirtualNetworkStack) createUDPPacket(conn *VirtualConnection, data []byte) []byte { 394 | udpAddr, _ := conn.LocalAddr.(*net.UDPAddr) 395 | remoteUDPAddr, _ := conn.RemoteAddr.(*net.UDPAddr) 396 | 397 | var ipHeader []byte 398 | 399 | // Determine if we're dealing with IPv4 or IPv6 400 | if udpAddr.IP.To4() != nil && remoteUDPAddr.IP.To4() != nil { 401 | // IPv4 header (20 bytes) 402 | ipHeader = make([]byte, 20) 403 | ipHeader[0] = 0x45 // Version 4, header length 5 404 | binary.BigEndian.PutUint16(ipHeader[2:4], uint16(20+8+len(data))) // Total length 405 | ipHeader[8] = 64 // TTL 406 | ipHeader[9] = 17 // Protocol (UDP) 407 | copy(ipHeader[12:16], udpAddr.IP.To4()) 408 | copy(ipHeader[16:20], remoteUDPAddr.IP.To4()) 409 | } else { 410 | // IPv6 header (40 bytes) 411 | ipHeader = make([]byte, 40) 412 | ipHeader[0] = 0x60 // Version 6 413 | // Traffic class and flow label in bytes 1-3 414 | binary.BigEndian.PutUint16(ipHeader[4:6], uint16(8+len(data))) // Payload length 415 | ipHeader[6] = 17 // Next header (UDP) 416 | ipHeader[7] = 64 // Hop limit 417 | copy(ipHeader[8:24], udpAddr.IP.To16()) 418 | copy(ipHeader[24:40], remoteUDPAddr.IP.To16()) 419 | } 420 | 421 | // UDP header (8 bytes) 422 | udpHeader := make([]byte, 8) 423 | binary.BigEndian.PutUint16(udpHeader[0:2], uint16(udpAddr.Port)) 424 | binary.BigEndian.PutUint16(udpHeader[2:4], uint16(remoteUDPAddr.Port)) 425 | binary.BigEndian.PutUint16(udpHeader[4:6], uint16(8+len(data))) // Length 426 | 427 | // Combine all parts 428 | packet := make([]byte, 0, len(ipHeader)+8+len(data)) 429 | packet = append(packet, ipHeader...) 430 | packet = append(packet, udpHeader...) 431 | packet = append(packet, data...) 432 | 433 | return packet 434 | } 435 | 436 | // handleIncomingTCP handles incoming TCP packets 437 | func (s *VirtualNetworkStack) handleIncomingTCP(srcIP, dstIP net.IP, payload []byte) error { 438 | if len(payload) < 20 { 439 | return fmt.Errorf("TCP header too short") 440 | } 441 | 442 | srcPort := binary.BigEndian.Uint16(payload[0:2]) 443 | dstPort := binary.BigEndian.Uint16(payload[2:4]) 444 | flags := payload[13] 445 | 446 | localAddr := &net.TCPAddr{IP: dstIP, Port: int(dstPort)} 447 | remoteAddr := &net.TCPAddr{IP: srcIP, Port: int(srcPort)} 448 | 449 | // Check if this is for a listening socket 450 | s.mu.RLock() 451 | listener, hasListener := s.listeningSockets[localAddr.String()] 452 | s.mu.RUnlock() 453 | 454 | if hasListener && (flags&0x02) != 0 { // SYN flag 455 | // Create new connection for incoming SYN 456 | newConn, _ := s.CreateConnection("tcp") 457 | newConn.LocalAddr = localAddr 458 | newConn.RemoteAddr = remoteAddr 459 | newConn.State = "connected" 460 | 461 | // Queue for accept 462 | select { 463 | case listener.AcceptQueue <- newConn: 464 | default: 465 | // Accept queue full 466 | } 467 | 468 | // Send SYN-ACK 469 | synAckPacket := s.createTCPPacket(newConn, nil, true, true, false) 470 | s.outgoingPackets <- synAckPacket 471 | return nil 472 | } 473 | 474 | // Find existing connection 475 | s.mu.RLock() 476 | var conn *VirtualConnection 477 | for _, c := range s.connections { 478 | if c.Type == "tcp" && 479 | c.LocalAddr != nil && c.LocalAddr.String() == localAddr.String() && 480 | c.RemoteAddr != nil && c.RemoteAddr.String() == remoteAddr.String() { 481 | conn = c 482 | break 483 | } 484 | } 485 | s.mu.RUnlock() 486 | 487 | if conn == nil { 488 | return fmt.Errorf("no connection found for TCP packet") 489 | } 490 | 491 | // Extract data after TCP header 492 | headerLen := int((payload[12]>>4)&0x0f) * 4 493 | if len(payload) > headerLen { 494 | data := payload[headerLen:] 495 | select { 496 | case conn.IncomingData <- data: 497 | default: 498 | // Buffer full 499 | } 500 | } 501 | 502 | return nil 503 | } 504 | 505 | // handleIncomingUDP handles incoming UDP packets 506 | func (s *VirtualNetworkStack) handleIncomingUDP(srcIP, dstIP net.IP, payload []byte) error { 507 | if len(payload) < 8 { 508 | return fmt.Errorf("UDP header too short") 509 | } 510 | 511 | srcPort := binary.BigEndian.Uint16(payload[0:2]) 512 | dstPort := binary.BigEndian.Uint16(payload[2:4]) 513 | 514 | localAddr := &net.UDPAddr{IP: dstIP, Port: int(dstPort)} 515 | remoteAddr := &net.UDPAddr{IP: srcIP, Port: int(srcPort)} 516 | 517 | // Find connection 518 | s.mu.RLock() 519 | var conn *VirtualConnection 520 | for _, c := range s.connections { 521 | if c.Type == "udp" && 522 | c.LocalAddr != nil && c.LocalAddr.String() == localAddr.String() { 523 | conn = c 524 | break 525 | } 526 | } 527 | s.mu.RUnlock() 528 | 529 | if conn == nil { 530 | return fmt.Errorf("no connection found for UDP packet") 531 | } 532 | 533 | // Update remote address for UDP (connectionless) 534 | conn.RemoteAddr = remoteAddr 535 | 536 | // Extract data 537 | data := payload[8:] 538 | select { 539 | case conn.IncomingData <- data: 540 | default: 541 | // Buffer full 542 | } 543 | 544 | return nil 545 | } 546 | -------------------------------------------------------------------------------- /network_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "net" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func TestNewVirtualNetworkStack(t *testing.T) { 10 | netStack, err := NewVirtualNetworkStack() 11 | if err != nil { 12 | t.Fatalf("Failed to create network stack: %v", err) 13 | } 14 | 15 | if netStack == nil { 16 | t.Fatal("Network stack should not be nil") 17 | } 18 | 19 | if netStack.connections == nil { 20 | t.Error("Connections map should be initialized") 21 | } 22 | 23 | if netStack.listeningSockets == nil { 24 | t.Error("Listening sockets map should be initialized") 25 | } 26 | 27 | if netStack.outgoingPackets == nil { 28 | t.Error("Outgoing packets channel should be initialized") 29 | } 30 | } 31 | 32 | func TestSetLocalAddress(t *testing.T) { 33 | netStack, _ := NewVirtualNetworkStack() 34 | 35 | _, ipnet, _ := net.ParseCIDR("10.0.0.2/24") 36 | netStack.SetLocalAddress(ipnet) 37 | 38 | if !netStack.localIP.Equal(ipnet.IP) { 39 | t.Errorf("Expected local IP %v, got %v", ipnet.IP, netStack.localIP) 40 | } 41 | 42 | if netStack.localNet.String() != ipnet.String() { 43 | t.Errorf("Expected local net %v, got %v", ipnet, netStack.localNet) 44 | } 45 | } 46 | 47 | func TestSetLocalAddressIPv6(t *testing.T) { 48 | netStack, _ := NewVirtualNetworkStack() 49 | 50 | _, ipnet, _ := net.ParseCIDR("2001:db8::1/64") 51 | netStack.SetLocalAddress(ipnet) 52 | 53 | if !netStack.localIP.Equal(ipnet.IP) { 54 | t.Errorf("Expected local IP %v, got %v", ipnet.IP, netStack.localIP) 55 | } 56 | 57 | if netStack.localNet.String() != ipnet.String() { 58 | t.Errorf("Expected local net %v, got %v", ipnet, netStack.localNet) 59 | } 60 | } 61 | 62 | func TestCreateConnection(t *testing.T) { 63 | netStack, _ := NewVirtualNetworkStack() 64 | 65 | conn, err := netStack.CreateConnection("tcp") 66 | if err != nil { 67 | t.Fatalf("Failed to create connection: %v", err) 68 | } 69 | 70 | if conn.ID == 0 { 71 | t.Error("Connection ID should be non-zero") 72 | } 73 | 74 | if conn.Type != "tcp" { 75 | t.Errorf("Expected connection type 'tcp', got %s", conn.Type) 76 | } 77 | 78 | if conn.State != "created" { 79 | t.Errorf("Expected connection state 'created', got %s", conn.State) 80 | } 81 | 82 | // Check if connection is stored 83 | netStack.mu.RLock() 84 | storedConn, exists := netStack.connections[conn.ID] 85 | netStack.mu.RUnlock() 86 | 87 | if !exists { 88 | t.Error("Connection should be stored in connections map") 89 | } 90 | 91 | if storedConn.ID != conn.ID { 92 | t.Error("Stored connection should match created connection") 93 | } 94 | } 95 | 96 | func TestBindConnection(t *testing.T) { 97 | netStack, _ := NewVirtualNetworkStack() 98 | conn, _ := netStack.CreateConnection("tcp") 99 | 100 | addr := &net.TCPAddr{IP: net.ParseIP("10.0.0.2"), Port: 8080} 101 | err := netStack.BindConnection(conn.ID, addr) 102 | if err != nil { 103 | t.Fatalf("Failed to bind connection: %v", err) 104 | } 105 | 106 | if conn.LocalAddr.String() != addr.String() { 107 | t.Errorf("Expected local address %s, got %s", addr.String(), conn.LocalAddr.String()) 108 | } 109 | } 110 | 111 | func TestBindConnectionIPv6(t *testing.T) { 112 | netStack, _ := NewVirtualNetworkStack() 113 | conn, _ := netStack.CreateConnection("tcp") 114 | 115 | addr := &net.TCPAddr{IP: net.ParseIP("2001:db8::1"), Port: 8080} 116 | err := netStack.BindConnection(conn.ID, addr) 117 | if err != nil { 118 | t.Fatalf("Failed to bind connection: %v", err) 119 | } 120 | 121 | if conn.LocalAddr.String() != addr.String() { 122 | t.Errorf("Expected local address %s, got %s", addr.String(), conn.LocalAddr.String()) 123 | } 124 | } 125 | 126 | func TestBindConnectionNotFound(t *testing.T) { 127 | netStack, _ := NewVirtualNetworkStack() 128 | 129 | addr := &net.TCPAddr{IP: net.ParseIP("10.0.0.2"), Port: 8080} 130 | err := netStack.BindConnection(999, addr) 131 | if err == nil { 132 | t.Error("Expected error for non-existent connection") 133 | } 134 | } 135 | 136 | func TestListenConnection(t *testing.T) { 137 | netStack, _ := NewVirtualNetworkStack() 138 | conn, _ := netStack.CreateConnection("tcp") 139 | 140 | addr := &net.TCPAddr{IP: net.ParseIP("10.0.0.2"), Port: 8080} 141 | netStack.BindConnection(conn.ID, addr) 142 | 143 | err := netStack.ListenConnection(conn.ID) 144 | if err != nil { 145 | t.Fatalf("Failed to set connection to listening: %v", err) 146 | } 147 | 148 | if conn.State != "listening" { 149 | t.Errorf("Expected connection state 'listening', got %s", conn.State) 150 | } 151 | 152 | // Check if listener is created 153 | netStack.mu.RLock() 154 | listener, exists := netStack.listeningSockets[addr.String()] 155 | netStack.mu.RUnlock() 156 | 157 | if !exists { 158 | t.Error("Listener should be created") 159 | } 160 | 161 | if listener.Addr.String() != addr.String() { 162 | t.Errorf("Expected listener address %s, got %s", addr.String(), listener.Addr.String()) 163 | } 164 | } 165 | 166 | func TestListenConnectionNotBound(t *testing.T) { 167 | netStack, _ := NewVirtualNetworkStack() 168 | conn, _ := netStack.CreateConnection("tcp") 169 | 170 | err := netStack.ListenConnection(conn.ID) 171 | if err == nil { 172 | t.Error("Expected error for unbound connection") 173 | } 174 | } 175 | 176 | func TestConnectConnection(t *testing.T) { 177 | netStack, _ := NewVirtualNetworkStack() 178 | _, ipnet, _ := net.ParseCIDR("10.0.0.2/24") 179 | netStack.SetLocalAddress(ipnet) 180 | 181 | conn, _ := netStack.CreateConnection("tcp") 182 | remoteAddr := &net.TCPAddr{IP: net.ParseIP("192.168.1.1"), Port: 80} 183 | 184 | err := netStack.ConnectConnection(conn.ID, remoteAddr) 185 | if err != nil { 186 | t.Fatalf("Failed to connect: %v", err) 187 | } 188 | 189 | if conn.State != "connected" { 190 | t.Errorf("Expected connection state 'connected', got %s", conn.State) 191 | } 192 | 193 | if conn.RemoteAddr.String() != remoteAddr.String() { 194 | t.Errorf("Expected remote address %s, got %s", remoteAddr.String(), conn.RemoteAddr.String()) 195 | } 196 | 197 | // Check if local address was auto-assigned 198 | if conn.LocalAddr == nil { 199 | t.Error("Local address should be auto-assigned") 200 | } 201 | } 202 | 203 | func TestConnectConnectionIPv6(t *testing.T) { 204 | netStack, _ := NewVirtualNetworkStack() 205 | _, ipnet, _ := net.ParseCIDR("2001:db8::1/64") 206 | netStack.SetLocalAddress(ipnet) 207 | 208 | conn, _ := netStack.CreateConnection("tcp") 209 | remoteAddr := &net.TCPAddr{IP: net.ParseIP("2001:db8::2"), Port: 80} 210 | 211 | err := netStack.ConnectConnection(conn.ID, remoteAddr) 212 | if err != nil { 213 | t.Fatalf("Failed to connect: %v", err) 214 | } 215 | 216 | if conn.State != "connected" { 217 | t.Errorf("Expected connection state 'connected', got %s", conn.State) 218 | } 219 | 220 | if conn.RemoteAddr.String() != remoteAddr.String() { 221 | t.Errorf("Expected remote address %s, got %s", remoteAddr.String(), conn.RemoteAddr.String()) 222 | } 223 | 224 | // Check if local address was auto-assigned 225 | if conn.LocalAddr == nil { 226 | t.Error("Local address should be auto-assigned") 227 | } 228 | } 229 | 230 | func TestSendData(t *testing.T) { 231 | netStack, _ := NewVirtualNetworkStack() 232 | conn, _ := netStack.CreateConnection("tcp") 233 | 234 | remoteAddr := &net.TCPAddr{IP: net.ParseIP("192.168.1.1"), Port: 80} 235 | netStack.ConnectConnection(conn.ID, remoteAddr) 236 | 237 | testData := []byte("hello world") 238 | err := netStack.SendData(conn.ID, testData) 239 | if err != nil { 240 | t.Fatalf("Failed to send data: %v", err) 241 | } 242 | 243 | // Give the handleConnectionPackets goroutine time to process 244 | time.Sleep(50 * time.Millisecond) 245 | 246 | // Check if packet was sent to outgoing channel 247 | select { 248 | case packet := <-netStack.OutgoingPackets(): 249 | if len(packet) == 0 { 250 | t.Error("Expected non-empty packet") 251 | } 252 | case <-time.After(100 * time.Millisecond): 253 | t.Error("Packet should be sent to OutgoingPackets channel") 254 | } 255 | } 256 | 257 | func TestReceiveData(t *testing.T) { 258 | netStack, _ := NewVirtualNetworkStack() 259 | conn, _ := netStack.CreateConnection("tcp") 260 | 261 | testData := []byte("hello world") 262 | 263 | // Simulate incoming data 264 | go func() { 265 | time.Sleep(10 * time.Millisecond) 266 | conn.IncomingData <- testData 267 | }() 268 | 269 | // Wait a bit for the goroutine to send data 270 | time.Sleep(20 * time.Millisecond) 271 | data, err := netStack.ReceiveData(conn.ID) 272 | if err != nil { 273 | t.Fatalf("Failed to receive data: %v", err) 274 | } 275 | 276 | if string(data) != string(testData) { 277 | t.Errorf("Expected data %s, got %s", string(testData), string(data)) 278 | } 279 | } 280 | 281 | func TestReceiveDataNoData(t *testing.T) { 282 | netStack, _ := NewVirtualNetworkStack() 283 | conn, _ := netStack.CreateConnection("tcp") 284 | 285 | _, err := netStack.ReceiveData(conn.ID) 286 | if err == nil { 287 | t.Error("Expected error when no data available") 288 | } 289 | } 290 | 291 | func TestCloseConnection(t *testing.T) { 292 | netStack, _ := NewVirtualNetworkStack() 293 | conn, _ := netStack.CreateConnection("tcp") 294 | 295 | err := netStack.CloseConnection(conn.ID) 296 | if err != nil { 297 | t.Fatalf("Failed to close connection: %v", err) 298 | } 299 | 300 | // Check if connection is removed from map 301 | netStack.mu.RLock() 302 | _, exists := netStack.connections[conn.ID] 303 | netStack.mu.RUnlock() 304 | 305 | if exists { 306 | t.Error("Connection should be removed from connections map") 307 | } 308 | } 309 | 310 | func TestDeliverIncomingPacket(t *testing.T) { 311 | netStack, _ := NewVirtualNetworkStack() 312 | 313 | // Test with packet too short 314 | shortPacket := []byte{0x45, 0x00} 315 | err := netStack.DeliverIncomingPacket(shortPacket) 316 | if err == nil { 317 | t.Error("Expected error for packet too short") 318 | } 319 | 320 | // Test with unsupported protocol 321 | unsupportedProtocolPacket := make([]byte, 20) 322 | unsupportedProtocolPacket[0] = 0x45 // IPv4, header length 20 323 | unsupportedProtocolPacket[9] = 1 // ICMP protocol 324 | err = netStack.DeliverIncomingPacket(unsupportedProtocolPacket) 325 | if err == nil { 326 | t.Error("Expected error for unsupported protocol") 327 | } 328 | 329 | // Test with unsupported IP version 330 | unsupportedVersionPacket := make([]byte, 20) 331 | unsupportedVersionPacket[0] = 0x75 // Version 7 332 | err = netStack.DeliverIncomingPacket(unsupportedVersionPacket) 333 | if err == nil { 334 | t.Error("Expected error for unsupported IP version") 335 | } 336 | } 337 | 338 | func TestDeliverIncomingIPv6Packet(t *testing.T) { 339 | netStack, _ := NewVirtualNetworkStack() 340 | 341 | // Test with IPv6 packet too short 342 | shortIPv6Packet := make([]byte, 20) 343 | shortIPv6Packet[0] = 0x60 // IPv6 344 | err := netStack.DeliverIncomingPacket(shortIPv6Packet) 345 | if err == nil { 346 | t.Error("Expected error for IPv6 packet too short") 347 | } 348 | 349 | // Test with IPv6 unsupported protocol 350 | unsupportedIPv6Packet := make([]byte, 40) 351 | unsupportedIPv6Packet[0] = 0x60 // IPv6 352 | unsupportedIPv6Packet[6] = 1 // ICMP next header 353 | err = netStack.DeliverIncomingPacket(unsupportedIPv6Packet) 354 | if err == nil { 355 | t.Error("Expected error for unsupported IPv6 protocol") 356 | } 357 | } 358 | 359 | func TestCreateTCPPacket(t *testing.T) { 360 | netStack, _ := NewVirtualNetworkStack() 361 | conn, _ := netStack.CreateConnection("tcp") 362 | 363 | localAddr := &net.TCPAddr{IP: net.ParseIP("10.0.0.2"), Port: 8080} 364 | remoteAddr := &net.TCPAddr{IP: net.ParseIP("192.168.1.1"), Port: 80} 365 | conn.LocalAddr = localAddr 366 | conn.RemoteAddr = remoteAddr 367 | 368 | testData := []byte("test data") 369 | packet := netStack.createTCPPacket(conn, testData, false, true, false) 370 | 371 | if len(packet) < 40 { 372 | t.Error("TCP packet should be at least 40 bytes (20 IP + 20 TCP)") 373 | } 374 | 375 | // Check IP version and header length 376 | if packet[0] != 0x45 { 377 | t.Error("IP version should be 4 and header length should be 20") 378 | } 379 | 380 | // Check protocol 381 | if packet[9] != 6 { 382 | t.Error("Protocol should be TCP (6)") 383 | } 384 | 385 | // Check source and destination IPs 386 | srcIP := net.IP(packet[12:16]) 387 | dstIP := net.IP(packet[16:20]) 388 | 389 | if !srcIP.Equal(localAddr.IP) { 390 | t.Errorf("Source IP should be %v, got %v", localAddr.IP, srcIP) 391 | } 392 | 393 | if !dstIP.Equal(remoteAddr.IP) { 394 | t.Errorf("Destination IP should be %v, got %v", remoteAddr.IP, dstIP) 395 | } 396 | } 397 | 398 | func TestCreateTCPPacketIPv6(t *testing.T) { 399 | netStack, _ := NewVirtualNetworkStack() 400 | conn, _ := netStack.CreateConnection("tcp") 401 | 402 | localAddr := &net.TCPAddr{IP: net.ParseIP("2001:db8::1"), Port: 8080} 403 | remoteAddr := &net.TCPAddr{IP: net.ParseIP("2001:db8::2"), Port: 80} 404 | conn.LocalAddr = localAddr 405 | conn.RemoteAddr = remoteAddr 406 | 407 | testData := []byte("test data") 408 | packet := netStack.createTCPPacket(conn, testData, false, true, false) 409 | 410 | if len(packet) < 60 { 411 | t.Error("IPv6 TCP packet should be at least 60 bytes (40 IP + 20 TCP)") 412 | } 413 | 414 | // Check IP version 415 | if packet[0] != 0x60 { 416 | t.Error("IP version should be 6") 417 | } 418 | 419 | // Check next header (TCP) 420 | if packet[6] != 6 { 421 | t.Error("Next header should be TCP (6)") 422 | } 423 | 424 | // Check source and destination IPs 425 | srcIP := net.IP(packet[8:24]) 426 | dstIP := net.IP(packet[24:40]) 427 | 428 | if !srcIP.Equal(localAddr.IP) { 429 | t.Errorf("Source IP should be %v, got %v", localAddr.IP, srcIP) 430 | } 431 | 432 | if !dstIP.Equal(remoteAddr.IP) { 433 | t.Errorf("Destination IP should be %v, got %v", remoteAddr.IP, dstIP) 434 | } 435 | } 436 | 437 | func TestCreateUDPPacket(t *testing.T) { 438 | netStack, _ := NewVirtualNetworkStack() 439 | conn, _ := netStack.CreateConnection("udp") 440 | 441 | localAddr := &net.UDPAddr{IP: net.ParseIP("10.0.0.2"), Port: 8080} 442 | remoteAddr := &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 80} 443 | conn.LocalAddr = localAddr 444 | conn.RemoteAddr = remoteAddr 445 | 446 | testData := []byte("test data") 447 | packet := netStack.createUDPPacket(conn, testData) 448 | 449 | if len(packet) < 28 { 450 | t.Error("UDP packet should be at least 28 bytes (20 IP + 8 UDP)") 451 | } 452 | 453 | // Check IP version and header length 454 | if packet[0] != 0x45 { 455 | t.Error("IP version should be 4 and header length should be 20") 456 | } 457 | 458 | // Check protocol 459 | if packet[9] != 17 { 460 | t.Error("Protocol should be UDP (17)") 461 | } 462 | 463 | // Check source and destination IPs 464 | srcIP := net.IP(packet[12:16]) 465 | dstIP := net.IP(packet[16:20]) 466 | 467 | if !srcIP.Equal(localAddr.IP) { 468 | t.Errorf("Source IP should be %v, got %v", localAddr.IP, srcIP) 469 | } 470 | 471 | if !dstIP.Equal(remoteAddr.IP) { 472 | t.Errorf("Destination IP should be %v, got %v", remoteAddr.IP, dstIP) 473 | } 474 | } 475 | 476 | func TestCreateUDPPacketIPv6(t *testing.T) { 477 | netStack, _ := NewVirtualNetworkStack() 478 | conn, _ := netStack.CreateConnection("udp") 479 | 480 | localAddr := &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 8080} 481 | remoteAddr := &net.UDPAddr{IP: net.ParseIP("2001:db8::2"), Port: 80} 482 | conn.LocalAddr = localAddr 483 | conn.RemoteAddr = remoteAddr 484 | 485 | testData := []byte("test data") 486 | packet := netStack.createUDPPacket(conn, testData) 487 | 488 | if len(packet) < 48 { 489 | t.Error("IPv6 UDP packet should be at least 48 bytes (40 IP + 8 UDP)") 490 | } 491 | 492 | // Check IP version 493 | if packet[0] != 0x60 { 494 | t.Error("IP version should be 6") 495 | } 496 | 497 | // Check next header (UDP) 498 | if packet[6] != 17 { 499 | t.Error("Next header should be UDP (17)") 500 | } 501 | 502 | // Check source and destination IPs 503 | srcIP := net.IP(packet[8:24]) 504 | dstIP := net.IP(packet[24:40]) 505 | 506 | if !srcIP.Equal(localAddr.IP) { 507 | t.Errorf("Source IP should be %v, got %v", localAddr.IP, srcIP) 508 | } 509 | 510 | if !dstIP.Equal(remoteAddr.IP) { 511 | t.Errorf("Destination IP should be %v, got %v", remoteAddr.IP, dstIP) 512 | } 513 | } 514 | 515 | func TestOutgoingPacketsChannel(t *testing.T) { 516 | netStack, _ := NewVirtualNetworkStack() 517 | 518 | // Test that we can get the channel 519 | ch := netStack.OutgoingPackets() 520 | if ch == nil { 521 | t.Error("Outgoing packets channel should not be nil") 522 | } 523 | 524 | // Test that we can send to the channel 525 | testPacket := []byte("test packet") 526 | go func() { 527 | netStack.outgoingPackets <- testPacket 528 | }() 529 | 530 | select { 531 | case packet := <-ch: 532 | if string(packet) != string(testPacket) { 533 | t.Errorf("Expected packet %s, got %s", string(testPacket), string(packet)) 534 | } 535 | case <-time.After(100 * time.Millisecond): 536 | t.Error("Should receive packet from channel") 537 | } 538 | } 539 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "WrapGuard Test Script" 4 | echo "====================" 5 | 6 | # Check if binaries exist 7 | if [ ! -f "wrapguard" ]; then 8 | echo "Error: wrapguard binary not found. Run 'make build' first." 9 | exit 1 10 | fi 11 | 12 | if [ ! -f "libwrapguard.so" ]; then 13 | echo "Error: libwrapguard.so not found. Run 'make build' first." 14 | exit 1 15 | fi 16 | 17 | echo "✓ Binaries found" 18 | 19 | # Test 1: Show help 20 | echo -e "\nTest 1: Show usage" 21 | ./wrapguard || true 22 | 23 | # Test 2: Try with a simple command (without actual WireGuard connection) 24 | echo -e "\nTest 2: Run echo command through wrapguard" 25 | ./wrapguard --config=example-wg0.conf -- echo "Hello from wrapguard!" 26 | 27 | echo -e "\nBasic tests completed!" -------------------------------------------------------------------------------- /version.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // Version is set at build time using -ldflags 4 | var Version = "v1.0.0-dev" 5 | -------------------------------------------------------------------------------- /version_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestVersion(t *testing.T) { 8 | // Test that Version is not empty 9 | if Version == "" { 10 | t.Error("Version should not be empty") 11 | } 12 | 13 | // Test that Version has expected default value 14 | if Version != "v1.0.0-dev" { 15 | t.Errorf("Expected default version 'v1.0.0-dev', got %s", Version) 16 | } 17 | } 18 | 19 | func TestVersionFormat(t *testing.T) { 20 | // Test that Version starts with 'v' 21 | if len(Version) == 0 || Version[0] != 'v' { 22 | t.Errorf("Version should start with 'v', got %s", Version) 23 | } 24 | 25 | // Test that Version contains expected components 26 | if len(Version) < 5 { // At minimum "v1.0.0" 27 | t.Errorf("Version should be at least 5 characters long, got %s", Version) 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /wireguard.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bufio" 5 | "encoding/base64" 6 | "encoding/hex" 7 | "fmt" 8 | "net" 9 | "net/netip" 10 | "strings" 11 | "sync" 12 | 13 | "golang.zx2c4.com/wireguard/conn" 14 | "golang.zx2c4.com/wireguard/device" 15 | ) 16 | 17 | // WireGuardProxy manages the WireGuard connection and packet routing 18 | type WireGuardProxy struct { 19 | config *WireGuardConfig 20 | device *device.Device 21 | memTun *MemoryTUN 22 | netStack *VirtualNetworkStack 23 | udpConn *net.UDPConn 24 | logger *device.Logger 25 | stopChan chan struct{} 26 | wg sync.WaitGroup 27 | } 28 | 29 | // NewWireGuardProxy creates a new WireGuard proxy 30 | func NewWireGuardProxy(config *WireGuardConfig, netStack *VirtualNetworkStack, appLogger *Logger) (*WireGuardProxy, error) { 31 | // Create WireGuard device logger that routes to our structured logger 32 | logger := &device.Logger{ 33 | Verbosef: func(format string, args ...interface{}) { 34 | appLogger.DebugWithComponent("wireguard", fmt.Sprintf(format, args...)) 35 | }, 36 | Errorf: func(format string, args ...interface{}) { 37 | appLogger.ErrorWithComponent("wireguard", fmt.Sprintf(format, args...)) 38 | }, 39 | } 40 | 41 | // Create memory TUN 42 | memTun := NewMemoryTUN("wg0", config.Interface.MTU) 43 | 44 | // Create UDP socket for WireGuard 45 | listenAddr := fmt.Sprintf(":%d", config.Interface.ListenPort) 46 | addr, err := net.ResolveUDPAddr("udp", listenAddr) 47 | if err != nil { 48 | return nil, fmt.Errorf("failed to resolve UDP address: %w", err) 49 | } 50 | 51 | udpConn, err := net.ListenUDP("udp", addr) 52 | if err != nil { 53 | return nil, fmt.Errorf("failed to create UDP socket: %w", err) 54 | } 55 | 56 | // Create WireGuard device 57 | device := device.NewDevice(memTun, conn.NewDefaultBind(), logger) 58 | 59 | // Configure the device 60 | if err := configureDevice(device, config); err != nil { 61 | device.Close() 62 | udpConn.Close() 63 | return nil, fmt.Errorf("failed to configure device: %w", err) 64 | } 65 | 66 | return &WireGuardProxy{ 67 | config: config, 68 | device: device, 69 | memTun: memTun, 70 | netStack: netStack, 71 | udpConn: udpConn, 72 | logger: logger, 73 | stopChan: make(chan struct{}), 74 | }, nil 75 | } 76 | 77 | // Start starts the WireGuard proxy 78 | func (w *WireGuardProxy) Start() error { 79 | // Bring up the device 80 | w.device.Up() 81 | w.memTun.SendUp() 82 | 83 | // Start packet routing goroutines 84 | w.wg.Add(2) 85 | go w.routeIncomingPackets() 86 | go w.routeOutgoingPackets() 87 | 88 | return nil 89 | } 90 | 91 | // Stop stops the WireGuard proxy 92 | func (w *WireGuardProxy) Stop() error { 93 | close(w.stopChan) 94 | w.wg.Wait() 95 | 96 | w.device.Down() 97 | w.device.Close() 98 | w.udpConn.Close() 99 | w.memTun.Close() 100 | 101 | return nil 102 | } 103 | 104 | // routeIncomingPackets routes packets from WireGuard to the virtual network stack 105 | func (w *WireGuardProxy) routeIncomingPackets() { 106 | defer w.wg.Done() 107 | 108 | for { 109 | select { 110 | case <-w.stopChan: 111 | return 112 | default: 113 | // Read decrypted packet from WireGuard 114 | packet, err := w.memTun.ReadOutbound() 115 | if err != nil { 116 | if err.Error() != "EOF" { 117 | w.logger.Errorf("Failed to read from TUN: %v", err) 118 | } 119 | continue 120 | } 121 | 122 | // Route packet to virtual network stack 123 | if err := w.netStack.DeliverIncomingPacket(packet); err != nil { 124 | w.logger.Errorf("Failed to deliver incoming packet: %v", err) 125 | } 126 | } 127 | } 128 | } 129 | 130 | // routeOutgoingPackets routes packets from the virtual network stack to WireGuard 131 | func (w *WireGuardProxy) routeOutgoingPackets() { 132 | defer w.wg.Done() 133 | 134 | for { 135 | select { 136 | case <-w.stopChan: 137 | return 138 | case packet := <-w.netStack.OutgoingPackets(): 139 | // Send packet to WireGuard for encryption 140 | if err := w.memTun.InjectInbound(packet); err != nil { 141 | w.logger.Errorf("Failed to inject packet to WireGuard: %v", err) 142 | } 143 | } 144 | } 145 | } 146 | 147 | // SendPacket sends a packet through the WireGuard tunnel 148 | func (w *WireGuardProxy) SendPacket(packet []byte) error { 149 | return w.memTun.InjectInbound(packet) 150 | } 151 | 152 | // configureDevice configures the WireGuard device with the provided configuration 153 | func configureDevice(dev *device.Device, config *WireGuardConfig) error { 154 | // Decode private key 155 | privateKeyBytes, err := base64.StdEncoding.DecodeString(config.Interface.PrivateKey) 156 | if err != nil { 157 | return fmt.Errorf("failed to decode private key: %w", err) 158 | } 159 | 160 | // Build device configuration with hex-encoded key 161 | deviceConfig := fmt.Sprintf("private_key=%s\n", hex.EncodeToString(privateKeyBytes)) 162 | 163 | // Add peers 164 | for _, peer := range config.Peers { 165 | publicKeyBytes, err := base64.StdEncoding.DecodeString(peer.PublicKey) 166 | if err != nil { 167 | return fmt.Errorf("failed to decode public key: %w", err) 168 | } 169 | 170 | deviceConfig += fmt.Sprintf("public_key=%s\n", hex.EncodeToString(publicKeyBytes)) 171 | 172 | if peer.Endpoint != nil { 173 | // Convert to netip.AddrPort 174 | ip, err := netip.ParseAddr(peer.Endpoint.IP.String()) 175 | if err != nil { 176 | return fmt.Errorf("failed to parse endpoint IP: %w", err) 177 | } 178 | endpoint := netip.AddrPortFrom(ip, uint16(peer.Endpoint.Port)) 179 | deviceConfig += fmt.Sprintf("endpoint=%s\n", endpoint.String()) 180 | } 181 | 182 | for _, allowedIP := range peer.AllowedIPs { 183 | ones, _ := allowedIP.Mask.Size() 184 | ip, err := netip.ParseAddr(allowedIP.IP.String()) 185 | if err != nil { 186 | return fmt.Errorf("failed to parse allowed IP: %w", err) 187 | } 188 | prefix := netip.PrefixFrom(ip, ones) 189 | deviceConfig += fmt.Sprintf("allowed_ip=%s\n", prefix.String()) 190 | } 191 | 192 | if peer.PersistentKeepalive > 0 { 193 | deviceConfig += fmt.Sprintf("persistent_keepalive_interval=%d\n", peer.PersistentKeepalive) 194 | } 195 | 196 | if peer.PresharedKey != "" { 197 | presharedKeyBytes, err := base64.StdEncoding.DecodeString(peer.PresharedKey) 198 | if err != nil { 199 | return fmt.Errorf("failed to decode preshared key: %w", err) 200 | } 201 | deviceConfig += fmt.Sprintf("preshared_key=%s\n", hex.EncodeToString(presharedKeyBytes)) 202 | } 203 | } 204 | 205 | // Apply configuration 206 | reader := strings.NewReader(deviceConfig) 207 | if err := dev.IpcSetOperation(bufio.NewReader(reader)); err != nil { 208 | return fmt.Errorf("failed to configure device: %w", err) 209 | } 210 | 211 | return nil 212 | } 213 | --------------------------------------------------------------------------------