├── .gitignore ├── examples └── jsonator │ ├── .gitignore │ ├── sample.proto │ ├── sample2.proto │ ├── generate.go │ ├── README.md │ └── main.go ├── fixtures ├── fileset.pb ├── generate.go ├── todo_import.proto ├── extend.proto ├── edition2023.proto ├── edition2023_implicit.proto ├── edition2024.proto ├── booking.proto └── todo.proto ├── parser_bench_test.go ├── .github ├── PULL_REQUEST_TEMPLATE.md ├── workflows │ ├── release.yaml │ └── ci.yaml └── CONTRIBUTING.md ├── doc.go ├── go.mod ├── example_plugin_test.go ├── LICENSE ├── comments_test.go ├── .goreleaser.yaml ├── flake.nix ├── plugin.go ├── flake.lock ├── context_test.go ├── .golangci.yaml ├── go.sum ├── context.go ├── README.md ├── taskfile.yaml ├── plugin_test.go ├── comments.go ├── utils ├── protobuf.go └── protobuf_test.go ├── parser.go ├── parser_test.go └── types.go /.gitignore: -------------------------------------------------------------------------------- 1 | .direnv/ 2 | bin/ 3 | dist/ 4 | -------------------------------------------------------------------------------- /examples/jsonator/.gitignore: -------------------------------------------------------------------------------- 1 | third_party 2 | jsonator 3 | output.json 4 | -------------------------------------------------------------------------------- /fixtures/fileset.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pseudomuto/protokit/HEAD/fixtures/fileset.pb -------------------------------------------------------------------------------- /fixtures/generate.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | //go:generate protoc --descriptor_set_out=fileset.pb --include_imports --include_source_info -I. ./booking.proto ./todo.proto ./extend.proto ./edition2023.proto ./edition2024.proto ./edition2023_implicit.proto 4 | -------------------------------------------------------------------------------- /fixtures/todo_import.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | // This is really just in place to make sure imported files are included in the parsed result. 4 | package com.pseudomuto.protokit.v1; 5 | 6 | // Details for list items 7 | message ListItemDetails { 8 | string notes = 1; // Some notes for the item 9 | } 10 | 11 | // A dummy enum to ensure importing works. 12 | enum ListItemDetailEnum { 13 | DEFAULT = 0; // The default value. 14 | } 15 | -------------------------------------------------------------------------------- /parser_bench_test.go: -------------------------------------------------------------------------------- 1 | package protokit_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/pseudomuto/protokit" 7 | "github.com/pseudomuto/protokit/utils" 8 | ) 9 | 10 | func BenchmarkParseCodeGenRequest(b *testing.B) { 11 | fds, _ := utils.LoadDescriptorSet("fixtures", "fileset.pb") 12 | req := utils.CreateGenRequest(fds, "booking.proto", "todo.proto") 13 | 14 | for b.Loop() { 15 | protokit.ParseCodeGenRequest(req) 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /examples/jsonator/sample.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "google/api/annotations.proto"; 4 | import "google/protobuf/empty.proto"; 5 | 6 | // This is just a sample proto for demonstrating how to use this library. 7 | // 8 | // There's nothing "fancy" here. 9 | package com.jsonator.v1; 10 | 11 | service SampleService { 12 | rpc RandomSample(google.protobuf.Empty) returns (Sample) { 13 | option (google.api.http).get = "/v1/sample"; 14 | } 15 | } 16 | 17 | message Sample { 18 | int64 id = 1; 19 | } 20 | -------------------------------------------------------------------------------- /examples/jsonator/sample2.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "google/api/annotations.proto"; 4 | import "google/protobuf/empty.proto"; 5 | 6 | // This is just another sample proto for demonstrating how to use this library. 7 | // 8 | // There's also nothing "fancy" here. 9 | package com.jsonator.v2; 10 | 11 | service SampleService { 12 | rpc RandomSample(google.protobuf.Empty) returns (Sample) { 13 | option (google.api.http).get = "/v2/sample"; 14 | } 15 | } 16 | 17 | message Sample { 18 | int64 id = 1; 19 | } 20 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ping @pseudomuto 2 | 3 | ### What is Changing? 4 | 5 | _Make sure you spell out in as much detail as necessary what will happen to which systems when your PR is merged, 6 | what are the expected changes._ 7 | 8 | ### How is it Changing? 9 | 10 | _Include any relevant implementation details, mimize surprises for the reviewers in this section, if you had to take some 11 | unorthodox approaches (read hacks), explain why here._ 12 | 13 | ### What Could Go Wrong? 14 | 15 | _How has this change been tested? In your opinion what is the risk, if any, of merging these changes._ 16 | -------------------------------------------------------------------------------- /examples/jsonator/generate.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | //go:generate go build -o ./jsonator main.go 4 | //go:generate mkdir -p third_party/google/api 5 | //go:generate curl -sSL -o third_party/google/api/annotations.proto https://raw.githubusercontent.com/googleapis/googleapis/master/google/api/annotations.proto 6 | //go:generate curl -sSL -o third_party/google/api/http.proto https://raw.githubusercontent.com/googleapis/googleapis/master/google/api/http.proto 7 | //go:generate protoc --plugin=protoc-gen-jsonator=./jsonator -I. -Ithird_party --jsonator_out=. ./sample.proto ./sample2.proto 8 | //go:generate rm -rf third_party 9 | //go:generate rm ./jsonator 10 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // Package protokit is a library that makes it easy to create your own protoc plugins. It has excellent test coverage, 2 | // and saves you so much time! 3 | // 4 | // There are two main things this library provides; a parser for parsing protobuf files into some well-defined structs, 5 | // and an abstraction to make it simple to write your own protoc plugins. 6 | // 7 | // # Getting Started 8 | // 9 | // For a quick view of how to get started, see https://godoc.org/github.com/pseudomuto/protokit#example-RunPlugin 10 | // 11 | // If you want see/try a working example, check out the examples in 12 | // https://github.com/pseudomuto/protokit/tree/master/examples. 13 | package protokit 14 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - "v*" 7 | 8 | permissions: 9 | contents: write 10 | packages: write 11 | 12 | jobs: 13 | release: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - name: Checkout 17 | uses: actions/checkout@v5 18 | with: 19 | fetch-depth: 0 20 | 21 | - name: Install Nix 22 | uses: DeterminateSystems/nix-installer-action@v20 23 | 24 | - name: Setup Nix cache 25 | uses: DeterminateSystems/magic-nix-cache-action@v13 26 | 27 | - name: Run GoReleaser 28 | run: nix develop -c goreleaser release --clean 29 | env: 30 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 31 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/pseudomuto/protokit 2 | 3 | go 1.25.1 4 | 5 | require ( 6 | github.com/stretchr/testify v1.2.1 7 | google.golang.org/genproto/googleapis/api v0.0.0-20241015192408-796eee8c2d53 8 | google.golang.org/protobuf v1.36.10 9 | ) 10 | 11 | require ( 12 | github.com/davecgh/go-spew v1.1.0 // indirect 13 | github.com/pmezard/go-difflib v1.0.0 // indirect 14 | golang.org/x/mod v0.26.0 // indirect 15 | golang.org/x/sync v0.16.0 // indirect 16 | golang.org/x/sys v0.34.0 // indirect 17 | golang.org/x/telemetry v0.0.0-20250710130107-8d8967aff50b // indirect 18 | golang.org/x/tools v0.35.1-0.20250728180453-01a3475a31bc // indirect 19 | golang.org/x/tools/gopls v0.20.0 // indirect 20 | ) 21 | 22 | tool golang.org/x/tools/gopls/internal/analysis/modernize/cmd/modernize 23 | -------------------------------------------------------------------------------- /example_plugin_test.go: -------------------------------------------------------------------------------- 1 | package protokit_test 2 | 3 | import ( 4 | "log" 5 | 6 | "github.com/pseudomuto/protokit" 7 | "google.golang.org/protobuf/proto" 8 | pluginpb "google.golang.org/protobuf/types/pluginpb" 9 | ) 10 | 11 | type plugin struct{} 12 | 13 | func (p *plugin) Generate(r *pluginpb.CodeGeneratorRequest) (*pluginpb.CodeGeneratorResponse, error) { 14 | descriptors := protokit.ParseCodeGenRequest(r) 15 | resp := new(pluginpb.CodeGeneratorResponse) 16 | 17 | for _, desc := range descriptors { 18 | resp.File = append(resp.File, &pluginpb.CodeGeneratorResponse_File{ 19 | Name: proto.String(desc.GetName() + ".out"), 20 | Content: proto.String("Some relevant output"), 21 | }) 22 | } 23 | 24 | return resp, nil 25 | } 26 | 27 | // An example of running a custom plugin. This would be in your main.go file. 28 | func ExampleRunPlugin() { 29 | // in func main() {} 30 | if err := protokit.RunPlugin(new(plugin)); err != nil { 31 | log.Fatal(err) 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 David Muto (pseudomuto) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /fixtures/extend.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | import "google/protobuf/descriptor.proto"; 4 | 5 | package com.pseudomuto.protokit.v1; 6 | 7 | /** 8 | * Extension of protobuf file options. 9 | */ 10 | extend google.protobuf.FileOptions { 11 | optional bool extend_file = 20000; 12 | } 13 | 14 | /** 15 | * Extension of protobuf service options. 16 | */ 17 | extend google.protobuf.ServiceOptions { 18 | optional bool extend_service = 20000; 19 | } 20 | 21 | /** 22 | * Extension of protobuf method options. 23 | */ 24 | extend google.protobuf.MethodOptions { 25 | optional bool extend_method = 20000; 26 | } 27 | 28 | /** 29 | * Extension of protobuf enum options. 30 | */ 31 | extend google.protobuf.EnumOptions { 32 | optional bool extend_enum = 20000; 33 | } 34 | 35 | /** 36 | * Extension of protobuf enum value options. 37 | */ 38 | extend google.protobuf.EnumValueOptions { 39 | optional bool extend_enum_value = 20000; 40 | } 41 | 42 | /** 43 | * Extension of protobuf message options. 44 | */ 45 | extend google.protobuf.MessageOptions { 46 | optional bool extend_message = 20000; 47 | } 48 | 49 | /** 50 | * Extension of protobuf field options. 51 | */ 52 | extend google.protobuf.FieldOptions { 53 | optional bool extend_field = 20000; 54 | } 55 | -------------------------------------------------------------------------------- /examples/jsonator/README.md: -------------------------------------------------------------------------------- 1 | # jsonator 2 | 3 | Quite possibly the most useless protoc-gen plugin out there. 4 | 5 | This is just a demo to show you how to use `protokit`. All this does, is generate a single output file (`output.json`) 6 | which contains the full name and description of each of the proto files to generate as well as any services and 7 | associated methods. 8 | 9 | **Running this example** 10 | 11 | * `go generate && cat output.json` 12 | * :rofl: 13 | 14 | You'll see something like this (assuming you made no changes to the proto files here): 15 | 16 | ```json 17 | [ 18 | { 19 | "name":"com.jsonator.v1.sample.proto", 20 | "description":"This is just a sample proto for demonstrating how to use this library.\n\nThere's nothing \"fancy\" here.", 21 | "services":[ 22 | { 23 | "name":"SampleService", 24 | "methods":[ 25 | "RandomSample" 26 | ] 27 | } 28 | ] 29 | }, 30 | { 31 | "name":"com.jsonator.v2.sample2.proto", 32 | "description":"This is just another sample proto for demonstrating how to use this library.\n\nThere's also nothing \"fancy\" here.", 33 | "services":[ 34 | { 35 | "name":"SampleService", 36 | "methods":[ 37 | "RandomSample" 38 | ] 39 | } 40 | ] 41 | } 42 | ] 43 | ``` 44 | -------------------------------------------------------------------------------- /comments_test.go: -------------------------------------------------------------------------------- 1 | package protokit_test 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/pseudomuto/protokit" 8 | "github.com/pseudomuto/protokit/utils" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestComments(t *testing.T) { 13 | t.Parallel() 14 | 15 | pf, err := utils.LoadDescriptor("todo.proto", "fixtures", "fileset.pb") 16 | require.NoError(t, err) 17 | 18 | comments := protokit.ParseComments(pf) 19 | 20 | tests := []struct { 21 | key string 22 | leading string 23 | trailing string 24 | }{ 25 | {"6.0.2.1", "Add an item to your list\n\nAdds a new item to the specified list.", ""}, // leading commend 26 | {"4.0.2.0", "", "The id of the list."}, // tailing comment 27 | } 28 | 29 | for _, test := range tests { 30 | require.Equal(t, test.leading, comments[test.key].GetLeading()) 31 | require.Equal(t, test.trailing, comments[test.key].GetTrailing()) 32 | require.Empty(t, comments[test.key].GetDetached()) 33 | } 34 | 35 | require.NotNil(t, comments.Get("WONTBETHERE")) 36 | require.Empty(t, comments.Get("WONTBETHERE").String()) 37 | } 38 | 39 | // Join the leading and trailing comments together 40 | func ExampleComment_String() { 41 | c := &protokit.Comment{Leading: "Some leading comment", Trailing: "Some trailing comment"} 42 | fmt.Println(c.String()) 43 | // Output: Some leading comment 44 | // 45 | // Some trailing comment 46 | } 47 | -------------------------------------------------------------------------------- /.goreleaser.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | # Project metadata 4 | project_name: protokit 5 | 6 | # Build configuration 7 | builds: 8 | - skip: true 9 | 10 | # Archive configuration 11 | archives: 12 | - id: default 13 | ids: 14 | - protokit 15 | name_template: "{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}" 16 | files: 17 | - LICENSE 18 | - README.md 19 | 20 | # Checksum configuration 21 | checksum: 22 | name_template: "checksums.txt" 23 | 24 | # Snapshot configuration 25 | snapshot: 26 | version_template: "{{ incpatch .Version }}-next" 27 | 28 | # Changelog configuration 29 | changelog: 30 | sort: asc 31 | use: github 32 | filters: 33 | exclude: 34 | - "^docs:" 35 | - "^test:" 36 | - "^ci:" 37 | - "^build:" 38 | - "^chore:" 39 | groups: 40 | - title: Features 41 | regexp: "^.*feat[(\\w)]*:+.*$" 42 | order: 0 43 | - title: Bug Fixes 44 | regexp: "^.*fix[(\\w)]*:+.*$" 45 | order: 1 46 | - title: Others 47 | order: 999 48 | 49 | # Release configuration 50 | release: 51 | github: 52 | owner: pseudomuto 53 | name: protokit 54 | draft: false 55 | prerelease: auto 56 | name_template: "Release {{ .Version }}" 57 | header: | 58 | ## Protokit - A starter kit for building protoc-plugins. 59 | 60 | ### Get it! 61 | 62 | `go get github.com/pseudomuto/protokit` 63 | 64 | footer: | 65 | **Full Changelog**: https://github.com/pseudomuto/protokit/compare/{{ .PreviousTag }}...{{ .Tag }} 66 | -------------------------------------------------------------------------------- /flake.nix: -------------------------------------------------------------------------------- 1 | { 2 | description = "Protokit - A starter kit for building protoc-plugins"; 3 | 4 | inputs = { 5 | nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable"; 6 | flake-utils.url = "github:numtide/flake-utils"; 7 | }; 8 | 9 | outputs = 10 | { 11 | self, 12 | nixpkgs, 13 | flake-utils, 14 | }: 15 | flake-utils.lib.eachSystem 16 | [ 17 | "x86_64-linux" 18 | "aarch64-linux" 19 | "x86_64-darwin" 20 | "aarch64-darwin" 21 | ] 22 | ( 23 | system: 24 | let 25 | pkgs = import nixpkgs { 26 | inherit system; 27 | config.allowUnfree = true; 28 | }; 29 | 30 | # Pick language/tool versions here (adjust as you like) 31 | go = pkgs.go_1_25; 32 | 33 | # Common build utils 34 | buildUtils = with pkgs; [ 35 | go-task 36 | golangci-lint 37 | goreleaser 38 | protobuf 39 | ]; 40 | in 41 | { 42 | # `nix develop` drops you into this shell 43 | devShells.default = pkgs.mkShell { 44 | packages = [ 45 | go 46 | buildUtils 47 | ]; 48 | 49 | CGO_ENABLED = "0"; 50 | 51 | # Helpful prompt when you enter the shell 52 | shellHook = '' 53 | echo "▶ Dev shell ready on ${system}" 54 | echo " Go: $(${go}/bin/go version)" 55 | ''; 56 | }; 57 | } 58 | ); 59 | } 60 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Protokit 2 | 3 | First off, glad you're here and want to contribute! :heart: 4 | 5 | ## Getting Started 6 | 7 | It's always good to start simple. Clone the repo and `make test` to make sure you're starting from a good place. 8 | 9 | ## Submitting a PR 10 | 11 | Here are some general guidelines for making PRs for this repo. 12 | 13 | 1. [Fork this repo](https://github.com/pseudomuto/protokit/fork) 14 | 1. Make a branch off of master (`git checkout -b `) 15 | 1. Make focused commits with descriptive messages 16 | 1. Add tests that fail without your code, and pass with it (`make test` is your friend) 17 | 1. GoFmt your code! (see to setup your editor to do this for you) 18 | 1. **Ping someone on the PR** (Lots of people, including myself, won't get a notification unless pinged directly) 19 | 20 | Every PR should have a well detailed summary of the changes being made and the reasoning behind them. I've added a 21 | PR template that should help with this. 22 | 23 | ## Code Guidelines 24 | 25 | I don't want to be too dogmatic about this, but here are some general things I try to keep in mind: 26 | 27 | * GoFmt all the things! 28 | * Imports are grouped into external, stdlib, internal groups in each file (see any go file in this repo for an example) - really just use `goimports` and be done with it. 29 | * Test are defined in `_test` packages to ensure only the public interface is tested. 30 | * If you export something, make sure you add appropriate godoc comments and tests. 31 | 32 | ## Tagging a Release 33 | 34 | * Ensure you're on a clean master 35 | * Run `make release` 36 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: ci 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | branches: [master] 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - name: Checkout code 15 | uses: actions/checkout@v5 16 | 17 | - name: Install Nix 18 | uses: DeterminateSystems/nix-installer-action@v20 19 | 20 | - name: Setup Nix cache 21 | uses: DeterminateSystems/magic-nix-cache-action@v13 22 | 23 | - name: Run tests 24 | run: nix develop -c task test:ci 25 | 26 | - name: Upload coverage to Codecov 27 | uses: codecov/codecov-action@v5 28 | with: 29 | file: ./coverage.out 30 | flags: unittests 31 | name: codecov-umbrella 32 | fail_ci_if_error: false 33 | 34 | lint: 35 | name: Lint 36 | runs-on: ubuntu-latest 37 | 38 | steps: 39 | - name: Checkout code 40 | uses: actions/checkout@v5 41 | 42 | - name: Install Nix 43 | uses: DeterminateSystems/nix-installer-action@v20 44 | 45 | - name: Setup Nix cache 46 | uses: DeterminateSystems/magic-nix-cache-action@v13 47 | 48 | - name: Run linter 49 | run: nix develop -c task lint 50 | 51 | build: 52 | name: Build 53 | runs-on: ubuntu-latest 54 | 55 | steps: 56 | - name: Checkout code 57 | uses: actions/checkout@v5 58 | 59 | - name: Install Nix 60 | uses: DeterminateSystems/nix-installer-action@v20 61 | 62 | - name: Setup Nix cache 63 | uses: DeterminateSystems/magic-nix-cache-action@v13 64 | 65 | - name: Run GoReleaser (no publish) 66 | run: nix develop -c task build 67 | -------------------------------------------------------------------------------- /plugin.go: -------------------------------------------------------------------------------- 1 | package protokit 2 | 3 | import ( 4 | "errors" 5 | "io" 6 | "os" 7 | 8 | "google.golang.org/protobuf/proto" 9 | pluginpb "google.golang.org/protobuf/types/pluginpb" 10 | ) 11 | 12 | // Plugin describes an interface for running protoc code generator plugins 13 | type Plugin interface { 14 | Generate(req *pluginpb.CodeGeneratorRequest) (*pluginpb.CodeGeneratorResponse, error) 15 | } 16 | 17 | // RunPlugin runs the supplied plugin by reading input from stdin and generating output to stdout. 18 | func RunPlugin(p Plugin) error { 19 | return RunPluginWithIO(p, os.Stdin, os.Stdout) 20 | } 21 | 22 | // RunPluginWithIO runs the supplied plugin using the supplied reader and writer for IO. 23 | func RunPluginWithIO(p Plugin, r io.Reader, w io.Writer) error { 24 | req, err := readRequest(r) 25 | if err != nil { 26 | return err 27 | } 28 | 29 | resp, err := p.Generate(req) 30 | if err != nil { 31 | return err 32 | } 33 | 34 | return writeResponse(w, resp) 35 | } 36 | 37 | func readRequest(r io.Reader) (*pluginpb.CodeGeneratorRequest, error) { 38 | data, err := io.ReadAll(r) 39 | if err != nil { 40 | return nil, err 41 | } 42 | 43 | req := new(pluginpb.CodeGeneratorRequest) 44 | if err = proto.Unmarshal(data, req); err != nil { 45 | return nil, err 46 | } 47 | 48 | if len(req.GetFileToGenerate()) == 0 { 49 | return nil, errors.New("no files were supplied to the generator") 50 | } 51 | 52 | return req, nil 53 | } 54 | 55 | func writeResponse(w io.Writer, resp *pluginpb.CodeGeneratorResponse) error { 56 | data, err := proto.Marshal(resp) 57 | if err != nil { 58 | return err 59 | } 60 | 61 | if _, err := w.Write(data); err != nil { 62 | return err 63 | } 64 | 65 | return nil 66 | } 67 | -------------------------------------------------------------------------------- /flake.lock: -------------------------------------------------------------------------------- 1 | { 2 | "nodes": { 3 | "flake-utils": { 4 | "inputs": { 5 | "systems": "systems" 6 | }, 7 | "locked": { 8 | "lastModified": 1731533236, 9 | "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", 10 | "owner": "numtide", 11 | "repo": "flake-utils", 12 | "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", 13 | "type": "github" 14 | }, 15 | "original": { 16 | "owner": "numtide", 17 | "repo": "flake-utils", 18 | "type": "github" 19 | } 20 | }, 21 | "nixpkgs": { 22 | "locked": { 23 | "lastModified": 1760533177, 24 | "narHash": "sha256-OwM1sFustLHx+xmTymhucZuNhtq98fHIbfO8Swm5L8A=", 25 | "owner": "NixOS", 26 | "repo": "nixpkgs", 27 | "rev": "35f590344ff791e6b1d6d6b8f3523467c9217caf", 28 | "type": "github" 29 | }, 30 | "original": { 31 | "owner": "NixOS", 32 | "ref": "nixpkgs-unstable", 33 | "repo": "nixpkgs", 34 | "type": "github" 35 | } 36 | }, 37 | "root": { 38 | "inputs": { 39 | "flake-utils": "flake-utils", 40 | "nixpkgs": "nixpkgs" 41 | } 42 | }, 43 | "systems": { 44 | "locked": { 45 | "lastModified": 1681028828, 46 | "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", 47 | "owner": "nix-systems", 48 | "repo": "default", 49 | "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", 50 | "type": "github" 51 | }, 52 | "original": { 53 | "owner": "nix-systems", 54 | "repo": "default", 55 | "type": "github" 56 | } 57 | } 58 | }, 59 | "root": "root", 60 | "version": 7 61 | } 62 | -------------------------------------------------------------------------------- /context_test.go: -------------------------------------------------------------------------------- 1 | package protokit_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/pseudomuto/protokit" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func TestContextWithFileDescriptor(t *testing.T) { 12 | t.Parallel() 13 | 14 | ctx := context.Background() 15 | 16 | val, found := protokit.FileDescriptorFromContext(ctx) 17 | require.Nil(t, val) 18 | require.False(t, found) 19 | 20 | ctx = protokit.ContextWithFileDescriptor(ctx, new(protokit.FileDescriptor)) 21 | val, found = protokit.FileDescriptorFromContext(ctx) 22 | require.NotNil(t, val) 23 | require.True(t, found) 24 | } 25 | 26 | func TestContextWithEnumDescriptor(t *testing.T) { 27 | t.Parallel() 28 | 29 | ctx := context.Background() 30 | 31 | val, found := protokit.EnumDescriptorFromContext(ctx) 32 | require.Nil(t, val) 33 | require.False(t, found) 34 | 35 | ctx = protokit.ContextWithEnumDescriptor(ctx, new(protokit.EnumDescriptor)) 36 | val, found = protokit.EnumDescriptorFromContext(ctx) 37 | require.NotNil(t, val) 38 | require.True(t, found) 39 | } 40 | 41 | func TestContextWithDescriptor(t *testing.T) { 42 | t.Parallel() 43 | 44 | ctx := context.Background() 45 | 46 | val, found := protokit.DescriptorFromContext(ctx) 47 | require.Nil(t, val) 48 | require.False(t, found) 49 | 50 | ctx = protokit.ContextWithDescriptor(ctx, new(protokit.Descriptor)) 51 | val, found = protokit.DescriptorFromContext(ctx) 52 | require.NotNil(t, val) 53 | require.True(t, found) 54 | } 55 | 56 | func TestContextWithServiceDescriptor(t *testing.T) { 57 | t.Parallel() 58 | 59 | ctx := context.Background() 60 | 61 | val, found := protokit.ServiceDescriptorFromContext(ctx) 62 | require.Empty(t, val) 63 | require.False(t, found) 64 | 65 | ctx = protokit.ContextWithServiceDescriptor(ctx, new(protokit.ServiceDescriptor)) 66 | val, found = protokit.ServiceDescriptorFromContext(ctx) 67 | require.NotNil(t, val) 68 | require.True(t, found) 69 | } 70 | -------------------------------------------------------------------------------- /fixtures/edition2023.proto: -------------------------------------------------------------------------------- 1 | // Top-level comments are attached to the edition directive. 2 | edition = "2023"; 3 | 4 | import "google/protobuf/any.proto"; 5 | import "google/protobuf/timestamp.proto"; 6 | import "extend.proto"; 7 | option go_package = "edition2023"; 8 | 9 | // The official documentation for the Edition 2023 test. 10 | // 11 | // This file demonstrates edition 2023 syntax and features, 12 | // including explicit field presence by default. 13 | package com.pseudomuto.protokit.edition2023; 14 | 15 | option (com.pseudomuto.protokit.v1.extend_file) = true; 16 | 17 | // A service for testing edition 2023 features. 18 | service Edition2023Service { 19 | option (com.pseudomuto.protokit.v1.extend_service) = true; 20 | 21 | // Test method for edition 2023 22 | rpc TestMethod(TestRequest) returns (TestResponse) { 23 | option (com.pseudomuto.protokit.v1.extend_method) = true; 24 | } 25 | } 26 | 27 | // Test enumeration for edition 2023 28 | enum TestEnum { 29 | option (com.pseudomuto.protokit.v1.extend_enum) = true; 30 | 31 | UNKNOWN = 0; // Unknown value 32 | VALUE_A = 1 [(com.pseudomuto.protokit.v1.extend_enum_value) = true]; // First value 33 | VALUE_B = 2; // Second value 34 | } 35 | 36 | // Test message for edition 2023 with explicit field presence by default 37 | message TestMessage { 38 | option (com.pseudomuto.protokit.v1.extend_message) = true; 39 | 40 | int64 id = 1; // Message ID with explicit presence 41 | string name = 2 [(com.pseudomuto.protokit.v1.extend_field) = true]; // Message name 42 | TestEnum type = 3; // Message type 43 | google.protobuf.Timestamp created_at = 4; // Creation timestamp 44 | google.protobuf.Any metadata = 5; // Additional metadata 45 | } 46 | 47 | // Request message for testing 48 | message TestRequest { 49 | string query = 1; // Search query 50 | int32 limit = 2; // Result limit 51 | } 52 | 53 | // Response message for testing 54 | message TestResponse { 55 | repeated TestMessage results = 1; // Search results 56 | int32 total_count = 2; // Total result count 57 | } -------------------------------------------------------------------------------- /fixtures/edition2023_implicit.proto: -------------------------------------------------------------------------------- 1 | // Edition 2023 with implicit field presence (proto3-like semantics) 2 | edition = "2023"; 3 | 4 | import "google/protobuf/any.proto"; 5 | import "google/protobuf/timestamp.proto"; 6 | import "extend.proto"; 7 | 8 | option go_package = "edition2023_implicit"; 9 | option features.field_presence = IMPLICIT; 10 | 11 | // Test package with proto3-like semantics using editions 12 | package com.pseudomuto.protokit.edition2023.implicit; 13 | 14 | option (com.pseudomuto.protokit.v1.extend_file) = true; 15 | 16 | // A service for testing edition 2023 with implicit field presence 17 | service Edition2023ImplicitService { 18 | option (com.pseudomuto.protokit.v1.extend_service) = true; 19 | 20 | // Test method for edition 2023 implicit 21 | rpc TestMethod(TestRequest) returns (TestResponse) { 22 | option (com.pseudomuto.protokit.v1.extend_method) = true; 23 | } 24 | } 25 | 26 | // Test enumeration for edition 2023 implicit 27 | enum TestEnum { 28 | option (com.pseudomuto.protokit.v1.extend_enum) = true; 29 | 30 | UNKNOWN = 0; // Unknown value 31 | VALUE_A = 1 [(com.pseudomuto.protokit.v1.extend_enum_value) = true]; // First value 32 | VALUE_B = 2; // Second value 33 | } 34 | 35 | // Test message for edition 2023 with implicit field presence (like proto3) 36 | message TestMessage { 37 | option (com.pseudomuto.protokit.v1.extend_message) = true; 38 | 39 | int64 id = 1; // Message ID with implicit presence 40 | string name = 2 [(com.pseudomuto.protokit.v1.extend_field) = true]; // Message name 41 | TestEnum type = 3; // Message type 42 | google.protobuf.Timestamp created_at = 4; // Creation timestamp 43 | google.protobuf.Any metadata = 5; // Additional metadata 44 | } 45 | 46 | // Request message for testing 47 | message TestRequest { 48 | string query = 1; // Search query 49 | int32 limit = 2; // Result limit 50 | } 51 | 52 | // Response message for testing 53 | message TestResponse { 54 | repeated TestMessage results = 1; // Search results 55 | int32 total_count = 2; // Total result count 56 | } -------------------------------------------------------------------------------- /.golangci.yaml: -------------------------------------------------------------------------------- 1 | version: "2" 2 | 3 | run: 4 | modules-download-mode: readonly 5 | 6 | linters: 7 | default: none 8 | enable: 9 | - asasalint 10 | - asciicheck 11 | - bidichk 12 | - bodyclose 13 | - contextcheck 14 | - depguard 15 | - durationcheck 16 | - errcheck 17 | - errchkjson 18 | - errorlint 19 | - exhaustive 20 | - funlen 21 | - gocheckcompilerdirectives 22 | - gochecksumtype 23 | - gocognit 24 | - gocyclo 25 | - gomoddirectives 26 | - gomodguard 27 | - gosec 28 | - gosmopolitan 29 | - govet 30 | - ineffassign 31 | - loggercheck 32 | - maintidx 33 | - makezero 34 | - musttag 35 | - nestif 36 | - nilerr 37 | - nilnesserr 38 | - noctx 39 | - perfsprint 40 | - prealloc 41 | - reassign 42 | - recvcheck 43 | - rowserrcheck 44 | - spancheck 45 | - sqlclosecheck 46 | - staticcheck 47 | - testifylint 48 | - unparam 49 | - unused 50 | - zerologlint 51 | 52 | settings: 53 | depguard: 54 | rules: 55 | main: 56 | list-mode: lax 57 | deny: 58 | - pkg: reflect 59 | funlen: 60 | ignore-comments: true 61 | lines: 100 62 | statements: 60 63 | 64 | gocognit: 65 | min-complexity: 50 66 | 67 | testifylint: 68 | disable: 69 | - go-require 70 | 71 | exclusions: 72 | generated: lax 73 | paths: 74 | - "examples/(.+).go$" 75 | presets: 76 | - comments 77 | - common-false-positives 78 | - legacy 79 | - std-error-handling 80 | rules: 81 | - linters: 82 | - gomoddirectives 83 | path: go.mod 84 | - linters: 85 | - funlen 86 | path: (.+)main.go 87 | - linters: 88 | - tagalign 89 | path: (.+).go 90 | - linters: 91 | - funlen 92 | path: (.+)_test.go 93 | - linters: 94 | - errcheck 95 | - gocognit 96 | path: examples/(.+).go 97 | 98 | formatters: 99 | enable: 100 | - gci 101 | - gofmt 102 | - gofumpt 103 | - goimports 104 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= 4 | github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= 5 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 6 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 7 | github.com/stretchr/testify v1.2.1 h1:52QO5WkIUcHGIR7EnGagH88x1bUzqGXTC5/1bDTUQ7U= 8 | github.com/stretchr/testify v1.2.1/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= 9 | golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg= 10 | golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= 11 | golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= 12 | golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= 13 | golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= 14 | golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= 15 | golang.org/x/telemetry v0.0.0-20250710130107-8d8967aff50b h1:DU+gwOBXU+6bO0sEyO7o/NeMlxZxCZEvI7v+J4a1zRQ= 16 | golang.org/x/telemetry v0.0.0-20250710130107-8d8967aff50b/go.mod h1:4ZwOYna0/zsOKwuR5X/m0QFOJpSZvAxFfkQT+Erd9D4= 17 | golang.org/x/tools v0.35.1-0.20250728180453-01a3475a31bc h1:ZRKyKRJl/YEWl9ScZwd6Ua6xSt7DE6tHp1I3ucMroGM= 18 | golang.org/x/tools v0.35.1-0.20250728180453-01a3475a31bc/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw= 19 | golang.org/x/tools/gopls v0.20.0 h1:fxOYZXKl6IsOTKIh6IgjDbIDHlr5btOtOUkrGOgFDB4= 20 | golang.org/x/tools/gopls v0.20.0/go.mod h1:vxYUZ8l4swjbvTQJJONmVfbHsd1ovixCwB7sodBbTYI= 21 | google.golang.org/genproto/googleapis/api v0.0.0-20241015192408-796eee8c2d53 h1:fVoAXEKA4+yufmbdVYv+SE73+cPZbbbe8paLsHfkK+U= 22 | google.golang.org/genproto/googleapis/api v0.0.0-20241015192408-796eee8c2d53/go.mod h1:riSXTwQ4+nqmPGtobMFyW5FqVAmIs0St6VPp4Ug7CE4= 23 | google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= 24 | google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= 25 | -------------------------------------------------------------------------------- /fixtures/edition2024.proto: -------------------------------------------------------------------------------- 1 | // Top-level comments are attached to the edition directive. 2 | edition = "2024"; 3 | 4 | import "google/protobuf/any.proto"; 5 | import "google/protobuf/duration.proto"; 6 | import "extend.proto"; 7 | option go_package = "edition2024"; 8 | 9 | // The official documentation for the Edition 2024 test. 10 | // 11 | // This file demonstrates edition 2024 syntax and features, 12 | // including enhanced symbol visibility controls. 13 | package com.pseudomuto.protokit.edition2024; 14 | 15 | option (com.pseudomuto.protokit.v1.extend_file) = true; 16 | 17 | // A service for testing edition 2024 features. 18 | service Edition2024Service { 19 | option (com.pseudomuto.protokit.v1.extend_service) = true; 20 | 21 | // Test method for edition 2024 22 | rpc TestMethod(TestRequest) returns (TestResponse) { 23 | option (com.pseudomuto.protokit.v1.extend_method) = true; 24 | } 25 | 26 | // Another test method 27 | rpc AnotherMethod(TestRequest) returns (TestResponse); 28 | } 29 | 30 | // Test enumeration for edition 2024 31 | enum TestEnum { 32 | option (com.pseudomuto.protokit.v1.extend_enum) = true; 33 | 34 | UNKNOWN = 0; // Unknown value 35 | OPTION_X = 1 [(com.pseudomuto.protokit.v1.extend_enum_value) = true]; // First option 36 | OPTION_Y = 2; // Second option 37 | OPTION_Z = 3; // Third option 38 | } 39 | 40 | // Test message for edition 2024 41 | message TestMessage { 42 | option (com.pseudomuto.protokit.v1.extend_message) = true; 43 | 44 | int64 id = 1; // Message ID 45 | string title = 2 [(com.pseudomuto.protokit.v1.extend_field) = true]; // Message title 46 | TestEnum category = 3; // Message category 47 | google.protobuf.Duration timeout = 4; // Timeout duration 48 | google.protobuf.Any payload = 5; // Message payload 49 | 50 | // Nested message for testing 51 | message NestedData { 52 | string key = 1; // Data key 53 | bytes value = 2; // Data value 54 | } 55 | 56 | repeated NestedData data = 6; // Nested data entries 57 | } 58 | 59 | // Request message for testing 60 | message TestRequest { 61 | string filter = 1; // Filter criteria 62 | int32 page_size = 2; // Page size 63 | string page_token = 3; // Page token for pagination 64 | } 65 | 66 | // Response message for testing 67 | message TestResponse { 68 | repeated TestMessage items = 1; // Response items 69 | string next_page_token = 2; // Next page token 70 | int64 total_size = 3; // Total items available 71 | } -------------------------------------------------------------------------------- /context.go: -------------------------------------------------------------------------------- 1 | package protokit 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | type contextKey string 8 | 9 | const ( 10 | fileContextKey = contextKey("file") 11 | descriptorContextKey = contextKey("descriptor") 12 | enumContextKey = contextKey("enum") 13 | serviceContextKey = contextKey("service") 14 | ) 15 | 16 | // ContextWithFileDescriptor returns a new context with the attached `FileDescriptor` 17 | func ContextWithFileDescriptor(ctx context.Context, fd *FileDescriptor) context.Context { 18 | return context.WithValue(ctx, fileContextKey, fd) 19 | } 20 | 21 | // FileDescriptorFromContext returns the `FileDescriptor` from the context and whether or not the key was found. 22 | func FileDescriptorFromContext(ctx context.Context) (*FileDescriptor, bool) { 23 | val, ok := ctx.Value(fileContextKey).(*FileDescriptor) 24 | return val, ok 25 | } 26 | 27 | // ContextWithDescriptor returns a new context with the specified `Descriptor` 28 | func ContextWithDescriptor(ctx context.Context, d *Descriptor) context.Context { 29 | return context.WithValue(ctx, descriptorContextKey, d) 30 | } 31 | 32 | // DescriptorFromContext returns the associated `Descriptor` for the context and whether or not it was found 33 | func DescriptorFromContext(ctx context.Context) (*Descriptor, bool) { 34 | val, ok := ctx.Value(descriptorContextKey).(*Descriptor) 35 | return val, ok 36 | } 37 | 38 | // ContextWithEnumDescriptor returns a new context with the specified `EnumDescriptor` 39 | func ContextWithEnumDescriptor(ctx context.Context, d *EnumDescriptor) context.Context { 40 | return context.WithValue(ctx, enumContextKey, d) 41 | } 42 | 43 | // EnumDescriptorFromContext returns the associated `EnumDescriptor` for the context and whether or not it was found 44 | func EnumDescriptorFromContext(ctx context.Context) (*EnumDescriptor, bool) { 45 | val, ok := ctx.Value(enumContextKey).(*EnumDescriptor) 46 | return val, ok 47 | } 48 | 49 | // ContextWithServiceDescriptor returns a new context with `service` 50 | func ContextWithServiceDescriptor(ctx context.Context, service *ServiceDescriptor) context.Context { 51 | return context.WithValue(ctx, serviceContextKey, service) 52 | } 53 | 54 | // ServiceDescriptorFromContext returns the `Service` from the context and whether or not the key was found. 55 | func ServiceDescriptorFromContext(ctx context.Context) (*ServiceDescriptor, bool) { 56 | val, ok := ctx.Value(serviceContextKey).(*ServiceDescriptor) 57 | return val, ok 58 | } 59 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # protokit 2 | 3 | [![CI][github-svg]][github-ci] 4 | [![codecov][codecov-svg]][codecov-url] 5 | [![GoDoc][godoc-svg]][godoc-url] 6 | [![Go Report Card][goreport-svg]][goreport-url] 7 | 8 | A starter kit for building protoc-plugins. Rather than write your own, you can just use an existing one. 9 | 10 | See the [examples](examples/) directory for uh...examples. 11 | 12 | ## Getting Started 13 | 14 | ```golang 15 | package main 16 | 17 | import ( 18 | "google.golang.org/protobuf/proto" 19 | "google.golang.org/protobuf/types/pluginpb" 20 | "github.com/pseudomuto/protokit" 21 | _ "google.golang.org/genproto/googleapis/api/annotations" // Support (google.api.http) option (from google/api/annotations.proto). 22 | 23 | "log" 24 | ) 25 | 26 | func main() { 27 | // all the heavy lifting done for you! 28 | if err := protokit.RunPlugin(new(plugin)); err != nil { 29 | log.Fatal(err) 30 | } 31 | } 32 | 33 | // plugin is an implementation of protokit.Plugin 34 | type plugin struct{} 35 | 36 | func (p *plugin) Generate(in *pluginpb.CodeGeneratorRequest) (*pluginpb.CodeGeneratorResponse, error) { 37 | descriptors := protokit.ParseCodeGenRequest(req) 38 | 39 | resp := new(pluginpb.CodeGeneratorResponse) 40 | 41 | for _, d := range descriptors { 42 | // TODO: YOUR WORK HERE 43 | fileName := // generate a file name based on d.GetName() 44 | content := // generate content for the output file 45 | 46 | resp.File = append(resp.File, &pluginpb.CodeGeneratorResponse_File{ 47 | Name: proto.String(fileName), 48 | Content: proto.String(content), 49 | }) 50 | } 51 | 52 | return resp, nil 53 | } 54 | ``` 55 | 56 | Then invoke your plugin via `protoc`. For example (assuming your app is called `thingy`): 57 | 58 | `protoc --plugin=protoc-gen-thingy=./thingy -I. --thingy_out=. rpc/*.proto` 59 | 60 | [github-svg]: https://github.com/pseudomuto/protokit/actions/workflows/ci.yaml/badge.svg?branch=master 61 | [github-ci]: https://github.com/pseudomuto/protokit/actions/workflows/ci.yaml 62 | [codecov-svg]: https://codecov.io/gh/pseudomuto/protokit/branch/master/graph/badge.svg 63 | [codecov-url]: https://codecov.io/gh/pseudomuto/protokit 64 | [godoc-svg]: https://godoc.org/github.com/pseudomuto/protokit?status.svg 65 | [godoc-url]: https://godoc.org/github.com/pseudomuto/protokit 66 | [goreport-svg]: https://goreportcard.com/badge/github.com/pseudomuto/protokit 67 | [goreport-url]: https://goreportcard.com/report/github.com/pseudomuto/protokit 68 | -------------------------------------------------------------------------------- /taskfile.yaml: -------------------------------------------------------------------------------- 1 | version: "3" 2 | 3 | tasks: 4 | update: 5 | desc: Updates all dependencies 6 | aliases: [up] 7 | silent: true 8 | cmds: 9 | - go mod tidy 10 | 11 | build: 12 | desc: Build a local snapshot with goreleaser 13 | silent: true 14 | cmds: 15 | - "goreleaser release --snapshot --clean --skip=publish" 16 | 17 | generate: 18 | aliases: [gen] 19 | silent: true 20 | cmd: go generate ./... 21 | 22 | lint: 23 | desc: Run golangci-lint on the codebase 24 | silent: true 25 | cmd: "golangci-lint run {{.CLI_ARGS}}" 26 | 27 | lint:fix: 28 | desc: Run golangci-lint --fix on the codebase 29 | silent: true 30 | cmds: 31 | - go tool modernize -fix -test ./... 32 | - golangci-lint run --fix 33 | 34 | test: 35 | desc: Run the test suite (unit tests only) 36 | silent: true 37 | cmd: go test ./... -cover -short 38 | 39 | test:ci: 40 | desc: Run the test suite for CI with coverage profile 41 | silent: true 42 | cmd: go test -v -coverprofile=coverage.out ./... 43 | 44 | tag: 45 | desc: Create and push a new signed tag for release 46 | prompt: This will create, sign, and push tag {{.TAG}}. Continue? 47 | requires: 48 | vars: [TAG] 49 | preconditions: 50 | - sh: git diff --quiet 51 | msg: "Working directory must be clean" 52 | - sh: git diff --cached --quiet 53 | msg: "No staged changes allowed" 54 | - sh: '[[ "{{.TAG}}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]' 55 | msg: "TAG must be in format vX.Y.Z (e.g., v1.0.0)" 56 | cmds: 57 | - git tag -s {{.TAG}} -m "Release {{.TAG}}" 58 | - git push origin {{.TAG}} 59 | - echo "Signed tag {{.TAG}} created and pushed. GitHub Actions will now build and release." 60 | 61 | tag:patch: 62 | desc: Create and push a patch version tag (v0.0.X) 63 | cmds: 64 | - task: tag 65 | vars: 66 | TAG: 67 | sh: git describe --tags --abbrev=0 2>/dev/null | awk -F. '{print $1"."$2"."$3+1}' || echo "v0.0.1" 68 | 69 | tag:minor: 70 | desc: Create and push a minor version tag (v0.X.0) 71 | cmds: 72 | - task: tag 73 | vars: 74 | TAG: 75 | sh: git describe --tags --abbrev=0 2>/dev/null | awk -F. '{print $1"."$2+1".0"}' || echo "v0.1.0" 76 | 77 | tag:major: 78 | desc: Create and push a major version tag (vX.0.0) 79 | cmds: 80 | - task: tag 81 | vars: 82 | TAG: 83 | sh: git describe --tags --abbrev=0 2>/dev/null | awk -F. '{print $1+1".0.0"}' | sed 's/^v/v/' || echo "v1.0.0" 84 | -------------------------------------------------------------------------------- /fixtures/booking.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | import "extend.proto"; 4 | 5 | /** 6 | * Booking related messages. 7 | * 8 | * This file is really just an example. The data model is completely fictional. 9 | */ 10 | package com.pseudomuto.protokit.v1; 11 | 12 | option (com.pseudomuto.protokit.v1.extend_file) = true; 13 | 14 | /** 15 | * Service for handling vehicle bookings. 16 | */ 17 | service BookingService { 18 | option (com.pseudomuto.protokit.v1.extend_service) = true; 19 | 20 | /// Used to book a vehicle. Pass in a Booking and a BookingStatus will be returned. 21 | rpc BookVehicle (Booking) returns (BookingStatus) { 22 | option (com.pseudomuto.protokit.v1.extend_method) = true; 23 | }; 24 | } 25 | 26 | /** 27 | * Represents the status of a vehicle booking. 28 | */ 29 | message BookingStatus { 30 | /** 31 | * A flag for the status result. 32 | */ 33 | enum StatusCode { 34 | OK = 200; // OK result. 35 | BAD_REQUEST = 400; // BAD result. 36 | } 37 | 38 | required int32 id = 1; /// Unique booking status ID. 39 | required string description = 2; /// Booking status description. E.g. "Active". 40 | optional StatusCode status_code = 3; /// The status of this status? 41 | 42 | extensions 100 to max; 43 | } 44 | 45 | // File-level extension 46 | extend BookingStatus { 47 | /* The country the booking occurred in. */ 48 | optional string country = 100 [default = "china", (com.pseudomuto.protokit.v1.extend_field) = true]; 49 | } 50 | 51 | /** 52 | * The type of booking. 53 | */ 54 | enum BookingType { 55 | option (com.pseudomuto.protokit.v1.extend_enum) = true; 56 | 57 | IMMEDIATE = 100; // Immediate booking. 58 | FUTURE = 101 [(com.pseudomuto.protokit.v1.extend_enum_value) = true]; // Future booking. 59 | } 60 | 61 | /** 62 | * Represents the booking of a vehicle. 63 | * 64 | * Vehicles are some cool shit. But drive carefully! 65 | */ 66 | message Booking { 67 | option (com.pseudomuto.protokit.v1.extend_message) = true; 68 | 69 | required int32 vehicle_id = 1; /// ID of booked vehicle. 70 | required int32 customer_id = 2; /// Customer that booked the vehicle. 71 | required BookingStatus status = 3; /// Status of the booking. 72 | 73 | /** Has booking confirmation been sent? */ 74 | required bool confirmation_sent = 4; 75 | 76 | /** Has payment been received? */ 77 | optional bool payment_received = 5 [default = true, (com.pseudomuto.protokit.v1.extend_field) = true]; 78 | 79 | oneof things { 80 | int32 reference_num = 6; // the numeric reference number 81 | string reference_tag = 7; // the reference tag (string) 82 | } 83 | 84 | // Nested extentions are also a thing. 85 | 86 | extend BookingStatus { 87 | optional string optional_field_1 = 101; // An optional field to be used however you please. 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /plugin_test.go: -------------------------------------------------------------------------------- 1 | package protokit_test 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "testing" 7 | 8 | "github.com/pseudomuto/protokit" 9 | "github.com/pseudomuto/protokit/utils" 10 | "github.com/stretchr/testify/require" 11 | "google.golang.org/protobuf/proto" 12 | pluginpb "google.golang.org/protobuf/types/pluginpb" 13 | ) 14 | 15 | func TestRunPlugin(t *testing.T) { 16 | t.Parallel() 17 | 18 | fds, err := utils.LoadDescriptorSet("fixtures", "fileset.pb") 19 | require.NoError(t, err) 20 | 21 | req := utils.CreateGenRequest(fds, "booking.proto", "todo.proto") 22 | data, err := proto.Marshal(req) 23 | require.NoError(t, err) 24 | 25 | in := bytes.NewBuffer(data) 26 | out := new(bytes.Buffer) 27 | 28 | require.NoError(t, protokit.RunPluginWithIO(new(OkPlugin), in, out)) 29 | require.NotEmpty(t, out) 30 | } 31 | 32 | func TestRunPluginInputError(t *testing.T) { 33 | t.Parallel() 34 | 35 | in := bytes.NewBufferString("Not a codegen request") 36 | out := new(bytes.Buffer) 37 | 38 | err := protokit.RunPluginWithIO(nil, in, out) 39 | require.Error(t, err) 40 | require.Contains(t, err.Error(), "proto:") 41 | require.Empty(t, out) 42 | } 43 | 44 | func TestRunPluginNoFilesToGenerate(t *testing.T) { 45 | t.Parallel() 46 | 47 | fds, err := utils.LoadDescriptorSet("fixtures", "fileset.pb") 48 | require.NoError(t, err) 49 | 50 | req := utils.CreateGenRequest(fds) 51 | data, err := proto.Marshal(req) 52 | require.NoError(t, err) 53 | 54 | in := bytes.NewBuffer(data) 55 | out := new(bytes.Buffer) 56 | 57 | err = protokit.RunPluginWithIO(new(ErrorPlugin), in, out) 58 | require.EqualError(t, err, "no files were supplied to the generator") 59 | require.Empty(t, out) 60 | } 61 | 62 | func TestRunPluginGeneratorError(t *testing.T) { 63 | t.Parallel() 64 | 65 | fds, err := utils.LoadDescriptorSet("fixtures", "fileset.pb") 66 | require.NoError(t, err) 67 | 68 | req := utils.CreateGenRequest(fds, "booking.proto", "todo.proto") 69 | data, err := proto.Marshal(req) 70 | require.NoError(t, err) 71 | 72 | in := bytes.NewBuffer(data) 73 | out := new(bytes.Buffer) 74 | 75 | err = protokit.RunPluginWithIO(new(ErrorPlugin), in, out) 76 | require.EqualError(t, err, "generator error") 77 | require.Empty(t, out) 78 | } 79 | 80 | type ErrorPlugin struct{} 81 | 82 | func (ep *ErrorPlugin) Generate(r *pluginpb.CodeGeneratorRequest) (*pluginpb.CodeGeneratorResponse, error) { 83 | return nil, errors.New("generator error") 84 | } 85 | 86 | type OkPlugin struct{} 87 | 88 | func (op *OkPlugin) Generate(r *pluginpb.CodeGeneratorRequest) (*pluginpb.CodeGeneratorResponse, error) { 89 | resp := new(pluginpb.CodeGeneratorResponse) 90 | resp.File = append(resp.File, &pluginpb.CodeGeneratorResponse_File{ 91 | Name: proto.String("myfile.out"), 92 | Content: proto.String("someoutput"), 93 | }) 94 | 95 | return resp, nil 96 | } 97 | -------------------------------------------------------------------------------- /comments.go: -------------------------------------------------------------------------------- 1 | package protokit 2 | 3 | import ( 4 | "bytes" 5 | "strconv" 6 | "strings" 7 | 8 | "google.golang.org/protobuf/types/descriptorpb" 9 | ) 10 | 11 | // A Comment describes the leading, trailing, and detached comments for a proto object. See `SourceCodeInfo_Location` in 12 | // descriptor.proto for details on what those terms mean 13 | type Comment struct { 14 | Leading string 15 | Trailing string 16 | Detached []string 17 | } 18 | 19 | // String returns the leading and trailing comments joined by 2 line breaks (`\n\n`). If either are empty, the line 20 | // breaks are removed. 21 | func (c *Comment) String() string { 22 | b := new(bytes.Buffer) 23 | if c.GetLeading() != "" { 24 | b.WriteString(c.GetLeading()) 25 | b.WriteString("\n\n") 26 | } 27 | 28 | b.WriteString(c.GetTrailing()) 29 | 30 | return strings.TrimSpace(b.String()) 31 | } 32 | 33 | func newComment(loc *descriptorpb.SourceCodeInfo_Location) *Comment { 34 | detached := make([]string, len(loc.GetLeadingDetachedComments())) 35 | for i, c := range loc.GetLeadingDetachedComments() { 36 | detached[i] = scrub(c) 37 | } 38 | 39 | return &Comment{ 40 | Leading: scrub(loc.GetLeadingComments()), 41 | Trailing: scrub(loc.GetTrailingComments()), 42 | Detached: detached, 43 | } 44 | } 45 | 46 | // GetLeading returns the leading comments 47 | func (c *Comment) GetLeading() string { return c.Leading } 48 | 49 | // GetTrailing returns the leading comments 50 | func (c *Comment) GetTrailing() string { return c.Trailing } 51 | 52 | // GetDetached returns the detached leading comments 53 | func (c *Comment) GetDetached() []string { return c.Detached } 54 | 55 | // Comments is a map of source location paths to values. 56 | type Comments map[string]*Comment 57 | 58 | // ParseComments parses all comments within a proto file. The locations are encoded into the map by joining the paths 59 | // with a "." character. E.g. `4.2.3.0`. 60 | // 61 | // Leading/trailing spaces are trimmed for each comment type (leading, trailing, detached) 62 | func ParseComments(fd *descriptorpb.FileDescriptorProto) Comments { 63 | comments := make(Comments) 64 | 65 | for _, loc := range fd.GetSourceCodeInfo().GetLocation() { 66 | if loc.GetLeadingComments() == "" && loc.GetTrailingComments() == "" && len(loc.GetLeadingDetachedComments()) == 0 { 67 | continue 68 | } 69 | 70 | path := loc.GetPath() 71 | key := make([]string, len(path)) 72 | for idx, p := range path { 73 | key[idx] = strconv.Itoa(int(p)) 74 | } 75 | 76 | comments[strings.Join(key, ".")] = newComment(loc) 77 | } 78 | 79 | return comments 80 | } 81 | 82 | func (c Comments) Get(path string) *Comment { 83 | if val, ok := c[path]; ok { 84 | return val 85 | } 86 | 87 | // return an empty comment 88 | return &Comment{Detached: make([]string, 0)} 89 | } 90 | 91 | func scrub(str string) string { 92 | return strings.TrimSpace(strings.ReplaceAll(str, "\n ", "\n")) 93 | } 94 | -------------------------------------------------------------------------------- /utils/protobuf.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "errors" 5 | "os" 6 | "path/filepath" 7 | "slices" 8 | 9 | "google.golang.org/protobuf/proto" 10 | "google.golang.org/protobuf/types/descriptorpb" 11 | pluginpb "google.golang.org/protobuf/types/pluginpb" 12 | ) 13 | 14 | // CreateGenRequest creates a codegen request from a `FileDescriptorSet` 15 | func CreateGenRequest(fds *descriptorpb.FileDescriptorSet, filesToGen ...string) *pluginpb.CodeGeneratorRequest { 16 | req := new(pluginpb.CodeGeneratorRequest) 17 | req.ProtoFile = fds.GetFile() 18 | 19 | for _, f := range req.GetProtoFile() { 20 | if slices.Contains(filesToGen, f.GetName()) { 21 | req.FileToGenerate = append(req.FileToGenerate, f.GetName()) 22 | } 23 | } 24 | 25 | return req 26 | } 27 | 28 | // FilesToGenerate iterates through the proto files in the request and returns only the ones that were requested on the 29 | // command line. Only these protos should be generated by a codegen plugin. 30 | func FilesToGenerate(req *pluginpb.CodeGeneratorRequest) []*descriptorpb.FileDescriptorProto { 31 | protos := make([]*descriptorpb.FileDescriptorProto, 0) 32 | 33 | OUTERLOOP: 34 | for _, name := range req.GetFileToGenerate() { 35 | for _, f := range req.GetProtoFile() { 36 | if f.GetName() == name { 37 | protos = append(protos, f) 38 | continue OUTERLOOP 39 | } 40 | } 41 | } 42 | 43 | return protos 44 | } 45 | 46 | // LoadDescriptorSet loads a `FileDescriptorSet` from a file on disk. Such a file can be generated using the 47 | // `--descriptor_set_out` flag with `protoc`. 48 | // 49 | // Example: 50 | // 51 | // protoc --descriptor_set_out=fileset.pb --include_imports --include_source_info ./booking.proto ./todo.proto 52 | func LoadDescriptorSet(pathSegments ...string) (*descriptorpb.FileDescriptorSet, error) { 53 | f, err := os.ReadFile(filepath.Join(pathSegments...)) 54 | if err != nil { 55 | return nil, err 56 | } 57 | 58 | set := new(descriptorpb.FileDescriptorSet) 59 | if err = proto.Unmarshal(f, set); err != nil { 60 | return nil, err 61 | } 62 | 63 | return set, nil 64 | } 65 | 66 | // FindDescriptor finds the named descriptor in the given set. Only base names are searched. The first match is 67 | // returned, on `nil` if not found 68 | func FindDescriptor(set *descriptorpb.FileDescriptorSet, name string) *descriptorpb.FileDescriptorProto { 69 | for _, pf := range set.GetFile() { 70 | if filepath.Base(pf.GetName()) == name { 71 | return pf 72 | } 73 | } 74 | 75 | return nil 76 | } 77 | 78 | // LoadDescriptor loads file descriptor protos from a file on disk, and returns the named proto descriptor. This is 79 | // useful mostly for testing purposes. 80 | func LoadDescriptor(name string, pathSegments ...string) (*descriptorpb.FileDescriptorProto, error) { 81 | set, err := LoadDescriptorSet(pathSegments...) 82 | if err != nil { 83 | return nil, err 84 | } 85 | 86 | if pf := FindDescriptor(set, name); pf != nil { 87 | return pf, nil 88 | } 89 | 90 | return nil, errors.New("FileDescriptor not found") 91 | } 92 | -------------------------------------------------------------------------------- /examples/jsonator/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "fmt" 7 | "log" 8 | 9 | "google.golang.org/protobuf/proto" 10 | pluginpb "google.golang.org/protobuf/types/pluginpb" 11 | "github.com/pseudomuto/protokit" 12 | "google.golang.org/genproto/googleapis/api/annotations" 13 | ) 14 | 15 | func main() { 16 | if err := protokit.RunPlugin(new(plugin)); err != nil { 17 | log.Fatal(err) 18 | } 19 | } 20 | 21 | type plugin struct{} 22 | 23 | func (p *plugin) Generate(req *pluginpb.CodeGeneratorRequest) (*pluginpb.CodeGeneratorResponse, error) { 24 | descriptors := protokit.ParseCodeGenRequest(req) 25 | files := make([]*file, len(descriptors)) 26 | 27 | for i, d := range descriptors { 28 | files[i] = newFile(d) 29 | } 30 | 31 | buf := new(bytes.Buffer) 32 | enc := json.NewEncoder(buf) 33 | enc.SetIndent("", " ") 34 | 35 | if err := enc.Encode(files); err != nil { 36 | return nil, err 37 | } 38 | 39 | resp := new(pluginpb.CodeGeneratorResponse) 40 | resp.File = append(resp.File, &pluginpb.CodeGeneratorResponse_File{ 41 | Name: proto.String("output.json"), 42 | Content: proto.String(buf.String()), 43 | }) 44 | 45 | return resp, nil 46 | } 47 | 48 | type file struct { 49 | Name string `json:"name"` 50 | Description string `json:"description"` 51 | Services []*service `json:"services"` 52 | } 53 | 54 | func newFile(fd *protokit.FileDescriptor) *file { 55 | svcs := make([]*service, len(fd.GetServices())) 56 | for i, sd := range fd.GetServices() { 57 | svcs[i] = newService(sd) 58 | } 59 | 60 | return &file{ 61 | Name: fmt.Sprintf("%s.%s", fd.GetPackage(), fd.GetName()), 62 | Description: fd.GetPackageComments().String(), 63 | Services: svcs, 64 | } 65 | } 66 | 67 | type service struct { 68 | Name string `json:"name"` 69 | Methods []*method `json:"methods"` 70 | } 71 | 72 | func newService(sd *protokit.ServiceDescriptor) *service { 73 | methods := make([]*method, len(sd.GetMethods())) 74 | for i, md := range sd.GetMethods() { 75 | methods[i] = newMethod(md) 76 | } 77 | 78 | return &service{Name: sd.GetName(), Methods: methods} 79 | } 80 | 81 | type method struct { 82 | Name string `json:"name"` 83 | HTTPRules []string `json:"http_rules"` 84 | } 85 | 86 | func newMethod(md *protokit.MethodDescriptor) *method { 87 | httpRules := make([]string, 0) 88 | if httpRule, ok := md.OptionExtensions["google.api.http"].(*annotations.HttpRule); ok { 89 | switch httpRule.GetPattern().(type) { 90 | case *annotations.HttpRule_Get: 91 | httpRules = append(httpRules, fmt.Sprintf("GET %s", httpRule.GetGet())) 92 | case *annotations.HttpRule_Put: 93 | httpRules = append(httpRules, fmt.Sprintf("PUT %s", httpRule.GetPut())) 94 | case *annotations.HttpRule_Post: 95 | httpRules = append(httpRules, fmt.Sprintf("POST %s", httpRule.GetPost())) 96 | case *annotations.HttpRule_Delete: 97 | httpRules = append(httpRules, fmt.Sprintf("DELETE %s", httpRule.GetDelete())) 98 | case *annotations.HttpRule_Patch: 99 | httpRules = append(httpRules, fmt.Sprintf("PATCH %s", httpRule.GetPatch())) 100 | } 101 | // Append more for each rule in httpRule.AdditionalBindings... 102 | } 103 | 104 | return &method{Name: md.GetName(), HTTPRules: httpRules} 105 | } 106 | -------------------------------------------------------------------------------- /utils/protobuf_test.go: -------------------------------------------------------------------------------- 1 | package utils_test 2 | 3 | import ( 4 | "slices" 5 | "testing" 6 | 7 | "github.com/pseudomuto/protokit/utils" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func TestCreateGenRequest(t *testing.T) { 12 | t.Parallel() 13 | 14 | fds, err := utils.LoadDescriptorSet("..", "fixtures", "fileset.pb") 15 | require.NoError(t, err) 16 | 17 | req := utils.CreateGenRequest(fds, "booking.proto", "todo.proto") 18 | require.Equal(t, []string{"booking.proto", "todo.proto"}, req.GetFileToGenerate()) 19 | 20 | expectedProtos := []string{ 21 | "booking.proto", 22 | "google/protobuf/any.proto", 23 | "google/protobuf/descriptor.proto", 24 | "google/protobuf/timestamp.proto", 25 | "google/protobuf/duration.proto", 26 | "extend.proto", 27 | "todo.proto", 28 | "todo_import.proto", 29 | "edition2023.proto", 30 | "edition2024.proto", 31 | "edition2023_implicit.proto", 32 | } 33 | 34 | for _, pf := range req.GetProtoFile() { 35 | require.True(t, slices.Contains(expectedProtos, pf.GetName()), "Unexpected proto file: %s", pf.GetName()) 36 | } 37 | } 38 | 39 | func TestFilesToGenerate(t *testing.T) { 40 | t.Parallel() 41 | 42 | fds, err := utils.LoadDescriptorSet("..", "fixtures", "fileset.pb") 43 | require.NoError(t, err) 44 | 45 | req := utils.CreateGenRequest(fds, "booking.proto") 46 | protos := utils.FilesToGenerate(req) 47 | require.Len(t, protos, 1) 48 | require.Equal(t, "booking.proto", protos[0].GetName()) 49 | } 50 | 51 | func TestLoadDescriptorSet(t *testing.T) { 52 | t.Parallel() 53 | 54 | set, err := utils.LoadDescriptorSet("..", "fixtures", "fileset.pb") 55 | require.NoError(t, err) 56 | require.Len(t, set.GetFile(), 11) 57 | 58 | require.NotNil(t, utils.FindDescriptor(set, "todo.proto")) 59 | require.Nil(t, utils.FindDescriptor(set, "whodis.proto")) 60 | } 61 | 62 | func TestLoadDescriptorSetFileNotFound(t *testing.T) { 63 | t.Parallel() 64 | 65 | set, err := utils.LoadDescriptorSet("..", "fixtures", "notgonnadoit.pb") 66 | require.Nil(t, set) 67 | require.EqualError(t, err, "open ../fixtures/notgonnadoit.pb: no such file or directory") 68 | } 69 | 70 | func TestLoadDescriptorSetMarshalError(t *testing.T) { 71 | t.Parallel() 72 | 73 | set, err := utils.LoadDescriptorSet("..", "fixtures", "todo.proto") 74 | require.Nil(t, set) 75 | require.Error(t, err) 76 | require.Contains(t, err.Error(), "proto:") 77 | } 78 | 79 | func TestLoadDescriptor(t *testing.T) { 80 | t.Parallel() 81 | 82 | proto, err := utils.LoadDescriptor("todo.proto", "..", "fixtures", "fileset.pb") 83 | require.NotNil(t, proto) 84 | require.NoError(t, err) 85 | } 86 | 87 | func TestLoadDescriptorFileNotFound(t *testing.T) { 88 | t.Parallel() 89 | 90 | proto, err := utils.LoadDescriptor("todo.proto", "..", "fixtures", "notgonnadoit.pb") 91 | require.Nil(t, proto) 92 | require.EqualError(t, err, "open ../fixtures/notgonnadoit.pb: no such file or directory") 93 | } 94 | 95 | func TestLoadDescriptorMarshalError(t *testing.T) { 96 | t.Parallel() 97 | 98 | proto, err := utils.LoadDescriptor("todo.proto", "..", "fixtures", "todo.proto") 99 | require.Nil(t, proto) 100 | require.Error(t, err) 101 | require.Contains(t, err.Error(), "proto:") 102 | } 103 | 104 | func TestLoadDescriptorDescriptorNotFound(t *testing.T) { 105 | t.Parallel() 106 | 107 | proto, err := utils.LoadDescriptor("nothere.proto", "..", "fixtures", "fileset.pb") 108 | require.Nil(t, proto) 109 | require.EqualError(t, err, "FileDescriptor not found") 110 | } 111 | -------------------------------------------------------------------------------- /fixtures/todo.proto: -------------------------------------------------------------------------------- 1 | // Top-level comments are attached to the syntax directive. 2 | syntax = "proto3"; 3 | 4 | import "google/protobuf/any.proto"; 5 | import "google/protobuf/timestamp.proto"; 6 | import "extend.proto"; 7 | import public "todo_import.proto"; 8 | option go_package = "todo"; 9 | 10 | // The official documentation for the Todo API. 11 | // 12 | // Some parts of this file are unnecessarily complicated. In order to have a test for nested messages, enums, etc. I've 13 | // added some odd looking implementation details. So you know, don't use this in real life for a todo service. 14 | // 15 | // The get started run the following: 16 | // 17 | // * `make setup` 18 | // * `make test` 19 | package com.pseudomuto.protokit.v1; 20 | 21 | option (com.pseudomuto.protokit.v1.extend_file) = true; 22 | 23 | // A service for managing "todo" items. 24 | // 25 | // Add, complete, and remove your items on your todo lists. 26 | service Todo { 27 | option (com.pseudomuto.protokit.v1.extend_service) = true; 28 | 29 | // Create a new todo list 30 | rpc CreateList(CreateListRequest) returns (CreateListResponse) { 31 | option (com.pseudomuto.protokit.v1.extend_method) = true; 32 | } 33 | 34 | // Add an item to your list 35 | // 36 | // Adds a new item to the specified list. 37 | rpc AddItem(AddItemRequest) returns (AddItemResponse); 38 | } 39 | 40 | // An enumeration of list types 41 | enum ListType { 42 | option (com.pseudomuto.protokit.v1.extend_enum) = true; 43 | 44 | REMINDERS = 0; // The reminders type. 45 | CHECKLIST = 1 [(com.pseudomuto.protokit.v1.extend_enum_value) = true]; // The checklist type. 46 | } 47 | 48 | // A list object. 49 | message List { 50 | option (com.pseudomuto.protokit.v1.extend_message) = true; 51 | 52 | int64 id = 1; // The id of the list. 53 | string name = 2 [(com.pseudomuto.protokit.v1.extend_field) = true]; // The name of the list. 54 | ListType type = 3; // The type of list 55 | google.protobuf.Timestamp created_at = 4; // The timestamp for creation. 56 | google.protobuf.Any details = 5; // Some arbitrary list details. 57 | } 58 | 59 | // A request object for creating todo lists. 60 | message CreateListRequest { 61 | // The name of the list. 62 | string name = 1; 63 | } 64 | 65 | // A successfully created list response. 66 | message CreateListResponse { 67 | // An internal status message 68 | message Status { 69 | sint32 code = 1; // The status code. 70 | } 71 | 72 | List list = 1; // The list that was created. 73 | Status status = 2; // The status for the response. 74 | } 75 | 76 | // A list item 77 | message Item { 78 | // An enumeration of possible statuses 79 | enum Status { 80 | PENDING = 0; // The pending status. 81 | COMPLETED = 1; // The completed status. 82 | } 83 | 84 | int64 id = 1; // The id of the item. 85 | string title = 2; // The title of the item. 86 | Status completed = 3; // The current status of the item. 87 | google.protobuf.Timestamp created_at = 4; // The timestamp for creation. 88 | ListItemDetails details = 5; // Item details. 89 | } 90 | 91 | // A request message for adding new items. 92 | message AddItemRequest { 93 | int64 list_id = 1; // The id of the list to add to. 94 | string title = 2; // The title of the item. 95 | bool completed = 3; // Whether or not the item is completed. 96 | } 97 | 98 | // A successfully added item response. 99 | message AddItemResponse { 100 | Item item = 1; // The list item that was added. 101 | } 102 | -------------------------------------------------------------------------------- /parser.go: -------------------------------------------------------------------------------- 1 | package protokit 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "strconv" 7 | "strings" 8 | 9 | "google.golang.org/protobuf/types/descriptorpb" 10 | pluginpb "google.golang.org/protobuf/types/pluginpb" 11 | ) 12 | 13 | const ( 14 | // tag numbers in FileDescriptorProto 15 | packageCommentPath = 2 16 | messageCommentPath = 4 17 | enumCommentPath = 5 18 | serviceCommentPath = 6 19 | extensionCommentPath = 7 20 | syntaxCommentPath = 12 21 | editionCommentPath = 14 22 | 23 | // tag numbers in DescriptorProto 24 | messageFieldCommentPath = 2 // field 25 | messageMessageCommentPath = 3 // nested_type 26 | messageEnumCommentPath = 4 // enum_type 27 | messageExtensionCommentPath = 6 // extension 28 | 29 | // tag numbers in EnumDescriptorProto 30 | enumValueCommentPath = 2 // value 31 | 32 | // tag numbers in ServiceDescriptorProto 33 | serviceMethodCommentPath = 2 34 | ) 35 | 36 | // ParseCodeGenRequest parses the given request into `FileDescriptor` objects. Only the `req.FilesToGenerate` will be 37 | // returned. 38 | // 39 | // For example, given the following invocation, only booking.proto will be returned even if it imports other protos: 40 | // 41 | // protoc --plugin=protoc-gen-test=./test -I. protos/booking.proto 42 | func ParseCodeGenRequest(req *pluginpb.CodeGeneratorRequest) []*FileDescriptor { 43 | allFiles := make(map[string]*FileDescriptor) 44 | genFiles := make([]*FileDescriptor, len(req.GetFileToGenerate())) 45 | 46 | for _, pf := range req.GetProtoFile() { 47 | allFiles[pf.GetName()] = parseFile(context.Background(), pf) 48 | } 49 | 50 | for i, f := range req.GetFileToGenerate() { 51 | genFiles[i] = allFiles[f] 52 | parseImports(genFiles[i], allFiles) 53 | } 54 | 55 | return genFiles 56 | } 57 | 58 | func parseFile(ctx context.Context, fd *descriptorpb.FileDescriptorProto) *FileDescriptor { 59 | comments := ParseComments(fd) 60 | 61 | file := &FileDescriptor{ 62 | comments: comments, 63 | FileDescriptorProto: fd, 64 | PackageComments: comments.Get(strconv.Itoa(packageCommentPath)), 65 | SyntaxComments: comments.Get(strconv.Itoa(syntaxCommentPath)), 66 | EditionComments: comments.Get(strconv.Itoa(editionCommentPath)), 67 | } 68 | 69 | if fd.Options != nil { 70 | file.setOptions(fd.Options) 71 | } 72 | 73 | fileCtx := ContextWithFileDescriptor(ctx, file) 74 | file.Enums = parseEnums(fileCtx, fd.GetEnumType()) 75 | file.Extensions = parseExtensions(fileCtx, fd.GetExtension()) 76 | file.Messages = parseMessages(fileCtx, fd.GetMessageType()) 77 | file.Services = parseServices(fileCtx, fd.GetService()) 78 | 79 | return file 80 | } 81 | 82 | func parseEnums(ctx context.Context, protos []*descriptorpb.EnumDescriptorProto) []*EnumDescriptor { 83 | enums := make([]*EnumDescriptor, len(protos)) 84 | file, _ := FileDescriptorFromContext(ctx) 85 | parent, hasParent := DescriptorFromContext(ctx) 86 | 87 | for i, ed := range protos { 88 | longName := ed.GetName() 89 | commentPath := fmt.Sprintf("%d.%d", enumCommentPath, i) 90 | 91 | if hasParent { 92 | longName = fmt.Sprintf("%s.%s", parent.GetLongName(), longName) 93 | commentPath = fmt.Sprintf("%s.%d.%d", parent.path, messageEnumCommentPath, i) 94 | } 95 | 96 | enums[i] = &EnumDescriptor{ 97 | common: newCommon(file, commentPath, longName), 98 | EnumDescriptorProto: ed, 99 | Comments: file.comments.Get(commentPath), 100 | Parent: parent, 101 | } 102 | if ed.Options != nil { 103 | enums[i].setOptions(ed.Options) 104 | } 105 | 106 | subCtx := ContextWithEnumDescriptor(ctx, enums[i]) 107 | enums[i].Values = parseEnumValues(subCtx, ed.GetValue()) 108 | } 109 | 110 | return enums 111 | } 112 | 113 | func parseEnumValues(ctx context.Context, protos []*descriptorpb.EnumValueDescriptorProto) []*EnumValueDescriptor { 114 | values := make([]*EnumValueDescriptor, len(protos)) 115 | file, _ := FileDescriptorFromContext(ctx) 116 | enum, _ := EnumDescriptorFromContext(ctx) 117 | 118 | for i, vd := range protos { 119 | longName := fmt.Sprintf("%s.%s", enum.GetLongName(), vd.GetName()) 120 | 121 | values[i] = &EnumValueDescriptor{ 122 | common: newCommon(file, "", longName), 123 | EnumValueDescriptorProto: vd, 124 | Enum: enum, 125 | Comments: file.comments.Get(fmt.Sprintf("%s.%d.%d", enum.path, enumValueCommentPath, i)), 126 | } 127 | if vd.Options != nil { 128 | values[i].setOptions(vd.Options) 129 | } 130 | } 131 | 132 | return values 133 | } 134 | 135 | func parseExtensions(ctx context.Context, protos []*descriptorpb.FieldDescriptorProto) []*ExtensionDescriptor { 136 | exts := make([]*ExtensionDescriptor, len(protos)) 137 | file, _ := FileDescriptorFromContext(ctx) 138 | parent, hasParent := DescriptorFromContext(ctx) 139 | 140 | for i, ext := range protos { 141 | commentPath := fmt.Sprintf("%d.%d", extensionCommentPath, i) 142 | longName := fmt.Sprintf("%s.%s", ext.GetExtendee(), ext.GetName()) 143 | 144 | if strings.Contains(longName, file.GetPackage()) { 145 | parts := strings.Split(ext.GetExtendee(), ".") 146 | longName = fmt.Sprintf("%s.%s", parts[len(parts)-1], ext.GetName()) 147 | } 148 | 149 | if hasParent { 150 | commentPath = fmt.Sprintf("%s.%d.%d", parent.path, messageExtensionCommentPath, i) 151 | } 152 | 153 | exts[i] = &ExtensionDescriptor{ 154 | common: newCommon(file, commentPath, longName), 155 | FieldDescriptorProto: ext, 156 | Comments: file.comments.Get(commentPath), 157 | Parent: parent, 158 | } 159 | if ext.Options != nil { 160 | exts[i].setOptions(ext.Options) 161 | } 162 | } 163 | 164 | return exts 165 | } 166 | 167 | func parseImports(fd *FileDescriptor, allFiles map[string]*FileDescriptor) { 168 | fd.Imports = make([]*ImportedDescriptor, 0) 169 | 170 | for _, index := range fd.GetPublicDependency() { 171 | file := allFiles[fd.GetDependency()[index]] 172 | 173 | for _, d := range file.GetMessages() { 174 | // skip map entry objects 175 | if !d.GetOptions().GetMapEntry() { 176 | fd.Imports = append(fd.Imports, &ImportedDescriptor{d.common}) 177 | } 178 | } 179 | 180 | for _, e := range file.GetEnums() { 181 | fd.Imports = append(fd.Imports, &ImportedDescriptor{e.common}) 182 | } 183 | 184 | for _, ext := range file.GetExtensions() { 185 | fd.Imports = append(fd.Imports, &ImportedDescriptor{ext.common}) 186 | } 187 | } 188 | } 189 | 190 | func parseMessages(ctx context.Context, protos []*descriptorpb.DescriptorProto) []*Descriptor { 191 | msgs := make([]*Descriptor, len(protos)) 192 | file, _ := FileDescriptorFromContext(ctx) 193 | parent, hasParent := DescriptorFromContext(ctx) 194 | 195 | for i, md := range protos { 196 | longName := md.GetName() 197 | commentPath := fmt.Sprintf("%d.%d", messageCommentPath, i) 198 | 199 | if hasParent { 200 | longName = fmt.Sprintf("%s.%s", parent.GetLongName(), longName) 201 | commentPath = fmt.Sprintf("%s.%d.%d", parent.path, messageMessageCommentPath, i) 202 | } 203 | 204 | msgs[i] = &Descriptor{ 205 | common: newCommon(file, commentPath, longName), 206 | DescriptorProto: md, 207 | Comments: file.comments.Get(commentPath), 208 | Parent: parent, 209 | } 210 | if md.Options != nil { 211 | msgs[i].setOptions(md.Options) 212 | } 213 | 214 | msgCtx := ContextWithDescriptor(ctx, msgs[i]) 215 | msgs[i].Enums = parseEnums(msgCtx, md.GetEnumType()) 216 | msgs[i].Extensions = parseExtensions(msgCtx, md.GetExtension()) 217 | msgs[i].Fields = parseMessageFields(msgCtx, md.GetField()) 218 | msgs[i].Messages = parseMessages(msgCtx, md.GetNestedType()) 219 | } 220 | 221 | return msgs 222 | } 223 | 224 | func parseMessageFields(ctx context.Context, protos []*descriptorpb.FieldDescriptorProto) []*FieldDescriptor { 225 | fields := make([]*FieldDescriptor, len(protos)) 226 | file, _ := FileDescriptorFromContext(ctx) 227 | message, _ := DescriptorFromContext(ctx) 228 | 229 | for i, fd := range protos { 230 | longName := fmt.Sprintf("%s.%s", message.GetLongName(), fd.GetName()) 231 | 232 | fields[i] = &FieldDescriptor{ 233 | common: newCommon(file, "", longName), 234 | FieldDescriptorProto: fd, 235 | Comments: file.comments.Get(fmt.Sprintf("%s.%d.%d", message.path, messageFieldCommentPath, i)), 236 | Message: message, 237 | } 238 | if fd.Options != nil { 239 | fields[i].setOptions(fd.Options) 240 | } 241 | } 242 | 243 | return fields 244 | } 245 | 246 | func parseServices(ctx context.Context, protos []*descriptorpb.ServiceDescriptorProto) []*ServiceDescriptor { 247 | svcs := make([]*ServiceDescriptor, len(protos)) 248 | file, _ := FileDescriptorFromContext(ctx) 249 | 250 | for i, sd := range protos { 251 | longName := sd.GetName() 252 | commentPath := fmt.Sprintf("%d.%d", serviceCommentPath, i) 253 | 254 | svcs[i] = &ServiceDescriptor{ 255 | common: newCommon(file, commentPath, longName), 256 | ServiceDescriptorProto: sd, 257 | Comments: file.comments.Get(commentPath), 258 | } 259 | if sd.Options != nil { 260 | svcs[i].setOptions(sd.Options) 261 | } 262 | 263 | svcCtx := ContextWithServiceDescriptor(ctx, svcs[i]) 264 | svcs[i].Methods = parseServiceMethods(svcCtx, sd.GetMethod()) 265 | } 266 | 267 | return svcs 268 | } 269 | 270 | func parseServiceMethods(ctx context.Context, protos []*descriptorpb.MethodDescriptorProto) []*MethodDescriptor { 271 | methods := make([]*MethodDescriptor, len(protos)) 272 | 273 | file, _ := FileDescriptorFromContext(ctx) 274 | svc, _ := ServiceDescriptorFromContext(ctx) 275 | 276 | for i, md := range protos { 277 | longName := fmt.Sprintf("%s.%s", svc.GetLongName(), md.GetName()) 278 | 279 | methods[i] = &MethodDescriptor{ 280 | common: newCommon(file, "", longName), 281 | MethodDescriptorProto: md, 282 | Service: svc, 283 | Comments: file.comments.Get(fmt.Sprintf("%s.%d.%d", svc.path, serviceMethodCommentPath, i)), 284 | } 285 | if md.Options != nil { 286 | methods[i].setOptions(md.Options) 287 | } 288 | } 289 | 290 | return methods 291 | } 292 | -------------------------------------------------------------------------------- /parser_test.go: -------------------------------------------------------------------------------- 1 | package protokit_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/pseudomuto/protokit" 7 | "github.com/pseudomuto/protokit/utils" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func setupParserTest(t *testing.T) (*protokit.FileDescriptor, *protokit.FileDescriptor) { 12 | t.Helper() 13 | 14 | set, err := utils.LoadDescriptorSet("fixtures", "fileset.pb") 15 | require.NoError(t, err) 16 | 17 | req := utils.CreateGenRequest(set, "booking.proto", "todo.proto") 18 | files := protokit.ParseCodeGenRequest(req) 19 | proto2 := files[0] 20 | proto3 := files[1] 21 | 22 | return proto2, proto3 23 | } 24 | 25 | func TestFileParsing(t *testing.T) { 26 | t.Parallel() 27 | 28 | proto2, proto3 := setupParserTest(t) 29 | 30 | require.True(t, proto3.IsProto3()) 31 | require.Equal(t, "Top-level comments are attached to the syntax directive.", proto3.GetSyntaxComments().String()) 32 | require.Contains(t, proto3.GetPackageComments().String(), "The official documentation for the Todo API.\n\n") 33 | require.Empty(t, proto3.GetExtensions()) // no extensions in proto3 34 | 35 | require.False(t, proto2.IsProto3()) 36 | require.Len(t, proto2.GetExtensions(), 1) 37 | } 38 | 39 | func TestFileImports(t *testing.T) { 40 | t.Parallel() 41 | 42 | _, proto3 := setupParserTest(t) 43 | 44 | require.Len(t, proto3.GetImports(), 2) 45 | 46 | imp := proto3.GetImports()[0] 47 | require.NotNil(t, imp.GetFile()) 48 | require.Equal(t, "ListItemDetails", imp.GetLongName()) 49 | require.Equal(t, "com.pseudomuto.protokit.v1.ListItemDetails", imp.GetFullName()) 50 | } 51 | 52 | func TestFileEnums(t *testing.T) { 53 | t.Parallel() 54 | 55 | _, proto3 := setupParserTest(t) 56 | 57 | require.Len(t, proto3.GetEnums(), 1) 58 | require.Nil(t, proto3.GetEnum("swingandamiss")) 59 | 60 | enum := proto3.GetEnum("ListType") 61 | require.Equal(t, "ListType", enum.GetLongName()) 62 | require.Equal(t, "com.pseudomuto.protokit.v1.ListType", enum.GetFullName()) 63 | require.True(t, enum.IsProto3()) 64 | require.Nil(t, enum.GetParent()) 65 | require.NotNil(t, enum.GetFile()) 66 | require.Equal(t, "An enumeration of list types", enum.GetComments().String()) 67 | require.Equal(t, "com.pseudomuto.protokit.v1", enum.GetPackage()) 68 | require.Len(t, enum.GetValues(), 2) 69 | 70 | require.Equal(t, "REMINDERS", enum.GetValues()[0].GetName()) 71 | require.Equal(t, enum, enum.GetValues()[0].GetEnum()) 72 | require.Equal(t, "The reminders type.", enum.GetNamedValue("REMINDERS").GetComments().String()) 73 | 74 | require.Nil(t, enum.GetNamedValue("whodis")) 75 | } 76 | 77 | func TestFileExtensions(t *testing.T) { 78 | t.Parallel() 79 | 80 | proto2, _ := setupParserTest(t) 81 | 82 | ext := proto2.GetExtensions()[0] 83 | require.Nil(t, ext.GetParent()) 84 | require.Equal(t, "country", ext.GetName()) 85 | require.Equal(t, "BookingStatus.country", ext.GetLongName()) 86 | require.Equal(t, "com.pseudomuto.protokit.v1.BookingStatus.country", ext.GetFullName()) 87 | require.Equal(t, "The country the booking occurred in.", ext.GetComments().String()) 88 | } 89 | 90 | func TestServices(t *testing.T) { 91 | t.Parallel() 92 | 93 | _, proto3 := setupParserTest(t) 94 | 95 | require.Len(t, proto3.GetServices(), 1) 96 | require.Nil(t, proto3.GetService("swingandamiss")) 97 | 98 | svc := proto3.GetService("Todo") 99 | require.Equal(t, "Todo", svc.GetLongName()) 100 | require.Equal(t, "com.pseudomuto.protokit.v1.Todo", svc.GetFullName()) 101 | require.NotNil(t, svc.GetFile()) 102 | require.True(t, svc.IsProto3()) 103 | require.Contains(t, svc.GetComments().String(), "A service for managing \"todo\" items.\n\n") 104 | require.Equal(t, "com.pseudomuto.protokit.v1", svc.GetPackage()) 105 | require.Len(t, svc.GetMethods(), 2) 106 | 107 | m := svc.GetNamedMethod("CreateList") 108 | require.Equal(t, "CreateList", m.GetName()) 109 | require.Equal(t, "Todo.CreateList", m.GetLongName()) 110 | require.Equal(t, "com.pseudomuto.protokit.v1.Todo.CreateList", m.GetFullName()) 111 | require.NotNil(t, m.GetFile()) 112 | require.Equal(t, svc, m.GetService()) 113 | require.Equal(t, "Create a new todo list", m.GetComments().String()) 114 | 115 | m = svc.GetNamedMethod("Todo.AddItem") 116 | require.Equal(t, "Add an item to your list\n\nAdds a new item to the specified list.", m.GetComments().String()) 117 | 118 | require.Nil(t, svc.GetNamedMethod("wat")) 119 | } 120 | 121 | func TestFileMessages(t *testing.T) { 122 | t.Parallel() 123 | 124 | proto2, proto3 := setupParserTest(t) 125 | 126 | require.Len(t, proto3.GetMessages(), 6) 127 | require.Nil(t, proto3.GetMessage("swingandamiss")) 128 | 129 | m := proto3.GetMessage("AddItemRequest") 130 | require.Equal(t, "AddItemRequest", m.GetName()) 131 | require.Equal(t, "AddItemRequest", m.GetLongName()) 132 | require.Equal(t, "com.pseudomuto.protokit.v1.AddItemRequest", m.GetFullName()) 133 | require.NotNil(t, m.GetFile()) 134 | require.Nil(t, m.GetParent()) 135 | require.Equal(t, "A request message for adding new items.", m.GetComments().String()) 136 | require.Equal(t, "com.pseudomuto.protokit.v1", m.GetPackage()) 137 | require.Len(t, m.GetMessageFields(), 3) 138 | require.Nil(t, m.GetMessageField("swingandamiss")) 139 | 140 | // no extensions in proto3 141 | require.Empty(t, m.GetExtensions()) 142 | 143 | f := m.GetMessageField("completed") 144 | require.Equal(t, "completed", f.GetName()) 145 | require.Equal(t, "AddItemRequest.completed", f.GetLongName()) 146 | require.Equal(t, "com.pseudomuto.protokit.v1.AddItemRequest.completed", f.GetFullName()) 147 | require.NotNil(t, f.GetFile()) 148 | require.Equal(t, m, f.GetMessage()) 149 | require.Equal(t, "Whether or not the item is completed.", f.GetComments().String()) 150 | 151 | // just making sure google.protobuf.Any fields aren't special 152 | m = proto3.GetMessage("List") 153 | f = m.GetMessageField("details") 154 | require.Equal(t, "details", f.GetName()) 155 | require.Equal(t, "List.details", f.GetLongName()) 156 | require.Equal(t, "com.pseudomuto.protokit.v1.List.details", f.GetFullName()) 157 | 158 | // oneof fields should just expand to fields 159 | m = proto2.GetMessage("Booking") 160 | require.NotNil(t, m.GetMessageField("reference_num")) 161 | require.NotNil(t, m.GetMessageField("reference_tag")) 162 | require.Equal(t, "the numeric reference number", m.GetMessageField("reference_num").GetComments().String()) 163 | } 164 | 165 | func TestMessageEnums(t *testing.T) { 166 | t.Parallel() 167 | 168 | _, proto3 := setupParserTest(t) 169 | 170 | m := proto3.GetMessage("Item") 171 | require.NotNil(t, m.GetFile()) 172 | require.Len(t, m.GetEnums(), 1) 173 | require.Nil(t, m.GetEnum("whodis")) 174 | 175 | e := m.GetEnum("Status") 176 | require.Equal(t, "Status", e.GetName()) 177 | require.Equal(t, "Item.Status", e.GetLongName()) 178 | require.Equal(t, "com.pseudomuto.protokit.v1.Item.Status", e.GetFullName()) 179 | require.NotNil(t, e.GetFile()) 180 | require.Equal(t, m, e.GetParent()) 181 | require.Equal(t, e, m.GetEnum("Item.Status")) 182 | require.Equal(t, "An enumeration of possible statuses", e.GetComments().String()) 183 | require.Len(t, e.GetValues(), 2) 184 | 185 | val := e.GetNamedValue("COMPLETED") 186 | require.Equal(t, "COMPLETED", val.GetName()) 187 | require.Equal(t, "Item.Status.COMPLETED", val.GetLongName()) 188 | require.Equal(t, "com.pseudomuto.protokit.v1.Item.Status.COMPLETED", val.GetFullName()) 189 | require.Equal(t, "The completed status.", val.GetComments().String()) 190 | require.NotNil(t, val.GetFile()) 191 | } 192 | 193 | func TestMessageExtensions(t *testing.T) { 194 | t.Parallel() 195 | 196 | proto2, _ := setupParserTest(t) 197 | 198 | m := proto2.GetMessage("Booking") 199 | ext := m.GetExtensions()[0] 200 | require.Equal(t, m, ext.GetParent()) 201 | require.Equal(t, int32(101), ext.GetNumber()) 202 | require.Equal(t, "optional_field_1", ext.GetName()) 203 | require.Equal(t, "BookingStatus.optional_field_1", ext.GetLongName()) 204 | require.Equal(t, "com.pseudomuto.protokit.v1.BookingStatus.optional_field_1", ext.GetFullName()) 205 | require.Equal(t, "An optional field to be used however you please.", ext.GetComments().String()) 206 | } 207 | 208 | func TestNestedMessages(t *testing.T) { 209 | t.Parallel() 210 | 211 | _, proto3 := setupParserTest(t) 212 | 213 | m := proto3.GetMessage("CreateListResponse") 214 | require.Len(t, m.GetMessages(), 1) 215 | require.Nil(t, m.GetMessage("whodis")) 216 | 217 | n := m.GetMessage("Status") 218 | require.Equal(t, n, m.GetMessage("CreateListResponse.Status")) 219 | 220 | // no extensions in proto3 221 | require.Empty(t, n.GetExtensions()) 222 | 223 | require.Equal(t, "Status", n.GetName()) 224 | require.Equal(t, "CreateListResponse.Status", n.GetLongName()) 225 | require.Equal(t, "com.pseudomuto.protokit.v1.CreateListResponse.Status", n.GetFullName()) 226 | require.Equal(t, "An internal status message", n.GetComments().String()) 227 | require.NotNil(t, n.GetFile()) 228 | require.Equal(t, m, n.GetParent()) 229 | 230 | f := n.GetMessageField("code") 231 | require.Equal(t, "CreateListResponse.Status.code", f.GetLongName()) 232 | require.Equal(t, "com.pseudomuto.protokit.v1.CreateListResponse.Status.code", f.GetFullName()) 233 | require.NotNil(t, f.GetFile()) 234 | require.Equal(t, "The status code.", f.GetComments().String()) 235 | } 236 | 237 | func TestExtendedOptions(t *testing.T) { 238 | t.Parallel() 239 | 240 | proto2, _ := setupParserTest(t) 241 | 242 | require.Contains(t, proto2.OptionExtensions, "com.pseudomuto.protokit.v1.extend_file") 243 | 244 | extendedValue, ok := proto2.OptionExtensions["com.pseudomuto.protokit.v1.extend_file"].(*bool) 245 | require.True(t, ok) 246 | require.True(t, *extendedValue) 247 | 248 | service := proto2.GetService("BookingService") 249 | require.Contains(t, service.OptionExtensions, "com.pseudomuto.protokit.v1.extend_service") 250 | 251 | extendedValue, ok = service.OptionExtensions["com.pseudomuto.protokit.v1.extend_service"].(*bool) 252 | require.True(t, ok) 253 | require.True(t, *extendedValue) 254 | 255 | method := service.GetNamedMethod("BookVehicle") 256 | require.Contains(t, method.OptionExtensions, "com.pseudomuto.protokit.v1.extend_method") 257 | 258 | extendedValue, ok = method.OptionExtensions["com.pseudomuto.protokit.v1.extend_method"].(*bool) 259 | require.True(t, ok) 260 | require.True(t, *extendedValue) 261 | 262 | message := proto2.GetMessage("Booking") 263 | require.Contains(t, message.OptionExtensions, "com.pseudomuto.protokit.v1.extend_message") 264 | 265 | extendedValue, ok = message.OptionExtensions["com.pseudomuto.protokit.v1.extend_message"].(*bool) 266 | require.True(t, ok) 267 | require.True(t, *extendedValue) 268 | 269 | field := message.GetMessageField("payment_received") 270 | require.Contains(t, field.OptionExtensions, "com.pseudomuto.protokit.v1.extend_field") 271 | 272 | extendedValue, ok = field.OptionExtensions["com.pseudomuto.protokit.v1.extend_field"].(*bool) 273 | require.True(t, ok) 274 | require.True(t, *extendedValue) 275 | 276 | enum := proto2.GetEnum("BookingType") 277 | require.Contains(t, enum.OptionExtensions, "com.pseudomuto.protokit.v1.extend_enum") 278 | 279 | extendedValue, ok = enum.OptionExtensions["com.pseudomuto.protokit.v1.extend_enum"].(*bool) 280 | require.True(t, ok) 281 | require.True(t, *extendedValue) 282 | 283 | enumValue := enum.GetNamedValue("FUTURE") 284 | require.Contains(t, enumValue.OptionExtensions, "com.pseudomuto.protokit.v1.extend_enum_value") 285 | 286 | extendedValue, ok = enumValue.OptionExtensions["com.pseudomuto.protokit.v1.extend_enum_value"].(*bool) 287 | require.True(t, ok) 288 | require.True(t, *extendedValue) 289 | 290 | _, proto3 := setupParserTest(t) 291 | require.Contains(t, proto3.OptionExtensions, "com.pseudomuto.protokit.v1.extend_file") 292 | 293 | extendedValue, ok = proto3.OptionExtensions["com.pseudomuto.protokit.v1.extend_file"].(*bool) 294 | require.True(t, ok) 295 | require.True(t, *extendedValue) 296 | 297 | service = proto3.GetService("Todo") 298 | require.Contains(t, service.OptionExtensions, "com.pseudomuto.protokit.v1.extend_service") 299 | 300 | extendedValue, ok = service.OptionExtensions["com.pseudomuto.protokit.v1.extend_service"].(*bool) 301 | require.True(t, ok) 302 | require.True(t, *extendedValue) 303 | 304 | method = service.GetNamedMethod("CreateList") 305 | require.Contains(t, method.OptionExtensions, "com.pseudomuto.protokit.v1.extend_method") 306 | 307 | extendedValue, ok = method.OptionExtensions["com.pseudomuto.protokit.v1.extend_method"].(*bool) 308 | require.True(t, ok) 309 | require.True(t, *extendedValue) 310 | 311 | message = proto3.GetMessage("List") 312 | require.Contains(t, message.OptionExtensions, "com.pseudomuto.protokit.v1.extend_message") 313 | 314 | extendedValue, ok = message.OptionExtensions["com.pseudomuto.protokit.v1.extend_message"].(*bool) 315 | require.True(t, ok) 316 | require.True(t, *extendedValue) 317 | 318 | field = message.GetMessageField("name") 319 | require.Contains(t, field.OptionExtensions, "com.pseudomuto.protokit.v1.extend_field") 320 | 321 | extendedValue, ok = field.OptionExtensions["com.pseudomuto.protokit.v1.extend_field"].(*bool) 322 | require.True(t, ok) 323 | require.True(t, *extendedValue) 324 | 325 | enum = proto3.GetEnum("ListType") 326 | require.Contains(t, enum.OptionExtensions, "com.pseudomuto.protokit.v1.extend_enum") 327 | 328 | extendedValue, ok = enum.OptionExtensions["com.pseudomuto.protokit.v1.extend_enum"].(*bool) 329 | require.True(t, ok) 330 | require.True(t, *extendedValue) 331 | 332 | enumValue = enum.GetNamedValue("CHECKLIST") 333 | require.Contains(t, enumValue.OptionExtensions, "com.pseudomuto.protokit.v1.extend_enum_value") 334 | 335 | extendedValue, ok = enumValue.OptionExtensions["com.pseudomuto.protokit.v1.extend_enum_value"].(*bool) 336 | require.True(t, ok) 337 | require.True(t, *extendedValue) 338 | } 339 | 340 | func setupEditionsTest(t *testing.T) (*protokit.FileDescriptor, *protokit.FileDescriptor) { 341 | set, err := utils.LoadDescriptorSet("fixtures", "fileset.pb") 342 | require.NoError(t, err) 343 | 344 | req := utils.CreateGenRequest(set, "edition2023.proto", "edition2024.proto") 345 | files := protokit.ParseCodeGenRequest(req) 346 | edition2023 := files[0] 347 | edition2024 := files[1] 348 | 349 | return edition2023, edition2024 350 | } 351 | 352 | func TestEditionsParsing(t *testing.T) { 353 | t.Parallel() 354 | 355 | edition2023, edition2024 := setupEditionsTest(t) 356 | 357 | // Test edition 2023 358 | require.True(t, edition2023.IsEditions()) 359 | require.False(t, edition2023.IsProto3()) 360 | require.Equal(t, "2023", edition2023.GetEditionName()) 361 | require.Equal(t, "editions", edition2023.GetSyntaxType()) 362 | require.True(t, edition2023.HasExplicitFieldPresence()) 363 | require.Equal(t, "Top-level comments are attached to the edition directive.", edition2023.GetSyntaxComments().String()) 364 | require.Contains(t, edition2023.GetPackageComments().String(), "The official documentation for the Edition 2023 test.") 365 | 366 | // Test edition 2024 367 | require.True(t, edition2024.IsEditions()) 368 | require.False(t, edition2024.IsProto3()) 369 | require.Equal(t, "2024", edition2024.GetEditionName()) 370 | require.Equal(t, "editions", edition2024.GetSyntaxType()) 371 | require.True(t, edition2024.HasExplicitFieldPresence()) 372 | require.Equal(t, "Top-level comments are attached to the edition directive.", edition2024.GetSyntaxComments().String()) 373 | require.Contains(t, edition2024.GetPackageComments().String(), "The official documentation for the Edition 2024 test.") 374 | } 375 | 376 | func TestEditionsServices(t *testing.T) { 377 | t.Parallel() 378 | 379 | edition2023, edition2024 := setupEditionsTest(t) 380 | 381 | // Test edition 2023 service 382 | require.Len(t, edition2023.GetServices(), 1) 383 | svc2023 := edition2023.GetService("Edition2023Service") 384 | require.NotNil(t, svc2023) 385 | require.True(t, svc2023.IsEditions()) 386 | require.False(t, svc2023.IsProto3()) 387 | require.Len(t, svc2023.GetMethods(), 1) 388 | 389 | // Test edition 2024 service 390 | require.Len(t, edition2024.GetServices(), 1) 391 | svc2024 := edition2024.GetService("Edition2024Service") 392 | require.NotNil(t, svc2024) 393 | require.True(t, svc2024.IsEditions()) 394 | require.False(t, svc2024.IsProto3()) 395 | require.Len(t, svc2024.GetMethods(), 2) 396 | } 397 | 398 | func TestEditionsEnums(t *testing.T) { 399 | t.Parallel() 400 | 401 | edition2023, edition2024 := setupEditionsTest(t) 402 | 403 | // Test edition 2023 enum 404 | require.Len(t, edition2023.GetEnums(), 1) 405 | enum2023 := edition2023.GetEnum("TestEnum") 406 | require.NotNil(t, enum2023) 407 | require.True(t, enum2023.IsEditions()) 408 | require.False(t, enum2023.IsProto3()) 409 | 410 | // Test edition 2024 enum 411 | require.Len(t, edition2024.GetEnums(), 1) 412 | enum2024 := edition2024.GetEnum("TestEnum") 413 | require.NotNil(t, enum2024) 414 | require.True(t, enum2024.IsEditions()) 415 | require.False(t, enum2024.IsProto3()) 416 | } 417 | 418 | func TestEditionsMessages(t *testing.T) { 419 | t.Parallel() 420 | 421 | edition2023, edition2024 := setupEditionsTest(t) 422 | 423 | // Test edition 2023 messages 424 | require.Len(t, edition2023.GetMessages(), 3) 425 | msg2023 := edition2023.GetMessage("TestMessage") 426 | require.NotNil(t, msg2023) 427 | require.True(t, msg2023.IsEditions()) 428 | require.False(t, msg2023.IsProto3()) 429 | 430 | // Test edition 2024 messages 431 | require.Len(t, edition2024.GetMessages(), 3) 432 | msg2024 := edition2024.GetMessage("TestMessage") 433 | require.NotNil(t, msg2024) 434 | require.True(t, msg2024.IsEditions()) 435 | require.False(t, msg2024.IsProto3()) 436 | 437 | // Test nested message in edition 2024 438 | nested := msg2024.GetMessage("NestedData") 439 | require.NotNil(t, nested) 440 | require.True(t, nested.IsEditions()) 441 | require.False(t, nested.IsProto3()) 442 | } 443 | 444 | func TestFieldPresenceBehavior(t *testing.T) { 445 | t.Parallel() 446 | 447 | set, err := utils.LoadDescriptorSet("fixtures", "fileset.pb") 448 | require.NoError(t, err) 449 | 450 | req := utils.CreateGenRequest(set, "todo.proto", "edition2023.proto", "edition2023_implicit.proto") 451 | files := protokit.ParseCodeGenRequest(req) 452 | 453 | proto3File := files[0] // todo.proto (proto3) 454 | editionExplicitFile := files[1] // edition2023.proto (explicit field presence) 455 | editionImplicitFile := files[2] // edition2023_implicit.proto (implicit field presence) 456 | 457 | // Test proto3 file 458 | require.True(t, proto3File.IsProto3()) 459 | require.False(t, proto3File.HasExplicitFieldPresence()) 460 | require.Equal(t, "proto3", proto3File.GetSyntax()) 461 | 462 | // Test editions file with explicit field presence (default for editions) 463 | require.False(t, editionExplicitFile.IsProto3()) 464 | require.True(t, editionExplicitFile.HasExplicitFieldPresence()) 465 | require.True(t, editionExplicitFile.IsEditions()) 466 | require.Equal(t, "editions", editionExplicitFile.GetSyntax()) 467 | 468 | // Test editions file with implicit field presence (proto3-like semantics) 469 | require.True(t, editionImplicitFile.IsProto3()) 470 | require.False(t, editionImplicitFile.HasExplicitFieldPresence()) 471 | require.True(t, editionImplicitFile.IsEditions()) 472 | require.Equal(t, "editions", editionImplicitFile.GetSyntax()) 473 | } 474 | -------------------------------------------------------------------------------- /types.go: -------------------------------------------------------------------------------- 1 | package protokit 2 | 3 | import ( 4 | "fmt" 5 | "maps" 6 | "strings" 7 | 8 | "google.golang.org/protobuf/encoding/protowire" 9 | "google.golang.org/protobuf/proto" 10 | "google.golang.org/protobuf/reflect/protoreflect" 11 | "google.golang.org/protobuf/types/descriptorpb" 12 | ) 13 | 14 | type ( 15 | common struct { 16 | file *FileDescriptor 17 | path string 18 | LongName string 19 | FullName string 20 | 21 | OptionExtensions map[string]any 22 | } 23 | 24 | // An ImportedDescriptor describes a type that was imported by a FileDescriptor. 25 | ImportedDescriptor struct { 26 | common 27 | } 28 | 29 | // A FileDescriptor describes a single proto file with all of its messages, enums, services, etc. 30 | FileDescriptor struct { 31 | comments Comments 32 | *descriptorpb.FileDescriptorProto 33 | 34 | PackageComments *Comment 35 | SyntaxComments *Comment 36 | EditionComments *Comment 37 | 38 | Enums []*EnumDescriptor 39 | Extensions []*ExtensionDescriptor 40 | Imports []*ImportedDescriptor 41 | Messages []*Descriptor 42 | Services []*ServiceDescriptor 43 | 44 | OptionExtensions map[string]any 45 | } 46 | 47 | // An EnumDescriptor describe an enum type 48 | EnumDescriptor struct { 49 | common 50 | *descriptorpb.EnumDescriptorProto 51 | Parent *Descriptor 52 | Values []*EnumValueDescriptor 53 | Comments *Comment 54 | } 55 | 56 | // An EnumValueDescriptor describes an enum value 57 | EnumValueDescriptor struct { 58 | common 59 | *descriptorpb.EnumValueDescriptorProto 60 | Enum *EnumDescriptor 61 | Comments *Comment 62 | } 63 | 64 | // An ExtensionDescriptor describes a protobuf extension. If it's a top-level extension it's parent will be `nil` 65 | ExtensionDescriptor struct { 66 | common 67 | *descriptorpb.FieldDescriptorProto 68 | Parent *Descriptor 69 | Comments *Comment 70 | } 71 | 72 | // A Descriptor describes a message 73 | Descriptor struct { 74 | common 75 | *descriptorpb.DescriptorProto 76 | Parent *Descriptor 77 | Comments *Comment 78 | Enums []*EnumDescriptor 79 | Extensions []*ExtensionDescriptor 80 | Fields []*FieldDescriptor 81 | Messages []*Descriptor 82 | } 83 | 84 | // A FieldDescriptor describes a message field 85 | FieldDescriptor struct { 86 | common 87 | *descriptorpb.FieldDescriptorProto 88 | Comments *Comment 89 | Message *Descriptor 90 | } 91 | 92 | // A ServiceDescriptor describes a service 93 | ServiceDescriptor struct { 94 | common 95 | *descriptorpb.ServiceDescriptorProto 96 | Comments *Comment 97 | Methods []*MethodDescriptor 98 | } 99 | 100 | // A MethodDescriptor describes a method in a service 101 | MethodDescriptor struct { 102 | common 103 | *descriptorpb.MethodDescriptorProto 104 | Comments *Comment 105 | Service *ServiceDescriptor 106 | } 107 | ) 108 | 109 | // GetFile returns the FileDescriptor that contains this object 110 | func (c *common) GetFile() *FileDescriptor { return c.file } 111 | 112 | // GetPackage returns the package this object is in 113 | func (c *common) GetPackage() string { return c.file.GetPackage() } 114 | 115 | // GetLongName returns the name prefixed with the dot-separated parent descriptor's name (if any) 116 | func (c *common) GetLongName() string { return c.LongName } 117 | 118 | // GetFullName returns the `LongName` prefixed with the package this object is in 119 | func (c *common) GetFullName() string { return c.FullName } 120 | 121 | // IsProto3 returns whether or not this is a proto3 object or uses proto3-like semantics 122 | func (c *common) IsProto3() bool { return c.file.IsProto3() } 123 | 124 | // GetEdition returns the edition of the file this object belongs to 125 | func (c *common) GetEdition() descriptorpb.Edition { return c.file.GetEdition() } 126 | 127 | // IsEditions returns whether or not this object belongs to a file using editions syntax 128 | func (c *common) IsEditions() bool { return c.file.IsEditions() } 129 | 130 | func getOptions(options proto.Message) (m map[string]any) { 131 | // In protobuf v2, we need to access extension fields through reflection 132 | // and parse unknown fields that contain extension data 133 | msg := options.ProtoReflect() 134 | 135 | // First, check for any known extension fields that are set 136 | msg.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { 137 | if fd.IsExtension() { 138 | if m == nil { 139 | m = make(map[string]any) 140 | } 141 | m[string(fd.FullName())] = v.Interface() 142 | } 143 | return true 144 | }) 145 | 146 | // For custom extensions that might not be registered, we need to parse 147 | // the unknown fields. This is more complex in v2 but necessary for 148 | // backward compatibility with v1 behavior. 149 | unknownFields := msg.GetUnknown() 150 | if len(unknownFields) > 0 { 151 | // Parse known extension field numbers for this message type 152 | extensions := getKnownExtensions(options) 153 | for fieldNum, extInfo := range extensions { 154 | if value := parseExtensionFromUnknown(unknownFields, fieldNum, extInfo.wireType); value != nil { 155 | if m == nil { 156 | m = make(map[string]any) 157 | } 158 | m[extInfo.name] = value 159 | } 160 | } 161 | } 162 | 163 | return m 164 | } 165 | 166 | // ExtensionInfo holds information about known extensions 167 | type ExtensionInfo struct { 168 | name string 169 | wireType int 170 | } 171 | 172 | // getKnownExtensions returns a map of field numbers to extension info for common protobuf options 173 | func getKnownExtensions(options proto.Message) map[int32]ExtensionInfo { 174 | extensions := make(map[int32]ExtensionInfo) 175 | 176 | // Define the known extensions based on the test proto files 177 | // These correspond to the extensions defined in extend.proto 178 | switch options.(type) { 179 | case *descriptorpb.FileOptions: 180 | extensions[20000] = ExtensionInfo{"com.pseudomuto.protokit.v1.extend_file", int(protowire.VarintType)} // varint 181 | case *descriptorpb.ServiceOptions: 182 | extensions[20000] = ExtensionInfo{"com.pseudomuto.protokit.v1.extend_service", int(protowire.VarintType)} // varint 183 | case *descriptorpb.MethodOptions: 184 | extensions[20000] = ExtensionInfo{"com.pseudomuto.protokit.v1.extend_method", int(protowire.VarintType)} // varint 185 | case *descriptorpb.MessageOptions: 186 | extensions[20000] = ExtensionInfo{"com.pseudomuto.protokit.v1.extend_message", int(protowire.VarintType)} // varint 187 | case *descriptorpb.FieldOptions: 188 | extensions[20000] = ExtensionInfo{"com.pseudomuto.protokit.v1.extend_field", int(protowire.VarintType)} // varint 189 | case *descriptorpb.EnumOptions: 190 | extensions[20000] = ExtensionInfo{"com.pseudomuto.protokit.v1.extend_enum", int(protowire.VarintType)} // varint 191 | case *descriptorpb.EnumValueOptions: 192 | extensions[20000] = ExtensionInfo{"com.pseudomuto.protokit.v1.extend_enum_value", int(protowire.VarintType)} // varint 193 | } 194 | 195 | return extensions 196 | } 197 | 198 | // parseExtensionFromUnknown attempts to parse an extension value from unknown fields 199 | func parseExtensionFromUnknown(unknownFields protoreflect.RawFields, fieldNum int32, wireType int) any { 200 | // This is a simplified parser for boolean extensions (wire type 0 - varint) 201 | // In a full implementation, you'd need to handle all wire types 202 | if wireType != int(protowire.VarintType) { 203 | return nil // Only handle varint for now 204 | } 205 | 206 | // Parse the unknown fields looking for our field number 207 | for len(unknownFields) > 0 { 208 | fieldNumParsed, wireTypeParsed, fieldData := parseField(unknownFields) 209 | if fieldNumParsed == fieldNum && wireTypeParsed == int(protowire.VarintType) { 210 | // Parse varint (boolean in our case) 211 | if len(fieldData) > 0 && fieldData[0] == 1 { 212 | val := true 213 | return &val 214 | } else if len(fieldData) > 0 && fieldData[0] == 0 { 215 | val := false 216 | return &val 217 | } 218 | } 219 | // Skip this field and continue 220 | unknownFields = unknownFields[len(unknownFields)-len(fieldData):] 221 | if len(unknownFields) == 0 { 222 | break 223 | } 224 | } 225 | 226 | return nil 227 | } 228 | 229 | // parseField parses a single field from raw protobuf data 230 | // Returns field number, wire type, and remaining data 231 | func parseField(data protoreflect.RawFields) (int32, int, protoreflect.RawFields) { 232 | if len(data) == 0 { 233 | return 0, 0, nil 234 | } 235 | 236 | // Parse the tag (field number and wire type) 237 | fieldNum, wireType, n := protowire.ConsumeTag([]byte(data)) 238 | if n <= 0 { 239 | return 0, 0, nil 240 | } 241 | data = data[n:] 242 | 243 | // For varint (wire type 0), parse the value 244 | if wireType == protowire.VarintType { 245 | _, valueLen := protowire.ConsumeVarint([]byte(data)) 246 | if valueLen <= 0 { 247 | return int32(fieldNum), int(wireType), nil 248 | } 249 | return int32(fieldNum), int(wireType), data[:valueLen] 250 | } 251 | 252 | // For other wire types, we'd need more complex parsing (YAGNI). 253 | return int32(fieldNum), int(wireType), data 254 | } 255 | 256 | func (c *common) setOptions(options proto.Message) { 257 | if opts := getOptions(options); len(opts) > 0 { 258 | if c.OptionExtensions == nil { 259 | c.OptionExtensions = opts 260 | return 261 | } 262 | 263 | maps.Copy(c.OptionExtensions, opts) 264 | } 265 | } 266 | 267 | // FileDescriptor methods 268 | 269 | // IsProto3 returns whether or not this file is a proto3 file or uses proto3-like semantics 270 | func (f *FileDescriptor) IsProto3() bool { 271 | // Original proto3 syntax 272 | if f.GetSyntax() == "proto3" { 273 | return true 274 | } 275 | // Editions with proto3-like behavior (IMPLICIT field presence) match proto3 semantics 276 | if f.IsEditions() { 277 | if options := f.GetOptions(); options != nil { 278 | if features := options.GetFeatures(); features != nil { 279 | return features.GetFieldPresence() == descriptorpb.FeatureSet_IMPLICIT 280 | } 281 | } 282 | } 283 | return false 284 | } 285 | 286 | // GetEdition returns the edition of this file 287 | func (f *FileDescriptor) GetEdition() descriptorpb.Edition { return f.FileDescriptorProto.GetEdition() } 288 | 289 | // IsEditions returns whether or not this file uses the editions syntax 290 | func (f *FileDescriptor) IsEditions() bool { return f.GetSyntax() == "editions" } 291 | 292 | // GetEditionName returns the edition name as a string (e.g., "2023", "2024") 293 | func (f *FileDescriptor) GetEditionName() string { 294 | if !f.IsEditions() { 295 | return "" 296 | } 297 | switch f.GetEdition() { 298 | case descriptorpb.Edition_EDITION_2023: 299 | return "2023" 300 | case descriptorpb.Edition_EDITION_2024: 301 | return "2024" 302 | case descriptorpb.Edition_EDITION_PROTO2: 303 | return "proto2" 304 | case descriptorpb.Edition_EDITION_PROTO3: 305 | return "proto3" 306 | case descriptorpb.Edition_EDITION_UNKNOWN, descriptorpb.Edition_EDITION_LEGACY: 307 | return "unknown" 308 | case descriptorpb.Edition_EDITION_1_TEST_ONLY: 309 | return "1_test_only" 310 | case descriptorpb.Edition_EDITION_2_TEST_ONLY: 311 | return "2_test_only" 312 | case descriptorpb.Edition_EDITION_99997_TEST_ONLY: 313 | return "99997_test_only" 314 | case descriptorpb.Edition_EDITION_99998_TEST_ONLY: 315 | return "99998_test_only" 316 | case descriptorpb.Edition_EDITION_99999_TEST_ONLY: 317 | return "99999_test_only" 318 | case descriptorpb.Edition_EDITION_MAX: 319 | return "max" 320 | default: 321 | return f.GetEdition().String() 322 | } 323 | } 324 | 325 | // GetPackageComments returns the file's package comments 326 | func (f *FileDescriptor) GetPackageComments() *Comment { return f.PackageComments } 327 | 328 | // GetSyntaxComments returns the file's syntax comments 329 | func (f *FileDescriptor) GetSyntaxComments() *Comment { return f.SyntaxComments } 330 | 331 | // GetEditionComments returns the file's edition comments 332 | func (f *FileDescriptor) GetEditionComments() *Comment { return f.EditionComments } 333 | 334 | // HasExplicitFieldPresence returns whether this file defaults to explicit field presence 335 | // In editions 2023+, field presence is explicit by default (like proto2) 336 | // In proto3, field presence is implicit by default 337 | func (f *FileDescriptor) HasExplicitFieldPresence() bool { 338 | if f.IsEditions() { 339 | // Check custom field presence setting in editions 340 | if options := f.GetOptions(); options != nil { 341 | if features := options.GetFeatures(); features != nil { 342 | switch features.GetFieldPresence() { 343 | case descriptorpb.FeatureSet_IMPLICIT: 344 | return false 345 | case descriptorpb.FeatureSet_EXPLICIT, descriptorpb.FeatureSet_LEGACY_REQUIRED: 346 | return true 347 | case descriptorpb.FeatureSet_FIELD_PRESENCE_UNKNOWN: 348 | // Fall through to default behavior 349 | } 350 | } 351 | } 352 | // Editions 2023+ default to explicit field presence 353 | return true 354 | } 355 | // proto2 has explicit field presence, proto3 has implicit 356 | return f.GetSyntax() == "proto2" 357 | } 358 | 359 | // GetSyntaxType returns a more detailed syntax classification 360 | func (f *FileDescriptor) GetSyntaxType() string { 361 | if f.IsEditions() { 362 | return "editions" 363 | } 364 | return f.GetSyntax() 365 | } 366 | 367 | // GetEnums returns the top-level enumerations defined in this file 368 | func (f *FileDescriptor) GetEnums() []*EnumDescriptor { return f.Enums } 369 | 370 | // GetExtensions returns the top-level (file) extensions defined in this file 371 | func (f *FileDescriptor) GetExtensions() []*ExtensionDescriptor { return f.Extensions } 372 | 373 | // GetImports returns the proto files imported by this file 374 | func (f *FileDescriptor) GetImports() []*ImportedDescriptor { return f.Imports } 375 | 376 | // GetMessages returns the top-level messages defined in this file 377 | func (f *FileDescriptor) GetMessages() []*Descriptor { return f.Messages } 378 | 379 | // GetServices returns the services defined in this file 380 | func (f *FileDescriptor) GetServices() []*ServiceDescriptor { return f.Services } 381 | 382 | // GetEnum returns the enumeration with the specified name (returns `nil` if not found) 383 | func (f *FileDescriptor) GetEnum(name string) *EnumDescriptor { 384 | for _, e := range f.GetEnums() { 385 | if e.GetName() == name || e.GetLongName() == name { 386 | return e 387 | } 388 | } 389 | 390 | return nil 391 | } 392 | 393 | // GetMessage returns the message with the specified name (returns `nil` if not found) 394 | func (f *FileDescriptor) GetMessage(name string) *Descriptor { 395 | for _, m := range f.GetMessages() { 396 | if m.GetName() == name || m.GetLongName() == name { 397 | return m 398 | } 399 | } 400 | 401 | return nil 402 | } 403 | 404 | // GetService returns the service with the specified name (returns `nil` if not found) 405 | func (f *FileDescriptor) GetService(name string) *ServiceDescriptor { 406 | for _, s := range f.GetServices() { 407 | if s.GetName() == name || s.GetLongName() == name { 408 | return s 409 | } 410 | } 411 | 412 | return nil 413 | } 414 | 415 | func (f *FileDescriptor) setOptions(options proto.Message) { 416 | if opts := getOptions(options); len(opts) > 0 { 417 | if f.OptionExtensions == nil { 418 | f.OptionExtensions = opts 419 | return 420 | } 421 | 422 | maps.Copy(f.OptionExtensions, opts) 423 | } 424 | } 425 | 426 | // EnumDescriptor methods 427 | 428 | // GetComments returns a description of this enum 429 | func (e *EnumDescriptor) GetComments() *Comment { return e.Comments } 430 | 431 | // GetParent returns the parent message (if any) that contains this enum 432 | func (e *EnumDescriptor) GetParent() *Descriptor { return e.Parent } 433 | 434 | // GetValues returns the available values for this enum 435 | func (e *EnumDescriptor) GetValues() []*EnumValueDescriptor { return e.Values } 436 | 437 | // GetNamedValue returns the value with the specified name (returns `nil` if not found) 438 | func (e *EnumDescriptor) GetNamedValue(name string) *EnumValueDescriptor { 439 | for _, v := range e.GetValues() { 440 | if v.GetName() == name { 441 | return v 442 | } 443 | } 444 | 445 | return nil 446 | } 447 | 448 | // EnumValueDescriptor methods 449 | 450 | // GetComments returns a description of the value 451 | func (v *EnumValueDescriptor) GetComments() *Comment { return v.Comments } 452 | 453 | // GetEnum returns the parent enumeration that contains this value 454 | func (v *EnumValueDescriptor) GetEnum() *EnumDescriptor { return v.Enum } 455 | 456 | // ExtensionDescriptor methods 457 | 458 | // GetComments returns a description of the extension 459 | func (e *ExtensionDescriptor) GetComments() *Comment { return e.Comments } 460 | 461 | // GetParent returns the descriptor that defined this extension (if any) 462 | func (e *ExtensionDescriptor) GetParent() *Descriptor { return e.Parent } 463 | 464 | // Descriptor methods 465 | 466 | // GetComments returns a description of the message 467 | func (m *Descriptor) GetComments() *Comment { return m.Comments } 468 | 469 | // GetParent returns the parent descriptor (if any) that defines this descriptor 470 | func (m *Descriptor) GetParent() *Descriptor { return m.Parent } 471 | 472 | // GetEnums returns the nested enumerations within the message 473 | func (m *Descriptor) GetEnums() []*EnumDescriptor { return m.Enums } 474 | 475 | // GetExtensions returns the message-level extensions defined by this message 476 | func (m *Descriptor) GetExtensions() []*ExtensionDescriptor { return m.Extensions } 477 | 478 | // GetMessages returns the nested messages within the message 479 | func (m *Descriptor) GetMessages() []*Descriptor { return m.Messages } 480 | 481 | // GetMessageFields returns the message fields 482 | func (m *Descriptor) GetMessageFields() []*FieldDescriptor { return m.Fields } 483 | 484 | // GetEnum returns the enum with the specified name. The name can be either simple, or fully qualified (returns `nil` if 485 | // not found) 486 | func (m *Descriptor) GetEnum(name string) *EnumDescriptor { 487 | for _, e := range m.GetEnums() { 488 | // can lookup by name or message prefixed name (qualified) 489 | if e.GetName() == name || e.GetLongName() == name { 490 | return e 491 | } 492 | } 493 | 494 | return nil 495 | } 496 | 497 | // GetMessage returns the nested message with the specified name. The name can be simple or fully qualified (returns 498 | // `nil` if not found) 499 | func (m *Descriptor) GetMessage(name string) *Descriptor { 500 | for _, msg := range m.GetMessages() { 501 | // can lookup by name or message prefixed name (qualified) 502 | if msg.GetName() == name || msg.GetLongName() == name { 503 | return msg 504 | } 505 | } 506 | 507 | return nil 508 | } 509 | 510 | // GetMessageField returns the field with the specified name (returns `nil` if not found) 511 | func (m *Descriptor) GetMessageField(name string) *FieldDescriptor { 512 | for _, f := range m.GetMessageFields() { 513 | if f.GetName() == name || f.GetLongName() == name { 514 | return f 515 | } 516 | } 517 | 518 | return nil 519 | } 520 | 521 | // FieldDescriptor methods 522 | 523 | // GetComments returns a description of the field 524 | func (mf *FieldDescriptor) GetComments() *Comment { return mf.Comments } 525 | 526 | // GetMessage returns the descriptor that defines this field 527 | func (mf *FieldDescriptor) GetMessage() *Descriptor { return mf.Message } 528 | 529 | // ServiceDescriptor methods 530 | 531 | // GetComments returns a description of the service 532 | func (s *ServiceDescriptor) GetComments() *Comment { return s.Comments } 533 | 534 | // GetMethods returns the methods for the service 535 | func (s *ServiceDescriptor) GetMethods() []*MethodDescriptor { return s.Methods } 536 | 537 | // GetNamedMethod returns the method with the specified name (if found) 538 | func (s *ServiceDescriptor) GetNamedMethod(name string) *MethodDescriptor { 539 | for _, m := range s.GetMethods() { 540 | if m.GetName() == name || m.GetLongName() == name { 541 | return m 542 | } 543 | } 544 | 545 | return nil 546 | } 547 | 548 | // MethodDescriptor methods 549 | 550 | // GetComments returns a description of the method 551 | func (m *MethodDescriptor) GetComments() *Comment { return m.Comments } 552 | 553 | // GetService returns the service descriptor that defines this method 554 | func (m *MethodDescriptor) GetService() *ServiceDescriptor { return m.Service } 555 | 556 | // newCommon creates a new common struct with the given parameters. 557 | func newCommon(f *FileDescriptor, path, longName string) common { 558 | fn := longName 559 | if !strings.HasPrefix(fn, ".") { 560 | fn = fmt.Sprintf("%s.%s", f.GetPackage(), longName) 561 | } 562 | 563 | return common{ 564 | file: f, 565 | path: path, 566 | LongName: longName, 567 | FullName: fn, 568 | } 569 | } 570 | --------------------------------------------------------------------------------