├── .gitignore
├── misc
└── logo.png
├── LICENSE
├── src
├── common.zig
├── main.zig
├── connection.zig
├── client.zig
├── sender.zig
└── receiver.zig
└── README.md
/.gitignore:
--------------------------------------------------------------------------------
1 | zig-cache/
2 | zig-out/
--------------------------------------------------------------------------------
/misc/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nikneym/ws/HEAD/misc/logo.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 nikneym
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/src/common.zig:
--------------------------------------------------------------------------------
1 | // maximum control frame length
2 | pub const MAX_CTL_FRAME_LENGTH = 125;
3 |
4 | pub const Opcode = enum (u4) {
5 | continuation = 0x0,
6 | text = 0x1,
7 | binary = 0x2,
8 | close = 0x8,
9 | ping = 0x9,
10 | pong = 0xA,
11 | // this one is custom for this implementation.
12 | // see how it's used in sender.zig.
13 | end = 0xF,
14 | _,
15 | };
16 |
17 | pub const Header = packed struct {
18 | len: u64,
19 | opcode: Opcode,
20 | fin: bool,
21 | rsv1: bool = false,
22 | rsv2: bool = false,
23 | rsv3: bool = false,
24 |
25 | pub const Error = error{MaskedMessageFromServer};
26 | };
27 |
28 | pub const Message = struct {
29 | type: Opcode,
30 | data: []const u8,
31 | code: ?u16, // only used in close messages
32 |
33 | pub const Error = error{FragmentedMessage, UnknownOpcode};
34 |
35 | /// Create a WebSocket message from given fields.
36 | pub fn from(opcode: Opcode, data: []const u8, code: ?u16) Message.Error!Message {
37 | switch (opcode) {
38 | .text, .binary,
39 | .ping, .pong,
40 | .close => {},
41 |
42 | .continuation => return error.FragmentedMessage,
43 | else => return error.UnknownOpcode,
44 | }
45 |
46 | return Message{ .type = opcode, .data = data, .code = code };
47 | }
48 | };
49 |
--------------------------------------------------------------------------------
/src/main.zig:
--------------------------------------------------------------------------------
1 | const std = @import("std");
2 | const net = std.net;
3 | const mem = std.mem;
4 | const io = std.io;
5 |
6 | // these can be used directly too
7 | pub const Client = @import("client.zig").Client;
8 | pub const client = @import("client.zig").client;
9 | pub const Connection = @import("connection.zig").Connection;
10 | pub const Header = [2][]const u8;
11 |
12 | pub const Address = union(enum) {
13 | ip: std.net.Address,
14 | host: []const u8,
15 |
16 | pub fn resolve(host: []const u8, port: u16) Address {
17 | const ip = std.net.Address.parseIp(host, port) catch return Address{ .host = host };
18 | return Address{ .ip = ip };
19 | }
20 | };
21 |
22 | // TODO: implement TLS connection
23 | /// Open a new WebSocket connection.
24 | /// Allocator is used for DNS resolving of host and the storage of response headers.
25 | pub fn connect(allocator: mem.Allocator, uri: std.Uri, request_headers: ?[]const Header) !Connection {
26 | const port: u16 = uri.port orelse
27 | if (mem.eql(u8, uri.scheme, "ws")) 80
28 | else if (mem.eql(u8, uri.scheme, "wss")) 443
29 | else return error.UnknownScheme;
30 |
31 | var stream = try switch (Address.resolve(uri.host orelse return error.MissingHost, port)) {
32 | .ip => |ip| net.tcpConnectToAddress(ip),
33 | .host => |host| net.tcpConnectToHost(allocator, host, port),
34 | };
35 | errdefer stream.close();
36 |
37 | return Connection.init(allocator, stream, uri, request_headers);
38 | }
39 |
40 | test "Simple connection to :8080" {
41 | const allocator = std.testing.allocator;
42 |
43 | var cli = try connect(allocator, try std.Uri.parse("ws://localhost:8080"), &.{
44 | .{"Host", "localhost"},
45 | .{"Origin", "http://localhost/"},
46 | });
47 | defer cli.deinit(allocator);
48 |
49 | while (true) {
50 | const msg = try cli.receive();
51 | switch (msg.type) {
52 | .text => {
53 | std.debug.print("received: {s}\n", .{msg.data});
54 | try cli.send(.text, msg.data);
55 | },
56 |
57 | .ping => {
58 | std.debug.print("got ping! sending pong...\n", .{});
59 | try cli.pong();
60 | },
61 |
62 | .close => {
63 | std.debug.print("close", .{});
64 | break;
65 | },
66 |
67 | else => {
68 | std.debug.print("got {s}: {s}\n", .{@tagName(msg.type), msg.data});
69 | },
70 | }
71 | }
72 |
73 | try cli.close();
74 | }
75 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | ws
6 | ===========
7 | a lightweight WebSocket library for Zig ⚡
8 |
9 | Features
10 | ===========
11 | * Only allocates for WebSocket handshake, message parsing and building does not allocate
12 | * Ease of use, can be used directly with `net.Stream`
13 | * Does buffered reads and writes (can be used with any other reader/writer too)
14 | * Supports streaming output thanks to WebSocket fragmentation
15 |
16 | Example
17 | ===========
18 | By default, ws uses the `Stream` interface of `net` namespace.
19 | You can use your choice of stream through `ws.Client` interface.
20 | ```zig
21 | test "Simple connection to :8080" {
22 | const allocator = std.testing.allocator;
23 |
24 | var cli = try connect(allocator, try std.Uri.parse("ws://localhost:8080"), &.{
25 | .{"Host", "localhost"},
26 | .{"Origin", "http://localhost/"},
27 | });
28 | defer cli.deinit(allocator);
29 |
30 | while (true) {
31 | const msg = try cli.receive();
32 | switch (msg.type) {
33 | .text => {
34 | std.debug.print("received: {s}\n", .{msg.data});
35 | try cli.send(.text, msg.data);
36 | },
37 |
38 | .ping => {
39 | std.debug.print("got ping! sending pong...\n", .{});
40 | try cli.pong();
41 | },
42 |
43 | .close => {
44 | std.debug.print("close", .{});
45 | break;
46 | },
47 |
48 | else => {
49 | std.debug.print("got {s}: {s}\n", .{@tagName(msg.type), msg.data});
50 | },
51 | }
52 | }
53 |
54 | try cli.close();
55 | }
56 | ```
57 |
58 | Planned
59 | ===========
60 | - [ ] WebSocket server support
61 | - [ ] TLS support out of the box (tracks `std.crypto.tls.Client`)
62 | - [x] Request & response headers
63 | - [ ] WebSocket Compression support
64 |
65 | Acknowledgements
66 | ===========
67 | This library wouldn't be possible without these cool projects & posts:
68 | * [truemedian/wz](https://github.com/truemedian/wz)
69 | * [frmdstryr/zhp](https://github.com/frmdstryr/zhp/blob/master/src/websocket.zig)
70 | * [treeform/ws](https://github.com/treeform/ws)
71 | * [openmymind.net/WebSocket-Framing-Masking-Fragmentation-and-More](https://www.openmymind.net/WebSocket-Framing-Masking-Fragmentation-and-More/)
72 | * [Writing WebSocket servers](https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API/Writing_WebSocket_servers)
73 |
74 | License
75 | ===========
76 | MIT License, [check out](https://github.com/nikneym/ws/blob/main/LICENSE).
77 |
--------------------------------------------------------------------------------
/src/connection.zig:
--------------------------------------------------------------------------------
1 | const std = @import("std");
2 | const net = std.net;
3 | const mem = std.mem;
4 | const io = std.io;
5 |
6 | const Client = @import("client.zig").Client;
7 | const client = @import("client.zig").client;
8 |
9 | const common = @import("common.zig");
10 | const Opcode = common.Opcode;
11 | const Message = common.Message;
12 |
13 | const READ_BUFFER_SIZE: usize = 1024 * 8;
14 | const WRITE_BUFFER_SIZE: usize = 1024 * 4;
15 |
16 | /// This is the direct implementation of ws over regular net.Stream.
17 | /// The Connection object will always use the current Stream implementation of net namespace.
18 | pub const Connection = struct {
19 | underlying_stream: net.Stream,
20 | ws_client: *WsClient,
21 | buffered_reader: BufferedReader,
22 | headers: std.StringHashMapUnmanaged([]const u8),
23 |
24 | /// general types
25 | const WsClient = Client(Reader, Writer, READ_BUFFER_SIZE, WRITE_BUFFER_SIZE);
26 | const BufferedReader = io.BufferedReader(4096, net.Stream.Reader);
27 | const Reader = BufferedReader.Reader;
28 | const Writer = net.Stream.Writer;
29 |
30 | pub fn init(
31 | allocator: mem.Allocator,
32 | underlying_stream: net.Stream,
33 | uri: std.Uri,
34 | request_headers: ?[]const [2][]const u8,
35 | ) !Connection {
36 | var buffered_reader = BufferedReader{ .unbuffered_reader = underlying_stream.reader() };
37 | var writer = underlying_stream.writer();
38 |
39 | const ws_client = try allocator.create(WsClient);
40 | errdefer allocator.destroy(ws_client);
41 |
42 | ws_client.* = client(
43 | buffered_reader.reader(),
44 | writer,
45 | READ_BUFFER_SIZE,
46 | WRITE_BUFFER_SIZE,
47 | );
48 |
49 | var self = Connection{
50 | .underlying_stream = underlying_stream,
51 | .ws_client = ws_client,
52 | .buffered_reader = buffered_reader,
53 | .headers = .{},
54 | };
55 |
56 | try self.ws_client.handshake(allocator, uri, request_headers, &self.headers);
57 | return self;
58 | }
59 |
60 | pub fn deinit(self: *Connection, allocator: mem.Allocator) void {
61 | defer allocator.destroy(self.ws_client);
62 | self.ws_client.deinit(allocator, &self.headers);
63 | self.underlying_stream.close();
64 | }
65 |
66 | /// Send a WebSocket message to the server.
67 | /// The `opcode` field can be text, binary, ping, pong or close.
68 | /// In order to send continuation frames or streaming messages, check out `stream` function.
69 | pub fn send(self: Connection, opcode: Opcode, data: []const u8) !void {
70 | return self.ws_client.send(opcode, data);
71 | }
72 |
73 | /// Send a ping message to the server.
74 | pub fn ping(self: Connection) !void {
75 | return self.send(.ping, "");
76 | }
77 |
78 | /// Send a pong message to the server.
79 | pub fn pong(self: Connection) !void {
80 | return self.send(.pong, "");
81 | }
82 |
83 | /// Send a close message to the server.
84 | pub fn close(self: Connection) !void {
85 | return self.ws_client.close();
86 | }
87 |
88 | /// TODO: Add usage example
89 | /// Send send continuation frames or streaming messages to the server.
90 | pub fn stream(self: Connection, opcode: Opcode, payload: ?[]const u8) !void {
91 | return self.ws_client.stream(opcode, payload);
92 | }
93 |
94 | /// Receive a message from the server.
95 | pub fn receive(self: Connection) !Message {
96 | return self.ws_client.receive();
97 | }
98 | };
99 |
--------------------------------------------------------------------------------
/src/client.zig:
--------------------------------------------------------------------------------
1 | const std = @import("std");
2 | const mem = std.mem;
3 |
4 | const common = @import("common.zig");
5 | const Message = common.Message;
6 | const Opcode = common.Opcode;
7 |
8 | const Receiver = @import("receiver.zig").Receiver;
9 | const Sender = @import("sender.zig").Sender;
10 |
11 | /// Create a new WebSocket client.
12 | /// This interface is for using your own reader and writer.
13 | pub fn client(
14 | reader: anytype,
15 | writer: anytype,
16 | comptime read_buffer_size: usize,
17 | comptime write_buffer_size: usize,
18 | ) Client(@TypeOf(reader), @TypeOf(writer), read_buffer_size, write_buffer_size)
19 | {
20 | var mask: [4]u8 = undefined;
21 | std.crypto.random.bytes(&mask);
22 |
23 | return .{
24 | .receiver = .{ .reader = reader },
25 | .sender = .{ .writer = writer, .mask = mask },
26 | };
27 | }
28 |
29 | /// Create a new WebSocket client.
30 | /// This interface is for using your own reader and writer.
31 | pub fn Client(
32 | comptime Reader: type,
33 | comptime Writer: type,
34 | comptime read_buffer_size: usize,
35 | comptime write_buffer_size: usize,
36 | ) type {
37 | return struct {
38 | const Self = @This();
39 |
40 | receiver: Receiver(Reader, read_buffer_size),
41 | sender: Sender(Writer, write_buffer_size),
42 |
43 | /// Deallocate response headers.
44 | pub fn deinit(
45 | self: Self,
46 | allocator: mem.Allocator,
47 | headers: *std.StringHashMapUnmanaged([]const u8),
48 | ) void {
49 | self.receiver.freeHttpHeaders(allocator, headers);
50 | }
51 |
52 | pub fn handshake(
53 | self: *Self,
54 | allocator: mem.Allocator,
55 | uri: std.Uri,
56 | request_headers: ?[]const [2][]const u8,
57 | response_headers: *std.StringHashMapUnmanaged([]const u8),
58 | ) !void {
59 | // create a random Sec-WebSocket-Key
60 | var buf: [24]u8 = undefined;
61 | std.crypto.random.bytes(buf[0..16]);
62 | const key = std.base64.standard.Encoder.encode(&buf, buf[0..16]);
63 |
64 | try self.sender.sendRequest(uri, request_headers, key);
65 | try self.receiver.receiveResponse(allocator, response_headers);
66 | errdefer self.receiver.freeHttpHeaders(allocator, response_headers);
67 |
68 | try checkWebSocketAcceptKey(response_headers.*, key);
69 | }
70 |
71 | const WsAcceptKeyError = error{KeyControlFailed, AcceptKeyNotFound};
72 |
73 | /// Controls the accept key received from the server
74 | fn checkWebSocketAcceptKey(
75 | headers: std.StringHashMapUnmanaged([]const u8),
76 | key: []const u8,
77 | ) WsAcceptKeyError!void {
78 | if (headers.get("Sec-WebSocket-Accept")) |sec_websocket_accept| {
79 | const magic_string = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
80 |
81 | var hash_buf: [20]u8 = undefined;
82 | var h = std.crypto.hash.Sha1.init(.{});
83 | h.update(key);
84 | h.update(magic_string);
85 | h.final(&hash_buf);
86 |
87 | var encoded_hash_buf: [28]u8 = undefined;
88 | const our_accept = std.base64.standard.Encoder.encode(&encoded_hash_buf, &hash_buf);
89 |
90 | if (!mem.eql(u8, our_accept, sec_websocket_accept))
91 | return error.KeyControlFailed;
92 | } else return error.AcceptKeyNotFound;
93 | }
94 |
95 | /// Send a WebSocket message to the server.
96 | /// The `opcode` field can be text, binary, ping, pong or close.
97 | /// In order to send continuation frames or streaming messages, check out `stream` function.
98 | pub fn send(self: *Self, opcode: Opcode, data: []const u8) !void {
99 | return self.sender.send(opcode, data);
100 | }
101 |
102 | /// Send a ping message to the server.
103 | pub fn ping(self: *Self) !void {
104 | return self.send(.ping, "");
105 | }
106 |
107 | /// Send a pong message to the server.
108 | pub fn pong(self: *Self) !void {
109 | return self.send(.pong, "");
110 | }
111 |
112 | /// Send a close message to the server.
113 | pub fn close(self: *Self) !void {
114 | return self.sender.close();
115 | }
116 |
117 | /// TODO: Add usage example
118 | /// Send send continuation frames or streaming messages to the server.
119 | pub fn stream(self: *Self, opcode: Opcode, payload: ?[]const u8) !void {
120 | return self.sender.stream(opcode, payload);
121 | }
122 |
123 | /// Receive a message from the server.
124 | pub fn receive(self: *Self) !Message {
125 | return self.receiver.receive();
126 | }
127 | };
128 | }
129 |
--------------------------------------------------------------------------------
/src/sender.zig:
--------------------------------------------------------------------------------
1 | const std = @import("std");
2 | const mem = std.mem;
3 | const io = std.io;
4 | const common = @import("common.zig");
5 | const Opcode = common.Opcode;
6 | const Header = common.Header;
7 |
8 | const MAX_CTL_FRAME_LENGTH = common.MAX_CTL_FRAME_LENGTH;
9 | const MASK_BUFFER_SIZE: usize = 1024;
10 | const DEFAULT_CLOSE_CODE: u16 = 1000;
11 |
12 | fn getUriFullPath(uri: std.Uri) ![]const u8 {
13 | var buf: [MASK_BUFFER_SIZE]u8 = undefined;
14 | return try std.fmt.bufPrint(&buf, "{}", .{uri});
15 | }
16 |
17 | pub fn Sender(comptime Writer: type, comptime capacity: usize) type {
18 | return struct {
19 | const Self = @This();
20 |
21 | writer: Writer,
22 | mask: [4]u8,
23 | // for buffered writes
24 | buffer: [capacity]u8 = undefined,
25 | end: usize = 0,
26 |
27 | pub fn sendRequest(
28 | self: *Self,
29 | uri: std.Uri,
30 | request_headers: ?[]const [2][]const u8,
31 | sec_websocket_key: []const u8,
32 | ) !void {
33 | // push http request line
34 | try self.put("GET ");
35 | try self.put(try getUriFullPath(uri));
36 | try self.put(" HTTP/1.1\r\n");
37 |
38 | // push default headers
39 | const default_headers =
40 | "Pragma: no-cache\r\n" ++
41 | "Cache-Control: no-cache\r\n" ++
42 | "Connection: Upgrade\r\n" ++
43 | "Upgrade: websocket\r\n" ++
44 | "Sec-WebSocket-Version: 13\r\n";
45 | try self.put(default_headers);
46 |
47 | // push websocket key
48 | try self.put("Sec-WebSocket-Key: ");
49 | try self.put(sec_websocket_key);
50 | try self.put("\r\n");
51 |
52 | // push user defined headers
53 | if (request_headers) |headers| {
54 | for (headers) |header| {
55 | try self.put(header[0]);
56 | try self.put(": ");
57 | try self.put(header[1]);
58 | try self.put("\r\n");
59 | }
60 | }
61 |
62 | // send 'em all
63 | try self.put("\r\n");
64 | return self.flush();
65 | }
66 |
67 | /// Write bytes that're buffered in Sender and reset the terminator.
68 | fn flush(self: *Self) Writer.Error!void {
69 | try self.writer.writeAll(self.buffer[0..self.end]);
70 | self.end = 0;
71 | }
72 |
73 | /// Does buffered writes, pretty similar to io.BufferedWriter.
74 | fn put(self: *Self, bytes: []const u8) Writer.Error!void {
75 | if (self.end + bytes.len > self.buffer.len) {
76 | try self.flush();
77 | if (bytes.len > self.buffer.len)
78 | return self.writer.writeAll(bytes);
79 | }
80 |
81 | mem.copy(u8, self.buffer[self.end..], bytes);
82 | self.end += bytes.len;
83 | }
84 |
85 | fn putHeader(self: *Self, header: Header) Writer.Error!void {
86 | var buf: [14]u8 = undefined;
87 |
88 | buf[0] = @as(u8, @intFromEnum(header.opcode));
89 | if (header.fin) buf[0] |= 0x80;
90 |
91 | buf[1] = 0x80;
92 | if (header.len < 126) {
93 | buf[1] |= @truncate(header.len);
94 | mem.copy(u8, buf[2..], &self.mask);
95 |
96 | // 2 + 4
97 | return self.put(buf[0..6]);
98 | } else if (header.len < 65536) {
99 | buf[1] |= 126;
100 | mem.writeIntBig(u16, buf[2..4], @as(u16, @truncate(header.len)));
101 | mem.copy(u8, buf[4..], &self.mask);
102 |
103 | // 2 + 2 + 4
104 | return self.put(buf[0..8]);
105 | } else {
106 | buf[1] |= 127;
107 | mem.writeIntBig(u64, buf[2..10], header.len);
108 | mem.copy(u8, buf[10..], &self.mask);
109 |
110 | // 2 + 8 + 4
111 | return self.put(&buf);
112 | }
113 |
114 | unreachable;
115 | }
116 |
117 | fn maskBytes(self: Self, buf: []u8, source: []const u8, pos: usize) void {
118 | for (source, 0..) |c, i|
119 | buf[i] = c ^ self.mask[(i + pos) % 4];
120 | }
121 |
122 | fn putMasked(self: *Self, data: []const u8) Writer.Error!void {
123 | var buf: [MASK_BUFFER_SIZE]u8 = undefined;
124 |
125 | // small payload, cool stuff!
126 | if (data.len <= MASK_BUFFER_SIZE) {
127 | self.maskBytes(buf[0..data.len], data, 0);
128 | return self.put(buf[0..data.len]);
129 | }
130 |
131 | const remainder = data.len % MASK_BUFFER_SIZE;
132 | const num_of_chunks = (data.len - remainder) / MASK_BUFFER_SIZE;
133 | var current_chunk: usize = 0;
134 | var pos: usize = 0;
135 |
136 | while (current_chunk < num_of_chunks) : (current_chunk += 1) {
137 | pos = current_chunk * MASK_BUFFER_SIZE;
138 | const chunk = data[pos..pos + MASK_BUFFER_SIZE];
139 |
140 | self.maskBytes(buf[0..], chunk, pos);
141 | try self.put(buf[0..]);
142 | }
143 |
144 | if (remainder == 0)
145 | return;
146 |
147 | // got remainder
148 | pos += MASK_BUFFER_SIZE;
149 | const chunk = data[pos..pos + remainder];
150 |
151 | self.maskBytes(&buf, chunk, pos);
152 | return self.put(buf[0..remainder]);
153 | }
154 |
155 | // ----------------------------------
156 | // Send API
157 | // ----------------------------------
158 |
159 | /// Send a WebSocket message.
160 | pub fn send(self: *Self, opcode: Opcode, data: []const u8) !void {
161 | return switch (opcode) {
162 | .text, .binary => self.regular(opcode, data),
163 | .ping, .pong => self.pingPong(opcode, data),
164 | .close => self.close(),
165 |
166 | .continuation, .end => error.UseStreamInstead,
167 | else => error.UnknownOpcode,
168 | };
169 | }
170 |
171 | // text + binary messages
172 | fn regular(self: *Self, opcode: Opcode, data: []const u8) !void {
173 | try self.putHeader(.{
174 | .len = data.len,
175 | .opcode = opcode,
176 | .fin = true,
177 | });
178 | try self.putMasked(data);
179 |
180 | return self.flush();
181 | }
182 |
183 | // the name implies
184 | fn pingPong(self: *Self, opcode: Opcode, data: []const u8) !void {
185 | if (data.len > MAX_CTL_FRAME_LENGTH)
186 | return error.PayloadTooBig;
187 |
188 | try self.putHeader(.{
189 | .len = data.len,
190 | .opcode = opcode,
191 | .fin = true,
192 | });
193 | try self.putMasked(data);
194 |
195 | return self.flush();
196 | }
197 |
198 | // TODO: implement close code & reason.
199 | pub fn close(self: *Self) !void {
200 | try self.putHeader(.{
201 | .len = 0,
202 | .opcode = .close,
203 | .fin = true,
204 | });
205 |
206 | return self.flush();
207 | }
208 |
209 | // ----------------------------------
210 | // Stream API
211 | // ----------------------------------
212 |
213 | /// writes data piece by piece, good for streaming big or unknown amounts of data as chunks.
214 | pub fn stream(self: *Self, opcode: Opcode, payload: ?[]const u8) !void {
215 | if (payload) |data| {
216 | return switch (opcode) {
217 | .text, .binary => self.fragmented(opcode, data, false),
218 | .continuation => self.fragmented(opcode, data, false),
219 | .end => self.fragmented(.continuation, data, true),
220 |
221 | else => error.UnknownOpcode,
222 | };
223 | }
224 |
225 | try self.putHeader(.{
226 | .len = 0,
227 | .opcode = switch (opcode) {
228 | .text, .binary,
229 | .continuation => opcode,
230 | .end => .continuation,
231 |
232 | else => return error.UnknownOpcode,
233 | },
234 | .fin = switch (opcode) {
235 | .text, .binary,
236 | .continuation => false,
237 | .end => true,
238 |
239 | else => return error.UnknownOpcode,
240 | },
241 | });
242 |
243 | return self.flush();
244 | }
245 |
246 | fn fragmented(self: *Self, opcode: Opcode, data: []const u8, fin: bool) !void {
247 | try self.putHeader(.{
248 | .len = data.len,
249 | .opcode = opcode,
250 | .fin = fin,
251 | });
252 | try self.putMasked(data);
253 |
254 | return self.flush();
255 | }
256 | };
257 | }
258 |
259 | test "std.Uri processing results in expected paths" {
260 | const uris = [_]std.Uri {
261 | try std.Uri.parse("ws://localhost"),
262 | try std.Uri.parse("ws://localhost/"),
263 | try std.Uri.parse("ws://localhost?query=example"),
264 | try std.Uri.parse("ws://localhost/?query=example"),
265 | try std.Uri.parse("ws://localhost/?query1=&&something&query2=somethingelse"),
266 | try std.Uri.parse("ws://localhost/?query1=something with spaces&query2=somethingelse"),
267 | try std.Uri.parse("ws://localhost:8080"),
268 | try std.Uri.parse("ws://localhost:8080/"),
269 | try std.Uri.parse("ws://localhost:8080?query=example"),
270 | try std.Uri.parse("ws://localhost:8080/?query=example"),
271 | try std.Uri.parse("ws://localhost:8080/?query1=&&something&query2=somethingelse"),
272 | try std.Uri.parse("ws://localhost:8080/?query1=something with spaces&query2=somethingelse"),
273 | };
274 |
275 | const paths = [_][]const u8{
276 | "/",
277 | "/",
278 | "/?query=example",
279 | "/?query=example",
280 | "/?query1=&&something&query2=somethingelse",
281 | "/?query1=something%20with%20spaces&query2=somethingelse",
282 | "/",
283 | "/",
284 | "/?query=example",
285 | "/?query=example",
286 | "/?query1=&&something&query2=somethingelse",
287 | "/?query1=something%20with%20spaces&query2=somethingelse",
288 | };
289 |
290 | for (uris, paths) |uri, path| {
291 | try std.testing.expectEqualSlices(u8, path, try getUriFullPath(uri));
292 | }
293 | }
294 |
--------------------------------------------------------------------------------
/src/receiver.zig:
--------------------------------------------------------------------------------
1 | const std = @import("std");
2 | const mem = std.mem;
3 | const common = @import("common.zig");
4 | const Opcode = common.Opcode;
5 | const Header = common.Header;
6 | const Message = common.Message;
7 |
8 | const MAX_CTL_FRAME_LENGTH = common.MAX_CTL_FRAME_LENGTH;
9 | // max header size can be 10 * u8,
10 | // if masking is allowed, header size can be up to 14 * u8
11 | // server should not be sending masked messages.
12 | const MAX_HEADER_SIZE = 10;
13 |
14 | pub fn Receiver(comptime Reader: type, comptime capacity: usize) type {
15 | return struct {
16 | const Self = @This();
17 |
18 | reader: Reader,
19 | buffer: [capacity]u8 = undefined,
20 | header_buffer: [MAX_HEADER_SIZE]u8 = undefined,
21 | // specified for ping, pong and close frames.
22 | control_buffer: [MAX_CTL_FRAME_LENGTH]u8 = undefined,
23 | end: usize = 0,
24 | fragmentation: Fragmentation = .{},
25 |
26 | const Fragmentation = struct {
27 | on: bool = false,
28 | opcode: Opcode = .text,
29 | };
30 |
31 | /// Deallocate HTTP response headers and string hashmap.
32 | pub fn freeHttpHeaders(
33 | _: Self,
34 | allocator: mem.Allocator,
35 | headers: *std.StringHashMapUnmanaged([]const u8),
36 | ) void {
37 | defer headers.deinit(allocator);
38 | var iter = headers.iterator();
39 | while (iter.next()) |entry| {
40 | //std.debug.print("{s}: {s}\n", .{entry.key_ptr.*, entry.value_ptr.*});
41 | allocator.free(entry.key_ptr.*);
42 | allocator.free(entry.value_ptr.*);
43 | }
44 | }
45 |
46 | /// Receive and allocate for HTTP headers, uses a StringHashMapUnmanaged([]const u8) to store the parsed headers.
47 | pub fn receiveResponse(
48 | self: Self,
49 | allocator: mem.Allocator,
50 | headers: *std.StringHashMapUnmanaged([]const u8),
51 | ) !void {
52 | errdefer self.freeHttpHeaders(allocator, headers);
53 | var buf: [2048]u8 = undefined;
54 | var i: usize = 0;
55 | var state: enum { key, value } = .key;
56 | var key_ptr: ?[]u8 = null;
57 |
58 | // HTTP/1.1 101 Switching Protocols
59 | const request_line = try self.reader.readUntilDelimiter(&buf, '\n');
60 | if (request_line.len < 32) return error.FailedSwitchingProtocols;
61 | if (!mem.eql(u8, request_line[0..32], "HTTP/1.1 101 Switching Protocols"))
62 | return error.FailedSwitchingProtocols;
63 |
64 | while (true) {
65 | const b = try self.reader.readByte();
66 | switch (state) {
67 | .key => switch (b) {
68 | ':' => { // delimiter of key
69 | // make sure space comes afterwards
70 | if (try self.reader.readByte() == ' ') {
71 | key_ptr = try allocator.dupe(u8, buf[0..i]);
72 | i = 0;
73 | state = .value;
74 | } else {
75 | return error.BadHttpResponse;
76 | }
77 | },
78 | '\r' => {
79 | if (try self.reader.readByte() == '\n') break;
80 | return error.BadHttpResponse;
81 | },
82 | '\n' => break,
83 |
84 | else => {
85 | buf[i] = b;
86 | if (i < buf.len) {
87 | i += 1;
88 | } else {
89 | return error.HttpHeaderTooLong;
90 | }
91 | },
92 | },
93 |
94 | .value => switch (b) {
95 | '\r' => {
96 | // make sure '\n' comes afterwards
97 | if (try self.reader.readByte() == '\n') {
98 | if (key_ptr) |ptr| {
99 | errdefer allocator.free(ptr);
100 | if (headers.contains(ptr)) {
101 | return error.RepeatingHttpHeader;
102 | // FIXME: alternative
103 | //const entry = headers.getEntry(ptr).?;
104 | //allocator.free(entry.key_ptr.*);
105 | //allocator.free(entry.value_ptr.*);
106 | }
107 |
108 | try headers.put(allocator, ptr, try allocator.dupe(u8, buf[0..i]));
109 | } else {
110 | return error.BadHttpResponse;
111 | }
112 |
113 | i = 0;
114 | state = .key;
115 | } else {
116 | return error.BadHttpResponse;
117 | }
118 | },
119 |
120 | else => {
121 | buf[i] = b;
122 | if (i < buf.len) {
123 | i += 1;
124 | } else {
125 | return error.HttpHeaderTooLong;
126 | }
127 | },
128 | },
129 | }
130 | }
131 | }
132 |
133 | pub const GetHeaderError = error{EndOfStream} || Header.Error || Reader.Error;
134 |
135 | fn getHeader(self: *Self) GetHeaderError!Header {
136 | const buf = self.header_buffer[0..2];
137 |
138 | const len = try self.reader.readAll(buf);
139 | if (len < 2) return error.EndOfStream;
140 |
141 | const is_masked = buf[1] & 0x80 != 0;
142 | if (is_masked)
143 | return error.MaskedMessageFromServer; // FIXME: should this be allowed?
144 |
145 | // get length from variable length
146 | const var_length: u7 = @truncate(buf[1] & 0x7F);
147 | const length = try self.getLength(var_length);
148 |
149 | const b = buf[0];
150 | const fin = b & 0x80 != 0;
151 | const rsv1 = b & 0x40 != 0;
152 | const rsv2 = b & 0x20 != 0;
153 | const rsv3 = b & 0x10 != 0;
154 |
155 | const op = b & 0x0F;
156 | const opcode: Opcode = @enumFromInt(@as(u4, @truncate(op)));
157 |
158 | return Header{
159 | .len = length,
160 | .opcode = opcode,
161 | .fin = fin,
162 | .rsv1 = rsv1,
163 | .rsv2 = rsv2,
164 | .rsv3 = rsv3,
165 | };
166 | }
167 |
168 | pub const GetLengthError = error{EndOfStream} || Reader.Error;
169 |
170 | fn getLength(self: *Self, var_length: u7) GetLengthError!u64 {
171 | return switch (var_length) {
172 | 126 => {
173 | const len = try self.reader.readAll(self.header_buffer[2..4]);
174 | if (len < 2) return error.EndOfStream;
175 |
176 | return mem.readIntBig(u16, self.header_buffer[2..4]);
177 | },
178 |
179 | 127 => {
180 | const len = try self.reader.readAll(self.header_buffer[2..]);
181 | if (len < 8) return error.EndOfStream;
182 |
183 | return mem.readIntBig(u64, self.header_buffer[2..]);
184 | },
185 |
186 | inline else => var_length,
187 | };
188 | }
189 |
190 | fn pingPong(self: *Self, header: Header) FrameError!Message {
191 | if (header.len > self.control_buffer.len)
192 | return error.PayloadTooBig;
193 |
194 | const buf = self.control_buffer[0..header.len];
195 |
196 | const len = try self.reader.readAll(buf);
197 | if (len < buf.len)
198 | return error.EndOfStream;
199 |
200 | return Message.from(header.opcode, buf, null);
201 | }
202 |
203 | fn close(self: *Self, header: Header) FrameError!Message {
204 | if (header.len > self.control_buffer.len)
205 | return error.PayloadTooBig;
206 |
207 | const buf = self.control_buffer[0..header.len];
208 |
209 | const len = try self.reader.readAll(buf);
210 | if (len < buf.len)
211 | return error.EndOfStream;
212 |
213 | return switch (buf.len) {
214 | 0 => Message.from(.close, buf, null),
215 |
216 | 2 => { // without reason but code
217 | const code = mem.readIntBig(u16, buf[0..2]);
218 |
219 | return Message.from(.close, buf, code);
220 | },
221 |
222 | else => { // with reason
223 | const code = mem.readIntBig(u16, buf[0..2]);
224 | const reason = buf[2..];
225 |
226 | return Message.from(.close, reason, code);
227 | }
228 | };
229 | }
230 |
231 | pub const ContinuationError = error{UnknownOpcode} || FrameError || GetHeaderError;
232 |
233 | // this must be called when continuation frame is received
234 | fn continuation1(self: *Self, header: Header) (error{BadMessageOrder} || ContinuationError)!Message {
235 | if (!self.fragmentation.on)
236 | return error.BadMessageOrder;
237 |
238 | var last: Header = header;
239 | while (true) : (last = try self.getHeader()) {
240 | switch (last.opcode) {
241 | .continuation => {},
242 | .text, .binary => return error.BadMessageOrder,
243 | .ping, .pong => return self.pingPong(last),
244 | .close => return self.close(last),
245 |
246 | else => return error.UnknownOpcode,
247 | }
248 |
249 | const boundary = self.end + last.len;
250 | if (boundary > self.buffer.len)
251 | return error.PayloadTooBig;
252 |
253 | const buf = self.buffer[self.end..boundary];
254 |
255 | const len = try self.reader.readAll(buf);
256 | if (len < buf.len)
257 | return error.EndOfStream;
258 |
259 | self.end = boundary;
260 | if (last.fin) break;
261 | }
262 |
263 | const buf = self.buffer[0..self.end];
264 | self.end = 0;
265 |
266 | return Message.from(self.fragmentation.opcode, buf, null);
267 | }
268 |
269 | // this must be called when text or binary frame without fin is received
270 | fn continuation(self: *Self, header: Header) ContinuationError!Message {
271 | // keep track of fragmentation
272 | self.fragmentation.on = true;
273 | self.fragmentation.opcode = header.opcode;
274 |
275 | var last: Header = header;
276 | // any of the control frames might sneak in to this while loop,
277 | // beware!
278 | while (true) : (last = try self.getHeader()) {
279 | switch (last.opcode) {
280 | .text, .binary, .continuation => {},
281 | // disturbed
282 | .ping, .pong => return self.pingPong(last),
283 | .close => return self.close(last),
284 |
285 | else => return error.UnknownOpcode,
286 | }
287 |
288 | const boundary = self.end + last.len;
289 | if (boundary > self.buffer.len)
290 | return error.PayloadTooBig;
291 |
292 | const buf = self.buffer[self.end..boundary];
293 |
294 | const len = try self.reader.readAll(buf);
295 | if (len < buf.len)
296 | return error.EndOfStream;
297 |
298 | self.end = boundary;
299 | if (last.fin) break;
300 | }
301 |
302 | const buf = self.buffer[0..self.end];
303 | self.end = 0;
304 |
305 | return Message.from(self.fragmentation.opcode, buf, null);
306 | }
307 |
308 | pub const FrameError = error{
309 | EndOfStream,
310 | PayloadTooBig,
311 | } || Message.Error || Reader.Error;
312 |
313 | fn regular(self: *Self, header: Header) FrameError!Message {
314 | const boundary = self.end + header.len;
315 |
316 | if (boundary > self.buffer.len)
317 | return error.PayloadTooBig;
318 |
319 | const buf = self.buffer[self.end..boundary];
320 |
321 | const len = try self.reader.readAll(buf);
322 | if (len < buf.len)
323 | return error.EndOfStream;
324 |
325 | return Message.from(header.opcode, buf, null);
326 | }
327 |
328 | pub const Error = error{BadMessageOrder} || Header.Error || FrameError || ContinuationError;
329 |
330 | /// Receive the next message from the stream.
331 | pub fn receive(self: *Self) Error!Message {
332 | const header = try self.getHeader();
333 |
334 | return switch (header.opcode) {
335 | .continuation => self.continuation1(header),
336 | .text, .binary => switch (header.fin) {
337 | true => self.regular(header),
338 | false => self.continuation(header),
339 | },
340 |
341 | // control frames
342 | .ping, .pong => self.pingPong(header),
343 | .close => self.close(header),
344 |
345 | else => error.UnknownOpcode,
346 | };
347 | }
348 | };
349 | }
350 |
--------------------------------------------------------------------------------