├── .gitignore ├── misc └── logo.png ├── LICENSE ├── src ├── common.zig ├── main.zig ├── connection.zig ├── client.zig ├── sender.zig └── receiver.zig └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | zig-cache/ 2 | zig-out/ -------------------------------------------------------------------------------- /misc/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikneym/ws/HEAD/misc/logo.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 nikneym 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/common.zig: -------------------------------------------------------------------------------- 1 | // maximum control frame length 2 | pub const MAX_CTL_FRAME_LENGTH = 125; 3 | 4 | pub const Opcode = enum (u4) { 5 | continuation = 0x0, 6 | text = 0x1, 7 | binary = 0x2, 8 | close = 0x8, 9 | ping = 0x9, 10 | pong = 0xA, 11 | // this one is custom for this implementation. 12 | // see how it's used in sender.zig. 13 | end = 0xF, 14 | _, 15 | }; 16 | 17 | pub const Header = packed struct { 18 | len: u64, 19 | opcode: Opcode, 20 | fin: bool, 21 | rsv1: bool = false, 22 | rsv2: bool = false, 23 | rsv3: bool = false, 24 | 25 | pub const Error = error{MaskedMessageFromServer}; 26 | }; 27 | 28 | pub const Message = struct { 29 | type: Opcode, 30 | data: []const u8, 31 | code: ?u16, // only used in close messages 32 | 33 | pub const Error = error{FragmentedMessage, UnknownOpcode}; 34 | 35 | /// Create a WebSocket message from given fields. 36 | pub fn from(opcode: Opcode, data: []const u8, code: ?u16) Message.Error!Message { 37 | switch (opcode) { 38 | .text, .binary, 39 | .ping, .pong, 40 | .close => {}, 41 | 42 | .continuation => return error.FragmentedMessage, 43 | else => return error.UnknownOpcode, 44 | } 45 | 46 | return Message{ .type = opcode, .data = data, .code = code }; 47 | } 48 | }; 49 | -------------------------------------------------------------------------------- /src/main.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const net = std.net; 3 | const mem = std.mem; 4 | const io = std.io; 5 | 6 | // these can be used directly too 7 | pub const Client = @import("client.zig").Client; 8 | pub const client = @import("client.zig").client; 9 | pub const Connection = @import("connection.zig").Connection; 10 | pub const Header = [2][]const u8; 11 | 12 | pub const Address = union(enum) { 13 | ip: std.net.Address, 14 | host: []const u8, 15 | 16 | pub fn resolve(host: []const u8, port: u16) Address { 17 | const ip = std.net.Address.parseIp(host, port) catch return Address{ .host = host }; 18 | return Address{ .ip = ip }; 19 | } 20 | }; 21 | 22 | // TODO: implement TLS connection 23 | /// Open a new WebSocket connection. 24 | /// Allocator is used for DNS resolving of host and the storage of response headers. 25 | pub fn connect(allocator: mem.Allocator, uri: std.Uri, request_headers: ?[]const Header) !Connection { 26 | const port: u16 = uri.port orelse 27 | if (mem.eql(u8, uri.scheme, "ws")) 80 28 | else if (mem.eql(u8, uri.scheme, "wss")) 443 29 | else return error.UnknownScheme; 30 | 31 | var stream = try switch (Address.resolve(uri.host orelse return error.MissingHost, port)) { 32 | .ip => |ip| net.tcpConnectToAddress(ip), 33 | .host => |host| net.tcpConnectToHost(allocator, host, port), 34 | }; 35 | errdefer stream.close(); 36 | 37 | return Connection.init(allocator, stream, uri, request_headers); 38 | } 39 | 40 | test "Simple connection to :8080" { 41 | const allocator = std.testing.allocator; 42 | 43 | var cli = try connect(allocator, try std.Uri.parse("ws://localhost:8080"), &.{ 44 | .{"Host", "localhost"}, 45 | .{"Origin", "http://localhost/"}, 46 | }); 47 | defer cli.deinit(allocator); 48 | 49 | while (true) { 50 | const msg = try cli.receive(); 51 | switch (msg.type) { 52 | .text => { 53 | std.debug.print("received: {s}\n", .{msg.data}); 54 | try cli.send(.text, msg.data); 55 | }, 56 | 57 | .ping => { 58 | std.debug.print("got ping! sending pong...\n", .{}); 59 | try cli.pong(); 60 | }, 61 | 62 | .close => { 63 | std.debug.print("close", .{}); 64 | break; 65 | }, 66 | 67 | else => { 68 | std.debug.print("got {s}: {s}\n", .{@tagName(msg.type), msg.data}); 69 | }, 70 | } 71 | } 72 | 73 | try cli.close(); 74 | } 75 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | ws 3 |

