├── .gitattributes ├── .gitignore ├── zigmod.yml ├── src ├── lib.zig ├── main_tinyhost.zig ├── enums.zig ├── main.zig ├── test.zig ├── resource_data.zig ├── parser.zig ├── packet.zig ├── cidr.zig ├── name.zig └── helpers.zig ├── LICENSE └── README.md /.gitattributes: -------------------------------------------------------------------------------- 1 | * text=auto 2 | *.zig text eol=lf 3 | zigmod.* text eol=lf 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | zig-cache/ 2 | .zig-cache/ 3 | zig-out/ 4 | .zigmod 5 | deps.zig 6 | -------------------------------------------------------------------------------- /zigmod.yml: -------------------------------------------------------------------------------- 1 | id: zsu7cmhy4r867qqklav1yjl4ww8dvyvlkndwv0k8t6b0sv6c 2 | name: zigdig 3 | main: src/lib.zig 4 | license: MIT 5 | description: A naively implemented DNS library for Zig 6 | dependencies: 7 | -------------------------------------------------------------------------------- /src/lib.zig: -------------------------------------------------------------------------------- 1 | pub const ResourceType = @import("enums.zig").ResourceType; 2 | pub const ResourceClass = @import("enums.zig").ResourceClass; 3 | 4 | pub const names = @import("name.zig"); 5 | pub const FullName = names.FullName; 6 | pub const RawName = names.RawName; 7 | pub const Name = names.Name; 8 | pub const LabelComponent = names.LabelComponent; 9 | pub const NamePool = names.NamePool; 10 | 11 | const pkt = @import("packet.zig"); 12 | pub const Packet = pkt.Packet; 13 | pub const ResponseCode = pkt.ResponseCode; 14 | pub const OpCode = pkt.OpCode; 15 | pub const IncomingPacket = pkt.IncomingPacket; 16 | pub const Question = pkt.Question; 17 | pub const Resource = pkt.Resource; 18 | pub const Header = pkt.Header; 19 | 20 | pub const parserlib = @import("parser.zig"); 21 | pub const parser = parserlib.parser; 22 | pub const Parser = parserlib.Parser; 23 | pub const ParserOptions = parserlib.ParserOptions; 24 | pub const ParserContext = parserlib.ParserContext; 25 | 26 | pub const helpers = @import("helpers.zig"); 27 | 28 | const resource_data = @import("resource_data.zig"); 29 | pub const ResourceData = resource_data.ResourceData; 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019 Luna Mendes 2 | Copyright (c) 2005-2014 Rich Felker, et al. 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy 5 | of this software and associated documentation files (the "Software"), to deal 6 | in the Software without restriction, including without limitation the rights 7 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | copies of the Software, and to permit persons to whom the Software is 9 | furnished to do so, subject to the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be included in all 12 | copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | SOFTWARE. 21 | -------------------------------------------------------------------------------- /src/main_tinyhost.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const dns = @import("lib.zig"); 3 | 4 | const logger = std.log.scoped(.zigdig_main); 5 | pub const std_options = std.Options{ 6 | .log_level = .debug, 7 | .logFn = logfn, 8 | }; 9 | 10 | pub var current_log_level: std.log.Level = .info; 11 | 12 | fn logfn( 13 | comptime message_level: std.log.Level, 14 | comptime scope: @Type(.enum_literal), 15 | comptime format: []const u8, 16 | args: anytype, 17 | ) void { 18 | if (@intFromEnum(message_level) <= @intFromEnum(@import("root").current_log_level)) { 19 | std.log.defaultLog(message_level, scope, format, args); 20 | } 21 | } 22 | 23 | pub fn main() !void { 24 | var gpa = std.heap.GeneralPurposeAllocator(.{}){}; 25 | defer { 26 | _ = gpa.deinit(); 27 | } 28 | const allocator = gpa.allocator(); 29 | 30 | const debug = std.process.getEnvVarOwned(allocator, "DEBUG") catch |err| switch (err) { 31 | error.EnvironmentVariableNotFound => try allocator.dupe(u8, ""), 32 | else => return err, 33 | }; 34 | defer allocator.free(debug); 35 | if (std.mem.eql(u8, debug, "1")) current_log_level = .debug; 36 | 37 | var args_it = try std.process.argsWithAllocator(allocator); 38 | defer args_it.deinit(); 39 | _ = args_it.skip(); 40 | 41 | const name_string = (args_it.next() orelse { 42 | logger.warn("no name provided", .{}); 43 | return error.InvalidArgs; 44 | }); 45 | 46 | var addrs = try dns.helpers.getAddressList(name_string, 80, allocator); 47 | defer addrs.deinit(); 48 | 49 | var stdout = std.io.getStdOut().writer(); 50 | 51 | for (addrs.addrs) |addr| { 52 | try stdout.print("{s} has address {any}\n", .{ name_string, addr }); 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /src/enums.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | 3 | const logger = std.log.scoped(.dns_enums); 4 | 5 | /// Represents a DNS type. 6 | /// Keep in mind this enum does not declare all possible DNS types. 7 | pub const ResourceType = enum(u16) { 8 | A = 1, 9 | NS = 2, 10 | MD = 3, 11 | MF = 4, 12 | CNAME = 5, 13 | SOA = 6, 14 | MB = 7, 15 | MG = 8, 16 | MR = 9, 17 | NULL = 10, 18 | WKS = 11, 19 | PTR = 12, 20 | HINFO = 13, 21 | MINFO = 14, 22 | MX = 15, 23 | TXT = 16, 24 | 25 | AAAA = 28, 26 | // TODO LOC = 29, (check if it's worth it. https://tools.ietf.org/html/rfc1876) 27 | SRV = 33, 28 | 29 | // https://www.rfc-editor.org/rfc/rfc6891#section-6 30 | // The OPT RR has RR type 41. 31 | OPT = 41, 32 | 33 | // those types are only valid in request packets. they may be wanted 34 | // later on for completeness, but for now, it's more hassle than it's worth. 35 | // AXFR = 252, 36 | // MAILB = 253, 37 | // MAILA = 254, 38 | // ANY = 255, 39 | 40 | // should this enum be non-exhaustive? 41 | // what does it actually mean to be non-exhaustive? 42 | //_, 43 | 44 | const Self = @This(); 45 | 46 | /// Try to convert a given string (case-insensitive compare) to an 47 | /// integer representing a Type. 48 | pub fn fromString(str: []const u8) error{InvalidResourceType}!Self { 49 | // this returned Overflow but i think InvalidResourceType is also valid 50 | // considering we dont have resource types that are more than 10 51 | // characters long. 52 | if (str.len > 10) return error.InvalidResourceType; 53 | 54 | // TODO we wouldn't need this buffer if we could do some 55 | // case insensitive string comparison in stdlib or something 56 | var buffer: [10]u8 = undefined; 57 | for (str, 0..) |char, index| { 58 | buffer[index] = std.ascii.toUpper(char); 59 | } 60 | 61 | const uppercased = buffer[0..str.len]; 62 | 63 | const type_info = @typeInfo(Self).@"enum"; 64 | inline for (type_info.fields) |field| { 65 | if (std.mem.eql(u8, uppercased, field.name)) { 66 | return @as(Self, @enumFromInt(field.value)); 67 | } 68 | } 69 | 70 | return error.InvalidResourceType; 71 | } 72 | 73 | pub fn readFrom(reader: anytype) !Self { 74 | const resource_type_int = try reader.readInt(u16, .big); 75 | return std.meta.intToEnum(Self, resource_type_int) catch |err| { 76 | logger.err( 77 | "unknown resource type {d}, got {s}", 78 | .{ resource_type_int, @errorName(err) }, 79 | ); 80 | return err; 81 | }; 82 | } 83 | 84 | /// Write the network representation of this type to a stream. 85 | /// 86 | /// Returns amount of bytes written. 87 | pub fn writeTo(self: Self, writer: anytype) !usize { 88 | try writer.writeInt(u16, @intFromEnum(self), .big); 89 | return 16 / 8; 90 | } 91 | }; 92 | 93 | /// Represents a DNS class. 94 | /// (TODO point to rfc) 95 | pub const ResourceClass = enum(u16) { 96 | /// The internet 97 | IN = 1, 98 | CS = 2, 99 | CH = 3, 100 | HS = 4, 101 | WILDCARD = 255, 102 | 103 | pub fn readFrom(reader: anytype) !@This() { 104 | const resource_class_int = try reader.readInt(u16, .big); 105 | return std.meta.intToEnum(@This(), resource_class_int) catch |err| { 106 | logger.err( 107 | "unknown resource class {d}, got {s}", 108 | .{ resource_class_int, @errorName(err) }, 109 | ); 110 | return err; 111 | }; 112 | } 113 | 114 | /// Write the network representation of this class to a stream. 115 | /// 116 | /// Returns amount of bytes written. 117 | pub fn writeTo(self: @This(), writer: anytype) !usize { 118 | try writer.writeInt(u16, @intFromEnum(self), .big); 119 | return 16 / 8; 120 | } 121 | }; 122 | -------------------------------------------------------------------------------- /src/main.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const builtin = @import("builtin"); 3 | const dns = @import("lib.zig"); 4 | 5 | const logger = std.log.scoped(.zigdig_main); 6 | 7 | pub const std_options = std.Options{ 8 | .log_level = .debug, 9 | .logFn = logfn, 10 | }; 11 | 12 | pub var current_log_level: std.log.Level = .info; 13 | 14 | fn logfn( 15 | comptime message_level: std.log.Level, 16 | comptime scope: @Type(.enum_literal), 17 | comptime format: []const u8, 18 | args: anytype, 19 | ) void { 20 | if (@intFromEnum(message_level) <= @intFromEnum(@import("root").current_log_level)) { 21 | std.log.defaultLog(message_level, scope, format, args); 22 | } 23 | } 24 | 25 | pub fn main() !void { 26 | var gpa = std.heap.GeneralPurposeAllocator(.{}){}; 27 | defer { 28 | _ = gpa.deinit(); 29 | } 30 | const allocator = gpa.allocator(); 31 | 32 | if (builtin.os.tag == .windows) { 33 | const debug = try std.unicode.utf8ToUtf16LeAllocZ(allocator, "DEBUG"); 34 | defer allocator.free(debug); 35 | 36 | const debug_expected = try std.unicode.utf8ToUtf16LeAllocZ(allocator, "1"); 37 | defer allocator.free(debug_expected); 38 | 39 | if (std.mem.eql(u16, std.process.getenvW(debug) orelse &[_]u16{0}, debug_expected)) current_log_level = .debug; 40 | } else { 41 | if (std.mem.eql(u8, std.posix.getenv("DEBUG") orelse "", "1")) current_log_level = .debug; 42 | } 43 | 44 | var args_it = try std.process.argsWithAllocator(allocator); 45 | defer args_it.deinit(); 46 | _ = args_it.skip(); 47 | 48 | const name_string = (args_it.next() orelse { 49 | logger.warn("no name provided", .{}); 50 | return error.InvalidArgs; 51 | }); 52 | 53 | const qtype_str = (args_it.next() orelse { 54 | logger.warn("no qtype provided", .{}); 55 | return error.InvalidArgs; 56 | }); 57 | 58 | const qtype = dns.ResourceType.fromString(qtype_str) catch |err| switch (err) { 59 | error.InvalidResourceType => { 60 | logger.warn("invalid query type provided", .{}); 61 | return error.InvalidArgs; 62 | }, 63 | }; 64 | 65 | var name_buffer: [128][]const u8 = undefined; 66 | const name = try dns.Name.fromString(name_string, &name_buffer); 67 | 68 | var questions = [_]dns.Question{ 69 | .{ 70 | .name = name, 71 | .typ = qtype, 72 | .class = .IN, 73 | }, 74 | }; 75 | 76 | var empty = [0]dns.Resource{}; 77 | 78 | // create question packet 79 | var packet = dns.Packet{ 80 | .header = .{ 81 | .id = dns.helpers.randomHeaderId(), 82 | .is_response = false, 83 | .wanted_recursion = true, 84 | .question_length = 1, 85 | }, 86 | .questions = &questions, 87 | .answers = &empty, 88 | .nameservers = &empty, 89 | .additionals = &empty, 90 | }; 91 | 92 | logger.debug("packet: {}", .{packet}); 93 | 94 | const conn = if (builtin.os.tag == .windows) try dns.helpers.connectToResolver("8.8.8.8", null) else try dns.helpers.connectToSystemResolver(); 95 | defer conn.close(); 96 | 97 | logger.info("selected nameserver: {}\n", .{conn.address}); 98 | const stdout = std.io.getStdOut(); 99 | 100 | // print out our same question as a zone file for debugging purposes 101 | try dns.helpers.printAsZoneFile(&packet, undefined, stdout.writer()); 102 | 103 | try conn.sendPacket(packet); 104 | 105 | // as we need Names inside the NamePool to live beyond the call to 106 | // receiveFullPacket (since we need to deserialize names in RDATA) 107 | // we must take ownership of them and deinit ourselves 108 | var name_pool = dns.NamePool.init(allocator); 109 | defer name_pool.deinitWithNames(); 110 | 111 | const reply = try conn.receiveFullPacket( 112 | allocator, 113 | 4096, 114 | .{ .name_pool = &name_pool }, 115 | ); 116 | defer reply.deinit(.{ .names = false }); 117 | 118 | const reply_packet = reply.packet; 119 | logger.debug("reply: {}", .{reply_packet}); 120 | 121 | try std.testing.expectEqual(packet.header.id, reply_packet.header.id); 122 | try std.testing.expect(reply_packet.header.is_response); 123 | 124 | try dns.helpers.printAsZoneFile(reply_packet, &name_pool, stdout.writer()); 125 | } 126 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # zigdig 2 | 3 | naive dns client library in zig 4 | 5 | help me decide if this api is good: https://github.com/lun-4/zigdig/issues/10 6 | 7 | ## what does it do 8 | - serialization and deserialization of dns packets as per rfc1035 9 | - supports a subset of rdata (i do not have any plans to support 100% of DNS, but SRV/MX/TXT/A/AAAA 10 | are there, which most likely will be enough for your use cases) 11 | - has helpers for reading `/etc/resolv.conf` (not that much, really) 12 | 13 | ## what does it not do 14 | - no edns0 15 | - support all resolv.conf options 16 | - can deserialize pointer labels, but does not serialize into pointers 17 | - follow CNAME records, this provides only the basic 18 | serialization/deserializtion 19 | 20 | ## how do 21 | 22 | - zig 0.14.0: https://ziglang.org 23 | - have a `/etc/resolv.conf` 24 | - tested on linux, should work on bsd i think 25 | 26 | ``` 27 | git clone ... 28 | cd zigdig 29 | 30 | zig build test 31 | zig build install --prefix ~/.local/ 32 | ``` 33 | 34 | and then 35 | 36 | ```bash 37 | zigdig google.com a 38 | ``` 39 | 40 | or, for the host(1) equivalent 41 | 42 | ```bash 43 | zigdig-tiny google.com 44 | ``` 45 | 46 | ## using the library 47 | 48 | ### getAddressList-style api 49 | 50 | ```zig 51 | const dns = @import("dns"); 52 | 53 | pub fn main() !void { 54 | var gpa = std.heap.GeneralPurposeAllocator(.{}){}; 55 | defer { 56 | _ = gpa.deinit(); 57 | } 58 | var allocator = gpa.alloator(); 59 | 60 | var addresses = try dns.helpers.getAddressList("ziglang.org", allocator); 61 | defer addresses.deinit(); 62 | 63 | for (addresses.addrs) |address| { 64 | std.debug.print("we live in a society {}\n", .{address}); 65 | } 66 | } 67 | ``` 68 | 69 | ### full api 70 | 71 | ```zig 72 | const dns = @import("dns"); 73 | 74 | pub fn main() !void { 75 | var gpa = std.heap.GeneralPurposeAllocator(.{}){}; 76 | defer { 77 | _ = gpa.deinit(); 78 | } 79 | var allocator = gpa.alloator(); 80 | 81 | var name_buffer: [128][]const u8 = undefined; 82 | const name = try dns.Name.fromString("ziglang.org", &name_buffer); 83 | 84 | var questions = [_]dns.Question{ 85 | .{ 86 | .name = name, 87 | .typ = .A, 88 | .class = .IN, 89 | }, 90 | }; 91 | 92 | var packet = dns.Packet{ 93 | .header = .{ 94 | .id = dns.helpers.randomHeaderId(), 95 | .is_response = false, 96 | .wanted_recursion = true, 97 | .question_length = 1, 98 | }, 99 | .questions = &questions, 100 | .answers = &[_]dns.Resource{}, 101 | .nameservers = &[_]dns.Resource{}, 102 | .additionals = &[_]dns.Resource{}, 103 | }; 104 | 105 | // use helper function to connect to a resolver in the systems' 106 | // resolv.conf 107 | 108 | const conn = try dns.helpers.connectToSystemResolver(); 109 | defer conn.close(); 110 | 111 | try conn.sendPacket(packet); 112 | 113 | // you can also do this to support any Writer 114 | // const written_bytes = try packet.writeTo(some_fun_writer_goes_here); 115 | 116 | const reply = try conn.receivePacket(allocator, 4096); 117 | defer reply.deinit(); 118 | 119 | // you can also do this to support any Reader 120 | // const packet = try dns.Packet.readFrom(some_fun_reader, allocator); 121 | // defer packet.deinit(); 122 | 123 | const reply_packet = reply.packet; 124 | logger.info("reply: {}", .{reply_packet}); 125 | 126 | try std.testing.expectEqual(packet.header.id, reply_packet.header.id); 127 | try std.testing.expect(reply_packet.header.is_response); 128 | 129 | // ASSERTS that there's one A resource in the answer!!! you should verify 130 | // reply_packet.header.opcode to see if there's any errors 131 | 132 | const resource = reply_packet.answers[0]; 133 | var resource_data = try dns.ResourceData.fromOpaque( 134 | reply_packet, 135 | resource.typ, 136 | resource.opaque_rdata, 137 | allocator 138 | ); 139 | defer resource_data.deinit(allocator); 140 | 141 | // you now have an std.net.Address to use to your hearts content 142 | const ziglang_address = resource_data.A; 143 | } 144 | 145 | ``` 146 | 147 | it is recommended to look at zigdig's source on `src/main.zig` to understand 148 | how things tick using the library, but it boils down to three things: 149 | - packet generation and serialization 150 | - sending/receiving (via a small shim on top of std.os.socket) 151 | - packet deserialization 152 | -------------------------------------------------------------------------------- /src/test.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const testing = std.testing; 3 | const Allocator = std.mem.Allocator; 4 | const io = std.io; 5 | 6 | const dns = @import("lib.zig"); 7 | const helpers = @import("helpers.zig"); 8 | const Packet = dns.Packet; 9 | 10 | test "convert domain string to dns name" { 11 | const domain = "www.google.com"; 12 | var name_buffer: [3][]const u8 = undefined; 13 | const name = (try dns.Name.fromString(domain[0..], &name_buffer)).full; 14 | try std.testing.expectEqual(3, name.labels.len); 15 | try std.testing.expectEqualStrings("www", name.labels[0]); 16 | try std.testing.expectEqualStrings("google", name.labels[1]); 17 | try std.testing.expectEqualStrings("com", name.labels[2]); 18 | } 19 | 20 | test "convert domain string to dns name (buffer overflow case)" { 21 | const domain = "www.google.com"; 22 | var name_buffer: [1][]const u8 = undefined; 23 | _ = dns.Name.fromString(domain[0..], &name_buffer) catch |err| switch (err) { 24 | error.Overflow => {}, 25 | else => return err, 26 | }; 27 | } 28 | 29 | // extracted with 'dig google.com a +noedns' 30 | const TEST_PKT_QUERY = "FEUBIAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ=="; 31 | const TEST_PKT_RESPONSE = "RM2BgAABAAEAAAAABmdvb2dsZQNjb20AAAEAAcAMAAEAAQAAASwABNg6yo4="; 32 | const GOOGLE_COM_LABELS = [_][]const u8{ "google"[0..], "com"[0..] }; 33 | 34 | test "Packet serialize/deserialize" { 35 | const random_id = dns.helpers.randomHeaderId(); 36 | const packet = dns.Packet{ 37 | .header = .{ .id = random_id }, 38 | .questions = &[_]dns.Question{}, 39 | .answers = &[_]dns.Resource{}, 40 | .nameservers = &[_]dns.Resource{}, 41 | .additionals = &[_]dns.Resource{}, 42 | }; 43 | 44 | // then we'll serialize it under a buffer on the stack, 45 | // deserialize it, and the header.id should be equal to random_id 46 | var write_buffer: [1024]u8 = undefined; 47 | const buf = try serialTest(packet, &write_buffer); 48 | 49 | // deserialize it and compare if everythings' equal 50 | var incoming = try deserialTest(buf); 51 | defer incoming.deinit(.{}); 52 | const deserialized = incoming.packet; 53 | 54 | try std.testing.expectEqual(deserialized.header.id, packet.header.id); 55 | 56 | const fields = [_][]const u8{ "id", "opcode", "question_length", "answer_length" }; 57 | 58 | const new_header = deserialized.header; 59 | const header = packet.header; 60 | 61 | inline for (fields) |field| { 62 | try std.testing.expectEqual( 63 | @field(new_header, field), 64 | @field(header, field), 65 | ); 66 | } 67 | } 68 | 69 | fn decodeBase64(encoded: []const u8, write_buffer: []u8) ![]const u8 { 70 | const size = try std.base64.standard.Decoder.calcSizeForSlice(encoded); 71 | try std.base64.standard.Decoder.decode(write_buffer[0..size], encoded); 72 | return write_buffer[0..size]; 73 | } 74 | 75 | fn expectGoogleLabels(actual: [][]const u8) !void { 76 | for (actual, 0..) |label, idx| { 77 | try std.testing.expectEqualSlices(u8, label, GOOGLE_COM_LABELS[idx]); 78 | } 79 | } 80 | 81 | test "deserialization of original question google.com/A" { 82 | var write_buffer: [0x10000]u8 = undefined; 83 | 84 | const decoded = try decodeBase64(TEST_PKT_QUERY, &write_buffer); 85 | 86 | var incoming = try deserialTest(decoded); 87 | defer incoming.deinit(.{}); 88 | const pkt = incoming.packet; 89 | 90 | try std.testing.expectEqual(@as(u16, 5189), pkt.header.id); 91 | try std.testing.expectEqual(@as(u16, 1), pkt.header.question_length); 92 | try std.testing.expectEqual(@as(u16, 0), pkt.header.answer_length); 93 | try std.testing.expectEqual(@as(u16, 0), pkt.header.nameserver_length); 94 | try std.testing.expectEqual(@as(u16, 0), pkt.header.additional_length); 95 | try std.testing.expectEqual(@as(usize, 1), pkt.questions.len); 96 | 97 | const question = pkt.questions[0]; 98 | 99 | try expectGoogleLabels(question.name.?.full.labels); 100 | try std.testing.expectEqual(@as(usize, 12), question.name.?.full.packet_index.?); 101 | try std.testing.expectEqual(question.typ, dns.ResourceType.A); 102 | try std.testing.expectEqual(question.class, dns.ResourceClass.IN); 103 | } 104 | 105 | test "deserialization of reply google.com/A" { 106 | var encode_buffer: [0x10000]u8 = undefined; 107 | const decoded = try decodeBase64(TEST_PKT_RESPONSE, &encode_buffer); 108 | 109 | var incoming = try deserialTest(decoded); 110 | defer incoming.deinit(.{}); 111 | const pkt = incoming.packet; 112 | 113 | try std.testing.expectEqual(@as(u16, 17613), pkt.header.id); 114 | try std.testing.expectEqual(@as(u16, 1), pkt.header.question_length); 115 | try std.testing.expectEqual(@as(u16, 1), pkt.header.answer_length); 116 | try std.testing.expectEqual(@as(u16, 0), pkt.header.nameserver_length); 117 | try std.testing.expectEqual(@as(u16, 0), pkt.header.additional_length); 118 | 119 | const question = pkt.questions[0]; 120 | 121 | try expectGoogleLabels(question.name.?.full.labels); 122 | try testing.expectEqual(dns.ResourceType.A, question.typ); 123 | try testing.expectEqual(dns.ResourceClass.IN, question.class); 124 | 125 | const answer = pkt.answers[0]; 126 | 127 | try expectGoogleLabels(answer.name.?.full.labels); 128 | try testing.expectEqual(dns.ResourceType.A, answer.typ); 129 | try testing.expectEqual(dns.ResourceClass.IN, answer.class); 130 | try testing.expectEqual(@as(i32, 300), answer.ttl); 131 | 132 | const resource_data = try dns.ResourceData.fromOpaque( 133 | .A, 134 | answer.opaque_rdata.?, 135 | .{}, 136 | ); 137 | 138 | try testing.expectEqual( 139 | dns.ResourceType.A, 140 | @as(dns.ResourceType, resource_data), 141 | ); 142 | 143 | const addr = @as(*const [4]u8, @ptrCast(&resource_data.A.in.sa.addr)).*; 144 | try testing.expectEqual(@as(u8, 216), addr[0]); 145 | try testing.expectEqual(@as(u8, 58), addr[1]); 146 | try testing.expectEqual(@as(u8, 202), addr[2]); 147 | try testing.expectEqual(@as(u8, 142), addr[3]); 148 | } 149 | 150 | fn encodeBase64(buffer: []u8, source: []const u8) []const u8 { 151 | const encoded = buffer[0..std.base64.standard.Encoder.calcSize(source.len)]; 152 | return std.base64.standard.Encoder.encode(encoded, source); 153 | } 154 | 155 | fn encodePacket(pkt: Packet, encode_buffer: []u8, write_buffer: []u8) ![]const u8 { 156 | const out = try serialTest(pkt, write_buffer); 157 | return encodeBase64(encode_buffer, out); 158 | } 159 | 160 | test "serialization of google.com/A (question)" { 161 | const domain = "google.com"; 162 | var name_buffer: [2][]const u8 = undefined; 163 | const name = try dns.Name.fromString(domain[0..], &name_buffer); 164 | 165 | var questions = [_]dns.Question{.{ 166 | .name = name, 167 | .typ = .A, 168 | .class = .IN, 169 | }}; 170 | 171 | var empty = [0]dns.Resource{}; 172 | 173 | const packet = dns.Packet{ 174 | .header = .{ 175 | .id = 5189, 176 | .wanted_recursion = true, 177 | .z = 2, 178 | .question_length = 1, 179 | }, 180 | .questions = &questions, 181 | .answers = &empty, 182 | .nameservers = &empty, 183 | .additionals = &empty, 184 | }; 185 | 186 | var encode_buffer: [256]u8 = undefined; 187 | var write_buffer: [256]u8 = undefined; 188 | const encoded = try encodePacket(packet, &encode_buffer, &write_buffer); 189 | try std.testing.expectEqualSlices(u8, TEST_PKT_QUERY, encoded); 190 | } 191 | 192 | fn serialTest(packet: Packet, write_buffer: []u8) ![]u8 { 193 | const typ = std.io.FixedBufferStream([]u8); 194 | var stream = typ{ .buffer = write_buffer, .pos = 0 }; 195 | 196 | const written_bytes = try packet.writeTo(stream.writer()); 197 | const written_data = stream.getWritten(); 198 | try std.testing.expectEqual(written_bytes, written_data.len); 199 | 200 | return written_data; 201 | } 202 | 203 | const FixedStream = std.io.FixedBufferStream([]const u8); 204 | fn deserialTest(packet_data: []const u8) !dns.IncomingPacket { 205 | var stream = FixedStream{ .buffer = packet_data, .pos = 0 }; 206 | return try dns.helpers.parseFullPacket( 207 | stream.reader(), 208 | std.testing.allocator, 209 | .{}, 210 | ); 211 | } 212 | 213 | test "convert string to dns type" { 214 | const parsed = try dns.ResourceType.fromString("AAAA"); 215 | try std.testing.expectEqual(dns.ResourceType.AAAA, parsed); 216 | } 217 | 218 | test "names have good sizes" { 219 | var name_buffer: [10][]const u8 = undefined; 220 | var name = try dns.Name.fromString("example.com", &name_buffer); 221 | 222 | var buf: [256]u8 = undefined; 223 | var stream = std.io.FixedBufferStream([]u8){ .buffer = &buf, .pos = 0 }; 224 | const network_size = try name.writeTo(stream.writer()); 225 | 226 | // length + data + length + data + null 227 | try testing.expectEqual(@as(usize, 1 + 7 + 1 + 3 + 1), network_size); 228 | } 229 | 230 | test "resources have good sizes" { 231 | var name_buffer: [10][]const u8 = undefined; 232 | var name = try dns.Name.fromString("example.com", &name_buffer); 233 | 234 | var resource = dns.Resource{ 235 | .name = name, 236 | .typ = .A, 237 | .class = .IN, 238 | .ttl = 300, 239 | .opaque_rdata = .{ .data = "", .current_byte_count = 0 }, 240 | }; 241 | 242 | var buf: [256]u8 = undefined; 243 | var stream = std.io.FixedBufferStream([]u8){ .buffer = &buf, .pos = 0 }; 244 | const network_size = try resource.writeTo(stream.writer()); 245 | 246 | // name + rr (2) + class (2) + ttl (4) + rdlength (2) 247 | try testing.expectEqual( 248 | @as(usize, name.networkSize() + 10 + resource.opaque_rdata.?.data.len), 249 | network_size, 250 | ); 251 | } 252 | 253 | // This is a known packet generated by zigdig. It would be welcome to have it 254 | // tested in other libraries. 255 | const PACKET_WITH_RDATA = "FEUBIAAAAAEAAAAABmdvb2dsZQNjb20AAAEAAQAAASwABAEAAH8="; 256 | 257 | test "rdata serialization" { 258 | var name_buffer: [2][]const u8 = undefined; 259 | const name = try dns.Name.fromString("google.com", &name_buffer); 260 | var resource_data = dns.ResourceData{ 261 | .A = try std.net.Address.parseIp4("127.0.0.1", 0), 262 | }; 263 | 264 | var opaque_rdata_buffer: [1024]u8 = undefined; 265 | var stream = std.io.fixedBufferStream(&opaque_rdata_buffer); 266 | _ = try resource_data.writeTo(stream.writer()); 267 | const opaque_rdata = stream.getWritten(); 268 | 269 | var answers = [_]dns.Resource{.{ 270 | .name = name, 271 | .typ = .A, 272 | .class = .IN, 273 | .ttl = 300, 274 | .opaque_rdata = .{ .data = opaque_rdata, .current_byte_count = 0 }, 275 | }}; 276 | 277 | var empty_res = [_]dns.Resource{}; 278 | var empty_question = [_]dns.Question{}; 279 | const packet = dns.Packet{ 280 | .header = .{ 281 | .id = 5189, 282 | .wanted_recursion = true, 283 | .z = 2, 284 | .answer_length = 1, 285 | }, 286 | .questions = &empty_question, 287 | .answers = &answers, 288 | .nameservers = &empty_res, 289 | .additionals = &empty_res, 290 | }; 291 | 292 | var write_buffer: [1024]u8 = undefined; 293 | const serialized_result = try serialTest(packet, &write_buffer); 294 | 295 | var encode_buffer: [1024]u8 = undefined; 296 | const encoded_result = encodeBase64(&encode_buffer, serialized_result); 297 | try std.testing.expectEqualStrings(PACKET_WITH_RDATA, encoded_result); 298 | } 299 | 300 | test "localhost always resolves to 127.0.0.1" { 301 | const addrs = try helpers.getAddressList("localhost", 80, std.testing.allocator); 302 | defer addrs.deinit(); 303 | try std.testing.expectEqual(16777343, addrs.addrs[1].in.sa.addr); 304 | try std.testing.expectEqualStrings("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", &addrs.addrs[0].in6.sa.addr); 305 | } 306 | 307 | test "everything" { 308 | std.testing.refAllDecls(@This()); 309 | std.testing.refAllDecls(@import("name.zig")); 310 | std.testing.refAllDecls(@import("cidr.zig")); 311 | } 312 | -------------------------------------------------------------------------------- /src/resource_data.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const fmt = std.fmt; 3 | 4 | const dns = @import("lib.zig"); 5 | const pkt = @import("packet.zig"); 6 | const Type = dns.ResourceType; 7 | 8 | const logger = std.log.scoped(.dns_rdata); 9 | 10 | pub const SOAData = struct { 11 | mname: ?dns.Name, 12 | rname: ?dns.Name, 13 | serial: u32, 14 | refresh: u32, 15 | retry: u32, 16 | expire: u32, 17 | minimum: u32, 18 | }; 19 | 20 | pub const MXData = struct { 21 | preference: u16, 22 | exchange: ?dns.Name, 23 | }; 24 | 25 | pub const SRVData = struct { 26 | priority: u16, 27 | weight: u16, 28 | port: u16, 29 | target: ?dns.Name, 30 | }; 31 | 32 | fn maybeReadResourceName( 33 | reader: anytype, 34 | options: ResourceData.ParseOptions, 35 | ) !?dns.Name { 36 | return switch (options.name_provider) { 37 | .none => null, 38 | .raw => |allocator| try dns.Name.readFrom(reader, .{ .allocator = allocator }), 39 | .full => |name_pool| blk: { 40 | const name = try dns.Name.readFrom( 41 | reader, 42 | .{ .allocator = name_pool.allocator }, 43 | ); 44 | break :blk try name_pool.transmuteName(name.?); 45 | }, 46 | }; 47 | } 48 | 49 | /// Common representations of DNS' Resource Data. 50 | pub const ResourceData = union(Type) { 51 | A: std.net.Address, 52 | 53 | NS: ?dns.Name, 54 | MD: ?dns.Name, 55 | MF: ?dns.Name, 56 | CNAME: ?dns.Name, 57 | SOA: SOAData, 58 | 59 | MB: ?dns.Name, 60 | MG: ?dns.Name, 61 | MR: ?dns.Name, 62 | 63 | // ???? 64 | NULL: void, 65 | 66 | // TODO WKS bit map 67 | WKS: struct { 68 | addr: u32, 69 | proto: u8, 70 | // how to define bit map? align(8)? 71 | }, 72 | PTR: ?dns.Name, 73 | 74 | // TODO replace []const u8 by Name? 75 | HINFO: struct { 76 | cpu: []const u8, 77 | os: []const u8, 78 | }, 79 | MINFO: struct { 80 | rmailbx: ?dns.Name, 81 | emailbx: ?dns.Name, 82 | }, 83 | MX: MXData, 84 | TXT: ?[]const u8, 85 | AAAA: std.net.Address, 86 | SRV: SRVData, 87 | OPT: void, // EDNS0 is not implemented 88 | 89 | const Self = @This(); 90 | 91 | pub fn networkSize(self: Self) usize { 92 | return switch (self) { 93 | .A => 4, 94 | .AAAA => 16, 95 | .NS, .MD, .MF, .MB, .MG, .MR, .CNAME, .PTR => |name| name.size(), 96 | .TXT => |text| blk: { 97 | var len: usize = 0; 98 | len += @sizeOf(u16) * text.len; 99 | for (text) |string| { 100 | len += string.len; 101 | } 102 | break :blk len; 103 | }, 104 | 105 | else => @panic("TODO"), 106 | }; 107 | } 108 | 109 | /// Format the RData into a human-readable form of it. 110 | /// 111 | /// For example, a resource data of type A would be 112 | /// formatted to its representing IPv4 address. 113 | pub fn format( 114 | self: Self, 115 | comptime f: []const u8, 116 | options: fmt.FormatOptions, 117 | writer: anytype, 118 | ) !void { 119 | _ = f; 120 | _ = options; 121 | 122 | switch (self) { 123 | .A, .AAAA => |addr| return fmt.format(writer, "{}", .{addr}), 124 | 125 | .NS, .MD, .MF, .MB, .MG, .MR, .CNAME, .PTR => |name| return fmt.format(writer, "{?}", .{name}), 126 | 127 | .SOA => |soa| return fmt.format(writer, "{?} {?} {} {} {} {} {}", .{ 128 | soa.mname, 129 | soa.rname, 130 | soa.serial, 131 | soa.refresh, 132 | soa.retry, 133 | soa.expire, 134 | soa.minimum, 135 | }), 136 | 137 | .MX => |mx| return fmt.format(writer, "{} {?}", .{ mx.preference, mx.exchange }), 138 | .SRV => |srv| return fmt.format(writer, "{} {} {} {?}", .{ 139 | srv.priority, 140 | srv.weight, 141 | srv.port, 142 | srv.target, 143 | }), 144 | 145 | .TXT => |text| return fmt.format(writer, "{?s}", .{text}), 146 | else => return fmt.format(writer, "TODO support {s}", .{@tagName(self)}), 147 | } 148 | } 149 | 150 | pub fn writeTo(self: Self, writer: anytype) !usize { 151 | return switch (self) { 152 | .A => |addr| blk: { 153 | try writer.writeInt(u32, addr.in.sa.addr, .big); 154 | break :blk @sizeOf(@TypeOf(addr.in.sa.addr)); 155 | }, 156 | .AAAA => |addr| try writer.write(&addr.in6.sa.addr), 157 | 158 | .NS, .MD, .MF, .MB, .MG, .MR, .CNAME, .PTR => |name| try name.?.writeTo(writer), 159 | 160 | .SOA => |soa_data| blk: { 161 | const mname_size = try soa_data.mname.?.writeTo(writer); 162 | const rname_size = try soa_data.rname.?.writeTo(writer); 163 | 164 | try writer.writeInt(u32, soa_data.serial, .big); 165 | try writer.writeInt(u32, soa_data.refresh, .big); 166 | try writer.writeInt(u32, soa_data.retry, .big); 167 | try writer.writeInt(u32, soa_data.expire, .big); 168 | try writer.writeInt(u32, soa_data.minimum, .big); 169 | 170 | break :blk mname_size + rname_size + (5 * @sizeOf(u32)); 171 | }, 172 | 173 | .MX => |mxdata| blk: { 174 | try writer.writeInt(u16, mxdata.preference, .big); 175 | const exchange_size = try mxdata.exchange.?.writeTo(writer); 176 | break :blk @sizeOf(@TypeOf(mxdata.preference)) + exchange_size; 177 | }, 178 | 179 | .SRV => |srv| { 180 | try writer.writeInt(u16, srv.priority, .big); 181 | try writer.writeInt(u16, srv.weight, .big); 182 | try writer.writeInt(u16, srv.port, .big); 183 | 184 | const target_size = try srv.target.?.writeTo(writer); 185 | return target_size + (3 * @sizeOf(u16)); 186 | }, 187 | 188 | // TODO TXT 189 | 190 | else => @panic("not implemented"), 191 | }; 192 | } 193 | 194 | pub fn deinit(self: Self, allocator: std.mem.Allocator) void { 195 | switch (self) { 196 | .NS, .MD, .MF, .MB, .MG, .MR, .CNAME, .PTR => |maybe_name| if (maybe_name) |name| name.deinit(allocator), 197 | .SOA => |soa_data| { 198 | if (soa_data.mname) |name| name.deinit(allocator); 199 | if (soa_data.rname) |name| name.deinit(allocator); 200 | }, 201 | .MX => |mxdata| if (mxdata.exchange) |name| name.deinit(allocator), 202 | .SRV => |srv| if (srv.target) |name| name.deinit(allocator), 203 | .TXT => |maybe_data| if (maybe_data) |data| allocator.free(data), 204 | else => {}, 205 | } 206 | } 207 | 208 | pub const Opaque = struct { 209 | data: []const u8, 210 | current_byte_count: usize, 211 | }; 212 | 213 | pub const NameProvider = union(enum) { 214 | none: void, 215 | raw: std.mem.Allocator, 216 | full: *dns.NamePool, 217 | }; 218 | 219 | pub const ParseOptions = struct { 220 | name_provider: NameProvider = NameProvider.none, 221 | allocator: ?std.mem.Allocator = null, 222 | }; 223 | 224 | /// Deserialize a given opaque resource data. 225 | /// 226 | /// Call deinit() with the same allocator. 227 | pub fn fromOpaque( 228 | resource_type: dns.ResourceType, 229 | opaque_resource_data: Opaque, 230 | options: ParseOptions, 231 | ) !ResourceData { 232 | const BufferT = std.io.FixedBufferStream([]const u8); 233 | var stream = BufferT{ .buffer = opaque_resource_data.data, .pos = 0 }; 234 | const underlying_reader = stream.reader(); 235 | 236 | // important to keep track of that rdata's position in the packet 237 | // as rdata could point to other rdata. 238 | var parser_ctx = dns.ParserContext{ 239 | .current_byte_count = opaque_resource_data.current_byte_count, 240 | }; 241 | 242 | const WrapperR = dns.parserlib.WrapperReader(BufferT.Reader); 243 | var wrapper_reader = WrapperR{ 244 | .underlying_reader = underlying_reader, 245 | .ctx = &parser_ctx, 246 | }; 247 | var reader = wrapper_reader.reader(); 248 | 249 | return switch (resource_type) { 250 | .A => blk: { 251 | var ip4addr: [4]u8 = undefined; 252 | _ = try reader.read(&ip4addr); 253 | break :blk ResourceData{ 254 | .A = std.net.Address.initIp4(ip4addr, 0), 255 | }; 256 | }, 257 | .AAAA => blk: { 258 | var ip6_addr: [16]u8 = undefined; 259 | _ = try reader.read(&ip6_addr); 260 | break :blk ResourceData{ 261 | .AAAA = std.net.Address.initIp6(ip6_addr, 0, 0, 0), 262 | }; 263 | }, 264 | 265 | .NS => ResourceData{ .NS = try maybeReadResourceName(reader, options) }, 266 | .CNAME => ResourceData{ .CNAME = try maybeReadResourceName(reader, options) }, 267 | .PTR => ResourceData{ .PTR = try maybeReadResourceName(reader, options) }, 268 | .MD => ResourceData{ .MD = try maybeReadResourceName(reader, options) }, 269 | .MF => ResourceData{ .MF = try maybeReadResourceName(reader, options) }, 270 | 271 | .MX => blk: { 272 | break :blk ResourceData{ 273 | .MX = MXData{ 274 | .preference = try reader.readInt(u16, .big), 275 | .exchange = try maybeReadResourceName(reader, options), 276 | }, 277 | }; 278 | }, 279 | 280 | .SOA => blk: { 281 | const mname = try maybeReadResourceName(reader, options); 282 | const rname = try maybeReadResourceName(reader, options); 283 | const serial = try reader.readInt(u32, .big); 284 | const refresh = try reader.readInt(u32, .big); 285 | const retry = try reader.readInt(u32, .big); 286 | const expire = try reader.readInt(u32, .big); 287 | const minimum = try reader.readInt(u32, .big); 288 | 289 | break :blk ResourceData{ 290 | .SOA = SOAData{ 291 | .mname = mname, 292 | .rname = rname, 293 | .serial = serial, 294 | .refresh = refresh, 295 | .retry = retry, 296 | .expire = expire, 297 | .minimum = minimum, 298 | }, 299 | }; 300 | }, 301 | .SRV => blk: { 302 | const priority = try reader.readInt(u16, .big); 303 | const weight = try reader.readInt(u16, .big); 304 | const port = try reader.readInt(u16, .big); 305 | const target = try maybeReadResourceName(reader, options); 306 | break :blk ResourceData{ 307 | .SRV = .{ 308 | .priority = priority, 309 | .weight = weight, 310 | .port = port, 311 | .target = target, 312 | }, 313 | }; 314 | }, 315 | .TXT => blk: { 316 | const length = try reader.readInt(u8, .big); 317 | if (length > 256) return error.Overflow; 318 | 319 | if (options.allocator) |allocator| { 320 | const text = try allocator.alloc(u8, length); 321 | _ = try reader.read(text); 322 | 323 | break :blk ResourceData{ .TXT = text }; 324 | } else { 325 | try reader.skipBytes(length, .{}); 326 | break :blk ResourceData{ .TXT = null }; 327 | } 328 | }, 329 | 330 | else => { 331 | logger.warn("unexpected rdata: {}\n", .{resource_type}); 332 | return error.UnknownResourceType; 333 | }, 334 | }; 335 | } 336 | }; 337 | -------------------------------------------------------------------------------- /src/parser.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const dns = @import("lib.zig"); 3 | 4 | const logger = std.log.scoped(.dns_parser); 5 | 6 | /// Create a Parser object out of a reader, context, and options. 7 | /// 8 | /// If you do not wish to have full control over deserialization, look at 9 | /// dns.helpers.parseFullPacket, which is a wrapper around the Parser that 10 | /// allocates everything. 11 | pub fn parser( 12 | reader: anytype, 13 | ctx: *ParserContext, 14 | options: dns.ParserOptions, 15 | ) Parser(@TypeOf(reader)) { 16 | return Parser(@TypeOf(reader)).init(reader, ctx, options); 17 | } 18 | 19 | pub const ResourceResolutionOptions = struct { 20 | max_follow: usize = 32, 21 | }; 22 | 23 | const ParserState = enum { 24 | header, 25 | question, 26 | answer, 27 | nameserver, 28 | additional, 29 | answer_rdata, 30 | nameserver_rdata, 31 | additional_rdata, 32 | done, 33 | }; 34 | 35 | /// A given frame from the parser, depending on the given options, some frames 36 | /// will not be emitted by Parser.next, look at options for more information. 37 | pub const ParserFrame = union(enum) { 38 | header: dns.Header, 39 | 40 | question: dns.Question, 41 | end_question: void, 42 | 43 | answer: dns.Resource, 44 | answer_rdata: dns.parserlib.ResourceDataHolder, 45 | end_answer: void, 46 | 47 | nameserver: dns.Resource, 48 | nameserver_rdata: dns.parserlib.ResourceDataHolder, 49 | end_nameserver: void, 50 | 51 | additional: dns.Resource, 52 | additional_rdata: dns.parserlib.ResourceDataHolder, 53 | end_additional: void, 54 | }; 55 | 56 | pub const ResourceDataHolder = struct { 57 | size: usize, 58 | current_byte_index: usize, 59 | 60 | pub fn skip(self: @This(), reader: anytype) !void { 61 | try reader.skipBytes(self.size, .{}); 62 | } 63 | 64 | pub fn readAllAlloc( 65 | self: @This(), 66 | allocator: std.mem.Allocator, 67 | reader: anytype, 68 | ) !dns.ResourceData.Opaque { 69 | const opaque_rdata = try allocator.alloc(u8, self.size); 70 | const read_bytes = try reader.read(opaque_rdata); 71 | std.debug.assert(read_bytes == opaque_rdata.len); 72 | return .{ 73 | .data = opaque_rdata, 74 | .current_byte_count = self.current_byte_index, 75 | }; 76 | } 77 | }; 78 | 79 | pub const ParserOptions = struct { 80 | /// When given an allocator, the following happens: 81 | /// - the parser creates RawName or FullName entities for the 82 | /// respective entities with names on them. 83 | /// (RawName when names end in Pointers, FullName when not) 84 | /// - the parser will automatically allocate RDATA sections inside 85 | /// Resource entities. It is on the parser's client to free the memory 86 | /// (e.g by putting it inside an IncomingPacket's Packet) 87 | /// 88 | /// If allocator is null, the following happens: 89 | /// - The name fields will be set to null. 90 | /// - answer_rdata, nameserver_rdata, additional_rdata events are 91 | /// emitted so the client of the Parser interface can decide if they 92 | /// will be allocated, or parsed onto the stack, or something else. 93 | /// 94 | /// It is required to pass an allocator to have any access to name 95 | /// information. We can't parse the names in a standalone manner as 96 | /// they are usually the *first* field in a Question or Resource, so we 97 | /// need to decide if we read and allocate, or skip and don't. 98 | allocator: ?std.mem.Allocator = null, 99 | 100 | /// The maximum amount of labels in a name while parsing. 101 | /// 102 | /// Makes parser return `error.Overflow` when 103 | /// the given name to deserialize surpasses the value in this field. 104 | max_label_size: usize = 32, 105 | }; 106 | 107 | pub const ParserContext = struct { 108 | header: ?dns.Header = null, 109 | current_byte_count: usize = 0, 110 | current_counts: struct { 111 | question: usize = 0, 112 | answer: usize = 0, 113 | nameserver: usize = 0, 114 | additional: usize = 0, 115 | } = .{}, 116 | }; 117 | 118 | pub const DeserializationContext = struct { 119 | current_byte_count: usize = 0, 120 | }; 121 | 122 | /// Wrap a Reader with a type that contains a DeserializationContext. 123 | /// 124 | /// Automatically increments the DeserializationContext's current_byte_count 125 | /// on every read(). 126 | /// 127 | /// Useful to hold deserialization state without having to pass an entire 128 | /// parameter around on every single helper function. 129 | pub fn WrapperReader(comptime ReaderType: anytype) type { 130 | return struct { 131 | underlying_reader: ReaderType, 132 | ctx: *ParserContext, 133 | 134 | const Self = @This(); 135 | 136 | pub fn read(self: *Self, buffer: []u8) !usize { 137 | const bytes_read = try self.underlying_reader.read(buffer); 138 | self.ctx.current_byte_count += bytes_read; 139 | logger.debug( 140 | "wrapper reader: read {d} bytes, now at {d}", 141 | .{ bytes_read, self.ctx.current_byte_count }, 142 | ); 143 | return bytes_read; 144 | } 145 | 146 | pub const Error = ReaderType.Error; 147 | pub const Reader = std.io.Reader(*Self, Error, read); 148 | pub fn reader(self: *Self) Reader { 149 | return Reader{ .context = self }; 150 | } 151 | }; 152 | } 153 | 154 | /// Low level parser for DNS packets. 155 | /// 156 | /// There are two wrappers for this parser, dns.helpers.parseFullPacket, 157 | /// and dns.helpers.receiveTrustedAddresses. 158 | pub fn Parser(comptime ReaderType: type) type { 159 | const WrapperR = WrapperReader(ReaderType); 160 | 161 | return struct { 162 | state: ParserState = .header, 163 | wrapper_reader: WrapperR, 164 | options: ParserOptions, 165 | ctx: *ParserContext, 166 | 167 | const Self = @This(); 168 | 169 | pub fn init( 170 | incoming_reader: ReaderType, 171 | ctx: *ParserContext, 172 | options: ParserOptions, 173 | ) Self { 174 | const self = Self{ 175 | .wrapper_reader = WrapperR{ 176 | .underlying_reader = incoming_reader, 177 | .ctx = ctx, 178 | }, 179 | .options = options, 180 | .ctx = ctx, 181 | }; 182 | 183 | return self; 184 | } 185 | 186 | /// Receive the next frame from the parser. 187 | pub fn next(self: *Self) !?ParserFrame { 188 | // self.state dictates what we *want* from the reader 189 | // at the moment, first state always being header. 190 | logger.debug("next(): enter {}", .{self.state}); 191 | 192 | logger.debug( 193 | "parser reader is at {d} bytes of message", 194 | .{self.wrapper_reader.ctx.current_byte_count}, 195 | ); 196 | 197 | var reader = self.wrapper_reader.reader(); 198 | 199 | switch (self.state) { 200 | .header => { 201 | // since header is constant size, store it 202 | // in our parser state so we know how to continue 203 | const header = try dns.Header.readFrom(reader); 204 | self.ctx.header = header; 205 | self.state = .question; 206 | logger.debug( 207 | "next(): header read ({?}). state is now {}", 208 | .{ self.ctx.header, self.state }, 209 | ); 210 | return ParserFrame{ .header = header }; 211 | }, 212 | .question => { 213 | logger.debug("next(): read {d} out of {d} questions", .{ 214 | self.ctx.current_counts.question, 215 | self.ctx.header.?.question_length, 216 | }); 217 | 218 | self.ctx.current_counts.question += 1; 219 | 220 | if (self.ctx.current_counts.question > self.ctx.header.?.question_length) { 221 | self.state = .answer; 222 | logger.debug("parser: end question, go to resources", .{}); 223 | return ParserFrame{ .end_question = {} }; 224 | } else { 225 | const raw_question = try dns.Question.readFrom(reader, self.options); 226 | return ParserFrame{ .question = raw_question }; 227 | } 228 | }, 229 | .answer, .nameserver, .additional => { 230 | const count_holder = (switch (self.state) { 231 | .answer => &self.ctx.current_counts.answer, 232 | .nameserver => &self.ctx.current_counts.nameserver, 233 | .additional => &self.ctx.current_counts.additional, 234 | else => unreachable, 235 | }); 236 | 237 | const header_count = switch (self.state) { 238 | .answer => self.ctx.header.?.answer_length, 239 | .nameserver => self.ctx.header.?.nameserver_length, 240 | .additional => self.ctx.header.?.additional_length, 241 | else => unreachable, 242 | }; 243 | 244 | logger.debug("next(): read {d} out of {d} resources", .{ 245 | count_holder.*, header_count, 246 | }); 247 | 248 | count_holder.* += 1; 249 | 250 | if (count_holder.* > header_count) { 251 | const old_state = self.state; 252 | self.state = switch (self.state) { 253 | .answer => .nameserver, 254 | .nameserver => .additional, 255 | .additional => .done, 256 | else => unreachable, 257 | }; 258 | 259 | logger.debug( 260 | "end resource list. state transition {} -> {}", 261 | .{ old_state, self.state }, 262 | ); 263 | 264 | return switch (old_state) { 265 | .answer => ParserFrame{ .end_answer = {} }, 266 | .nameserver => ParserFrame{ .end_nameserver = {} }, 267 | .additional => ParserFrame{ .end_additional = {} }, 268 | else => unreachable, 269 | }; 270 | } else { 271 | const raw_resource = try dns.Resource.readFrom(reader, self.options); 272 | 273 | // not at end yet, which means resource_rdata event 274 | // must happen if we don't have allocator 275 | 276 | const old_state = self.state; 277 | 278 | // if we don't have allocator, we emit rdata records 279 | if (self.options.allocator == null) { 280 | self.state = switch (self.state) { 281 | .answer => .answer_rdata, 282 | .nameserver => .nameserver_rdata, 283 | .additional => .additional_rdata, 284 | else => unreachable, 285 | }; 286 | } 287 | 288 | logger.debug("resource from {}: {}", .{ old_state, raw_resource }); 289 | 290 | return switch (old_state) { 291 | .answer => ParserFrame{ .answer = raw_resource }, 292 | .nameserver => ParserFrame{ .nameserver = raw_resource }, 293 | .additional => ParserFrame{ .additional = raw_resource }, 294 | else => unreachable, 295 | }; 296 | } 297 | }, 298 | 299 | .answer_rdata, .nameserver_rdata, .additional_rdata => { 300 | const old_state = self.state; 301 | 302 | self.state = switch (self.state) { 303 | .answer_rdata => .answer, 304 | .nameserver_rdata => .nameserver, 305 | .additional_rdata => .additional, 306 | else => unreachable, 307 | }; 308 | 309 | const rdata_length = try reader.readInt(u16, .big); 310 | const rdata_index = reader.context.ctx.current_byte_count; 311 | const rdata = ResourceDataHolder{ 312 | .size = rdata_length, 313 | .current_byte_index = rdata_index, 314 | }; 315 | 316 | return switch (old_state) { 317 | .answer_rdata => ParserFrame{ .answer_rdata = rdata }, 318 | .nameserver_rdata => ParserFrame{ .nameserver_rdata = rdata }, 319 | .additional_rdata => ParserFrame{ .additional_rdata = rdata }, 320 | else => unreachable, 321 | }; 322 | }, 323 | 324 | .done => return null, 325 | } 326 | } 327 | }; 328 | } 329 | -------------------------------------------------------------------------------- /src/packet.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const dns = @import("lib.zig"); 3 | 4 | const Name = dns.Name; 5 | const ResourceType = dns.ResourceType; 6 | const ResourceClass = dns.ResourceClass; 7 | 8 | const logger = std.log.scoped(.dns_packet); 9 | 10 | /// Represents the response code of the packet. 11 | /// 12 | /// RCODE, in https://datatracker.ietf.org/doc/html/rfc1035#section-4.1.1 13 | pub const ResponseCode = enum(u4) { 14 | NoError = 0, 15 | 16 | /// Format error - The name server was unable to interpret the query. 17 | FormatError = 1, 18 | 19 | /// Server failure - The name server was unable to process this query 20 | /// due to a problem with the name server. 21 | ServerFailure = 2, 22 | 23 | /// Name Error - Meaningful only for responses from an authoritative name 24 | /// server, this code signifies that the domain name referenced in 25 | /// the query does not exist. 26 | NameError = 3, 27 | 28 | /// Not Implemented - The name server does not support the requested 29 | /// kind of query. 30 | NotImplemented = 4, 31 | 32 | /// Refused - The name server refuses to perform the specified 33 | /// operation for policy reasons. For example, a name server may not 34 | /// wish to provide the information to the particular requester, 35 | /// or a name server may not wish to perform a particular operation 36 | /// (e.g., zone transfer) for particular data. 37 | Refused = 5, 38 | }; 39 | 40 | /// OPCODE from https://datatracker.ietf.org/doc/html/rfc1035#section-4.1.1 41 | /// 42 | /// This value is set by the originator of a query and copied into the response. 43 | pub const OpCode = enum(u4) { 44 | /// a standard query (QUERY) 45 | Query = 0, 46 | /// an inverse query (IQUERY) 47 | InverseQuery = 1, 48 | /// a server status request (STATUS) 49 | ServerStatusRequest = 2, 50 | 51 | // rest is unused as per RFC1035 52 | }; 53 | 54 | /// Describes the header of a DNS packet. 55 | /// 56 | /// https://datatracker.ietf.org/doc/html/rfc1035#section-4.1.1 57 | pub const Header = packed struct { 58 | /// The ID of the packet. Replies to a packet MUST have the same ID. 59 | id: u16 = 0, 60 | 61 | /// Query/Response flag 62 | /// Defines if this is a response packet or not. 63 | is_response: bool = false, 64 | 65 | /// specifies kind of query in this message. 66 | opcode: OpCode = .Query, 67 | 68 | /// Authoritative Answer flag 69 | /// Only valid in response packets. Specifies if the server 70 | /// replying is an authority for the domain name. 71 | aa_flag: bool = false, 72 | 73 | /// TC flag - TrunCation. 74 | /// If the packet was truncated. 75 | truncated: bool = false, 76 | 77 | /// RD flag - Recursion Desired. 78 | /// Must be copied to a response packet. If set, the server 79 | /// handling the request can pursue the query recursively. 80 | wanted_recursion: bool = false, 81 | 82 | /// RA flag - Recursion Available 83 | /// Whether recursive query support is available on the server. 84 | recursion_available: bool = false, 85 | 86 | /// DO NOT USE. RFC1035 has not assigned anything to the Z bits 87 | z: u3 = 0, 88 | 89 | /// Response code. 90 | response_code: ResponseCode = .NoError, 91 | 92 | /// Amount of questions in the packet. 93 | question_length: u16 = 0, 94 | 95 | /// Amount of answers in the packet. 96 | answer_length: u16 = 0, 97 | 98 | /// Amount of nameservers in the packet. 99 | nameserver_length: u16 = 0, 100 | 101 | /// Amount of additional records in the packet. 102 | additional_length: u16 = 0, 103 | 104 | const Self = @This(); 105 | 106 | /// Read a header from its network representation in a stream. 107 | pub fn readFrom(byte_reader: anytype) !Self { 108 | var self = Self{}; 109 | 110 | // turn incoming reader into a bitReader so that we can extract 111 | // non-u8-aligned data from it 112 | var reader = std.io.bitReader(.big, byte_reader); 113 | 114 | const fields = @typeInfo(Self).@"struct".fields; 115 | inline for (fields) |field| { 116 | var out_bits: u16 = undefined; 117 | @field(self, field.name) = switch (field.type) { 118 | bool => (try reader.readBits(u1, 1, &out_bits)) > 0, 119 | u3 => try reader.readBits(u3, 3, &out_bits), 120 | u4 => try reader.readBits(u4, 4, &out_bits), 121 | OpCode, ResponseCode => blk: { 122 | const tag_int = try reader.readBits(u4, 4, &out_bits); 123 | break :blk try std.meta.intToEnum(field.type, tag_int); 124 | }, 125 | u16 => try byte_reader.readInt(field.type, .big), 126 | else => @compileError( 127 | "unsupported type on header " ++ @typeName(field.type), 128 | ), 129 | }; 130 | } 131 | return self; 132 | } 133 | 134 | /// Write the network representation of a header to the given writer. 135 | pub fn writeTo(self: Self, byte_writer: anytype) !usize { 136 | var writer = std.io.bitWriter(.big, byte_writer); 137 | 138 | var written_bits: usize = 0; 139 | 140 | const fields = @typeInfo(Self).@"struct".fields; 141 | inline for (fields) |field| { 142 | const value = @field(self, field.name); 143 | written_bits += @bitSizeOf(field.type); 144 | switch (field.type) { 145 | bool => try writer.writeBits(@as(u1, if (value) 1 else 0), 1), 146 | u3 => try writer.writeBits(value, 3), 147 | u4 => try writer.writeBits(value, 4), 148 | OpCode, ResponseCode => try writer.writeBits(@intFromEnum(value), 4), 149 | u16 => try writer.writeBits(value, 16), 150 | else => @compileError( 151 | "unsupported type on header " ++ @typeName(field.type), 152 | ), 153 | } 154 | } 155 | 156 | try writer.flushBits(); 157 | const written_bytes = written_bits / 8; 158 | std.debug.assert(written_bytes == 12); 159 | return written_bytes; 160 | } 161 | }; 162 | 163 | /// Represents a DNS question. 164 | /// 165 | /// https://datatracker.ietf.org/doc/html/rfc1035#section-4.1.2 166 | pub const Question = struct { 167 | name: ?dns.Name, 168 | typ: ResourceType, 169 | class: ResourceClass = .IN, 170 | 171 | const Self = @This(); 172 | 173 | pub fn readFrom(reader: anytype, options: dns.ParserOptions) !Self { 174 | // TODO assert reader is WrapperReader 175 | logger.debug( 176 | "reading question at {d} bytes", 177 | .{reader.context.ctx.current_byte_count}, 178 | ); 179 | 180 | const name = try Name.readFrom(reader, options); 181 | const qtype = try reader.readEnum(ResourceType, .big); 182 | const qclass = try ResourceClass.readFrom(reader); 183 | 184 | return Self{ 185 | .name = name, 186 | .typ = qtype, 187 | .class = qclass, 188 | }; 189 | } 190 | }; 191 | 192 | /// DNS resource 193 | /// 194 | /// https://datatracker.ietf.org/doc/html/rfc1035#section-4.1.3 195 | pub const Resource = struct { 196 | name: ?dns.Name, 197 | typ: ResourceType, 198 | class: ResourceClass, 199 | 200 | ttl: i32, 201 | 202 | /// Opaque Resource Data. This holds the bytes representing the RDATA 203 | /// section of the resource, with some metadata for pointer resolution. 204 | /// 205 | /// To parse this section, use dns.ResourceData.fromOpaque 206 | opaque_rdata: ?dns.ResourceData.Opaque, 207 | 208 | const Self = @This(); 209 | 210 | /// Extract an RDATA. This only spits out a slice of u8. 211 | /// Parsing of RDATA sections are in the dns.rdata module. 212 | /// 213 | /// Caller owns returned memory. 214 | fn readResourceDataFrom( 215 | reader: anytype, 216 | options: dns.ParserOptions, 217 | ) !?dns.ResourceData.Opaque { 218 | if (options.allocator) |allocator| { 219 | const rdata_length = try reader.readInt(u16, .big); 220 | const rdata_index = reader.context.ctx.current_byte_count; 221 | 222 | const opaque_rdata = try allocator.alloc(u8, rdata_length); 223 | const read_bytes = try reader.read(opaque_rdata); 224 | std.debug.assert(read_bytes == opaque_rdata.len); 225 | return .{ 226 | .data = opaque_rdata, 227 | .current_byte_count = rdata_index, 228 | }; 229 | } else { 230 | return null; 231 | } 232 | } 233 | 234 | pub fn readFrom(reader: anytype, options: dns.ParserOptions) !Self { 235 | // TODO assert reader is WrapperReader 236 | logger.debug( 237 | "reading resource at {d} bytes", 238 | .{reader.context.ctx.current_byte_count}, 239 | ); 240 | const name = try Name.readFrom(reader, options); 241 | const typ = try ResourceType.readFrom(reader); 242 | const class = try ResourceClass.readFrom(reader); 243 | const ttl = try reader.readInt(i32, .big); 244 | const opaque_rdata = try Self.readResourceDataFrom(reader, options); 245 | 246 | return Self{ 247 | .name = name, 248 | .typ = typ, 249 | .class = class, 250 | .ttl = ttl, 251 | .opaque_rdata = opaque_rdata, 252 | }; 253 | } 254 | 255 | pub fn writeTo(self: @This(), writer: anytype) !usize { 256 | const name_size = try self.name.?.writeTo(writer); 257 | const typ_size = try self.typ.writeTo(writer); 258 | const class_size = try self.class.writeTo(writer); 259 | const ttl_size = 32 / 8; 260 | try writer.writeInt(i32, self.ttl, .big); 261 | 262 | const rdata_prefix_size = 16 / 8; 263 | try writer.writeInt(u16, @as(u16, @intCast(self.opaque_rdata.?.data.len)), .big); 264 | const rdata_size = try writer.write(self.opaque_rdata.?.data); 265 | 266 | return name_size + typ_size + class_size + ttl_size + 267 | rdata_prefix_size + rdata_size; 268 | } 269 | }; 270 | 271 | /// A DNS packet, as specified in RFC1035. 272 | /// 273 | /// Beware, the amount of questions or resources given in this Packet 274 | /// MUST be synchronized with the lengths set in the Header field. 275 | /// 276 | /// https://datatracker.ietf.org/doc/html/rfc1035#section-4.1 277 | pub const Packet = struct { 278 | header: Header, 279 | questions: []Question, 280 | answers: []Resource, 281 | nameservers: []Resource, 282 | additionals: []Resource, 283 | 284 | const Self = @This(); 285 | 286 | fn writeResourceListTo(resource_list: []Resource, writer: anytype) !usize { 287 | var size: usize = 0; 288 | for (resource_list) |resource| { 289 | size += try resource.writeTo(writer); 290 | } 291 | return size; 292 | } 293 | 294 | /// Write the network representation of this packet into a Writer. 295 | pub fn writeTo(self: Self, writer: anytype) !usize { 296 | std.debug.assert(self.header.question_length == self.questions.len); 297 | std.debug.assert(self.header.answer_length == self.answers.len); 298 | std.debug.assert(self.header.nameserver_length == self.nameservers.len); 299 | std.debug.assert(self.header.additional_length == self.additionals.len); 300 | 301 | const header_size = try self.header.writeTo(writer); 302 | 303 | var question_size: usize = 0; 304 | 305 | for (self.questions) |question| { 306 | const question_name_size = try question.name.?.writeTo(writer); 307 | const question_typ_size = try question.typ.writeTo(writer); 308 | const question_class_size = try question.class.writeTo(writer); 309 | 310 | question_size += question_name_size + question_typ_size + question_class_size; 311 | } 312 | 313 | const answers_size = try Self.writeResourceListTo(self.answers, writer); 314 | const nameservers_size = try Self.writeResourceListTo(self.nameservers, writer); 315 | const additionals_size = try Self.writeResourceListTo(self.additionals, writer); 316 | 317 | logger.debug( 318 | "header = {d}, question_size = {d}, answers_size = {d}," ++ 319 | " nameservers_size = {d}, additionals_size = {d}", 320 | .{ header_size, question_size, answers_size, nameservers_size, additionals_size }, 321 | ); 322 | 323 | return header_size + question_size + 324 | answers_size + nameservers_size + additionals_size; 325 | } 326 | }; 327 | 328 | /// Represents a Packet where all of its data was allocated dynamically. 329 | pub const IncomingPacket = struct { 330 | allocator: std.mem.Allocator, 331 | packet: *Packet, 332 | 333 | fn freeResource( 334 | self: @This(), 335 | resource: Resource, 336 | options: DeinitOptions, 337 | ) void { 338 | if (options.names) 339 | if (resource.name) |name| name.deinit(self.allocator); 340 | if (resource.opaque_rdata) |opaque_rdata| 341 | self.allocator.free(opaque_rdata.data); 342 | } 343 | 344 | fn freeResourceList( 345 | self: @This(), 346 | resource_list: []Resource, 347 | options: DeinitOptions, 348 | ) void { 349 | for (resource_list) |resource| self.freeResource(resource, options); 350 | self.allocator.free(resource_list); 351 | } 352 | 353 | pub const DeinitOptions = struct { 354 | /// If the names inside the packet should be deinitialized or not. 355 | /// 356 | /// This should be set to false if you are passing ownership of the Name 357 | /// to dns.NamePool, as it has dns.NamePool.deinitWithNames(). 358 | names: bool = true, 359 | }; 360 | 361 | pub fn deinit(self: @This(), options: DeinitOptions) void { 362 | if (options.names) for (self.packet.questions) |question| { 363 | if (question.name) |name| name.deinit(self.allocator); 364 | }; 365 | 366 | self.allocator.free(self.packet.questions); 367 | self.freeResourceList(self.packet.answers, options); 368 | self.freeResourceList(self.packet.nameservers, options); 369 | self.freeResourceList(self.packet.additionals, options); 370 | 371 | self.allocator.destroy(self.packet); 372 | } 373 | }; 374 | -------------------------------------------------------------------------------- /src/cidr.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const net = std.net; 3 | const mem = std.mem; 4 | const fmt = std.fmt; 5 | 6 | pub const IpVersion = enum { 7 | v4, 8 | v6, 9 | }; 10 | 11 | pub const CidrParseError = error{ 12 | InvalidFormat, 13 | InvalidAddress, 14 | InvalidPrefixLength, 15 | AddressesExhausted, 16 | }; 17 | 18 | pub const CidrRange = struct { 19 | version: IpVersion, 20 | first_address: [16]u8, 21 | prefix_len: u8, 22 | 23 | const Self = @This(); 24 | 25 | /// Parse a CIDR notation string into a CidrRange 26 | pub fn parse(cidr: []const u8) !CidrRange { 27 | var it = std.mem.splitSequence(u8, cidr, "/"); 28 | const addr_str = it.next() orelse return error.InvalidFormat; 29 | const prefix_str = it.next() orelse return error.InvalidFormat; 30 | const must_be_null = it.next(); 31 | if (must_be_null != null) return error.InvalidFormat; 32 | 33 | const prefix_len = std.fmt.parseInt(u8, prefix_str, 10) catch return error.InvalidPrefixLength; 34 | 35 | // Try parsing as IPv4 36 | const maybe_ipv4 = net.Address.parseIp4(addr_str, 0) catch |err| switch (err) { 37 | else => null, 38 | }; 39 | if (maybe_ipv4) |ipv4| { 40 | // ipv4 only has 32 bits, so prefix_len can only be up to 32 too lol 41 | if (prefix_len > 32) return CidrParseError.InvalidPrefixLength; 42 | 43 | var result = CidrRange{ 44 | .version = .v4, 45 | .first_address = [_]u8{0} ** 16, 46 | .prefix_len = prefix_len, 47 | }; 48 | 49 | // implementation wise all addresses get mapped to ipv6 internally 50 | const bytes = std.mem.toBytes(ipv4.in.sa.addr); 51 | result.first_address[10] = 0xff; 52 | result.first_address[11] = 0xff; 53 | result.first_address[12] = bytes[0]; 54 | result.first_address[13] = bytes[1]; 55 | result.first_address[14] = bytes[2]; 56 | result.first_address[15] = bytes[3]; 57 | 58 | // Clear host portion 59 | const host_bits = 32 - prefix_len; 60 | if (prefix_len == 0) { 61 | // for /0, just set the entire address to 0 62 | std.mem.writeInt(u32, result.first_address[12..16], 0, .big); 63 | } else if (host_bits > 0) { 64 | const mask = ~(@as(u32, (@as(u32, 1) << @as(u5, @intCast(host_bits))) - 1)); 65 | const addr = std.mem.readInt(u32, result.first_address[12..16], .big); 66 | const masked = addr & mask; 67 | std.mem.writeInt(u32, result.first_address[12..16], masked, .big); 68 | } 69 | 70 | return result; 71 | } 72 | 73 | const maybe_ipv6: ?std.net.Address = std.net.Address.parseIp6(addr_str, 0) catch |err| switch (err) { 74 | else => null, 75 | }; 76 | if (maybe_ipv6) |ipv6| { 77 | if (prefix_len > 128) return CidrParseError.InvalidPrefixLength; 78 | 79 | var result = CidrRange{ 80 | .version = .v6, 81 | .first_address = ipv6.in6.sa.addr, 82 | .prefix_len = prefix_len, 83 | }; 84 | 85 | // clear host portion 86 | const full_bytes = prefix_len / 8; 87 | const remaining_bits = prefix_len % 8; 88 | 89 | if (remaining_bits > 0) { 90 | const mask = @as(u8, 0xFF) << @intCast(8 - remaining_bits); 91 | result.first_address[full_bytes] &= mask; 92 | } 93 | 94 | for (result.first_address[full_bytes + @intFromBool(remaining_bits > 0) ..]) |*byte| { 95 | byte.* = 0; 96 | } 97 | 98 | return result; 99 | } 100 | 101 | return CidrParseError.InvalidAddress; 102 | } 103 | 104 | /// Check if an IP address is within this CIDR range 105 | pub fn contains(self: Self, addr: std.net.Address) !bool { 106 | var data: [16]u8 = switch (addr.any.family) { 107 | std.posix.AF.INET => blk: { 108 | const raw_in_bytes = std.mem.toBytes(addr.in.sa.addr); 109 | 110 | var result: [16]u8 = [_]u8{0} ** 16; 111 | // Set the IPv4-mapped IPv6 prefix (::ffff:) 112 | result[10] = 0xff; 113 | result[11] = 0xff; 114 | // Copy the IPv4 address bytes 115 | @memcpy(result[12..16], &raw_in_bytes); 116 | 117 | break :blk result; 118 | }, 119 | std.posix.AF.INET6 => blk: { 120 | break :blk addr.in6.sa.addr; 121 | }, 122 | else => return CidrParseError.InvalidAddress, 123 | }; 124 | 125 | const full_bytes = self.prefix_len / 8; 126 | const remaining_bits = self.prefix_len % 8; 127 | 128 | // For IPv4, we only compare the last 4 bytes 129 | const start_byte: usize = if (self.version == .v4) 12 else 0; 130 | 131 | // Compare full bytes 132 | for (self.first_address[start_byte .. start_byte + full_bytes], data[start_byte .. start_byte + full_bytes], 0..) |a, b, i| { 133 | _ = i; 134 | if (a != b) { 135 | return false; 136 | } 137 | } 138 | 139 | // Compare remaining bits if any 140 | if (remaining_bits > 0) { 141 | const mask = @as(u8, 0xFF) << @intCast(8 - remaining_bits); 142 | const byte_pos = start_byte + full_bytes; 143 | if ((self.first_address[byte_pos] & mask) != (data[byte_pos] & mask)) { 144 | return false; 145 | } 146 | } 147 | 148 | return true; 149 | } 150 | 151 | pub fn format( 152 | self: Self, 153 | comptime fmt_str: []const u8, 154 | options: std.fmt.FormatOptions, 155 | writer: anytype, 156 | ) !void { 157 | _ = fmt_str; 158 | _ = options; 159 | 160 | switch (self.version) { 161 | .v4 => { 162 | try writer.print("{}.{}.{}.{}/{}", .{ 163 | self.first_address[12], 164 | self.first_address[13], 165 | self.first_address[14], 166 | self.first_address[15], 167 | self.prefix_len, 168 | }); 169 | }, 170 | .v6 => { 171 | const addr = std.net.Ip6Address.init(self.first_address, 0, 0, 0); 172 | try writer.print("{}/{}", .{ addr, self.prefix_len }); 173 | }, 174 | } 175 | } 176 | }; 177 | 178 | const testing = std.testing; 179 | 180 | test "IPv4 basic parsing" { 181 | const cases = .{ 182 | .{ 183 | "192.168.1.0/24", 184 | [16]u8{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 192, 168, 1, 0 }, 185 | 24, 186 | IpVersion.v4, 187 | }, 188 | .{ 189 | "10.0.0.0/8", 190 | [16]u8{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 10, 0, 0, 0 }, 191 | 8, 192 | IpVersion.v4, 193 | }, 194 | .{ 195 | "172.16.0.0/12", 196 | [16]u8{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 172, 16, 0, 0 }, 197 | 12, 198 | IpVersion.v4, 199 | }, 200 | }; 201 | 202 | inline for (cases) |case| { 203 | const cidr = try CidrRange.parse(case[0]); 204 | try testing.expectEqual(case[1], cidr.first_address); 205 | try testing.expectEqual(case[2], cidr.prefix_len); 206 | try testing.expectEqual(case[3], cidr.version); 207 | } 208 | } 209 | 210 | test "IPv6 basic parsing" { 211 | const cases = .{ 212 | .{ 213 | "2001:db8::/32", 214 | [16]u8{ 0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, 215 | 32, 216 | IpVersion.v6, 217 | }, 218 | .{ 219 | "fe80::/10", 220 | [16]u8{ 0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, 221 | 10, 222 | IpVersion.v6, 223 | }, 224 | .{ 225 | "::1/128", 226 | [16]u8{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 }, 227 | 128, 228 | IpVersion.v6, 229 | }, 230 | }; 231 | 232 | inline for (cases) |case| { 233 | const cidr = try CidrRange.parse(case[0]); 234 | try testing.expectEqual(case[1], cidr.first_address); 235 | try testing.expectEqual(case[2], cidr.prefix_len); 236 | try testing.expectEqual(case[3], cidr.version); 237 | } 238 | } 239 | 240 | test "IPv6 compressed notation" { 241 | const cases = .{ 242 | .{ 243 | "2001:db8:0:0:0:0:0:0/32", 244 | "2001:db8::/32", 245 | }, 246 | .{ 247 | "2001:0db8:0000:0000:0000:0000:0000:0000/32", 248 | "2001:db8::/32", 249 | }, 250 | .{ 251 | "fe80:0:0:0:0:0:0:0/10", 252 | "fe80::/10", 253 | }, 254 | }; 255 | 256 | inline for (cases) |case| { 257 | const cidr1 = try CidrRange.parse(case[0]); 258 | const cidr2 = try CidrRange.parse(case[1]); 259 | try testing.expectEqual(cidr1.first_address, cidr2.first_address); 260 | try testing.expectEqual(cidr1.prefix_len, cidr2.prefix_len); 261 | try testing.expectEqual(cidr1.version, cidr2.version); 262 | } 263 | } 264 | 265 | test "Invalid CIDR formats" { 266 | const cases = .{ 267 | "192.168.1.0", // Missing prefix 268 | "192.168.1.0/", // Empty prefix 269 | "192.168.1.0/33", // IPv4 prefix too large 270 | "2001:db8::/129", // IPv6 prefix too large 271 | "192.168.1.256/24", // Invalid IPv4 address 272 | "2001:db8::xyz/32", // Invalid IPv6 address 273 | "not.an.ip/24", // Invalid address format 274 | "/24", // Missing address 275 | "192.168.1.0/-1", // Negative prefix 276 | "192.168.1.0/a", // Non-numeric prefix 277 | }; 278 | 279 | inline for (cases) |case| { 280 | if (CidrRange.parse(case)) |_| { 281 | try testing.expect(false); // Should not succeed 282 | } else |err| { 283 | switch (err) { 284 | error.InvalidFormat, 285 | error.InvalidAddress, 286 | error.InvalidPrefixLength, 287 | => {}, 288 | else => try testing.expect(false), 289 | } 290 | } 291 | } 292 | } 293 | 294 | test "Edge cases" { 295 | const cases = .{ 296 | .{ 297 | "0.0.0.0/0", // Full IPv4 range 298 | [16]u8{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 0, 0, 0, 0 }, 299 | 0, 300 | IpVersion.v4, 301 | }, 302 | .{ 303 | "::/0", // Full IPv6 range 304 | [16]u8{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, 305 | 0, 306 | IpVersion.v6, 307 | }, 308 | .{ 309 | "255.255.255.255/32", // Single IPv4 address 310 | [16]u8{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 255, 255, 255, 255 }, 311 | 32, 312 | IpVersion.v4, 313 | }, 314 | .{ 315 | "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/128", // Single IPv6 address 316 | [16]u8{ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff }, 317 | 128, 318 | IpVersion.v6, 319 | }, 320 | }; 321 | 322 | inline for (cases) |case| { 323 | const cidr = try CidrRange.parse(case[0]); 324 | try testing.expectEqual(case[1], cidr.first_address); 325 | try testing.expectEqual(case[2], cidr.prefix_len); 326 | try testing.expectEqual(case[3], cidr.version); 327 | } 328 | } 329 | 330 | test "Non-zero host bits" { 331 | const cases = .{ 332 | .{ 333 | "192.168.1.1/24", // Should clear to 192.168.1.0/24 334 | [16]u8{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 192, 168, 1, 0 }, 335 | }, 336 | .{ 337 | "2001:db8::1/32", // Should clear to 2001:db8::/32 338 | [16]u8{ 0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, 339 | }, 340 | }; 341 | 342 | inline for (cases) |case| { 343 | const cidr = try CidrRange.parse(case[0]); 344 | try testing.expectEqual(case[1], cidr.first_address); 345 | } 346 | } 347 | 348 | test "Boundary prefix lengths" { 349 | const ipv4_cases = .{ 350 | "192.168.1.0/0", 351 | "192.168.1.0/1", 352 | "192.168.1.0/31", 353 | "192.168.1.0/32", 354 | }; 355 | 356 | const ipv6_cases = .{ 357 | "2001:db8::/0", 358 | "2001:db8::/1", 359 | "2001:db8::/127", 360 | "2001:db8::/128", 361 | }; 362 | 363 | inline for (ipv4_cases) |case| { 364 | const cidr = try CidrRange.parse(case); 365 | try testing.expect(cidr.version == .v4); 366 | } 367 | 368 | inline for (ipv6_cases) |case| { 369 | const cidr = try CidrRange.parse(case); 370 | try testing.expect(cidr.version == .v6); 371 | } 372 | } 373 | 374 | fn ip4ToBytes(strAddr: []const u8, port: u16) [16]u8 { 375 | const ipv4 = std.net.Address.parseIp4(strAddr, port) catch unreachable; 376 | const addr = std.mem.toBytes(ipv4.in.sa.addr); 377 | 378 | var result: [16]u8 = [_]u8{0} ** 16; 379 | // Set the IPv4-mapped IPv6 prefix (::ffff:) 380 | result[10] = 0xff; 381 | result[11] = 0xff; 382 | // Copy the IPv4 address bytes 383 | @memcpy(result[12..16], &addr); 384 | return result; 385 | } 386 | 387 | test "IPv4 contains basic tests" { 388 | // Test a typical IPv4 /24 network 389 | const cidr = try CidrRange.parse("192.168.1.0/24"); 390 | 391 | // These should be in the range 392 | try testing.expect(try cidr.contains(std.net.Address.parseIp4("192.168.1.0", 0) catch unreachable)); 393 | try testing.expect(try cidr.contains(std.net.Address.parseIp4("192.168.1.1", 0) catch unreachable)); 394 | try testing.expect(try cidr.contains(std.net.Address.parseIp4("192.168.1.255", 0) catch unreachable)); 395 | 396 | // These should not be in the range 397 | try testing.expect(!try cidr.contains(std.net.Address.parseIp4("192.168.0.255", 0) catch unreachable)); 398 | try testing.expect(!try cidr.contains(std.net.Address.parseIp4("192.168.2.0", 0) catch unreachable)); 399 | try testing.expect(!try cidr.contains(std.net.Address.parseIp4("192.169.1.1", 0) catch unreachable)); 400 | } 401 | 402 | test "IPv6 contains basic tests" { 403 | // Test a typical IPv6 /64 network 404 | const cidr = try CidrRange.parse("2001:db8::/64"); 405 | 406 | // These should be in the range 407 | try testing.expect(try cidr.contains(try net.Address.parseIp6("2001:db8::", 0))); 408 | try testing.expect(try cidr.contains(try net.Address.parseIp6("2001:db8::1", 0))); 409 | try testing.expect(try cidr.contains(try net.Address.parseIp6("2001:db8::ffff", 0))); 410 | 411 | // These should not be in the range 412 | try testing.expect(!try cidr.contains(try net.Address.parseIp6("2001:db9::", 0))); 413 | try testing.expect(!try cidr.contains(try net.Address.parseIp6("2001:db8:1::", 0))); 414 | try testing.expect(!try cidr.contains(try net.Address.parseIp6("2002:db8::", 0))); 415 | } 416 | 417 | test "contains edge prefix lengths" { 418 | // Test /0 (entire address space) 419 | { 420 | const cidr_v4 = try CidrRange.parse("0.0.0.0/0"); 421 | try testing.expect(try cidr_v4.contains(try std.net.Address.parseIp4("0.0.0.0", 0))); 422 | try testing.expect(try cidr_v4.contains(try std.net.Address.parseIp4("255.255.255.255", 0))); 423 | try testing.expect(try cidr_v4.contains(try std.net.Address.parseIp4("192.168.1.1", 0))); 424 | } 425 | 426 | { 427 | const cidr_v6 = try CidrRange.parse("::/0"); 428 | try testing.expect(try cidr_v6.contains(try net.Address.parseIp6("::", 0))); 429 | try testing.expect(try cidr_v6.contains(try net.Address.parseIp6("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", 0))); 430 | try testing.expect(try cidr_v6.contains(try net.Address.parseIp6("2001:db8::1", 0))); 431 | } 432 | 433 | // Test single address (/32 for IPv4, /128 for IPv6) 434 | { 435 | const cidr_v4 = try CidrRange.parse("192.168.1.1/32"); 436 | try testing.expect(try cidr_v4.contains(try std.net.Address.parseIp4("192.168.1.1", 0))); 437 | try testing.expect(!try cidr_v4.contains(try std.net.Address.parseIp4("192.168.1.2", 0))); 438 | } 439 | 440 | { 441 | const cidr_v6 = try CidrRange.parse("2001:db8::1/128"); 442 | try testing.expect(try cidr_v6.contains(try net.Address.parseIp6("2001:db8::1", 0))); 443 | try testing.expect(!try cidr_v6.contains(try net.Address.parseIp6("2001:db8::2", 0))); 444 | } 445 | } 446 | 447 | test "contains non-aligned prefix lengths" { 448 | // Test IPv4 /23 (two /24 networks) 449 | { 450 | const cidr = try CidrRange.parse("192.168.0.0/23"); 451 | try testing.expect(try cidr.contains(try std.net.Address.parseIp4("192.168.0.1", 0))); 452 | try testing.expect(try cidr.contains(try std.net.Address.parseIp4("192.168.1.1", 0))); 453 | try testing.expect(!try cidr.contains(try std.net.Address.parseIp4("192.168.2.1", 0))); 454 | } 455 | 456 | // Test IPv6 /63 (two /64 networks) 457 | { 458 | const cidr = try CidrRange.parse("2001:db8::/63"); 459 | try testing.expect(try cidr.contains(try net.Address.parseIp6("2001:db8::", 0))); 460 | try testing.expect(try cidr.contains(try net.Address.parseIp6("2001:db8:0:1::", 0))); 461 | try testing.expect(!try cidr.contains(try net.Address.parseIp6("2001:db8:0:2::", 0))); 462 | } 463 | } 464 | 465 | test "contains byte boundary edge cases" { 466 | // Test IPv4 /16 (byte boundary) 467 | { 468 | const cidr = try CidrRange.parse("192.168.0.0/16"); 469 | try testing.expect(try cidr.contains(try std.net.Address.parseIp4("192.168.0.0", 0))); 470 | try testing.expect(try cidr.contains(try std.net.Address.parseIp4("192.168.255.255", 0))); 471 | try testing.expect(!try cidr.contains(try std.net.Address.parseIp4("192.169.0.0", 0))); 472 | } 473 | 474 | // Test IPv6 /48 (byte boundary) 475 | { 476 | const cidr = try CidrRange.parse("2001:db8:1100::/48"); 477 | try testing.expect(try cidr.contains(try net.Address.parseIp6("2001:db8:1100::", 0))); 478 | try testing.expect(try cidr.contains(try net.Address.parseIp6("2001:db8:1100:ffff::", 0))); 479 | try testing.expect(!try cidr.contains(try net.Address.parseIp6("2001:db8:1101::", 0))); 480 | } 481 | } 482 | -------------------------------------------------------------------------------- /src/name.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const dns = @import("lib.zig"); 3 | 4 | const logger = std.log.scoped(.dns_name); 5 | 6 | /// Represents a raw component of a DNS label. 7 | pub const LabelComponent = union(enum) { 8 | Full: []const u8, 9 | /// Holds the first offset component of that pointer. 10 | /// 11 | /// You still have to read a byte for the second component and assemble 12 | /// it into the final packet offset. 13 | Pointer: u16, 14 | Null: void, 15 | }; 16 | 17 | /// A raw name that may end in a Pointer LabelComponent. 18 | pub const RawName = struct { 19 | components: []LabelComponent, 20 | 21 | /// Represents the index of that name in its packet's body. 22 | /// 23 | /// **This is an internal field for DNS name pointer resolution.** 24 | packet_index: ?usize = null, 25 | }; 26 | 27 | /// Wrapper class for safer handling of names 28 | pub const Name = union(enum) { 29 | raw: RawName, 30 | full: FullName, 31 | 32 | const Self = @This(); 33 | 34 | /// Deinitializes the entire name, including the labels inside, given 35 | /// an allocator. 36 | pub fn deinit(self: Self, allocator: std.mem.Allocator) void { 37 | switch (self) { 38 | .raw => |raw| { 39 | for (raw.components) |label| switch (label) { 40 | .Full => |data| allocator.free(data), 41 | else => {}, 42 | }; 43 | 44 | allocator.free(raw.components); 45 | }, 46 | .full => |full| { 47 | for (full.labels) |label| allocator.free(label); 48 | allocator.free(full.labels); 49 | }, 50 | } 51 | } 52 | 53 | /// Caller owns returned memory. 54 | pub fn readFrom( 55 | reader: anytype, 56 | options: dns.ParserOptions, 57 | ) !?Self { 58 | const current_byte_index = reader.context.ctx.current_byte_count; 59 | 60 | if (options.allocator) |allocator| { 61 | var components = std.ArrayList(LabelComponent).init(allocator); 62 | defer components.deinit(); 63 | 64 | var has_pointer: bool = false; 65 | 66 | while (true) { 67 | if (components.items.len > options.max_label_size) 68 | return error.Overflow; 69 | 70 | const component = (try Self.readLabelComponent(reader, allocator)).?; 71 | logger.debug("read name: component {}", .{component}); 72 | try components.append(component); 73 | switch (component) { 74 | .Null => break, 75 | .Pointer => { 76 | has_pointer = true; 77 | break; 78 | }, 79 | else => {}, 80 | } 81 | } 82 | 83 | return if (has_pointer) .{ .raw = .{ 84 | .components = try components.toOwnedSlice(), 85 | .packet_index = current_byte_index, 86 | } } else .{ 87 | .full = try FullName.fromAssumedComponents( 88 | allocator, 89 | components.items, 90 | current_byte_index, 91 | ), 92 | }; 93 | } else { 94 | // skip the name in the reader 95 | var name_index: usize = 0; 96 | 97 | while (true) : (name_index += 1) { 98 | if (name_index > options.max_label_size) 99 | return error.Overflow; 100 | 101 | const maybe_component = try Self.readLabelComponent(reader, null); 102 | if (maybe_component) |component| switch (component) { 103 | .Null, .Pointer => break, 104 | else => {}, 105 | }; 106 | } 107 | 108 | return null; 109 | } 110 | } 111 | 112 | /// Deserialize a single LabelComponent, which can be: 113 | /// - a pointer 114 | /// - a full label ([]const u8) 115 | /// - a null octet 116 | fn readLabelComponent( 117 | reader: anytype, 118 | maybe_allocator: ?std.mem.Allocator, 119 | ) !?LabelComponent { 120 | // pointers, in the binary representation of a byte, are as follows 121 | // 1 1 B B B B B B | B B B B B B B B 122 | // they are two bytes length, but to identify one, you check if the 123 | // first two bits are 1 and 1 respectively. 124 | // 125 | // then you read the rest, and turn it into an offset (without the 126 | // starting bits!!!) 127 | // 128 | // to prevent inefficiencies, we just read a single byte, see if it 129 | // has the starting bits, and then we chop it off, merging with the 130 | // next byte. pointer offsets are 14 bits long 131 | // 132 | // when it isn't a pointer, its a length for a given label, and that 133 | // length can only be a single byte. 134 | // 135 | // if the length is 0, its a null octet 136 | logger.debug( 137 | "reading label component at {d} bytes", 138 | .{reader.context.ctx.current_byte_count}, 139 | ); 140 | const possible_length = try reader.readInt(u8, .big); 141 | if (possible_length == 0) return LabelComponent{ .Null = {} }; 142 | 143 | // RFC1035: 144 | // since the label must begin with two zero bits because 145 | // labels are restricted to 63 octets or less. 146 | 147 | const bit1 = (possible_length & (1 << 7)) != 0; 148 | const bit2 = (possible_length & (1 << 6)) != 0; 149 | 150 | if (bit1 and bit2) { 151 | const second_offset_component = try reader.readInt(u8, .big); 152 | 153 | // merge them together 154 | var offset: u16 = (possible_length << 7) | second_offset_component; 155 | 156 | // set first two bits of ptr_offset to zero as they're the 157 | // pointer prefix bits (which are always 1, which brings problems) 158 | offset &= ~@as(u16, 1 << 15); 159 | offset &= ~@as(u16, 1 << 14); 160 | 161 | return LabelComponent{ .Pointer = offset }; 162 | } else { 163 | // those must be 0 for a correct label length to be made 164 | std.debug.assert((!bit1) and (!bit2)); 165 | 166 | // the next bytes contain a full label. 167 | if (maybe_allocator) |allocator| { 168 | const label = try allocator.alloc(u8, possible_length); 169 | const read_bytes = try reader.read(label); 170 | if (read_bytes != label.len) logger.err( 171 | "possible_length = {d} read_bytes = {d} label.len = {d}", 172 | .{ possible_length, read_bytes, label.len }, 173 | ); 174 | std.debug.assert(read_bytes == label.len); 175 | return LabelComponent{ .Full = label }; 176 | } else { 177 | logger.debug("read_name: skip {d} bytes as no alloc", .{possible_length}); 178 | try reader.skipBytes(possible_length, .{}); 179 | return null; 180 | } 181 | } 182 | } 183 | 184 | /// Write the network representation of a name onto a stream. 185 | pub fn writeTo(self: Self, writer: anytype) !usize { 186 | return switch (self) { 187 | // NOTE we don't serialize to pointers. 188 | .raw => unreachable, // must convert to full name to be able to write 189 | .full => |full| try full.writeTo(writer), 190 | }; 191 | } 192 | 193 | /// Return the byte size of the network representation of the name. 194 | pub fn networkSize(self: Self) usize { 195 | return switch (self) { 196 | .raw => unreachable, // must resolve against original packet so that we know the full name 197 | .full => |full| full.networkSize(), 198 | }; 199 | } 200 | 201 | /// Create a FullName with a given buffer that will hold its labels. 202 | pub fn fromString(domain: []const u8, buffer: [][]const u8) !Self { 203 | return .{ .full = try FullName.fromString(domain, buffer) }; 204 | } 205 | 206 | pub fn format( 207 | self: Self, 208 | comptime f: []const u8, 209 | options: std.fmt.FormatOptions, 210 | writer: anytype, 211 | ) !void { 212 | return switch (self) { 213 | .full => |full| full.format(f, options, writer), 214 | .raw => |raw| for (raw.components) |component| switch (component) { 215 | .Pointer => |ptr| try std.fmt.format(writer, "(pointer={d}).", .{ptr}), 216 | .Full => |label| try std.fmt.format(writer, "{s}.", .{label}), 217 | .Null => break, 218 | }, 219 | }; 220 | } 221 | }; 222 | 223 | /// Represents a single DNS domain name, which is a slice of strings. 224 | /// 225 | /// The "www.google.com" friendly domain name can be represented in DNS as a 226 | /// sequence of labels: first "www", then "google", then "com", with a length 227 | /// prefix for all of them, ending in a null byte. 228 | /// 229 | /// For RawName, the names may end in a pointer that is dependent on the overall 230 | /// parsing context of the packet. To be able to turn a RawName into a FullName, 231 | /// look at NamePool.transmuteName 232 | pub const FullName = struct { 233 | /// The name's labels. 234 | labels: [][]const u8, 235 | 236 | /// Represents the index of that name in its packet's body. 237 | /// 238 | /// **This is an internal field for DNS name pointer resolution.** 239 | packet_index: ?usize = null, 240 | 241 | const Self = @This(); 242 | 243 | /// Create a FullName from a []LabelComponent. 244 | /// 245 | /// Assumes that the slice does not end in a pointer. 246 | /// 247 | /// Does not take ownership of the returned slice. 248 | /// Caller owns returned memory. 249 | pub fn fromAssumedComponents( 250 | allocator: std.mem.Allocator, 251 | components: []LabelComponent, 252 | packet_index: ?usize, 253 | ) !Self { 254 | var labels = std.ArrayList([]const u8).init(allocator); 255 | defer labels.deinit(); 256 | 257 | for (components) |component| switch (component) { 258 | .Full => |data| try labels.append(data), 259 | .Pointer => unreachable, 260 | .Null => break, 261 | }; 262 | 263 | return Self{ 264 | .labels = try labels.toOwnedSlice(), 265 | .packet_index = packet_index, 266 | }; 267 | } 268 | 269 | /// Only use this if you have manually heap allocated a Name 270 | /// through the internal Packet.readName function. 271 | /// 272 | /// IncomingPacket.deinit already frees allocated Names. 273 | pub fn deinit(self: Self, allocator: std.mem.Allocator) void { 274 | for (self.labels) |label| allocator.free(label); 275 | allocator.free(self.labels); 276 | } 277 | 278 | /// Returns the total size in bytes of the DNS Name 279 | pub fn networkSize(self: Self) usize { 280 | // by default, add the null octet at the end of it 281 | var total_size: usize = 1; 282 | 283 | for (self.labels) |label| { 284 | // length octet + the actual label octets 285 | total_size += @sizeOf(u8); 286 | total_size += label.len * @sizeOf(u8); 287 | } 288 | 289 | return total_size; 290 | } 291 | 292 | /// Get a Name out of a domain name ("www.google.com", for example). 293 | pub fn fromString(domain: []const u8, buffer: [][]const u8) error{ 294 | /// The given domain contains an empty label (e.g "google...com") 295 | /// 296 | /// Empty labels are disallowed in DNS, as "length 0" label is also 297 | /// determined as the Null label, which ends a name. 298 | EmptyLabelInName, 299 | /// The given DNS name has too many labels for the given buffer to hold. 300 | Overflow, 301 | }!Self { 302 | if (domain.len > 255) return error.Overflow; 303 | 304 | var it = std.mem.splitSequence(u8, domain, "."); 305 | 306 | // labels can finish in a dot, depresenting the null label right after 307 | // this is a dirty solution so that i dont need to write to memory too much 308 | // but it should work 309 | var label_count: usize = 0; 310 | while (it.next()) |_| { 311 | label_count += 1; 312 | } 313 | 314 | it = std.mem.splitSequence(u8, domain, "."); 315 | var idx: usize = 0; 316 | while (it.next()) |label| { 317 | if (idx == label_count - 1 and label.len == 0) continue; 318 | if (label.len == 0) return error.EmptyLabelInName; 319 | if (idx > (buffer.len - 1)) return error.Overflow; 320 | 321 | buffer[idx] = label; 322 | idx += 1; 323 | } 324 | 325 | return Self{ .labels = buffer[0..idx] }; 326 | } 327 | 328 | pub fn writeTo(self: Self, writer: anytype) !usize { 329 | var size: usize = 0; 330 | for (self.labels) |label| { 331 | std.debug.assert(label.len < 255); 332 | 333 | try writer.writeInt(u8, @as(u8, @intCast(label.len)), .big); 334 | size += 1; 335 | 336 | for (label) |byte| { 337 | try writer.writeByte(byte); 338 | size += 1; 339 | } 340 | } 341 | 342 | // null-octet for the end of labels for this name 343 | try writer.writeByte(@as(u8, 0)); 344 | return size + 1; 345 | } 346 | 347 | /// Format the given DNS name. 348 | pub fn format( 349 | self: Self, 350 | comptime f: []const u8, 351 | options: std.fmt.FormatOptions, 352 | writer: anytype, 353 | ) !void { 354 | _ = f; 355 | _ = options; 356 | 357 | for (self.labels) |label| { 358 | try std.fmt.format(writer, "{s}.", .{label}); 359 | } 360 | } 361 | }; 362 | 363 | const NameList = std.ArrayList(dns.Name); 364 | 365 | /// Implements RFC1035, section 4.1.4 Message Compression. 366 | /// 367 | /// This is an entity that holds Name entities inside, with their respective 368 | /// locations inside the packet. With that information, it is able to convert 369 | /// from a RawName given by deserialization into a FullName that contains 370 | /// all wanted labels. 371 | pub const NamePool = struct { 372 | allocator: std.mem.Allocator, 373 | held_names: NameList, 374 | 375 | const Self = @This(); 376 | pub fn init(allocator: std.mem.Allocator) Self { 377 | return Self{ 378 | .allocator = allocator, 379 | .held_names = NameList.init(allocator), 380 | }; 381 | } 382 | 383 | pub fn deinit(self: Self) void { 384 | self.held_names.deinit(); 385 | } 386 | 387 | pub fn deinitWithNames(self: Self) void { 388 | for (self.held_names.items) |name| name.deinit(self.allocator); 389 | self.deinit(); 390 | } 391 | 392 | /// Convert dns.RawName or FullName to FullName, applying pointer 393 | /// resolution, and storing the name for future pointers to be resolved. 394 | /// 395 | /// takes ownership of the given name's memory. 396 | pub fn transmuteName(self: *Self, name: dns.Name) !dns.Name { 397 | return switch (name) { 398 | .full => blk: { 399 | try self.held_names.append(name); 400 | break :blk name; 401 | }, 402 | .raw => |raw| blk: { 403 | defer name.deinit(self.allocator); 404 | // this ends in a Pointer, create a new FullName 405 | var resolved_labels = std.ArrayList([]const u8).init(self.allocator); 406 | defer resolved_labels.deinit(); 407 | 408 | for (raw.components) |raw_component| switch (raw_component) { 409 | .Full => |text| try resolved_labels.append(try self.allocator.dupe(u8, text)), 410 | .Pointer => |packet_offset| { 411 | 412 | // step 1: find out the name we already have 413 | // that contains this pointer 414 | var maybe_referenced_name: ?dns.FullName = null; 415 | for (self.held_names.items) |held_name_from_list| { 416 | const held_name = held_name_from_list.full; 417 | 418 | const packet_index = 419 | if (held_name.packet_index) |idx| 420 | idx 421 | else 422 | continue; 423 | 424 | // calculate end packet offset using length of the 425 | // full name. 426 | 427 | const start_index = packet_index; 428 | var name_length: usize = 0; 429 | for (held_name.labels) |label| 430 | name_length += label.len; 431 | const end_index = packet_index + name_length; 432 | 433 | if (start_index <= packet_offset and packet_offset <= end_index) { 434 | maybe_referenced_name = held_name; 435 | } 436 | } 437 | 438 | if (maybe_referenced_name) |referenced_name| { 439 | var label_cursor: usize = referenced_name.packet_index.?; 440 | var label_index: ?usize = null; 441 | 442 | for (referenced_name.labels, 0..) |label, idx| { 443 | // if cursor is in offset's range, select that 444 | // label onwards as our new label 445 | const label_start = label_cursor; 446 | if (label_start <= packet_offset) { 447 | label_index = idx; 448 | } 449 | label_cursor += label.len; 450 | } 451 | 452 | const referenced_labels = referenced_name.labels[label_index.?..]; 453 | 454 | for (referenced_labels) |referenced_label| { 455 | try resolved_labels.append(try self.allocator.dupe(u8, referenced_label)); 456 | } 457 | } else { 458 | logger.warn( 459 | "unknown pointer offset: pointer has offset={d}", 460 | .{packet_offset}, 461 | ); 462 | 463 | for (self.held_names.items) |held_name| { 464 | logger.warn( 465 | "known name: {} at offset {?d}", 466 | .{ held_name, held_name.full.packet_index }, 467 | ); 468 | } 469 | 470 | return error.UnknownPointerOffset; 471 | } 472 | }, 473 | .Null => unreachable, 474 | }; 475 | 476 | const full_name = dns.Name{ .full = dns.FullName{ 477 | .labels = try resolved_labels.toOwnedSlice(), 478 | .packet_index = name.raw.packet_index, 479 | } }; 480 | try self.held_names.append(full_name); 481 | break :blk full_name; 482 | }, 483 | }; 484 | } 485 | 486 | /// given a dns.Question or dns.Resource, resolve pointers and return 487 | /// that same Question or Resource with a FullName inside of it. 488 | /// 489 | /// to be able to do this, ALL questions and resources must be registered 490 | /// in the NamePool. 491 | /// 492 | /// this takes ownership of the given resource, returning a new one with 493 | /// FullName set in the respective "name" union field. 494 | pub fn transmuteResource(self: *Self, resource: anytype) !@TypeOf(resource) { 495 | switch (@TypeOf(resource)) { 496 | dns.Question => { 497 | var new_question = resource; 498 | new_question.name = try self.transmuteName(resource.name.?); 499 | return new_question; 500 | }, 501 | dns.Resource => { 502 | var new_resource = resource; 503 | new_resource.name = try self.transmuteName(resource.name.?); 504 | return new_resource; 505 | }, 506 | else => @compileError("invalid type to resolve in name pool " ++ @typeName(@TypeOf(resource))), 507 | } 508 | } 509 | }; 510 | 511 | test "localhost parses correctly" { 512 | var name_buffer: [128][]const u8 = undefined; 513 | const name = try dns.Name.fromString("localhost", &name_buffer); 514 | try std.testing.expectEqual(name.full.labels.len, 1); 515 | try std.testing.expectEqualStrings(name.full.labels[0], "localhost"); 516 | 517 | const name2 = try dns.Name.fromString("localhost.", &name_buffer); 518 | try std.testing.expectEqual(name2.full.labels.len, 1); 519 | try std.testing.expectEqualStrings(name2.full.labels[0], "localhost"); 520 | } 521 | -------------------------------------------------------------------------------- /src/helpers.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const builtin = @import("builtin"); 3 | const ws2_32 = std.os.windows.ws2_32; 4 | const dns = @import("lib.zig"); 5 | 6 | const CidrRange = @import("cidr.zig").CidrRange; 7 | 8 | fn printList( 9 | name_pool: *dns.NamePool, 10 | writer: anytype, 11 | resource_list: []dns.Resource, 12 | ) !void { 13 | // TODO the formatting here is not good... 14 | try writer.print(";;name\t\t\trrtype\tclass\tttl\trdata\n", .{}); 15 | 16 | for (resource_list) |resource| { 17 | const resource_data = try dns.ResourceData.fromOpaque( 18 | resource.typ, 19 | resource.opaque_rdata.?, 20 | .{ 21 | .name_provider = .{ .full = name_pool }, 22 | .allocator = name_pool.allocator, 23 | }, 24 | ); 25 | defer switch (resource_data) { 26 | .TXT => resource_data.deinit(name_pool.allocator), 27 | else => {}, // names are owned by given NamePool 28 | }; 29 | 30 | try writer.print("{?}\t\t{s}\t{s}\t{d}\t{any}\n", .{ 31 | resource.name.?, 32 | @tagName(resource.typ), 33 | @tagName(resource.class), 34 | resource.ttl, 35 | resource_data, 36 | }); 37 | } 38 | 39 | try writer.print("\n", .{}); 40 | } 41 | 42 | /// Print a packet in the format of a "zone file". 43 | /// 44 | /// This will deserialize resourcedata in the resource sections, so 45 | /// a NamePool instance is required. 46 | /// 47 | /// This helper method will NOT free the memory created by name allocation, 48 | /// you should do this manually in a defer block calling NamePool.deinitWithNames. 49 | pub fn printAsZoneFile( 50 | packet: *dns.Packet, 51 | name_pool: *dns.NamePool, 52 | writer: anytype, 53 | ) !void { 54 | try writer.print(";; opcode: {}, status: {}, id: {}\n", .{ 55 | packet.header.opcode, 56 | packet.header.response_code, 57 | packet.header.id, 58 | }); 59 | 60 | try writer.print(";; QUERY: {}, ANSWER: {}, AUTHORITY: {}, ADDITIONAL: {}\n\n", .{ 61 | packet.header.question_length, 62 | packet.header.answer_length, 63 | packet.header.nameserver_length, 64 | packet.header.additional_length, 65 | }); 66 | 67 | if (packet.header.question_length > 0) { 68 | try writer.print(";; QUESTION SECTION:\n", .{}); 69 | try writer.print(";;name\ttype\tclass\n", .{}); 70 | 71 | for (packet.questions) |question| { 72 | try writer.print(";{?}\t{s}\t{s}\n", .{ 73 | question.name, 74 | @tagName(question.typ), 75 | @tagName(question.class), 76 | }); 77 | } 78 | 79 | try writer.print("\n", .{}); 80 | } 81 | 82 | if (packet.header.answer_length > 0) { 83 | try writer.print(";; ANSWER SECTION:\n", .{}); 84 | try printList(name_pool, writer, packet.answers); 85 | } else { 86 | try writer.print(";; no answer\n", .{}); 87 | } 88 | 89 | if (packet.header.nameserver_length > 0) { 90 | try writer.print(";; AUTHORITY SECTION:\n", .{}); 91 | try printList(name_pool, writer, packet.nameservers); 92 | } else { 93 | try writer.print(";; no authority\n\n", .{}); 94 | } 95 | 96 | if (packet.header.additional_length > 0) { 97 | try writer.print(";; ADDITIONAL SECTION:\n", .{}); 98 | try printList(name_pool, writer, packet.additionals); 99 | } else { 100 | try writer.print(";; no additional\n\n", .{}); 101 | } 102 | } 103 | 104 | /// Generate a random header ID to use in a query. 105 | pub fn randomHeaderId() u16 { 106 | const seed = @as(u64, @truncate(@as(u128, @bitCast(std.time.nanoTimestamp())))); 107 | var r = std.Random.DefaultPrng.init(seed); 108 | return r.random().int(u16); 109 | } 110 | 111 | /// High level wrapper around a single UDP connection to send and receive 112 | /// DNS packets. 113 | pub const DNSConnection = struct { 114 | address: std.net.Address, 115 | socket: std.net.Stream, 116 | 117 | const Self = @This(); 118 | 119 | pub fn close(self: Self) void { 120 | self.socket.close(); 121 | } 122 | 123 | pub fn sendPacket(self: Self, packet: dns.Packet) !void { 124 | // Stream won't use sendto() when its UDP, so serialize it into 125 | // a buffer, and then send that 126 | var buffer: [1024]u8 = undefined; 127 | 128 | const typ = std.io.FixedBufferStream([]u8); 129 | var stream = typ{ .buffer = &buffer, .pos = 0 }; 130 | 131 | const written_bytes = try packet.writeTo(stream.writer()); 132 | 133 | const result = buffer[0..written_bytes]; 134 | const dest_len: u32 = switch (self.address.any.family) { 135 | std.posix.AF.INET => @sizeOf(std.posix.sockaddr.in), 136 | std.posix.AF.INET6 => @sizeOf(std.posix.sockaddr.in6), 137 | else => unreachable, 138 | }; 139 | 140 | _ = try std.posix.sendto( 141 | self.socket.handle, 142 | result, 143 | 0, 144 | &self.address.any, 145 | dest_len, 146 | ); 147 | } 148 | 149 | /// Deserializes and allocates an *entire* DNS packet. 150 | /// 151 | /// This function is not encouraged if you only wish to get A/AAAA 152 | /// records for a domain name through the system DNS resolver, as this 153 | /// allocates all the data of the packet. Use `receiveTrustedAddresses` 154 | /// for such. 155 | pub fn receiveFullPacket( 156 | self: Self, 157 | packet_allocator: std.mem.Allocator, 158 | /// Maximum size for the incoming UDP datagram 159 | comptime max_incoming_message_size: usize, 160 | options: ParseFullPacketOptions, 161 | ) !dns.IncomingPacket { 162 | var packet_buffer: [max_incoming_message_size]u8 = undefined; 163 | const read_bytes = try self.socket.read(&packet_buffer); 164 | const packet_bytes = packet_buffer[0..read_bytes]; 165 | logger.debug("read {d} bytes", .{read_bytes}); 166 | 167 | var stream = std.io.FixedBufferStream([]const u8){ 168 | .buffer = packet_bytes, 169 | .pos = 0, 170 | }; 171 | return parseFullPacket(stream.reader(), packet_allocator, options); 172 | } 173 | }; 174 | 175 | pub const ParseFullPacketOptions = struct { 176 | /// Use this NamePool to let deserialization of names outlive the call 177 | /// to parseFullPacket. 178 | /// 179 | /// Useful if you need to parse RDATA sections after parseFullPacket. 180 | name_pool: ?*dns.NamePool = null, 181 | }; 182 | 183 | pub fn parseFullPacket( 184 | reader: anytype, 185 | allocator: std.mem.Allocator, 186 | parse_full_packet_options: ParseFullPacketOptions, 187 | ) !dns.IncomingPacket { 188 | const parser_options = dns.ParserOptions{ .allocator = allocator }; 189 | 190 | var packet = try allocator.create(dns.Packet); 191 | errdefer allocator.destroy(packet); 192 | const incoming_packet = dns.IncomingPacket{ 193 | .allocator = allocator, 194 | .packet = packet, 195 | }; 196 | 197 | var ctx = dns.ParserContext{}; 198 | var parser = dns.parser(reader, &ctx, parser_options); 199 | 200 | var builtin_name_pool = dns.NamePool.init(allocator); 201 | defer builtin_name_pool.deinit(); 202 | 203 | var name_pool = if (parse_full_packet_options.name_pool) |name_pool| 204 | name_pool 205 | else 206 | &builtin_name_pool; 207 | 208 | var questions = std.ArrayList(dns.Question).init(allocator); 209 | defer questions.deinit(); 210 | 211 | var answers = std.ArrayList(dns.Resource).init(allocator); 212 | defer answers.deinit(); 213 | 214 | var nameservers = std.ArrayList(dns.Resource).init(allocator); 215 | defer nameservers.deinit(); 216 | 217 | var additionals = std.ArrayList(dns.Resource).init(allocator); 218 | defer additionals.deinit(); 219 | 220 | while (try parser.next()) |part| { 221 | switch (part) { 222 | .header => |header| packet.header = header, 223 | .question => |question_with_raw_names| { 224 | const question = 225 | try name_pool.transmuteResource(question_with_raw_names); 226 | try questions.append(question); 227 | }, 228 | .end_question => packet.questions = try questions.toOwnedSlice(), 229 | .answer, .nameserver, .additional => |raw_resource| { 230 | // since we give it an allocator, we don't receive rdata frames 231 | const resource = try name_pool.transmuteResource(raw_resource); 232 | try (switch (part) { 233 | .answer => answers, 234 | .nameserver => nameservers, 235 | .additional => additionals, 236 | else => unreachable, 237 | }).append(resource); 238 | }, 239 | .end_answer => packet.answers = try answers.toOwnedSlice(), 240 | .end_nameserver => packet.nameservers = try nameservers.toOwnedSlice(), 241 | .end_additional => packet.additionals = try additionals.toOwnedSlice(), 242 | .answer_rdata, .nameserver_rdata, .additional_rdata => unreachable, 243 | } 244 | } 245 | 246 | return incoming_packet; 247 | } 248 | 249 | const logger = std.log.scoped(.dns_helpers); 250 | 251 | /// Open a socket to the DNS resolver specified in input parameter 252 | pub fn connectToResolver(address: []const u8, port: ?u16) !DNSConnection { 253 | const addr = blk: { 254 | if (builtin.os.tag == .windows) { 255 | // it is recommended to use `resolveIp`, but windows currently does 256 | // not support resolving ipv6 addresses. there is a PR for the 257 | // stdlib here: https://github.com/ziglang/zig/pull/22555 as soon 258 | // as that is merged, this can be removed. till then, as 259 | // `resolveIp` is only recommended in order to handle ipv6 link 260 | // local addresses, we can use `parseIp`, as the use case 261 | // `resolveIp` is intended for doesnt work. 262 | break :blk try std.net.Address.parseIp(address, port orelse 53); 263 | } else { 264 | break :blk try std.net.Address.resolveIp(address, port orelse 53); 265 | } 266 | }; 267 | 268 | const flags: u32 = std.posix.SOCK.DGRAM; 269 | const fd = try std.posix.socket(addr.any.family, flags, std.posix.IPPROTO.UDP); 270 | 271 | return DNSConnection{ 272 | .address = addr, 273 | .socket = std.net.Stream{ .handle = fd }, 274 | }; 275 | } 276 | 277 | /// Open a socket to a random DNS resolver declared in the systems' 278 | /// "/etc/resolv.conf" file. 279 | pub fn connectToSystemResolver() !DNSConnection { 280 | //@compileLog("should not be reached"); 281 | var out_buffer: [256]u8 = undefined; 282 | 283 | if (builtin.os.tag != .linux) @compileError("connectToSystemResolver not supported on this target"); 284 | 285 | const nameserver_address_string = (try randomNameserver(&out_buffer)).?; 286 | 287 | return connectToResolver(nameserver_address_string, null); 288 | } 289 | 290 | pub fn randomNameserver(output_buffer: []u8) !?[]const u8 { 291 | var file = try std.fs.cwd().openFile( 292 | "/etc/resolv.conf", 293 | .{ .mode = .read_only }, 294 | ); 295 | defer file.close(); 296 | 297 | // iterate through all lines to find the amount of nameservers, then select 298 | // a random one, then read AGAIN so that we can return it. 299 | // 300 | // this doesn't need any allocator or lists or whatever. just the 301 | // output buffer 302 | 303 | try file.seekTo(0); 304 | var line_buffer: [1024]u8 = undefined; 305 | var nameserver_amount: usize = 0; 306 | while (try file.reader().readUntilDelimiterOrEof(&line_buffer, '\n')) |line| { 307 | if (std.mem.startsWith(u8, line, "#")) continue; 308 | 309 | var ns_it = std.mem.splitSequence(u8, line, " "); 310 | const decl_name = ns_it.next(); 311 | if (decl_name == null) continue; 312 | 313 | if (std.mem.eql(u8, decl_name.?, "nameserver")) { 314 | nameserver_amount += 1; 315 | } 316 | } 317 | 318 | const seed = @as(u64, @truncate(@as(u128, @bitCast(std.time.nanoTimestamp())))); 319 | var r = std.Random.DefaultPrng.init(seed); 320 | const selected = r.random().uintLessThan(usize, nameserver_amount); 321 | 322 | try file.seekTo(0); 323 | 324 | var current_nameserver: usize = 0; 325 | while (try file.reader().readUntilDelimiterOrEof(&line_buffer, '\n')) |line| { 326 | if (std.mem.startsWith(u8, line, "#")) continue; 327 | 328 | var ns_it = std.mem.splitSequence(u8, line, " "); 329 | const decl_name = ns_it.next(); 330 | if (decl_name == null) continue; 331 | 332 | if (std.mem.eql(u8, decl_name.?, "nameserver")) { 333 | if (current_nameserver == selected) { 334 | const nameserver_addr = ns_it.next().?; 335 | 336 | @memcpy(output_buffer[0..nameserver_addr.len], nameserver_addr); 337 | return output_buffer[0..nameserver_addr.len]; 338 | } 339 | 340 | current_nameserver += 1; 341 | } 342 | } 343 | 344 | return null; 345 | } 346 | 347 | const AddressList = struct { 348 | allocator: std.mem.Allocator, 349 | addrs: []std.net.Address, 350 | pub fn deinit(self: @This()) void { 351 | self.allocator.free(self.addrs); 352 | } 353 | 354 | fn fromList(allocator: std.mem.Allocator, addrs: *std.ArrayList(std.net.Address)) !AddressList { 355 | return AddressList{ .allocator = allocator, .addrs = try addrs.toOwnedSlice() }; 356 | } 357 | }; 358 | 359 | const ReceiveTrustedAddressesOptions = struct { 360 | max_incoming_message_size: usize = 4096, 361 | requested_packet_header: ?dns.Header = null, 362 | }; 363 | 364 | /// This is an optimized deserializer that is only interested in A and AAAA 365 | /// answers, returning a list of std.net.Address. 366 | /// 367 | /// This function trusts the DNS connection to be returning answers related 368 | /// to the given domain sent through DNSConnection.sendPacket. 369 | /// 370 | /// This, however, does not allocate the packet. It is very memory efficient 371 | /// in that regard. 372 | pub fn receiveTrustedAddresses( 373 | allocator: std.mem.Allocator, 374 | connection: *const DNSConnection, 375 | /// Options to receive message and deserialize it 376 | comptime options: ReceiveTrustedAddressesOptions, 377 | ) ![]std.net.Address { 378 | var packet_buffer: [options.max_incoming_message_size]u8 = undefined; 379 | const read_bytes = try connection.socket.read(&packet_buffer); 380 | const packet_bytes = packet_buffer[0..read_bytes]; 381 | logger.debug("read {d} bytes", .{read_bytes}); 382 | 383 | var stream = std.io.FixedBufferStream([]const u8){ 384 | .buffer = packet_bytes, 385 | .pos = 0, 386 | }; 387 | 388 | var ctx = dns.ParserContext{}; 389 | 390 | var parser = dns.parser(stream.reader(), &ctx, .{}); 391 | 392 | var addrs = std.ArrayList(std.net.Address).init(allocator); 393 | errdefer addrs.deinit(); 394 | 395 | var current_resource: ?dns.Resource = null; 396 | 397 | while (try parser.next()) |part| { 398 | switch (part) { 399 | .header => |header| { 400 | if (options.requested_packet_header) |given_header| { 401 | if (given_header.id != header.id) 402 | return error.InvalidReply; 403 | } 404 | 405 | if (!header.is_response) return error.InvalidResponse; 406 | 407 | switch (header.response_code) { 408 | .NoError => {}, 409 | .FormatError => return error.ServerFormatError, // bug in implementation caught by server? 410 | .ServerFailure => return error.ServerFailure, 411 | .NameError => return error.ServerNameError, 412 | .NotImplemented => return error.ServerNotImplemented, 413 | .Refused => return error.ServerRefused, 414 | } 415 | }, 416 | .answer => |raw_resource| { 417 | current_resource = raw_resource; 418 | }, 419 | 420 | .answer_rdata => |rdata| { 421 | // TODO parser.reader()? 422 | var reader = parser.wrapper_reader.reader(); 423 | defer current_resource = null; 424 | const maybe_addr = switch (current_resource.?.typ) { 425 | .A => blk: { 426 | var ip4addr: [4]u8 = undefined; 427 | _ = try reader.read(&ip4addr); 428 | break :blk std.net.Address.initIp4(ip4addr, 0); 429 | }, 430 | .AAAA => blk: { 431 | var ip6_addr: [16]u8 = undefined; 432 | _ = try reader.read(&ip6_addr); 433 | break :blk std.net.Address.initIp6(ip6_addr, 0, 0, 0); 434 | }, 435 | else => blk: { 436 | try reader.skipBytes(rdata.size, .{}); 437 | break :blk null; 438 | }, 439 | }; 440 | 441 | if (maybe_addr) |addr| try addrs.append(addr); 442 | }, 443 | else => {}, 444 | } 445 | } 446 | 447 | return try addrs.toOwnedSlice(); 448 | } 449 | 450 | fn fetchTrustedAddresses( 451 | allocator: std.mem.Allocator, 452 | name: dns.Name, 453 | qtype: dns.ResourceType, 454 | ) ![]std.net.Address { 455 | var questions = [_]dns.Question{ 456 | .{ 457 | .name = name, 458 | .typ = qtype, 459 | .class = .IN, 460 | }, 461 | }; 462 | 463 | const packet = dns.Packet{ 464 | .header = .{ 465 | .id = dns.helpers.randomHeaderId(), 466 | .is_response = false, 467 | .wanted_recursion = true, 468 | .question_length = 1, 469 | }, 470 | .questions = &questions, 471 | .answers = &[_]dns.Resource{}, 472 | .nameservers = &[_]dns.Resource{}, 473 | .additionals = &[_]dns.Resource{}, 474 | }; 475 | 476 | //@compileLog("from fetchtrustedaddresses"); 477 | const conn = try dns.helpers.connectToSystemResolver(); 478 | defer conn.close(); 479 | 480 | logger.debug("selected nameserver: {}", .{conn.address}); 481 | try conn.sendPacket(packet); 482 | return try receiveTrustedAddresses(allocator, &conn, .{}); 483 | } 484 | 485 | // implementation taken from std.net address resolution 486 | fn lookupHosts(addrs: *std.ArrayList(std.net.Address), family: std.posix.sa_family_t, port: u16, name: []const u8) !void { 487 | const file = std.fs.openFileAbsoluteZ("/etc/hosts", .{}) catch |err| switch (err) { 488 | error.FileNotFound, 489 | error.NotDir, 490 | error.AccessDenied, 491 | => return, 492 | else => |e| return e, 493 | }; 494 | defer file.close(); 495 | 496 | var buffered_reader = std.io.bufferedReader(file.reader()); 497 | const reader = buffered_reader.reader(); 498 | var line_buf: [512]u8 = undefined; 499 | while (reader.readUntilDelimiterOrEof(&line_buf, '\n') catch |err| switch (err) { 500 | error.StreamTooLong => blk: { 501 | // Skip to the delimiter in the reader, to fix parsing 502 | try reader.skipUntilDelimiterOrEof('\n'); 503 | // Use the truncated line. A truncated comment or hostname will be handled correctly. 504 | break :blk &line_buf; 505 | }, 506 | else => |e| return e, 507 | }) |line| { 508 | var split_it = std.mem.splitScalar(u8, line, '#'); 509 | const no_comment_line = split_it.first(); 510 | 511 | var line_it = std.mem.tokenizeAny(u8, no_comment_line, " \t"); 512 | const ip_text = line_it.next() orelse continue; 513 | var first_name_text: ?[]const u8 = null; 514 | while (line_it.next()) |name_text| { 515 | if (first_name_text == null) first_name_text = name_text; 516 | if (std.mem.eql(u8, name_text, name)) { 517 | break; 518 | } 519 | } else continue; 520 | 521 | const addr = std.net.Address.parseExpectingFamily(ip_text, family, port) catch |err| switch (err) { 522 | error.Overflow, 523 | error.InvalidEnd, 524 | error.InvalidCharacter, 525 | error.Incomplete, 526 | error.InvalidIPAddressFormat, 527 | error.InvalidIpv4Mapping, 528 | error.NonCanonical, 529 | => continue, 530 | }; 531 | try addrs.append(addr); 532 | } 533 | } 534 | 535 | /// A getAddressList-like function that: 536 | /// - gets a nameserver from resolv.conf 537 | /// - starts a DNSConnection 538 | /// - extracts A/AAAA records and turns them into std.net.Address 539 | /// 540 | /// The only memory allocated here is for the list that holds std.net.Address. 541 | /// 542 | /// This function does not implement the "happy eyeballs" algorithm. 543 | pub fn getAddressList(incoming_name: []const u8, port: u16, allocator: std.mem.Allocator) !AddressList { 544 | var name_buffer: [128][]const u8 = undefined; 545 | const name = try dns.Name.fromString(incoming_name, &name_buffer); 546 | 547 | var final_list = std.ArrayList(std.net.Address).init(allocator); 548 | defer final_list.deinit(); 549 | 550 | const last_label = name.full.labels[name.full.labels.len - 1]; 551 | 552 | // see if we can short-circuit on parsing the name as addr 553 | if (std.net.Address.parseExpectingFamily(incoming_name, std.posix.AF.INET, port) catch null) |addr| { 554 | try final_list.append(addr); 555 | } else if (std.net.Address.parseExpectingFamily(incoming_name, std.posix.AF.INET6, port) catch null) |addr| { 556 | try final_list.append(addr); 557 | } else if (std.mem.eql(u8, last_label, "localhost")) { 558 | // RFC 6761 Section 6.3.3 559 | // Name resolution APIs and libraries SHOULD recognize localhost 560 | // names as special and SHOULD always return the IP loopback address 561 | // for address queries and negative responses for all other query 562 | // types. 563 | try final_list.append(std.net.Address.parseIp4("127.0.0.1", port) catch unreachable); 564 | try final_list.append(std.net.Address.parseIp6("::1", port) catch unreachable); 565 | } else { 566 | if (builtin.os.tag == .windows) { 567 | const name_c = try allocator.dupeZ(u8, incoming_name); 568 | defer allocator.free(name_c); 569 | 570 | const port_c = try std.fmt.allocPrintZ(allocator, "{}", .{port}); 571 | defer allocator.free(port_c); 572 | 573 | var addr_info: ?*ws2_32.addrinfoa = null; 574 | 575 | const hints: ws2_32.addrinfo = .{ 576 | .flags = .{ .NUMERICSERV = true }, 577 | .family = ws2_32.AF.UNSPEC, 578 | .socktype = ws2_32.SOCK.STREAM, 579 | .protocol = ws2_32.IPPROTO.TCP, 580 | .addr = null, 581 | .canonname = null, 582 | .addrlen = 0, 583 | .next = null, 584 | }; 585 | 586 | for (0..2) |_| { 587 | const res = ws2_32.getaddrinfo(name_c.ptr, port_c.ptr, &hints, &addr_info); 588 | 589 | if (res != 0) { 590 | switch (@as(ws2_32.WinsockError, @enumFromInt(res))) { 591 | .WSATRY_AGAIN => return error.TryAgain, 592 | .WSAEINVAL => return error.InvalidArgument, 593 | .WSANO_RECOVERY => return error.Fatal, 594 | .WSAEAFNOSUPPORT => return error.FamilyNotSupported, 595 | .WSA_NOT_ENOUGH_MEMORY => return error.NotEnoughMemory, 596 | .WSAHOST_NOT_FOUND => return error.HostNotFound, 597 | .WSATYPE_NOT_FOUND => return error.TypeNotFound, 598 | .WSAESOCKTNOSUPPORT => return error.SocketTypeNotSupported, 599 | .WSANOTINITIALISED => { 600 | try std.os.windows.callWSAStartup(); 601 | continue; 602 | }, 603 | else => return error.InternalUnexpected, 604 | } 605 | } else break; 606 | } else return error.InternalUnexpected; 607 | 608 | defer ws2_32.freeaddrinfo(addr_info); 609 | 610 | while (addr_info) |ai| : (addr_info = ai.next) { 611 | switch (ai.family) { 612 | ws2_32.AF.INET => { 613 | const sa: *ws2_32.sockaddr.in = @as( 614 | *ws2_32.sockaddr.in, 615 | @ptrCast(@alignCast(ai.addr orelse continue)), 616 | ); 617 | const addr = std.net.Address.initIp4(@as([4]u8, @bitCast(sa.addr)), sa.port); 618 | 619 | try final_list.append(addr); 620 | }, 621 | ws2_32.AF.INET6 => { 622 | const sa: *ws2_32.sockaddr.in6 = @as( 623 | *ws2_32.sockaddr.in6, 624 | @ptrCast(@alignCast(ai.addr orelse continue)), 625 | ); 626 | const addr = std.net.Address.initIp6(sa.addr, sa.port, 0, 0); 627 | 628 | try final_list.append(addr); 629 | }, 630 | else => continue, 631 | } 632 | } 633 | } else if (builtin.os.tag == .linux) { 634 | try lookupHosts(&final_list, std.posix.AF.INET, port, incoming_name); 635 | try lookupHosts(&final_list, std.posix.AF.INET, port, incoming_name); 636 | 637 | if (final_list.items.len == 0) { 638 | // if that didn't work, go to dns server 639 | const addrs_v4 = try fetchTrustedAddresses(allocator, name, .A); 640 | defer allocator.free(addrs_v4); 641 | for (addrs_v4) |addr| try final_list.append(addr); 642 | 643 | const addrs_v6 = try fetchTrustedAddresses(allocator, name, .AAAA); 644 | defer allocator.free(addrs_v6); 645 | for (addrs_v6) |addr| try final_list.append(addr); 646 | } 647 | } else @compileError("getAddressList not supported on this target"); 648 | } 649 | 650 | // RFC 6761 is not run if everything is v4 or only 1 address returned 651 | if (final_list.items.len == 1) return AddressList.fromList(allocator, &final_list); 652 | const all_ip4 = for (final_list.items) |addr| { 653 | if (addr.any.family != std.posix.AF.INET) break false; 654 | } else true; 655 | if (all_ip4) return AddressList.fromList(allocator, &final_list); 656 | 657 | std.mem.sort(std.net.Address, final_list.items, {}, addrCmpLessThan); 658 | 659 | return AddressList.fromList(allocator, &final_list); 660 | } 661 | 662 | const Policy = struct { 663 | cidr: CidrRange, 664 | precedence: usize, 665 | label: usize, 666 | 667 | pub fn new(cidr: CidrRange, precedence: usize, label: usize) @This() { 668 | return .{ .cidr = cidr, .precedence = precedence, .label = label }; 669 | } 670 | }; 671 | 672 | // Default policy table from RFC 6724 Section 2.1 673 | const policy_table = [_]Policy{ 674 | Policy.new(CidrRange.parse("::1/128") catch unreachable, 50, 0), // Loopback 675 | Policy.new(CidrRange.parse("::/0") catch unreachable, 40, 1), // Default 676 | Policy.new(CidrRange.parse("::ffff:0:0/96") catch unreachable, 35, 4), // IPv4-mapped 677 | Policy.new(CidrRange.parse("2002::/16") catch unreachable, 30, 2), // 6to4 678 | Policy.new(CidrRange.parse("2001::/32") catch unreachable, 5, 5), // Teredo 679 | Policy.new(CidrRange.parse("fc00::/7") catch unreachable, 3, 13), // ULA 680 | Policy.new(CidrRange.parse("::/96") catch unreachable, 1, 3), // IPv4-compatible 681 | }; 682 | fn cmpGetPrecedence(addr: std.net.Address) usize { 683 | for (policy_table) |policy| { 684 | if (policy.cidr.contains(addr) catch unreachable) { 685 | return policy.precedence; 686 | } 687 | } 688 | return 40; // Default precedence if no match 689 | } 690 | 691 | fn isMulticast(a: std.net.Address) bool { 692 | return a.in6.sa.addr[0] == 0xff; 693 | } 694 | 695 | fn isLinklocal(a: std.net.Address) bool { 696 | return a.in6.sa.addr[0] == 0xfe and (a.in6.sa.addr[1] & 0xc0) == 0x80; 697 | } 698 | 699 | fn isLoopback(a: std.net.Address) bool { 700 | return a.in6.sa.addr[0] == 0 and a.in6.sa.addr[1] == 0 and 701 | a.in6.sa.addr[2] == 0 and 702 | a.in6.sa.addr[12] == 0 and a.in6.sa.addr[13] == 0 and 703 | a.in6.sa.addr[14] == 0 and a.in6.sa.addr[15] == 1; 704 | } 705 | 706 | fn isSitelocal(a: std.net.Address) bool { 707 | return a.in6.sa.addr[0] == 0xfe and (a.in6.sa.addr[1] & 0xc0) == 0xc0; 708 | } 709 | 710 | fn cmpGetScope(addr: std.net.Address) usize { 711 | if (isMulticast(addr)) { 712 | return addr.in6.sa.addr[1] & 15; 713 | } else if (isLinklocal(addr)) { 714 | return 2; 715 | } else if (isLoopback(addr)) { 716 | return 2; 717 | } else if (isSitelocal(addr)) { 718 | return 5; 719 | } 720 | return 14; 721 | } 722 | 723 | fn cmpAddresses(a: std.net.Address, b: std.net.Address) bool { 724 | // RFC 6761. Rules 3, 4, and 7 are omitted. 725 | 726 | // Rule 6: Prefer higher precedence 727 | const prec_a = cmpGetPrecedence(a); 728 | const prec_b = cmpGetPrecedence(b); 729 | 730 | if (prec_a != prec_b) { 731 | return if (prec_a > prec_b) false else true; 732 | } 733 | 734 | const scope_a = cmpGetScope(a); 735 | const scope_b = cmpGetScope(b); 736 | 737 | // Rule 8: Prefer smaller scope 738 | if (scope_a != scope_b) { 739 | return if (scope_a < scope_b) false else true; 740 | } 741 | 742 | // Rule 10: Otherwise, leave order unchanged 743 | return false; 744 | } 745 | 746 | fn addrCmpLessThan(context: void, b: std.net.Address, a: std.net.Address) bool { 747 | _ = context; 748 | return cmpAddresses(a, b); 749 | } 750 | --------------------------------------------------------------------------------