├── .gitignore ├── README.md ├── build.zig ├── protobuf.zig └── test.zig /.gitignore: -------------------------------------------------------------------------------- 1 | zig-cache/ 2 | zig-out/ 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # zigpb 2 | 3 | This is a simple Protobuf encoding and decoding library written in Zig. It aims to support Protobuf 3, but also supports custom defaults, meaning any Protobuf 2 4 | message which does not use the deprecated groups feature can also be used. 5 | 6 | Details of the API are currently subject to change. Some planned changes are as follows: 7 | 8 | - Add a comptime parser for `.proto` files 9 | -------------------------------------------------------------------------------- /build.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | 3 | pub fn build(b: *std.build.Builder) void { 4 | // Standard target options allows the person running `zig build` to choose 5 | // what target to build for. Here we do not override the defaults, which 6 | // means any target is allowed, and the default is native. Other options 7 | // for restricting supported target set are available. 8 | const target = b.standardTargetOptions(.{}); 9 | 10 | // Standard release options allow the person running `zig build` to select 11 | // between Debug, ReleaseSafe, ReleaseFast, and ReleaseSmall. 12 | const mode = b.standardReleaseOptions(); 13 | 14 | const exe_tests = b.addTest("test.zig"); 15 | exe_tests.setTarget(target); 16 | exe_tests.setBuildMode(mode); 17 | 18 | const test_step = b.step("test", "Run unit tests"); 19 | test_step.dependOn(&exe_tests.step); 20 | } 21 | -------------------------------------------------------------------------------- /protobuf.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | 3 | /// Describes how a single struct field is encoded. For 'repeated' fields, this specifies whether 4 | /// the values are packed; otherwise it just tells us the representation of the actual type (e.g. 5 | /// to differentiate between u32 and fixed32). 6 | /// If modifying this, also look at createFieldEncoding! 7 | const FieldEncoding = union(enum) { 8 | default, // bool; float/double (f32/f64); submessage (struct with pb_desc) 9 | fixed, // [s]fixed[32/64] 10 | varint, // [u]int[32/64] 11 | zigzag, // sint[32/64] 12 | string, // string 13 | repeat: *const FieldEncoding, // repeated (child encoding) 14 | repeat_pack: *const FieldEncoding, // repeated (child encoding) 15 | bytes, // bytes 16 | map: [*]const FieldEncoding, //*const [2]FieldEncoding, // map (k/v encodings) 17 | }; 18 | 19 | /// A descriptor for a single field, giving its field number and encoding. These should be stored in 20 | /// a 'pb_desc' decl on the message struct. 'oneof' values, represented as optional tagged unions, 21 | /// are the only field type which should not have a corresponding descriptor index, but they must 22 | /// contain their own 'pb_desc' decl describing the fields within them. 23 | const FieldDescriptor = struct { 24 | field_num: u29, 25 | encoding: FieldEncoding, 26 | }; 27 | 28 | /// Convenience wrapper for constructing protobuf maps. 29 | pub fn Map(comptime K: type, comptime V: type) type { 30 | return std.HashMapUnmanaged(K, V, struct { 31 | pub fn hash(_: @This(), key: K) u64 { 32 | var hasher = std.hash.Wyhash.init(0); 33 | std.hash.autoHashStrat(&hasher, key, .Deep); 34 | return hasher.final(); 35 | } 36 | pub fn eql(_: @This(), a: K, b: K) bool { 37 | return if (comptime std.meta.trait.isSlice(K)) 38 | std.mem.eql(std.meta.Child(K), a, b) 39 | else 40 | a == b; 41 | } 42 | }, std.hash_map.default_max_load_percentage); 43 | } 44 | 45 | /// A Protobuf wire type - all data is encoded as one of these. 46 | const WireType = enum(u3) { 47 | varint, 48 | i64, 49 | len, 50 | sgroup, // DEPRECATED 51 | egroup, // DEPRECATED 52 | i32, 53 | }; 54 | 55 | /// Encode 'val' into the given writer as LEB128. 56 | fn encodeVarInt(w: anytype, val: u64) !void { 57 | if (val == 0) { 58 | try w.writeByte(0); 59 | return; 60 | } 61 | 62 | var x = val; 63 | while (x != 0) { 64 | const part: u8 = @truncate(u7, x); 65 | x >>= 7; 66 | const next: u8 = @boolToInt(x != 0); 67 | try w.writeByte(next << 7 | part); 68 | } 69 | } 70 | 71 | /// Encode a field tag, composed of a field number and associated wire type. 72 | fn encodeTag(w: anytype, field_num: u29, wire_type: WireType) !void { 73 | const wire = @enumToInt(wire_type); 74 | const val = @as(u32, wire) | @as(u32, field_num) << 3; 75 | return encodeVarInt(w, val); 76 | } 77 | 78 | /// Encode 'val' of scalar type (integer, float, bool, string, or bytes) with field descriptor 79 | /// 'desc' into the given writer. If 'encode_default' is false, the field will be omitted if it 80 | /// corresponds to its type's default value. If 'include_tag' is false, the field's tag is not 81 | /// included in the output. 82 | fn encodeSingleScalar(w: anytype, val: anytype, comptime desc: FieldDescriptor, comptime encode_default: bool, comptime override_default: ?@TypeOf(val), comptime include_tag: bool) !void { 83 | const T = @TypeOf(val); 84 | 85 | if (@typeInfo(T) == .Enum) { 86 | if (desc.encoding != .default) @compileError("Enum types must use FieldEncoding.default"); 87 | const Tag = @typeInfo(T).Enum.tag_type; 88 | if (@bitSizeOf(Tag) > 32) @compileError("Enum types must have a tag type of no more than 32 bits"); 89 | const Tag32 = if (@typeInfo(Tag).Int.signedness == .signed) i32 else u32; 90 | const ival: Tag32 = @enumToInt(val); 91 | return encodeSingleScalar( 92 | w, 93 | ival, 94 | .{ .field_num = desc.field_num, .encoding = .varint }, 95 | encode_default, 96 | if (override_default) |x| @enumToInt(x) else null, 97 | include_tag, 98 | ); 99 | } 100 | 101 | switch (T) { 102 | bool => { 103 | if (desc.encoding != .default) @compileError("Boolean types must use FieldEncoding.default"); 104 | if (!encode_default and val == (override_default orelse false)) return; 105 | if (include_tag) try encodeTag(w, desc.field_num, .varint); 106 | try w.writeByte(@boolToInt(val)); 107 | }, 108 | 109 | u32, u64, i32, i64 => { 110 | if (!encode_default and val == (override_default orelse 0)) return; 111 | switch (desc.encoding) { 112 | .fixed => { 113 | if (include_tag) try encodeTag(w, desc.field_num, switch (T) { 114 | u32, i32 => .i32, 115 | u64, i64 => .i64, 116 | else => unreachable, 117 | }); 118 | try w.writeIntLittle(T, val); 119 | }, 120 | .varint => { 121 | if (include_tag) try encodeTag(w, desc.field_num, .varint); 122 | const val64: u64 = switch (T) { 123 | u32, u64 => val, 124 | i32 => @bitCast(u64, @as(i64, val)), // sign-extend 125 | i64 => @bitCast(u64, val), 126 | else => unreachable, 127 | }; 128 | try encodeVarInt(w, val64); 129 | }, 130 | .zigzag => { 131 | if (@typeInfo(T).Int.signedness != .signed) @compileError("Only signed integral types can use FieldEncoding.zigzag"); 132 | if (include_tag) try encodeTag(w, desc.field_num, .varint); 133 | if (val >= 0) { 134 | try encodeVarInt(w, @intCast(u64, val) * 2); 135 | } else { 136 | try encodeVarInt(w, @intCast(u64, -val - 1) * 2 + 1); 137 | } 138 | }, 139 | else => @compileError("Integral types must use FieldEncoding.fixed, FieldEncoding.varint, or FieldEncoding.zigzag"), 140 | } 141 | return; 142 | }, 143 | 144 | f32, f64 => { 145 | if (desc.encoding != .default) @compileError("Floating types must use FieldEncoding.default"); 146 | if (!encode_default and val == (override_default orelse 0)) return; 147 | if (T == f32) { 148 | if (include_tag) try encodeTag(w, desc.field_num, .i32); 149 | try w.writeIntLittle(u32, @bitCast(u32, val)); 150 | } else { 151 | if (include_tag) try encodeTag(w, desc.field_num, .i64); 152 | try w.writeIntLittle(u64, @bitCast(u64, val)); 153 | } 154 | }, 155 | 156 | []u8, []const u8 => { 157 | if (override_default != null) @compileError("Cannot override default for []u8"); 158 | if (!encode_default and val.len == 0) return; 159 | switch (desc.encoding) { 160 | .string, .bytes => { 161 | if (include_tag) try encodeTag(w, desc.field_num, .len); 162 | try encodeVarInt(w, val.len); 163 | try w.writeAll(val); 164 | }, 165 | else => @compileError("[]u8 must use FieldEncoding.string or FieldEncoding.bytes"), 166 | } 167 | }, 168 | 169 | else => @compileError("Type '" ++ @typeName(T) ++ "' cannot be encoded as a primitive"), 170 | } 171 | } 172 | 173 | /// Encode a single value of scalar or submessage type. 'map's are not included here since 174 | /// they're sugar for a 'repeated' submessage (and cannot themselves be repeated), meaning they are 175 | /// really multiple values. 176 | fn encodeSingleValue(w: anytype, ally: std.mem.Allocator, val: anytype, comptime desc: FieldDescriptor, comptime encode_default: bool, comptime override_default: ?@TypeOf(val)) !void { 177 | const T = @TypeOf(val); 178 | 179 | if (@typeInfo(T) == .Struct) { 180 | if (desc.encoding != .default) @compileError("Sub-messages must use FieldEncoding.default"); 181 | 182 | var buf = std.ArrayList(u8).init(ally); 183 | defer buf.deinit(); 184 | 185 | try encodeMessage(buf.writer(), ally, val); 186 | 187 | try encodeTag(w, desc.field_num, .len); 188 | try encodeVarInt(w, buf.items.len); 189 | try w.writeAll(buf.items); 190 | } else { 191 | try encodeSingleScalar(w, val, desc, encode_default, override_default, true); 192 | } 193 | } 194 | 195 | /// Encode the field 'val' with 'desc_opt' as its descriptor (null if none exists) into the given 196 | /// writer. 'field_name' is used only for error messages. 197 | fn encodeAnyField( 198 | w: anytype, 199 | ally: std.mem.Allocator, 200 | val: anytype, 201 | comptime desc_opt: ?FieldDescriptor, 202 | comptime field_name: []const u8, 203 | comptime field_default: ?@TypeOf(val), 204 | ) !void { 205 | const T = @TypeOf(val); 206 | 207 | // Nicer error message if you forgot to make your union optional 208 | if (@typeInfo(T) == .Union) { 209 | @compileError("Only optional unions can be encoded"); 210 | } 211 | 212 | if (@typeInfo(T) == .Optional and 213 | @typeInfo(std.meta.Child(T)) == .Union) 214 | { 215 | // oneof 216 | const U = std.meta.Child(T); 217 | if (desc_opt != null) @compileError("Union '" ++ field_name ++ "' must not have a field descriptor"); 218 | if (val) |un| { 219 | const pb_desc = comptime getPbDesc(U) orelse @compileError("Union '" ++ @typeName(U) ++ "' must have a pb_desc decl"); 220 | switch (un) { 221 | inline else => |payload, tag| { 222 | const sub_desc = comptime pb_desc.getField(@tagName(tag)) orelse 223 | @compileError("Mising descriptor for field '" ++ @typeName(U) ++ "." ++ @tagName(tag) ++ "'"); 224 | 225 | try encodeSingleValue(w, ally, payload, sub_desc, true, null); 226 | }, 227 | } 228 | } 229 | 230 | return; 231 | } 232 | 233 | const desc = desc_opt orelse @compileError("Missing descriptor for field '" ++ field_name ++ "'"); 234 | 235 | if (desc.encoding == .repeat) { 236 | for (val) |x| { 237 | try encodeSingleValue(w, ally, x, .{ 238 | .field_num = desc.field_num, 239 | .encoding = desc.encoding.repeat.*, 240 | }, true, null); 241 | } 242 | } else if (desc.encoding == .repeat_pack) { 243 | var buf = std.ArrayList(u8).init(ally); 244 | defer buf.deinit(); 245 | 246 | for (val.items) |x| { 247 | try encodeSingleScalar(buf.writer(), x, .{ 248 | .field_num = desc.field_num, 249 | .encoding = desc.encoding.repeat_pack.*, 250 | }, true, null, false); 251 | } 252 | 253 | try encodeTag(w, desc.field_num, .len); 254 | try encodeVarInt(w, buf.items.len); 255 | try w.writeAll(buf.items); 256 | } else if (desc.encoding == .map) { 257 | var it = val.iterator(); 258 | while (it.next()) |pair| { 259 | try encodeSingleValue(w, ally, struct { 260 | k: std.meta.FieldType(T.KV, .key), 261 | v: std.meta.FieldType(T.KV, .value), 262 | const pb_desc = .{ 263 | .k = .{ 1, desc.encoding.map[0] }, 264 | .v = .{ 2, desc.encoding.map[1] }, 265 | }; 266 | }{ .k = pair.key_ptr.*, .v = pair.value_ptr.* }, .{ 267 | .field_num = desc.field_num, 268 | .encoding = .default, 269 | }, true, null); 270 | } 271 | } else if (@typeInfo(T) == .Optional) { 272 | if (val) |x| { 273 | try encodeSingleValue(w, ally, x, desc, true, null); 274 | } 275 | } else { 276 | try encodeSingleValue(w, ally, val, desc, false, field_default); 277 | } 278 | } 279 | 280 | /// Encode an entire Protobuf message 'msg' into the given writer. Only temporary allocations are 281 | /// performed, all of which are cleaned up before this function returns. 282 | pub fn encodeMessage(w: anytype, ally: std.mem.Allocator, msg: anytype) !void { 283 | const Msg = @TypeOf(msg); 284 | const pb_desc = comptime getPbDesc(Msg) orelse 285 | @compileError("Message type '" ++ @typeName(Msg) ++ "' must have a pb_desc decl"); 286 | 287 | validateDescriptors(Msg); 288 | 289 | inline for (@typeInfo(Msg).Struct.fields) |field| { 290 | const desc: ?FieldDescriptor = comptime pb_desc.getField(field.name); 291 | 292 | const default: ?field.type = if (field.default_value) |ptr| 293 | @ptrCast(*const field.type, ptr).* 294 | else 295 | null; 296 | 297 | try encodeAnyField(w, ally, @field(msg, field.name), desc, @typeName(Msg) ++ "." ++ field.name, default); 298 | } 299 | } 300 | 301 | /// Perform some basic checks on the field descriptors in the message type 'Msg', ensuring every 302 | /// descriptor corresponds to a field and that field numbers appear at most once. 303 | fn validateDescriptors(comptime Msg: type) void { 304 | comptime { 305 | var seen_field_nums: []const u29 = &.{}; 306 | validateDescriptorsInner(Msg, &seen_field_nums); 307 | for (seen_field_nums) |x, i| { 308 | for (seen_field_nums[i + 1 ..]) |y| { 309 | if (x == y) { 310 | @compileError(std.fmt.comptimePrint("Duplicate field number {} in type '{s}'", .{ x, @typeName(Msg) })); 311 | } 312 | } 313 | } 314 | } 315 | } 316 | 317 | fn validateDescriptorsInner(comptime Msg: type, comptime seen_field_nums: *[]const u29) void { 318 | const pb_desc = comptime getPbDesc(Msg).?; 319 | for (pb_desc.fields) |field_desc| { 320 | const name = field_desc[0]; 321 | if (!@hasField(Msg, name)) { 322 | @compileError("Descriptor '" ++ name ++ "' does not correspond to any field in type '" + @typeName(Msg)); 323 | } 324 | seen_field_nums.* = seen_field_nums.* ++ &[1]u29{field_desc[1].field_num}; 325 | } 326 | 327 | for (std.meta.fields(Msg)) |field| { 328 | if (@typeInfo(field.type) == .Struct and comptime getPbDesc(field.type) != null) { 329 | validateDescriptors(field.type); 330 | } else if (@typeInfo(field.type) == .Optional and 331 | @typeInfo(std.meta.Child(field.type)) == .Union and 332 | comptime getPbDesc(std.meta.Child(field.type)) != null) 333 | { 334 | validateDescriptorsInner(std.meta.Child(field.type), seen_field_nums); 335 | } 336 | } 337 | } 338 | 339 | /// A small wrapper around a decoded message. You must call 'deinit' once you're done with the 340 | /// message to free all its allocated memory. 341 | pub fn Decoded(comptime Msg: type) type { 342 | return struct { 343 | msg: Msg, 344 | arena: std.heap.ArenaAllocator, 345 | 346 | const Self = @This(); 347 | 348 | pub fn deinit(self: Self) void { 349 | self.arena.deinit(); 350 | } 351 | }; 352 | } 353 | 354 | fn initDefault(comptime Msg: type, arena: std.mem.Allocator) Msg { 355 | var result: Msg = undefined; 356 | 357 | inline for (comptime std.meta.fields(Msg)) |field| { 358 | if (comptime std.meta.trait.isSlice(field.type)) { 359 | @field(result, field.name) = &.{}; 360 | continue; 361 | } 362 | 363 | const default: ?field.type = if (field.default_value) |ptr| 364 | @ptrCast(*const field.type, ptr).* 365 | else 366 | null; 367 | 368 | @field(result, field.name) = switch (@typeInfo(field.type)) { 369 | .Optional => default orelse null, 370 | .Int, .Float => default orelse 0, 371 | .Enum => |e| default orelse if (e.is_exhaustive) 372 | comptime std.meta.intToEnum(field.type, 0) catch 373 | @compileError("Enum '" ++ @typeName(field.type) ++ "' has no 0 default") 374 | else 375 | @intToEnum(field.type, 0), 376 | .Bool => default orelse false, 377 | .Struct => if (comptime getPbDesc(field.type) != null) 378 | initDefault(field.type, arena) 379 | else 380 | field.type{}, 381 | else => @compileError("Type '" ++ @typeName(field.type) ++ "' cannot be deserialized"), 382 | }; 383 | } 384 | 385 | return result; 386 | } 387 | 388 | fn decodeVarInt(r: anytype) !u64 { 389 | var shift: u6 = 0; 390 | var x: u64 = 0; 391 | while (true) { 392 | const b = try r.readByte(); 393 | x |= @as(u64, @truncate(u7, b)) << shift; 394 | if (b >> 7 == 0) break; 395 | shift += 7; 396 | } 397 | return x; 398 | } 399 | 400 | fn skipField(r: anytype, wire_type: WireType, field_num: u29) !void { 401 | switch (wire_type) { 402 | .varint => _ = try decodeVarInt(r), 403 | .i64 => _ = try r.readIntLittle(u64), 404 | .len => { 405 | const len = try decodeVarInt(r); 406 | try r.skipBytes(len, .{}); 407 | }, 408 | .sgroup => { 409 | while (true) { 410 | const tag = try decodeVarInt(r); 411 | const sub_wire = std.meta.intToEnum(WireType, @truncate(u3, tag)) catch return error.MalformedInput; 412 | const sub_num = std.math.cast(u29, tag >> 3) orelse return error.MalformedInput; 413 | if (sub_wire == .egroup and sub_num == field_num) { 414 | break; 415 | } 416 | try skipField(r, sub_wire, sub_num); 417 | } 418 | }, 419 | .egroup => return error.MalformedInput, 420 | .i32 => _ = try r.readIntLittle(u32), 421 | } 422 | } 423 | 424 | fn decodeSingleScalar(comptime T: type, comptime encoding: FieldEncoding, r: anytype, arena: std.mem.Allocator, wire_type: WireType) !T { 425 | if (@typeInfo(T) == .Enum) { 426 | if (encoding != .default) @compileError("Enum types must use FieldEncoding.default"); 427 | const Tag = @typeInfo(T).Enum.tag_type; 428 | if (@bitSizeOf(Tag) > 32) @compileError("Enum types must have a tag type of no more than 32 bits"); 429 | const Tag32 = if (@typeInfo(Tag).Int.signedness == .signed) i32 else u32; 430 | const ival = try decodeSingleScalar(Tag32, .varint, r, arena, wire_type); 431 | if (@typeInfo(T).Enum.is_exhaustive) { 432 | return std.meta.intToEnum(T, ival) catch return error.UnknownEnumTag; 433 | } else { 434 | return @intToEnum(T, ival); 435 | } 436 | } 437 | 438 | switch (T) { 439 | bool => { 440 | if (encoding != .default) @compileError("Boolean types must use FieldEncoding.default"); 441 | if (wire_type != .varint) return error.MalformedInput; 442 | const x = try decodeVarInt(r); 443 | return @truncate(u32, x) != 0; 444 | }, 445 | 446 | u32, u64, i32, i64 => { 447 | switch (encoding) { 448 | .fixed => { 449 | switch (T) { 450 | u32, i32 => if (wire_type != .i32) return error.MalformedInput, 451 | u64, i64 => if (wire_type != .i64) return error.MalformedInput, 452 | else => unreachable, 453 | } 454 | return r.readIntLittle(T); 455 | }, 456 | .varint => { 457 | if (wire_type != .varint) return error.MalformedInput; 458 | const Unsigned = switch (T) { 459 | u32, i32 => u32, 460 | u64, i64 => u64, 461 | else => unreachable, 462 | }; 463 | return @bitCast(T, @truncate(Unsigned, try decodeVarInt(r))); 464 | }, 465 | .zigzag => { 466 | if (@typeInfo(T).Int.signedness != .signed) @compileError("Only signed integral types can use FieldEncoding.zigzag"); 467 | if (wire_type != .varint) return error.MalformedInput; 468 | const raw = try decodeVarInt(r); 469 | const val = if (raw % 2 == 1) 470 | -@intCast(i64, raw / 2) - 1 471 | else 472 | @intCast(i64, raw / 2); 473 | return @truncate(T, val); 474 | }, 475 | else => @compileError("Integral types must use FieldEncoding.fixed, FieldEncoding.varint, or FieldEncoding.zigzag"), 476 | } 477 | }, 478 | 479 | f32, f64 => { 480 | if (encoding != .default) @compileError("Floating types must use FieldEncoding.default"); 481 | if (T == f32) { 482 | if (wire_type != .i32) return error.MalformedInput; 483 | return @bitCast(f32, try r.readIntLittle(u32)); 484 | } else { 485 | if (wire_type != .i64) return error.MalformedInput; 486 | return @bitCast(f64, try r.readIntLittle(u64)); 487 | } 488 | }, 489 | 490 | []u8, []const u8 => { 491 | if (encoding != .string and encoding != .bytes) @compileError("[]u8 must use FieldEncoding.string or FieldEncoding.bytes"); 492 | if (wire_type != .len) return error.MalformedInput; 493 | const len = try decodeVarInt(r); 494 | const buf = try arena.alloc(u8, len); 495 | try r.readNoEof(buf); 496 | return buf; 497 | }, 498 | 499 | else => @compileError("Type '" ++ @typeName(T) ++ "' cannot be decoded as a primitive"), 500 | } 501 | } 502 | 503 | /// Decodes a value of scalar or submessage type, returning the result. 504 | fn decodeSingleValue(comptime T: type, comptime encoding: FieldEncoding, r: anytype, arena: std.mem.Allocator, wire_type: WireType) !T { 505 | if (@typeInfo(T) == .Struct) { 506 | if (encoding != .default) @compileError("Sub-messages must use FieldEncoding.default"); 507 | if (wire_type != .len) return error.MalformedInput; 508 | const len = try decodeVarInt(r); 509 | var lr = std.io.limitedReader(r, len); 510 | return decodeMessageInner(T, lr.reader(), arena); 511 | } else { 512 | return decodeSingleScalar(T, encoding, r, arena, wire_type); 513 | } 514 | } 515 | 516 | /// Attempts to decode a field of any type, modifying the result location as necessary (either 517 | /// overwriting the value or appending data). Returns true if this message corresponded to the given 518 | /// field (and was decoded). 519 | fn maybeDecodeAnyField(comptime T: type, comptime desc_opt: ?FieldDescriptor, comptime field_name: []const u8, r: anytype, arena: std.mem.Allocator, wire_type: WireType, field_num: u29, result: *T) !bool { 520 | // Nicer error message if you forgot to make your union optional 521 | if (@typeInfo(T) == .Union) { 522 | @compileError("Only optional unions can be decoded"); 523 | } 524 | 525 | if (@typeInfo(T) == .Optional and @typeInfo(std.meta.Child(T)) == .Union) { 526 | if (desc_opt != null) @compileError("Union must not have a field descriptor"); 527 | if (try maybeDecodeOneOf(std.meta.Child(T), r, arena, wire_type, field_num)) |val| { 528 | result.* = val; 529 | return true; 530 | } else { 531 | return false; 532 | } 533 | } 534 | 535 | const desc = desc_opt orelse @compileError("Missing descriptor for field '" ++ field_name ++ "'"); 536 | 537 | if (field_num != desc.field_num) return false; 538 | 539 | if (desc.encoding == .repeat or desc.encoding == .repeat_pack) { 540 | const Elem = std.meta.Child(T.Slice); 541 | const scalar_elem = switch (@typeInfo(Elem)) { 542 | .Int, .Bool, .Float => true, 543 | else => Elem == []u8 or Elem == []const u8, 544 | }; 545 | if (desc.encoding == .repeat_pack and !scalar_elem) { 546 | @compileError("Packed repeated fields must be slices of scalar types"); 547 | } 548 | const child_enc = switch (desc.encoding) { 549 | .repeat, .repeat_pack => |e| e.*, 550 | else => unreachable, 551 | }; 552 | // By spec, decoders should be able to decode non-packed repeated fields as packed and vice 553 | // versa, so that the protocol can be changed whilst preserving forwards and backwards 554 | // compatibility. 555 | if (scalar_elem) { 556 | if (wire_type == .len) { 557 | const len = try decodeVarInt(r); 558 | var lr = std.io.limitedReader(r, len); 559 | const expect_wire: WireType = switch (child_enc) { 560 | .fixed => switch (Elem) { 561 | u32, i32, f32 => .i32, 562 | u64, i64, f64 => .i64, 563 | else => undefined, // not unreachable to defer to nice error handling in decodeSingleScalar 564 | }, 565 | .varint, .zigzag => .varint, 566 | .string, .bytes => .len, 567 | .default => switch (Elem) { 568 | bool => .varint, 569 | f32 => .i32, 570 | f64 => .i64, 571 | else => undefined, // not unreachable to defer to nice error handling in decodeSingleScalar 572 | }, 573 | else => undefined, 574 | }; 575 | 576 | while (decodeSingleScalar(Elem, child_enc, lr.reader(), arena, expect_wire)) |elem| { 577 | try result.*.append(arena, elem); 578 | } else |err| switch (err) { 579 | error.EndOfStream => {}, 580 | else => |e| return e, 581 | } 582 | 583 | return true; 584 | } 585 | } 586 | 587 | const elem = try decodeSingleScalar(Elem, child_enc, r, arena, wire_type); 588 | try result.*.append(arena, elem); 589 | } else if (desc.encoding == .map) { 590 | const val = try decodeSingleValue(struct { 591 | k: std.meta.FieldType(T.KV, .key), 592 | v: std.meta.FieldType(T.KV, .value), 593 | const pb_desc = .{ 594 | .k = .{ 1, desc.encoding.map[0] }, 595 | .v = .{ 2, desc.encoding.map[1] }, 596 | }; 597 | }, .default, r, arena, wire_type); 598 | try result.put(arena, val.k, val.v); 599 | } else if (@typeInfo(T) == .Optional) { 600 | result.* = try decodeSingleValue(std.meta.Child(T), desc.encoding, r, arena, wire_type); 601 | } else { 602 | result.* = try decodeSingleValue(T, desc.encoding, r, arena, wire_type); 603 | } 604 | 605 | return true; 606 | } 607 | 608 | fn maybeDecodeOneOf(comptime U: type, r: anytype, arena: std.mem.Allocator, wire_type: WireType, field_num: u29) !?U { 609 | const pb_desc = comptime getPbDesc(U) orelse 610 | @compileError("Union '" ++ @typeName(U) ++ "' must have a pb_desc decl"); 611 | 612 | inline for (std.meta.fields(U)) |field| { 613 | const desc = comptime pb_desc.getField(field.name) orelse 614 | @compileError("Missing descriptor for field '" ++ @typeName(U) ++ "." ++ field.name ++ "'"); 615 | 616 | if (desc.field_num == field_num) { 617 | const payload = try decodeSingleValue(field.type, desc.encoding, r, arena, wire_type); 618 | return @unionInit(U, field.name, payload); 619 | } 620 | } 621 | 622 | return null; 623 | } 624 | 625 | fn decodeMessageInner(comptime Msg: type, r: anytype, arena: std.mem.Allocator) !Msg { 626 | const pb_desc = comptime getPbDesc(Msg) orelse @compileError("Message type '" ++ @typeName(Msg) ++ "' must have a pb_desc decl"); 627 | validateDescriptors(Msg); 628 | 629 | var result = initDefault(Msg, arena); 630 | 631 | while (decodeVarInt(r)) |tag| { 632 | const wire_type = std.meta.intToEnum(WireType, @truncate(u3, tag)) catch return error.MalformedInput; 633 | const field_num = std.math.cast(u29, tag >> 3) orelse return error.MalformedInput; 634 | 635 | inline for (std.meta.fields(Msg)) |field| { 636 | const desc_opt: ?FieldDescriptor = comptime pb_desc.getField(field.name); 637 | 638 | if (try maybeDecodeAnyField(field.type, desc_opt, @typeName(Msg) ++ "." ++ field.name, r, arena, wire_type, field_num, &@field(result, field.name))) { 639 | break; 640 | } 641 | } else { 642 | try skipField(r, wire_type, field_num); 643 | } 644 | } else |err| switch (err) { 645 | error.EndOfStream => {}, 646 | else => |e| return e, 647 | } 648 | 649 | return result; 650 | } 651 | 652 | pub fn decodeMessage(comptime Msg: type, r: anytype, ally: std.mem.Allocator) !Decoded(Msg) { 653 | var arena = std.heap.ArenaAllocator.init(ally); 654 | errdefer arena.deinit(); 655 | 656 | return .{ 657 | .msg = try decodeMessageInner(Msg, r, arena.allocator()), 658 | .arena = arena, 659 | }; 660 | } 661 | 662 | const PbDesc = struct { 663 | const Entry = struct { []const u8, FieldDescriptor }; 664 | fields: []const Entry, 665 | 666 | fn getField(self: PbDesc, name: []const u8) ?FieldDescriptor { 667 | for (self.fields) |f| { 668 | if (std.mem.eql(u8, f[0], name)) return f[1]; 669 | } 670 | return null; 671 | } 672 | }; 673 | 674 | // Directly making a pb_desc with fields of type FieldDescriptor is quite inconvenient, so instead 675 | // we'll take big literals in the same shape and parse them into the real descriptors. 676 | 677 | fn getPbDesc(comptime T: type) ?PbDesc { 678 | comptime { 679 | if (!@hasDecl(T, "pb_desc")) return null; 680 | const desc = T.pb_desc; 681 | 682 | var fields: []const PbDesc.Entry = &.{}; 683 | 684 | for (std.meta.fields(@TypeOf(desc))) |field| { 685 | const fd = createFieldDesc(@field(desc, field.name), @typeName(T) ++ "." ++ field.name); 686 | fields = fields ++ &[1]PbDesc.Entry{.{ field.name, fd }}; 687 | } 688 | 689 | return .{ .fields = fields }; 690 | } 691 | } 692 | 693 | fn createFieldDesc(comptime desc: anytype, comptime field_name: []const u8) FieldDescriptor { 694 | if (!std.meta.trait.isTuple(@TypeOf(desc))) { 695 | @compileError("Bad descriptor format for field '" ++ field_name ++ "'"); 696 | } 697 | 698 | return .{ 699 | .field_num = desc[0], 700 | .encoding = createFieldEncoding(desc[1], field_name), 701 | }; 702 | } 703 | 704 | fn createFieldEncoding(comptime enc: anytype, comptime field_name: []const u8) FieldEncoding { 705 | if (@TypeOf(enc) == FieldEncoding) { 706 | return enc; 707 | } else if (@TypeOf(enc) == @Type(.EnumLiteral)) { 708 | // try to match with an encoding type 709 | for (std.meta.fields(FieldEncoding)) |field| { 710 | if (std.mem.eql(u8, @tagName(enc), field.name)) { 711 | return @field(FieldEncoding, field.name); 712 | } 713 | } 714 | } else if (@typeInfo(@TypeOf(enc)) == .Struct) { 715 | // nested encoding types 716 | const fields = @typeInfo(@TypeOf(enc)).Struct.fields; 717 | if (fields.len == 1) { 718 | const tag = fields[0].name; 719 | const val = @field(enc, tag); 720 | if (std.mem.eql(u8, tag, "repeat")) { 721 | const child = createFieldEncoding(val, field_name); 722 | return .{ .repeat = &child }; 723 | } else if (std.mem.eql(u8, tag, "repeat_pack")) { 724 | const child = createFieldEncoding(val, field_name); 725 | return .{ .repeat_pack = &child }; 726 | } else if (std.mem.eql(u8, tag, "map")) { 727 | if (std.meta.trait.isTuple(@TypeOf(val)) and val.len == 2) { 728 | const child0 = createFieldEncoding(val[0], field_name); 729 | const child1 = createFieldEncoding(val[1], field_name); 730 | return .{ .map = &[2]FieldEncoding{ child0, child1 } }; 731 | } 732 | } 733 | } 734 | } 735 | 736 | @compileError("Bad encoding for field '" ++ field_name ++ "'"); 737 | } 738 | -------------------------------------------------------------------------------- /test.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const pb = @import("protobuf.zig"); 3 | 4 | fn expectEqualMessages(comptime T: type, expected: T, actual: T) !void { 5 | if (@typeInfo(T) == .Optional) { 6 | try std.testing.expectEqual(expected == null, actual == null); 7 | return expectEqualMessages(std.meta.Child(T), expected.?, actual.?); 8 | } 9 | 10 | if (@typeInfo(T) == .Union) { 11 | try std.testing.expectEqual(std.meta.activeTag(expected), std.meta.activeTag(actual)); 12 | switch (expected) { 13 | inline else => |val, tag| { 14 | return expectEqualMessages(@TypeOf(val), val, @field(actual, @tagName(tag))); 15 | }, 16 | } 17 | } 18 | 19 | if (@typeInfo(T) == .Struct) { 20 | if (@hasDecl(T, "pb_desc")) { 21 | inline for (comptime std.meta.fields(T)) |field| { 22 | try expectEqualMessages(field.type, @field(expected, field.name), @field(actual, field.name)); 23 | } 24 | } else if (@hasDecl(T, "GetOrPutResult")) { 25 | try std.testing.expectEqual(expected.count(), actual.count()); 26 | var it = expected.iterator(); 27 | while (it.next()) |pair| { 28 | const val = actual.get(pair.key_ptr.*) orelse return error.TestExpectedEqual; 29 | try std.testing.expectEqual(pair.value_ptr.*, val); 30 | } 31 | } else if (@hasDecl(T, "Slice")) { 32 | try std.testing.expectEqualSlices(std.meta.Child(T.Slice), expected.items, actual.items); 33 | } else { 34 | @compileError("Cannot test equality of type '" ++ @typeName(T) ++ "'"); 35 | } 36 | return; 37 | } 38 | 39 | if (T == []const u8 or T == []u8) { 40 | return std.testing.expectEqualSlices(u8, expected, actual); 41 | } 42 | 43 | switch (@typeInfo(T)) { 44 | .Int, .Float, .Enum => try std.testing.expectEqual(expected, actual), 45 | else => @compileError("Cannot test equality of type '" ++ @typeName(T) ++ "'"), 46 | } 47 | } 48 | 49 | fn initMessage(comptime T: type, comptime val: anytype, arena: std.mem.Allocator) !T { 50 | if (@typeInfo(T) == .Optional) { 51 | if (@typeInfo(@TypeOf(val)) == .Optional) { 52 | return if (val) |x| try initMessage(std.meta.Child(T), x, arena) else null; 53 | } else { 54 | return try initMessage(std.meta.Child(T), val, arena); 55 | } 56 | } 57 | 58 | if (@typeInfo(T) == .Union) { 59 | if (@typeInfo(@TypeOf(val)) != .Struct) @compileError("Expected struct literal to initialize union"); 60 | const fields = @typeInfo(@TypeOf(val)).Struct.fields; 61 | if (fields.len != 1) @compileError("Expected single-element struct to initialize union"); 62 | return @unionInit(T, fields[0].name, try initMessage( 63 | std.meta.TagPayload(T, @field(std.meta.Tag(T), fields[0].name)), 64 | @field(val, fields[0].name), 65 | arena, 66 | )); 67 | } 68 | 69 | if (@typeInfo(T) == .Struct) { 70 | if (@hasDecl(T, "pb_desc")) { 71 | var result: T = undefined; 72 | inline for (comptime std.meta.fields(T)) |field| { 73 | @field(result, field.name) = try initMessage(field.type, @field(val, field.name), arena); 74 | } 75 | return result; 76 | } else if (@hasDecl(T, "GetOrPutResult")) { 77 | var result: T = .{}; 78 | inline for (val) |pair| { 79 | try result.put(arena, pair[0], pair[1]); 80 | } 81 | return result; 82 | } else if (@hasDecl(T, "Slice")) { 83 | var result: T = .{}; 84 | try result.appendSlice(arena, &val); 85 | return result; 86 | } else { 87 | @compileError("Cannot initalize type '" ++ @typeName(T) ++ "'"); 88 | } 89 | return; 90 | } 91 | 92 | if (T == []u8 and @TypeOf(val) == []const u8) { 93 | return arena.dupe(val); 94 | } else { 95 | return val; 96 | } 97 | } 98 | 99 | fn testEncodeDecode(comptime Msg: type, comptime val: anytype) !void { 100 | var arena = std.heap.ArenaAllocator.init(std.testing.allocator); 101 | defer arena.deinit(); 102 | 103 | const msg = try initMessage(Msg, val, arena.allocator()); 104 | 105 | var buf = std.ArrayList(u8).init(std.testing.allocator); 106 | defer buf.deinit(); 107 | 108 | try pb.encodeMessage(buf.writer(), std.testing.allocator, msg); 109 | 110 | var fbs = std.io.fixedBufferStream(buf.items); 111 | const decoded = try pb.decodeMessage(Msg, fbs.reader(), std.testing.allocator); 112 | defer decoded.deinit(); 113 | 114 | try expectEqualMessages(Msg, msg, decoded.msg); 115 | } 116 | 117 | test { 118 | try testEncodeDecode(struct { 119 | single1: u32, 120 | single2: u32, 121 | opt: ?u64, 122 | rep: std.ArrayListUnmanaged(i32), 123 | map: pb.Map([]const u8, f32), 124 | options: ?union(enum) { 125 | foo: u32, 126 | bar: []const u8, 127 | pub const pb_desc = .{ 128 | .foo = .{ 5, .varint }, 129 | .bar = .{ 10, .bytes }, 130 | }; 131 | }, 132 | embedded: struct { 133 | x: u32, 134 | y: i64, 135 | pub const pb_desc = .{ 136 | .x = .{ 1, .varint }, 137 | .y = .{ 2, .zigzag }, 138 | }; 139 | }, 140 | en: enum { 141 | val1, 142 | val2, 143 | val3, 144 | }, 145 | 146 | pub const pb_desc = .{ 147 | .single1 = .{ 1, .varint }, 148 | .single2 = .{ 42, .fixed }, 149 | .opt = .{ 2, .varint }, 150 | .rep = .{ 3, .{ .repeat_pack = .zigzag } }, 151 | .map = .{ 4, .{ .map = .{ .string, .default } } }, 152 | .embedded = .{ 6, .default }, 153 | .en = .{ 7, .default }, 154 | }; 155 | }, .{ 156 | .single1 = 256, 157 | .single2 = 0, 158 | .opt = 0, 159 | .rep = .{ 69, 0, 42 }, 160 | .map = .{ 161 | .{ "hello", 1 }, 162 | .{ "", 2 }, 163 | .{ "this has\x00embedded nuls", 3 }, 164 | }, 165 | .options = .{ 166 | .bar = "ziggify all the kingdoms", 167 | }, 168 | .embedded = .{ 169 | .x = 1 << 20, 170 | .y = -2048, 171 | }, 172 | .en = .val2, 173 | }); 174 | } 175 | --------------------------------------------------------------------------------