4 | 5 | ws 6 | =========== 7 | a lightweight WebSocket library for Zig ⚡ 8 | 9 | Features 10 | =========== 11 | * Only allocates for WebSocket handshake, message parsing and building does not allocate 12 | * Ease of use, can be used directly with `net.Stream` 13 | * Does buffered reads and writes (can be used with any other reader/writer too) 14 | * Supports streaming output thanks to WebSocket fragmentation 15 | 16 | Example 17 | =========== 18 | By default, ws uses the `Stream` interface of `net` namespace. 19 | You can use your choice of stream through `ws.Client` interface. 20 | ```zig 21 | test "Simple connection to :8080" { 22 | const allocator = std.testing.allocator; 23 | 24 | var cli = try connect(allocator, try std.Uri.parse("ws://localhost:8080"), &.{ 25 | .{"Host", "localhost"}, 26 | .{"Origin", "http://localhost/"}, 27 | }); 28 | defer cli.deinit(allocator); 29 | 30 | while (true) { 31 | const msg = try cli.receive(); 32 | switch (msg.type) { 33 | .text => { 34 | std.debug.print("received: {s}\n", .{msg.data}); 35 | try cli.send(.text, msg.data); 36 | }, 37 | 38 | .ping => { 39 | std.debug.print("got ping! sending pong...\n", .{}); 40 | try cli.pong(); 41 | }, 42 | 43 | .close => { 44 | std.debug.print("close", .{}); 45 | break; 46 | }, 47 | 48 | else => { 49 | std.debug.print("got {s}: {s}\n", .{@tagName(msg.type), msg.data}); 50 | }, 51 | } 52 | } 53 | 54 | try cli.close(); 55 | } 56 | ``` 57 | 58 | Planned 59 | =========== 60 | - [ ] WebSocket server support 61 | - [ ] TLS support out of the box (tracks `std.crypto.tls.Client`) 62 | - [x] Request & response headers 63 | - [ ] WebSocket Compression support 64 | 65 | Acknowledgements 66 | =========== 67 | This library wouldn't be possible without these cool projects & posts: 68 | * [truemedian/wz](https://github.com/truemedian/wz) 69 | * [frmdstryr/zhp](https://github.com/frmdstryr/zhp/blob/master/src/websocket.zig) 70 | * [treeform/ws](https://github.com/treeform/ws) 71 | * [openmymind.net/WebSocket-Framing-Masking-Fragmentation-and-More](https://www.openmymind.net/WebSocket-Framing-Masking-Fragmentation-and-More/) 72 | * [Writing WebSocket servers](https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API/Writing_WebSocket_servers) 73 | 74 | License 75 | =========== 76 | MIT License, [check out](https://github.com/nikneym/ws/blob/main/LICENSE). 77 | -------------------------------------------------------------------------------- /src/connection.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const net = std.net; 3 | const mem = std.mem; 4 | const io = std.io; 5 | 6 | const Client = @import("client.zig").Client; 7 | const client = @import("client.zig").client; 8 | 9 | const common = @import("common.zig"); 10 | const Opcode = common.Opcode; 11 | const Message = common.Message; 12 | 13 | const READ_BUFFER_SIZE: usize = 1024 * 8; 14 | const WRITE_BUFFER_SIZE: usize = 1024 * 4; 15 | 16 | /// This is the direct implementation of ws over regular net.Stream. 17 | /// The Connection object will always use the current Stream implementation of net namespace. 18 | pub const Connection = struct { 19 | underlying_stream: net.Stream, 20 | ws_client: *WsClient, 21 | buffered_reader: BufferedReader, 22 | headers: std.StringHashMapUnmanaged([]const u8), 23 | 24 | /// general types 25 | const WsClient = Client(Reader, Writer, READ_BUFFER_SIZE, WRITE_BUFFER_SIZE); 26 | const BufferedReader = io.BufferedReader(4096, net.Stream.Reader); 27 | const Reader = BufferedReader.Reader; 28 | const Writer = net.Stream.Writer; 29 | 30 | pub fn init( 31 | allocator: mem.Allocator, 32 | underlying_stream: net.Stream, 33 | uri: std.Uri, 34 | request_headers: ?[]const [2][]const u8, 35 | ) !Connection { 36 | var buffered_reader = BufferedReader{ .unbuffered_reader = underlying_stream.reader() }; 37 | var writer = underlying_stream.writer(); 38 | 39 | const ws_client = try allocator.create(WsClient); 40 | errdefer allocator.destroy(ws_client); 41 | 42 | ws_client.* = client( 43 | buffered_reader.reader(), 44 | writer, 45 | READ_BUFFER_SIZE, 46 | WRITE_BUFFER_SIZE, 47 | ); 48 | 49 | var self = Connection{ 50 | .underlying_stream = underlying_stream, 51 | .ws_client = ws_client, 52 | .buffered_reader = buffered_reader, 53 | .headers = .{}, 54 | }; 55 | 56 | try self.ws_client.handshake(allocator, uri, request_headers, &self.headers); 57 | return self; 58 | } 59 | 60 | pub fn deinit(self: *Connection, allocator: mem.Allocator) void { 61 | defer allocator.destroy(self.ws_client); 62 | self.ws_client.deinit(allocator, &self.headers); 63 | self.underlying_stream.close(); 64 | } 65 | 66 | /// Send a WebSocket message to the server. 67 | /// The `opcode` field can be text, binary, ping, pong or close. 68 | /// In order to send continuation frames or streaming messages, check out `stream` function. 69 | pub fn send(self: Connection, opcode: Opcode, data: []const u8) !void { 70 | return self.ws_client.send(opcode, data); 71 | } 72 | 73 | /// Send a ping message to the server. 74 | pub fn ping(self: Connection) !void { 75 | return self.send(.ping, ""); 76 | } 77 | 78 | /// Send a pong message to the server. 79 | pub fn pong(self: Connection) !void { 80 | return self.send(.pong, ""); 81 | } 82 | 83 | /// Send a close message to the server. 84 | pub fn close(self: Connection) !void { 85 | return self.ws_client.close(); 86 | } 87 | 88 | /// TODO: Add usage example 89 | /// Send send continuation frames or streaming messages to the server. 90 | pub fn stream(self: Connection, opcode: Opcode, payload: ?[]const u8) !void { 91 | return self.ws_client.stream(opcode, payload); 92 | } 93 | 94 | /// Receive a message from the server. 95 | pub fn receive(self: Connection) !Message { 96 | return self.ws_client.receive(); 97 | } 98 | }; 99 | -------------------------------------------------------------------------------- /src/client.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const mem = std.mem; 3 | 4 | const common = @import("common.zig"); 5 | const Message = common.Message; 6 | const Opcode = common.Opcode; 7 | 8 | const Receiver = @import("receiver.zig").Receiver; 9 | const Sender = @import("sender.zig").Sender; 10 | 11 | /// Create a new WebSocket client. 12 | /// This interface is for using your own reader and writer. 13 | pub fn client( 14 | reader: anytype, 15 | writer: anytype, 16 | comptime read_buffer_size: usize, 17 | comptime write_buffer_size: usize, 18 | ) Client(@TypeOf(reader), @TypeOf(writer), read_buffer_size, write_buffer_size) 19 | { 20 | var mask: [4]u8 = undefined; 21 | std.crypto.random.bytes(&mask); 22 | 23 | return .{ 24 | .receiver = .{ .reader = reader }, 25 | .sender = .{ .writer = writer, .mask = mask }, 26 | }; 27 | } 28 | 29 | /// Create a new WebSocket client. 30 | /// This interface is for using your own reader and writer. 31 | pub fn Client( 32 | comptime Reader: type, 33 | comptime Writer: type, 34 | comptime read_buffer_size: usize, 35 | comptime write_buffer_size: usize, 36 | ) type { 37 | return struct { 38 | const Self = @This(); 39 | 40 | receiver: Receiver(Reader, read_buffer_size), 41 | sender: Sender(Writer, write_buffer_size), 42 | 43 | /// Deallocate response headers. 44 | pub fn deinit( 45 | self: Self, 46 | allocator: mem.Allocator, 47 | headers: *std.StringHashMapUnmanaged([]const u8), 48 | ) void { 49 | self.receiver.freeHttpHeaders(allocator, headers); 50 | } 51 | 52 | pub fn handshake( 53 | self: *Self, 54 | allocator: mem.Allocator, 55 | uri: std.Uri, 56 | request_headers: ?[]const [2][]const u8, 57 | response_headers: *std.StringHashMapUnmanaged([]const u8), 58 | ) !void { 59 | // create a random Sec-WebSocket-Key 60 | var buf: [24]u8 = undefined; 61 | std.crypto.random.bytes(buf[0..16]); 62 | const key = std.base64.standard.Encoder.encode(&buf, buf[0..16]); 63 | 64 | try self.sender.sendRequest(uri, request_headers, key); 65 | try self.receiver.receiveResponse(allocator, response_headers); 66 | errdefer self.receiver.freeHttpHeaders(allocator, response_headers); 67 | 68 | try checkWebSocketAcceptKey(response_headers.*, key); 69 | } 70 | 71 | const WsAcceptKeyError = error{KeyControlFailed, AcceptKeyNotFound}; 72 | 73 | /// Controls the accept key received from the server 74 | fn checkWebSocketAcceptKey( 75 | headers: std.StringHashMapUnmanaged([]const u8), 76 | key: []const u8, 77 | ) WsAcceptKeyError!void { 78 | if (headers.get("Sec-WebSocket-Accept")) |sec_websocket_accept| { 79 | const magic_string = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; 80 | 81 | var hash_buf: [20]u8 = undefined; 82 | var h = std.crypto.hash.Sha1.init(.{}); 83 | h.update(key); 84 | h.update(magic_string); 85 | h.final(&hash_buf); 86 | 87 | var encoded_hash_buf: [28]u8 = undefined; 88 | const our_accept = std.base64.standard.Encoder.encode(&encoded_hash_buf, &hash_buf); 89 | 90 | if (!mem.eql(u8, our_accept, sec_websocket_accept)) 91 | return error.KeyControlFailed; 92 | } else return error.AcceptKeyNotFound; 93 | } 94 | 95 | /// Send a WebSocket message to the server. 96 | /// The `opcode` field can be text, binary, ping, pong or close. 97 | /// In order to send continuation frames or streaming messages, check out `stream` function. 98 | pub fn send(self: *Self, opcode: Opcode, data: []const u8) !void { 99 | return self.sender.send(opcode, data); 100 | } 101 | 102 | /// Send a ping message to the server. 103 | pub fn ping(self: *Self) !void { 104 | return self.send(.ping, ""); 105 | } 106 | 107 | /// Send a pong message to the server. 108 | pub fn pong(self: *Self) !void { 109 | return self.send(.pong, ""); 110 | } 111 | 112 | /// Send a close message to the server. 113 | pub fn close(self: *Self) !void { 114 | return self.sender.close(); 115 | } 116 | 117 | /// TODO: Add usage example 118 | /// Send send continuation frames or streaming messages to the server. 119 | pub fn stream(self: *Self, opcode: Opcode, payload: ?[]const u8) !void { 120 | return self.sender.stream(opcode, payload); 121 | } 122 | 123 | /// Receive a message from the server. 124 | pub fn receive(self: *Self) !Message { 125 | return self.receiver.receive(); 126 | } 127 | }; 128 | } 129 | -------------------------------------------------------------------------------- /src/sender.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const mem = std.mem; 3 | const io = std.io; 4 | const common = @import("common.zig"); 5 | const Opcode = common.Opcode; 6 | const Header = common.Header; 7 | 8 | const MAX_CTL_FRAME_LENGTH = common.MAX_CTL_FRAME_LENGTH; 9 | const MASK_BUFFER_SIZE: usize = 1024; 10 | const DEFAULT_CLOSE_CODE: u16 = 1000; 11 | 12 | fn getUriFullPath(uri: std.Uri) ![]const u8 { 13 | var buf: [MASK_BUFFER_SIZE]u8 = undefined; 14 | return try std.fmt.bufPrint(&buf, "{}", .{uri}); 15 | } 16 | 17 | pub fn Sender(comptime Writer: type, comptime capacity: usize) type { 18 | return struct { 19 | const Self = @This(); 20 | 21 | writer: Writer, 22 | mask: [4]u8, 23 | // for buffered writes 24 | buffer: [capacity]u8 = undefined, 25 | end: usize = 0, 26 | 27 | pub fn sendRequest( 28 | self: *Self, 29 | uri: std.Uri, 30 | request_headers: ?[]const [2][]const u8, 31 | sec_websocket_key: []const u8, 32 | ) !void { 33 | // push http request line 34 | try self.put("GET "); 35 | try self.put(try getUriFullPath(uri)); 36 | try self.put(" HTTP/1.1\r\n"); 37 | 38 | // push default headers 39 | const default_headers = 40 | "Pragma: no-cache\r\n" ++ 41 | "Cache-Control: no-cache\r\n" ++ 42 | "Connection: Upgrade\r\n" ++ 43 | "Upgrade: websocket\r\n" ++ 44 | "Sec-WebSocket-Version: 13\r\n"; 45 | try self.put(default_headers); 46 | 47 | // push websocket key 48 | try self.put("Sec-WebSocket-Key: "); 49 | try self.put(sec_websocket_key); 50 | try self.put("\r\n"); 51 | 52 | // push user defined headers 53 | if (request_headers) |headers| { 54 | for (headers) |header| { 55 | try self.put(header[0]); 56 | try self.put(": "); 57 | try self.put(header[1]); 58 | try self.put("\r\n"); 59 | } 60 | } 61 | 62 | // send 'em all 63 | try self.put("\r\n"); 64 | return self.flush(); 65 | } 66 | 67 | /// Write bytes that're buffered in Sender and reset the terminator. 68 | fn flush(self: *Self) Writer.Error!void { 69 | try self.writer.writeAll(self.buffer[0..self.end]); 70 | self.end = 0; 71 | } 72 | 73 | /// Does buffered writes, pretty similar to io.BufferedWriter. 74 | fn put(self: *Self, bytes: []const u8) Writer.Error!void { 75 | if (self.end + bytes.len > self.buffer.len) { 76 | try self.flush(); 77 | if (bytes.len > self.buffer.len) 78 | return self.writer.writeAll(bytes); 79 | } 80 | 81 | mem.copy(u8, self.buffer[self.end..], bytes); 82 | self.end += bytes.len; 83 | } 84 | 85 | fn putHeader(self: *Self, header: Header) Writer.Error!void { 86 | var buf: [14]u8 = undefined; 87 | 88 | buf[0] = @as(u8, @intFromEnum(header.opcode)); 89 | if (header.fin) buf[0] |= 0x80; 90 | 91 | buf[1] = 0x80; 92 | if (header.len < 126) { 93 | buf[1] |= @truncate(header.len); 94 | mem.copy(u8, buf[2..], &self.mask); 95 | 96 | // 2 + 4 97 | return self.put(buf[0..6]); 98 | } else if (header.len < 65536) { 99 | buf[1] |= 126; 100 | mem.writeIntBig(u16, buf[2..4], @as(u16, @truncate(header.len))); 101 | mem.copy(u8, buf[4..], &self.mask); 102 | 103 | // 2 + 2 + 4 104 | return self.put(buf[0..8]); 105 | } else { 106 | buf[1] |= 127; 107 | mem.writeIntBig(u64, buf[2..10], header.len); 108 | mem.copy(u8, buf[10..], &self.mask); 109 | 110 | // 2 + 8 + 4 111 | return self.put(&buf); 112 | } 113 | 114 | unreachable; 115 | } 116 | 117 | fn maskBytes(self: Self, buf: []u8, source: []const u8, pos: usize) void { 118 | for (source, 0..) |c, i| 119 | buf[i] = c ^ self.mask[(i + pos) % 4]; 120 | } 121 | 122 | fn putMasked(self: *Self, data: []const u8) Writer.Error!void { 123 | var buf: [MASK_BUFFER_SIZE]u8 = undefined; 124 | 125 | // small payload, cool stuff! 126 | if (data.len <= MASK_BUFFER_SIZE) { 127 | self.maskBytes(buf[0..data.len], data, 0); 128 | return self.put(buf[0..data.len]); 129 | } 130 | 131 | const remainder = data.len % MASK_BUFFER_SIZE; 132 | const num_of_chunks = (data.len - remainder) / MASK_BUFFER_SIZE; 133 | var current_chunk: usize = 0; 134 | var pos: usize = 0; 135 | 136 | while (current_chunk < num_of_chunks) : (current_chunk += 1) { 137 | pos = current_chunk * MASK_BUFFER_SIZE; 138 | const chunk = data[pos..pos + MASK_BUFFER_SIZE]; 139 | 140 | self.maskBytes(buf[0..], chunk, pos); 141 | try self.put(buf[0..]); 142 | } 143 | 144 | if (remainder == 0) 145 | return; 146 | 147 | // got remainder 148 | pos += MASK_BUFFER_SIZE; 149 | const chunk = data[pos..pos + remainder]; 150 | 151 | self.maskBytes(&buf, chunk, pos); 152 | return self.put(buf[0..remainder]); 153 | } 154 | 155 | // ---------------------------------- 156 | // Send API 157 | // ---------------------------------- 158 | 159 | /// Send a WebSocket message. 160 | pub fn send(self: *Self, opcode: Opcode, data: []const u8) !void { 161 | return switch (opcode) { 162 | .text, .binary => self.regular(opcode, data), 163 | .ping, .pong => self.pingPong(opcode, data), 164 | .close => self.close(), 165 | 166 | .continuation, .end => error.UseStreamInstead, 167 | else => error.UnknownOpcode, 168 | }; 169 | } 170 | 171 | // text + binary messages 172 | fn regular(self: *Self, opcode: Opcode, data: []const u8) !void { 173 | try self.putHeader(.{ 174 | .len = data.len, 175 | .opcode = opcode, 176 | .fin = true, 177 | }); 178 | try self.putMasked(data); 179 | 180 | return self.flush(); 181 | } 182 | 183 | // the name implies 184 | fn pingPong(self: *Self, opcode: Opcode, data: []const u8) !void { 185 | if (data.len > MAX_CTL_FRAME_LENGTH) 186 | return error.PayloadTooBig; 187 | 188 | try self.putHeader(.{ 189 | .len = data.len, 190 | .opcode = opcode, 191 | .fin = true, 192 | }); 193 | try self.putMasked(data); 194 | 195 | return self.flush(); 196 | } 197 | 198 | // TODO: implement close code & reason. 199 | pub fn close(self: *Self) !void { 200 | try self.putHeader(.{ 201 | .len = 0, 202 | .opcode = .close, 203 | .fin = true, 204 | }); 205 | 206 | return self.flush(); 207 | } 208 | 209 | // ---------------------------------- 210 | // Stream API 211 | // ---------------------------------- 212 | 213 | /// writes data piece by piece, good for streaming big or unknown amounts of data as chunks. 214 | pub fn stream(self: *Self, opcode: Opcode, payload: ?[]const u8) !void { 215 | if (payload) |data| { 216 | return switch (opcode) { 217 | .text, .binary => self.fragmented(opcode, data, false), 218 | .continuation => self.fragmented(opcode, data, false), 219 | .end => self.fragmented(.continuation, data, true), 220 | 221 | else => error.UnknownOpcode, 222 | }; 223 | } 224 | 225 | try self.putHeader(.{ 226 | .len = 0, 227 | .opcode = switch (opcode) { 228 | .text, .binary, 229 | .continuation => opcode, 230 | .end => .continuation, 231 | 232 | else => return error.UnknownOpcode, 233 | }, 234 | .fin = switch (opcode) { 235 | .text, .binary, 236 | .continuation => false, 237 | .end => true, 238 | 239 | else => return error.UnknownOpcode, 240 | }, 241 | }); 242 | 243 | return self.flush(); 244 | } 245 | 246 | fn fragmented(self: *Self, opcode: Opcode, data: []const u8, fin: bool) !void { 247 | try self.putHeader(.{ 248 | .len = data.len, 249 | .opcode = opcode, 250 | .fin = fin, 251 | }); 252 | try self.putMasked(data); 253 | 254 | return self.flush(); 255 | } 256 | }; 257 | } 258 | 259 | test "std.Uri processing results in expected paths" { 260 | const uris = [_]std.Uri { 261 | try std.Uri.parse("ws://localhost"), 262 | try std.Uri.parse("ws://localhost/"), 263 | try std.Uri.parse("ws://localhost?query=example"), 264 | try std.Uri.parse("ws://localhost/?query=example"), 265 | try std.Uri.parse("ws://localhost/?query1=&&something&query2=somethingelse"), 266 | try std.Uri.parse("ws://localhost/?query1=something with spaces&query2=somethingelse"), 267 | try std.Uri.parse("ws://localhost:8080"), 268 | try std.Uri.parse("ws://localhost:8080/"), 269 | try std.Uri.parse("ws://localhost:8080?query=example"), 270 | try std.Uri.parse("ws://localhost:8080/?query=example"), 271 | try std.Uri.parse("ws://localhost:8080/?query1=&&something&query2=somethingelse"), 272 | try std.Uri.parse("ws://localhost:8080/?query1=something with spaces&query2=somethingelse"), 273 | }; 274 | 275 | const paths = [_][]const u8{ 276 | "/", 277 | "/", 278 | "/?query=example", 279 | "/?query=example", 280 | "/?query1=&&something&query2=somethingelse", 281 | "/?query1=something%20with%20spaces&query2=somethingelse", 282 | "/", 283 | "/", 284 | "/?query=example", 285 | "/?query=example", 286 | "/?query1=&&something&query2=somethingelse", 287 | "/?query1=something%20with%20spaces&query2=somethingelse", 288 | }; 289 | 290 | for (uris, paths) |uri, path| { 291 | try std.testing.expectEqualSlices(u8, path, try getUriFullPath(uri)); 292 | } 293 | } 294 | -------------------------------------------------------------------------------- /src/receiver.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const mem = std.mem; 3 | const common = @import("common.zig"); 4 | const Opcode = common.Opcode; 5 | const Header = common.Header; 6 | const Message = common.Message; 7 | 8 | const MAX_CTL_FRAME_LENGTH = common.MAX_CTL_FRAME_LENGTH; 9 | // max header size can be 10 * u8, 10 | // if masking is allowed, header size can be up to 14 * u8 11 | // server should not be sending masked messages. 12 | const MAX_HEADER_SIZE = 10; 13 | 14 | pub fn Receiver(comptime Reader: type, comptime capacity: usize) type { 15 | return struct { 16 | const Self = @This(); 17 | 18 | reader: Reader, 19 | buffer: [capacity]u8 = undefined, 20 | header_buffer: [MAX_HEADER_SIZE]u8 = undefined, 21 | // specified for ping, pong and close frames. 22 | control_buffer: [MAX_CTL_FRAME_LENGTH]u8 = undefined, 23 | end: usize = 0, 24 | fragmentation: Fragmentation = .{}, 25 | 26 | const Fragmentation = struct { 27 | on: bool = false, 28 | opcode: Opcode = .text, 29 | }; 30 | 31 | /// Deallocate HTTP response headers and string hashmap. 32 | pub fn freeHttpHeaders( 33 | _: Self, 34 | allocator: mem.Allocator, 35 | headers: *std.StringHashMapUnmanaged([]const u8), 36 | ) void { 37 | defer headers.deinit(allocator); 38 | var iter = headers.iterator(); 39 | while (iter.next()) |entry| { 40 | //std.debug.print("{s}: {s}\n", .{entry.key_ptr.*, entry.value_ptr.*}); 41 | allocator.free(entry.key_ptr.*); 42 | allocator.free(entry.value_ptr.*); 43 | } 44 | } 45 | 46 | /// Receive and allocate for HTTP headers, uses a StringHashMapUnmanaged([]const u8) to store the parsed headers. 47 | pub fn receiveResponse( 48 | self: Self, 49 | allocator: mem.Allocator, 50 | headers: *std.StringHashMapUnmanaged([]const u8), 51 | ) !void { 52 | errdefer self.freeHttpHeaders(allocator, headers); 53 | var buf: [2048]u8 = undefined; 54 | var i: usize = 0; 55 | var state: enum { key, value } = .key; 56 | var key_ptr: ?[]u8 = null; 57 | 58 | // HTTP/1.1 101 Switching Protocols 59 | const request_line = try self.reader.readUntilDelimiter(&buf, '\n'); 60 | if (request_line.len < 32) return error.FailedSwitchingProtocols; 61 | if (!mem.eql(u8, request_line[0..32], "HTTP/1.1 101 Switching Protocols")) 62 | return error.FailedSwitchingProtocols; 63 | 64 | while (true) { 65 | const b = try self.reader.readByte(); 66 | switch (state) { 67 | .key => switch (b) { 68 | ':' => { // delimiter of key 69 | // make sure space comes afterwards 70 | if (try self.reader.readByte() == ' ') { 71 | key_ptr = try allocator.dupe(u8, buf[0..i]); 72 | i = 0; 73 | state = .value; 74 | } else { 75 | return error.BadHttpResponse; 76 | } 77 | }, 78 | '\r' => { 79 | if (try self.reader.readByte() == '\n') break; 80 | return error.BadHttpResponse; 81 | }, 82 | '\n' => break, 83 | 84 | else => { 85 | buf[i] = b; 86 | if (i < buf.len) { 87 | i += 1; 88 | } else { 89 | return error.HttpHeaderTooLong; 90 | } 91 | }, 92 | }, 93 | 94 | .value => switch (b) { 95 | '\r' => { 96 | // make sure '\n' comes afterwards 97 | if (try self.reader.readByte() == '\n') { 98 | if (key_ptr) |ptr| { 99 | errdefer allocator.free(ptr); 100 | if (headers.contains(ptr)) { 101 | return error.RepeatingHttpHeader; 102 | // FIXME: alternative 103 | //const entry = headers.getEntry(ptr).?; 104 | //allocator.free(entry.key_ptr.*); 105 | //allocator.free(entry.value_ptr.*); 106 | } 107 | 108 | try headers.put(allocator, ptr, try allocator.dupe(u8, buf[0..i])); 109 | } else { 110 | return error.BadHttpResponse; 111 | } 112 | 113 | i = 0; 114 | state = .key; 115 | } else { 116 | return error.BadHttpResponse; 117 | } 118 | }, 119 | 120 | else => { 121 | buf[i] = b; 122 | if (i < buf.len) { 123 | i += 1; 124 | } else { 125 | return error.HttpHeaderTooLong; 126 | } 127 | }, 128 | }, 129 | } 130 | } 131 | } 132 | 133 | pub const GetHeaderError = error{EndOfStream} || Header.Error || Reader.Error; 134 | 135 | fn getHeader(self: *Self) GetHeaderError!Header { 136 | const buf = self.header_buffer[0..2]; 137 | 138 | const len = try self.reader.readAll(buf); 139 | if (len < 2) return error.EndOfStream; 140 | 141 | const is_masked = buf[1] & 0x80 != 0; 142 | if (is_masked) 143 | return error.MaskedMessageFromServer; // FIXME: should this be allowed? 144 | 145 | // get length from variable length 146 | const var_length: u7 = @truncate(buf[1] & 0x7F); 147 | const length = try self.getLength(var_length); 148 | 149 | const b = buf[0]; 150 | const fin = b & 0x80 != 0; 151 | const rsv1 = b & 0x40 != 0; 152 | const rsv2 = b & 0x20 != 0; 153 | const rsv3 = b & 0x10 != 0; 154 | 155 | const op = b & 0x0F; 156 | const opcode: Opcode = @enumFromInt(@as(u4, @truncate(op))); 157 | 158 | return Header{ 159 | .len = length, 160 | .opcode = opcode, 161 | .fin = fin, 162 | .rsv1 = rsv1, 163 | .rsv2 = rsv2, 164 | .rsv3 = rsv3, 165 | }; 166 | } 167 | 168 | pub const GetLengthError = error{EndOfStream} || Reader.Error; 169 | 170 | fn getLength(self: *Self, var_length: u7) GetLengthError!u64 { 171 | return switch (var_length) { 172 | 126 => { 173 | const len = try self.reader.readAll(self.header_buffer[2..4]); 174 | if (len < 2) return error.EndOfStream; 175 | 176 | return mem.readIntBig(u16, self.header_buffer[2..4]); 177 | }, 178 | 179 | 127 => { 180 | const len = try self.reader.readAll(self.header_buffer[2..]); 181 | if (len < 8) return error.EndOfStream; 182 | 183 | return mem.readIntBig(u64, self.header_buffer[2..]); 184 | }, 185 | 186 | inline else => var_length, 187 | }; 188 | } 189 | 190 | fn pingPong(self: *Self, header: Header) FrameError!Message { 191 | if (header.len > self.control_buffer.len) 192 | return error.PayloadTooBig; 193 | 194 | const buf = self.control_buffer[0..header.len]; 195 | 196 | const len = try self.reader.readAll(buf); 197 | if (len < buf.len) 198 | return error.EndOfStream; 199 | 200 | return Message.from(header.opcode, buf, null); 201 | } 202 | 203 | fn close(self: *Self, header: Header) FrameError!Message { 204 | if (header.len > self.control_buffer.len) 205 | return error.PayloadTooBig; 206 | 207 | const buf = self.control_buffer[0..header.len]; 208 | 209 | const len = try self.reader.readAll(buf); 210 | if (len < buf.len) 211 | return error.EndOfStream; 212 | 213 | return switch (buf.len) { 214 | 0 => Message.from(.close, buf, null), 215 | 216 | 2 => { // without reason but code 217 | const code = mem.readIntBig(u16, buf[0..2]); 218 | 219 | return Message.from(.close, buf, code); 220 | }, 221 | 222 | else => { // with reason 223 | const code = mem.readIntBig(u16, buf[0..2]); 224 | const reason = buf[2..]; 225 | 226 | return Message.from(.close, reason, code); 227 | } 228 | }; 229 | } 230 | 231 | pub const ContinuationError = error{UnknownOpcode} || FrameError || GetHeaderError; 232 | 233 | // this must be called when continuation frame is received 234 | fn continuation1(self: *Self, header: Header) (error{BadMessageOrder} || ContinuationError)!Message { 235 | if (!self.fragmentation.on) 236 | return error.BadMessageOrder; 237 | 238 | var last: Header = header; 239 | while (true) : (last = try self.getHeader()) { 240 | switch (last.opcode) { 241 | .continuation => {}, 242 | .text, .binary => return error.BadMessageOrder, 243 | .ping, .pong => return self.pingPong(last), 244 | .close => return self.close(last), 245 | 246 | else => return error.UnknownOpcode, 247 | } 248 | 249 | const boundary = self.end + last.len; 250 | if (boundary > self.buffer.len) 251 | return error.PayloadTooBig; 252 | 253 | const buf = self.buffer[self.end..boundary]; 254 | 255 | const len = try self.reader.readAll(buf); 256 | if (len < buf.len) 257 | return error.EndOfStream; 258 | 259 | self.end = boundary; 260 | if (last.fin) break; 261 | } 262 | 263 | const buf = self.buffer[0..self.end]; 264 | self.end = 0; 265 | 266 | return Message.from(self.fragmentation.opcode, buf, null); 267 | } 268 | 269 | // this must be called when text or binary frame without fin is received 270 | fn continuation(self: *Self, header: Header) ContinuationError!Message { 271 | // keep track of fragmentation 272 | self.fragmentation.on = true; 273 | self.fragmentation.opcode = header.opcode; 274 | 275 | var last: Header = header; 276 | // any of the control frames might sneak in to this while loop, 277 | // beware! 278 | while (true) : (last = try self.getHeader()) { 279 | switch (last.opcode) { 280 | .text, .binary, .continuation => {}, 281 | // disturbed 282 | .ping, .pong => return self.pingPong(last), 283 | .close => return self.close(last), 284 | 285 | else => return error.UnknownOpcode, 286 | } 287 | 288 | const boundary = self.end + last.len; 289 | if (boundary > self.buffer.len) 290 | return error.PayloadTooBig; 291 | 292 | const buf = self.buffer[self.end..boundary]; 293 | 294 | const len = try self.reader.readAll(buf); 295 | if (len < buf.len) 296 | return error.EndOfStream; 297 | 298 | self.end = boundary; 299 | if (last.fin) break; 300 | } 301 | 302 | const buf = self.buffer[0..self.end]; 303 | self.end = 0; 304 | 305 | return Message.from(self.fragmentation.opcode, buf, null); 306 | } 307 | 308 | pub const FrameError = error{ 309 | EndOfStream, 310 | PayloadTooBig, 311 | } || Message.Error || Reader.Error; 312 | 313 | fn regular(self: *Self, header: Header) FrameError!Message { 314 | const boundary = self.end + header.len; 315 | 316 | if (boundary > self.buffer.len) 317 | return error.PayloadTooBig; 318 | 319 | const buf = self.buffer[self.end..boundary]; 320 | 321 | const len = try self.reader.readAll(buf); 322 | if (len < buf.len) 323 | return error.EndOfStream; 324 | 325 | return Message.from(header.opcode, buf, null); 326 | } 327 | 328 | pub const Error = error{BadMessageOrder} || Header.Error || FrameError || ContinuationError; 329 | 330 | /// Receive the next message from the stream. 331 | pub fn receive(self: *Self) Error!Message { 332 | const header = try self.getHeader(); 333 | 334 | return switch (header.opcode) { 335 | .continuation => self.continuation1(header), 336 | .text, .binary => switch (header.fin) { 337 | true => self.regular(header), 338 | false => self.continuation(header), 339 | }, 340 | 341 | // control frames 342 | .ping, .pong => self.pingPong(header), 343 | .close => self.close(header), 344 | 345 | else => error.UnknownOpcode, 346 | }; 347 | } 348 | }; 349 | } 350 | --------------------------------------------------------------------------------