├── eventing ├── selectnotwindows.zig ├── common.zig ├── selectwindows.zig ├── epoll.zig └── select.zig ├── .gitignore ├── punch.zig ├── logging.zig ├── eventing.zig ├── relative-times ├── README.md ├── punch ├── proto.zig └── util.zig ├── restarter.zig ├── pool.zig ├── test ├── old ├── config-server.zig └── reverse-tunnel-client.zig ├── proxy.zig ├── nc.zig ├── double-server.zig ├── netext.zig ├── timing.zig ├── punch-client-forwarder.zig ├── common.zig ├── punch-server-initiator.zig └── socat.zig /eventing/selectnotwindows.zig: -------------------------------------------------------------------------------- 1 | // placeholder for now 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /zig-cache 2 | zig-out/ 3 | /debug 4 | /release 5 | /scratch 6 | *~ 7 | -------------------------------------------------------------------------------- /punch.zig: -------------------------------------------------------------------------------- 1 | pub const proto = @import("./punch/proto.zig"); 2 | pub const util = @import("./punch/util.zig"); 3 | -------------------------------------------------------------------------------- /logging.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | 3 | pub fn logTimestamp() void { 4 | std.debug.warn("{}: ", .{std.time.milliTimestamp()}); 5 | } 6 | pub fn log(comptime fmt: []const u8, args: anytype) void { 7 | logTimestamp(); 8 | std.debug.warn(fmt ++ "\n", args); 9 | } 10 | -------------------------------------------------------------------------------- /eventing/common.zig: -------------------------------------------------------------------------------- 1 | 2 | const EpollCtlError = error { 3 | FileDescriptorAlreadyPresentInSet, 4 | OperationCausesCircularLoop, 5 | FileDescriptorNotRegistered, 6 | SystemResources, 7 | UserResourceLimitReached, 8 | FileDescriptorIncompatibleWithEpoll, 9 | Unexpected, 10 | }; 11 | 12 | pub const EventerAddError = EpollCtlError; 13 | pub const EventerModifyError = EpollCtlError; 14 | -------------------------------------------------------------------------------- /eventing.zig: -------------------------------------------------------------------------------- 1 | const builtin = @import("builtin"); 2 | const std = @import("std"); 3 | 4 | pub const select = @import("eventing/select.zig"); 5 | pub const epoll = @import("eventing/epoll.zig"); 6 | 7 | pub const default = if (builtin.os.tag == .windows) select else epoll; 8 | 9 | pub const EventerOptions = struct { 10 | // The extra data type that the eventer tracks 11 | Data: type = struct {}, 12 | // The error type for callbacks 13 | CallbackError: type = anyerror, 14 | // The data that is passed to callbacks 15 | CallbackData: type = struct {}, 16 | }; 17 | -------------------------------------------------------------------------------- /relative-times: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Convert the timestamps in a logfile to relative 3 | import sys 4 | import fileinput 5 | 6 | def main(): 7 | first_time = None 8 | for line in fileinput.input(): 9 | time_sep_idx = line.find(': ') 10 | num = line[:time_sep_idx] 11 | content = line[time_sep_idx + 2:] 12 | 13 | if first_time == None: 14 | first_time = int(num) 15 | time_offset = 0 16 | else: 17 | time_offset = int(num) - first_time 18 | 19 | sys.stdout.write("{}: {}".format(time_offset / 1000, content)) 20 | 21 | main() 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # TODO 3 | 4 | I need to fix timestamp semantics after `std.time.milliTimestamp` was changed to a signed type. I need to fix `zig test timing.zig`. 5 | 6 | # Punch Protocol 7 | 8 | Connection is initiated with a handshake. 8 bytes for the punch protocol magic value `0x8ec04ff4a00e8694`, then 1-byte indicating which role the endpoint is taking. `0` for the initiator role which will be opening tunnels, and `1` for the forwarder role which will accept OpenTunnel messages and forward the tunnel data to another endpoint. 9 | 10 | > TODO: support authentication? Allow an authenticate command which requires a sequence of bytes to be sent from the other endpoint. 11 | 12 | ### Common Messages 13 | 14 | | Message | ID| Length | Data | 15 | |-------------|---|------------------|---------| 16 | | Heartbeat | 0 | | | 17 | | CloseTunnel | 1 | | | 18 | | Data | 2 | Length (8 bytes) | Data... | 19 | 20 | ### InitiatorOnly Messages 21 | 22 | | Message | ID| Length | Data | 23 | |-------------|---|------------------|---------| 24 | | OpenTunnel |128| | | 25 | -------------------------------------------------------------------------------- /eventing/selectwindows.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const os = std.os; 3 | 4 | pub const fd_t = os.socket_t; 5 | 6 | pub const fd_base_set = extern struct { 7 | fd_count: c_uint, 8 | fd_array: [0]fd_t, 9 | }; 10 | 11 | pub fn fd_set(comptime setSize: comptime_int) type { 12 | return extern struct { 13 | fd_count: c_uint, 14 | fd_array: [setSize]fd_t, 15 | pub fn base(self: *@This()) *fd_base_set { 16 | return @ptrCast(*fd_base_set, self); 17 | } 18 | pub fn add(self: *@This(), fd: fd_t) void { 19 | self.fd_array[self.fd_count] = fd; 20 | self.fd_count += 1; 21 | } 22 | }; 23 | } 24 | 25 | pub const timeval = extern struct { 26 | tv_sec: c_long, 27 | tv_usec: c_long, 28 | }; 29 | 30 | pub extern "ws2_32" fn select( 31 | nfds: c_int, // ignored 32 | readfds: *fd_base_set, 33 | writefds: *fd_base_set, 34 | exceptfds: *fd_base_set, 35 | timeout: ?*const timeval, 36 | ) callconv(os.windows.WINAPI) c_int; 37 | 38 | pub fn set_fd(comptime SetType: type, set: *SetType, s: fd_t) void { 39 | set.fd_array[set.fd_count] = s; 40 | set.fd_count += 1; 41 | } 42 | 43 | pub fn msToTimeval(ms: u31) timeval { 44 | return .{ 45 | .tv_sec = ms / 1000, 46 | .tv_usec = (ms % 1000) * 1000, 47 | }; 48 | } -------------------------------------------------------------------------------- /punch/proto.zig: -------------------------------------------------------------------------------- 1 | // 2 | // these magic values were randomly generated 3 | // 4 | // Client/Server 5 | // 6 | // The client connects to the server for punch communication. Once connected, each side will send 8-bytes to identify which role they are taking (i.e. the 'initiator' or the 'forwarder'). 7 | // 8 | // Initiator/Forwarder Roles: 9 | // 10 | // The 'initiator' is the one who can send the 'OpenTunnel' message to establish a new tunneled-connection through the punch data stream. 11 | // The 'initator' will have a "raw server" socket waiting for connections. 12 | // The 'forwader' will make a new "raw client" connection when it receives the OpenTunnel message. 13 | // 14 | pub const magic = [8]u8 { 0x8e, 0xc0, 0x4f, 0xf4, 0xa0, 0x0e, 0x86, 0x94 }; 15 | 16 | pub const Role = enum(u8) { 17 | initiator = 0, 18 | forwarder = 1, 19 | }; 20 | // TODO: it's possible an endpoint could be both an initator and a forwarder 21 | // if I find a use-case for this, I can make an initiator and forwarder flag 22 | pub const initiatorHandshake = magic ++ [1]u8 {@enumToInt(Role.initiator)}; 23 | pub const forwarderHandshake = magic ++ [1]u8 {@enumToInt(Role.forwarder)}; 24 | 25 | pub const TwoWayMessage = struct { 26 | pub const Heartbeat = 0; 27 | pub const CloseTunnel = 1; 28 | pub const Data = 2; 29 | }; 30 | pub const InitiatorMessage = struct { 31 | pub const OpenTunnel = 128; 32 | }; 33 | pub const ForwarderMessage = struct { 34 | }; 35 | -------------------------------------------------------------------------------- /restarter.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const mem = std.mem; 3 | const os = std.os; 4 | 5 | const logging = @import("./logging.zig"); 6 | const timing = @import("./timing.zig"); 7 | 8 | const ChildProcess = std.ChildProcess; 9 | const log = logging.log; 10 | 11 | fn makeThrottler(logPrefix: []const u8) timing.Throttler { 12 | return (timing.makeThrottler { 13 | .logPrefix = logPrefix, 14 | .desiredSleepMillis = 10000, 15 | .slowRateMillis = 500, 16 | }).create(); 17 | } 18 | 19 | const global = struct { 20 | var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator); 21 | }; 22 | 23 | fn usage() void { 24 | std.debug.warn("Usage: restarter PROGRAM ARGS\n", .{}); 25 | } 26 | 27 | pub fn main() anyerror!u8 { 28 | var args = try std.process.argsAlloc(&global.arena.allocator); 29 | if (args.len <= 1) { 30 | usage(); 31 | return 1; 32 | } 33 | args = args[1..]; 34 | 35 | var throttler = makeThrottler("[restarter] throttle: "); 36 | while (true) { 37 | throttler.throttle(); 38 | logging.logTimestamp(); 39 | std.debug.warn("[restarter] starting: ", .{}); 40 | printArgs(args); 41 | std.debug.warn("\n", .{}); 42 | // TODO: is there a way to use an allocator that can free? 43 | var proc = try std.ChildProcess.init(args, &global.arena.allocator); 44 | defer proc.deinit(); 45 | try proc.spawn(); 46 | try waitForChild(proc); 47 | } 48 | } 49 | 50 | fn printArgs(argv: []const []const u8) void { 51 | var prefix : []const u8 = ""; 52 | for (argv) |arg| { 53 | std.debug.warn("{s}'{s}'", .{prefix, arg}); 54 | prefix = " "; 55 | } 56 | } 57 | 58 | fn waitForChild(proc: *ChildProcess) !void { 59 | // prottect from printing signals too fast 60 | var signalThrottler = (timing.makeThrottler { 61 | .logPrefix = "[restarter] signal throttler: ", 62 | .desiredSleepMillis = 10000, 63 | .slowRateMillis = 100, 64 | }).create(); 65 | while (true) { 66 | signalThrottler.throttle(); 67 | switch (try proc.spawnAndWait()) { 68 | .Exited => |code| { 69 | log("[restarter] child process exited with {}", .{code}); 70 | return; 71 | }, 72 | .Stopped => |sig| log("[restarter] child process has stopped ({})", .{sig}), 73 | .Signal => |sig| log("[restarter] child process signal ({})", .{sig}), 74 | .Unknown => |sig| log("[restarter] child process unknown ({})", .{sig}), 75 | } 76 | } 77 | } -------------------------------------------------------------------------------- /pool.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const mem = std.mem; 3 | 4 | const ArrayList = std.ArrayList; 5 | const Allocator = mem.Allocator; 6 | 7 | /// Maintains a pool of objects. Objects created by this pool are never moved. 8 | /// It will re-use memory that has been freed, but will not try to release 9 | /// memory back to the underlying allocator. 10 | pub fn Pool(comptime T: type, chunkSize: usize) type { 11 | // TODO: assert if chunkSize < 1 12 | 13 | const Chunk = struct { 14 | array: [chunkSize]T, 15 | // TODO: this isn't necessary if T has in invalid bit pattern 16 | allocated: [chunkSize]bool, 17 | }; 18 | 19 | return struct { 20 | allocator: *Allocator, 21 | chunks: ArrayList(*Chunk), 22 | pub fn init(allocator: *Allocator) @This() { 23 | return @This() { 24 | .allocator = allocator, 25 | .chunks = ArrayList(*Chunk).init(allocator), 26 | }; 27 | } 28 | pub fn create(self: *@This()) Allocator.Error!*T { 29 | for (self.chunks.span()) |chunk| { 30 | var i: usize = 0; 31 | while (i < chunkSize) : (i += 1) { 32 | if (!chunk.allocated[i]) { 33 | chunk.allocated[i] = true; 34 | //std.debug.warn("[DEBUG] returning existing chunk 0x{x} index {}\n", .{@ptrToInt(&chunk.array[i]), i}); 35 | return &chunk.array[i]; 36 | } 37 | } 38 | } 39 | var newChunk = try self.allocator.create(Chunk); 40 | @memset(@ptrCast([*]u8, &newChunk.allocated), 0, @sizeOf(@TypeOf(newChunk.allocated))); 41 | try self.chunks.append(newChunk); 42 | newChunk.allocated[0] = true; 43 | //std.debug.warn("[DEBUG] returning new chunk 0x{x}\n", .{@ptrToInt(&newChunk.array[0])}); 44 | return &newChunk.array[0]; 45 | } 46 | 47 | pub fn destroy(self: *@This(), ptr: *T) void { 48 | for (self.chunks.span()) |chunk| { 49 | if (@ptrToInt(ptr) <= @ptrToInt(&chunk.array[chunkSize-1]) and 50 | @ptrToInt(ptr) >= @ptrToInt(&chunk.array[0])) { 51 | const diff = @ptrToInt(ptr) - @ptrToInt(&chunk.array[0]); 52 | const index = diff / @sizeOf(T); 53 | std.debug.assert(chunk.allocated[index]); // freed non-allocated pointer 54 | chunk.allocated[index] = false; 55 | // TODO: zero the memory? 56 | return; 57 | } 58 | } 59 | //std.debug.warn("destroy got invalid address 0x{x}\n", .{@ptrToInt(ptr)}); 60 | std.debug.assert(false); 61 | } 62 | pub fn range(self: *@This()) PoolRange(T, chunkSize) { 63 | return PoolRange(T, chunkSize) { .pool = self, .nextChunkIndex = 0, .nextElementIndex = 0 }; 64 | } 65 | }; 66 | } 67 | 68 | pub fn PoolRange(comptime T: type, chunkSize: usize) type { 69 | return struct { 70 | pool: *Pool(T, chunkSize), 71 | nextChunkIndex: usize, 72 | nextElementIndex: usize, 73 | fn inc(self: *@This()) void { 74 | self.nextElementIndex += 1; 75 | if (self.nextElementIndex == chunkSize) { 76 | self.nextChunkIndex += 1; 77 | self.nextElementIndex = 0; 78 | } 79 | } 80 | pub fn next(self: *@This()) ?*T { 81 | while (true) : (self.inc()) { 82 | if (self.nextChunkIndex >= self.pool.chunks.len) 83 | return null; 84 | var chunk = self.pool.chunks.span()[self.nextChunkIndex]; 85 | if (chunk.allocated[self.nextElementIndex]) { 86 | var result = &chunk.array[self.nextElementIndex]; 87 | self.inc(); 88 | return result; 89 | } 90 | } 91 | } 92 | }; 93 | } 94 | -------------------------------------------------------------------------------- /test: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # manual tests 4 | # 5 | # * connect punch-client-forwarder to a server that doesn't respond like a web server 6 | # 7 | # ./zig-out/bin/punch-client-forwarder 127.0.0.1 80 127.0.0.1 1234 8 | # 9 | # make sure it keeps reconnecting waiting for magic value 10 | # 11 | 12 | set -euo pipefail 13 | 14 | trap 'for job in $(jobs -p); do kill -n 9 $job; done' EXIT 15 | 16 | wait_for_port() { 17 | host=$1 18 | port=$2 19 | attempt=0 20 | # todo: increase sleep time 21 | while true; do 22 | attempt=$((attempt+1)) 23 | if netstat -tln | grep -q ":$port"; then 24 | break 25 | fi 26 | if [ "$attempt" == "6" ]; then 27 | echo "ERROR: Port '$port' on host '$host' did not open after $attempt attempts" 28 | return 1 29 | fi 30 | echo "Port '$port' on host '$host' not open on attempt $attempt, waiting..." 31 | sleep 0.2 32 | done 33 | echo "Port '$port' on host '$host' is open ($attempt attempts)" 34 | } 35 | 36 | wait_for_pid() { 37 | tail --pid=$1 -f /dev/null 38 | } 39 | 40 | clean_scratch() { 41 | rm -rf scratch 42 | mkdir scratch 43 | } 44 | clean_scratch 45 | 46 | bin=./zig-out/bin 47 | 48 | if [ -z "${START_PORT+x}" ]; then 49 | start_port=9281 50 | else 51 | start_port=$START_PORT 52 | fi 53 | 54 | port0=$start_port 55 | port1=$(expr $start_port + 1) 56 | port2=$(expr $start_port + 2) 57 | 58 | echo Using Ports: $port0 $port1 $port2 59 | 60 | ###$bin/reverse-tunnel-client 127.0.0.1 > scratch/reverse-tunnel-client.log 2>&1 & 61 | ###reverse_tunnel_client_pid=$! ### 62 | ###$bin/config-server > scratch/config-server1.log 2>&1 & 63 | ###config_server_pid=$! 64 | #### wait for server to start 65 | ###sleep 1 66 | ###kill -n 9 $config_server_pid 67 | #### wait for server to exit 68 | ###sleep 1 69 | ### 70 | ###$bin/config-server > scratch/config-server2.log 2>&1 & 71 | ###config_server_pid=$! 72 | #### wait for server to start 73 | ###sleep 1 74 | ### 75 | 76 | test_punch() { 77 | echo "test_punch" 78 | clean_scratch 79 | $bin/punch-client-forwarder 127.0.0.1 $port0 127.0.0.1 $port2 > scratch/punch-client-forwarder.log 2>&1 & 80 | punch_client_forwarder_pid=$! 81 | 82 | $bin/punch-server-initiator 127.0.0.1 $port0 127.0.0.1 $port1 > scratch/punch-server-initiator.log 2>&1 & 83 | punch_server_initiator_pid=$! 84 | 85 | $bin/nc -l $port2 > scratch/inside-server.log 2>&1 & 86 | inside_server_pid=$! 87 | wait_for_port 127.0.0.1 $port2 88 | wait_for_port 127.0.0.1 $port1 89 | echo "hello from outside-client" | $bin/nc 127.0.0.1 $port1 > scratch/outside-client.log 2>&1 90 | wait_for_pid $inside_server_pid 91 | grep -q "hello from outside-client" scratch/inside-server.log 92 | } 93 | 94 | test_socat() { 95 | echo "test_socat" 96 | clean_scratch 97 | for i in {1..5}; do 98 | $bin/socat tcp-listen:1280 tcp-listen:1281 > scratch/temp-socat-double-server-$i.log 2>&1 & 99 | socat_double_server_pid=$! 100 | wait_for_port 127.0.0.1 1280 101 | # kill and restart to make sure immediate restart works 102 | kill -n 9 $socat_double_server_pid 103 | done 104 | # disable throttling so the test runs fast 105 | $bin/socat --no-throttle tcp-listen:1280 tcp-listen:1281 > scratch/socat-double-server-main.log 2>&1 & 106 | socat_double_server_pid=$! 107 | wait_for_port 127.0.0.1 1280 108 | for i in {1..20}; do 109 | mkfifo scratch/nc-1280.fifo 110 | cat scratch/nc-1280.fifo | $bin/nc 127.0.0.1 1280 > scratch/nc-1280.log 2>&1 & 111 | wait_for_port 127.0.0.1 1281 112 | $bin/nc 127.0.0.1 1281 > scratch/nc-1281.log 2>&1 & 113 | nc_1281_pid=$! 114 | echo "what" > scratch/nc-1280.fifo 115 | wait_for_pid $nc_1281_pid 116 | grep -q what scratch/nc-1281.log 117 | rm scratch/nc-128* 118 | done 119 | kill -n 9 $socat_double_server_pid 120 | } 121 | 122 | 123 | #$bin/config-server > scratch/config-server1.log 2>&1 & 124 | #config_server_pid=$! 125 | ## wait for server to start 126 | #sleep 1 127 | #kill -n 9 $config_server_pid 128 | ## wait for server to exit 129 | #sleep 1 130 | 131 | #$bin/config-server > scratch/config-server2.log 2>&1 & 132 | #config_server_pid=$! 133 | ## wait for server to start 134 | #sleep 1 135 | 136 | #echo "" | $bin/nc 127.0.0.1 $port0 | hexdump -C 137 | 138 | #kill -n 9 $config_server_pid 139 | 140 | #$bin/double-server > scratch/double-server.log 2>&1 & 141 | #double_server_pid=$! 142 | ## wait for server to start 143 | #sleep 1 144 | 145 | #echo "" | $bin/nc 127.0.0.1 $port1 146 | 147 | #echo "killing double-server pid=$double_server_pid..." 148 | #kill -n 9 $double_server_pid 149 | 150 | test_punch 151 | test_socat 152 | 153 | echo Success 154 | -------------------------------------------------------------------------------- /old/config-server.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const mem = std.mem; 3 | const os = std.os; 4 | const net = std.net; 5 | 6 | const common = @import("./common.zig"); 7 | const eventing = @import("./eventing.zig"); 8 | const pool = @import("./pool.zig"); 9 | const Pool = pool.Pool; 10 | 11 | const fd_t = os.fd_t; 12 | const Address = net.Address; 13 | const EventFlags = eventing.EventFlags; 14 | 15 | const Fd = struct { 16 | fd: fd_t, 17 | }; 18 | const Eventer = eventing.EventerTemplate(anyerror, struct {}, Fd); 19 | 20 | const global = struct { 21 | var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator); 22 | 23 | // TODO: change from 1 to something else 24 | var clientPool = Pool(Client, 1).init(&arena.allocator); 25 | var config = [_]u8 { 26 | 11, // msg size 27 | 3, // add endpoint 28 | 0, 0, 0, 0, // connection ID (one connection per ID) 29 | 96, 19, 192, 252, // ip address 30 | 0xFF & (9282 >> 8), // port 31 | 0xFF & (9282 >> 0), 32 | }; 33 | }; 34 | 35 | const Client = struct { 36 | callback: Eventer.Callback, 37 | }; 38 | 39 | fn callbackToClient(callback: *Eventer.Callback) *Client { 40 | return @ptrCast(*Client, 41 | @ptrCast([*]u8, callback) - @byteOffsetOf(Client, "callback") 42 | ); 43 | } 44 | 45 | pub fn main() anyerror!u8 { 46 | var eventer = try Eventer.init(.{}); 47 | 48 | var serverCallback = initServer: { 49 | const port : u16 = 9281; // picked a random one 50 | const sockfd = try common.makeListenSock(&Address.initIp4([4]u8 {0, 0, 0, 0}, port)); 51 | std.debug.warn("[DEBUG] server socket is {}\n", .{sockfd}); 52 | break :initServer Eventer.Callback { 53 | .func = onAccept, 54 | .data = Fd { .fd = sockfd }, 55 | }; 56 | }; 57 | try eventer.add(serverCallback.data.fd, EventFlags.read, &serverCallback); 58 | try eventer.loop(); 59 | return 0; 60 | } 61 | 62 | fn onAccept(eventer: *Eventer, callback: *Eventer.Callback) anyerror!void { 63 | const fd = callback.data.fd; 64 | std.debug.warn("accepting client on socket {}!\n", .{fd}); 65 | 66 | var addr : Address = undefined; 67 | var addrlen : os.socklen_t = @sizeOf(@TypeOf(addr)); 68 | 69 | const newsockfd = try os.accept(fd, &addr.any, &addrlen, os.SOCK_NONBLOCK); 70 | errdefer common.shutdownclose(newsockfd); 71 | std.debug.warn("got new client {} from {}\n", .{newsockfd, addr}); 72 | 73 | // can add a client/server handshake/auth but for now I'm just going to send the config 74 | // send the config 75 | { 76 | std.debug.warn("s={} sending {}-byte config...\n", .{newsockfd, global.config.len}); 77 | const sendResult = os.send(newsockfd, &global.config, 0) catch |e| { 78 | std.debug.warn("s={} send initial config of {}-bytes failed: {}\n", .{newsockfd, global.config.len, e}); 79 | common.shutdownclose(newsockfd); 80 | return; 81 | }; 82 | if (sendResult != global.config.len) { 83 | std.debug.warn("s={} failed to send {}-byte initial config, returned {}\n", .{newsockfd, global.config.len, sendResult}); 84 | common.shutdownclose(newsockfd); 85 | return; 86 | } 87 | } 88 | 89 | { 90 | var newClient = try global.clientPool.create(); 91 | errdefer global.clientPool.destroy(newClient); 92 | //std.debug.warn("[DEBUG] new client at 0x{x}\n", .{@ptrToInt(newClient)}); 93 | newClient.* = Client { 94 | .callback = Eventer.Callback { 95 | .func = onClientData, 96 | .data = Fd { .fd = newsockfd }, 97 | }, 98 | }; 99 | try eventer.add(newsockfd, EventFlags.read, &newClient.callback); 100 | // we've now tranferred ownership of newClient, do not free it here, even on errors 101 | } 102 | } 103 | 104 | fn removeClient(eventer: *Eventer, callback: *Eventer.Callback) !void { 105 | eventer.remove(callback.data.fd); 106 | common.shutdownclose(callback.data.fd); 107 | global.clientPool.destroy(callbackToClient(callback)); 108 | } 109 | 110 | fn onClientData(eventer: *Eventer, callback: *Eventer.Callback) anyerror!void { 111 | const fd = callback.data.fd; 112 | //std.debug.warn("got data on socket {}!\n", .{fd}); 113 | var buffer: [100]u8 = undefined; 114 | const len = os.read(fd, &buffer) catch |e| { 115 | try removeClient(eventer, callback); 116 | return e; 117 | }; 118 | if (len == 0) { 119 | std.debug.warn("client {} closed because read returned 0\n", .{fd}); 120 | try removeClient(eventer, callback); 121 | return; 122 | } 123 | std.debug.warn("[DEBUG] got {} bytes from socket {}\n", .{len, fd}); 124 | } 125 | -------------------------------------------------------------------------------- /eventing/epoll.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const os = std.os; 3 | 4 | const logging = @import("../logging.zig"); 5 | const common = @import("./common.zig"); 6 | 7 | const eventing = @import("../eventing.zig"); 8 | const EventerOptions = eventing.EventerOptions; 9 | 10 | const fd_t = os.fd_t; 11 | const log = logging.log; 12 | 13 | pub const EventFlags = struct { 14 | pub const read = os.linux.EPOLL.IN; 15 | pub const write = os.linux.EPOLL.OUT; 16 | pub const hangup = os.linux.EPOLL.RDHUP; 17 | }; 18 | 19 | pub fn EventerTemplate(comptime options: EventerOptions) type { 20 | return struct { 21 | pub const Fd = fd_t; 22 | pub const Data = options.Data; 23 | pub const CallbackError = options.CallbackError; 24 | pub const Callback = struct { 25 | func: CallbackFn, 26 | data: options.CallbackData, 27 | pub fn init(func: CallbackFn, data: options.CallbackData) @This() { 28 | return @This() { 29 | .func = func, 30 | .data = data, 31 | }; 32 | } 33 | }; 34 | pub const CallbackFn = fn(server: *@This(), callback: *Callback) CallbackError!void; 35 | 36 | /// data that can be shared between all callbacks 37 | data: Data, 38 | epollfd: fd_t, 39 | ownEpollFd: bool, 40 | pub fn init(data: Data) !@This() { 41 | return @This().initEpoll(data, try os.epoll_create1(0), true); 42 | } 43 | pub fn initEpoll(data: Data, epollfd: fd_t, ownEpollFd: bool) @This() { 44 | return @This() { 45 | .data = data, 46 | .epollfd = epollfd, 47 | .ownEpollFd = ownEpollFd, 48 | }; 49 | } 50 | pub fn deinit(self: *@This()) void { 51 | if (self.ownEpollFd) 52 | os.close(self.epollfd); 53 | } 54 | 55 | pub fn add(self: *@This(), fd: fd_t, flags: u32, callback: *Callback) common.EventerAddError!void { 56 | var event = os.linux.epoll_event { 57 | .events = flags, 58 | .data = os.linux.epoll_data { .ptr = @ptrToInt(callback) }, 59 | }; 60 | try os.epoll_ctl(self.epollfd, os.linux.EPOLL.CTL_ADD, fd, &event); 61 | } 62 | pub fn modify(self: *@This(), fd: fd_t, flags: u32, callback: *Callback) common.EventerModifyError!void { 63 | var event = os.linux.epoll_event { 64 | .events = flags, 65 | .data = os.linux.epoll_data { .ptr = @ptrToInt(callback) }, 66 | }; 67 | try os.epoll_ctl(self.epollfd, os.linux.EPOLL.CTL_MOD, fd, &event); 68 | } 69 | 70 | pub fn remove(self: *@This(), fd: fd_t) void { 71 | // TODO: kernels before 2.6.9 had a bug where event must be non-null 72 | os.epoll_ctl(self.epollfd, os.linux.EPOLL.CTL_DEL, fd, null) catch |e| switch (e) { 73 | error.FileDescriptorNotRegistered // we could ignore this, but this represents a code bug 74 | ,error.FileDescriptorAlreadyPresentInSet 75 | ,error.FileDescriptorIncompatibleWithEpoll 76 | ,error.OperationCausesCircularLoop 77 | ,error.SystemResources // this should never happen during removal 78 | ,error.UserResourceLimitReached // this should never happen during removal 79 | ,error.Unexpected 80 | => std.debug.panic("epoll_ctl DEL failed with {}", .{e}), 81 | }; 82 | } 83 | 84 | // returns: false if there was a timeout 85 | fn handleEventsGeneric(self: *@This(), timeoutMillis: i32) CallbackError!bool { 86 | // get 1 event at a time to prevent stale events 87 | var events : [1]os.linux.epoll_event = undefined; 88 | const count = os.epoll_wait(self.epollfd, &events, timeoutMillis); 89 | const errno = os.errno(count); 90 | switch (errno) { 91 | .SUCCESS => {}, 92 | .BADF 93 | ,.FAULT 94 | ,.INTR 95 | ,.INVAL 96 | => std.debug.panic("epoll_wait failed with {}", .{errno}), 97 | else => std.debug.panic("epoll_wait failed with {}", .{errno}), 98 | } 99 | if (count == 0) 100 | return false; // timeout 101 | for (events[0..count]) |event| { 102 | const callback = @intToPtr(*Callback, event.data.ptr); 103 | try callback.func(self, callback); 104 | } 105 | return true; // was not a timeout 106 | } 107 | pub fn handleEvents(self: *@This(), timeoutMillis: u32) CallbackError!bool { 108 | return self.handleEventsGeneric(@intCast(i32, timeoutMillis)); 109 | } 110 | 111 | pub fn handleEventsNoTimeout(self: *@This()) CallbackError!void { 112 | if (!try self.handleEventsGeneric(-1)) 113 | std.debug.panic("epoll returned 0 with ifinite timeout?", .{}); 114 | } 115 | // a convenient helper method, might remove this 116 | // TODO: should only return CallbackError, not CallbackError!void 117 | pub fn loop(self: *@This()) CallbackError!void { 118 | while (true) { 119 | try self.handleEventsNoTimeout(); 120 | } 121 | } 122 | }; 123 | } 124 | 125 | pub fn epoll_create1(flags: u32) !fd_t { 126 | return os.epoll_create1(flags) catch |e| switch (e) { 127 | error.SystemFdQuotaExceeded 128 | ,error.ProcessFdQuotaExceeded 129 | ,error.SystemResources 130 | => { 131 | log("epoll_create1 failed with {}", .{e}); 132 | return error.Retry; 133 | }, 134 | error.Unexpected 135 | => std.debug.panic("epoll_create1 failed with {}", .{e}), 136 | }; 137 | } 138 | -------------------------------------------------------------------------------- /proxy.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const mem = std.mem; 3 | const os = std.os; 4 | 5 | const common = @import("./common.zig"); 6 | 7 | const assert = std.debug.assert; 8 | const fd_t = os.fd_t; 9 | const socket_t = os.socket_t; 10 | 11 | const MAX_HOST = 253; 12 | const MAX_PORT_DIGITS = 5; 13 | 14 | pub const Proxy = union(enum) { 15 | None: void, 16 | Http: Http, 17 | 18 | pub const Http = struct { 19 | host: []const u8, 20 | port: u16, 21 | }; 22 | 23 | // TODO: the HTTP protocol means that we could read data from the target 24 | // server during negotiation, so this function would need to support 25 | // returning any extra data received from the target server 26 | pub fn connectHost(self: *const @This(), host: []const u8, port: u16) !socket_t { 27 | std.debug.assert(host.len <= MAX_HOST); 28 | 29 | switch (self.*) { 30 | .None => return common.connectHost(host, port), 31 | .Http => |http| { 32 | const sockfd = try common.connectHost(http.host, http.port); 33 | errdefer common.shutdownclose(sockfd); 34 | try sendHttpConnect(sockfd, host, port); 35 | try receiveHttpOk(sockfd, 10000); 36 | return sockfd; 37 | }, 38 | } 39 | } 40 | 41 | pub fn eql(self: *const @This(), other: *const @This()) bool { 42 | switch (self.*) { 43 | .None => switch (other.*) { .None => return true, else => return false }, 44 | .Http => |selfHttp| switch (other.*) { 45 | .Http => |otherHttp| return selfHttp.port == otherHttp.port and 46 | std.mem.eql(u8, selfHttp.host, otherHttp.host), 47 | else => return false, 48 | }, 49 | } 50 | } 51 | 52 | pub fn format( 53 | self: Proxy, 54 | comptime fmt: []const u8, 55 | options: std.fmt.FormatOptions, 56 | out_stream: anytype, 57 | ) !void { 58 | _ = fmt; 59 | _ = options; 60 | switch (self) { 61 | .None => return, 62 | .Http => |http| { 63 | try std.fmt.format(out_stream, "http://{s}:{}/", .{http.host, http.port}); 64 | }, 65 | } 66 | } 67 | }; 68 | 69 | pub fn sendHttpConnect(sockfd: socket_t, host: []const u8, port: u16) !void { 70 | const PART1 = "CONNECT "; 71 | const PART2 = " HTTP/1.1\r\nHost: "; 72 | const PART3 = "\r\n\r\n"; 73 | const MAX_CONNECT_REQUEST = 74 | PART1.len 75 | + MAX_HOST + 1 + MAX_PORT_DIGITS 76 | + PART2.len 77 | + MAX_HOST + 1 + MAX_PORT_DIGITS 78 | + PART3.len; 79 | var requestBuffer : [MAX_CONNECT_REQUEST]u8 = undefined; 80 | const request = std.fmt.bufPrint(&requestBuffer, 81 | PART1 ++ "{s}:{}" ++ PART2 ++ "{s}:{}" ++ PART3, 82 | .{host, port, host, port}) catch |e| switch (e) { 83 | error.NoSpaceLeft 84 | => std.debug.panic("code bug: HTTP CONNECT requeset buffer {} not big enough", .{MAX_CONNECT_REQUEST}), 85 | }; 86 | try common.sendfull(sockfd, request, 0); 87 | } 88 | 89 | const Http200Response = "HTTP/1.1 200"; 90 | const HttpEndResponse = "\r\n\r\n"; 91 | pub fn receiveHttpOk(sockfd: fd_t, readTimeoutMillis: i32) !void { 92 | // TODO: I must implement a reasonable timeout 93 | // to prevent waiting forever if I never get \r\n\r\n 94 | _ = readTimeoutMillis; 95 | const State = union(enum) { 96 | Reading200: u8, 97 | ReadingToEnd: u8, 98 | }; 99 | var buf: [1]u8 = undefined; 100 | var state = State { .Reading200 = 0 }; 101 | while (true) { 102 | // TODO: read with a timeout 103 | const received = try os.read(sockfd, &buf); 104 | if (received == 0) 105 | return error.HttpProxyDisconnectedDurringReply; 106 | //std.debug.warn("[DEBUG] got '{}' 0x{x}\n", .{buf[0..], buf[0]}); 107 | switch (state) { 108 | .Reading200 => |left| { 109 | if (buf[0] != Http200Response[left]) 110 | return error.HttpProxyUnexpectedReply; 111 | state.Reading200 += 1; 112 | if (state.Reading200 == Http200Response.len) 113 | state = State { .ReadingToEnd = 0 }; 114 | }, 115 | .ReadingToEnd => |matched| { 116 | if (buf[0] == HttpEndResponse[matched]) { 117 | state.ReadingToEnd += 1; 118 | if (state.ReadingToEnd == HttpEndResponse.len) 119 | return; // success 120 | } else { 121 | state.ReadingToEnd = 0; 122 | } 123 | }, 124 | } 125 | } 126 | } 127 | 128 | 129 | pub const HostAndProxy = struct { 130 | host: []const u8, 131 | proxy: Proxy, 132 | 133 | pub fn eql(self: *const @This(), other: *const @This()) bool { 134 | return std.mem.eql(u8, self.host, other.host) and 135 | self.proxy.eql(&other.proxy); 136 | } 137 | }; 138 | 139 | pub fn parseProxy(connectSpec: anytype) !HostAndProxy { 140 | return parseProxyTyped(@TypeOf(connectSpec), connectSpec); 141 | } 142 | pub fn parseProxyTyped(comptime String: type, connectSpec: String) !HostAndProxy { 143 | var rest = connectSpec; 144 | if (common.skipOver(String, &rest, "http://")) { 145 | const slashIndex = mem.indexOfScalar(u8, rest, '/') orelse 146 | return error.MissingSlashToDelimitProxy; 147 | var host = rest[slashIndex + 1..]; 148 | if (host.len == 0) 149 | return error.NoHostAfterProxy; 150 | var proxyHostPort = rest[0 .. slashIndex]; 151 | var proxyColonIndex = mem.indexOfScalar(u8, proxyHostPort, ':') orelse 152 | return error.ProxyMissingPort; 153 | var proxyHost = proxyHostPort[0 .. proxyColonIndex]; 154 | if (proxyHost.len == 0) 155 | return error.ProxyMissingHost; 156 | var proxyPortString = proxyHostPort[proxyColonIndex+1..]; 157 | if (proxyHost.len == 0) 158 | return error.ProxyMissingPort; 159 | const proxyPort = std.fmt.parseInt(u16, proxyPortString, 10) catch |e| switch (e) { 160 | error.Overflow => return error.ProxyPortOutOfRange, 161 | error.InvalidCharacter => return error.ProxyPortNotNumber, 162 | }; 163 | return HostAndProxy { 164 | .host = host, 165 | .proxy = Proxy { .Http = .{ 166 | .host = proxyHost, 167 | .port = proxyPort, 168 | }}, 169 | }; 170 | } 171 | return HostAndProxy { .host = connectSpec, .proxy = Proxy.None }; 172 | } 173 | 174 | test "parseProxy" { 175 | assert((HostAndProxy { 176 | .host = "a", 177 | .proxy = Proxy.None, 178 | }).eql(&try parseProxyTyped([]const u8, "a"))); 179 | assert((HostAndProxy { 180 | .host = "hey", 181 | .proxy = Proxy { .Http = .{.host = "what.com", .port = 1234} }, 182 | }).eql(&try parseProxyTyped([]const u8, "http://what.com:1234/hey"))); 183 | } 184 | -------------------------------------------------------------------------------- /nc.zig: -------------------------------------------------------------------------------- 1 | // a simple version of netcat for testing 2 | // this is created so we have a common implementation for things like "CLOSE ON EOF" 3 | // 4 | // TODO: is it worth it to support the sendfile syscall variation? 5 | // maybe not since it will make this more complicated and its 6 | // main purpose is just for testing 7 | // 8 | const builtin = @import("builtin"); 9 | const std = @import("std"); 10 | const mem = std.mem; 11 | const os = std.os; 12 | const net = std.net; 13 | 14 | const common = @import("./common.zig"); 15 | const eventing = @import("./eventing.zig").default; 16 | 17 | const fd_t = os.fd_t; 18 | const Address = net.Address; 19 | const EventFlags = eventing.EventFlags; 20 | const Eventer = eventing.EventerTemplate(.{}); 21 | 22 | const global = struct { 23 | var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator); 24 | var stdout : fd_t = undefined; 25 | var stdin : fd_t = undefined; 26 | var sockfd : fd_t = undefined; 27 | var buffer : [8192]u8 = undefined; 28 | }; 29 | 30 | fn usage() void { 31 | std.debug.warn("Usage: nc [-l PORT]\n", .{}); 32 | std.debug.warn(" nc [-z] HOST PORT\n", .{}); 33 | std.debug.warn(" -z Scan for open port without sending data\n", .{}); 34 | } 35 | 36 | pub fn main() anyerror!u8 { 37 | global.stdout = std.io.getStdOut().handle; 38 | global.stdin = std.io.getStdIn().handle; 39 | 40 | var args = try std.process.argsAlloc(&global.arena.allocator); 41 | if (args.len <= 1) { 42 | usage(); 43 | return 1; 44 | } 45 | args = args[1..]; 46 | 47 | var optionalListenPort : ?u16 = null; 48 | var portScan = false; 49 | { 50 | var newArgsLen : usize = 0; 51 | defer args = args[0..newArgsLen]; 52 | var i : usize = 0; 53 | while (i < args.len) : (i += 1) { 54 | const arg = args[i]; 55 | if (!std.mem.startsWith(u8, arg, "-")) { 56 | args[newArgsLen] = arg; 57 | newArgsLen += 1; 58 | } else if (std.mem.eql(u8, arg, "-l")) { 59 | optionalListenPort = common.parsePort(common.getOptArg(args, &i) catch return 1) catch return 1; 60 | } else if (std.mem.eql(u8, arg, "-z")) { 61 | portScan = true; 62 | } else { 63 | std.debug.warn("Error: unknown command-line option '{s}'\n", .{arg}); 64 | return 1; 65 | } 66 | } 67 | } 68 | 69 | global.sockfd = initSock: { 70 | if (optionalListenPort) |listenPort| { 71 | if (args.len != 0) { 72 | usage(); 73 | return 1; 74 | } 75 | if (portScan) { 76 | std.debug.warn("Error: '-z' (port scan) is not compatible with '-l PORT'\n", .{}); 77 | return 1; 78 | } 79 | var addr = Address.initIp4([4]u8{0,0,0,0}, listenPort); 80 | const listenFd = try os.socket(addr.any.family, os.SOCK.STREAM, os.IPPROTO.TCP); 81 | defer os.close(listenFd); 82 | if (builtin.os.tag != .windows) { 83 | try os.setsockopt(listenFd, os.SOL.SOCKET, os.SO.REUSEADDR, &std.mem.toBytes(@as(c_int, 1))); 84 | } 85 | try os.bind(listenFd, &addr.any, addr.getOsSockLen()); 86 | try os.listen(listenFd, 1); 87 | std.debug.warn("[NC] listening on {}...\n", .{addr}); 88 | var clientAddr : Address = undefined; 89 | var clientAddrLen : os.socklen_t = @sizeOf(@TypeOf(clientAddr)); 90 | const clientFd = try os.accept(listenFd, &clientAddr.any, &clientAddrLen, 0); 91 | std.debug.warn("[NC] accepted client {}\n", .{clientAddr}); 92 | break :initSock clientFd; 93 | } else { 94 | if (args.len != 2) { 95 | usage(); 96 | return 1; 97 | } 98 | const hostString = args[0]; 99 | const portString = args[1]; 100 | const port = common.parsePort(portString) catch return 1; 101 | const addr = Address.parseIp4(hostString, port) catch |e| { 102 | std.debug.warn("Error: failed to parse '{s}' as an IPv4 address: {}\n", .{hostString, e}); 103 | return 1; 104 | }; 105 | std.debug.warn("[NC] connecting to {}...\n", .{addr}); 106 | // tcpConnectToHost is not working 107 | //break :initSock net.tcpConnectToHost(&global.arena.allocator, "localhost", 9282)).handle; 108 | const sockFile = try net.tcpConnectToAddress(addr); 109 | std.debug.warn("[NC] connected\n", .{}); 110 | if (portScan) { 111 | try common.shutdown(sockFile.handle); 112 | os.close(sockFile.handle); 113 | return 0; 114 | } 115 | break :initSock sockFile.handle; 116 | } 117 | }; 118 | 119 | var eventer = try Eventer.init(.{}); 120 | var sockCallback = Eventer.Callback { 121 | .func = onSockData, 122 | .data = .{}, 123 | }; 124 | try eventer.add(global.sockfd, EventFlags.read, &sockCallback); 125 | 126 | var stdinCallback = Eventer.Callback { 127 | .func = onStdinData, 128 | .data = .{}, 129 | }; 130 | eventer.add(global.stdin, EventFlags.read, &stdinCallback) catch |e| switch (e) { 131 | error.FileDescriptorIncompatibleWithEpoll => { 132 | std.debug.warn("[NC] stdin appears to be closed, will ignore it\n", .{}); 133 | }, 134 | else => return e, 135 | }; 136 | try eventer.loop(); 137 | return 0; 138 | } 139 | 140 | fn sockDisconnected() noreturn { 141 | // we can't shutdown stdin, so the only thing to do once 142 | // the socket is disconnected is to exit 143 | os.exit(0); // nothing else we can do 144 | } 145 | 146 | fn onSockData(eventer: *Eventer, callback: *Eventer.Callback) anyerror!void { 147 | _ = eventer; 148 | _ = callback; 149 | // TODO: I should use the sendfile syscall if available 150 | const length = os.read(global.sockfd, &global.buffer) catch |e| { 151 | std.debug.warn("[NC] s={} read failed: {}\n", .{global.sockfd, e}); 152 | sockDisconnected(); 153 | }; 154 | if (length == 0) { 155 | std.debug.warn("[NC] s={} disconnected\n", .{global.sockfd}); 156 | sockDisconnected(); 157 | } 158 | if (common.tryWriteAll(global.stdout, global.buffer[0..length])) |result| { 159 | std.debug.warn("[NC] s={} write failed with {}, wrote {} bytes out of {}", .{global.sockfd, result.err, result.wrote, length}); 160 | return error.StdoutClosed; 161 | } 162 | } 163 | 164 | fn stdinClosed(eventer: *Eventer) !void { 165 | eventer.remove(global.stdin); 166 | os.close(global.stdin); // TODO: I don't have to close stdin...should I? 167 | try common.shutdown(global.sockfd); 168 | } 169 | 170 | fn onStdinData(eventer: *Eventer, callback: *Eventer.Callback) anyerror!void { 171 | _ = callback; 172 | // TODO: I should use the sendfile syscall if available 173 | const length = os.read(global.stdin, &global.buffer) catch |e| { 174 | std.debug.warn("[NC] stdin read failed: {}\n", .{e}); 175 | try stdinClosed(eventer); 176 | return; 177 | }; 178 | if (length == 0) { 179 | std.debug.warn("[NC] stdin EOF\n", .{}); 180 | try stdinClosed(eventer); 181 | return; 182 | } 183 | try common.sendfull(global.sockfd, global.buffer[0..length], 0); 184 | } 185 | -------------------------------------------------------------------------------- /double-server.zig: -------------------------------------------------------------------------------- 1 | const builtin = @import("builtin"); 2 | const std = @import("std"); 3 | const mem = std.mem; 4 | const os = std.os; 5 | const net = std.net; 6 | 7 | const common = @import("./common.zig"); 8 | const eventing = @import("./eventing.zig").default; 9 | 10 | const fd_t = os.fd_t; 11 | const Address = net.Address; 12 | const EventFlags = eventing.EventFlags; 13 | const Eventer = eventing.EventerTemplate(.{}); 14 | 15 | const INVALID_FD = if(builtin.os.tag == .windows) std.os.windows.ws2_32.INVALID_SOCKET 16 | else -1; 17 | 18 | const Client = struct { 19 | fd: fd_t, 20 | callback: Eventer.Callback, 21 | }; 22 | 23 | const global = struct { 24 | var listenFd : fd_t = undefined; 25 | var clientA : Client = undefined; 26 | var clientB : Client = undefined; 27 | 28 | // client's are 'linked' once they have sent data to each-other 29 | // if client's are linked, then when one closes it will cause 30 | // the other to close 31 | var clientsLinked : bool = undefined; 32 | 33 | var buffer : [8192]u8 = undefined; 34 | }; 35 | 36 | fn callbackToClient(callback: *Eventer.Callback) *Client { 37 | if (callback == &global.clientA.callback) 38 | return &global.clientA; 39 | std.debug.assert(callback == &global.clientB.callback); // code bug if false 40 | return &global.clientB; 41 | } 42 | 43 | pub fn main() anyerror!u8 { 44 | global.clientA.fd = INVALID_FD; 45 | global.clientB.fd = INVALID_FD; 46 | var eventer = try Eventer.init(.{}); 47 | 48 | var serverCallback = initServer: { 49 | const port : u16 = 9282; 50 | global.listenFd = try common.makeListenSock(&Address.initIp4([4]u8 {0, 0, 0, 0}, port)); 51 | std.debug.warn("[DEBUG] server socket is {}\n", .{global.listenFd}); 52 | break :initServer Eventer.Callback { 53 | .func = onAccept, 54 | .data = .{}, 55 | }; 56 | }; 57 | try eventer.add(global.listenFd, EventFlags.read, &serverCallback); 58 | try eventer.loop(); 59 | return 0; 60 | } 61 | 62 | fn onAccept(eventer: *Eventer, callback: *Eventer.Callback) anyerror!void { 63 | _ = callback; 64 | var addr : Address = undefined; 65 | var addrlen : os.socklen_t = @sizeOf(@TypeOf(addr)); 66 | 67 | const newsockfd = try os.accept(global.listenFd, &addr.any, &addrlen, os.SOCK.NONBLOCK); 68 | errdefer common.shutdownclose(newsockfd); 69 | 70 | const ClientInfos = struct { newClient: *Client, otherClient: *Client }; 71 | var info = clientInit: { 72 | if (global.clientA.fd == INVALID_FD) 73 | break :clientInit ClientInfos {.newClient=&global.clientA, .otherClient=&global.clientB}; 74 | if (global.clientB.fd == INVALID_FD) 75 | break :clientInit ClientInfos {.newClient=&global.clientB, .otherClient=&global.clientA}; 76 | 77 | std.debug.warn("s={} closing connection from {}, already have 2 clients\n", .{newsockfd, addr}); 78 | common.shutdownclose(newsockfd); 79 | return; 80 | }; 81 | 82 | std.debug.warn("s={} new client from {}\n", .{newsockfd, addr}); 83 | errdefer info.newClient.fd = INVALID_FD; 84 | if (info.otherClient.fd == INVALID_FD or info.otherClient.callback.func == onDataClosing) { 85 | info.newClient.* = Client { 86 | .fd = newsockfd, 87 | .callback = Eventer.Callback { 88 | .func = onDataOneClient, 89 | .data = .{}, 90 | }, 91 | }; 92 | try eventer.add(newsockfd, EventFlags.hangup, &info.newClient.callback); 93 | } else { 94 | std.debug.assert(info.otherClient.callback.func == onDataOneClient); 95 | info.newClient.* = Client { 96 | .fd = newsockfd, 97 | .callback = Eventer.Callback { 98 | .func = onDataTwoClients, 99 | .data = .{}, 100 | }, 101 | }; 102 | try eventer.add(newsockfd, EventFlags.read, &info.newClient.callback); 103 | 104 | try eventer.modify(info.otherClient.fd, EventFlags.read, &info.otherClient.callback); 105 | info.otherClient.callback.func = onDataTwoClients; 106 | 107 | // initialize this because now we're using onDataTwoClients which is where it 108 | // will be used 109 | global.clientsLinked = false; 110 | } 111 | } 112 | 113 | // because we aren't listening for data, this should only be called if the socket has been closed 114 | fn onDataOneClient(eventer: *Eventer, callback: *Eventer.Callback) anyerror!void { 115 | const clientRef = callbackToClient(callback); 116 | std.debug.warn("s={} connection closed\n", .{clientRef.fd}); 117 | std.debug.assert(clientRef.fd != INVALID_FD); 118 | eventer.remove(clientRef.fd); 119 | common.shutdownclose(clientRef.fd); 120 | clientRef.fd = INVALID_FD; 121 | } 122 | 123 | // in this callback, you can assume that both clients are valid and have this callback 124 | fn onDataTwoClients(eventer: *Eventer, callback: *Eventer.Callback) anyerror!void { 125 | std.debug.assert(global.clientA.fd != INVALID_FD and global.clientB.fd != INVALID_FD); 126 | std.debug.assert(global.clientA.callback.func == onDataTwoClients); 127 | std.debug.assert(global.clientB.callback.func == onDataTwoClients); 128 | 129 | if (callback == &global.clientA.callback) { 130 | try forward(eventer, &global.clientA, &global.clientB); 131 | } else { 132 | std.debug.assert(callback == &global.clientB.callback); // code bug if false 133 | try forward(eventer, &global.clientB, &global.clientA); 134 | } 135 | } 136 | 137 | fn in_forward_from_closed(eventer: *Eventer, from: *Client, to: *Client) !void { 138 | eventer.remove(from.fd); 139 | os.close(from.fd); 140 | from.fd = INVALID_FD; 141 | 142 | if (!global.clientsLinked) { 143 | std.debug.warn("client's weren't linked, back to one-client s={}\n", .{to.fd}); 144 | // clients haven't sent any data so keep the other client open 145 | to.callback.func = onDataOneClient; 146 | try eventer.modify(to.fd, EventFlags.hangup, &to.callback); 147 | } else { 148 | std.debug.warn("client's were linked, disconnecting s={}\n", .{to.fd}); 149 | to.callback.func = onDataClosing; 150 | try common.shutdown(to.fd); 151 | } 152 | } 153 | 154 | fn forward(eventer: *Eventer, from: *Client, to: *Client) !void { 155 | const length = os.read(from.fd, &global.buffer) catch |e| { 156 | std.debug.warn("s={} read failed: {}\n", .{from.fd, e}); 157 | try in_forward_from_closed(eventer, from, to); 158 | return; 159 | }; 160 | if (length == 0) { 161 | std.debug.warn("s={} disconnected\n", .{from.fd}); 162 | try in_forward_from_closed(eventer, from, to); 163 | return; 164 | } 165 | std.debug.warn("s={} forwarding {} bytes to {}\n", .{from.fd, length, to.fd}); 166 | const sendResult = os.send(to.fd, global.buffer[0..length], 0) catch |e| { 167 | std.debug.warn("send on {} failed: {}\n", .{to.fd, e}); 168 | std.debug.warn("TODO: implement cleanup...\n", .{}); 169 | return error.NotImplemented; 170 | }; 171 | if (sendResult != length) { 172 | std.debug.warn("only sent {} out of {} on {}\n", .{sendResult, length, to.fd}); 173 | std.debug.warn("TODO: implement something here...\n", .{}); 174 | return error.NotImplemented; 175 | } 176 | if (!global.clientsLinked) { 177 | std.debug.warn("s={} linked to s={}\n", .{from.fd, to.fd}); 178 | global.clientsLinked = true; 179 | } 180 | } 181 | 182 | fn onDataClosing(eventer: *Eventer, callback: *Eventer.Callback) anyerror!void { 183 | const clientRef = callbackToClient(callback); 184 | std.debug.warn("s={} finishing close\n", .{clientRef.fd}); 185 | eventer.remove(clientRef.fd); 186 | os.close(clientRef.fd); 187 | clientRef.fd = INVALID_FD; 188 | } 189 | -------------------------------------------------------------------------------- /eventing/select.zig: -------------------------------------------------------------------------------- 1 | const builtin = @import("builtin"); 2 | const std = @import("std"); 3 | const os = std.os; 4 | 5 | const common = @import("./common.zig"); 6 | 7 | const eventing = @import("../eventing.zig"); 8 | const EventerOptions = eventing.EventerOptions; 9 | 10 | const platform = struct { 11 | usingnamespace if (builtin.os.tag == .windows) 12 | @import("./selectwindows.zig") 13 | else 14 | @import("./selectnotwindows.zig"); 15 | }; 16 | // just hardcode to 64 for now 17 | const fd_set = platform.fd_set(64); 18 | 19 | pub const EventFlags = struct { 20 | pub const read = 0x01; 21 | pub const write = 0x02; 22 | pub const hangup = 0x04; 23 | }; 24 | 25 | // TODO: allow various backend-specific options like select fd capacity 26 | // TODO: add Eventer reference to EventerData rather 27 | // than passing it by default 28 | // some programs only have 1 eventer and don't need to 29 | // pass it as an argument 30 | pub fn EventerTemplate(comptime options: EventerOptions) type { 31 | return struct { 32 | pub const Fd = platform.fd_t; 33 | pub const Data = options.Data; 34 | pub const CallbackError = options.CallbackError; 35 | pub const CallbackFn = fn(server: *@This(), callback: *Callback) CallbackError!void; 36 | pub const Callback = struct { 37 | func: CallbackFn, 38 | data: options.CallbackData, 39 | pub fn init(func: CallbackFn, data: options.CallbackData) @This() { 40 | return @This() { 41 | .func = func, 42 | .data = data, 43 | }; 44 | } 45 | }; 46 | const FdInfo = struct { 47 | fd: platform.fd_t, 48 | flags: u32, 49 | callback: *Callback, 50 | }; 51 | const CountType = u8; 52 | 53 | /// data that can be shared between all callbacks 54 | data: Data, 55 | fdlist: [64]FdInfo, 56 | fdcount: CountType, 57 | pub fn init(data: Data) !@This() { 58 | var this : @This() = undefined; 59 | this.data = data; 60 | this.fdcount = 0; 61 | return this; 62 | } 63 | 64 | fn find(self: @This(), fd: platform.fd_t) ?CountType { 65 | var i : CountType = 0; 66 | while (i < self.fdcount) : (i += 1) { 67 | if (self.fdlist[i].fd == fd) 68 | return i; 69 | } 70 | return null; 71 | } 72 | 73 | pub fn add(self: *@This(), fd: platform.fd_t, flags: u32, callback: *Callback) common.EventerAddError!void { 74 | if (self.fdcount == self.fdlist.len) 75 | return error.UserResourceLimitReached; 76 | std.debug.assert(self.find(fd) == null); 77 | self.fdlist[self.fdcount] = .{ .fd = fd, .flags = flags, .callback = callback }; 78 | self.fdcount += 1; 79 | } 80 | pub fn modify(self: *@This(), fd: platform.fd_t, flags: u32, callback: *Callback) common.EventerAddError!void { 81 | if (self.find(fd)) |i| { 82 | self.fdlist[i].flags = flags; 83 | self.fdlist[i].callback = callback; 84 | } else return error.SocketNotAddedToEventer; 85 | } 86 | pub fn remove(self: *@This(), fd: platform.fd_t) void { 87 | if (self.find(fd)) |i| { 88 | var j = i; 89 | while (j + 1 < self.fdcount) { 90 | self.fdlist[j] = self.fdlist[j+1]; 91 | } 92 | self.fdcount -= 1; 93 | } else std.debug.panic("remove called on socket {} that is not registered with eventer", .{fd}); 94 | } 95 | 96 | // returns: false if there was a timeout 97 | fn handleEventsGeneric(self: *@This(), timeout_ms: i32) CallbackError!bool { 98 | 99 | const nfds = if (builtin.os.tag == .windows) 0 else @compileError("select nfds not implemented for non-windows"); 100 | 101 | var read_set : fd_set = .{ .fd_count = 0, .fd_array = undefined }; 102 | var write_set : fd_set = .{ .fd_count = 0, .fd_array = undefined }; 103 | var error_set : fd_set = .{ .fd_count = 0, .fd_array = undefined }; 104 | {var i : CountType = 0; while (i < self.fdcount) : (i += 1) { 105 | if ( (self.fdlist[i].flags & EventFlags.read) != 0) { 106 | platform.set_fd(fd_set, &read_set, self.fdlist[i].fd); 107 | } 108 | if ( (self.fdlist[i].flags & EventFlags.write) != 0) { 109 | platform.set_fd(fd_set, &write_set, self.fdlist[i].fd); 110 | } 111 | if ( (self.fdlist[i].flags & EventFlags.hangup) != 0) { 112 | platform.set_fd(fd_set, &error_set, self.fdlist[i].fd); 113 | } 114 | }} 115 | var timeout_buf : platform.timeval = undefined; 116 | const timeout = init: { 117 | if (timeout_ms == -1) break :init null; 118 | std.debug.assert(timeout_ms >= 0); 119 | timeout_buf = platform.msToTimeval(@intCast(u31, timeout_ms)); 120 | break :init &timeout_buf; 121 | }; 122 | const result = platform.select(nfds, read_set.base(), write_set.base(), error_set.base(), timeout); 123 | if (result == -1) { 124 | // TODO: create wrapper function in std.os to handle all error codes 125 | std.debug.panic("select failed, lasterror = {}", .{std.os.windows.ws2_32.WSAGetLastError()}); 126 | //std.debug.warn("Error: select failed, lasterror = {}", .{std.os.windows.ws2_32.WSAGetLastError()}); 127 | //return error.SelectFailed; 128 | } 129 | if (result == 0) 130 | return false; // timeout 131 | 132 | var left = result; 133 | while (left > 0) : (left -= 1) { 134 | // TODO: prevent sockets from being called multiple times from different sets? 135 | for (read_set.fd_array[0..read_set.fd_count]) |fd| { 136 | if (self.find(fd)) |i| { 137 | try self.fdlist[i].callback.func(self, self.fdlist[i].callback); 138 | } else std.debug.panic("bug, select returned socket not in list {}", .{fd}); 139 | } 140 | } 141 | return true; 142 | } 143 | 144 | pub fn handleEventsNoTimeout(self: *@This()) CallbackError!void { 145 | if (!try self.handleEventsGeneric(-1)) 146 | std.debug.panic("select returned 0 with ifinite timeout?", .{}); 147 | } 148 | 149 | pub fn loop(self: *@This()) anyerror!void { 150 | _ = self; 151 | std.debug.panic("not implemented", .{}); 152 | //while (true) { 153 | // readSet: fd_set, 154 | // writeSet: fd_set, 155 | // errorSet: fd_set, 156 | // var events : [16]os.epoll_event = undefined; 157 | // //std.debug.warn("[DEBUG] waiting for event...\n", .{}); 158 | // const count = os.epoll_wait(self.epollfd, &events, -1); 159 | // //std.debug.warn("[DEBUG] epoll_wait returned {}\n", .{count}); 160 | // { 161 | // const errno = os.errno(count); 162 | // if (errno != 0) { 163 | // std.debug.warn("epoll_wait failed, errno={}", .{errno}); 164 | // return error.EpollFailed; 165 | // } 166 | // } 167 | // for (events[0..count]) |event| { 168 | // const callback = @intToPtr(*Callback, event.data.ptr); 169 | // try callback.func(self, callback); 170 | // } 171 | //} 172 | } 173 | }; 174 | } 175 | -------------------------------------------------------------------------------- /netext.zig: -------------------------------------------------------------------------------- 1 | /// 2 | /// network functions that both log errors and return actionable error codes 3 | /// 4 | const std = @import("std"); 5 | const os = std.os; 6 | const net = std.net; 7 | 8 | const fd_t = os.fd_t; 9 | const socket_t = os.socket_t; 10 | const Address = net.Address; 11 | 12 | const logging = @import("./logging.zig"); 13 | const common = @import("./common.zig"); 14 | const proxy = @import("./proxy.zig"); 15 | 16 | const panic = std.debug.panic; 17 | const log = logging.log; 18 | const Proxy = proxy.Proxy; 19 | 20 | /// logs errors and returns either fatal or retry 21 | pub fn socket(domain: u32, socketType: u32, proto: u32) !socket_t { 22 | return os.socket(domain, socketType, proto) catch |e| switch (e) { 23 | error.ProcessFdQuotaExceeded 24 | ,error.SystemFdQuotaExceeded 25 | ,error.SystemResources 26 | => { 27 | log("WARNING: socket function error: {}", .{e}); 28 | return error.Retry; 29 | }, 30 | error.PermissionDenied 31 | ,error.AddressFamilyNotSupported 32 | ,error.SocketTypeNotSupported 33 | ,error.ProtocolFamilyNotAvailable 34 | ,error.ProtocolNotSupported 35 | ,error.Unexpected 36 | => panic("socket function failed with: {}", .{e}), 37 | }; 38 | } 39 | 40 | pub fn connect(sockfd: socket_t, addr: *const Address) !void { 41 | return common.connect(sockfd, addr) catch |e| switch (e) { 42 | error.AddressNotAvailable 43 | ,error.AddressInUse 44 | ,error.ConnectionRefused 45 | ,error.ConnectionTimedOut 46 | ,error.ConnectionResetByPeer 47 | ,error.NetworkUnreachable 48 | ,error.SystemResources 49 | ,error.WouldBlock 50 | ,error.ConnectionPending 51 | => { 52 | log("WARNING: connect function returned error: {}", .{e}); 53 | return error.Retry; 54 | }, 55 | error.PermissionDenied 56 | ,error.AddressFamilyNotSupported 57 | ,error.Unexpected 58 | ,error.FileNotFound 59 | => panic("connect function failed with: {}", .{e}), 60 | }; 61 | } 62 | 63 | pub fn proxyConnect(prox: *const Proxy, host: []const u8, port: u16) !socket_t { 64 | return prox.connectHost(host, port) catch |e| switch (e) { 65 | error.AddressNotAvailable 66 | ,error.AddressInUse 67 | ,error.ConnectionRefused 68 | ,error.ConnectionTimedOut 69 | ,error.NetworkUnreachable 70 | ,error.SystemResources 71 | ,error.ProcessFdQuotaExceeded 72 | ,error.SystemFdQuotaExceeded 73 | ,error.WouldBlock 74 | ,error.SendReturnedZero 75 | ,error.ConnectionResetByPeer 76 | ,error.MessageTooBig 77 | ,error.BrokenPipe 78 | ,error.InputOutput 79 | ,error.NotOpenForReading 80 | ,error.OperationAborted 81 | ,error.HttpProxyDisconnectedDurringReply 82 | ,error.HttpProxyUnexpectedReply 83 | ,error.NetworkSubsystemFailed 84 | ,error.ConnectionPending 85 | => { 86 | log("WARNING: proxy connectHost returned error: {}", .{e}); 87 | return error.Retry; 88 | }, 89 | error.AccessDenied 90 | ,error.PermissionDenied 91 | ,error.AddressFamilyNotSupported 92 | ,error.Unexpected 93 | ,error.FileDescriptorNotASocket 94 | ,error.FileNotFound 95 | ,error.ProtocolFamilyNotAvailable 96 | ,error.ProtocolNotSupported 97 | ,error.DnsNotSupported 98 | ,error.IsDir 99 | ,error.FastOpenAlreadyInProgress // this is from sendto, EALREADY, not sure what it means 100 | ,error.SocketTypeNotSupported 101 | => panic("proxy connectHost failed with: {}", .{e}), 102 | }; 103 | } 104 | 105 | pub fn send(sockfd: socket_t, buf: []const u8, flags: u32) !void { 106 | common.sendfull(sockfd, buf, flags) catch |e| switch (e) { 107 | error.ConnectionResetByPeer 108 | ,error.BrokenPipe 109 | ,error.NetworkUnreachable 110 | ,error.NetworkSubsystemFailed 111 | => { 112 | log("send function error: {}", .{e}); 113 | return error.Disconnected; 114 | }, 115 | error.WouldBlock 116 | ,error.MessageTooBig 117 | ,error.SystemResources 118 | ,error.SendReturnedZero 119 | => { 120 | log("WARNING: send function error: {}", .{e}); 121 | return error.Retry; 122 | }, 123 | error.AccessDenied 124 | ,error.FastOpenAlreadyInProgress // don't know what this is 125 | ,error.FileDescriptorNotASocket 126 | ,error.Unexpected 127 | => panic("send function failed with: {}", .{e}), 128 | }; 129 | } 130 | 131 | pub fn read(fd: fd_t, buf: []u8) !usize { 132 | return os.read(fd, buf) catch |e| switch (e) { 133 | error.BrokenPipe 134 | ,error.ConnectionResetByPeer 135 | ,error.ConnectionTimedOut 136 | ,error.InputOutput 137 | ,error.NotOpenForReading 138 | => { 139 | log("read function disconnect error: {}", .{e}); 140 | return error.Disconnected; 141 | }, 142 | error.WouldBlock 143 | ,error.SystemResources 144 | ,error.OperationAborted 145 | => { 146 | log("WARNING: read function retry error: {}", .{e}); 147 | return error.Retry; 148 | }, 149 | error.IsDir 150 | ,error.Unexpected 151 | ,error.AccessDenied 152 | => panic("read function failed with: {}", .{e}), 153 | }; 154 | } 155 | 156 | pub fn recvfullTimeout(sockfd: socket_t, buf: []u8, timeoutMillis: u32) !bool { 157 | return common.recvfullTimeout(sockfd, buf, timeoutMillis) catch |e| switch (e) { 158 | error.BrokenPipe 159 | ,error.ConnectionResetByPeer 160 | ,error.ConnectionTimedOut 161 | ,error.InputOutput 162 | ,error.NotOpenForReading 163 | => { 164 | log("read function disconnect error: {}", .{e}); 165 | return error.Disconnected; 166 | }, 167 | error.Retry => return error.Retry, // already logged 168 | error.WouldBlock 169 | ,error.SystemResources 170 | ,error.OperationAborted 171 | => { 172 | log("WARNING: read function retry error: {}", .{e}); 173 | return error.Retry; 174 | }, 175 | error.IsDir 176 | ,error.Unexpected 177 | ,error.AccessDenied 178 | => panic("read function failed with: {}", .{e}), 179 | }; 180 | } 181 | 182 | pub fn setsockopt(sockfd: socket_t, level: u32, optname: u32, opt: []const u8) !void { 183 | os.setsockopt(sockfd, level, optname, opt) catch |e| switch (e) { 184 | error.SystemResources 185 | ,error.NetworkSubsystemFailed 186 | => { 187 | log("WARNING: setsockopt function error: {}", .{e}); 188 | return error.Retry; 189 | }, 190 | error.InvalidProtocolOption 191 | ,error.FileDescriptorNotASocket 192 | ,error.TimeoutTooBig 193 | ,error.AlreadyConnected 194 | ,error.SocketNotBound 195 | ,error.Unexpected 196 | ,error.PermissionDenied 197 | => panic("setsockopt function fatal with: {}", .{e}), 198 | }; 199 | } 200 | 201 | pub fn bind(sockfd: socket_t, addr: *const os.sockaddr, len: os.socklen_t) !void { 202 | os.bind(sockfd, addr, len) catch |e| switch (e) { 203 | error.SystemResources 204 | ,error.AddressInUse 205 | ,error.AddressNotAvailable 206 | ,error.NetworkSubsystemFailed 207 | => { 208 | log("WARNING: bind function error: {}", .{e}); 209 | return error.Retry; 210 | }, 211 | error.AccessDenied 212 | ,error.AlreadyBound 213 | ,error.Unexpected 214 | ,error.FileNotFound 215 | ,error.FileDescriptorNotASocket 216 | ,error.NotDir 217 | ,error.ReadOnlyFileSystem 218 | ,error.SymLinkLoop 219 | ,error.NameTooLong 220 | => panic("bind function failed with: {}", .{e}), 221 | }; 222 | } 223 | 224 | pub fn listen(sockfd: socket_t, backlog: u31) !void { 225 | os.listen(sockfd, backlog) catch |e| switch (e) { 226 | error.AddressInUse 227 | ,error.NetworkSubsystemFailed 228 | ,error.SystemResources 229 | => { 230 | log("WARNING: listen function error: {}", .{e}); 231 | return error.Retry; 232 | }, 233 | error.OperationNotSupported 234 | ,error.AlreadyConnected 235 | ,error.FileDescriptorNotASocket 236 | ,error.SocketNotBound 237 | ,error.Unexpected 238 | => panic("listen function failed with: {}", .{e}), 239 | }; 240 | } 241 | 242 | pub fn makeListenSock(addr: *std.net.Address, backlog: u31) !socket_t { 243 | const sockfd = try socket(addr.any.family, os.SOCK.STREAM, os.IPPROTO.TCP); 244 | errdefer os.close(sockfd); 245 | try setsockopt(sockfd, os.SOL.SOCKET, os.SO.REUSEADDR, &std.mem.toBytes(@as(c_int, 1))); 246 | try bind(sockfd, &addr.any, addr.getOsSockLen()); 247 | try listen(sockfd, backlog); 248 | return sockfd; 249 | } 250 | 251 | pub fn accept(sockfd: socket_t, addr: *os.sockaddr, addr_size: *os.socklen_t, flags: u32) !socket_t { 252 | return os.accept(sockfd, addr, addr_size, flags) catch |e| switch (e) { 253 | error.ConnectionAborted 254 | ,error.ConnectionResetByPeer 255 | ,error.ProtocolFailure 256 | ,error.BlockedByFirewall 257 | ,error.WouldBlock 258 | => { 259 | log("accept dropped client: {}", .{e}); 260 | return error.ClientDropped; 261 | }, 262 | error.SystemResources 263 | ,error.ProcessFdQuotaExceeded 264 | ,error.SystemFdQuotaExceeded 265 | ,error.NetworkSubsystemFailed 266 | => { 267 | log("WARNING: accept function error: {}", .{e}); 268 | return error.Retry; 269 | }, 270 | error.Unexpected 271 | ,error.FileDescriptorNotASocket 272 | ,error.SocketNotListening 273 | ,error.OperationNotSupported 274 | => panic("accept function failed with: {}", .{e}), 275 | }; 276 | } 277 | -------------------------------------------------------------------------------- /timing.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | 3 | const logging = @import("./logging.zig"); 4 | const log = logging.log; 5 | 6 | /// TODO: these functions should go somewhere else 7 | pub fn SignModified(comptime T: type, comptime signedness: std.builtin.Signedness) type { 8 | return switch (@typeInfo(T)) { 9 | .Int => |info| @Type(std.builtin.TypeInfo{.Int = .{ 10 | .signedness = signedness, 11 | .bits = info.bits, 12 | }}), 13 | else => @compileError("Signed requires an Int type but got: " ++ @typeName(T)), 14 | }; 15 | } 16 | pub fn Signed (comptime T: type) type { return SignModified(T, .signed ); } 17 | pub fn Unsigned(comptime T: type) type { return SignModified(T, .unsigned); } 18 | 19 | pub const Timestamp = @typeInfo(@TypeOf(std.time.milliTimestamp)).Fn.return_type.?; 20 | pub const TimestampDiff = Signed(Timestamp); 21 | 22 | pub fn getNowTimestamp() Timestamp { 23 | return std.time.milliTimestamp(); 24 | } 25 | 26 | pub fn timestampToMillis(timestamp: Timestamp) Timestamp { 27 | return timestamp; // already millis right now 28 | } 29 | pub fn secondsToTimestamp(value: anytype) Timestamp { 30 | return 1000 * value; 31 | } 32 | 33 | // 2's complement negate 34 | pub fn negate(val: anytype) @TypeOf(val) { 35 | var result : @TypeOf(val) = undefined; 36 | _ = @addWithOverflow(@TypeOf(val), ~val, 1, &result); 37 | return result; 38 | } 39 | 40 | pub fn timestampDiff(left: Timestamp, right: Timestamp) TimestampDiff { 41 | var result : Timestamp = undefined; 42 | _ = @subWithOverflow(Timestamp, left, right, &result); 43 | return @intCast(TimestampDiff, result); 44 | } 45 | 46 | test "timestampDiff" { 47 | const Test = struct { left: Timestamp, right: Timestamp, diff: TimestampDiff }; 48 | const tests = [_]Test { 49 | Test {.left= 0, .right= 0, .diff= 0}, 50 | Test {.left= 1, .right= 0, .diff= 1}, 51 | Test {.left=100, .right=83, .diff=17}, 52 | Test {.left=std.math.maxInt(Timestamp) , .right=std.math.maxInt(Timestamp) , .diff=0}, 53 | Test {.left=std.math.maxInt(Timestamp) , .right=std.math.maxInt(Timestamp) - 1, .diff=1}, 54 | Test {.left=std.math.maxInt(Timestamp) - 80, .right=std.math.maxInt(Timestamp) - 223, .diff=143}, 55 | Test {.left=0, .right=std.math.maxInt(Timestamp), .diff=1}, 56 | Test {.left=1234, .right=std.math.maxInt(Timestamp) - 100, .diff=1335}, 57 | Test {.left=std.math.maxInt(Timestamp)/2, .right=0, .diff=std.math.maxInt(Timestamp)/2}, 58 | Test {.left=std.math.maxInt(Timestamp)/2, .right=123, .diff=std.math.maxInt(Timestamp)/2 - 123}, 59 | Test {.left=std.math.maxInt(Timestamp)/2 + 234, .right=234, .diff=std.math.maxInt(Timestamp)/2}, 60 | Test {.left=std.math.maxInt(Timestamp)/2 + 1, .right=0, .diff=std.math.minInt(TimestampDiff)}, 61 | Test {.left=std.math.maxInt(Timestamp)/2 + 2, .right=0, .diff=std.math.minInt(TimestampDiff) + 1}, 62 | }; 63 | for (tests) |t| { 64 | std.debug.warn("left {} right {} diff {} ndiff {}\n", .{t.left, t.right, t.diff, negate(t.diff)}); 65 | std.debug.warn(" {}\n", .{timestampDiff(t.left , t.right)}); 66 | std.debug.warn(" {}\n", .{timestampDiff(t.right, t.left )}); 67 | std.debug.assert(timestampDiff(t.left , t.right) == t.diff); 68 | std.debug.assert(timestampDiff(t.right, t.left ) == negate(t.diff)); 69 | } 70 | } 71 | 72 | // TODO: create a single timer (used for the heartbeat timer) 73 | pub const TimerCheckResult = union (enum) { 74 | Expired: void, 75 | Wait: u32, 76 | }; 77 | pub const Timer = struct { 78 | durationMillis: u32, 79 | started: bool, 80 | lastExpireTimestamp: Timestamp, 81 | pub fn init(durationMillis: u32) Timer { 82 | return Timer { 83 | .durationMillis = durationMillis, 84 | .started = false, 85 | .lastExpireTimestamp = undefined, 86 | }; 87 | } 88 | pub fn check(self: *Timer) TimerCheckResult { 89 | const nowMillis = getNowTimestamp(); 90 | if (!self.started) { 91 | self.started = true; 92 | self.lastExpireTimestamp = nowMillis; 93 | return TimerCheckResult { .Wait = self.durationMillis }; 94 | } 95 | const diff = timestampDiff(nowMillis, self.lastExpireTimestamp); 96 | if (diff < 0 or diff >= self.durationMillis) { 97 | self.lastExpireTimestamp = nowMillis; 98 | return TimerCheckResult.Expired; 99 | } 100 | return TimerCheckResult { .Wait = self.durationMillis - @intCast(u32, diff) }; 101 | } 102 | }; 103 | 104 | pub fn TimersTemplate(comptime CallbackData: type) type { 105 | return struct { 106 | pub const CallbackFn = fn(*@This(), *Callback) anyerror!void; 107 | pub const Callback = struct { 108 | optionalNext: ?*Callback, 109 | timestamp: Timestamp, 110 | func: CallbackFn, 111 | data: CallbackData, 112 | pub fn init(timestamp: Timestamp, func: CallbackFn, data: CallbackData) @This() { 113 | return @This() { 114 | .optionalNext = null, 115 | .timestamp = timestamp, 116 | .func = func, 117 | .data = data, 118 | }; 119 | } 120 | }; 121 | 122 | optionalNext: ?*Callback, 123 | pub fn init() @This() { 124 | return @This() { .optionalNext = null }; 125 | } 126 | pub fn add(self: *@This(), callback: *Callback) !void { 127 | if (self.optionalNext) |_| { 128 | std.debug.assert(false); 129 | } else { 130 | self.optionalNext = callback; 131 | } 132 | } 133 | pub fn handleEvents(self: *@This()) !?Timestamp { 134 | while (true) { 135 | if (self.optionalNext) |next| { 136 | const diff = timestampDiff(next.timestamp, getNowTimestamp()); 137 | //std.debug.warn("[DEBUG] timestamp diff {}\n", .{diff}); 138 | if (diff > 0) return @intCast(Timestamp, diff); 139 | self.optionalNext = next.optionalNext; 140 | try next.func(self, next); 141 | } else { 142 | return null; 143 | } 144 | } 145 | } 146 | }; 147 | } 148 | 149 | 150 | pub const makeThrottler = struct { 151 | logPrefix: []const u8, 152 | desiredSleepMillis: Timestamp, 153 | slowRateMillis: Timestamp, 154 | pub fn create(self: *const makeThrottler) Throttler { 155 | return Throttler.init(self.logPrefix, self.desiredSleepMillis, self.slowRateMillis); 156 | } 157 | }; 158 | 159 | const ns_per_ms = std.time.ns_per_s / std.time.ms_per_s; 160 | 161 | /// Use to throttle an operation from happening to quickly 162 | /// Note that if it is occuring too fast, it will slow down 163 | /// gruadually based on `slowRateMillis`. 164 | pub const Throttler = struct { 165 | logPrefix: []const u8, 166 | desiredSleepMillis: Timestamp, 167 | slowRateMillis: Timestamp, 168 | started: bool, 169 | sleepMillis: Timestamp, 170 | checkinTimestamp : Timestamp, 171 | beforeWorkTimestamp: Timestamp, 172 | pub fn init(logPrefix: []const u8, desiredSleepMillis: Timestamp, slowRateMillis: Timestamp) Throttler { 173 | return Throttler { 174 | .logPrefix = logPrefix, 175 | .desiredSleepMillis = desiredSleepMillis, 176 | .slowRateMillis = slowRateMillis, 177 | .started = false, 178 | .sleepMillis = 0, 179 | .checkinTimestamp = undefined, 180 | .beforeWorkTimestamp = undefined, 181 | }; 182 | } 183 | /// call this function before performing the work because it needs to 184 | /// save the timestamp before performing the work, i.e. 185 | /// var t = Throttler.init() 186 | /// while (true) { t.throttle(); dowork() } 187 | pub fn throttle(self: *Throttler) void { 188 | const nowMillis = getNowTimestamp(); 189 | if (!self.started) { 190 | self.started = true; 191 | } else { 192 | const elapsedMillis = timestampDiff(nowMillis, self.checkinTimestamp); 193 | if (elapsedMillis < 0) { 194 | if (self.logPrefix.len > 0) 195 | log("{s}elapsed time is negative ({} ms), will wait {} ms...", .{self.logPrefix, elapsedMillis, self.desiredSleepMillis}); 196 | std.time.sleep(ns_per_ms * @intCast(u64, self.desiredSleepMillis)); 197 | self.sleepMillis = 0; // reset sleep time 198 | } else if (elapsedMillis >= self.desiredSleepMillis) { 199 | const workMillis = timestampDiff(nowMillis, self.beforeWorkTimestamp); 200 | std.debug.assert(workMillis >= 0 and workMillis <= elapsedMillis); 201 | if (workMillis >= self.desiredSleepMillis) { 202 | self.sleepMillis = 0; 203 | } else { 204 | self.sleepMillis = self.desiredSleepMillis - @intCast(Timestamp, workMillis); 205 | } 206 | if (self.logPrefix.len > 0) 207 | log("{s}last operation took {} ms, no throttling needed (next sleep {} ms)...", .{self.logPrefix, workMillis, self.sleepMillis}); 208 | } else { 209 | const millisNeeded = self.desiredSleepMillis - @intCast(Timestamp, elapsedMillis); 210 | const addMillis = if (millisNeeded < self.slowRateMillis) millisNeeded else self.slowRateMillis; 211 | self.sleepMillis += addMillis; 212 | if (self.logPrefix.len > 0) 213 | log("{s}{} ms since last operation, will sleep {} ms...", .{self.logPrefix, elapsedMillis, self.sleepMillis}); 214 | std.time.sleep(ns_per_ms * @intCast(u64, self.sleepMillis)); 215 | } 216 | } 217 | self.checkinTimestamp = nowMillis; 218 | self.beforeWorkTimestamp = getNowTimestamp(); 219 | } 220 | }; 221 | 222 | //pub fn throttle(comptime eventName: []const u8, minTimeMillis: u32, elapsedMillis: timing.TimestampDiff) void { 223 | // if (elapsedMillis < 0) { 224 | // log("time since " ++ eventName ++ " is negative ({} ms)? Will wait {} ms...", .{elapsedMillis, minTimeMillis}); 225 | // std.time.sleep(ns_per_ms * @intCast(u64, minTimeMillis)); 226 | // } else if (elapsedMillis < minTimeMillis) { 227 | // const sleepMillis : u64 = minTimeMillis - @intCast(u32, elapsedMillis); 228 | // log("been {} ms since " ++ eventName ++ ", will sleep for {} ms...", .{elapsedMillis, sleepMillis}); 229 | // std.time.sleep(ns_per_ms * sleepMillis); 230 | // } else { 231 | // log("been {} ms since last connect, will retry immediately", .{elapsedMillis}); 232 | // } 233 | //} 234 | -------------------------------------------------------------------------------- /punch-client-forwarder.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const mem = std.mem; 3 | const os = std.os; 4 | const net = std.net; 5 | 6 | const logging = @import("./logging.zig"); 7 | const common = @import("./common.zig"); 8 | const netext = @import("./netext.zig"); 9 | const timing = @import("./timing.zig"); 10 | const eventing = @import("./eventing.zig").default; 11 | const punch = @import("./punch.zig"); 12 | const proxy = @import("./proxy.zig"); 13 | 14 | const log = logging.log; 15 | const fd_t = os.fd_t; 16 | const Address = net.Address; 17 | const delaySeconds = common.delaySeconds; 18 | const Timestamp = timing.Timestamp; 19 | const Timer = timing.Timer; 20 | const EventFlags = eventing.EventFlags; 21 | const PunchRecvState = punch.util.PunchRecvState; 22 | const Proxy = proxy.Proxy; 23 | const HostAndProxy = proxy.HostAndProxy; 24 | 25 | const Eventer = eventing.EventerTemplate(.{ 26 | .Data = struct { 27 | punchFd: fd_t, 28 | rawFd: fd_t, 29 | punchRecvState: *PunchRecvState, 30 | gotCloseTunnel: *bool, 31 | }, 32 | .CallbackError = error { 33 | PunchSocketDisconnect, 34 | RawSocketDisconnect, 35 | }, 36 | .CallbackData = struct { 37 | fd: fd_t, 38 | }, 39 | }); 40 | 41 | const global = struct { 42 | var ignoreSigaction : os.Sigaction = undefined; 43 | var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator); 44 | var rawForwardAddr : Address = undefined; 45 | var buffer : [8192]u8 = undefined; 46 | }; 47 | 48 | 49 | fn setupSignals() void { 50 | global.ignoreSigaction.handler.sigaction = os.SIG.IGN; 51 | std.mem.set(u32, &global.ignoreSigaction.mask, 0); 52 | global.ignoreSigaction.flags = 0; 53 | os.sigaction(os.SIG.PIPE, &global.ignoreSigaction, null); 54 | } 55 | 56 | fn usage() void { 57 | std.debug.warn("Usage: punch-client-forwarder PUNCH_SERVER PUNCH_PORT FORWARD_HOST FORWARD_PORT\n", .{}); 58 | std.debug.warn("\n", .{}); 59 | std.debug.warn("enable proxy with http://PROXY_HOST:PROXY_PORT/PUNCH_SERVER\n", .{}); 60 | } 61 | pub fn main() anyerror!u8 { 62 | setupSignals(); 63 | 64 | var args = try std.process.argsAlloc(&global.arena.allocator); 65 | if (args.len <= 1) { 66 | usage(); 67 | return 1; 68 | } 69 | args = args[1..]; 70 | if (args.len != 4) { 71 | usage(); 72 | return 1; 73 | } 74 | const punchConnectSpec = args[0]; 75 | const punchPort = common.parsePort(args[1]) catch return 1; 76 | const rawForwardString = args[2]; 77 | const rawForwardPort = common.parsePort(args[3]) catch return 1; 78 | 79 | const punchHostAndProxy = proxy.parseProxy(punchConnectSpec) catch |e| { 80 | log("Error: invalid connect specifier '{s}': {}", .{punchConnectSpec, e}); 81 | return 1; 82 | }; 83 | 84 | global.rawForwardAddr = common.parseIp4(rawForwardString, rawForwardPort) catch return 1; 85 | 86 | var connectThrottler = makeThrottler("connect throttler: "); 87 | while (true) { 88 | connectThrottler.throttle(); 89 | switch (sequenceConnectToPunchClient(&punchHostAndProxy, punchPort)) { 90 | error.PunchSocketDisconnect => {}, 91 | } 92 | } 93 | } 94 | 95 | fn makeThrottler(logPrefix: []const u8) timing.Throttler { 96 | return (timing.makeThrottler { 97 | .logPrefix = logPrefix, 98 | .desiredSleepMillis = 15000, 99 | .slowRateMillis = 500, 100 | }).create(); 101 | } 102 | 103 | fn sequenceConnectToPunchClient(punchHostAndProxy: *const HostAndProxy, punchPort: u16) error { 104 | PunchSocketDisconnect, 105 | } { 106 | log("connecting to punch server {}{s}:{}...", .{punchHostAndProxy.proxy, punchHostAndProxy.host, punchPort}); 107 | const punchFd = netext.proxyConnect(&punchHostAndProxy.proxy, punchHostAndProxy.host, punchPort) catch |e| switch (e) { 108 | error.Retry => return error.PunchSocketDisconnect, 109 | }; 110 | defer common.shutdownclose(punchFd); 111 | log("connected to punch server", .{}); 112 | 113 | punch.util.doHandshake(punchFd, .forwarder, 10000) catch |e| switch (e) { 114 | error.PunchSocketDisconnect 115 | ,error.BadPunchHandshake 116 | => return error.PunchSocketDisconnect, 117 | }; 118 | 119 | var heartbeatTimer = Timer.init(15000); 120 | var waitOpenTunnelThrottler = makeThrottler("wait for OpenTunnel: "); 121 | while (true) { 122 | waitOpenTunnelThrottler.throttle(); 123 | log("waiting for OpenTunnel...", .{}); 124 | waitForOpenTunnelMessage(punchFd, &heartbeatTimer) catch |e| switch (e) { 125 | error.PunchSocketDisconnect => return error.PunchSocketDisconnect, 126 | }; 127 | var punchRecvState : PunchRecvState = PunchRecvState.Initial; 128 | var gotCloseTunnel = false; 129 | 130 | switch (sequenceConnectRawClient(punchFd, &heartbeatTimer, &punchRecvState, &gotCloseTunnel)) { 131 | error.PunchSocketDisconnect => return error.PunchSocketDisconnect, 132 | error.RawSocketDisconnect => { 133 | try punch.util.closeTunnel(punchFd, &punchRecvState, &gotCloseTunnel, &global.buffer); 134 | continue; 135 | }, 136 | } 137 | } 138 | } 139 | 140 | fn sequenceConnectRawClient(punchFd: fd_t, heartbeatTimer: *Timer, punchRecvState: *PunchRecvState, gotCloseTunnel: *bool) error { 141 | PunchSocketDisconnect, 142 | RawSocketDisconnect, 143 | } { 144 | const rawFd = netext.socket(global.rawForwardAddr.any.family, os.SOCK.STREAM , os.IPPROTO.TCP) catch |e| switch (e) { 145 | error.Retry => return error.RawSocketDisconnect, 146 | }; 147 | defer os.close(rawFd); 148 | 149 | log("s={} connecting raw to {}", .{rawFd, global.rawForwardAddr}); 150 | netext.connect(rawFd, &global.rawForwardAddr) catch |e| switch (e) { 151 | error.Retry => return error.RawSocketDisconnect, 152 | }; 153 | defer common.shutdown(rawFd) catch |e| { 154 | log("WARNING: shutdown raw s={} failed with {}", .{rawFd, e}); 155 | }; 156 | log("s={} raw side connected", .{rawFd}); 157 | 158 | var eventingThrottler = makeThrottler("eventing throttler: "); 159 | while (true) { 160 | eventingThrottler.throttle(); 161 | switch (sequenceForwardingLoop(punchFd, heartbeatTimer, punchRecvState, gotCloseTunnel, rawFd)) { 162 | error.EpollError => continue, 163 | error.PunchSocketDisconnect => return error.PunchSocketDisconnect, 164 | error.RawSocketDisconnect => return error.RawSocketDisconnect, 165 | } 166 | } 167 | } 168 | 169 | fn waitForOpenTunnelMessage(punchFd: fd_t, heartbeatTimer: *Timer) !void { 170 | while (true) { 171 | const sleepMillis = punch.util.serviceHeartbeat(punchFd, heartbeatTimer, false) catch |e| switch (e) { 172 | error.PunchSocketDisconnect => return error.PunchSocketDisconnect, 173 | }; 174 | var buf: [1]u8 = undefined; 175 | const gotMessage = netext.recvfullTimeout(punchFd, &buf, sleepMillis) catch |e| switch (e) { 176 | error.Disconnected => return error.PunchSocketDisconnect, 177 | error.Retry => { 178 | // we can do this because we are only receiving 1-byte 179 | delaySeconds(1, "before calling recv again..."); 180 | continue; 181 | }, 182 | }; 183 | if (gotMessage) { 184 | if (buf[0] == punch.proto.TwoWayMessage.Heartbeat) { 185 | //log("[DEBUG] got heartbeat", .{}); 186 | } else if(buf[0] == punch.proto.InitiatorMessage.OpenTunnel) { 187 | log("got OpenTunnel message", .{}); 188 | return; 189 | } else { 190 | log("got unexpected punch message {}, will disconnect", .{buf[0]}); 191 | return error.PunchSocketDisconnect; 192 | } 193 | } 194 | } 195 | } 196 | 197 | // Note: !noreturn would be better in this case (see https://github.com/ziglang/zig/issues/3461) 198 | fn sequenceForwardingLoop(punchFd: fd_t, heartbeatTimer: *Timer, punchRecvState: *PunchRecvState, 199 | gotCloseTunnel: *bool, rawFd: fd_t) error { 200 | EpollError, 201 | PunchSocketDisconnect, 202 | RawSocketDisconnect, 203 | } { 204 | var eventer = common.eventerInit(Eventer, Eventer.Data { 205 | .punchFd = punchFd, 206 | .rawFd = rawFd, 207 | .punchRecvState = punchRecvState, 208 | .gotCloseTunnel = gotCloseTunnel, 209 | }) catch |e| switch (e) { 210 | error.Retry => return error.EpollError, 211 | }; 212 | defer eventer.deinit(); 213 | 214 | var punchCallback = Eventer.Callback { 215 | .func = onPunchData, 216 | .data = .{.fd = punchFd}, 217 | }; 218 | common.eventerAdd(Eventer, &eventer, punchFd, EventFlags.read, &punchCallback) catch |e| switch (e) { 219 | error.Retry => return error.EpollError, 220 | }; 221 | defer eventer.remove(punchFd); 222 | 223 | var rawCallback = Eventer.Callback { 224 | .func = onRawData, 225 | .data = .{.fd = rawFd}, 226 | }; 227 | common.eventerAdd(Eventer, &eventer, rawFd, EventFlags.read, &rawCallback) catch |e| switch (e) { 228 | error.Retry => return error.EpollError, 229 | }; 230 | defer eventer.remove(rawFd); 231 | 232 | while (true) { 233 | const sleepMillis = punch.util.serviceHeartbeat(punchFd, heartbeatTimer, false) catch |e| switch (e) { 234 | error.PunchSocketDisconnect => return error.PunchSocketDisconnect, 235 | }; 236 | //log("[DEBUG] waiting for events (sleep {} ms)...", .{sleepMillis}); 237 | _ = eventer.handleEvents(sleepMillis) catch |e| switch (e) { 238 | error.PunchSocketDisconnect => return error.PunchSocketDisconnect, 239 | error.RawSocketDisconnect => return error.RawSocketDisconnect, 240 | }; 241 | } 242 | } 243 | 244 | fn onRawData(eventer: *Eventer, callback: *Eventer.Callback) Eventer.CallbackError!void { 245 | std.debug.assert(callback.data.fd == eventer.data.rawFd); 246 | punch.util.forwardRawToPunch(callback.data.fd, eventer.data.punchFd, &global.buffer) catch |e| switch (e) { 247 | error.RawSocketDisconnect => return error.RawSocketDisconnect, 248 | error.PunchSocketDisconnect => return error.PunchSocketDisconnect, 249 | }; 250 | } 251 | 252 | fn onPunchData(eventer: *Eventer, callback: *Eventer.Callback) Eventer.CallbackError!void { 253 | const len = netext.read(callback.data.fd, &global.buffer) catch |e| switch (e) { 254 | error.Retry => { 255 | delaySeconds(1, "before trying to read punch socket again..."); 256 | return; 257 | }, 258 | error.Disconnected => return error.PunchSocketDisconnect, 259 | }; 260 | if (len == 0) { 261 | log("punch socket disconnected (read returned 0)", .{}); 262 | return error.PunchSocketDisconnect; 263 | } 264 | var data = global.buffer[0..len]; 265 | while (data.len > 0) { 266 | const action = punch.util.parsePunchToNextAction(eventer.data.punchRecvState, &data) catch |e| switch (e) { 267 | error.InvalidPunchMessage => { 268 | log("received unexpected punch message {}", .{data[0]}); 269 | // socket will be shutdown in a defer 270 | return error.PunchSocketDisconnect; 271 | }, 272 | }; 273 | switch (action) { 274 | .None => { 275 | std.debug.assert(data.len == 0); 276 | break; 277 | }, 278 | .OpenTunnel => { 279 | log("WARNING: received OpenTunnel message when a tunnel is already open", .{}); 280 | // socket will be shutdown in a defer 281 | return error.PunchSocketDisconnect; 282 | }, 283 | .CloseTunnel => { 284 | log("received CloseTunnel message", .{}); 285 | eventer.data.gotCloseTunnel.* = true; 286 | return error.RawSocketDisconnect; 287 | }, 288 | .ForwardData => |forwardAction| { 289 | //log("[VERBOSE] forwarding {} bytes to raw socket...", .{forwardAction.data.len}); 290 | netext.send(eventer.data.rawFd, forwardAction.data, 0) catch |e| switch (e) { 291 | error.Disconnected, error.Retry => { 292 | log("s={} send failed on raw socket", .{eventer.data.rawFd}); 293 | return error.RawSocketDisconnect; 294 | }, 295 | }; 296 | }, 297 | } 298 | } 299 | } 300 | -------------------------------------------------------------------------------- /old/reverse-tunnel-client.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const mem = std.mem; 3 | const os = std.os; 4 | const net = std.net; 5 | 6 | const logging = @import("./logging.zig"); 7 | const common = @import("./common.zig"); 8 | const timing = @import("./timing.zig"); 9 | const eventing = @import("./eventing.zig"); 10 | const pool = @import("./pool.zig"); 11 | const Pool = pool.Pool; 12 | 13 | const log = logging.log; 14 | const fd_t = os.fd_t; 15 | const Address = net.Address; 16 | const Timestamp = timing.Timestamp; 17 | const Timers = timing.TimersTemplate(struct {}); 18 | const EventFlags = eventing.EventFlags; 19 | 20 | const Eventer = eventing.EventerTemplate(anyerror, struct { 21 | timers: *Timers 22 | }, struct { 23 | fd: fd_t, 24 | addr: Address, 25 | connectAttempt: u32, 26 | timer: ?Timers.Callback, 27 | connectionId: u32, 28 | }); 29 | 30 | fn getField(comptime structInfo: std.builtin.TypeInfo.Struct, name: []const u8) ?std.builtin.TypeInfo.StructField { 31 | for (structInfo.fields) |field| { 32 | if (std.mem.eql(u8, field.name, name)) return field; 33 | } 34 | return null; 35 | } 36 | 37 | fn timerToEventerCallback(callback: *Timers.Callback) *Eventer.Callback { 38 | const basePtr = @ptrCast(*Eventer.Callback, 39 | @ptrCast([*]u8, callback) - ( 40 | @byteOffsetOf(Eventer.Callback, "data") 41 | + @byteOffsetOf(getField(@typeInfo(Eventer.Callback).Struct, "data").?.field_type, "timer") 42 | //+ @byteOffsetOf(?Timers.Callback, "?") 43 | )); 44 | std.debug.assert(&(basePtr.data.timer.?) == callback); 45 | return basePtr; 46 | } 47 | 48 | const global = struct { 49 | var eventer : Eventer = undefined; 50 | var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator); 51 | var configServerCallback : Eventer.Callback = undefined; 52 | // one more than 256 because we are using a 1-byte length for now 53 | var configRecvBuf : [257]u8 = undefined; 54 | var configRecvLen : usize = 0; 55 | 56 | // TODO: change from 1 to something else 57 | var hostPool = Pool(Eventer.Callback, 1).init(&arena.allocator); 58 | }; 59 | 60 | fn usage() void { 61 | log("Usage: reverse-tunnel-client CONFIG_SERVER", .{}); 62 | } 63 | pub fn main() anyerror!u8 { 64 | const args = try std.process.argsAlloc(&global.arena.allocator); 65 | if (args.len <= 1) { 66 | usage(); 67 | return 1; 68 | } 69 | args = args[1..]; 70 | if (args.len != 1) { 71 | usage(); 72 | return 1; 73 | } 74 | const configServerString = args[0]; 75 | if (std.builtin.os.tag == .windows) 76 | @compileError("how do I make a socket non-blocking on windows?"); 77 | 78 | var timers = Timers.init(); 79 | global.eventer = try Eventer.init(.{ 80 | .timers = &timers, 81 | }); 82 | global.configServerCallback = .{ 83 | .func = invalidCallback, 84 | .data = .{ 85 | .fd = -1, // set to -1 for sanity checking 86 | .addr = Address.parseIp4(configServerString, 9281) catch |e| { 87 | log("Error: failed to parse '{}' as an IPv4 address: {}", .{configServerString, e}); 88 | return 1; 89 | }, 90 | .connectAttempt = 0, 91 | .timer = null, 92 | .connectionId = undefined, // unused here on config-server 93 | }, 94 | }; 95 | try startConnect(&global.eventer, &global.configServerCallback); 96 | while (true) { 97 | const optionalTimeout = try timers.handleEvents(); 98 | if (optionalTimeout) |timeout| { 99 | const millis = timing.timestampToMillis(timeout); 100 | //log("timeoutMillis {}", .{timeoutMillis}); 101 | _ = try global.eventer.handleEvents(@intCast(u32, millis)); 102 | } else { 103 | _ = try global.eventer.handleEventsNoTimeout(); 104 | } 105 | } 106 | } 107 | 108 | fn invalidCallback(eventer: *Eventer, callback: *Eventer.Callback) !void { 109 | return error.InvalidCallback; 110 | } 111 | 112 | // assumption: callback has not been added to eventer yet 113 | fn startConnect(eventer: *Eventer, callback: *Eventer.Callback) !void { 114 | std.debug.assert(callback.func == invalidCallback); 115 | std.debug.assert(callback.data.fd == -1); 116 | std.debug.assert(callback.data.timer == null); 117 | callback.data.connectAttempt += 1; 118 | callback.data.fd = try os.socket(callback.data.addr.any.family, os.SOCK_STREAM | os.SOCK_NONBLOCK, os.IPPROTO_TCP); 119 | log("s={} connecting to {} (attempt {})", .{callback.data.fd, callback.data.addr, callback.data.connectAttempt}); 120 | common.connect(callback.data.fd, &callback.data.addr) catch |e| { 121 | if (e == error.WouldBlock) { 122 | callback.func = onConnecting; 123 | try eventer.add(callback.data.fd, EventFlags.write, callback); 124 | } else { 125 | log("connect failed with {}, will retry", .{e}); 126 | try startConnectTimer(eventer, callback); 127 | } 128 | return; 129 | }; 130 | log("s={} connected to {} (immediately)...", .{callback.data.fd, callback.data.addr}); 131 | if (callback == &global.configServerCallback) { 132 | callback.func = onConfigData; 133 | global.configRecvLen = 0; 134 | } else { 135 | callback.func = onHostData; 136 | } 137 | try eventer.add(callback.data.fd, EventFlags.read, callback); 138 | } 139 | 140 | fn startConnectTimer(eventer: *Eventer, callback :*Eventer.Callback) !void { 141 | std.debug.assert(callback.data.fd == -1); 142 | std.debug.assert(callback.data.timer == null); 143 | const delaySeconds : Timestamp = init: { 144 | if (callback.data.connectAttempt <= 10) 145 | break :init 1; 146 | if (callback.data.connectAttempt <= 30) 147 | break :init 5; 148 | break :init 15; 149 | }; 150 | callback.data.timer = Timers.Callback.init( 151 | timing.getNowTimestamp() + timing.secondsToTimestamp(delaySeconds), 152 | finishConnectTimer, .{}); 153 | try eventer.data.timers.add(&callback.data.timer.?); 154 | } 155 | fn finishConnectTimer(timers: *Timers, timerCallback: *Timers.Callback) !void { 156 | const eventCallback = timerToEventerCallback(timerCallback); 157 | std.debug.assert(eventCallback.data.fd == -1); 158 | std.debug.assert(eventCallback.data.timer != null); 159 | //log("[DEBUG] finishConnect s={} addr={}", .{eventCallback.data.fd, eventCallback.data.addr}); 160 | eventCallback.data.timer = null; // for sanity checking 161 | try startConnect(&global.eventer, eventCallback); 162 | } 163 | 164 | fn reconnect(eventer: *Eventer, callback: *Eventer.Callback) !void { 165 | std.debug.assert(callback.data.fd != -1); 166 | 167 | eventer.remove(callback.data.fd); 168 | os.close(callback.data.fd); 169 | if (callback.func != onConnecting) 170 | callback.data.connectAttempt = 0; 171 | callback.data.fd = -1; // used for sanity checking 172 | callback.func = invalidCallback; // used for sanity checking 173 | try startConnectTimer(eventer, callback); 174 | } 175 | 176 | fn onConnecting(eventer: *Eventer, callback: *Eventer.Callback) anyerror!void { 177 | std.debug.assert(callback.data.fd != -1); 178 | std.debug.assert(callback.data.timer == null); 179 | 180 | const sockError = try common.getsockerror(callback.data.fd); 181 | if (sockError != 0) { 182 | log("s={} socket error {}", .{callback.data.fd, sockError}); 183 | try reconnect(eventer, callback); 184 | return; 185 | } 186 | log("s={} connected to {} (attempt {})", .{callback.data.fd, callback.data.addr, callback.data.connectAttempt}); 187 | if (callback == &global.configServerCallback) { 188 | callback.func = onConfigData; 189 | global.configRecvLen = 0; 190 | } else { 191 | callback.func = onHostData; 192 | } 193 | try eventer.modify(callback.data.fd, EventFlags.read, callback); 194 | } 195 | 196 | fn onConfigData(eventer: *Eventer, callback: *Eventer.Callback) anyerror!void { 197 | std.debug.assert(callback == &global.configServerCallback); 198 | std.debug.assert(callback.data.fd != -1); 199 | std.debug.assert(callback.data.timer == null); 200 | 201 | if (global.configRecvLen == global.configRecvBuf.len) { 202 | log("s={} no more room left ({}) in config recv buffer", .{callback.data.fd, global.configRecvLen}); 203 | try reconnect(eventer, callback); 204 | return; 205 | } 206 | const len = os.read(callback.data.fd, global.configRecvBuf[global.configRecvLen..]) catch |e| { 207 | log("s={} read on config-server failed: {}", .{callback.data.fd, e}); 208 | try reconnect(eventer, callback); 209 | return; 210 | }; 211 | if (len == 0) { 212 | log("s={} config-server closed connection (read returned 0)", .{callback.data.fd}); 213 | try reconnect(eventer, callback); 214 | return; 215 | } 216 | log("s={} got {} bytes from config-server", .{callback.data.fd, len}); 217 | global.configRecvLen += len; 218 | handleConfigMessages() catch |e| { 219 | if (e == error.InvalidMessage) { 220 | // error already logged 221 | try reconnect(eventer, callback); 222 | return; 223 | } 224 | return e; 225 | }; 226 | } 227 | fn onHostData(eventer: *Eventer, callback: *Eventer.Callback) anyerror!void { 228 | return error.NotImplemented; 229 | } 230 | 231 | fn handleConfigMessages() !void { 232 | var next : usize = 0; 233 | while (true) { 234 | if (next >= global.configRecvLen) 235 | break; // need more data 236 | const msgLen = global.configRecvBuf[next]; 237 | const msgOff = next + 1; 238 | const msgLimit = msgOff + msgLen; 239 | if (msgLimit > global.configRecvLen) 240 | break; // need more data 241 | try handleConfigMessage(global.configRecvBuf[msgOff..msgLimit]); 242 | next = msgLimit; 243 | } 244 | // shift everything to the left 245 | if (next > 0) { 246 | global.configRecvLen = global.configRecvLen - next; 247 | memcpyLowToHigh(&global.configRecvBuf, &global.configRecvBuf + next, global.configRecvLen); 248 | } 249 | } 250 | 251 | fn bigEndianDeserializeU32(ptr: [*]const u8) u32 { 252 | return 253 | @intCast(u32, ptr[0]) << 24 | 254 | @intCast(u32, ptr[1]) << 16 | 255 | @intCast(u32, ptr[2]) << 8 | 256 | @intCast(u32, ptr[3]) << 0 ; 257 | } 258 | 259 | fn handleConfigMessage(msg: []const u8) !void { 260 | if (msg.len == 0) { 261 | log("got heartbeat!", .{}); 262 | return; 263 | } 264 | // add host 265 | if (msg[0] == 3) { 266 | if (msg.len != 11) { 267 | log("WARNING: msg 3 (add host) must be 7 bytes but got {}", .{msg.len}); 268 | return error.InvalidMessage; 269 | } 270 | const connectionId = bigEndianDeserializeU32(msg.ptr + 1); 271 | var addrBytes : [4]u8 = undefined; 272 | const port : u16 = (@intCast(u16, msg[9]) << 8) | msg[10]; 273 | mem.copy(u8, &addrBytes, msg[5..9]); 274 | const addr = Address.initIp4(addrBytes, port); 275 | try addHost(connectionId, &addr); 276 | } else { 277 | log("WARNING: unknown message type '{}'", .{msg[0]}); 278 | return error.InvalidMessage; 279 | } 280 | } 281 | 282 | 283 | 284 | fn addHost(connectionId: u32, addr: *const Address) !void { 285 | { 286 | var range = global.hostPool.range(); 287 | while (range.next()) |hostCallback| { 288 | if (hostCallback.data.connectionId == connectionId) { 289 | if (Address.eql(hostCallback.data.addr, addr.*)) { 290 | log("host '{}' id {} already exists, s={}", .{addr, connectionId, hostCallback.data.fd}); 291 | return; 292 | } 293 | return error.notimpl; 294 | } 295 | log("[DEBUG] existing host s={} id={}", .{hostCallback.data.fd, hostCallback.data.connectionId}); 296 | } 297 | } 298 | log("add host command for '{}'", .{addr}); 299 | var newHost = try global.hostPool.create(); 300 | errdefer global.hostPool.destroy(newHost); 301 | newHost.* = .{ 302 | .func = invalidCallback, 303 | .data = .{ 304 | .fd = -1, // set to -1 for sanity checking 305 | .addr = addr.*, 306 | .connectAttempt = 0, 307 | .timer = null, 308 | .connectionId = connectionId, 309 | }, 310 | }; 311 | try startConnect(&global.eventer, newHost); 312 | } 313 | 314 | // copy memory from src to dst, moving from low addresses to higher 315 | fn memcpyLowToHigh(dst: [*]u8, src: [*]const u8, len: usize) void { 316 | var i : usize = 0; 317 | while (i < len) : (i += 1) { 318 | dst[i] = src[i]; 319 | } 320 | } 321 | -------------------------------------------------------------------------------- /common.zig: -------------------------------------------------------------------------------- 1 | const builtin = @import("builtin"); 2 | const std = @import("std"); 3 | const mem = std.mem; 4 | const os = std.os; 5 | 6 | const logging = @import("./logging.zig"); 7 | const timing = @import("./timing.zig"); 8 | 9 | const panic = std.debug.panic; 10 | const log = logging.log; 11 | const fd_t = os.fd_t; 12 | const socket_t = os.socket_t; 13 | const Address = std.net.Address; 14 | 15 | // TODO: this should go somewhere else (i.e. std.algorithm in D) 16 | pub fn skipOver(comptime T: type, haystack: *T, needle: []const u8) bool { 17 | if (mem.startsWith(u8, haystack.*, needle)) { 18 | haystack.* = haystack.*[needle.len..]; 19 | return true; 20 | } 21 | return false; 22 | } 23 | 24 | pub fn delaySeconds(seconds: u32, msg: []const u8) void { 25 | log("waiting {} seconds {s}", .{seconds, msg}); 26 | std.time.sleep(@intCast(u64, seconds) * std.time.ns_per_s); 27 | } 28 | 29 | pub fn makeListenSock(listenAddr: *Address) !socket_t { 30 | var flags : u32 = os.SOCK.STREAM; 31 | if (builtin.os.tag != .windows) { 32 | flags = flags | os.SOCK.NONBLOCK; 33 | } 34 | const sockfd = try os.socket(listenAddr.any.family, flags, os.IPPROTO.TCP); 35 | errdefer os.close(sockfd); 36 | if (builtin.os.tag != .windows) { 37 | try os.setsockopt(sockfd, os.SOL.SOCKET, os.SO.REUSEADDR, &mem.toBytes(@as(c_int, 1))); 38 | } 39 | os.bind(sockfd, &listenAddr.any, listenAddr.getOsSockLen()) catch |e| { 40 | std.debug.warn("bind to address '{}' failed: {}\n", .{listenAddr, e}); 41 | return error.AlreadyReported; 42 | }; 43 | os.listen(sockfd, 8) catch |e| { 44 | std.debug.warn("listen failed: {}\n", .{e}); 45 | return error.AlreadyReported; 46 | }; 47 | return sockfd; 48 | } 49 | 50 | pub fn getsockerror(sockfd: socket_t) !c_int { 51 | var errorCode : c_int = undefined; 52 | var resultLen : os.socklen_t = @sizeOf(c_int); 53 | switch (os.errno(os.linux.getsockopt(sockfd, os.SOL.SOCKET, os.SO.ERROR, @ptrCast([*]u8, &errorCode), &resultLen))) { 54 | 0 => return errorCode, 55 | .EBADF => unreachable, 56 | .EFAULT => unreachable, 57 | .EINVAL => unreachable, 58 | .ENOPROTOOPT => unreachable, 59 | .ENOTSOCK => unreachable, 60 | else => |err| return os.unexpectedErrno(err), 61 | } 62 | } 63 | 64 | pub fn connect(sockfd: socket_t, addr: *const Address) os.ConnectError!void { 65 | return os.connect(sockfd, &addr.any, addr.getOsSockLen()); 66 | } 67 | pub fn connectHost(host: []const u8, port: u16) !socket_t { 68 | // so far only ipv4 addresses supported 69 | if (Address.parseIp(host, port)) |addr| { 70 | const sockfd = try os.socket(addr.any.family, os.SOCK.STREAM, os.IPPROTO.TCP); 71 | errdefer os.close(sockfd); 72 | try os.connect(sockfd, &addr.any, addr.getOsSockLen()); 73 | return sockfd; 74 | } else |_| { 75 | // TODO: implement DNS 76 | return error.DnsNotSupported; 77 | } 78 | } 79 | 80 | const extern_windows = struct { 81 | pub extern "ws2_32" fn shutdown( 82 | s: socket_t, 83 | how: c_int 84 | ) callconv(os.windows.WINAPI) c_int; 85 | pub const SD_BOTH = 2; 86 | }; 87 | 88 | // TODO: move to standard library 89 | pub const ShutdownError = error{ 90 | ConnectionAborted, 91 | 92 | /// Connection was reset by peer, application should close socket as it is no longer usable. 93 | ConnectionResetByPeer, 94 | 95 | BlockingOperationInProgress, 96 | 97 | /// Shutdown was passed an invalid "how" argument 98 | InvalidShutdownHow, 99 | 100 | /// The network subsystem has failed. 101 | NetworkSubsystemFailed, 102 | 103 | /// The socket is not connected (connection-oriented sockets only). 104 | SocketNotConnected, 105 | 106 | /// The file descriptor sockfd does not refer to a socket. 107 | FileDescriptorNotASocket, 108 | 109 | SystemResources 110 | } || std.os.UnexpectedError; 111 | 112 | pub fn shutdown(sockfd: socket_t) ShutdownError!void { 113 | if (builtin.os.tag == .windows) { 114 | const result = extern_windows.shutdown(sockfd, extern_windows.SD_BOTH); 115 | if (0 != result) switch (std.os.windows.ws2_32.WSAGetLastError()) { 116 | .WSAECONNABORTED => return error.ConnectionAborted, 117 | .WSAECONNRESET => return error.ConnectionResetByPeer, 118 | .WSAEINPROGRESS => return error.BlockingOperationInProgress, 119 | .WSAEINVAL => return error.InvalidShutdownHow, 120 | .WSAENETDOWN => return error.NetworkSubsystemFailed, 121 | .WSAENOTCONN => return error.SocketNotConnected, 122 | .WSAENOTSOCK => return error.FileDescriptorNotASocket, 123 | .WSANOTINITIALISED => unreachable, 124 | else => |err| return std.os.windows.unexpectedWSAError(err), 125 | }; 126 | } else switch (os.errno(os.linux.shutdown(sockfd, os.SHUT.RDWR))) { 127 | .SUCCESS => return, 128 | .BADF => unreachable, 129 | .INVAL => return error.InvalidShutdownHow, 130 | .NOTCONN => return error.SocketNotConnected, 131 | .NOTSOCK => return error.FileDescriptorNotASocket, 132 | .NOBUFS => return error.SystemResources, 133 | else => |err| return os.unexpectedErrno(err), 134 | } 135 | } 136 | 137 | pub fn shutdownclose(sockfd: socket_t) void { 138 | shutdown(sockfd) catch { }; // ignore error 139 | os.close(sockfd); 140 | } 141 | 142 | // workaround https://github.com/ziglang/zig/issues/9971 143 | fn sendWorkaround(sockfd: socket_t, buf: []const u8, flags: u32) os.SendError!usize { 144 | if (builtin.os.tag == .windows) { 145 | const rc = os.windows.ws2_32.send(sockfd, buf.ptr, @intCast(i32, buf.len), flags); 146 | if (rc != os.windows.ws2_32.SOCKET_ERROR) 147 | return @intCast(usize, rc); 148 | 149 | switch (os.windows.ws2_32.WSAGetLastError()) { 150 | .WSAEACCES => return error.AccessDenied, 151 | .WSAECONNRESET => return error.ConnectionResetByPeer, 152 | .WSAEMSGSIZE => return error.MessageTooBig, 153 | .WSAENOBUFS => return error.SystemResources, 154 | .WSAENOTSOCK => return error.FileDescriptorNotASocket, 155 | .WSAEFAULT => unreachable, // The lpBuffers, lpTo, lpOverlapped, lpNumberOfBytesSent, or lpCompletionRoutine parameters are not part of the user address space, or the lpTo parameter is too small. 156 | .WSAEHOSTUNREACH => unreachable, 157 | // TODO: WSAEINPROGRESS, WSAEINTR 158 | .WSAEINVAL => unreachable, 159 | .WSAENETDOWN => return error.NetworkSubsystemFailed, 160 | .WSAENETRESET => return error.ConnectionResetByPeer, 161 | .WSAENOTCONN => unreachable, 162 | .WSAESHUTDOWN => unreachable, // The socket has been shut down; it is not possible to WSASendTo on a socket after shutdown has been invoked with how set to SD_SEND or SD_BOTH. 163 | .WSAEWOULDBLOCK => return error.WouldBlock, 164 | .WSANOTINITIALISED => unreachable, // A successful WSAStartup call must occur before using this function. 165 | else => |err| return os.windows.unexpectedWSAError(err), 166 | } 167 | } 168 | return os.send(sockfd, buf, flags); 169 | } 170 | 171 | pub fn sendfull(sockfd: socket_t, buf: []const u8, flags: u32) !void { 172 | var totalSent : usize = 0; 173 | while (totalSent < buf.len) { 174 | const lastSent = try sendWorkaround(sockfd, buf[totalSent..], flags); 175 | if (lastSent == 0) 176 | return error.SendReturnedZero; 177 | totalSent += lastSent; 178 | } 179 | } 180 | 181 | const WriteAllError = error { FdClosed } || std.os.WriteError; 182 | const WriteAllErrorResult = struct { 183 | err: WriteAllError, 184 | wrote: usize, 185 | }; 186 | pub fn tryWriteAll(fd: fd_t, buf: []const u8) ?WriteAllErrorResult { 187 | var total_wrote : usize = 0; 188 | while (total_wrote < buf.len) { 189 | const last_wrote = os.write(fd, buf[total_wrote..]) catch |e| 190 | return WriteAllErrorResult { .err = e, .wrote = total_wrote }; 191 | if (last_wrote == 0) 192 | return WriteAllErrorResult { .err = error.FdClosed, .wrote = total_wrote }; 193 | total_wrote += last_wrote; 194 | } 195 | return null; 196 | } 197 | 198 | fn waitGenericTimeout(fd: fd_t, timeoutMillis: i32, events: i16) !bool { 199 | var pollfds = [1]os.linux.pollfd { 200 | os.linux.pollfd { .fd = fd, .events = events, .revents = undefined }, 201 | }; 202 | const result = os.poll(&pollfds, timeoutMillis) catch |e| switch (e) { 203 | error.SystemResources 204 | ,error.NetworkSubsystemFailed 205 | => { 206 | log("poll function failed with {}", .{e}); 207 | return error.Retry; 208 | }, 209 | error.Unexpected 210 | => panic("poll function failed with {}", .{e}), 211 | }; 212 | if (result == 0) return false; // timeout 213 | if (result == 1) return true; // socket is readable 214 | panic("poll function with only 1 fd returned {}", .{result}); 215 | } 216 | 217 | // returns: true if readable, false on timeout 218 | pub fn waitReadableTimeout(fd: fd_t, timeoutMillis: i32) !bool { 219 | return waitGenericTimeout(fd, timeoutMillis, os.POLL.IN); 220 | } 221 | pub fn waitReadable(fd: fd_t) !void { 222 | if (!try waitReadableTimeout(fd, -1)) 223 | panic("poll function with infinite timeout returned 0", .{}); 224 | } 225 | 226 | pub fn waitWriteableTimeout(fd: fd_t, timeoutMillis: i32) !bool { 227 | return waitGenericTimeout(fd, timeoutMillis, os.POLL.OUT); 228 | } 229 | 230 | pub fn recvfullTimeout(sockfd: socket_t, buf: []u8, timeoutMillis: u32) !bool { 231 | var newTimeoutMillis = timeoutMillis; 232 | var totalReceived : usize = 0; 233 | while (newTimeoutMillis > @intCast(u32, std.math.maxInt(i32))) { 234 | const received = try recvfullTimeoutHelper(sockfd, buf[totalReceived..], std.math.maxInt(i32)); 235 | totalReceived += received; 236 | if (totalReceived == buf.len) return true; 237 | newTimeoutMillis -= std.math.maxInt(i32); 238 | } 239 | totalReceived += try recvfullTimeoutHelper(sockfd, buf[totalReceived..], @intCast(i32, newTimeoutMillis)); 240 | return totalReceived == buf.len; 241 | } 242 | fn recvfullTimeoutHelper(sockfd: socket_t, buf: []u8, timeoutMillis: i32) !usize { 243 | std.debug.assert(timeoutMillis >= 0); // code bug otherwise 244 | var totalReceived : usize = 0; 245 | if (buf.len > 0) { 246 | const startTime = std.time.milliTimestamp(); 247 | while (true) { 248 | const readable = try waitReadableTimeout(sockfd, timeoutMillis); 249 | if (!readable) break; 250 | const result = try os.read(sockfd, buf[totalReceived..]); 251 | if (result <= 0) break; 252 | totalReceived += result; 253 | if (totalReceived == buf.len) break; 254 | const elapsed = timing.timestampDiff(std.time.milliTimestamp(), startTime); 255 | if (elapsed > timeoutMillis) break; 256 | } 257 | return totalReceived; 258 | } 259 | return totalReceived; 260 | } 261 | 262 | pub fn getOptArg(args: anytype, i: *usize) !@TypeOf(args[0]) { 263 | i.* += 1; 264 | if (i.* >= args.len) { 265 | std.debug.warn("Error: option '{s}' requires an argument\n", .{args[i.* - 1]}); 266 | return error.CommandLineOptionMissingArgument; 267 | } 268 | return args[i.*]; 269 | } 270 | 271 | /// logs an error if it fails 272 | pub fn parsePort(s: []const u8) !u16 { 273 | return std.fmt.parseInt(u16, s, 10) catch |e| { 274 | log("Error: failed to parse '{s}' as a port: {}", .{s, e}); 275 | return error.InvalidPortString; 276 | }; 277 | } 278 | /// logs an error if it fails 279 | pub fn parseIp4(s: []const u8, port: u16) !Address { 280 | return Address.parseIp4(s, port) catch |e| { 281 | log("Error: failed to parse '{s}' as an IPv4 address: {}", .{s, e}); 282 | return e; 283 | }; 284 | } 285 | 286 | pub fn eventerAdd(comptime Eventer: type, eventer: *Eventer, fd: Eventer.Fd, flags: u32, callback: *Eventer.Callback) !void { 287 | eventer.add(fd, flags, callback) catch |e| switch (e) { 288 | error.SystemResources 289 | ,error.UserResourceLimitReached 290 | => { 291 | log("epoll add error {}", .{e}); 292 | return error.Retry; 293 | }, 294 | error.FileDescriptorAlreadyPresentInSet 295 | ,error.OperationCausesCircularLoop 296 | ,error.FileDescriptorNotRegistered 297 | ,error.FileDescriptorIncompatibleWithEpoll 298 | ,error.Unexpected 299 | => panic("epoll add failed with {}", .{e}), 300 | }; 301 | } 302 | 303 | pub fn eventerInit(comptime Eventer: type, data: Eventer.Data) !Eventer { 304 | return Eventer.init(data) catch |e| switch (e) { 305 | error.ProcessFdQuotaExceeded 306 | ,error.SystemFdQuotaExceeded 307 | ,error.SystemResources 308 | => { 309 | log("epoll_create failed with {}", .{e}); 310 | return error.Retry; 311 | }, 312 | error.Unexpected 313 | => std.debug.panic("epoll_create failed with {}", .{e}), 314 | }; 315 | } 316 | -------------------------------------------------------------------------------- /punch/util.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const os = std.os; 3 | 4 | const logging = @import("../logging.zig"); 5 | const common = @import("../common.zig"); 6 | const timing = @import("../timing.zig"); 7 | const punch = @import("../punch.zig"); 8 | const proto = punch.proto; 9 | const netext = @import("../netext.zig"); 10 | 11 | const assert = std.debug.assert; 12 | const fd_t = os.fd_t; 13 | const log = logging.log; 14 | const Timestamp = timing.Timestamp; 15 | const Timer = timing.Timer; 16 | 17 | pub fn doHandshake(punchFd: fd_t, myRole: proto.Role, recvTimeoutMillis: u32) !void { 18 | var expectRole : punch.proto.Role = undefined; 19 | var handshakeToSend : []const u8 = undefined; 20 | switch (myRole) { 21 | .initiator => { 22 | expectRole = .forwarder; 23 | handshakeToSend = &punch.proto.initiatorHandshake; 24 | }, 25 | .forwarder => { 26 | expectRole = .initiator; 27 | handshakeToSend = &punch.proto.forwarderHandshake; 28 | }, 29 | } 30 | 31 | netext.send(punchFd, handshakeToSend, 0) catch |e| switch (e) { 32 | error.Disconnected,error.Retry => { 33 | log("failed to send punch handshake", .{}); 34 | return error.PunchSocketDisconnect; 35 | }, 36 | }; 37 | 38 | var handshake: [punch.proto.magic.len + 1]u8 = undefined; 39 | const gotHandshake = netext.recvfullTimeout(punchFd, &handshake, recvTimeoutMillis) catch |e| switch (e) { 40 | error.Disconnected,error.Retry => { 41 | log("failed to receive punch handshake", .{}); 42 | return error.PunchSocketDisconnect; 43 | }, 44 | }; 45 | if (!gotHandshake) { 46 | log("timed out waiting for punch handshake", .{}); 47 | return error.BadPunchHandshake; 48 | } 49 | const magic = handshake[0..punch.proto.magic.len]; 50 | if (!std.mem.eql(u8, magic, &punch.proto.magic)) { 51 | log("got punch connection but received invalid magic value {}", .{std.fmt.fmtSliceHexLower(magic)}); 52 | return error.BadPunchHandshake; 53 | } 54 | const role = handshake[punch.proto.magic.len]; 55 | if (role != @enumToInt(expectRole)) { 56 | log("received punch role {} but expected {} ({})", .{role, expectRole, @enumToInt(expectRole)}); 57 | return error.BadPunchHandshake; 58 | } 59 | } 60 | 61 | pub fn serviceHeartbeat(punchFd: fd_t, heartbeatTimer: *Timer, verboseHeartbeats: bool) !u32 { 62 | switch (heartbeatTimer.check()) { 63 | .Expired => { 64 | if (verboseHeartbeats) { 65 | log("[VERBOSE] sending heartbeat...", .{}); 66 | } 67 | punch.util.sendHeartbeat(punchFd) catch |e| switch (e) { 68 | error.Disconnected, error.Retry => return error.PunchSocketDisconnect, 69 | }; 70 | return heartbeatTimer.durationMillis; 71 | }, 72 | .Wait => |millis| return millis, 73 | } 74 | } 75 | 76 | pub fn sendHeartbeat(punchFd: fd_t) !void { 77 | const msg = [1]u8 {proto.TwoWayMessage.Heartbeat}; 78 | try netext.send(punchFd, &msg, 0); 79 | } 80 | pub fn sendCloseTunnel(punchFd: fd_t) !void { 81 | const msg = [1]u8 {proto.TwoWayMessage.CloseTunnel}; 82 | try netext.send(punchFd, &msg, 0); 83 | } 84 | pub fn sendOpenTunnel(punchFd: fd_t) !void { 85 | const msg = [1]u8 {proto.InitiatorMessage.OpenTunnel}; 86 | try netext.send(punchFd, &msg, 0); 87 | } 88 | 89 | pub fn waitForCloseTunnel(punchFd: fd_t, punchRecvState: *PunchRecvState, buffer: []u8, timeoutMillis: i32) !void { 90 | var failedAttempts : u16 = 0; 91 | const maxFailedAttempts = 5; 92 | while (true) { 93 | if (failedAttempts >= maxFailedAttempts) { 94 | log("failed to read from punch socket after {} attempts", .{failedAttempts}); 95 | return error.PunchSocketDisconnect; 96 | } 97 | const isReadable = common.waitReadableTimeout(punchFd, timeoutMillis) catch |e| switch (e) { 98 | error.Retry => { 99 | failedAttempts += 1; 100 | continue; 101 | }, 102 | }; 103 | if (!isReadable) 104 | return error.PunchSocketDisconnect; 105 | 106 | const len = netext.read(punchFd, buffer) catch |e| switch (e) { 107 | error.Retry => { 108 | failedAttempts += 1; 109 | continue; 110 | }, 111 | error.Disconnected => return error.PunchSocketDisconnect, 112 | }; 113 | if (len == 0) { 114 | log("punch socket disconnected (read returned 0)", .{}); 115 | return error.PunchSocketDisconnect; 116 | } 117 | var data = buffer[0..len]; 118 | while (data.len > 0) { 119 | const action = punch.util.parsePunchToNextAction(punchRecvState, &data) catch |e| switch (e) { 120 | error.InvalidPunchMessage => { 121 | log("received unexpected punch message {}", .{data[0]}); 122 | return error.PunchSocketDisconnect; 123 | }, 124 | }; 125 | switch (action) { 126 | .None => { 127 | std.debug.assert(data.len == 0); 128 | break; 129 | }, 130 | .OpenTunnel => { 131 | log("WARNING: received OpenTunnel message when a tunnel is already open", .{}); 132 | return error.PunchSocketDisconnect; 133 | }, 134 | .CloseTunnel => { 135 | log("received CloseTunnel message", .{}); 136 | return; 137 | }, 138 | .ForwardData => |forwardAction| { 139 | log("ignore {} bytes of forwarding data", .{forwardAction.data.len}); 140 | }, 141 | } 142 | } 143 | } 144 | } 145 | pub fn closeTunnel(punchFd: fd_t, punchRecvState: *PunchRecvState, gotCloseTunnel: *bool, buffer: []u8) !void { 146 | log("sending CloseTunnel...", .{}); 147 | punch.util.sendCloseTunnel(punchFd) catch |e| switch (e) { 148 | error.Disconnected, error.Retry => return error.PunchSocketDisconnect, 149 | }; 150 | if (!gotCloseTunnel.*) { 151 | punch.util.waitForCloseTunnel(punchFd, punchRecvState, buffer, 8000) catch |e| switch (e) { 152 | error.PunchSocketDisconnect => return error.PunchSocketDisconnect, 153 | }; 154 | gotCloseTunnel.* = true; 155 | } 156 | } 157 | 158 | pub fn forwardRawToPunch(rawFd: fd_t, punchFd: fd_t, buffer: []u8) !void { 159 | // NOTE: can't use the sendfile syscall because I need to 160 | // convert raw data to the punch-data packets 161 | // I could make this work if I opened a new raw connection instead 162 | // of sending data through the punch protocol 163 | // But the punch protocol isn't meant for alot of data, it's meant 164 | // to facilate things like manual SSH session to start other sessions. 165 | 166 | // receive at an offset to save room for the punch data command prefix 167 | // offset 64 to give the best chance for aligned data copy 168 | const length = netext.read(rawFd, buffer[64..]) catch |e| switch (e) { 169 | error.Disconnected, error.Retry => { 170 | log("s={} read on raw socket failed", .{rawFd}); 171 | common.shutdown(rawFd) catch {}; 172 | return error.RawSocketDisconnect; 173 | }, 174 | }; 175 | if (length == 0) { 176 | log("s={} raw socket disconnected (read returned 0)", .{rawFd}); 177 | return error.RawSocketDisconnect; 178 | } 179 | 180 | buffer[55] = proto.TwoWayMessage.Data; 181 | // std.mem.writeIntSliceBig doesn't seem to be working 182 | //std.mem.writeIntSliceBig(u64, buffer[56..], length); 183 | writeU64Big(buffer[56..].ptr, length); 184 | log("[VERBOSE] fowarding {} bytes to punch socket...", .{length}); 185 | netext.send(punchFd, buffer[55.. 64 + length], 0) catch |e| switch (e) { 186 | error.Disconnected, error.Retry => { 187 | log("s={} send failed on punch socket", .{punchFd}); 188 | return error.PunchSocketDisconnect; 189 | }, 190 | }; 191 | } 192 | 193 | fn writeU64Big(buf: [*]u8, value: u64) void { 194 | buf[0] = @truncate(u8, value >> 56); 195 | buf[1] = @truncate(u8, value >> 48); 196 | buf[2] = @truncate(u8, value >> 40); 197 | buf[3] = @truncate(u8, value >> 32); 198 | buf[4] = @truncate(u8, value >> 24); 199 | buf[5] = @truncate(u8, value >> 16); 200 | buf[6] = @truncate(u8, value >> 8); 201 | buf[7] = @truncate(u8, value >> 0); 202 | } 203 | 204 | pub const PunchRecvState = union(enum) { 205 | Initial: void, 206 | Data: Data, 207 | 208 | pub const Data = struct { 209 | lenBytesLeft: u8, 210 | dataLeft: u64, 211 | }; 212 | }; 213 | // tells the caller what to do 214 | const PunchAction = union(enum) { 215 | None: void, 216 | OpenTunnel: void, 217 | CloseTunnel: void, 218 | ForwardData: ForwardData, 219 | 220 | pub const ForwardData = struct { 221 | data: []const u8, 222 | }; 223 | }; 224 | 225 | pub fn parsePunchToNextAction(state: *PunchRecvState, data: *[]const u8) !PunchAction { 226 | assert(data.*.len > 0); 227 | //std.debug.warn("[DEBUG] parsing {}-bytes...\n", .{data.*.len}); 228 | while (true) { 229 | switch (try parsePunchMessage(state, data)) { 230 | .None => if (data.*.len == 0) return PunchAction.None, 231 | else => |action| return action, 232 | } 233 | } 234 | } 235 | 236 | fn parsePunchMessage(state: *PunchRecvState, data: *[]const u8) !PunchAction { 237 | switch (state.*) { 238 | .Initial => { 239 | const msgType = data.*[0]; 240 | data.* = data.*[1..]; 241 | if (msgType == proto.TwoWayMessage.Heartbeat) 242 | return PunchAction.None; 243 | if (msgType == proto.TwoWayMessage.CloseTunnel) 244 | return PunchAction.CloseTunnel; 245 | if (msgType == proto.InitiatorMessage.OpenTunnel) 246 | return PunchAction.OpenTunnel; 247 | if (msgType == proto.TwoWayMessage.Data) { 248 | state.* = PunchRecvState { .Data = .{ 249 | .lenBytesLeft = 8, 250 | .dataLeft = 0, 251 | }}; 252 | return try parsePunchDataMessage(state, data); 253 | } 254 | // rewind data so caller can see the invalid byte 255 | data.* = (data.ptr - 1)[0 .. data.*.len + 1]; 256 | return error.InvalidPunchMessage; 257 | }, 258 | .Data => return try parsePunchDataMessage(state, data), 259 | } 260 | } 261 | 262 | fn parsePunchDataMessage(state: *PunchRecvState, data: *[]const u8) !PunchAction { 263 | switch (state.*) { .Data=>{}, else => assert(false), } 264 | 265 | while (state.Data.lenBytesLeft > 0) : (state.Data.lenBytesLeft -= 1) { 266 | if (data.*.len == 0) 267 | return PunchAction.None; 268 | 269 | state.Data.dataLeft <<= 8; 270 | state.Data.dataLeft |= data.*[0]; 271 | data.* = data.*[1..]; 272 | } 273 | if (state.Data.dataLeft == 0) { 274 | state.* = PunchRecvState.Initial; 275 | return PunchAction.None; 276 | } 277 | if (data.*.len == 0) 278 | return PunchAction.None; 279 | if (data.*.len >= state.Data.dataLeft) { 280 | var forwardData = data.*[0..state.Data.dataLeft]; 281 | data.* = data.*[state.Data.dataLeft..]; 282 | state.* = PunchRecvState.Initial; 283 | return PunchAction { .ForwardData = .{ .data = forwardData } }; 284 | } 285 | state.Data.dataLeft -= data.*.len; 286 | var forwardData = data.*; 287 | data.* = data.*[data.*.len..]; 288 | return PunchAction { .ForwardData = .{ .data = forwardData } }; 289 | } 290 | 291 | 292 | const ParserTest = struct { 293 | data: []const u8, 294 | actions: []const PunchAction, 295 | }; 296 | 297 | fn testParser(t: *const ParserTest, chunkLen: usize) !void { 298 | var expectedActionIndex : usize = 0; 299 | var expectedForwardDataOffset : usize = 0; 300 | var state : PunchRecvState = PunchRecvState.Initial; 301 | var data = t.data; 302 | while (data.len > 0) { 303 | const nextLen = if (data.len < chunkLen) 304 | data.len else chunkLen; 305 | var nextChunk = data[0..nextLen]; 306 | const action = try parsePunchToNextAction(&state, &nextChunk); 307 | std.debug.warn("action {}\n", .{action}); 308 | switch (action) { 309 | .None => std.debug.assert(nextChunk.len == 0), 310 | .OpenTunnel => { 311 | switch (t.actions[expectedActionIndex]) { 312 | .OpenTunnel => {}, 313 | else => std.debug.assert(false), 314 | } 315 | expectedActionIndex += 1; 316 | }, 317 | .CloseTunnel => { 318 | switch (t.actions[expectedActionIndex]) { 319 | .CloseTunnel => {}, 320 | else => std.debug.assert(false), 321 | } 322 | expectedActionIndex += 1; 323 | }, 324 | .ForwardData => |actualForward| { 325 | std.debug.assert(actualForward.data.len > 0); 326 | switch (t.actions[expectedActionIndex]) { 327 | .ForwardData => |expectedForward| { 328 | const expected = expectedForward.data[expectedForwardDataOffset..]; 329 | //std.debug.warn("[DEBUG] verifying {} bytes {x}\n", .{actualForward.data.len, std.fmt.fmtSliceHexLower(actualForward.data)}); 330 | std.debug.assert(std.mem.startsWith(u8, expected, actualForward.data)); 331 | expectedForwardDataOffset += actualForward.data.len; 332 | if (expectedForwardDataOffset == expectedForward.data.len) { 333 | expectedActionIndex += 1; 334 | expectedForwardDataOffset = 0; 335 | } 336 | }, 337 | else => std.debug.assert(false), 338 | } 339 | }, 340 | } 341 | std.debug.assert(nextLen > nextChunk.len); 342 | data = data[nextLen - nextChunk.len..]; 343 | } 344 | std.debug.assert(expectedActionIndex == t.actions.len); 345 | } 346 | 347 | test "parsePunchMessage" { 348 | const tests = [_]ParserTest { 349 | ParserTest { 350 | .data = &[_]u8 { 351 | proto.TwoWayMessage.Heartbeat, 352 | proto.InitiatorMessage.OpenTunnel, 353 | proto.TwoWayMessage.Heartbeat, 354 | proto.TwoWayMessage.Heartbeat, 355 | proto.TwoWayMessage.CloseTunnel, 356 | proto.TwoWayMessage.Heartbeat, 357 | proto.TwoWayMessage.Heartbeat, 358 | }, 359 | .actions = &[_]PunchAction { 360 | PunchAction.OpenTunnel, 361 | PunchAction.CloseTunnel, 362 | }, 363 | }, 364 | ParserTest { 365 | .data = &[_]u8 { 366 | proto.TwoWayMessage.Heartbeat, 367 | proto.TwoWayMessage.Heartbeat, 368 | proto.TwoWayMessage.Data, 369 | 0,0,0,0,0,0,0,0, 370 | proto.TwoWayMessage.Heartbeat, 371 | proto.TwoWayMessage.Data, 372 | 0,0,0,0,0,0,0,1, 373 | 0xac, 374 | proto.TwoWayMessage.Data, 375 | 0,0,0,0,0,0,0,0, 376 | proto.TwoWayMessage.Heartbeat, 377 | proto.TwoWayMessage.Data, 378 | 0,0,0,0,0,0,0,10, 379 | 0x12,0x34,0x45,0x67,0x89,0xab,0xcd,0xef,0x0a,0xf4, 380 | }, 381 | .actions = &[_]PunchAction { 382 | PunchAction { .ForwardData = .{ .data = &[_]u8{0xac} }}, 383 | PunchAction { .ForwardData = .{ .data = &[_]u8{0x12,0x34,0x45,0x67,0x89,0xab,0xcd,0xef,0x0a,0xf4} }}, 384 | }, 385 | }, 386 | }; 387 | for (tests) |t| { 388 | var i : usize = 1; 389 | while (i <= t.data.len) : (i += 1) { 390 | try testParser(&t, i); 391 | } 392 | } 393 | 394 | { 395 | var state : PunchRecvState = PunchRecvState.Initial; 396 | { 397 | var data : []const u8 = &[_]u8 {proto.TwoWayMessage.Heartbeat}; 398 | switch (try parsePunchMessage(&state, &data)) { 399 | .None => {}, 400 | else => assert(false), 401 | } 402 | assert(data.len == 0); 403 | } 404 | { 405 | var data : []const u8 = &[_]u8 {proto.TwoWayMessage.CloseTunnel}; 406 | switch (try parsePunchMessage(&state, &data)) { 407 | .CloseTunnel => {}, 408 | else => assert(false), 409 | } 410 | assert(data.len == 0); 411 | } 412 | { 413 | var data : []const u8 = &[_]u8 {proto.InitiatorMessage.OpenTunnel}; 414 | switch (try parsePunchMessage(&state, &data)) { 415 | .OpenTunnel => {}, 416 | else => assert(false), 417 | } 418 | assert(data.len == 0); 419 | } 420 | blk: { 421 | var data : []const u8 = &[_]u8 {10}; 422 | _ = parsePunchMessage(&state, &data) catch |e| { 423 | assert(e == error.InvalidPunchMessage); 424 | break :blk; 425 | }; 426 | assert(false); 427 | } 428 | { 429 | var data : []const u8 = &[_]u8 {proto.TwoWayMessage.Data,0,0,0,0,0,0,0,0}; 430 | switch (try parsePunchMessage(&state, &data)) { 431 | .None => {}, 432 | else => assert(false), 433 | } 434 | assert(data.len == 0); 435 | } 436 | { 437 | var data : []const u8 = &[_]u8 {proto.TwoWayMessage.Data,0,0,0,0,0,0,0,1,0xa3}; 438 | switch (try parsePunchMessage(&state, &data)) { 439 | .ForwardData => |forwardData| assert(std.mem.eql(u8, &[_]u8 {0xa3}, forwardData.data)), 440 | else => assert(false), 441 | } 442 | assert(data.len == 0); 443 | } 444 | } 445 | } 446 | -------------------------------------------------------------------------------- /punch-server-initiator.zig: -------------------------------------------------------------------------------- 1 | // 2 | // TODO: look at the delaySeconds where we retry accepting client 3 | // I should come up with a way to detect when the event loop 4 | // just starts churning with failed accept calls rather than 5 | // just delaying a second each time we get one 6 | // 7 | const std = @import("std"); 8 | const mem = std.mem; 9 | const os = std.os; 10 | const net = std.net; 11 | 12 | const logging = @import("./logging.zig"); 13 | const common = @import("./common.zig"); 14 | const netext = @import("./netext.zig"); 15 | const timing = @import("./timing.zig"); 16 | const eventing = @import("./eventing.zig").default; 17 | const punch = @import("./punch.zig"); 18 | 19 | const fd_t = os.fd_t; 20 | const Address = net.Address; 21 | const log = logging.log; 22 | const delaySeconds = common.delaySeconds; 23 | const Timer = timing.Timer; 24 | const EventFlags = eventing.EventFlags; 25 | const PunchRecvState = punch.util.PunchRecvState; 26 | 27 | const AcceptRawEventer = eventing.EventerTemplate(.{ 28 | .Data = struct { 29 | punchRecvState: *PunchRecvState, 30 | acceptedRawClient: fd_t, 31 | }, 32 | .CallbackError = error { PunchSocketDisconnect }, 33 | .CallbackData = struct { 34 | fd: fd_t, 35 | }, 36 | }); 37 | 38 | const ForwardingEventer = eventing.EventerTemplate(.{ 39 | .Data = struct { 40 | punchRecvState: *PunchRecvState, 41 | punchFd: fd_t, 42 | rawFd: fd_t, 43 | gotCloseTunnel: *bool, 44 | }, 45 | .CallbackError = error {PunchSocketDisconnect, RawSocketDisconnect}, 46 | .CallbackData = struct { 47 | fd: fd_t, 48 | }, 49 | }); 50 | 51 | const global = struct { 52 | var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator); 53 | var rawListenAddr : Address = undefined; 54 | var listenFd : fd_t = undefined; 55 | var buffer : [8192]u8 = undefined; 56 | }; 57 | 58 | fn makeThrottler(logPrefix: []const u8) timing.Throttler { 59 | return (timing.makeThrottler { 60 | .logPrefix = logPrefix, 61 | .desiredSleepMillis = 15000, 62 | .slowRateMillis = 500, 63 | }).create(); 64 | } 65 | 66 | fn usage() void { 67 | log("Usage: punch-server-initiator PUNCH_LISTEN_ADDR PUNCH_PORT RAW_LISTEN_ADDR RAW_PORT", .{}); 68 | } 69 | pub fn main() !u8 { 70 | var args = try std.process.argsAlloc(&global.arena.allocator); 71 | if (args.len <= 1) { 72 | usage(); 73 | return 1; 74 | } 75 | args = args[1..]; 76 | if (args.len != 4) { 77 | usage(); 78 | return 1; 79 | } 80 | const punchListenAddrString = args[0]; 81 | const punchPort = common.parsePort(args[1]) catch return 1; 82 | const rawListenAddrString = args[2]; 83 | const rawPort = common.parsePort(args[3]) catch return 1; 84 | 85 | var punchListenAddr = common.parseIp4(punchListenAddrString, punchPort) catch return 1; 86 | global.rawListenAddr = common.parseIp4(rawListenAddrString, rawPort) catch return 1; 87 | 88 | var bindThrottler = makeThrottler("bind throttler: "); 89 | while (true) { 90 | bindThrottler.throttle(); 91 | const punchListenFd = netext.makeListenSock(&punchListenAddr, 1) catch |e| switch (e) { 92 | error.Retry => continue, 93 | }; 94 | log("created punch listen socket s={}", .{punchListenFd}); 95 | defer os.close(punchListenFd); 96 | 97 | switch (sequenceAcceptPunchClient(punchListenFd)) { 98 | //error.RetryMakePunchListenSocket => continue, 99 | } 100 | } 101 | } 102 | 103 | fn sequenceAcceptPunchClient(punchListenFd: fd_t) error {} { 104 | var acceptPunchThrottler = makeThrottler("accept punch throttler: "); 105 | while (true) { 106 | acceptPunchThrottler.throttle(); 107 | 108 | log("accepting punch client...", .{}); 109 | var clientAddr : Address = undefined; 110 | var clientAddrLen : os.socklen_t = @sizeOf(@TypeOf(clientAddr)); 111 | const punchFd = netext.accept(punchListenFd, &clientAddr.any, &clientAddrLen, 0) catch |e| switch (e) { 112 | error.ClientDropped, error.Retry => continue, 113 | }; 114 | defer common.shutdownclose(punchFd); 115 | log("s={} accepted punch client {}", .{punchFd, clientAddr}); 116 | 117 | punch.util.doHandshake(punchFd, .initiator, 10000) catch |e| switch (e) { 118 | error.PunchSocketDisconnect 119 | ,error.BadPunchHandshake 120 | => continue, 121 | }; 122 | 123 | var heartbeatTimer = Timer.init(15000); 124 | switch (sequenceSetupEventing(punchListenFd, punchFd, &heartbeatTimer)) { 125 | error.PunchSocketDisconnect => continue, 126 | } 127 | } 128 | } 129 | 130 | fn sequenceSetupEventing(punchListenFd: fd_t, punchFd: fd_t, heartbeatTimer: *Timer) error { 131 | PunchSocketDisconnect, 132 | } { 133 | var eventingThrottler = makeThrottler("eventing throttler"); 134 | while (true) { 135 | eventingThrottler.throttle(); 136 | _ = punch.util.serviceHeartbeat(punchFd, heartbeatTimer, false) catch |e| switch (e) { 137 | error.PunchSocketDisconnect => return error.PunchSocketDisconnect, 138 | }; 139 | switch (sequenceAcceptRawClient(punchListenFd, punchFd, heartbeatTimer)) { 140 | error.EpollError 141 | ,error.CreateRawListenSocketFailed 142 | => continue, 143 | error.PunchSocketDisconnect => return error.PunchSocketDisconnect, 144 | } 145 | } 146 | } 147 | 148 | fn sequenceAcceptRawClient(punchListenFd: fd_t, punchFd: fd_t, heartbeatTimer: *Timer) error { 149 | EpollError, 150 | CreateRawListenSocketFailed, 151 | PunchSocketDisconnect, 152 | } { 153 | const epollfd = eventing.epoll_create1(0) catch |e| switch (e) { 154 | error.Retry => return error.EpollError, 155 | }; 156 | defer os.close(epollfd); 157 | 158 | const rawListenFd = netext.makeListenSock(&global.rawListenAddr, 1) catch |e| switch (e) { 159 | error.Retry => return error.CreateRawListenSocketFailed, 160 | }; 161 | defer os.close(rawListenFd); 162 | 163 | var punchRecvState : PunchRecvState = PunchRecvState.Initial; 164 | 165 | var acceptRawThrottler = makeThrottler("accept raw throttler: "); 166 | while (true) { 167 | acceptRawThrottler.throttle(); 168 | const rawFd = waitForRawClient(epollfd, punchListenFd, punchFd, heartbeatTimer, &punchRecvState, rawListenFd) catch |e| switch (e) { 169 | error.PunchSocketDisconnect => return error.PunchSocketDisconnect, 170 | error.EpollError => return error.EpollError, 171 | }; 172 | 173 | punch.util.sendOpenTunnel(punchFd) catch |e| switch (e) { 174 | error.Disconnected, error.Retry => return error.PunchSocketDisconnect, 175 | }; 176 | var gotCloseTunnel = false; 177 | 178 | switch (sequenceForwardingLoop(epollfd, punchListenFd, punchFd, heartbeatTimer, &punchRecvState, rawListenFd, rawFd, &gotCloseTunnel)) { 179 | error.PunchSocketDisconnect => return error.PunchSocketDisconnect, 180 | error.EpollError => { 181 | try punch.util.closeTunnel(punchFd, &punchRecvState, &gotCloseTunnel, &global.buffer); 182 | return error.EpollError; 183 | }, 184 | error.RawSocketDisconnect => { 185 | try punch.util.closeTunnel(punchFd, &punchRecvState, &gotCloseTunnel, &global.buffer); 186 | continue; 187 | }, 188 | } 189 | } 190 | } 191 | 192 | fn waitForRawClient(epollfd: fd_t, punchListenFd: fd_t, punchFd: fd_t, heartbeatTimer: *Timer, 193 | punchRecvState: *PunchRecvState, rawListenFd: fd_t) !fd_t { 194 | 195 | var eventer = AcceptRawEventer.initEpoll(.{ 196 | .punchRecvState = punchRecvState, 197 | .acceptedRawClient = -1, 198 | }, epollfd, false); 199 | defer eventer.deinit(); 200 | 201 | var punchListenCallback = AcceptRawEventer.Callback { 202 | .func = onPunchAcceptAcceptRaw, 203 | .data = .{.fd = punchListenFd}, 204 | }; 205 | common.eventerAdd(AcceptRawEventer, &eventer, punchListenFd, EventFlags.read, &punchListenCallback) catch |e| switch (e) { 206 | error.Retry => return error.EpollError, 207 | }; 208 | defer eventer.remove(punchListenFd); 209 | 210 | var punchCallback = AcceptRawEventer.Callback { 211 | .func = onPunchDataAcceptRaw, 212 | .data = .{.fd = punchFd}, 213 | }; 214 | common.eventerAdd(AcceptRawEventer, &eventer, punchFd, EventFlags.read, &punchCallback) catch |e| switch (e) { 215 | error.Retry => return error.EpollError, 216 | }; 217 | defer eventer.remove(punchFd); 218 | 219 | var rawListenCallback = AcceptRawEventer.Callback { 220 | .func = onFirstRawAccept, 221 | .data = .{.fd = rawListenFd}, 222 | }; 223 | common.eventerAdd(AcceptRawEventer, &eventer, rawListenFd, EventFlags.read, &rawListenCallback) catch |e| switch (e) { 224 | error.Retry => return error.EpollError, 225 | }; 226 | defer eventer.remove(rawListenFd); 227 | 228 | while (true) { 229 | const sleepMillis = punch.util.serviceHeartbeat(punchFd, heartbeatTimer, false) catch |e| switch (e) { 230 | error.PunchSocketDisconnect => return error.PunchSocketDisconnect, 231 | }; 232 | //log("[DEBUG] waiting for events (sleep {} ms)...", .{sleepMillis}); 233 | _ = eventer.handleEvents(sleepMillis) catch |e| switch (e) { 234 | error.PunchSocketDisconnect => return error.PunchSocketDisconnect, 235 | }; 236 | if (eventer.data.acceptedRawClient != -1) 237 | return eventer.data.acceptedRawClient; 238 | } 239 | } 240 | 241 | fn sequenceForwardingLoop(epollfd: fd_t, punchListenFd: fd_t, punchFd: fd_t, heartbeatTimer: *Timer, 242 | punchRecvState: *PunchRecvState, rawListenFd: fd_t, rawFd: fd_t, gotCloseTunnel: *bool) error { 243 | EpollError, 244 | RawSocketDisconnect, 245 | PunchSocketDisconnect, 246 | } { 247 | 248 | var eventer = ForwardingEventer.initEpoll(.{ 249 | .punchRecvState = punchRecvState, 250 | .punchFd = punchFd, 251 | .rawFd = rawFd, 252 | .gotCloseTunnel = gotCloseTunnel, 253 | }, epollfd, false); 254 | defer eventer.deinit(); 255 | 256 | var punchListenCallback = ForwardingEventer.Callback { 257 | .func = onPunchAcceptForwarding, 258 | .data = .{.fd = punchListenFd}, 259 | }; 260 | common.eventerAdd(ForwardingEventer, &eventer, punchListenFd, EventFlags.read, &punchListenCallback) catch |e| switch (e) { 261 | error.Retry => return error.EpollError, 262 | }; 263 | defer eventer.remove(punchListenFd); 264 | 265 | var punchCallback = ForwardingEventer.Callback { 266 | .func = onPunchDataForwarding, 267 | .data = .{.fd = punchFd}, 268 | }; 269 | common.eventerAdd(ForwardingEventer, &eventer, punchFd, EventFlags.read, &punchCallback) catch |e| switch (e) { 270 | error.Retry => return error.EpollError, 271 | }; 272 | defer eventer.remove(punchFd); 273 | 274 | var rawListenCallback = ForwardingEventer.Callback { 275 | .func = onRawAcceptForwarding, 276 | .data = .{.fd = rawListenFd}, 277 | }; 278 | common.eventerAdd(ForwardingEventer, &eventer, rawListenFd, EventFlags.read, &rawListenCallback) catch |e| switch (e) { 279 | error.Retry => return error.EpollError, 280 | }; 281 | defer eventer.remove(rawListenFd); 282 | 283 | var rawCallback = ForwardingEventer.Callback { 284 | .func = onRawData, 285 | .data = .{.fd = rawFd}, 286 | }; 287 | common.eventerAdd(ForwardingEventer, &eventer, rawFd, EventFlags.read, &rawCallback) catch |e| switch (e) { 288 | error.Retry => return error.EpollError, 289 | }; 290 | defer eventer.remove(rawFd); 291 | 292 | while (true) { 293 | const sleepMillis = punch.util.serviceHeartbeat(punchFd, heartbeatTimer, false) catch |e| switch (e) { 294 | error.PunchSocketDisconnect => return error.PunchSocketDisconnect, 295 | }; 296 | //log("[DEBUG] waiting for events (sleep {} ms)...", .{sleepMillis}); 297 | _ = eventer.handleEvents(sleepMillis) catch |e| switch (e) { 298 | error.PunchSocketDisconnect => return error.PunchSocketDisconnect, 299 | error.RawSocketDisconnect => return error.RawSocketDisconnect, 300 | }; 301 | } 302 | } 303 | 304 | fn onPunchAcceptAcceptRaw(eventer: *AcceptRawEventer, callback: *AcceptRawEventer.Callback) AcceptRawEventer.CallbackError!void { 305 | _ = eventer; 306 | dropClient(callback.data.fd, true); 307 | } 308 | fn onPunchAcceptForwarding(eventer: *ForwardingEventer, callback: *ForwardingEventer.Callback) ForwardingEventer.CallbackError!void { 309 | _ = eventer; 310 | dropClient(callback.data.fd, true); 311 | } 312 | fn onRawAcceptForwarding(eventer: *ForwardingEventer, callback: *ForwardingEventer.Callback) ForwardingEventer.CallbackError!void { 313 | _ = eventer; 314 | dropClient(callback.data.fd, false); 315 | } 316 | fn dropClient(listenFd: fd_t, isPunch: bool) void { 317 | var addr : Address = undefined; 318 | var addrLen : os.socklen_t = @sizeOf(Address); 319 | const fd = netext.accept(listenFd, &addr.any, &addrLen, 0) catch |e| switch (e) { 320 | error.ClientDropped => return, 321 | error.Retry => { 322 | delaySeconds(1, "before calling accept again..."); 323 | return; 324 | }, 325 | }; 326 | const kind : []const u8 = if (isPunch) "punch" else "raw"; 327 | log("got another {s} client s={} from {}, closing it...", .{kind, fd, addr}); 328 | common.shutdownclose(fd); 329 | } 330 | 331 | fn onFirstRawAccept(eventer: *AcceptRawEventer, callback: *AcceptRawEventer.Callback) AcceptRawEventer.CallbackError!void { 332 | std.debug.assert(eventer.data.acceptedRawClient == -1); 333 | 334 | var addr : Address = undefined; 335 | var addrLen : os.socklen_t = @sizeOf(Address); 336 | const rawFd = netext.accept(callback.data.fd, &addr.any, &addrLen, 0) catch |e| switch (e) { 337 | error.ClientDropped => return, 338 | error.Retry => { 339 | delaySeconds(1, "before accepting raw client again..."); 340 | return; 341 | }, 342 | }; 343 | errdefer common.shutdownclose(rawFd); 344 | log("accepted raw client s={} from {}", .{rawFd, addr}); 345 | eventer.data.acceptedRawClient = rawFd; // signals the eventer loop that we have accept a raw client 346 | } 347 | 348 | fn onRawData(eventer: *ForwardingEventer, callback: *ForwardingEventer.Callback) ForwardingEventer.CallbackError!void { 349 | punch.util.forwardRawToPunch(callback.data.fd, eventer.data.punchFd, &global.buffer) catch |e| switch (e) { 350 | error.RawSocketDisconnect => return error.RawSocketDisconnect, 351 | error.PunchSocketDisconnect => return error.PunchSocketDisconnect, 352 | }; 353 | } 354 | 355 | fn onPunchDataAcceptRaw(eventer: *AcceptRawEventer, callback: *AcceptRawEventer.Callback) AcceptRawEventer.CallbackError!void { 356 | onPunchData(AcceptRawEventer, eventer, callback.data.fd) catch |e| switch (e) { 357 | error.PunchSocketDisconnect => return error.PunchSocketDisconnect, 358 | }; 359 | } 360 | fn onPunchDataForwarding(eventer: *ForwardingEventer, callback: *ForwardingEventer.Callback) ForwardingEventer.CallbackError!void { 361 | onPunchData(ForwardingEventer, eventer, callback.data.fd) catch |e| switch (e) { 362 | error.PunchSocketDisconnect => return error.PunchSocketDisconnect, 363 | error.RawSocketDisconnect => return error.RawSocketDisconnect, 364 | }; 365 | } 366 | fn onPunchData(comptime Eventer: type, eventer: *Eventer, punchFd: fd_t) !void { 367 | const len = netext.read(punchFd, &global.buffer) catch |e| switch (e) { 368 | error.Retry => { 369 | delaySeconds(1, "before trying to read punch socket again..."); 370 | return; 371 | }, 372 | error.Disconnected => return error.PunchSocketDisconnect, 373 | }; 374 | if (len == 0) { 375 | log("punch socket disconnected (read returned 0)", .{}); 376 | return error.PunchSocketDisconnect; 377 | } 378 | var data = global.buffer[0..len]; 379 | //log("[DEBUG] received {}-bytes of punch data", .{data.len}); 380 | while (data.len > 0) { 381 | const action = punch.util.parsePunchToNextAction(eventer.data.punchRecvState, &data) catch |e| switch (e) { 382 | error.InvalidPunchMessage => { 383 | log("received unexpected punch message {}", .{data[0]}); 384 | return error.PunchSocketDisconnect; 385 | }, 386 | }; 387 | //log("[DEBUG] action {} data.len {}", .{action, data.len}); 388 | switch (action) { 389 | .None => { 390 | std.debug.assert(data.len == 0); 391 | break; 392 | }, 393 | .OpenTunnel => { 394 | log("WARNING: received OpenTunnel message from the forwarder", .{}); 395 | return error.PunchSocketDisconnect; 396 | }, 397 | .CloseTunnel => { 398 | if (Eventer == AcceptRawEventer) { 399 | log("WARNING: got CloseTunnel message but the tunnel is not open", .{}); 400 | return error.PunchSocketDisconnect; 401 | } else { 402 | log("received CloseTunnel message", .{}); 403 | eventer.data.gotCloseTunnel.* = true; 404 | return error.RawSocketDisconnect; 405 | } 406 | }, 407 | .ForwardData => |forwardAction| { 408 | if (Eventer == AcceptRawEventer) { 409 | log("WARNING: got ForwardData message but the tunnel is not open", .{}); 410 | return error.PunchSocketDisconnect; 411 | } else { 412 | log("[VERBOSE] forwarding {} bytes to raw socket s={}", .{forwardAction.data.len, eventer.data.rawFd}); 413 | netext.send(eventer.data.rawFd, forwardAction.data, 0) catch |e| switch (e) { 414 | error.Disconnected, error.Retry => { 415 | log("s={} send failed on raw socket", .{eventer.data.rawFd}); 416 | return error.RawSocketDisconnect; 417 | }, 418 | }; 419 | } 420 | }, 421 | } 422 | } 423 | } 424 | -------------------------------------------------------------------------------- /socat.zig: -------------------------------------------------------------------------------- 1 | // a simple version of netcat for testing 2 | // this is created so we have a common implementation for things like "CLOSE ON EOF" 3 | // 4 | // TODO: is it worth it to support the sendfile syscall variation? 5 | // maybe not since it will make this more complicated and its 6 | // main purpose is just for testing 7 | // 8 | const builtin = @import("builtin"); 9 | const std = @import("std"); 10 | const mem = std.mem; 11 | const os = std.os; 12 | const net = std.net; 13 | 14 | const common = @import("./common.zig"); 15 | const logging = @import("./logging.zig"); 16 | const timing = @import("./timing.zig"); 17 | const eventing = @import("./eventing.zig").default; 18 | const netext = @import("./netext.zig"); 19 | const proxy = @import("./proxy.zig"); 20 | 21 | const fd_t = os.fd_t; 22 | const socket_t = os.socket_t; 23 | const Address = net.Address; 24 | const log = logging.log; 25 | const EventFlags = eventing.EventFlags; 26 | const Eventer = eventing.EventerTemplate(.{ 27 | .CallbackError = error { Disconnect }, 28 | .CallbackData = struct { 29 | inOut: InOut, 30 | }, 31 | }); 32 | const Proxy = proxy.Proxy; 33 | 34 | const global = struct { 35 | var addr1String: []const u8 = undefined; 36 | var addr2String: []const u8 = undefined; 37 | var addr1 : Addr = undefined; 38 | var addr2 : Addr = undefined; 39 | var eventer : Eventer = undefined; 40 | var buffer : [8192]u8 = undefined; 41 | }; 42 | 43 | fn peelTo(strRef: *[]const u8, to: u8) ?[]const u8 { 44 | var str = strRef.*; 45 | for (str) |c, i| { 46 | if (c == to) { 47 | strRef.* = str[i+1..]; 48 | return str[0..i]; 49 | } 50 | } 51 | return null; 52 | } 53 | 54 | var noThrottle = false; 55 | fn makeThrottler(logPrefix: []const u8) timing.Throttler { 56 | return if (noThrottle) (timing.makeThrottler { 57 | .logPrefix = logPrefix, 58 | .desiredSleepMillis = 0, 59 | .slowRateMillis = 0, 60 | }).create() else (timing.makeThrottler { 61 | .logPrefix = logPrefix, 62 | .desiredSleepMillis = 15000, 63 | .slowRateMillis = 500, 64 | }).create(); 65 | } 66 | 67 | const ConnectPrep = union(enum) { 68 | None, 69 | TcpListen: TcpListen, 70 | 71 | pub const TcpListen = struct { 72 | listenFd: socket_t, 73 | }; 74 | }; 75 | 76 | const Addr = union(enum) { 77 | TcpConnect: TcpConnect, 78 | ProxyConnect: ProxyConnect, 79 | TcpListen: TcpListen, 80 | 81 | pub fn parse(spec: []const u8) !Addr { 82 | var rest = spec; 83 | const specType = peelTo(&rest, ':') orelse { 84 | std.debug.warn("Error: address '{s}' missing ':' to delimit type\n", .{spec}); 85 | return error.ParseAddrFailed; 86 | }; 87 | if (mem.eql(u8, specType, "tcp-connect")) 88 | return Addr { .TcpConnect = try TcpConnect.parse(rest) }; 89 | if (mem.eql(u8, specType, "proxy-connect")) 90 | return Addr { .ProxyConnect = try ProxyConnect.parse(rest) }; 91 | if (mem.eql(u8, specType, "tcp-listen")) 92 | return Addr { .TcpListen = try TcpListen.parse(rest) }; 93 | 94 | std.debug.warn("Error: unknown address-specifier type '{s}'\n", .{specType}); 95 | return error.ParseAddrFailed; 96 | } 97 | pub fn prepareConnect(self: *const Addr) !ConnectPrep { 98 | switch (self.*) { 99 | .TcpConnect => |a| return a.prepareConnect(), 100 | .ProxyConnect => |a| return a.prepareConnect(), 101 | .TcpListen => |a| return a.prepareConnect(), 102 | } 103 | } 104 | pub fn unprepareConnect(self: *const Addr, prep: *const ConnectPrep) void { 105 | switch (self.*) { 106 | .TcpConnect => |a| return a.unprepareConnect(prep), 107 | .ProxyConnect => |a| return a.unprepareConnect(prep), 108 | .TcpListen => |a| return a.unprepareConnect(prep), 109 | } 110 | } 111 | 112 | pub fn connect(self: *const Addr, prep: *const ConnectPrep) !InOut { 113 | switch (self.*) { 114 | .TcpConnect => |a| return a.connect(prep), 115 | .ProxyConnect => |a| return a.connect(prep), 116 | .TcpListen => |a| return a.connect(prep), 117 | } 118 | } 119 | pub fn connectSqueezeErrors(self: *const Addr, prep: *const ConnectPrep) !InOut { 120 | return self.connect(prep) catch |e| switch (e) { 121 | error.Retry => return error.Retry, 122 | error.RetryConnect => return error.RetryConnect, 123 | error.AddressInUse 124 | ,error.AddressNotAvailable 125 | ,error.SystemResources 126 | ,error.ConnectionRefused 127 | ,error.ConnectionResetByPeer 128 | ,error.NetworkUnreachable 129 | ,error.PermissionDenied 130 | ,error.ConnectionTimedOut 131 | ,error.WouldBlock 132 | ,error.FileNotFound 133 | ,error.ProcessFdQuotaExceeded 134 | ,error.SystemFdQuotaExceeded 135 | ,error.ConnectionPending 136 | => { 137 | log("connect failed with {}", .{e}); 138 | return error.Retry; 139 | }, 140 | error.AddressFamilyNotSupported 141 | ,error.SocketTypeNotSupported 142 | ,error.ProtocolFamilyNotAvailable 143 | ,error.ProtocolNotSupported 144 | ,error.Unexpected 145 | ,error.DnsNotSupported 146 | => std.debug.panic("FATAL ERROR: connect failed with {}", .{e}), 147 | }; 148 | } 149 | pub fn disconnect(self: *const Addr, inOut: InOut) void { 150 | switch (self.*) { 151 | .TcpConnect => |a| return a.disconnect(inOut), 152 | .ProxyConnect => |a| return a.disconnect(inOut), 153 | .TcpListen => |a| return a.disconnect(inOut), 154 | } 155 | } 156 | 157 | pub fn eventerAdd(self: *const Addr, prep: *const ConnectPrep, callback: *Eventer.Callback) !void { 158 | switch (self.*) { 159 | .TcpConnect => |a| return a.eventerAdd(prep, callback), 160 | .ProxyConnect => |a| return a.eventerAdd(prep, callback), 161 | .TcpListen => |a| return a.eventerAdd(prep, callback), 162 | } 163 | } 164 | pub fn eventerRemove(self: *const Addr, prep: *const ConnectPrep) void { 165 | switch (self.*) { 166 | .TcpConnect => |a| return a.eventerRemove(prep), 167 | .ProxyConnect => |a| return a.eventerRemove(prep), 168 | .TcpListen => |a| return a.eventerRemove(prep), 169 | } 170 | } 171 | 172 | pub const TcpConnect = struct { 173 | host: []const u8, 174 | port: u16, 175 | pub fn parse(spec: []const u8) !TcpConnect { 176 | var rest = spec; 177 | const host = peelTo(&rest, ':') orelse { 178 | std.debug.warn("Error: 'tcp-connect:{s}' missing ':' to delimit host\n", .{spec}); 179 | return error.ParseAddrFailed; 180 | }; 181 | const port = try common.parsePort(rest); 182 | return TcpConnect { 183 | .host = host, 184 | .port = port, 185 | }; 186 | } 187 | pub fn prepareConnect(self: *const TcpConnect) !ConnectPrep { 188 | _ = self; 189 | return ConnectPrep.None; 190 | } 191 | pub fn unprepareConnect(self: *const TcpConnect, prep: *const ConnectPrep) void { 192 | _ = self; 193 | _ = prep; 194 | } 195 | pub fn connect(self: *const TcpConnect, prep: *const ConnectPrep) !InOut { 196 | _ = prep; 197 | const sockfd = try common.connectHost(self.host, self.port); 198 | return InOut { .in = sockfd, .out = sockfd }; 199 | } 200 | pub fn disconnect(self: *const TcpConnect, inOut: InOut) void { 201 | _ = self; 202 | std.debug.assert(inOut.in == inOut.out); 203 | common.shutdownclose(inOut.in); 204 | } 205 | pub fn eventerAdd(self: *const TcpConnect, prep: *const ConnectPrep, callback: *Eventer.Callback) !void { 206 | _ = self; 207 | _ = prep; 208 | _ = callback; 209 | } 210 | pub fn eventerRemove(self: *const TcpConnect, prep: *const ConnectPrep) void { 211 | _ = self; 212 | _ = prep; 213 | } 214 | }; 215 | pub const ProxyConnect = struct { 216 | httpProxy: Proxy, 217 | targetHost: []const u8, 218 | targetPort: u16, 219 | pub fn parse(spec: []const u8) !ProxyConnect { 220 | var rest = spec; 221 | const proxyHost = peelTo(&rest, ':') orelse { 222 | std.debug.warn("Error: 'proxy-connect:{s}' missing ':' to delimit proxy-host\n", .{spec}); 223 | return error.ParseAddrFailed; 224 | }; 225 | const proxyPort = try common.parsePort(peelTo(&rest, ':') orelse { 226 | std.debug.warn("Error: 'proxy-connect:{s}' missing 2nd ':' to delimit proxy-port\n", .{spec}); 227 | return error.ParseAddrFailed; 228 | }); 229 | const targetHost = peelTo(&rest, ':') orelse { 230 | std.debug.warn("Error: 'proxy-connect:{s}' missing the 3rd ':' to delimit host\n", .{spec}); 231 | return error.ParseAddrFailed; 232 | }; 233 | const targetPort = try common.parsePort(rest); 234 | return ProxyConnect { 235 | .httpProxy = Proxy { .Http = .{ .host = proxyHost, .port = proxyPort } }, 236 | .targetHost = targetHost, 237 | .targetPort = targetPort, 238 | }; 239 | } 240 | pub fn prepareConnect(self: *const ProxyConnect) !ConnectPrep { 241 | _ = self; 242 | return ConnectPrep.None; 243 | } 244 | pub fn unprepareConnect(self: *const ProxyConnect, prep: *const ConnectPrep) void { 245 | _ = self; 246 | _ = prep; 247 | } 248 | pub fn connect(self: *const ProxyConnect, prep: *const ConnectPrep) !InOut { 249 | _ = prep; 250 | const sockfd = try netext.proxyConnect(&self.httpProxy, self.targetHost, self.targetPort); 251 | return InOut { .in = sockfd, .out = sockfd }; 252 | } 253 | pub fn disconnect(self: *const ProxyConnect, inOut: InOut) void { 254 | _ = self; 255 | std.debug.assert(inOut.in == inOut.out); 256 | common.shutdownclose(inOut.in); 257 | } 258 | pub fn eventerAdd(self: *const ProxyConnect, prep: *const ConnectPrep, callback: *Eventer.Callback) !void { 259 | _ = self; 260 | _ = prep; 261 | _ = callback; 262 | } 263 | pub fn eventerRemove(self: *const ProxyConnect, prep: *const ConnectPrep) void { 264 | _ = self; 265 | _ = prep; 266 | } 267 | }; 268 | pub const TcpListen = struct { 269 | port: u16, 270 | //listenAddr: ?Address, 271 | pub fn parse(spec: []const u8) !TcpListen { 272 | var rest = spec; 273 | const port = try common.parsePort(rest); 274 | return TcpListen { 275 | .port = port, 276 | }; 277 | } 278 | pub fn prepareConnect(self: *const TcpListen) !ConnectPrep { 279 | var listenAddr = Address.initIp4([4]u8{0,0,0,0}, self.port); 280 | const listenFd = netext.makeListenSock(&listenAddr, 1) catch |e| switch (e) { 281 | error.Retry => return error.RetryPrepareConnect, 282 | }; 283 | return ConnectPrep { .TcpListen = .{ .listenFd = listenFd } }; 284 | } 285 | fn getListenFd(prep: *const ConnectPrep) socket_t { 286 | return switch (prep.*) { 287 | .TcpListen => |p| p.listenFd, 288 | else => @panic("code bug: connect prep type is wrong"), 289 | }; 290 | } 291 | pub fn unprepareConnect(self: *const TcpListen, prep: *const ConnectPrep) void { 292 | _ = self; 293 | os.close(getListenFd(prep)); 294 | } 295 | pub fn connect(self: *const TcpListen, prep: *const ConnectPrep) !InOut { 296 | _ = self; 297 | const listenFd = getListenFd(prep); 298 | var clientAddr : Address = undefined; 299 | var clientAddrLen : os.socklen_t = @sizeOf(@TypeOf(clientAddr)); 300 | const clientFd = netext.accept(listenFd, &clientAddr.any, &clientAddrLen, 0) catch |e| switch (e) { 301 | error.ClientDropped, error.Retry => return error.RetryConnect, 302 | }; 303 | log("accepted client from {}", .{clientAddr}); 304 | return InOut { .in = clientFd, .out = clientFd }; 305 | } 306 | pub fn disconnect(self: *const TcpListen, inOut: InOut) void { 307 | _ = self; 308 | std.debug.assert(inOut.in == inOut.out); 309 | common.shutdownclose(inOut.in); 310 | } 311 | pub fn eventerAdd(self: *const TcpListen, prep: *const ConnectPrep, callback: *Eventer.Callback) !void { 312 | _ = self; 313 | const listenFd = getListenFd(prep); 314 | callback.* = Eventer.Callback { 315 | .func = onAccept, 316 | .data = .{ .inOut = InOut {.in = listenFd, .out = undefined } }, 317 | }; 318 | common.eventerAdd(Eventer, &global.eventer, listenFd, EventFlags.read, callback) catch |e| switch (e) { 319 | error.Retry => return error.EpollError, 320 | }; 321 | } 322 | pub fn eventerRemove(self: *const TcpListen, prep: *const ConnectPrep) void { 323 | _ = self; 324 | global.eventer.remove(getListenFd(prep)); 325 | } 326 | }; 327 | }; 328 | 329 | const InOut = struct { in: socket_t, out: socket_t }; 330 | 331 | fn usage() void { 332 | std.debug.warn( 333 | \\Usage: socat ADDRESS1 ADDRESS2 334 | \\Address Specifiers: 335 | \\ tcp-connect:: 336 | \\ tcp-listen:[,] 337 | \\ proxy-connect:::: 338 | , .{}); 339 | } 340 | 341 | pub fn main() anyerror!u8 { 342 | if (builtin.os.tag == .windows) { 343 | _ = try std.os.windows.WSAStartup(2, 2); 344 | } 345 | 346 | var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator); 347 | var args = try std.process.argsAlloc(&arena.allocator); 348 | if (args.len <= 1) { 349 | usage(); 350 | return 1; 351 | } 352 | args = args[1..]; 353 | 354 | { 355 | var newArgsLen : usize = 0; 356 | defer args = args[0..newArgsLen]; 357 | var i : usize = 0; 358 | while (i < args.len) : (i += 1) { 359 | const arg = args[i]; 360 | if (!std.mem.startsWith(u8, arg, "-")) { 361 | args[newArgsLen] = arg; 362 | newArgsLen += 1; 363 | } else if (std.mem.eql(u8, arg, "--no-throttle")) { 364 | noThrottle = true; 365 | } else { 366 | std.debug.warn("Error: unknown command-line option '{s}'\n", .{arg}); 367 | return 1; 368 | } 369 | } 370 | } 371 | 372 | if (args.len != 2) { 373 | std.debug.warn("Error: expected 2 command-line arguments but got {}\n", .{args.len}); 374 | return 1; 375 | } 376 | 377 | global.addr1String = args[0]; 378 | global.addr2String = args[1]; 379 | 380 | global.addr1 = Addr.parse(global.addr1String) catch return 1; 381 | global.addr2 = Addr.parse(global.addr2String) catch return 1; 382 | 383 | global.eventer = try Eventer.init(.{}); 384 | defer global.eventer.deinit(); 385 | 386 | var prepareConnectThrottler = makeThrottler("addr1 prepare connect: "); 387 | while (true) { 388 | prepareConnectThrottler.throttle(); 389 | const addr1Prep = global.addr1.prepareConnect() catch |e| switch (e) { 390 | error.RetryPrepareConnect => continue, 391 | //error.Retry => continue, 392 | }; 393 | defer global.addr1.unprepareConnect(&addr1Prep); 394 | switch (sequenceConnectAddr1(&addr1Prep)) { 395 | //error.Disconnect => continue, 396 | } 397 | } 398 | } 399 | 400 | fn sequenceConnectAddr1(addr1Prep: *const ConnectPrep) error { } { 401 | var connectThrottler = makeThrottler("addr1 connect: "); 402 | while (true) { 403 | connectThrottler.throttle(); 404 | log("connecting to {s}...", .{global.addr1String}); 405 | const addr1InOut = global.addr1.connectSqueezeErrors(addr1Prep) catch |e| switch (e) { 406 | error.RetryConnect, error.Retry => continue, 407 | }; 408 | defer global.addr1.disconnect(addr1InOut); 409 | log("connected to {s} (in={} out={})", .{global.addr1String, addr1InOut.in, addr1InOut.out}); 410 | switch (sequencePrepareAddr2(addr1Prep, addr1InOut)) { 411 | error.Disconnect => continue, 412 | } 413 | } 414 | } 415 | 416 | fn sequencePrepareAddr2(addr1Prep: *const ConnectPrep, addr1InOut: InOut) error{ Disconnect } { 417 | var attempt : u16 = 0; 418 | var prepareThrottler = makeThrottler("addr2 prepare connect: "); 419 | while (true) { 420 | attempt += 1; 421 | if (attempt >= 10) { 422 | log("failed {} attempts to prepare address 2, disconnecting...", .{attempt}); 423 | return error.Disconnect; 424 | } 425 | prepareThrottler.throttle(); 426 | const addr2Prep = global.addr2.prepareConnect() catch |e| switch (e) { 427 | error.RetryPrepareConnect => continue, 428 | }; 429 | defer global.addr2.unprepareConnect(&addr2Prep); 430 | switch (sequenceConnectAddr2(addr1Prep, addr1InOut, &addr2Prep)) { 431 | error.Disconnect => return error.Disconnect, 432 | } 433 | } 434 | } 435 | 436 | fn sequenceConnectAddr2(addr1Prep: *const ConnectPrep, addr1InOut: InOut, addr2Prep: *const ConnectPrep) error{ Disconnect } { 437 | var attempt : u16 = 0; 438 | var connectThrottler = makeThrottler("addr2 connect: "); 439 | while (true) { 440 | attempt += 1; 441 | if (attempt >= 10) { 442 | log("failed {} attempts to connect to address 2, disconnecting...", .{attempt}); 443 | return error.Disconnect; 444 | } 445 | connectThrottler.throttle(); 446 | log("connecting to {s}...", .{global.addr2String}); 447 | const addr2InOut = global.addr2.connectSqueezeErrors(addr2Prep) catch |e| switch (e) { 448 | error.Retry, error.RetryConnect => continue, 449 | }; 450 | log("connected to {s} (in={} out={})", .{global.addr2String, addr2InOut.in, addr2InOut.out}); 451 | defer global.addr2.disconnect(addr2InOut); 452 | switch (sequenceSetupEventing(addr1Prep, addr1InOut, addr2InOut, addr2Prep)) { 453 | error.Disconnect => return error.Disconnect, 454 | } 455 | } 456 | } 457 | 458 | fn sequenceSetupEventing(addr1Prep: *const ConnectPrep, addr1InOut: InOut, addr2InOut: InOut, addr2Prep: *const ConnectPrep) error{ Disconnect } { 459 | var eventingThrottler = makeThrottler("setup eventing: "); 460 | while (true) { 461 | eventingThrottler.throttle(); 462 | switch (sequenceForwardLoop(addr1Prep, addr1InOut, addr2InOut, addr2Prep)) { 463 | error.EpollError => continue, 464 | error.Disconnect => return error.Disconnect, 465 | } 466 | } 467 | } 468 | 469 | fn sequenceForwardLoop(addr1Prep: *const ConnectPrep, addr1InOut: InOut, addr2InOut: InOut, addr2Prep: *const ConnectPrep) error { EpollError, Disconnect } { 470 | var addr1Callback = Eventer.Callback { 471 | .func = onAddr1Read, 472 | .data = .{ .inOut = InOut {.in = addr1InOut.in, .out = addr2InOut.out } }, 473 | }; 474 | common.eventerAdd(Eventer, &global.eventer, addr1InOut.in, EventFlags.read, &addr1Callback) catch |e| switch (e) { 475 | error.Retry => return error.EpollError, 476 | }; 477 | defer global.eventer.remove(addr1InOut.in); 478 | 479 | var addr2Callback = Eventer.Callback { 480 | .func = onAddr2Read, 481 | .data = .{ .inOut = InOut {.in = addr2InOut.in, .out = addr1InOut.out } }, 482 | }; 483 | common.eventerAdd(Eventer, &global.eventer, addr2InOut.in, EventFlags.read, &addr2Callback) catch |e| switch (e) { 484 | error.Retry => return error.EpollError, 485 | }; 486 | defer global.eventer.remove(addr2InOut.in); 487 | 488 | var addr1PrepCallback : Eventer.Callback = undefined; 489 | global.addr1.eventerAdd(addr1Prep, &addr1PrepCallback) catch |e| switch (e) { 490 | error.EpollError => return error.EpollError, 491 | }; 492 | defer global.addr1.eventerRemove(addr1Prep); 493 | 494 | var addr2PrepCallback : Eventer.Callback = undefined; 495 | global.addr2.eventerAdd(addr2Prep, &addr2PrepCallback) catch |e| switch (e) { 496 | error.EpollError => return error.EpollError, 497 | }; 498 | defer global.addr2.eventerRemove(addr2Prep); 499 | 500 | while (true) { 501 | global.eventer.handleEventsNoTimeout() catch |e| switch (e) { 502 | error.Disconnect => return error.Disconnect, 503 | }; 504 | } 505 | } 506 | 507 | fn onAddr1Read(eventer: *Eventer, callback: *Eventer.Callback) Eventer.CallbackError!void { 508 | _ = eventer; 509 | return onRead(true, callback); 510 | } 511 | fn onAddr2Read(eventer: *Eventer, callback: *Eventer.Callback) Eventer.CallbackError!void { 512 | _ = eventer; 513 | return try onRead(false, callback); 514 | } 515 | 516 | fn onRead(isAddr1Read: bool, callback: *Eventer.Callback) Eventer.CallbackError!void { 517 | // TODO: I should use the sendfile syscall if available 518 | const length = os.read(callback.data.inOut.in, &global.buffer) catch |e| { 519 | log("read failed: {}", .{e}); 520 | return error.Disconnect; 521 | }; 522 | if (length == 0) { 523 | log("read fd={} returned 0", .{callback.data.inOut.in}); 524 | return error.Disconnect; 525 | } 526 | if (common.tryWriteAll(callback.data.inOut.out, global.buffer[0..length])) |result| { 527 | log("write fd={} len={} failed with {}, wrote {}", .{callback.data.inOut.out, length, result.err, result.wrote}); 528 | return error.Disconnect; 529 | } 530 | _ = isAddr1Read; 531 | //const dirString : []const u8 = if (isAddr1Read) ">>>" else "<<<"; 532 | //log("[VERBOSE] {} {} bytes", .{dirString, length}); 533 | } 534 | 535 | fn onAccept(eventer: *Eventer, callback: *Eventer.Callback) Eventer.CallbackError!void { 536 | _ = eventer; 537 | var addr : Address = undefined; 538 | var addrLen : os.socklen_t = @sizeOf(Address); 539 | const fd = netext.accept(callback.data.inOut.in, &addr.any, &addrLen, 0) catch |e| switch (e) { 540 | error.Retry, error.ClientDropped => return, 541 | }; 542 | log("s={} already have client, dropping client s={} from {}", .{callback.data.inOut.in, fd, addr}); 543 | common.shutdownclose(fd); 544 | } 545 | --------------------------------------------------------------------------------