├── .circleci └── config.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── channel.go ├── cmd └── protoc-gen-grpchan │ ├── gen_test.go │ ├── gen_test.proto │ ├── gen_test_grpc_pb_test.go │ ├── gen_test_pb_test.go │ └── protoc-gen-grpchan.go ├── doc.go ├── download_protoc.sh ├── go.mod ├── go.sum ├── grpchantesting ├── channel_test_cases.go ├── channel_test_cases_test.go ├── doc.go ├── test.pb.go ├── test.pb.grpchan.go ├── test.proto ├── test_grpc.pb.go └── test_service.go ├── httpgrpc ├── client.go ├── codes.go ├── doc.go ├── httpgrpc.pb.go ├── httpgrpc.proto ├── httpgrpc_test.go ├── io.go ├── json.go ├── protocol_versions.go └── server.go ├── inprocgrpc ├── cloner.go ├── cloner_test.go ├── in_process.go ├── in_process_test.go └── no_values_context_test.go ├── intercept.go ├── intercept_client_test.go ├── intercept_server_test.go ├── interceptor_chain_client_test.go ├── interceptor_chain_server_test.go ├── internal ├── call_options.go ├── misc.go └── transport_stream.go └── server.go /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | shared_configs: 2 | simple_job_steps: &simple_job_steps 3 | - checkout 4 | - run: 5 | name: Run tests 6 | command: | 7 | make test 8 | 9 | # Use the latest 2.1 version of CircleCI pipeline process engine. See: https://circleci.com/docs/2.0/configuration-reference 10 | version: 2.1 11 | jobs: 12 | build-1-21: 13 | working_directory: ~/repo 14 | docker: 15 | - image: cimg/go:1.21 16 | steps: *simple_job_steps 17 | 18 | build-1-22: 19 | working_directory: ~/repo 20 | docker: 21 | - image: cimg/go:1.22 22 | steps: *simple_job_steps 23 | 24 | build-1-23: 25 | working_directory: ~/repo 26 | docker: 27 | - image: cimg/go:1.23 28 | steps: 29 | - checkout 30 | - run: 31 | name: Run tests and linters 32 | command: | 33 | make ci 34 | 35 | workflows: 36 | pr-build-test: 37 | jobs: 38 | - build-1-21 39 | - build-1-22 40 | - build-1-23 41 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | .tmp/ 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2018 Fullstory, Inc 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | export PATH := $(shell pwd)/.tmp/protoc/bin:$(PATH) 2 | export PROTOC_VERSION := 22.0 3 | 4 | .PHONY: ci 5 | ci: deps checkgofmt vet staticcheck ineffassign predeclared test 6 | 7 | .PHONY: deps 8 | deps: 9 | go get -d -v -t ./... 10 | go mod tidy 11 | 12 | .PHONY: updatedeps 13 | updatedeps: 14 | go get -d -v -t -u -f ./... 15 | go mod tidy 16 | 17 | .PHONY: install 18 | install: 19 | go install ./... 20 | 21 | .PHONY: checkgofmt 22 | checkgofmt: 23 | gofmt -s -l . 24 | @if [ -n "$$(gofmt -s -l .)" ]; then \ 25 | exit 1; \ 26 | fi 27 | 28 | .PHONY: generate 29 | generate: .tmp/protoc/bin/protoc 30 | @go install ./cmd/protoc-gen-grpchan 31 | @go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.26 32 | @go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1.0 33 | go generate ./... 34 | 35 | .PHONY: vet 36 | vet: 37 | go vet ./... 38 | 39 | .PHONY: staticcheck 40 | staticcheck: 41 | @go install honnef.co/go/tools/cmd/staticcheck@v0.5.1 42 | staticcheck -checks "inherit,-SA1019" ./... 43 | 44 | .PHONY: ineffassign 45 | ineffassign: 46 | @go install github.com/gordonklaus/ineffassign@7953dde2c7bf 47 | ineffassign . 48 | 49 | .PHONY: predeclared 50 | predeclared: 51 | @go install github.com/nishanths/predeclared@245576f9a85c96ea16c750df3887f1d827f01e9c 52 | predeclared . 53 | 54 | # Intentionally omitted from CI, but target here for ad-hoc reports. 55 | .PHONY: golint 56 | golint: 57 | @go install golang.org/x/lint/golint@v0.0.0-20210508222113-6edffad5e616 58 | golint -min_confidence 0.9 -set_exit_status ./... 59 | 60 | # Intentionally omitted from CI, but target here for ad-hoc reports. 61 | .PHONY: errcheck 62 | errcheck: 63 | @go install github.com/kisielk/errcheck@v1.2.0 64 | errcheck ./... 65 | 66 | .PHONY: test 67 | test: 68 | go test -race ./... 69 | 70 | .tmp/protoc/bin/protoc: ./Makefile ./download_protoc.sh 71 | ./download_protoc.sh 72 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # gRPC Channels 2 | 3 | [![Build Status](https://circleci.com/gh/fullstorydev/grpchan/tree/master.svg?style=svg)](https://circleci.com/gh/fullstorydev/grpchan/tree/master) 4 | [![Go Report Card](https://goreportcard.com/badge/github.com/fullstorydev/grpchan)](https://goreportcard.com/report/github.com/fullstorydev/grpchan) 5 | [![GoDoc](https://godoc.org/github.com/fullstorydev/grpchan?status.svg)](https://godoc.org/github.com/fullstorydev/grpchan) 6 | 7 | This repo provides an abstraction for an RPC connection: the `Channel`. 8 | Implementations of `Channel` can provide alternate transports -- different 9 | from the standard HTTP/2-based transport provided by the `google.golang.org/grpc` 10 | package. 11 | 12 | This can be useful for providing new transports, such as HTTP 1.1, web sockets, 13 | or (significantly) in-process channels for testing. 14 | 15 | This repo also contains two such alternate transports: an HTTP 1.1 implementation 16 | of gRPC (which supports all stream kinds other than full-duplex bidi streams) and 17 | an in-process transport (which allows a process to dispatch handlers implemented 18 | in the same program without needing serialize and de-serialize messages over the 19 | loopback network interface). 20 | 21 | In order to use channels with your proto-defined gRPC services, you need to use a 22 | protoc plugin included in this repo: `protoc-gen-grpchan`. 23 | 24 | ```bash 25 | go install github.com/fullstorydev/grpchan/cmd/protoc-gen-grpchan 26 | ``` 27 | 28 | You use the plugin via a `--grpchan_out` parameter to protoc. Specify the same 29 | output directory to this parameter as you supply to `--go_out`. The plugin will 30 | then generate `*.pb.grpchan.go` files, alongside the `*.pb.go` files. These 31 | additional files contain additional methods that let you use the proto-defined 32 | service methods with alternate transports. 33 | 34 | ```go 35 | //go:generate protoc --go_out=plugins=grpc:. --grpchan_out=. my.proto 36 | ``` 37 | -------------------------------------------------------------------------------- /channel.go: -------------------------------------------------------------------------------- 1 | package grpchan 2 | 3 | import ( 4 | "google.golang.org/grpc" 5 | ) 6 | 7 | // Channel is an abstraction of a GRPC transport. With corresponding generated 8 | // code, it can provide an alternate transport to the standard HTTP/2-based one. 9 | // For example, a Channel implementation could instead provide an HTTP 1.1-based 10 | // transport, or an in-process transport. 11 | // 12 | // Deprecated: Use grpc.ClientConnInterface instead. 13 | type Channel = grpc.ClientConnInterface 14 | -------------------------------------------------------------------------------- /cmd/protoc-gen-grpchan/gen_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/jhump/protoreflect/grpcreflect" 7 | ) 8 | 9 | // we use bash to mv the file after it is generated so that it will end in "_test.go" 10 | // (so it's just a test file and not linked into the actual protoc-gen-grpchan command) 11 | //go:generate bash -c "protoc --go_out=. --go-grpc_out=. --go_opt=paths=source_relative gen_test.proto && mv ./gen_test.pb.go ./gen_test_pb_test.go && mv ./gen_test_grpc.pb.go ./gen_test_grpc_pb_test.go" 12 | 13 | func TestStreamOrder(t *testing.T) { 14 | // we get the service descriptor (same descriptor that protoc-gen-grpchan processes 15 | sd, err := grpcreflect.LoadServiceDescriptor(&TestStreams_ServiceDesc) 16 | if err != nil { 17 | t.Fatalf("failed to load service descriptor: %v", err) 18 | } 19 | 20 | // loop through stream methods just as protoc-gen-grpchan does to emit code 21 | streamCount := 0 22 | for _, md := range sd.GetMethods() { 23 | if md.IsClientStreaming() || md.IsServerStreaming() { 24 | // verify that the stream at current index is correct 25 | // (code emits this index when querying the serviceDesc, so we must 26 | // be certain that this index is right!) 27 | strDesc := TestStreams_ServiceDesc.Streams[streamCount] 28 | if md.GetName() != strDesc.StreamName { 29 | t.Fatalf("wrong stream at %d: %s != %s", streamCount, md.GetName(), strDesc.StreamName) 30 | } 31 | if md.IsClientStreaming() != strDesc.ClientStreams || md.IsServerStreaming() != strDesc.ServerStreams { 32 | t.Fatalf("wrong stream type at %d", streamCount) 33 | } 34 | 35 | streamCount++ 36 | } 37 | } 38 | 39 | // sanity check that we saw all of the streams 40 | if streamCount != 7 { 41 | t.Fatalf("processed wrong number of methods: %d != %d", streamCount, 7) 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /cmd/protoc-gen-grpchan/gen_test.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package main; 4 | 5 | option go_package = "./;main"; 6 | 7 | message Test { 8 | int64 id = 1; 9 | string name = 2; 10 | } 11 | 12 | service TestStreams { 13 | // We need to verify the order in which streams are stored in a 14 | // service desc in generated GRPC code. So we concoct a bunch of 15 | // methods with streams interleaved therein, and then have a test 16 | // that verifies the indexes of each stream method, ensuring they 17 | // match the assumption made by protoc-gen-grpchan code gen. 18 | 19 | rpc Unary1 (Test) returns (Test); 20 | rpc Unary2 (Test) returns (Test); 21 | rpc Stream1 (stream Test) returns (Test); 22 | rpc Unary3 (Test) returns (Test); 23 | rpc Stream2 (stream Test) returns (stream Test); 24 | rpc Unary4 (Test) returns (Test); 25 | rpc Unary5 (Test) returns (Test); 26 | rpc Stream3 (Test) returns (stream Test); 27 | rpc Unary6 (Test) returns (Test); 28 | rpc Unary7 (Test) returns (Test); 29 | rpc Unary8 (Test) returns (Test); 30 | rpc Stream4 (stream Test) returns (Test); 31 | rpc Stream5 (Test) returns (stream Test); 32 | rpc Unary9 (Test) returns (Test); 33 | rpc Unary10 (Test) returns (Test); 34 | rpc Unary11 (Test) returns (Test); 35 | rpc Stream6 (stream Test) returns (stream Test); 36 | rpc Stream7 (stream Test) returns (stream Test); 37 | rpc Unary12 (Test) returns (Test); 38 | } 39 | -------------------------------------------------------------------------------- /cmd/protoc-gen-grpchan/gen_test_pb_test.go: -------------------------------------------------------------------------------- 1 | // Code generated by protoc-gen-go. DO NOT EDIT. 2 | // versions: 3 | // protoc-gen-go v1.26.0 4 | // protoc v4.22.0 5 | // source: gen_test.proto 6 | 7 | package main 8 | 9 | import ( 10 | protoreflect "google.golang.org/protobuf/reflect/protoreflect" 11 | protoimpl "google.golang.org/protobuf/runtime/protoimpl" 12 | reflect "reflect" 13 | sync "sync" 14 | ) 15 | 16 | const ( 17 | // Verify that this generated code is sufficiently up-to-date. 18 | _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) 19 | // Verify that runtime/protoimpl is sufficiently up-to-date. 20 | _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) 21 | ) 22 | 23 | type Test struct { 24 | state protoimpl.MessageState 25 | sizeCache protoimpl.SizeCache 26 | unknownFields protoimpl.UnknownFields 27 | 28 | Id int64 `protobuf:"varint,1,opt,name=id,proto3" json:"id,omitempty"` 29 | Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"` 30 | } 31 | 32 | func (x *Test) Reset() { 33 | *x = Test{} 34 | if protoimpl.UnsafeEnabled { 35 | mi := &file_gen_test_proto_msgTypes[0] 36 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 37 | ms.StoreMessageInfo(mi) 38 | } 39 | } 40 | 41 | func (x *Test) String() string { 42 | return protoimpl.X.MessageStringOf(x) 43 | } 44 | 45 | func (*Test) ProtoMessage() {} 46 | 47 | func (x *Test) ProtoReflect() protoreflect.Message { 48 | mi := &file_gen_test_proto_msgTypes[0] 49 | if protoimpl.UnsafeEnabled && x != nil { 50 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 51 | if ms.LoadMessageInfo() == nil { 52 | ms.StoreMessageInfo(mi) 53 | } 54 | return ms 55 | } 56 | return mi.MessageOf(x) 57 | } 58 | 59 | // Deprecated: Use Test.ProtoReflect.Descriptor instead. 60 | func (*Test) Descriptor() ([]byte, []int) { 61 | return file_gen_test_proto_rawDescGZIP(), []int{0} 62 | } 63 | 64 | func (x *Test) GetId() int64 { 65 | if x != nil { 66 | return x.Id 67 | } 68 | return 0 69 | } 70 | 71 | func (x *Test) GetName() string { 72 | if x != nil { 73 | return x.Name 74 | } 75 | return "" 76 | } 77 | 78 | var File_gen_test_proto protoreflect.FileDescriptor 79 | 80 | var file_gen_test_proto_rawDesc = []byte{ 81 | 0x0a, 0x0e, 0x67, 0x65, 0x6e, 0x5f, 0x74, 0x65, 0x73, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 82 | 0x12, 0x04, 0x6d, 0x61, 0x69, 0x6e, 0x22, 0x2a, 0x0a, 0x04, 0x54, 0x65, 0x73, 0x74, 0x12, 0x0e, 83 | 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x02, 0x69, 0x64, 0x12, 0x12, 84 | 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 85 | 0x6d, 0x65, 0x32, 0xb1, 0x05, 0x0a, 0x0b, 0x54, 0x65, 0x73, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 86 | 0x6d, 0x73, 0x12, 0x20, 0x0a, 0x06, 0x55, 0x6e, 0x61, 0x72, 0x79, 0x31, 0x12, 0x0a, 0x2e, 0x6d, 87 | 0x61, 0x69, 0x6e, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x1a, 0x0a, 0x2e, 0x6d, 0x61, 0x69, 0x6e, 0x2e, 88 | 0x54, 0x65, 0x73, 0x74, 0x12, 0x20, 0x0a, 0x06, 0x55, 0x6e, 0x61, 0x72, 0x79, 0x32, 0x12, 0x0a, 89 | 0x2e, 0x6d, 0x61, 0x69, 0x6e, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x1a, 0x0a, 0x2e, 0x6d, 0x61, 0x69, 90 | 0x6e, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x12, 0x23, 0x0a, 0x07, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 91 | 0x31, 0x12, 0x0a, 0x2e, 0x6d, 0x61, 0x69, 0x6e, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x1a, 0x0a, 0x2e, 92 | 0x6d, 0x61, 0x69, 0x6e, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x28, 0x01, 0x12, 0x20, 0x0a, 0x06, 0x55, 93 | 0x6e, 0x61, 0x72, 0x79, 0x33, 0x12, 0x0a, 0x2e, 0x6d, 0x61, 0x69, 0x6e, 0x2e, 0x54, 0x65, 0x73, 94 | 0x74, 0x1a, 0x0a, 0x2e, 0x6d, 0x61, 0x69, 0x6e, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x12, 0x25, 0x0a, 95 | 0x07, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x32, 0x12, 0x0a, 0x2e, 0x6d, 0x61, 0x69, 0x6e, 0x2e, 96 | 0x54, 0x65, 0x73, 0x74, 0x1a, 0x0a, 0x2e, 0x6d, 0x61, 0x69, 0x6e, 0x2e, 0x54, 0x65, 0x73, 0x74, 97 | 0x28, 0x01, 0x30, 0x01, 0x12, 0x20, 0x0a, 0x06, 0x55, 0x6e, 0x61, 0x72, 0x79, 0x34, 0x12, 0x0a, 98 | 0x2e, 0x6d, 0x61, 0x69, 0x6e, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x1a, 0x0a, 0x2e, 0x6d, 0x61, 0x69, 99 | 0x6e, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x12, 0x20, 0x0a, 0x06, 0x55, 0x6e, 0x61, 0x72, 0x79, 0x35, 100 | 0x12, 0x0a, 0x2e, 0x6d, 0x61, 0x69, 0x6e, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x1a, 0x0a, 0x2e, 0x6d, 101 | 0x61, 0x69, 0x6e, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x12, 0x23, 0x0a, 0x07, 0x53, 0x74, 0x72, 0x65, 102 | 0x61, 0x6d, 0x33, 0x12, 0x0a, 0x2e, 0x6d, 0x61, 0x69, 0x6e, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x1a, 103 | 0x0a, 0x2e, 0x6d, 0x61, 0x69, 0x6e, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x30, 0x01, 0x12, 0x20, 0x0a, 104 | 0x06, 0x55, 0x6e, 0x61, 0x72, 0x79, 0x36, 0x12, 0x0a, 0x2e, 0x6d, 0x61, 0x69, 0x6e, 0x2e, 0x54, 105 | 0x65, 0x73, 0x74, 0x1a, 0x0a, 0x2e, 0x6d, 0x61, 0x69, 0x6e, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x12, 106 | 0x20, 0x0a, 0x06, 0x55, 0x6e, 0x61, 0x72, 0x79, 0x37, 0x12, 0x0a, 0x2e, 0x6d, 0x61, 0x69, 0x6e, 107 | 0x2e, 0x54, 0x65, 0x73, 0x74, 0x1a, 0x0a, 0x2e, 0x6d, 0x61, 0x69, 0x6e, 0x2e, 0x54, 0x65, 0x73, 108 | 0x74, 0x12, 0x20, 0x0a, 0x06, 0x55, 0x6e, 0x61, 0x72, 0x79, 0x38, 0x12, 0x0a, 0x2e, 0x6d, 0x61, 109 | 0x69, 0x6e, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x1a, 0x0a, 0x2e, 0x6d, 0x61, 0x69, 0x6e, 0x2e, 0x54, 110 | 0x65, 0x73, 0x74, 0x12, 0x23, 0x0a, 0x07, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x34, 0x12, 0x0a, 111 | 0x2e, 0x6d, 0x61, 0x69, 0x6e, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x1a, 0x0a, 0x2e, 0x6d, 0x61, 0x69, 112 | 0x6e, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x28, 0x01, 0x12, 0x23, 0x0a, 0x07, 0x53, 0x74, 0x72, 0x65, 113 | 0x61, 0x6d, 0x35, 0x12, 0x0a, 0x2e, 0x6d, 0x61, 0x69, 0x6e, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x1a, 114 | 0x0a, 0x2e, 0x6d, 0x61, 0x69, 0x6e, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x30, 0x01, 0x12, 0x20, 0x0a, 115 | 0x06, 0x55, 0x6e, 0x61, 0x72, 0x79, 0x39, 0x12, 0x0a, 0x2e, 0x6d, 0x61, 0x69, 0x6e, 0x2e, 0x54, 116 | 0x65, 0x73, 0x74, 0x1a, 0x0a, 0x2e, 0x6d, 0x61, 0x69, 0x6e, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x12, 117 | 0x21, 0x0a, 0x07, 0x55, 0x6e, 0x61, 0x72, 0x79, 0x31, 0x30, 0x12, 0x0a, 0x2e, 0x6d, 0x61, 0x69, 118 | 0x6e, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x1a, 0x0a, 0x2e, 0x6d, 0x61, 0x69, 0x6e, 0x2e, 0x54, 0x65, 119 | 0x73, 0x74, 0x12, 0x21, 0x0a, 0x07, 0x55, 0x6e, 0x61, 0x72, 0x79, 0x31, 0x31, 0x12, 0x0a, 0x2e, 120 | 0x6d, 0x61, 0x69, 0x6e, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x1a, 0x0a, 0x2e, 0x6d, 0x61, 0x69, 0x6e, 121 | 0x2e, 0x54, 0x65, 0x73, 0x74, 0x12, 0x25, 0x0a, 0x07, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x36, 122 | 0x12, 0x0a, 0x2e, 0x6d, 0x61, 0x69, 0x6e, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x1a, 0x0a, 0x2e, 0x6d, 123 | 0x61, 0x69, 0x6e, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x28, 0x01, 0x30, 0x01, 0x12, 0x25, 0x0a, 0x07, 124 | 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x37, 0x12, 0x0a, 0x2e, 0x6d, 0x61, 0x69, 0x6e, 0x2e, 0x54, 125 | 0x65, 0x73, 0x74, 0x1a, 0x0a, 0x2e, 0x6d, 0x61, 0x69, 0x6e, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x28, 126 | 0x01, 0x30, 0x01, 0x12, 0x21, 0x0a, 0x07, 0x55, 0x6e, 0x61, 0x72, 0x79, 0x31, 0x32, 0x12, 0x0a, 127 | 0x2e, 0x6d, 0x61, 0x69, 0x6e, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x1a, 0x0a, 0x2e, 0x6d, 0x61, 0x69, 128 | 0x6e, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x42, 0x09, 0x5a, 0x07, 0x2e, 0x2f, 0x3b, 0x6d, 0x61, 0x69, 129 | 0x6e, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, 130 | } 131 | 132 | var ( 133 | file_gen_test_proto_rawDescOnce sync.Once 134 | file_gen_test_proto_rawDescData = file_gen_test_proto_rawDesc 135 | ) 136 | 137 | func file_gen_test_proto_rawDescGZIP() []byte { 138 | file_gen_test_proto_rawDescOnce.Do(func() { 139 | file_gen_test_proto_rawDescData = protoimpl.X.CompressGZIP(file_gen_test_proto_rawDescData) 140 | }) 141 | return file_gen_test_proto_rawDescData 142 | } 143 | 144 | var file_gen_test_proto_msgTypes = make([]protoimpl.MessageInfo, 1) 145 | var file_gen_test_proto_goTypes = []interface{}{ 146 | (*Test)(nil), // 0: main.Test 147 | } 148 | var file_gen_test_proto_depIdxs = []int32{ 149 | 0, // 0: main.TestStreams.Unary1:input_type -> main.Test 150 | 0, // 1: main.TestStreams.Unary2:input_type -> main.Test 151 | 0, // 2: main.TestStreams.Stream1:input_type -> main.Test 152 | 0, // 3: main.TestStreams.Unary3:input_type -> main.Test 153 | 0, // 4: main.TestStreams.Stream2:input_type -> main.Test 154 | 0, // 5: main.TestStreams.Unary4:input_type -> main.Test 155 | 0, // 6: main.TestStreams.Unary5:input_type -> main.Test 156 | 0, // 7: main.TestStreams.Stream3:input_type -> main.Test 157 | 0, // 8: main.TestStreams.Unary6:input_type -> main.Test 158 | 0, // 9: main.TestStreams.Unary7:input_type -> main.Test 159 | 0, // 10: main.TestStreams.Unary8:input_type -> main.Test 160 | 0, // 11: main.TestStreams.Stream4:input_type -> main.Test 161 | 0, // 12: main.TestStreams.Stream5:input_type -> main.Test 162 | 0, // 13: main.TestStreams.Unary9:input_type -> main.Test 163 | 0, // 14: main.TestStreams.Unary10:input_type -> main.Test 164 | 0, // 15: main.TestStreams.Unary11:input_type -> main.Test 165 | 0, // 16: main.TestStreams.Stream6:input_type -> main.Test 166 | 0, // 17: main.TestStreams.Stream7:input_type -> main.Test 167 | 0, // 18: main.TestStreams.Unary12:input_type -> main.Test 168 | 0, // 19: main.TestStreams.Unary1:output_type -> main.Test 169 | 0, // 20: main.TestStreams.Unary2:output_type -> main.Test 170 | 0, // 21: main.TestStreams.Stream1:output_type -> main.Test 171 | 0, // 22: main.TestStreams.Unary3:output_type -> main.Test 172 | 0, // 23: main.TestStreams.Stream2:output_type -> main.Test 173 | 0, // 24: main.TestStreams.Unary4:output_type -> main.Test 174 | 0, // 25: main.TestStreams.Unary5:output_type -> main.Test 175 | 0, // 26: main.TestStreams.Stream3:output_type -> main.Test 176 | 0, // 27: main.TestStreams.Unary6:output_type -> main.Test 177 | 0, // 28: main.TestStreams.Unary7:output_type -> main.Test 178 | 0, // 29: main.TestStreams.Unary8:output_type -> main.Test 179 | 0, // 30: main.TestStreams.Stream4:output_type -> main.Test 180 | 0, // 31: main.TestStreams.Stream5:output_type -> main.Test 181 | 0, // 32: main.TestStreams.Unary9:output_type -> main.Test 182 | 0, // 33: main.TestStreams.Unary10:output_type -> main.Test 183 | 0, // 34: main.TestStreams.Unary11:output_type -> main.Test 184 | 0, // 35: main.TestStreams.Stream6:output_type -> main.Test 185 | 0, // 36: main.TestStreams.Stream7:output_type -> main.Test 186 | 0, // 37: main.TestStreams.Unary12:output_type -> main.Test 187 | 19, // [19:38] is the sub-list for method output_type 188 | 0, // [0:19] is the sub-list for method input_type 189 | 0, // [0:0] is the sub-list for extension type_name 190 | 0, // [0:0] is the sub-list for extension extendee 191 | 0, // [0:0] is the sub-list for field type_name 192 | } 193 | 194 | func init() { file_gen_test_proto_init() } 195 | func file_gen_test_proto_init() { 196 | if File_gen_test_proto != nil { 197 | return 198 | } 199 | if !protoimpl.UnsafeEnabled { 200 | file_gen_test_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { 201 | switch v := v.(*Test); i { 202 | case 0: 203 | return &v.state 204 | case 1: 205 | return &v.sizeCache 206 | case 2: 207 | return &v.unknownFields 208 | default: 209 | return nil 210 | } 211 | } 212 | } 213 | type x struct{} 214 | out := protoimpl.TypeBuilder{ 215 | File: protoimpl.DescBuilder{ 216 | GoPackagePath: reflect.TypeOf(x{}).PkgPath(), 217 | RawDescriptor: file_gen_test_proto_rawDesc, 218 | NumEnums: 0, 219 | NumMessages: 1, 220 | NumExtensions: 0, 221 | NumServices: 1, 222 | }, 223 | GoTypes: file_gen_test_proto_goTypes, 224 | DependencyIndexes: file_gen_test_proto_depIdxs, 225 | MessageInfos: file_gen_test_proto_msgTypes, 226 | }.Build() 227 | File_gen_test_proto = out.File 228 | file_gen_test_proto_rawDesc = nil 229 | file_gen_test_proto_goTypes = nil 230 | file_gen_test_proto_depIdxs = nil 231 | } 232 | -------------------------------------------------------------------------------- /cmd/protoc-gen-grpchan/protoc-gen-grpchan.go: -------------------------------------------------------------------------------- 1 | // Command protoc-gen-grpchan is a protoc plugin that generates gRPC client stubs 2 | // in Go that use github.com/fullstorydev/grpchan.Channel as their transport 3 | // abstraction, instead of using *grpc.ClientConn. This can be used to carry RPC 4 | // requests and streams over other transports, such as HTTP 1.1 or in-process. 5 | package main 6 | 7 | import ( 8 | "fmt" 9 | "path" 10 | "strings" 11 | "text/template" 12 | 13 | "github.com/jhump/gopoet" 14 | "github.com/jhump/goprotoc/plugins" 15 | "github.com/jhump/protoreflect/desc" 16 | "google.golang.org/protobuf/types/pluginpb" 17 | ) 18 | 19 | func main() { 20 | plugins.PluginMain(doCodeGen) 21 | } 22 | 23 | func doCodeGen(req *plugins.CodeGenRequest, resp *plugins.CodeGenResponse) error { 24 | resp.SupportsFeatures(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL) 25 | args, err := parseArgs(req.Args) 26 | if err != nil { 27 | return err 28 | } 29 | names := plugins.GoNames{ 30 | ImportMap: args.importMap, 31 | ModuleRoot: args.moduleRoot, 32 | SourceRelative: args.sourceRelative, 33 | } 34 | if args.importPath != "" { 35 | // if we're overriding import path, go ahead and query 36 | // package for each file, which will cache the override name 37 | // so all subsequent queries are consistent 38 | for _, fd := range req.Files { 39 | // Only use the override for files that don't otherwise have an 40 | // entry in the specified import map 41 | if _, ok := args.importMap[fd.GetName()]; !ok { 42 | names.GoPackageForFileWithOverride(fd, args.importPath) 43 | } 44 | } 45 | } 46 | for _, fd := range req.Files { 47 | if err := generateChanStubs(fd, &names, resp, args); err != nil { 48 | if fe, ok := err.(*gopoet.FormatError); ok { 49 | if args.debug { 50 | return fmt.Errorf("%s: error in generated Go code: %v:\n%s", fd.GetName(), err, fe.Unformatted) 51 | } else { 52 | return fmt.Errorf("%s: error in generated Go code: %v (use debug=true arg to show full source)", fd.GetName(), err) 53 | } 54 | } else { 55 | return fmt.Errorf("%s: %v", fd.GetName(), err) 56 | } 57 | } 58 | } 59 | return nil 60 | } 61 | 62 | var typeOfRegistry = gopoet.NamedType(gopoet.NewSymbol("github.com/fullstorydev/grpchan", "ServiceRegistry")) 63 | var typeOfClientConn = gopoet.NamedType(gopoet.NewSymbol("google.golang.org/grpc", "ClientConnInterface")) 64 | var typeOfContext = gopoet.NamedType(gopoet.NewSymbol("context", "Context")) 65 | var typeOfCallOptions = gopoet.SliceType(gopoet.NamedType(gopoet.NewSymbol("google.golang.org/grpc", "CallOption"))) 66 | 67 | func generateChanStubs(fd *desc.FileDescriptor, names *plugins.GoNames, resp *plugins.CodeGenResponse, args codeGenArgs) error { 68 | if len(fd.GetServices()) == 0 { 69 | return nil 70 | } 71 | 72 | pkg := names.GoPackageForFile(fd) 73 | filename := names.OutputFilenameFor(fd, ".pb.grpchan.go") 74 | f := gopoet.NewGoFile(path.Base(filename), pkg.ImportPath, pkg.Name) 75 | 76 | f.FileComment = "Code generated by protoc-gen-grpchan. DO NOT EDIT.\n" + 77 | "source: " + fd.GetName() 78 | 79 | for _, sd := range fd.GetServices() { 80 | svcName := names.CamelCase(sd.GetName()) 81 | lowerSvcName := gopoet.Unexport(svcName) 82 | 83 | f.AddElement(gopoet.NewFunc(fmt.Sprintf("RegisterHandler%s", svcName)). 84 | SetComment(fmt.Sprintf("Deprecated: Use Register%sServer instead.", svcName)). 85 | AddArg("reg", typeOfRegistry). 86 | AddArg("srv", names.GoTypeForServiceServer(sd)). 87 | Printlnf("reg.RegisterService(&%s, srv)", serviceDescVarName(sd, names, args.legacyDescNames))) 88 | 89 | if !args.legacyStubs { 90 | continue 91 | } 92 | 93 | cc := gopoet.NewStructTypeSpec(fmt.Sprintf("%sChannelClient", lowerSvcName), 94 | gopoet.NewField("ch", typeOfClientConn)) 95 | f.AddType(cc) 96 | 97 | f.AddElement(gopoet.NewFunc(fmt.Sprintf("New%sChannelClient", svcName)). 98 | SetComment(fmt.Sprintf("Deprecated: Use New%sClient instead.", svcName)). 99 | AddArg("ch", typeOfClientConn). 100 | AddResult("", names.GoTypeForServiceClient(sd)). 101 | SetComment(fmt.Sprintf("Deprecated: Use New%sClient instead.", svcName)). 102 | Printlnf("return &%s{ch: ch}", cc)) 103 | 104 | streamCount := 0 105 | tmpls := templates{} 106 | for _, md := range sd.GetMethods() { 107 | methodInfo := struct { 108 | ServiceName string 109 | MethodName string 110 | ServiceDesc string 111 | StreamClient string 112 | StreamIndex int 113 | RequestType gopoet.TypeName 114 | }{ 115 | ServiceName: sd.GetFullyQualifiedName(), 116 | MethodName: md.GetName(), 117 | ServiceDesc: serviceDescVarName(sd, names, args.legacyDescNames), 118 | StreamClient: names.GoTypeForStreamClientImpl(md), 119 | StreamIndex: streamCount, 120 | RequestType: names.GoTypeForMessage(md.GetOutputType()), 121 | } 122 | mtdName := names.CamelCase(md.GetName()) 123 | if md.IsClientStreaming() { 124 | // bidi or client streaming method 125 | f.AddElement(gopoet.NewMethod(gopoet.NewPointerReceiverForType("c", cc), mtdName). 126 | AddArg("ctx", typeOfContext). 127 | AddArg("opts", typeOfCallOptions). 128 | SetVariadic(true). 129 | AddResult("", names.GoTypeForStreamClient(md)). 130 | AddResult("", gopoet.ErrorType). 131 | RenderCode(tmpls.makeTemplate( 132 | `stream, err := c.ch.NewStream(ctx, &{{.ServiceDesc}}.Streams[{{.StreamIndex}}], "/{{.ServiceName}}/{{.MethodName}}", opts...) 133 | if err != nil { 134 | return nil, err 135 | } 136 | x := &{{.StreamClient}}{stream} 137 | return x, nil`), &methodInfo)) 138 | streamCount++ 139 | } else if md.IsServerStreaming() { 140 | // server streaming method 141 | f.AddElement(gopoet.NewMethod(gopoet.NewPointerReceiverForType("c", cc), mtdName). 142 | AddArg("ctx", typeOfContext). 143 | AddArg("in", names.GoTypeOfRequest(md)). 144 | AddArg("opts", typeOfCallOptions). 145 | SetVariadic(true). 146 | AddResult("", names.GoTypeForStreamClient(md)). 147 | AddResult("", gopoet.ErrorType). 148 | RenderCode(tmpls.makeTemplate( 149 | `stream, err := c.ch.NewStream(ctx, &{{.ServiceDesc}}.Streams[{{.StreamIndex}}], "/{{.ServiceName}}/{{.MethodName}}", opts...) 150 | if err != nil { 151 | return nil, err 152 | } 153 | x := &{{.StreamClient}}{stream} 154 | if err := x.ClientStream.SendMsg(in); err != nil { 155 | return nil, err 156 | } 157 | if err := x.ClientStream.CloseSend(); err != nil { 158 | return nil, err 159 | } 160 | return x, nil`), &methodInfo)) 161 | streamCount++ 162 | } else { 163 | // unary method 164 | f.AddElement(gopoet.NewMethod(gopoet.NewPointerReceiverForType("c", cc), mtdName). 165 | AddArg("ctx", typeOfContext). 166 | AddArg("in", names.GoTypeOfRequest(md)). 167 | AddArg("opts", typeOfCallOptions). 168 | SetVariadic(true). 169 | AddResult("", names.GoTypeOfResponse(md)). 170 | AddResult("", gopoet.ErrorType). 171 | RenderCode(tmpls.makeTemplate( 172 | `out := new({{.RequestType}}) 173 | err := c.ch.Invoke(ctx, "/{{.ServiceName}}/{{.MethodName}}", in, out, opts...) 174 | if err != nil { 175 | return nil, err 176 | } 177 | return out, nil`), &methodInfo)) 178 | } 179 | } 180 | } 181 | 182 | out := resp.OutputFile(filename) 183 | return gopoet.WriteGoFile(out, f) 184 | } 185 | 186 | func serviceDescVarName(sd *desc.ServiceDescriptor, names *plugins.GoNames, legacyNames bool) string { 187 | if legacyNames { 188 | return names.GoNameOfServiceDesc(sd) 189 | } 190 | return names.GoNameOfExportedServiceDesc(sd).Name 191 | } 192 | 193 | type templates map[string]*template.Template 194 | 195 | func (t templates) makeTemplate(templateText string) *template.Template { 196 | tpl := t[templateText] 197 | if tpl == nil { 198 | tpl = template.Must(template.New("code").Parse(templateText)) 199 | t[templateText] = tpl 200 | } 201 | return tpl 202 | } 203 | 204 | type codeGenArgs struct { 205 | debug bool 206 | legacyStubs bool 207 | legacyDescNames bool 208 | importPath string 209 | importMap map[string]string 210 | moduleRoot string 211 | sourceRelative bool 212 | } 213 | 214 | func parseArgs(args []string) (codeGenArgs, error) { 215 | var result codeGenArgs 216 | for _, arg := range args { 217 | vals := strings.SplitN(arg, "=", 2) 218 | switch vals[0] { 219 | case "debug": 220 | val, err := boolVal(vals) 221 | if err != nil { 222 | return result, err 223 | } 224 | result.debug = val 225 | 226 | case "legacy_stubs": 227 | val, err := boolVal(vals) 228 | if err != nil { 229 | return result, err 230 | } 231 | result.legacyStubs = val 232 | 233 | case "legacy_desc_names": 234 | val, err := boolVal(vals) 235 | if err != nil { 236 | return result, err 237 | } 238 | result.legacyDescNames = val 239 | 240 | case "import_path": 241 | if len(vals) == 1 { 242 | return result, fmt.Errorf("plugin option 'import_path' requires an argument") 243 | } 244 | result.importPath = vals[1] 245 | 246 | case "module": 247 | if len(vals) == 1 { 248 | return result, fmt.Errorf("plugin option 'module' requires an argument") 249 | } 250 | result.moduleRoot = vals[1] 251 | 252 | case "paths": 253 | if len(vals) == 1 { 254 | return result, fmt.Errorf("plugin option 'paths' requires an argument") 255 | } 256 | switch vals[1] { 257 | case "import": 258 | result.sourceRelative = false 259 | case "source_relative": 260 | result.sourceRelative = true 261 | default: 262 | return result, fmt.Errorf("plugin option 'paths' accepts 'import' or 'source_relative' as value, got %q", vals[1]) 263 | } 264 | 265 | default: 266 | if len(vals[0]) > 1 && vals[0][0] == 'M' { 267 | if len(vals) == 1 { 268 | return result, fmt.Errorf("plugin 'M' options require an argument: %s", vals[0]) 269 | } 270 | if result.importMap == nil { 271 | result.importMap = map[string]string{} 272 | } 273 | result.importMap[vals[0][1:]] = vals[1] 274 | break 275 | } 276 | 277 | return result, fmt.Errorf("unknown plugin option: %s", vals[0]) 278 | } 279 | } 280 | 281 | if result.sourceRelative && result.moduleRoot != "" { 282 | return result, fmt.Errorf("plugin option 'module' cannot be used with 'paths=source_relative'") 283 | } 284 | 285 | return result, nil 286 | } 287 | 288 | func boolVal(vals []string) (bool, error) { 289 | if len(vals) == 1 { 290 | // if no value, assume "true" 291 | return true, nil 292 | } 293 | switch strings.ToLower(vals[1]) { 294 | case "true", "on", "yes", "1": 295 | return true, nil 296 | case "false", "off", "no", "0": 297 | return false, nil 298 | default: 299 | return false, fmt.Errorf("invalid boolean arg for option '%s': %s", vals[0], vals[1]) 300 | } 301 | } 302 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // Package grpchan provides an abstraction for a gRPC transport, called a 2 | // Channel. The channel is more general than the concrete *grpc.ClientConn 3 | // and *grpc.Server types provided by gRPC. It allows gRPC over alternate 4 | // substrates and includes sub-packages that provide two such alternatives: 5 | // in-process and HTTP 1.1. 6 | // 7 | // The key type in this package is an alternate implementation of 8 | // grpc.ServiceRegistrar interface that allows you to accumulate service 9 | // registrations, for use with an implementation other than *grpc.Server. 10 | // 11 | // # Protoc Plugin 12 | // 13 | // This repo also includes a deprecated protoc plugin. This is no longer 14 | // needed now that the standard protoc-gen-go-grpc plugin generates code that 15 | // uses interfaces: grpc.ClientConnInterface and grpc.ServiceRegistrar. In 16 | // older versions, the generated code only supported concrete types 17 | // (*grpc.ClientConn and *grpc.Server) so this repo's protoc plugin would 18 | // generate alternate code that used interfaces (and thus supported other 19 | // concrete implementations). 20 | // 21 | // Continued use of the plugin is only to continue supporting code that 22 | // uses the functions generated by it. 23 | // 24 | // To use the protoc plugin, you need to first build it and make sure its 25 | // location is in your PATH. 26 | // 27 | // go install github.com/fullstorydev/grpchan/cmd/protoc-gen-grpchan 28 | // # If necessary, make sure its location is on your path like so: 29 | // # export PATH=$PATH:$GOPATH/bin 30 | // 31 | // When you invoke protoc, include a --grpchan_out parameter that indicates 32 | // the same output directory as used for your --go_out parameter. Alongside 33 | // the *.pb.go files generated, the grpchan plugin will also create 34 | // *.pb.grpchan.go files. 35 | // 36 | // In older versions of the Go plugin (when emitting gRPC code), a server 37 | // registration function for each RPC service defined in the proto source files 38 | // was generated that looked like so: 39 | // 40 | // // A function with this signature is generated, for registering 41 | // // server handlers with the given server. 42 | // func RegisterServer(s *grpc.Server, srv Server) { 43 | // s.RegisterService(&__serviceDesc, srv) 44 | // } 45 | // 46 | // The grpchan plugin produces a similarly named method that accepts the 47 | // ServiceRegistry interface: 48 | // 49 | // func RegisterHandler(sr grpchan.ServiceRegistry, srv Server) { 50 | // s.RegisterService(&__serviceDesc, srv) 51 | // } 52 | // 53 | // A new transport can then be implemented by just implementing two interfaces: 54 | // grpc.ClientConnInterface for the client side and grpchan.ServiceRegistry for 55 | // the server side. 56 | // 57 | // The alternate method also works just fine with *grpc.Server as it implements 58 | // the ServiceRegistry interface. 59 | // 60 | // NOTE: If your have code relying on NewChannelClient methods that 61 | // earlier versions of this package produced, they can still be generated by passing 62 | // a "legacy_stubs" option to the plugin. Example: 63 | // 64 | // protoc foo.proto --grpchan_out=legacy_stubs:./output/dir 65 | // 66 | // # Client-Side Channels 67 | // 68 | // The client-side implementation of a transport is done with just the two 69 | // methods in grpc.ClientConnInterface: one for unary RPCs and the other for 70 | // streaming RPCs. 71 | // 72 | // Note that when a unary interceptor is invoked for an RPC on a channel that 73 | // is *not* a *grpc.ClientConn, the parameter of that type will be nil. 74 | // 75 | // Not all client call options will make sense for all transports. This repo 76 | // chooses to ignore call options that do not apply (as opposed to failing 77 | // the RPC or panicking). However, several call options are likely important 78 | // to support: those for accessing header and trailer metadata. The peer, 79 | // per-RPC credentials, and message size limits are other options that are 80 | // reasonably straight-forward to apply to other transports. But the other 81 | // options (dealing with on-the-wire encoding, compression, etc) may not be 82 | // applicable. 83 | // 84 | // # Server-Side Service Registries 85 | // 86 | // The server-side implementation of a transport must be able to invoke 87 | // method and stream handlers for a given service implementation. This is done 88 | // by implementing the grpc.ServiceRegistrar interface. When a service is 89 | // registered, a service description is provided that includes access to method 90 | // and stream handlers. When the transport receives requests for RPC operations, 91 | // it in turn invokes these handlers. For streaming operations, it must also 92 | // supply a grpc.ServerStream implementation, for exchanging messages on the 93 | // stream. 94 | // 95 | // Note that the server stream's context will need a custom implementation of 96 | // the grpc.ServerTransportStream in it, too. Sadly, this interface is just 97 | // different enough from grpc.ServerStream that they cannot be implemented by 98 | // the same type. This is particularly necessary for unary calls since this is 99 | // how a unary handler indicates what headers and trailers to send back to the 100 | // client. 101 | package grpchan 102 | -------------------------------------------------------------------------------- /download_protoc.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | cd $(dirname $0) 6 | 7 | if [[ -z "$PROTOC_VERSION" ]]; then 8 | echo "Set PROTOC_VERSION env var to indicate the version to download" >&2 9 | exit 1 10 | fi 11 | PROTOC_OS="$(uname -s)" 12 | PROTOC_ARCH="$(uname -m)" 13 | case "${PROTOC_OS}" in 14 | Darwin) PROTOC_OS="osx" ;; 15 | Linux) PROTOC_OS="linux" ;; 16 | *) 17 | echo "Invalid value for uname -s: ${PROTOC_OS}" >&2 18 | exit 1 19 | esac 20 | 21 | # This is for macs with M1 chips. Precompiled binaries for osx/amd64 are not available for download, so for that case 22 | # we download the x86_64 version instead. This will work as long as rosetta2 is installed. 23 | if [ "$PROTOC_OS" = "osx" ] && [ "$PROTOC_ARCH" = "arm64" ]; then 24 | PROTOC_ARCH="x86_64" 25 | fi 26 | 27 | PROTOC="${PWD}/.tmp/protoc/bin/protoc" 28 | 29 | if [[ "$(${PROTOC} --version 2>/dev/null)" != "libprotoc 3.${PROTOC_VERSION}" ]]; then 30 | rm -rf ./.tmp/protoc 31 | mkdir -p .tmp/protoc 32 | curl -L "https://github.com/google/protobuf/releases/download/v${PROTOC_VERSION}/protoc-${PROTOC_VERSION}-${PROTOC_OS}-${PROTOC_ARCH}.zip" > .tmp/protoc/protoc.zip 33 | pushd ./.tmp/protoc && unzip protoc.zip && popd 34 | touch -c ./.tmp/protoc/bin/protoc 35 | fi 36 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/fullstorydev/grpchan 2 | 3 | go 1.23.0 4 | 5 | toolchain go1.23.5 6 | 7 | require ( 8 | github.com/golang/protobuf v1.5.4 9 | github.com/jhump/gopoet v0.1.0 10 | github.com/jhump/goprotoc v0.5.0 11 | github.com/jhump/protoreflect v1.15.6 12 | google.golang.org/genproto/googleapis/rpc v0.0.0-20241202173237-19429a94021a 13 | google.golang.org/grpc v1.70.0 14 | google.golang.org/protobuf v1.35.2 15 | ) 16 | 17 | require ( 18 | github.com/bufbuild/protocompile v0.9.0 // indirect 19 | golang.org/x/net v0.38.0 // indirect 20 | golang.org/x/sys v0.31.0 // indirect 21 | golang.org/x/text v0.23.0 // indirect 22 | ) 23 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= 2 | github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= 3 | github.com/bufbuild/protocompile v0.9.0 h1:DI8qLG5PEO0Mu1Oj51YFPqtx6I3qYXUAhJVJ/IzAVl0= 4 | github.com/bufbuild/protocompile v0.9.0/go.mod h1:s89m1O8CqSYpyE/YaSGtg1r1YFMF5nLTwh4vlj6O444= 5 | github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= 6 | github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= 7 | github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= 8 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 9 | github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= 10 | github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= 11 | github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= 12 | github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= 13 | github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= 14 | github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= 15 | github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= 16 | github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= 17 | github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= 18 | github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= 19 | github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= 20 | github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= 21 | github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= 22 | github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= 23 | github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= 24 | github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= 25 | github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= 26 | github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= 27 | github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= 28 | github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= 29 | github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= 30 | github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= 31 | github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= 32 | github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= 33 | github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= 34 | github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 35 | github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 36 | github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 37 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= 38 | github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 39 | github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 40 | github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= 41 | github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 42 | github.com/jhump/gopoet v0.0.0-20190322174617-17282ff210b3/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI= 43 | github.com/jhump/gopoet v0.1.0 h1:gYjOPnzHd2nzB37xYQZxj4EIQNpBrBskRqQQ3q4ZgSg= 44 | github.com/jhump/gopoet v0.1.0/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI= 45 | github.com/jhump/goprotoc v0.5.0 h1:Y1UgUX+txUznfqcGdDef8ZOVlyQvnV0pKWZH08RmZuo= 46 | github.com/jhump/goprotoc v0.5.0/go.mod h1:VrbvcYrQOrTi3i0Vf+m+oqQWk9l72mjkJCYo7UvLHRQ= 47 | github.com/jhump/protoreflect v1.11.0/go.mod h1:U7aMIjN0NWq9swDP7xDdoMfRHb35uiuTd3Z9nFXJf5E= 48 | github.com/jhump/protoreflect v1.15.6 h1:WMYJbw2Wo+KOWwZFvgY0jMoVHM6i4XIvRs2RcBj5VmI= 49 | github.com/jhump/protoreflect v1.15.6/go.mod h1:jCHoyYQIJnaabEYnbGwyo9hUqfyUMTbJw/tAut5t97E= 50 | github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= 51 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 52 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 53 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 54 | github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= 55 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 56 | github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= 57 | go.opentelemetry.io/otel v1.32.0 h1:WnBN+Xjcteh0zdk01SVqV55d/m62NJLJdIyb4y/WO5U= 58 | go.opentelemetry.io/otel v1.32.0/go.mod h1:00DCVSB0RQcnzlwyTfqtxSm+DRr9hpYrHjNGiBHVQIg= 59 | go.opentelemetry.io/otel/metric v1.32.0 h1:xV2umtmNcThh2/a/aCP+h64Xx5wsj8qqnkYZktzNa0M= 60 | go.opentelemetry.io/otel/metric v1.32.0/go.mod h1:jH7CIbbK6SH2V2wE16W05BHCtIDzauciCRLoc/SyMv8= 61 | go.opentelemetry.io/otel/sdk v1.32.0 h1:RNxepc9vK59A8XsgZQouW8ue8Gkb4jpWtJm9ge5lEG4= 62 | go.opentelemetry.io/otel/sdk v1.32.0/go.mod h1:LqgegDBjKMmb2GC6/PrTnteJG39I8/vJCAP9LlJXEjU= 63 | go.opentelemetry.io/otel/sdk/metric v1.32.0 h1:rZvFnvmvawYb0alrYkjraqJq0Z4ZUJAiyYCU9snn1CU= 64 | go.opentelemetry.io/otel/sdk/metric v1.32.0/go.mod h1:PWeZlq0zt9YkYAp3gjKZ0eicRYvOh1Gd+X99x6GHpCQ= 65 | go.opentelemetry.io/otel/trace v1.32.0 h1:WIC9mYrXf8TmY/EXuULKc8hR17vE+Hjv2cssQDe03fM= 66 | go.opentelemetry.io/otel/trace v1.32.0/go.mod h1:+i4rkvCraA+tG6AzwloGaCtkx53Fa+L+V8e9a7YvhT8= 67 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 68 | golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= 69 | golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= 70 | golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= 71 | golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= 72 | golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= 73 | golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= 74 | golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= 75 | golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= 76 | golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 77 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 78 | golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= 79 | golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= 80 | golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= 81 | golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= 82 | golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 83 | golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 84 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 85 | golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 86 | golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= 87 | golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= 88 | golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 89 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 90 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 91 | golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 92 | golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= 93 | golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= 94 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 95 | golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= 96 | golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= 97 | golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= 98 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 99 | golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 100 | golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= 101 | golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= 102 | golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= 103 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 104 | google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= 105 | google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= 106 | google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= 107 | google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= 108 | google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= 109 | google.golang.org/genproto/googleapis/rpc v0.0.0-20241202173237-19429a94021a h1:hgh8P4EuoxpsuKMXX/To36nOFD7vixReXgn8lPGnt+o= 110 | google.golang.org/genproto/googleapis/rpc v0.0.0-20241202173237-19429a94021a/go.mod h1:5uTbfoYQed2U9p3KIj2/Zzm02PYhndfdmML0qC3q3FU= 111 | google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= 112 | google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= 113 | google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= 114 | google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= 115 | google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= 116 | google.golang.org/grpc v1.70.0 h1:pWFv03aZoHzlRKHWicjsZytKAiYCtNS0dHbXnIdq7jQ= 117 | google.golang.org/grpc v1.70.0/go.mod h1:ofIJqVKDXx/JiXrwr2IG4/zwdH9txy3IlF40RmcJSQw= 118 | google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= 119 | google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= 120 | google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= 121 | google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= 122 | google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= 123 | google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= 124 | google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= 125 | google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= 126 | google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= 127 | google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= 128 | google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= 129 | google.golang.org/protobuf v1.35.2 h1:8Ar7bF+apOIoThw1EdZl0p1oWvMqTHmpA2fRTyZO8io= 130 | google.golang.org/protobuf v1.35.2/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= 131 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 132 | gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 133 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 134 | gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 135 | honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= 136 | honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= 137 | -------------------------------------------------------------------------------- /grpchantesting/channel_test_cases_test.go: -------------------------------------------------------------------------------- 1 | package grpchantesting 2 | 3 | import ( 4 | "context" 5 | "google.golang.org/grpc/credentials/insecure" 6 | "net" 7 | "testing" 8 | "time" 9 | 10 | "google.golang.org/grpc" 11 | ) 12 | 13 | // We test all of our channel test cases by running them against a normal 14 | // *grpc.Server and *grpc.ClientConn, to make sure they are asserting the 15 | // same behavior exhibited by the standard HTTP/2 channel implementation. 16 | func TestChannelTestCases(t *testing.T) { 17 | s := grpc.NewServer() 18 | RegisterTestServiceServer(s, &TestServer{}) 19 | 20 | l, err := net.Listen("tcp", "127.0.0.1:0") 21 | if err != nil { 22 | t.Fatalf("failed to listen on socket: %v", err) 23 | } 24 | go s.Serve(l) 25 | defer s.Stop() 26 | 27 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 28 | defer cancel() 29 | 30 | addr := l.Addr().String() 31 | cc, err := grpc.DialContext(ctx, addr, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock(), grpc.FailOnNonTempDialError(true)) 32 | if err != nil { 33 | t.Fatalf("failed to dial address: %s", addr) 34 | } 35 | defer cc.Close() 36 | 37 | RunChannelTestCases(t, cc, true) 38 | } 39 | -------------------------------------------------------------------------------- /grpchantesting/doc.go: -------------------------------------------------------------------------------- 1 | // Package grpchantesting helps with testing implementations of alternate gRPC 2 | // transports. Its main value is in a method that, given a channel, will ensure 3 | // the channel behaves correctly under various conditions. 4 | // 5 | // It tests successful RPCs, failures, timeouts and client-side cancellations. 6 | // It also covers all kinds of RPCs: unary, client-streaming, server-streaming 7 | // and bidirectional-streaming. It can optionally test full-duplex bidi streams 8 | // if the underlying channel supports that. 9 | // 10 | // The channel must be connected to a server that exposes the test server 11 | // implementation contained in this package: &grpchantesting.TestServer{} 12 | package grpchantesting 13 | -------------------------------------------------------------------------------- /grpchantesting/test.pb.go: -------------------------------------------------------------------------------- 1 | // Code generated by protoc-gen-go. DO NOT EDIT. 2 | // versions: 3 | // protoc-gen-go v1.26.0 4 | // protoc v4.22.0 5 | // source: test.proto 6 | 7 | package grpchantesting 8 | 9 | import ( 10 | protoreflect "google.golang.org/protobuf/reflect/protoreflect" 11 | protoimpl "google.golang.org/protobuf/runtime/protoimpl" 12 | anypb "google.golang.org/protobuf/types/known/anypb" 13 | emptypb "google.golang.org/protobuf/types/known/emptypb" 14 | reflect "reflect" 15 | sync "sync" 16 | ) 17 | 18 | const ( 19 | // Verify that this generated code is sufficiently up-to-date. 20 | _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) 21 | // Verify that runtime/protoimpl is sufficiently up-to-date. 22 | _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) 23 | ) 24 | 25 | type Message struct { 26 | state protoimpl.MessageState 27 | sizeCache protoimpl.SizeCache 28 | unknownFields protoimpl.UnknownFields 29 | 30 | Payload []byte `protobuf:"bytes,1,opt,name=payload,proto3" json:"payload,omitempty"` 31 | Count int32 `protobuf:"varint,2,opt,name=count,proto3" json:"count,omitempty"` 32 | Code int32 `protobuf:"varint,3,opt,name=code,proto3" json:"code,omitempty"` 33 | DelayMillis int32 `protobuf:"varint,4,opt,name=delay_millis,json=delayMillis,proto3" json:"delay_millis,omitempty"` 34 | Headers map[string][]byte `protobuf:"bytes,5,rep,name=headers,proto3" json:"headers,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` 35 | Trailers map[string][]byte `protobuf:"bytes,6,rep,name=trailers,proto3" json:"trailers,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` 36 | ErrorDetails []*anypb.Any `protobuf:"bytes,7,rep,name=error_details,json=errorDetails,proto3" json:"error_details,omitempty"` 37 | } 38 | 39 | func (x *Message) Reset() { 40 | *x = Message{} 41 | if protoimpl.UnsafeEnabled { 42 | mi := &file_test_proto_msgTypes[0] 43 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 44 | ms.StoreMessageInfo(mi) 45 | } 46 | } 47 | 48 | func (x *Message) String() string { 49 | return protoimpl.X.MessageStringOf(x) 50 | } 51 | 52 | func (*Message) ProtoMessage() {} 53 | 54 | func (x *Message) ProtoReflect() protoreflect.Message { 55 | mi := &file_test_proto_msgTypes[0] 56 | if protoimpl.UnsafeEnabled && x != nil { 57 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 58 | if ms.LoadMessageInfo() == nil { 59 | ms.StoreMessageInfo(mi) 60 | } 61 | return ms 62 | } 63 | return mi.MessageOf(x) 64 | } 65 | 66 | // Deprecated: Use Message.ProtoReflect.Descriptor instead. 67 | func (*Message) Descriptor() ([]byte, []int) { 68 | return file_test_proto_rawDescGZIP(), []int{0} 69 | } 70 | 71 | func (x *Message) GetPayload() []byte { 72 | if x != nil { 73 | return x.Payload 74 | } 75 | return nil 76 | } 77 | 78 | func (x *Message) GetCount() int32 { 79 | if x != nil { 80 | return x.Count 81 | } 82 | return 0 83 | } 84 | 85 | func (x *Message) GetCode() int32 { 86 | if x != nil { 87 | return x.Code 88 | } 89 | return 0 90 | } 91 | 92 | func (x *Message) GetDelayMillis() int32 { 93 | if x != nil { 94 | return x.DelayMillis 95 | } 96 | return 0 97 | } 98 | 99 | func (x *Message) GetHeaders() map[string][]byte { 100 | if x != nil { 101 | return x.Headers 102 | } 103 | return nil 104 | } 105 | 106 | func (x *Message) GetTrailers() map[string][]byte { 107 | if x != nil { 108 | return x.Trailers 109 | } 110 | return nil 111 | } 112 | 113 | func (x *Message) GetErrorDetails() []*anypb.Any { 114 | if x != nil { 115 | return x.ErrorDetails 116 | } 117 | return nil 118 | } 119 | 120 | var File_test_proto protoreflect.FileDescriptor 121 | 122 | var file_test_proto_rawDesc = []byte{ 123 | 0x0a, 0x0a, 0x74, 0x65, 0x73, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0e, 0x67, 0x72, 124 | 0x70, 0x63, 0x68, 0x61, 0x6e, 0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67, 0x1a, 0x19, 0x67, 0x6f, 125 | 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x61, 0x6e, 126 | 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1b, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 127 | 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x65, 0x6d, 0x70, 0x74, 0x79, 0x2e, 0x70, 128 | 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xa7, 0x03, 0x0a, 0x07, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 129 | 0x12, 0x18, 0x0a, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 130 | 0x0c, 0x52, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x63, 0x6f, 131 | 0x75, 0x6e, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x05, 0x63, 0x6f, 0x75, 0x6e, 0x74, 132 | 0x12, 0x12, 0x0a, 0x04, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x04, 133 | 0x63, 0x6f, 0x64, 0x65, 0x12, 0x21, 0x0a, 0x0c, 0x64, 0x65, 0x6c, 0x61, 0x79, 0x5f, 0x6d, 0x69, 134 | 0x6c, 0x6c, 0x69, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0b, 0x64, 0x65, 0x6c, 0x61, 135 | 0x79, 0x4d, 0x69, 0x6c, 0x6c, 0x69, 0x73, 0x12, 0x3e, 0x0a, 0x07, 0x68, 0x65, 0x61, 0x64, 0x65, 136 | 0x72, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x24, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x68, 137 | 0x61, 0x6e, 0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 138 | 0x65, 0x2e, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x07, 139 | 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x12, 0x41, 0x0a, 0x08, 0x74, 0x72, 0x61, 0x69, 0x6c, 140 | 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x25, 0x2e, 0x67, 0x72, 0x70, 0x63, 141 | 0x68, 0x61, 0x6e, 0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, 142 | 0x67, 0x65, 0x2e, 0x54, 0x72, 0x61, 0x69, 0x6c, 0x65, 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 143 | 0x52, 0x08, 0x74, 0x72, 0x61, 0x69, 0x6c, 0x65, 0x72, 0x73, 0x12, 0x39, 0x0a, 0x0d, 0x65, 0x72, 144 | 0x72, 0x6f, 0x72, 0x5f, 0x64, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 145 | 0x0b, 0x32, 0x14, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 146 | 0x62, 0x75, 0x66, 0x2e, 0x41, 0x6e, 0x79, 0x52, 0x0c, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x44, 0x65, 147 | 0x74, 0x61, 0x69, 0x6c, 0x73, 0x1a, 0x3a, 0x0a, 0x0c, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 148 | 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 149 | 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 150 | 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 151 | 0x01, 0x1a, 0x3b, 0x0a, 0x0d, 0x54, 0x72, 0x61, 0x69, 0x6c, 0x65, 0x72, 0x73, 0x45, 0x6e, 0x74, 152 | 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 153 | 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 154 | 0x01, 0x28, 0x0c, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x32, 0xdf, 155 | 0x02, 0x0a, 0x0b, 0x54, 0x65, 0x73, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x39, 156 | 0x0a, 0x05, 0x55, 0x6e, 0x61, 0x72, 0x79, 0x12, 0x17, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x68, 0x61, 157 | 0x6e, 0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 158 | 0x1a, 0x17, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x68, 0x61, 0x6e, 0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 159 | 0x67, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x42, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 160 | 0x65, 0x6e, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x17, 0x2e, 0x67, 0x72, 0x70, 0x63, 161 | 0x68, 0x61, 0x6e, 0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, 162 | 0x67, 0x65, 0x1a, 0x17, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x68, 0x61, 0x6e, 0x74, 0x65, 0x73, 0x74, 163 | 0x69, 0x6e, 0x67, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x28, 0x01, 0x12, 0x42, 0x0a, 164 | 0x0c, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x17, 0x2e, 165 | 0x67, 0x72, 0x70, 0x63, 0x68, 0x61, 0x6e, 0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67, 0x2e, 0x4d, 166 | 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x17, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x68, 0x61, 0x6e, 167 | 0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x30, 168 | 0x01, 0x12, 0x42, 0x0a, 0x0a, 0x42, 0x69, 0x64, 0x69, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 169 | 0x17, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x68, 0x61, 0x6e, 0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67, 170 | 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x17, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x68, 171 | 0x61, 0x6e, 0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 172 | 0x65, 0x28, 0x01, 0x30, 0x01, 0x12, 0x49, 0x0a, 0x17, 0x55, 0x73, 0x65, 0x45, 0x78, 0x74, 0x65, 173 | 0x72, 0x6e, 0x61, 0x6c, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x54, 0x77, 0x69, 0x63, 0x65, 174 | 0x12, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 175 | 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 176 | 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 177 | 0x42, 0x13, 0x5a, 0x11, 0x2e, 0x2f, 0x3b, 0x67, 0x72, 0x70, 0x63, 0x68, 0x61, 0x6e, 0x74, 0x65, 178 | 0x73, 0x74, 0x69, 0x6e, 0x67, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, 179 | } 180 | 181 | var ( 182 | file_test_proto_rawDescOnce sync.Once 183 | file_test_proto_rawDescData = file_test_proto_rawDesc 184 | ) 185 | 186 | func file_test_proto_rawDescGZIP() []byte { 187 | file_test_proto_rawDescOnce.Do(func() { 188 | file_test_proto_rawDescData = protoimpl.X.CompressGZIP(file_test_proto_rawDescData) 189 | }) 190 | return file_test_proto_rawDescData 191 | } 192 | 193 | var file_test_proto_msgTypes = make([]protoimpl.MessageInfo, 3) 194 | var file_test_proto_goTypes = []interface{}{ 195 | (*Message)(nil), // 0: grpchantesting.Message 196 | nil, // 1: grpchantesting.Message.HeadersEntry 197 | nil, // 2: grpchantesting.Message.TrailersEntry 198 | (*anypb.Any)(nil), // 3: google.protobuf.Any 199 | (*emptypb.Empty)(nil), // 4: google.protobuf.Empty 200 | } 201 | var file_test_proto_depIdxs = []int32{ 202 | 1, // 0: grpchantesting.Message.headers:type_name -> grpchantesting.Message.HeadersEntry 203 | 2, // 1: grpchantesting.Message.trailers:type_name -> grpchantesting.Message.TrailersEntry 204 | 3, // 2: grpchantesting.Message.error_details:type_name -> google.protobuf.Any 205 | 0, // 3: grpchantesting.TestService.Unary:input_type -> grpchantesting.Message 206 | 0, // 4: grpchantesting.TestService.ClientStream:input_type -> grpchantesting.Message 207 | 0, // 5: grpchantesting.TestService.ServerStream:input_type -> grpchantesting.Message 208 | 0, // 6: grpchantesting.TestService.BidiStream:input_type -> grpchantesting.Message 209 | 4, // 7: grpchantesting.TestService.UseExternalMessageTwice:input_type -> google.protobuf.Empty 210 | 0, // 8: grpchantesting.TestService.Unary:output_type -> grpchantesting.Message 211 | 0, // 9: grpchantesting.TestService.ClientStream:output_type -> grpchantesting.Message 212 | 0, // 10: grpchantesting.TestService.ServerStream:output_type -> grpchantesting.Message 213 | 0, // 11: grpchantesting.TestService.BidiStream:output_type -> grpchantesting.Message 214 | 4, // 12: grpchantesting.TestService.UseExternalMessageTwice:output_type -> google.protobuf.Empty 215 | 8, // [8:13] is the sub-list for method output_type 216 | 3, // [3:8] is the sub-list for method input_type 217 | 3, // [3:3] is the sub-list for extension type_name 218 | 3, // [3:3] is the sub-list for extension extendee 219 | 0, // [0:3] is the sub-list for field type_name 220 | } 221 | 222 | func init() { file_test_proto_init() } 223 | func file_test_proto_init() { 224 | if File_test_proto != nil { 225 | return 226 | } 227 | if !protoimpl.UnsafeEnabled { 228 | file_test_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { 229 | switch v := v.(*Message); i { 230 | case 0: 231 | return &v.state 232 | case 1: 233 | return &v.sizeCache 234 | case 2: 235 | return &v.unknownFields 236 | default: 237 | return nil 238 | } 239 | } 240 | } 241 | type x struct{} 242 | out := protoimpl.TypeBuilder{ 243 | File: protoimpl.DescBuilder{ 244 | GoPackagePath: reflect.TypeOf(x{}).PkgPath(), 245 | RawDescriptor: file_test_proto_rawDesc, 246 | NumEnums: 0, 247 | NumMessages: 3, 248 | NumExtensions: 0, 249 | NumServices: 1, 250 | }, 251 | GoTypes: file_test_proto_goTypes, 252 | DependencyIndexes: file_test_proto_depIdxs, 253 | MessageInfos: file_test_proto_msgTypes, 254 | }.Build() 255 | File_test_proto = out.File 256 | file_test_proto_rawDesc = nil 257 | file_test_proto_goTypes = nil 258 | file_test_proto_depIdxs = nil 259 | } 260 | -------------------------------------------------------------------------------- /grpchantesting/test.pb.grpchan.go: -------------------------------------------------------------------------------- 1 | // Code generated by protoc-gen-grpchan. DO NOT EDIT. 2 | // source: test.proto 3 | 4 | package grpchantesting 5 | 6 | import "context" 7 | import "github.com/fullstorydev/grpchan" 8 | import "google.golang.org/grpc" 9 | import "google.golang.org/protobuf/types/known/emptypb" 10 | 11 | // Deprecated: Use RegisterTestServiceServer instead. 12 | func RegisterHandlerTestService(reg grpchan.ServiceRegistry, srv TestServiceServer) { 13 | reg.RegisterService(&TestService_ServiceDesc, srv) 14 | } 15 | 16 | type testServiceChannelClient struct { 17 | ch grpc.ClientConnInterface 18 | } 19 | 20 | // Deprecated: Use NewTestServiceClient instead. 21 | func NewTestServiceChannelClient(ch grpc.ClientConnInterface) TestServiceClient { 22 | return &testServiceChannelClient{ch: ch} 23 | } 24 | 25 | func (c *testServiceChannelClient) Unary(ctx context.Context, in *Message, opts ...grpc.CallOption) (*Message, error) { 26 | out := new(Message) 27 | err := c.ch.Invoke(ctx, "/grpchantesting.TestService/Unary", in, out, opts...) 28 | if err != nil { 29 | return nil, err 30 | } 31 | return out, nil 32 | } 33 | 34 | func (c *testServiceChannelClient) ClientStream(ctx context.Context, opts ...grpc.CallOption) (TestService_ClientStreamClient, error) { 35 | stream, err := c.ch.NewStream(ctx, &TestService_ServiceDesc.Streams[0], "/grpchantesting.TestService/ClientStream", opts...) 36 | if err != nil { 37 | return nil, err 38 | } 39 | x := &testServiceClientStreamClient{stream} 40 | return x, nil 41 | } 42 | 43 | func (c *testServiceChannelClient) ServerStream(ctx context.Context, in *Message, opts ...grpc.CallOption) (TestService_ServerStreamClient, error) { 44 | stream, err := c.ch.NewStream(ctx, &TestService_ServiceDesc.Streams[1], "/grpchantesting.TestService/ServerStream", opts...) 45 | if err != nil { 46 | return nil, err 47 | } 48 | x := &testServiceServerStreamClient{stream} 49 | if err := x.ClientStream.SendMsg(in); err != nil { 50 | return nil, err 51 | } 52 | if err := x.ClientStream.CloseSend(); err != nil { 53 | return nil, err 54 | } 55 | return x, nil 56 | } 57 | 58 | func (c *testServiceChannelClient) BidiStream(ctx context.Context, opts ...grpc.CallOption) (TestService_BidiStreamClient, error) { 59 | stream, err := c.ch.NewStream(ctx, &TestService_ServiceDesc.Streams[2], "/grpchantesting.TestService/BidiStream", opts...) 60 | if err != nil { 61 | return nil, err 62 | } 63 | x := &testServiceBidiStreamClient{stream} 64 | return x, nil 65 | } 66 | 67 | func (c *testServiceChannelClient) UseExternalMessageTwice(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*emptypb.Empty, error) { 68 | out := new(emptypb.Empty) 69 | err := c.ch.Invoke(ctx, "/grpchantesting.TestService/UseExternalMessageTwice", in, out, opts...) 70 | if err != nil { 71 | return nil, err 72 | } 73 | return out, nil 74 | } 75 | -------------------------------------------------------------------------------- /grpchantesting/test.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "google/protobuf/any.proto"; 4 | import "google/protobuf/empty.proto"; 5 | 6 | package grpchantesting; 7 | 8 | option go_package = "./;grpchantesting"; 9 | 10 | service TestService { 11 | rpc Unary (Message) returns (Message); 12 | rpc ClientStream (stream Message) returns (Message); 13 | rpc ServerStream (Message) returns (stream Message); 14 | rpc BidiStream (stream Message) returns (stream Message); 15 | 16 | // UseExternalMessageTwice is here purely to test the protoc-gen-grpchan plug-in 17 | rpc UseExternalMessageTwice (google.protobuf.Empty) returns (google.protobuf.Empty); 18 | } 19 | 20 | message Message { 21 | bytes payload = 1; 22 | int32 count = 2; 23 | int32 code = 3; 24 | int32 delay_millis = 4; 25 | map headers = 5; 26 | map trailers = 6; 27 | repeated google.protobuf.Any error_details = 7; 28 | } 29 | -------------------------------------------------------------------------------- /grpchantesting/test_grpc.pb.go: -------------------------------------------------------------------------------- 1 | // Code generated by protoc-gen-go-grpc. DO NOT EDIT. 2 | 3 | package grpchantesting 4 | 5 | import ( 6 | context "context" 7 | grpc "google.golang.org/grpc" 8 | codes "google.golang.org/grpc/codes" 9 | status "google.golang.org/grpc/status" 10 | emptypb "google.golang.org/protobuf/types/known/emptypb" 11 | ) 12 | 13 | // This is a compile-time assertion to ensure that this generated file 14 | // is compatible with the grpc package it is being compiled against. 15 | // Requires gRPC-Go v1.32.0 or later. 16 | const _ = grpc.SupportPackageIsVersion7 17 | 18 | // TestServiceClient is the client API for TestService service. 19 | // 20 | // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. 21 | type TestServiceClient interface { 22 | Unary(ctx context.Context, in *Message, opts ...grpc.CallOption) (*Message, error) 23 | ClientStream(ctx context.Context, opts ...grpc.CallOption) (TestService_ClientStreamClient, error) 24 | ServerStream(ctx context.Context, in *Message, opts ...grpc.CallOption) (TestService_ServerStreamClient, error) 25 | BidiStream(ctx context.Context, opts ...grpc.CallOption) (TestService_BidiStreamClient, error) 26 | // UseExternalMessageTwice is here purely to test the protoc-gen-grpchan plug-in 27 | UseExternalMessageTwice(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*emptypb.Empty, error) 28 | } 29 | 30 | type testServiceClient struct { 31 | cc grpc.ClientConnInterface 32 | } 33 | 34 | func NewTestServiceClient(cc grpc.ClientConnInterface) TestServiceClient { 35 | return &testServiceClient{cc} 36 | } 37 | 38 | func (c *testServiceClient) Unary(ctx context.Context, in *Message, opts ...grpc.CallOption) (*Message, error) { 39 | out := new(Message) 40 | err := c.cc.Invoke(ctx, "/grpchantesting.TestService/Unary", in, out, opts...) 41 | if err != nil { 42 | return nil, err 43 | } 44 | return out, nil 45 | } 46 | 47 | func (c *testServiceClient) ClientStream(ctx context.Context, opts ...grpc.CallOption) (TestService_ClientStreamClient, error) { 48 | stream, err := c.cc.NewStream(ctx, &TestService_ServiceDesc.Streams[0], "/grpchantesting.TestService/ClientStream", opts...) 49 | if err != nil { 50 | return nil, err 51 | } 52 | x := &testServiceClientStreamClient{stream} 53 | return x, nil 54 | } 55 | 56 | type TestService_ClientStreamClient interface { 57 | Send(*Message) error 58 | CloseAndRecv() (*Message, error) 59 | grpc.ClientStream 60 | } 61 | 62 | type testServiceClientStreamClient struct { 63 | grpc.ClientStream 64 | } 65 | 66 | func (x *testServiceClientStreamClient) Send(m *Message) error { 67 | return x.ClientStream.SendMsg(m) 68 | } 69 | 70 | func (x *testServiceClientStreamClient) CloseAndRecv() (*Message, error) { 71 | if err := x.ClientStream.CloseSend(); err != nil { 72 | return nil, err 73 | } 74 | m := new(Message) 75 | if err := x.ClientStream.RecvMsg(m); err != nil { 76 | return nil, err 77 | } 78 | return m, nil 79 | } 80 | 81 | func (c *testServiceClient) ServerStream(ctx context.Context, in *Message, opts ...grpc.CallOption) (TestService_ServerStreamClient, error) { 82 | stream, err := c.cc.NewStream(ctx, &TestService_ServiceDesc.Streams[1], "/grpchantesting.TestService/ServerStream", opts...) 83 | if err != nil { 84 | return nil, err 85 | } 86 | x := &testServiceServerStreamClient{stream} 87 | if err := x.ClientStream.SendMsg(in); err != nil { 88 | return nil, err 89 | } 90 | if err := x.ClientStream.CloseSend(); err != nil { 91 | return nil, err 92 | } 93 | return x, nil 94 | } 95 | 96 | type TestService_ServerStreamClient interface { 97 | Recv() (*Message, error) 98 | grpc.ClientStream 99 | } 100 | 101 | type testServiceServerStreamClient struct { 102 | grpc.ClientStream 103 | } 104 | 105 | func (x *testServiceServerStreamClient) Recv() (*Message, error) { 106 | m := new(Message) 107 | if err := x.ClientStream.RecvMsg(m); err != nil { 108 | return nil, err 109 | } 110 | return m, nil 111 | } 112 | 113 | func (c *testServiceClient) BidiStream(ctx context.Context, opts ...grpc.CallOption) (TestService_BidiStreamClient, error) { 114 | stream, err := c.cc.NewStream(ctx, &TestService_ServiceDesc.Streams[2], "/grpchantesting.TestService/BidiStream", opts...) 115 | if err != nil { 116 | return nil, err 117 | } 118 | x := &testServiceBidiStreamClient{stream} 119 | return x, nil 120 | } 121 | 122 | type TestService_BidiStreamClient interface { 123 | Send(*Message) error 124 | Recv() (*Message, error) 125 | grpc.ClientStream 126 | } 127 | 128 | type testServiceBidiStreamClient struct { 129 | grpc.ClientStream 130 | } 131 | 132 | func (x *testServiceBidiStreamClient) Send(m *Message) error { 133 | return x.ClientStream.SendMsg(m) 134 | } 135 | 136 | func (x *testServiceBidiStreamClient) Recv() (*Message, error) { 137 | m := new(Message) 138 | if err := x.ClientStream.RecvMsg(m); err != nil { 139 | return nil, err 140 | } 141 | return m, nil 142 | } 143 | 144 | func (c *testServiceClient) UseExternalMessageTwice(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*emptypb.Empty, error) { 145 | out := new(emptypb.Empty) 146 | err := c.cc.Invoke(ctx, "/grpchantesting.TestService/UseExternalMessageTwice", in, out, opts...) 147 | if err != nil { 148 | return nil, err 149 | } 150 | return out, nil 151 | } 152 | 153 | // TestServiceServer is the server API for TestService service. 154 | // All implementations must embed UnimplementedTestServiceServer 155 | // for forward compatibility 156 | type TestServiceServer interface { 157 | Unary(context.Context, *Message) (*Message, error) 158 | ClientStream(TestService_ClientStreamServer) error 159 | ServerStream(*Message, TestService_ServerStreamServer) error 160 | BidiStream(TestService_BidiStreamServer) error 161 | // UseExternalMessageTwice is here purely to test the protoc-gen-grpchan plug-in 162 | UseExternalMessageTwice(context.Context, *emptypb.Empty) (*emptypb.Empty, error) 163 | mustEmbedUnimplementedTestServiceServer() 164 | } 165 | 166 | // UnimplementedTestServiceServer must be embedded to have forward compatible implementations. 167 | type UnimplementedTestServiceServer struct { 168 | } 169 | 170 | func (UnimplementedTestServiceServer) Unary(context.Context, *Message) (*Message, error) { 171 | return nil, status.Errorf(codes.Unimplemented, "method Unary not implemented") 172 | } 173 | func (UnimplementedTestServiceServer) ClientStream(TestService_ClientStreamServer) error { 174 | return status.Errorf(codes.Unimplemented, "method ClientStream not implemented") 175 | } 176 | func (UnimplementedTestServiceServer) ServerStream(*Message, TestService_ServerStreamServer) error { 177 | return status.Errorf(codes.Unimplemented, "method ServerStream not implemented") 178 | } 179 | func (UnimplementedTestServiceServer) BidiStream(TestService_BidiStreamServer) error { 180 | return status.Errorf(codes.Unimplemented, "method BidiStream not implemented") 181 | } 182 | func (UnimplementedTestServiceServer) UseExternalMessageTwice(context.Context, *emptypb.Empty) (*emptypb.Empty, error) { 183 | return nil, status.Errorf(codes.Unimplemented, "method UseExternalMessageTwice not implemented") 184 | } 185 | func (UnimplementedTestServiceServer) mustEmbedUnimplementedTestServiceServer() {} 186 | 187 | // UnsafeTestServiceServer may be embedded to opt out of forward compatibility for this service. 188 | // Use of this interface is not recommended, as added methods to TestServiceServer will 189 | // result in compilation errors. 190 | type UnsafeTestServiceServer interface { 191 | mustEmbedUnimplementedTestServiceServer() 192 | } 193 | 194 | func RegisterTestServiceServer(s grpc.ServiceRegistrar, srv TestServiceServer) { 195 | s.RegisterService(&TestService_ServiceDesc, srv) 196 | } 197 | 198 | func _TestService_Unary_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { 199 | in := new(Message) 200 | if err := dec(in); err != nil { 201 | return nil, err 202 | } 203 | if interceptor == nil { 204 | return srv.(TestServiceServer).Unary(ctx, in) 205 | } 206 | info := &grpc.UnaryServerInfo{ 207 | Server: srv, 208 | FullMethod: "/grpchantesting.TestService/Unary", 209 | } 210 | handler := func(ctx context.Context, req interface{}) (interface{}, error) { 211 | return srv.(TestServiceServer).Unary(ctx, req.(*Message)) 212 | } 213 | return interceptor(ctx, in, info, handler) 214 | } 215 | 216 | func _TestService_ClientStream_Handler(srv interface{}, stream grpc.ServerStream) error { 217 | return srv.(TestServiceServer).ClientStream(&testServiceClientStreamServer{stream}) 218 | } 219 | 220 | type TestService_ClientStreamServer interface { 221 | SendAndClose(*Message) error 222 | Recv() (*Message, error) 223 | grpc.ServerStream 224 | } 225 | 226 | type testServiceClientStreamServer struct { 227 | grpc.ServerStream 228 | } 229 | 230 | func (x *testServiceClientStreamServer) SendAndClose(m *Message) error { 231 | return x.ServerStream.SendMsg(m) 232 | } 233 | 234 | func (x *testServiceClientStreamServer) Recv() (*Message, error) { 235 | m := new(Message) 236 | if err := x.ServerStream.RecvMsg(m); err != nil { 237 | return nil, err 238 | } 239 | return m, nil 240 | } 241 | 242 | func _TestService_ServerStream_Handler(srv interface{}, stream grpc.ServerStream) error { 243 | m := new(Message) 244 | if err := stream.RecvMsg(m); err != nil { 245 | return err 246 | } 247 | return srv.(TestServiceServer).ServerStream(m, &testServiceServerStreamServer{stream}) 248 | } 249 | 250 | type TestService_ServerStreamServer interface { 251 | Send(*Message) error 252 | grpc.ServerStream 253 | } 254 | 255 | type testServiceServerStreamServer struct { 256 | grpc.ServerStream 257 | } 258 | 259 | func (x *testServiceServerStreamServer) Send(m *Message) error { 260 | return x.ServerStream.SendMsg(m) 261 | } 262 | 263 | func _TestService_BidiStream_Handler(srv interface{}, stream grpc.ServerStream) error { 264 | return srv.(TestServiceServer).BidiStream(&testServiceBidiStreamServer{stream}) 265 | } 266 | 267 | type TestService_BidiStreamServer interface { 268 | Send(*Message) error 269 | Recv() (*Message, error) 270 | grpc.ServerStream 271 | } 272 | 273 | type testServiceBidiStreamServer struct { 274 | grpc.ServerStream 275 | } 276 | 277 | func (x *testServiceBidiStreamServer) Send(m *Message) error { 278 | return x.ServerStream.SendMsg(m) 279 | } 280 | 281 | func (x *testServiceBidiStreamServer) Recv() (*Message, error) { 282 | m := new(Message) 283 | if err := x.ServerStream.RecvMsg(m); err != nil { 284 | return nil, err 285 | } 286 | return m, nil 287 | } 288 | 289 | func _TestService_UseExternalMessageTwice_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { 290 | in := new(emptypb.Empty) 291 | if err := dec(in); err != nil { 292 | return nil, err 293 | } 294 | if interceptor == nil { 295 | return srv.(TestServiceServer).UseExternalMessageTwice(ctx, in) 296 | } 297 | info := &grpc.UnaryServerInfo{ 298 | Server: srv, 299 | FullMethod: "/grpchantesting.TestService/UseExternalMessageTwice", 300 | } 301 | handler := func(ctx context.Context, req interface{}) (interface{}, error) { 302 | return srv.(TestServiceServer).UseExternalMessageTwice(ctx, req.(*emptypb.Empty)) 303 | } 304 | return interceptor(ctx, in, info, handler) 305 | } 306 | 307 | // TestService_ServiceDesc is the grpc.ServiceDesc for TestService service. 308 | // It's only intended for direct use with grpc.RegisterService, 309 | // and not to be introspected or modified (even as a copy) 310 | var TestService_ServiceDesc = grpc.ServiceDesc{ 311 | ServiceName: "grpchantesting.TestService", 312 | HandlerType: (*TestServiceServer)(nil), 313 | Methods: []grpc.MethodDesc{ 314 | { 315 | MethodName: "Unary", 316 | Handler: _TestService_Unary_Handler, 317 | }, 318 | { 319 | MethodName: "UseExternalMessageTwice", 320 | Handler: _TestService_UseExternalMessageTwice_Handler, 321 | }, 322 | }, 323 | Streams: []grpc.StreamDesc{ 324 | { 325 | StreamName: "ClientStream", 326 | Handler: _TestService_ClientStream_Handler, 327 | ClientStreams: true, 328 | }, 329 | { 330 | StreamName: "ServerStream", 331 | Handler: _TestService_ServerStream_Handler, 332 | ServerStreams: true, 333 | }, 334 | { 335 | StreamName: "BidiStream", 336 | Handler: _TestService_BidiStream_Handler, 337 | ServerStreams: true, 338 | ClientStreams: true, 339 | }, 340 | }, 341 | Metadata: "test.proto", 342 | } 343 | -------------------------------------------------------------------------------- /grpchantesting/test_service.go: -------------------------------------------------------------------------------- 1 | package grpchantesting 2 | 3 | //go:generate protoc --go_out=./ --go-grpc_out=./ --grpchan_out=legacy_stubs:./ --go_opt=paths=source_relative test.proto 4 | 5 | import ( 6 | "context" 7 | "io" 8 | "time" 9 | 10 | "github.com/golang/protobuf/ptypes/empty" 11 | spb "google.golang.org/genproto/googleapis/rpc/status" 12 | "google.golang.org/grpc" 13 | "google.golang.org/grpc/metadata" 14 | "google.golang.org/grpc/status" 15 | ) 16 | 17 | // TestServer has default responses to the various kinds of methods. 18 | type TestServer struct { 19 | UnimplementedTestServiceServer 20 | } 21 | 22 | // Unary implements the TestService server interface. 23 | func (s *TestServer) Unary(ctx context.Context, req *Message) (*Message, error) { 24 | if req.DelayMillis > 0 { 25 | time.Sleep(time.Millisecond * time.Duration(req.DelayMillis)) 26 | } 27 | grpc.SetHeader(ctx, MetadataNew(req.Headers)) 28 | grpc.SetTrailer(ctx, MetadataNew(req.Trailers)) 29 | if req.Code != 0 { 30 | return nil, statusFromRequest(req) 31 | } 32 | md, _ := metadata.FromIncomingContext(ctx) 33 | return &Message{ 34 | Headers: asMap(md), 35 | Payload: req.Payload, 36 | }, nil 37 | } 38 | 39 | func statusFromRequest(req *Message) error { 40 | statProto := spb.Status{ 41 | Code: req.Code, 42 | Message: "error", 43 | Details: req.ErrorDetails, 44 | } 45 | return status.FromProto(&statProto).Err() 46 | } 47 | 48 | // ClientStream implements the TestService server interface. 49 | func (s *TestServer) ClientStream(cs TestService_ClientStreamServer) error { 50 | var req *Message 51 | count := int32(0) 52 | for { 53 | r, err := cs.Recv() 54 | if err == io.EOF { 55 | break 56 | } else if err != nil { 57 | return err 58 | } 59 | req = r 60 | count++ 61 | if req.Code != 0 { 62 | break 63 | } 64 | } 65 | if req == nil { 66 | req = &Message{} 67 | } 68 | if req.DelayMillis > 0 { 69 | time.Sleep(time.Millisecond * time.Duration(req.DelayMillis)) 70 | } 71 | if err := cs.SetHeader(MetadataNew(req.Headers)); err != nil { 72 | return err 73 | } 74 | cs.SetTrailer(MetadataNew(req.Trailers)) 75 | if req.Code != 0 { 76 | return statusFromRequest(req) 77 | } 78 | md, _ := metadata.FromIncomingContext(cs.Context()) 79 | return cs.SendAndClose(&Message{ 80 | Headers: asMap(md), 81 | Payload: req.Payload, 82 | Count: count, 83 | }) 84 | } 85 | 86 | // ServerStream implements the TestService server interface. 87 | func (s *TestServer) ServerStream(req *Message, ss TestService_ServerStreamServer) error { 88 | if req.DelayMillis > 0 { 89 | time.Sleep(time.Millisecond * time.Duration(req.DelayMillis)) 90 | } 91 | md, _ := metadata.FromIncomingContext(ss.Context()) 92 | if err := ss.SetHeader(MetadataNew(req.Headers)); err != nil { 93 | return err 94 | } 95 | for i := 0; i < int(req.Count); i++ { 96 | err := ss.Send(&Message{ 97 | Headers: asMap(md), 98 | Payload: req.Payload, 99 | }) 100 | if err != nil { 101 | return err 102 | } 103 | } 104 | ss.SetTrailer(MetadataNew(req.Trailers)) 105 | if req.Code != 0 { 106 | return statusFromRequest(req) 107 | } 108 | return nil 109 | } 110 | 111 | // BidiStream implements the TestService server interface. 112 | func (s *TestServer) BidiStream(str TestService_BidiStreamServer) error { 113 | md, _ := metadata.FromIncomingContext(str.Context()) 114 | var req *Message 115 | count := int32(0) 116 | var responses []*Message 117 | isHalfDuplex := false 118 | for { 119 | r, err := str.Recv() 120 | if err == io.EOF { 121 | break 122 | } else if err != nil { 123 | return err 124 | } 125 | req = r 126 | if req.DelayMillis > 0 { 127 | time.Sleep(time.Millisecond * time.Duration(req.DelayMillis)) 128 | } 129 | if count == 0 { 130 | if err := str.SetHeader(MetadataNew(req.Headers)); err != nil { 131 | return err 132 | } 133 | isHalfDuplex = req.Count < 0 134 | } 135 | count++ 136 | if req.Code != 0 { 137 | break 138 | } 139 | replyMsg := &Message{ 140 | Headers: asMap(md), 141 | Payload: req.Payload, 142 | Count: count, 143 | } 144 | if isHalfDuplex { 145 | // half duplex means we fully consume the client stream before we 146 | // start sending responses, so buffer these messages in a slice 147 | responses = append(responses, replyMsg) 148 | } else if err = str.Send(replyMsg); err != nil { 149 | return err 150 | } 151 | } 152 | if isHalfDuplex { 153 | // now we can send out all buffered responses 154 | for _, response := range responses { 155 | if err := str.Send(response); err != nil { 156 | return err 157 | } 158 | } 159 | } 160 | if req != nil { 161 | str.SetTrailer(MetadataNew(req.Trailers)) 162 | if req.Code != 0 { 163 | return statusFromRequest(req) 164 | } 165 | } 166 | return nil 167 | } 168 | 169 | // UseExternalMessageTwice implements the TestService server interface. 170 | func (s *TestServer) UseExternalMessageTwice(ctx context.Context, in *empty.Empty) (*empty.Empty, error) { 171 | return &empty.Empty{}, nil 172 | } 173 | 174 | func asMap(md metadata.MD) map[string][]byte { 175 | m := map[string][]byte{} 176 | for k, vs := range md { 177 | if len(vs) == 0 { 178 | continue 179 | } 180 | m[k] = []byte(vs[len(vs)-1]) 181 | } 182 | return m 183 | } 184 | -------------------------------------------------------------------------------- /httpgrpc/codes.go: -------------------------------------------------------------------------------- 1 | package httpgrpc 2 | 3 | import ( 4 | "net/http" 5 | 6 | "google.golang.org/grpc/codes" 7 | ) 8 | 9 | // httpStatusFromCode translates the given GRPC code into an HTTP 10 | // response. This is used to set the HTTP status code for unary RPCs. 11 | // (Streaming RPCs cannot convey a GRPC status code until the stream 12 | // completes, so they use a 200 HTTP status code and then encode the 13 | // actual status, along with any trailer metadata, at the end of the 14 | // response stream.) 15 | func httpStatusFromCode(code codes.Code) int { 16 | switch code { 17 | case codes.OK: 18 | return http.StatusOK 19 | case codes.Canceled: 20 | return http.StatusBadGateway 21 | case codes.Unknown: 22 | return http.StatusInternalServerError 23 | case codes.InvalidArgument: 24 | return http.StatusBadRequest 25 | case codes.DeadlineExceeded: 26 | return http.StatusGatewayTimeout 27 | case codes.NotFound: 28 | return http.StatusNotFound 29 | case codes.AlreadyExists: 30 | return http.StatusConflict 31 | case codes.PermissionDenied: 32 | return http.StatusForbidden 33 | case codes.Unauthenticated: 34 | return http.StatusUnauthorized 35 | case codes.ResourceExhausted: 36 | return http.StatusTooManyRequests 37 | case codes.FailedPrecondition: 38 | return http.StatusPreconditionFailed 39 | case codes.Aborted: 40 | return http.StatusConflict 41 | case codes.OutOfRange: 42 | return http.StatusUnprocessableEntity 43 | case codes.Unimplemented: 44 | return http.StatusNotImplemented 45 | case codes.Internal: 46 | return http.StatusInternalServerError 47 | case codes.Unavailable: 48 | return http.StatusServiceUnavailable 49 | case codes.DataLoss: 50 | return http.StatusInternalServerError 51 | default: 52 | return http.StatusInternalServerError 53 | } 54 | } 55 | 56 | // codeFromHttpStatus translates the given HTTP status code into a GRPC code. 57 | // This is used for unary RPCs where the server failed to include the actual 58 | // GRPC status code via response header. 59 | func codeFromHttpStatus(stat int) codes.Code { 60 | switch { 61 | case stat >= 200 && stat < 300: 62 | return codes.OK 63 | case stat >= 400 && stat < 500: 64 | switch stat { 65 | case http.StatusBadRequest: 66 | return codes.InvalidArgument 67 | case http.StatusUnauthorized: 68 | return codes.Unauthenticated 69 | case http.StatusForbidden: 70 | return codes.PermissionDenied 71 | case http.StatusNotFound: 72 | return codes.NotFound 73 | case http.StatusMethodNotAllowed: 74 | return codes.InvalidArgument 75 | case http.StatusRequestTimeout: 76 | return codes.DeadlineExceeded 77 | case http.StatusConflict: 78 | return codes.Aborted 79 | case http.StatusRequestedRangeNotSatisfiable: 80 | return codes.OutOfRange 81 | case http.StatusLocked: 82 | return codes.Aborted 83 | case http.StatusPreconditionFailed, http.StatusExpectationFailed: 84 | return codes.FailedPrecondition 85 | case http.StatusTooManyRequests: 86 | return codes.ResourceExhausted 87 | case 499: 88 | return codes.Canceled 89 | default: 90 | return codes.InvalidArgument 91 | } 92 | case stat >= 500 && stat < 600: 93 | switch stat { 94 | case http.StatusInternalServerError: 95 | return codes.Internal 96 | case http.StatusNotImplemented: 97 | return codes.Unimplemented 98 | case http.StatusBadGateway: 99 | return codes.Unknown 100 | case http.StatusGatewayTimeout: 101 | return codes.DeadlineExceeded 102 | case http.StatusServiceUnavailable: 103 | return codes.Unavailable 104 | default: 105 | return codes.Internal 106 | } 107 | default: 108 | // 1XX (not supported by GRPC), 3xx/redirects (not supported by GRPC), other codes 109 | return codes.Unknown 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /httpgrpc/doc.go: -------------------------------------------------------------------------------- 1 | // Package httpgrpc contains code for using HTTP 1.1 for GRPC calls. This is intended only 2 | // for environments where real GRPC is not possible or prohibitively expensive, like Google 3 | // App Engine. It could possibly be used to perform GRPC operations from a browser, however 4 | // no client implementation, other than a Go client, is provided. 5 | // 6 | // For servers, RPC handlers will be invoked directly from an HTTP request, optionally 7 | // transiting through a server interceptor. Importantly, this does not transform the 8 | // request and then proxy it on loopback to the actual GRPC server. So GRPC service 9 | // handlers are dispatched directly from HTTP server handlers. 10 | // 11 | // Clients create a *httpgrpc.Channel value, setting the two required fields to configure 12 | // it. This value can then be used to create RPC stubs, for sending RPCs to the configured 13 | // destination. 14 | // 15 | // Servers can expose individual methods using httpgrpc.HandleMethod or httpgrpc.HandleStream; 16 | // they can accumulate (1 or more) servers in a grpchan.HandlerMap and then use HandleServices 17 | // to expose them all; or they can create a httpgrpc.NewServer and register services on that 18 | // and use that to expose the services. The httpgrpc.NewServer route is the simplest and most 19 | // intuitive and really supersedes the other two options. 20 | // 21 | // # Caveats 22 | // 23 | // There are couple of limitations when using this package: 24 | // 1. True bidi streams are not supported. The best that can be done are half-duplex 25 | // bidi streams, where the client uploads its entire streaming request and then the 26 | // server can reply with a streaming response. Interleaved reading and writing does 27 | // not work with HTTP 1.1. (Even if there were clients that supported it, the Go HTTP 28 | // server APIS do not -- once a server handler starts writing to the response body, 29 | // the request body is closed and no more messages can be read from it). 30 | // 2. Client-side interceptors that interact with the *grpc.ClientConn, such as examining 31 | // connection states or querying static method configs, will not work. No GRPC 32 | // client connection is actually established and HTTP 1.1 calls will supply a nil 33 | // *grpc.ClientConn to any interceptor. 34 | // 35 | // Note that for environments like Google App Engine, which do not support streaming, use 36 | // of streaming RPCs may result in high latency and high memory usage as entire streams must 37 | // be buffered in memory. Use streams judiciously when inter-operating with App Engine. 38 | // 39 | // This package does not attempt to block use of full-duplex streaming. So if HTTP 1.1 is 40 | // used to invoke a bidi streaming method, the RPC will almost certainly fail because the 41 | // server's sending of headers and the first response message will immediately close the 42 | // request side for reading. So later attempts to read a request message will fail. 43 | // 44 | // # Anatomy of GRPC-over-HTTP 45 | // 46 | // A unary RPC is the simplest: the request will be a POST message and the request path 47 | // will be the base URL's path (if any) plus "/service.name/method" (where service.name and 48 | // method represent the fully-qualified proto service name and the method name for the 49 | // unary method being invoked). Request metadata are used directly as HTTP request headers. 50 | // The request payload is the binary-encoded form of the request proto, and the content-type 51 | // is "application/x-protobuf". The response includes the best match for an HTTP status code 52 | // based on the GRPC status code. But the response also includes a special response header, 53 | // "X-GRPC-Status", that encodes the actual GRPC status code and message in a "code:message" 54 | // string format. The response body is the binary-encoded form of the response proto, but 55 | // will be empty when the GRPC status code is not "OK". If the RPC failed and the error 56 | // includes details, they are attached via one or more headers named "X-GRPC-Details". If 57 | // more than one error detail is associated with the status, there will be more than one 58 | // header, and they will be added to the response in the same order as they appear in the 59 | // server-side status. The value for the details header is a base64-encoding 60 | // google.protobuf.Any message, which contains the error detail message. If the handler 61 | // sends trailers, not just headers, they are encoded as HTTP 1.1 headers, but their names 62 | // are prefixed with "X-GRPC-Trailer-". This allows clients to recover headers and trailers 63 | // independently, as the server handler intended them. 64 | // 65 | // Streaming RPCs are a bit more complex. Since the payloads can include multiple messages, 66 | // the content type is not "application/x-protobuf". It is instead "application/x-httpgrpc-proto+v1". 67 | // The actual request and response bodies consist of a sequence of length-delimited proto 68 | // messages, each of which is binary encoded. The length delimiter is a 32-bit prefix that 69 | // indicates the size of the subsequent message. Response sequences have a special final 70 | // message that is encoded with a negative size (e.g. if the message size were 15, it would 71 | // be written as -15 on the wire in the 32-bit prefix). The type of this special final 72 | // message is always HttpTrailer, whereas the types of all other messages in the sequence 73 | // are that of the method's request proto. The HttpTrailer final message indicates the final 74 | // disposition of the stream (e.g. a GRPC status code and error details) as well as any 75 | // trailing metadata. Because the status code is not encoded until the end of the response 76 | // payload, the HTTP status code (which is the first line of the reply) will be 200 OK. 77 | // 78 | // For clients that support streaming, client and server streams both work over HTTP 1.1. 79 | // However, bidirectional streaming methods can only work if they are "half-duplex", where 80 | // the client fully sends all request messages and then the server fully sends all response 81 | // messages (e.g. the invocation timeline can have no interleaving/overlapping of request 82 | // and response messages). 83 | package httpgrpc 84 | 85 | //go:generate protoc --go_out=. --go_opt=paths=source_relative httpgrpc.proto 86 | -------------------------------------------------------------------------------- /httpgrpc/httpgrpc.pb.go: -------------------------------------------------------------------------------- 1 | // Code generated by protoc-gen-go. DO NOT EDIT. 2 | // versions: 3 | // protoc-gen-go v1.26.0 4 | // protoc v4.22.0 5 | // source: httpgrpc.proto 6 | 7 | package httpgrpc 8 | 9 | import ( 10 | protoreflect "google.golang.org/protobuf/reflect/protoreflect" 11 | protoimpl "google.golang.org/protobuf/runtime/protoimpl" 12 | anypb "google.golang.org/protobuf/types/known/anypb" 13 | reflect "reflect" 14 | sync "sync" 15 | ) 16 | 17 | const ( 18 | // Verify that this generated code is sufficiently up-to-date. 19 | _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) 20 | // Verify that runtime/protoimpl is sufficiently up-to-date. 21 | _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) 22 | ) 23 | 24 | // HttpTrailer is the last message sent in a streaming GRPC-over-HTTP 25 | // call, to encode the GRPC status code and any trailer metadata. This 26 | // message is only used in GRPC responses, not in requests. 27 | type HttpTrailer struct { 28 | state protoimpl.MessageState 29 | sizeCache protoimpl.SizeCache 30 | unknownFields protoimpl.UnknownFields 31 | 32 | Metadata map[string]*TrailerValues `protobuf:"bytes,1,rep,name=metadata,proto3" json:"metadata,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` 33 | Code int32 `protobuf:"varint,2,opt,name=code,proto3" json:"code,omitempty"` 34 | Message string `protobuf:"bytes,3,opt,name=message,proto3" json:"message,omitempty"` 35 | Details []*anypb.Any `protobuf:"bytes,4,rep,name=details,proto3" json:"details,omitempty"` 36 | } 37 | 38 | func (x *HttpTrailer) Reset() { 39 | *x = HttpTrailer{} 40 | if protoimpl.UnsafeEnabled { 41 | mi := &file_httpgrpc_proto_msgTypes[0] 42 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 43 | ms.StoreMessageInfo(mi) 44 | } 45 | } 46 | 47 | func (x *HttpTrailer) String() string { 48 | return protoimpl.X.MessageStringOf(x) 49 | } 50 | 51 | func (*HttpTrailer) ProtoMessage() {} 52 | 53 | func (x *HttpTrailer) ProtoReflect() protoreflect.Message { 54 | mi := &file_httpgrpc_proto_msgTypes[0] 55 | if protoimpl.UnsafeEnabled && x != nil { 56 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 57 | if ms.LoadMessageInfo() == nil { 58 | ms.StoreMessageInfo(mi) 59 | } 60 | return ms 61 | } 62 | return mi.MessageOf(x) 63 | } 64 | 65 | // Deprecated: Use HttpTrailer.ProtoReflect.Descriptor instead. 66 | func (*HttpTrailer) Descriptor() ([]byte, []int) { 67 | return file_httpgrpc_proto_rawDescGZIP(), []int{0} 68 | } 69 | 70 | func (x *HttpTrailer) GetMetadata() map[string]*TrailerValues { 71 | if x != nil { 72 | return x.Metadata 73 | } 74 | return nil 75 | } 76 | 77 | func (x *HttpTrailer) GetCode() int32 { 78 | if x != nil { 79 | return x.Code 80 | } 81 | return 0 82 | } 83 | 84 | func (x *HttpTrailer) GetMessage() string { 85 | if x != nil { 86 | return x.Message 87 | } 88 | return "" 89 | } 90 | 91 | func (x *HttpTrailer) GetDetails() []*anypb.Any { 92 | if x != nil { 93 | return x.Details 94 | } 95 | return nil 96 | } 97 | 98 | type TrailerValues struct { 99 | state protoimpl.MessageState 100 | sizeCache protoimpl.SizeCache 101 | unknownFields protoimpl.UnknownFields 102 | 103 | Values []string `protobuf:"bytes,1,rep,name=values,proto3" json:"values,omitempty"` 104 | } 105 | 106 | func (x *TrailerValues) Reset() { 107 | *x = TrailerValues{} 108 | if protoimpl.UnsafeEnabled { 109 | mi := &file_httpgrpc_proto_msgTypes[1] 110 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 111 | ms.StoreMessageInfo(mi) 112 | } 113 | } 114 | 115 | func (x *TrailerValues) String() string { 116 | return protoimpl.X.MessageStringOf(x) 117 | } 118 | 119 | func (*TrailerValues) ProtoMessage() {} 120 | 121 | func (x *TrailerValues) ProtoReflect() protoreflect.Message { 122 | mi := &file_httpgrpc_proto_msgTypes[1] 123 | if protoimpl.UnsafeEnabled && x != nil { 124 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 125 | if ms.LoadMessageInfo() == nil { 126 | ms.StoreMessageInfo(mi) 127 | } 128 | return ms 129 | } 130 | return mi.MessageOf(x) 131 | } 132 | 133 | // Deprecated: Use TrailerValues.ProtoReflect.Descriptor instead. 134 | func (*TrailerValues) Descriptor() ([]byte, []int) { 135 | return file_httpgrpc_proto_rawDescGZIP(), []int{1} 136 | } 137 | 138 | func (x *TrailerValues) GetValues() []string { 139 | if x != nil { 140 | return x.Values 141 | } 142 | return nil 143 | } 144 | 145 | var File_httpgrpc_proto protoreflect.FileDescriptor 146 | 147 | var file_httpgrpc_proto_rawDesc = []byte{ 148 | 0x0a, 0x0e, 0x68, 0x74, 0x74, 0x70, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 149 | 0x12, 0x1d, 0x66, 0x75, 0x6c, 0x6c, 0x73, 0x74, 0x6f, 0x72, 0x79, 0x64, 0x65, 0x76, 0x2e, 0x67, 150 | 0x72, 0x70, 0x63, 0x68, 0x61, 0x6e, 0x2e, 0x68, 0x74, 0x74, 0x70, 0x67, 0x72, 0x70, 0x63, 0x1a, 151 | 0x19, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 152 | 0x2f, 0x61, 0x6e, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xac, 0x02, 0x0a, 0x0b, 0x48, 153 | 0x74, 0x74, 0x70, 0x54, 0x72, 0x61, 0x69, 0x6c, 0x65, 0x72, 0x12, 0x54, 0x0a, 0x08, 0x6d, 0x65, 154 | 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x38, 0x2e, 0x66, 155 | 0x75, 0x6c, 0x6c, 0x73, 0x74, 0x6f, 0x72, 0x79, 0x64, 0x65, 0x76, 0x2e, 0x67, 0x72, 0x70, 0x63, 156 | 0x68, 0x61, 0x6e, 0x2e, 0x68, 0x74, 0x74, 0x70, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x48, 0x74, 0x74, 157 | 0x70, 0x54, 0x72, 0x61, 0x69, 0x6c, 0x65, 0x72, 0x2e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 158 | 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 159 | 0x12, 0x12, 0x0a, 0x04, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x04, 160 | 0x63, 0x6f, 0x64, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 161 | 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x2e, 162 | 0x0a, 0x07, 0x64, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 163 | 0x14, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 164 | 0x66, 0x2e, 0x41, 0x6e, 0x79, 0x52, 0x07, 0x64, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x1a, 0x69, 165 | 0x0a, 0x0d, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 166 | 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 167 | 0x79, 0x12, 0x42, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 168 | 0x32, 0x2c, 0x2e, 0x66, 0x75, 0x6c, 0x6c, 0x73, 0x74, 0x6f, 0x72, 0x79, 0x64, 0x65, 0x76, 0x2e, 169 | 0x67, 0x72, 0x70, 0x63, 0x68, 0x61, 0x6e, 0x2e, 0x68, 0x74, 0x74, 0x70, 0x67, 0x72, 0x70, 0x63, 170 | 0x2e, 0x54, 0x72, 0x61, 0x69, 0x6c, 0x65, 0x72, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x52, 0x05, 171 | 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x27, 0x0a, 0x0d, 0x54, 0x72, 0x61, 172 | 0x69, 0x6c, 0x65, 0x72, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x76, 0x61, 173 | 0x6c, 0x75, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x76, 0x61, 0x6c, 0x75, 174 | 0x65, 0x73, 0x42, 0x0d, 0x5a, 0x0b, 0x2e, 0x2f, 0x3b, 0x68, 0x74, 0x74, 0x70, 0x67, 0x72, 0x70, 175 | 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, 176 | } 177 | 178 | var ( 179 | file_httpgrpc_proto_rawDescOnce sync.Once 180 | file_httpgrpc_proto_rawDescData = file_httpgrpc_proto_rawDesc 181 | ) 182 | 183 | func file_httpgrpc_proto_rawDescGZIP() []byte { 184 | file_httpgrpc_proto_rawDescOnce.Do(func() { 185 | file_httpgrpc_proto_rawDescData = protoimpl.X.CompressGZIP(file_httpgrpc_proto_rawDescData) 186 | }) 187 | return file_httpgrpc_proto_rawDescData 188 | } 189 | 190 | var file_httpgrpc_proto_msgTypes = make([]protoimpl.MessageInfo, 3) 191 | var file_httpgrpc_proto_goTypes = []interface{}{ 192 | (*HttpTrailer)(nil), // 0: fullstorydev.grpchan.httpgrpc.HttpTrailer 193 | (*TrailerValues)(nil), // 1: fullstorydev.grpchan.httpgrpc.TrailerValues 194 | nil, // 2: fullstorydev.grpchan.httpgrpc.HttpTrailer.MetadataEntry 195 | (*anypb.Any)(nil), // 3: google.protobuf.Any 196 | } 197 | var file_httpgrpc_proto_depIdxs = []int32{ 198 | 2, // 0: fullstorydev.grpchan.httpgrpc.HttpTrailer.metadata:type_name -> fullstorydev.grpchan.httpgrpc.HttpTrailer.MetadataEntry 199 | 3, // 1: fullstorydev.grpchan.httpgrpc.HttpTrailer.details:type_name -> google.protobuf.Any 200 | 1, // 2: fullstorydev.grpchan.httpgrpc.HttpTrailer.MetadataEntry.value:type_name -> fullstorydev.grpchan.httpgrpc.TrailerValues 201 | 3, // [3:3] is the sub-list for method output_type 202 | 3, // [3:3] is the sub-list for method input_type 203 | 3, // [3:3] is the sub-list for extension type_name 204 | 3, // [3:3] is the sub-list for extension extendee 205 | 0, // [0:3] is the sub-list for field type_name 206 | } 207 | 208 | func init() { file_httpgrpc_proto_init() } 209 | func file_httpgrpc_proto_init() { 210 | if File_httpgrpc_proto != nil { 211 | return 212 | } 213 | if !protoimpl.UnsafeEnabled { 214 | file_httpgrpc_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { 215 | switch v := v.(*HttpTrailer); i { 216 | case 0: 217 | return &v.state 218 | case 1: 219 | return &v.sizeCache 220 | case 2: 221 | return &v.unknownFields 222 | default: 223 | return nil 224 | } 225 | } 226 | file_httpgrpc_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { 227 | switch v := v.(*TrailerValues); i { 228 | case 0: 229 | return &v.state 230 | case 1: 231 | return &v.sizeCache 232 | case 2: 233 | return &v.unknownFields 234 | default: 235 | return nil 236 | } 237 | } 238 | } 239 | type x struct{} 240 | out := protoimpl.TypeBuilder{ 241 | File: protoimpl.DescBuilder{ 242 | GoPackagePath: reflect.TypeOf(x{}).PkgPath(), 243 | RawDescriptor: file_httpgrpc_proto_rawDesc, 244 | NumEnums: 0, 245 | NumMessages: 3, 246 | NumExtensions: 0, 247 | NumServices: 0, 248 | }, 249 | GoTypes: file_httpgrpc_proto_goTypes, 250 | DependencyIndexes: file_httpgrpc_proto_depIdxs, 251 | MessageInfos: file_httpgrpc_proto_msgTypes, 252 | }.Build() 253 | File_httpgrpc_proto = out.File 254 | file_httpgrpc_proto_rawDesc = nil 255 | file_httpgrpc_proto_goTypes = nil 256 | file_httpgrpc_proto_depIdxs = nil 257 | } 258 | -------------------------------------------------------------------------------- /httpgrpc/httpgrpc.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package fullstorydev.grpchan.httpgrpc; 4 | 5 | import "google/protobuf/any.proto"; 6 | 7 | option go_package = "./;httpgrpc"; 8 | 9 | // HttpTrailer is the last message sent in a streaming GRPC-over-HTTP 10 | // call, to encode the GRPC status code and any trailer metadata. This 11 | // message is only used in GRPC responses, not in requests. 12 | message HttpTrailer { 13 | map metadata = 1; 14 | int32 code = 2; 15 | string message = 3; 16 | repeated google.protobuf.Any details = 4; 17 | } 18 | 19 | message TrailerValues { 20 | repeated string values = 1; 21 | } 22 | -------------------------------------------------------------------------------- /httpgrpc/httpgrpc_test.go: -------------------------------------------------------------------------------- 1 | package httpgrpc_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "net" 8 | "net/http" 9 | "net/url" 10 | "testing" 11 | 12 | "github.com/fullstorydev/grpchan" 13 | "github.com/fullstorydev/grpchan/grpchantesting" 14 | "github.com/fullstorydev/grpchan/httpgrpc" 15 | "google.golang.org/grpc/status" 16 | ) 17 | 18 | func TestGrpcOverHttp(t *testing.T) { 19 | svr := &grpchantesting.TestServer{} 20 | reg := grpchan.HandlerMap{} 21 | grpchantesting.RegisterTestServiceServer(reg, svr) 22 | 23 | var mux http.ServeMux 24 | httpgrpc.HandleServices(mux.HandleFunc, "/", reg, nil, nil) 25 | 26 | l, err := net.Listen("tcp", "127.0.0.1:0") 27 | if err != nil { 28 | t.Fatalf("failed it listen on socket: %v", err) 29 | } 30 | httpServer := http.Server{Handler: &mux} 31 | go httpServer.Serve(l) 32 | defer httpServer.Close() 33 | 34 | // now setup client stub 35 | u, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", l.Addr().(*net.TCPAddr).Port)) 36 | if err != nil { 37 | t.Fatalf("failed to parse base URL: %v", err) 38 | } 39 | cc := httpgrpc.Channel{ 40 | Transport: http.DefaultTransport, 41 | BaseURL: u, 42 | } 43 | 44 | grpchantesting.RunChannelTestCases(t, &cc, false) 45 | 46 | t.Run("empty-trailer", func(t *testing.T) { 47 | // test RPC w/ streaming response where trailer message is empty 48 | // (e.g. no trailer metadata and code == 0 [OK]) 49 | cli := grpchantesting.NewTestServiceClient(&cc) 50 | str, err := cli.ServerStream(context.Background(), &grpchantesting.Message{}) 51 | if err != nil { 52 | t.Fatalf("failed to initiate server stream: %v", err) 53 | } 54 | // if there is an issue with trailer message, it will appear to be 55 | // a regular message and err would be nil 56 | _, err = str.Recv() 57 | if err != io.EOF { 58 | t.Fatalf("server stream should not have returned any messages") 59 | } 60 | }) 61 | } 62 | 63 | // This test is nearly identical to TestGrpcOverHttp, except that it uses 64 | // *httpgrpc.Server instead of httpgrpc.HandleServices. 65 | func TestServer(t *testing.T) { 66 | errFunc := func(reqCtx context.Context, st *status.Status, response http.ResponseWriter) { 67 | 68 | } 69 | 70 | svc := &grpchantesting.TestServer{} 71 | svr := httpgrpc.NewServer(httpgrpc.WithBasePath("/foo/"), httpgrpc.ErrorRenderer(errFunc)) 72 | grpchantesting.RegisterTestServiceServer(svr, svc) 73 | 74 | l, err := net.Listen("tcp", "127.0.0.1:0") 75 | if err != nil { 76 | t.Fatalf("failed it listen on socket: %v", err) 77 | } 78 | httpServer := http.Server{Handler: svr} 79 | go httpServer.Serve(l) 80 | defer httpServer.Close() 81 | 82 | // now setup client stub 83 | u, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d/foo/", l.Addr().(*net.TCPAddr).Port)) 84 | if err != nil { 85 | t.Fatalf("failed to parse base URL: %v", err) 86 | } 87 | cc := httpgrpc.Channel{ 88 | Transport: http.DefaultTransport, 89 | BaseURL: u, 90 | } 91 | 92 | grpchantesting.RunChannelTestCases(t, &cc, false) 93 | 94 | t.Run("empty-trailer", func(t *testing.T) { 95 | // test RPC w/ streaming response where trailer message is empty 96 | // (e.g. no trailer metadata and code == 0 [OK]) 97 | cli := grpchantesting.NewTestServiceClient(&cc) 98 | str, err := cli.ServerStream(context.Background(), &grpchantesting.Message{}) 99 | if err != nil { 100 | t.Fatalf("failed to initiate server stream: %v", err) 101 | } 102 | // if there is an issue with trailer message, it will appear to be 103 | // a regular message and err would be nil 104 | _, err = str.Recv() 105 | if err != io.EOF { 106 | t.Fatalf("server stream should not have returned any messages") 107 | } 108 | }) 109 | } 110 | -------------------------------------------------------------------------------- /httpgrpc/io.go: -------------------------------------------------------------------------------- 1 | package httpgrpc 2 | 3 | import ( 4 | "encoding/base64" 5 | "encoding/binary" 6 | "fmt" 7 | "google.golang.org/grpc/mem" 8 | "io" 9 | "math" 10 | "net/http" 11 | "strings" 12 | 13 | "google.golang.org/grpc/encoding" 14 | "google.golang.org/grpc/metadata" 15 | ) 16 | 17 | const ( 18 | maxMessageSize = 100 * 1024 * 1024 // 100mb 19 | ) 20 | 21 | // writeSizePreface writes the given 32-bit size to the given writer. 22 | func writeSizePreface(w io.Writer, sz int32) error { 23 | return binary.Write(w, binary.BigEndian, sz) 24 | } 25 | 26 | // writeProtoMessage writes a length-delimited proto message to the given 27 | // writer. This writes the size preface, indicating the size of the encoded 28 | // message, followed by the actual message contents. If end is true, the 29 | // size is written as a negative value, indicating to the receiver that this 30 | // is the last message in the stream. (The last message should be an instance 31 | // of HttpTrailer.) 32 | func writeProtoMessage(w io.Writer, codec encoding.CodecV2, m interface{}, end bool) error { 33 | buf, err := codec.Marshal(m) 34 | if err != nil { 35 | return err 36 | } 37 | b := buf.Materialize() 38 | 39 | sz := len(b) 40 | if sz > math.MaxInt32 { 41 | return fmt.Errorf("message too large to send: %d bytes", sz) 42 | } 43 | if end { 44 | // trailer message is indicated w/ negative size 45 | sz = -sz 46 | } 47 | err = writeSizePreface(w, int32(sz)) 48 | if err != nil { 49 | return err 50 | } 51 | 52 | _, err = w.Write(b) 53 | if err == nil { 54 | if f, ok := w.(http.Flusher); ok { 55 | f.Flush() 56 | } 57 | } 58 | return err 59 | } 60 | 61 | // readSizePreface reads a 32-bit size from the given reader. If the value is 62 | // negative, it indicates the last message in the stream. Messages can have zero 63 | // size, but the last message in the stream should never have zero size (so its 64 | // size will be negative). 65 | func readSizePreface(in io.Reader) (int32, error) { 66 | var sz int32 67 | err := binary.Read(in, binary.BigEndian, &sz) 68 | return sz, err 69 | } 70 | 71 | // readProtoMessage reads data from the given reader and decodes it into the given 72 | // message. The sz parameter indicates the number of bytes that must be read to 73 | // decode the proto. This does not first call readSizePreface; callers must do that 74 | // first. 75 | func readProtoMessage(in io.Reader, codec encoding.CodecV2, sz int32, m interface{}) error { 76 | if sz < 0 { 77 | return fmt.Errorf("bad size preface: size cannot be negative: %d", sz) 78 | } else if sz > maxMessageSize { 79 | return fmt.Errorf("bad size preface: indicated size is too large: %d", sz) 80 | } 81 | msg := make([]byte, sz) 82 | _, err := io.ReadAtLeast(in, msg, int(sz)) 83 | if err != nil { 84 | return err 85 | } 86 | return codec.Unmarshal(mem.BufferSlice{mem.SliceBuffer(msg)}, m) 87 | } 88 | 89 | // asMetadata converts the given HTTP headers into GRPC metadata. 90 | func asMetadata(header http.Header) (metadata.MD, error) { 91 | // metadata has same shape as http.Header, 92 | md := metadata.MD{} 93 | for k, vs := range header { 94 | k = strings.ToLower(k) 95 | for _, v := range vs { 96 | if strings.HasSuffix(k, "-bin") { 97 | vv, err := base64.URLEncoding.DecodeString(v) 98 | if err != nil { 99 | return nil, err 100 | } 101 | v = string(vv) 102 | } 103 | md[k] = append(md[k], v) 104 | } 105 | } 106 | return md, nil 107 | } 108 | 109 | var reservedHeaders = map[string]struct{}{ 110 | "accept-encoding": {}, 111 | "connection": {}, 112 | "content-type": {}, 113 | "content-length": {}, 114 | "keep-alive": {}, 115 | "te": {}, 116 | "trailer": {}, 117 | "transfer-encoding": {}, 118 | "upgrade": {}, 119 | } 120 | 121 | func toHeaders(md metadata.MD, h http.Header, prefix string) { 122 | // binary headers must be base-64-encoded 123 | for k, vs := range md { 124 | lowerK := strings.ToLower(k) 125 | if _, ok := reservedHeaders[lowerK]; ok { 126 | // ignore reserved header keys 127 | continue 128 | } 129 | isBin := strings.HasSuffix(lowerK, "-bin") 130 | for _, v := range vs { 131 | if isBin { 132 | v = base64.URLEncoding.EncodeToString([]byte(v)) 133 | } 134 | h.Add(prefix+k, v) 135 | } 136 | } 137 | } 138 | 139 | type strAddr string 140 | 141 | func (a strAddr) Network() string { 142 | if a != "" { 143 | // Per the documentation on net/http.Request.RemoteAddr, if this is 144 | // set, it's set to the IP:port of the peer (hence, TCP): 145 | // https://golang.org/pkg/net/http/#Request 146 | // 147 | // If we want to support Unix sockets later, we can 148 | // add our own grpc-specific convention within the 149 | // grpc codebase to set RemoteAddr to a different 150 | // format, or probably better: we can attach it to the 151 | // context and use that from serverHandlerTransport.RemoteAddr. 152 | return "tcp" 153 | } 154 | return "" 155 | } 156 | 157 | func (a strAddr) String() string { return string(a) } 158 | -------------------------------------------------------------------------------- /httpgrpc/json.go: -------------------------------------------------------------------------------- 1 | package httpgrpc 2 | 3 | import ( 4 | //lint:ignore SA1019 we use the old v1 package because 5 | // we need to support older generated messages 6 | "github.com/golang/protobuf/proto" 7 | "google.golang.org/grpc/encoding" 8 | "google.golang.org/grpc/mem" 9 | "google.golang.org/protobuf/encoding/protojson" 10 | ) 11 | 12 | var ( 13 | grpcJsonMarshaler = protojson.MarshalOptions{ 14 | UseEnumNumbers: true, 15 | EmitUnpopulated: true, 16 | } 17 | 18 | grpcJsonUnmarshaler = protojson.UnmarshalOptions{ 19 | DiscardUnknown: true, 20 | } 21 | ) 22 | 23 | func init() { 24 | encoding.RegisterCodecV2(jsonCodec{}) 25 | } 26 | 27 | type jsonCodec struct{} 28 | 29 | func (c jsonCodec) Marshal(v interface{}) (mem.BufferSlice, error) { 30 | msg := proto.MessageV2(v.(proto.Message)) 31 | bb, err := grpcJsonMarshaler.Marshal(msg) 32 | return mem.BufferSlice{mem.SliceBuffer(bb)}, err 33 | } 34 | 35 | func (c jsonCodec) Unmarshal(data mem.BufferSlice, v interface{}) error { 36 | msg := proto.MessageV2(v.(proto.Message)) 37 | return grpcJsonUnmarshaler.Unmarshal(data.Materialize(), msg) 38 | } 39 | 40 | func (c jsonCodec) Name() string { 41 | return "json" 42 | } 43 | -------------------------------------------------------------------------------- /httpgrpc/protocol_versions.go: -------------------------------------------------------------------------------- 1 | package httpgrpc 2 | 3 | import ( 4 | "mime" 5 | 6 | "google.golang.org/grpc/encoding" 7 | grpcproto "google.golang.org/grpc/encoding/proto" 8 | ) 9 | 10 | // If the on-the-wire encoding every needs to be changed in a backwards-incompatible way, 11 | // here are the steps for doing so: 12 | // 13 | // 1. Define new content-types that represent the new encoding. If the current encoding is 14 | // "v1" then increment (e.g. "v2"). (No semver here, just a single integer version...) 15 | // 2. Update server code to switch on incoming content-type. It must continue to support 16 | // the previous version protocol if it sees the corresponding content-types. 17 | // NOTE: Servers should only support two versions at a time; let's call them Version-Now 18 | // and Version-Next. Version-Next should be fully deployed (in all clients and servers) 19 | // before any third version is conceived. That way, Version-Now support can be removed. 20 | // The code will be simpler if we only support to 2 versions, instead of up to N. 21 | // 3. Create a new implementation of github.com/fullstorydev/grpchan.Channel named channelNext 22 | // that implements Version-Next for clients. IMPORTANT: note that the recommended name 23 | // is NOT exported. It should not yet be usable and should only be used for tests inside 24 | // this package. 25 | // 4. Update tests so that they perform the same cases with BOTH client versions: e.g with 26 | // both Channel and channelNext instances. This confirms that servers correctly continue 27 | // to support Version-Now and also ensures that Version-Next is functional end-to-end. 28 | // 5. Update the package documentation to describe the Version-Next protocol anatomy. 29 | // 6. Deploy servers!!! 30 | // 7. Only after all servers support Version-Next is it safe to export the Version-Next 31 | // channel implementation and use it outside of testing. At this time, you can safely 32 | // change channelNext to be named Channel and then remove the old code for Version-Now. 33 | 34 | // These are the content-types used for "version 1" (hopefully the only version ever?) 35 | // of the gRPC-over-HTTP transport 36 | const ( 37 | UnaryRpcContentType_V1 = "application/x-protobuf" 38 | StreamRpcContentType_V1 = "application/x-httpgrpc-proto+v1" 39 | ) 40 | 41 | const ( 42 | // Non-standard and experimental; uses the `jsonpb.Marshaler` by default. 43 | // Only unary calls are supported; streams with JSON encoding are not supported. 44 | // Use `encoding.RegisterCodecV2` to override the default encoder with a custom encoder. 45 | ApplicationJson = "application/json" 46 | ) 47 | 48 | func getUnaryCodec(contentType string) encoding.CodecV2 { 49 | // Ignore any errors or charsets for now, just parse the main type. 50 | // TODO: should this be more picky / return an error? Maybe charset utf8 only? 51 | mediaType, _, _ := mime.ParseMediaType(contentType) 52 | 53 | if mediaType == UnaryRpcContentType_V1 { 54 | return encoding.GetCodecV2(grpcproto.Name) 55 | } 56 | 57 | if mediaType == ApplicationJson { 58 | return encoding.GetCodecV2("json") 59 | } 60 | 61 | return nil 62 | } 63 | 64 | func getStreamingCodec(contentType string) encoding.CodecV2 { 65 | // Ignore any errors or charsets for now, just parse the main type. 66 | // TODO: should this be more picky / return an error? Maybe charset utf8 only? 67 | mediaType, _, _ := mime.ParseMediaType(contentType) 68 | 69 | if mediaType == StreamRpcContentType_V1 { 70 | return encoding.GetCodecV2(grpcproto.Name) 71 | } 72 | 73 | if mediaType == ApplicationJson { 74 | // TODO: support half-duplix JSON streaming? 75 | // https://en.wikipedia.org/wiki/JSON_streaming#Record_separator-delimited_JSON 76 | return nil 77 | } 78 | 79 | return nil 80 | } 81 | -------------------------------------------------------------------------------- /inprocgrpc/cloner.go: -------------------------------------------------------------------------------- 1 | package inprocgrpc 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | 7 | //lint:ignore SA1019 we use the old v1 package because 8 | // we need to support older generated messages 9 | "github.com/golang/protobuf/proto" 10 | "google.golang.org/grpc/encoding" 11 | grpcproto "google.golang.org/grpc/encoding/proto" 12 | 13 | "github.com/fullstorydev/grpchan/internal" 14 | ) 15 | 16 | // Cloner knows how to make copies of messages. It can be asked to copy one 17 | // value into another, and it can also be asked to simply synthesize a new 18 | // value that is a copy of some input value. 19 | // 20 | // This is used to copy messages between in-process client and server. Copying 21 | // will usually be more efficient than marshalling to bytes and back (though 22 | // that is a valid strategy that a custom Cloner implementation could take). 23 | // Copies are made to avoid sharing values across client and server goroutines. 24 | type Cloner interface { 25 | Copy(out, in interface{}) error 26 | Clone(interface{}) (interface{}, error) 27 | } 28 | 29 | // ProtoCloner is the default cloner used by an in-process channel. This 30 | // implementation can correctly handle protobuf messages. Copy and clone 31 | // operations will fail if the input message is not a protobuf message (in 32 | // which case a custom cloner must be used). 33 | type ProtoCloner struct{} 34 | 35 | var _ Cloner = ProtoCloner{} 36 | 37 | func (ProtoCloner) Copy(out, in interface{}) error { 38 | _, outIsProto := out.(proto.Message) 39 | _, inIsProto := in.(proto.Message) 40 | if inIsProto && outIsProto { 41 | return internal.CopyMessage(out, in) 42 | } 43 | // maybe the user has registered a gRPC codec that can 44 | // handle this thing 45 | if codec := encoding.GetCodecV2(grpcproto.Name); codec != nil { 46 | return CodecClonerV2(codec).Copy(out, in) 47 | } 48 | if codec := encoding.GetCodec(grpcproto.Name); codec != nil { 49 | return CodecCloner(codec).Copy(out, in) 50 | } 51 | panic("no codec found") 52 | } 53 | 54 | func (ProtoCloner) Clone(in interface{}) (interface{}, error) { 55 | if _, isProto := in.(proto.Message); isProto { 56 | return internal.CloneMessage(in) 57 | } 58 | // maybe the user has registered a gRPC codec that can 59 | // handle this thing 60 | if codec := encoding.GetCodecV2(grpcproto.Name); codec != nil { 61 | return CodecClonerV2(codec).Clone(in) 62 | } 63 | if codec := encoding.GetCodec(grpcproto.Name); codec != nil { 64 | return CodecCloner(codec).Clone(in) 65 | } 66 | panic("no codec found") 67 | } 68 | 69 | // CloneFunc adapts a single clone function to the Cloner interface. The given 70 | // function implements the Clone method. To implement the Copy method, the given 71 | // function is invoked and then reflection is used to shallow copy the clone to 72 | // the output. 73 | func CloneFunc(fn func(interface{}) (interface{}, error)) Cloner { 74 | copyFn := func(out, in interface{}) error { 75 | in, err := fn(in) // deep copy input 76 | if err != nil { 77 | return err 78 | } 79 | 80 | // then shallow-copy into out via reflection 81 | src := reflect.Indirect(reflect.ValueOf(in)) 82 | dest := reflect.Indirect(reflect.ValueOf(out)) 83 | if src.Type() != dest.Type() { 84 | return fmt.Errorf("incompatible types: %v != %v", src.Type(), dest.Type()) 85 | } 86 | if !dest.CanSet() { 87 | return fmt.Errorf("unable to set destination: %v", reflect.ValueOf(out).Type()) 88 | } 89 | dest.Set(src) 90 | return nil 91 | 92 | } 93 | return &funcCloner{clone: fn, copy: copyFn} 94 | } 95 | 96 | // CopyFunc adapts a single copy function to the Cloner interface. The given 97 | // function implements the Copy method. To implement the Clone method, a new 98 | // value of the same type is created using reflection and then the given 99 | // function is used to copy the input to the newly created value. 100 | func CopyFunc(fn func(out, in interface{}) error) Cloner { 101 | cloneFn := func(in interface{}) (interface{}, error) { 102 | clone := reflect.New(reflect.TypeOf(in).Elem()).Interface() 103 | if err := fn(clone, in); err != nil { 104 | return nil, err 105 | } 106 | return clone, nil 107 | } 108 | return &funcCloner{clone: cloneFn, copy: fn} 109 | } 110 | 111 | // CodecCloner uses the given codec to implement the Cloner interface. The Copy 112 | // method is implemented by using the code to marshal the input to bytes and 113 | // then unmarshal from bytes into the output value. The Clone method then uses 114 | // reflection to create a new value of the same type and uses this strategy to 115 | // then copy the input to the newly created value. 116 | func CodecCloner(codec encoding.Codec) Cloner { 117 | return CopyFunc(func(out, in interface{}) error { 118 | if b, err := codec.Marshal(in); err != nil { 119 | return err 120 | } else if err := codec.Unmarshal(b, out); err != nil { 121 | return err 122 | } 123 | return nil 124 | }) 125 | } 126 | 127 | // CodecClonerV2 uses the given codec to implement the Cloner interface. The Copy 128 | // method is implemented by using the code to marshal the input to bytes and 129 | // then unmarshal from bytes into the output value. The Clone method then uses 130 | // reflection to create a new value of the same type and uses this strategy to 131 | // then copy the input to the newly created value. 132 | func CodecClonerV2(codec encoding.CodecV2) Cloner { 133 | return CopyFunc(func(out, in interface{}) error { 134 | if b, err := codec.Marshal(in); err != nil { 135 | return err 136 | } else if err := codec.Unmarshal(b, out); err != nil { 137 | return err 138 | } 139 | return nil 140 | }) 141 | } 142 | 143 | type funcCloner struct { 144 | clone func(interface{}) (interface{}, error) 145 | copy func(in, out interface{}) error 146 | } 147 | 148 | var _ Cloner = (*funcCloner)(nil) 149 | 150 | func (c *funcCloner) Copy(out, in interface{}) error { 151 | return c.copy(out, in) 152 | } 153 | 154 | func (c *funcCloner) Clone(in interface{}) (interface{}, error) { 155 | return c.clone(in) 156 | } 157 | -------------------------------------------------------------------------------- /inprocgrpc/cloner_test.go: -------------------------------------------------------------------------------- 1 | package inprocgrpc 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "testing" 7 | 8 | "google.golang.org/protobuf/encoding/protojson" 9 | "google.golang.org/protobuf/proto" 10 | 11 | "github.com/fullstorydev/grpchan/httpgrpc" 12 | ) 13 | 14 | var ( 15 | source *httpgrpc.HttpTrailer 16 | sourceJs string // snapshot of source as JSON 17 | 18 | jsm = &protojson.MarshalOptions{} 19 | ) 20 | 21 | func init() { 22 | source = &httpgrpc.HttpTrailer{ 23 | Code: 123, 24 | Message: "foobar", 25 | Metadata: map[string]*httpgrpc.TrailerValues{ 26 | "abc": {Values: []string{"a", "b", "c"}}, 27 | "def": {Values: []string{"foo", "bar", "baz"}}, 28 | "ghi": {Values: []string{"xyz", "123"}}, 29 | }, 30 | } 31 | sourceJsBytes, err := jsm.Marshal(source) 32 | if err != nil { 33 | panic(err) 34 | } 35 | sourceJs = string(sourceJsBytes) 36 | } 37 | 38 | func TestProtoCloner(t *testing.T) { 39 | testCloner(t, ProtoCloner{}) 40 | } 41 | 42 | type protoCodec struct{} 43 | 44 | func (protoCodec) Marshal(v interface{}) ([]byte, error) { 45 | return proto.Marshal(v.(proto.Message)) 46 | } 47 | 48 | func (protoCodec) Unmarshal(data []byte, v interface{}) error { 49 | return proto.Unmarshal(data, v.(proto.Message)) 50 | } 51 | 52 | func (protoCodec) Name() string { 53 | return "proto" 54 | } 55 | 56 | func TestCodecCloner(t *testing.T) { 57 | testCloner(t, CodecCloner(protoCodec{})) 58 | } 59 | 60 | func TestCloneFunc(t *testing.T) { 61 | testCloner(t, CloneFunc(func(in interface{}) (interface{}, error) { 62 | return proto.Clone(in.(proto.Message)), nil 63 | })) 64 | } 65 | 66 | func TestCopyFunc(t *testing.T) { 67 | testCloner(t, CopyFunc(func(out, in interface{}) error { 68 | if reflect.TypeOf(in) != reflect.TypeOf(out) { 69 | return fmt.Errorf("type mismatch: %T != %T", in, out) 70 | } 71 | if reflect.ValueOf(out).IsNil() { 72 | return fmt.Errorf("out must not be nil") 73 | } 74 | inM := in.(proto.Message) 75 | outM := out.(proto.Message) 76 | proto.Reset(outM) 77 | proto.Merge(outM, inM) 78 | return nil 79 | })) 80 | } 81 | 82 | func testCloner(t *testing.T, cloner Cloner) { 83 | dest := &httpgrpc.HttpTrailer{} 84 | err := cloner.Copy(dest, source) 85 | if err != nil { 86 | t.Fatalf("Copy returned unexpected error: %v", err) 87 | } 88 | if !proto.Equal(source, dest) { 89 | t.Fatalf("Copy failed to produce a value equal to input") 90 | } 91 | checkIndependence(t, dest) 92 | 93 | clone, err := cloner.Clone(source) 94 | if err != nil { 95 | t.Fatalf("Clone returned unexpected error: %v", err) 96 | } 97 | if !proto.Equal(source, clone.(proto.Message)) { 98 | t.Fatalf("Clone failed to produce a value equal to input") 99 | } 100 | checkIndependence(t, clone.(*httpgrpc.HttpTrailer)) 101 | } 102 | 103 | func checkIndependence(t *testing.T, dest *httpgrpc.HttpTrailer) { 104 | // mutate copy and make sure we don't see it in original 105 | // (e.g. verifies the copy is a deep copy) 106 | dest.Message += "baz" 107 | dest.Metadata["ghi"].Values = append(dest.Metadata["ghi"].Values, "456") 108 | dest.Metadata["jkl"] = &httpgrpc.TrailerValues{Values: []string{"zomg!"}} 109 | 110 | sourceJs2, err := jsm.Marshal(source) 111 | if err != nil { 112 | t.Fatalf("Failed to marsal message to JSON: %v", err) 113 | } 114 | if string(sourceJs2) != sourceJs { 115 | t.Errorf("source changed after mutating dest!\nExpecting:\n%s\nGot:\n%s\n", sourceJs, string(sourceJs2)) 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /inprocgrpc/in_process_test.go: -------------------------------------------------------------------------------- 1 | package inprocgrpc_test 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "reflect" 7 | "runtime" 8 | "testing" 9 | "time" 10 | 11 | "github.com/fullstorydev/grpchan/grpchantesting" 12 | "github.com/fullstorydev/grpchan/inprocgrpc" 13 | "github.com/jhump/protoreflect/desc" 14 | "github.com/jhump/protoreflect/dynamic" 15 | "github.com/jhump/protoreflect/dynamic/grpcdynamic" 16 | "google.golang.org/grpc" 17 | "google.golang.org/grpc/metadata" 18 | ) 19 | 20 | func TestInProcessChannel(t *testing.T) { 21 | svr := &grpchantesting.TestServer{} 22 | 23 | var cc inprocgrpc.Channel 24 | grpchantesting.RegisterTestServiceServer(&cc, svr) 25 | 26 | before := runtime.NumGoroutine() 27 | 28 | grpchantesting.RunChannelTestCases(t, &cc, true) 29 | 30 | // check for goroutine leaks 31 | deadline := time.Now().Add(time.Second * 5) 32 | after := 0 33 | for deadline.After(time.Now()) { 34 | after = runtime.NumGoroutine() 35 | if after <= before { 36 | // number of goroutines returned to previous level: no leak! 37 | return 38 | } 39 | time.Sleep(time.Millisecond * 50) 40 | } 41 | t.Errorf("%d goroutines leaked", after-before) 42 | } 43 | 44 | func TestUseDynamicMessage(t *testing.T) { 45 | // This uses dynamic messages for request and response and 46 | // ensures the in-process channel works correctly that way. 47 | 48 | svr := &grpchantesting.TestServer{} 49 | 50 | var cc inprocgrpc.Channel 51 | grpchantesting.RegisterTestServiceServer(&cc, svr) 52 | stub := grpcdynamic.NewStub(&cc) 53 | 54 | fd, err := desc.LoadFileDescriptor("test.proto") 55 | if err != nil { 56 | t.Fatalf("failed to load descriptor for test.proto: %v", err) 57 | } 58 | md := fd.FindMessage("grpchantesting.Message") 59 | if md == nil { 60 | t.Fatalf("could not find descriptor for grpchantesting.Message") 61 | } 62 | sd := fd.FindService("grpchantesting.TestService") 63 | if sd == nil { 64 | t.Fatalf("could not find descriptor for grpchantesting.TestService") 65 | } 66 | mtd := sd.FindMethodByName("Unary") 67 | if mtd == nil { 68 | t.Fatalf("could not find descriptor for grpchantesting.TestService/Unary") 69 | } 70 | 71 | testPayload := []byte{100, 90, 80, 70, 60, 50, 40, 30, 20, 10, 0} 72 | testOutgoingMd := map[string][]byte{ 73 | "foo": []byte("bar"), 74 | } 75 | testMdHeaders := map[string][]byte{ 76 | "foo1": []byte("bar2"), 77 | } 78 | testMdTrailers := map[string][]byte{ 79 | "foo3": []byte("bar4"), 80 | } 81 | 82 | ctx := metadata.NewOutgoingContext(context.Background(), grpchantesting.MetadataNew(testOutgoingMd)) 83 | req := dynamic.NewMessage(md) 84 | req.SetFieldByName("payload", testPayload) 85 | req.SetFieldByName("headers", testMdHeaders) 86 | req.SetFieldByName("trailers", testMdTrailers) 87 | 88 | var hdr, tlr metadata.MD 89 | rsp, err := stub.InvokeRpc(ctx, mtd, req, grpc.Header(&hdr), grpc.Trailer(&tlr)) 90 | if err != nil { 91 | t.Fatalf("RPC failed: %v", err) 92 | } 93 | msg := rsp.(*dynamic.Message) 94 | 95 | payload := msg.GetFieldByName("payload") 96 | if !bytes.Equal(testPayload, payload.([]byte)) { 97 | t.Fatalf("wrong payload returned: expecting %v; got %v", testPayload, payload) 98 | } 99 | reqHeaders := map[string][]byte{} 100 | for k, v := range msg.GetFieldByName("headers").(map[interface{}]interface{}) { 101 | reqHeaders[k.(string)] = v.([]byte) 102 | } 103 | if !reflect.DeepEqual(testOutgoingMd, reqHeaders) { 104 | t.Fatalf("wrong request headers echoed back: expecting %v; got %v", testOutgoingMd, reqHeaders) 105 | } 106 | 107 | actualHdrs := map[string][]byte{} 108 | for k, v := range hdr { 109 | if len(v) > 1 { 110 | t.Fatalf("too many values for response header %q", k) 111 | } 112 | actualHdrs[k] = []byte(v[0]) 113 | } 114 | if !reflect.DeepEqual(testMdHeaders, actualHdrs) { 115 | t.Fatalf("wrong response headers echoed back: expecting %v; got %v", testMdHeaders, actualHdrs) 116 | } 117 | 118 | actualTlrs := map[string][]byte{} 119 | for k, v := range tlr { 120 | if len(v) > 1 { 121 | t.Fatalf("too many values for response trailer %q", k) 122 | } 123 | actualTlrs[k] = []byte(v[0]) 124 | } 125 | if !reflect.DeepEqual(testMdTrailers, actualTlrs) { 126 | t.Fatalf("wrong response trailers echoed back: expecting %v; got %v", testMdTrailers, actualTlrs) 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /inprocgrpc/no_values_context_test.go: -------------------------------------------------------------------------------- 1 | package inprocgrpc 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | ) 7 | 8 | //lint:file-ignore SA1029 context values are just for tests 9 | 10 | func TestNoValuesContext(t *testing.T) { 11 | ctx := context.WithValue(context.Background(), "abc", "def") 12 | ctx = context.WithValue(ctx, "xyz", "123") 13 | ctx = context.WithValue(ctx, "foo", "bar") 14 | ctx, cancel := context.WithCancel(ctx) 15 | 16 | nvCtx := context.Context(noValuesContext{ctx}) 17 | nvCtx = context.WithValue(nvCtx, "frob", "nitz") 18 | 19 | // make sure no values are supplied by wrapped context 20 | if nvCtx.Value("abc") != nil { 21 | t.Errorf(`noValuesContext should not have value for key "abc"`) 22 | } 23 | if nvCtx.Value("xyz") != nil { 24 | t.Errorf(`noValuesContext should not have value for key "xyz"`) 25 | } 26 | if nvCtx.Value("foo") != nil { 27 | t.Errorf(`noValuesContext should not have value for key "foo"`) 28 | } 29 | // it should, of course, have its own value 30 | if nvCtx.Value("frob") != "nitz" { 31 | t.Errorf(`noValuesContext returned wrong value for key "frob": expecting "nitz", got %v`, nvCtx.Value("frob")) 32 | } 33 | 34 | // and it still respect's cancellation/deadlines of the parent context 35 | if nvCtx.Err() != nil { 36 | t.Errorf(`noValuesContext should not be done!`) 37 | } 38 | cancel() 39 | if nvCtx.Err() != context.Canceled { 40 | t.Errorf(`noValuesContext should be canceled!`) 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /intercept.go: -------------------------------------------------------------------------------- 1 | package grpchan 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "google.golang.org/grpc" 8 | ) 9 | 10 | // WrappedClientConn is a channel that wraps another. It provides an Unwrap method 11 | // for access the underlying wrapped implementation. 12 | type WrappedClientConn interface { 13 | grpc.ClientConnInterface 14 | Unwrap() grpc.ClientConnInterface 15 | } 16 | 17 | // InterceptChannel returns a new channel that intercepts RPCs with the given 18 | // interceptors. If both given interceptors are nil, returns ch. Otherwise, the 19 | // returned value will implement WrappedClientConn and its Unwrap() method will 20 | // return ch. 21 | // 22 | // Deprecated: Use InterceptClientConn instead. 23 | func InterceptChannel(ch grpc.ClientConnInterface, unaryInt grpc.UnaryClientInterceptor, streamInt grpc.StreamClientInterceptor) grpc.ClientConnInterface { 24 | return InterceptClientConn(ch, unaryInt, streamInt) 25 | } 26 | 27 | // InterceptClientConn returns a new channel that intercepts RPCs with the given 28 | // interceptors, which may be nil. If both given interceptors are nil, returns ch. 29 | // Otherwise, the returned value will implement WrappedClientConn and its Unwrap() 30 | // method will return ch. 31 | func InterceptClientConn(ch grpc.ClientConnInterface, unaryInt grpc.UnaryClientInterceptor, streamInt grpc.StreamClientInterceptor) grpc.ClientConnInterface { 32 | if unaryInt != nil { 33 | ch = InterceptClientConnUnary(ch, unaryInt) 34 | } 35 | if streamInt != nil { 36 | ch = InterceptClientConnStream(ch, streamInt) 37 | } 38 | return ch 39 | } 40 | 41 | // InterceptClientConnUnary returns a new channel that intercepts unary RPCs 42 | // with the given chain of interceptors. If the given set of interceptors is 43 | // empty, this returns ch. Otherwise, the returned value will implement 44 | // WrappedClientConn and its Unwrap() method will return ch. 45 | // 46 | // The first interceptor in the set will be the first one invoked when an RPC 47 | // is called. When that interceptor delegates to the provided invoker, it will 48 | // call the second interceptor, and so on. 49 | func InterceptClientConnUnary(ch grpc.ClientConnInterface, unaryInt ...grpc.UnaryClientInterceptor) grpc.ClientConnInterface { 50 | if len(unaryInt) == 0 { 51 | return ch 52 | } 53 | var streamInt grpc.StreamClientInterceptor 54 | intCh, ok := ch.(*interceptedChannel) 55 | if ok { 56 | // Instead of building a chain of multiple interceptedChannels, build 57 | // a single interceptedChannel with the combined set of interceptors. 58 | ch = intCh.ch 59 | if intCh.unaryInt != nil { 60 | unaryInt = append(unaryInt, intCh.unaryInt) 61 | } 62 | streamInt = intCh.streamInt 63 | } 64 | return &interceptedChannel{ch: ch, unaryInt: chainUnaryClient(unaryInt), streamInt: streamInt} 65 | } 66 | 67 | // InterceptClientConnStream returns a new channel that intercepts streaming 68 | // RPCs with the given chain of interceptors. If the given set of interceptors 69 | // is empty, this returns ch. Otherwise, the returned value will implement 70 | // WrappedClientConn and its Unwrap() method will return ch. 71 | // 72 | // The first interceptor in the set will be the first one invoked when an RPC 73 | // is called. When that interceptor delegates to the provided invoker, it will 74 | // call the second interceptor, and so on. 75 | func InterceptClientConnStream(ch grpc.ClientConnInterface, streamInt ...grpc.StreamClientInterceptor) grpc.ClientConnInterface { 76 | if len(streamInt) == 0 { 77 | return ch 78 | } 79 | var unaryInt grpc.UnaryClientInterceptor 80 | intCh, ok := ch.(*interceptedChannel) 81 | if ok { 82 | // Instead of building a chain of multiple interceptedChannels, build 83 | // a single interceptedChannel with the combined set of interceptors. 84 | ch = intCh.ch 85 | unaryInt = intCh.unaryInt 86 | if intCh.streamInt != nil { 87 | streamInt = append(streamInt, intCh.streamInt) 88 | } 89 | } 90 | return &interceptedChannel{ch: ch, unaryInt: unaryInt, streamInt: chainStreamClient(streamInt)} 91 | } 92 | 93 | type interceptedChannel struct { 94 | ch grpc.ClientConnInterface 95 | unaryInt grpc.UnaryClientInterceptor 96 | streamInt grpc.StreamClientInterceptor 97 | } 98 | 99 | func (intch *interceptedChannel) Unwrap() grpc.ClientConnInterface { 100 | return intch.ch 101 | } 102 | 103 | func unwrap(ch grpc.ClientConnInterface) grpc.ClientConnInterface { 104 | // completely unwrap to find the root ClientConn 105 | for { 106 | w, ok := ch.(WrappedClientConn) 107 | if !ok { 108 | return ch 109 | } 110 | unwrapped := w.Unwrap() 111 | if unwrapped == nil { 112 | return ch 113 | } 114 | ch = unwrapped 115 | } 116 | } 117 | 118 | func (intch *interceptedChannel) Invoke(ctx context.Context, methodName string, req, resp any, opts ...grpc.CallOption) error { 119 | if intch.unaryInt == nil { 120 | return intch.ch.Invoke(ctx, methodName, req, resp, opts...) 121 | } 122 | cc, _ := unwrap(intch.ch).(*grpc.ClientConn) 123 | return intch.unaryInt(ctx, methodName, req, resp, cc, intch.unaryInvoker, opts...) 124 | } 125 | 126 | func (intch *interceptedChannel) unaryInvoker(ctx context.Context, methodName string, req, resp any, cc *grpc.ClientConn, opts ...grpc.CallOption) error { 127 | return intch.ch.Invoke(ctx, methodName, req, resp, opts...) 128 | } 129 | 130 | func (intch *interceptedChannel) NewStream(ctx context.Context, desc *grpc.StreamDesc, methodName string, opts ...grpc.CallOption) (grpc.ClientStream, error) { 131 | if intch.streamInt == nil { 132 | return intch.ch.NewStream(ctx, desc, methodName, opts...) 133 | } 134 | cc, _ := intch.ch.(*grpc.ClientConn) 135 | return intch.streamInt(ctx, desc, cc, methodName, intch.streamer, opts...) 136 | } 137 | 138 | func (intch *interceptedChannel) streamer(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, methodName string, opts ...grpc.CallOption) (grpc.ClientStream, error) { 139 | return intch.ch.NewStream(ctx, desc, methodName, opts...) 140 | } 141 | 142 | var _ grpc.ClientConnInterface = (*interceptedChannel)(nil) 143 | 144 | func chainUnaryClient(unaryInt []grpc.UnaryClientInterceptor) grpc.UnaryClientInterceptor { 145 | if len(unaryInt) == 1 { 146 | return unaryInt[0] 147 | } 148 | return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { 149 | for i := range unaryInt { 150 | currInterceptor := unaryInt[len(unaryInt)-i-1] // going backwards through the chain 151 | currInvoker := invoker 152 | invoker = func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, opts ...grpc.CallOption) error { 153 | return currInterceptor(ctx, method, req, reply, cc, currInvoker, opts...) 154 | } 155 | } 156 | return invoker(ctx, method, req, reply, cc, opts...) 157 | } 158 | } 159 | 160 | func chainStreamClient(streamInt []grpc.StreamClientInterceptor) grpc.StreamClientInterceptor { 161 | if len(streamInt) == 1 { 162 | return streamInt[0] 163 | } 164 | return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { 165 | for i := range streamInt { 166 | currInterceptor := streamInt[len(streamInt)-i-1] // going backwards through the chain 167 | currStreamer := streamer 168 | streamer = func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { 169 | return currInterceptor(ctx, desc, cc, method, currStreamer, opts...) 170 | } 171 | } 172 | return streamer(ctx, desc, cc, method, opts...) 173 | } 174 | } 175 | 176 | // WithInterceptor returns a view of the given ServiceRegistrar that will 177 | // automatically apply the given interceptors to all registered services. 178 | func WithInterceptor(reg grpc.ServiceRegistrar, unaryInt grpc.UnaryServerInterceptor, streamInt grpc.StreamServerInterceptor) grpc.ServiceRegistrar { 179 | if unaryInt != nil { 180 | reg = WithUnaryInterceptors(reg, unaryInt) 181 | } 182 | if streamInt != nil { 183 | reg = WithStreamInterceptors(reg, streamInt) 184 | } 185 | return reg 186 | } 187 | 188 | // WithUnaryInterceptors returns a view of the given ServiceRegistrar that will 189 | // automatically apply the given interceptors to all registered services. 190 | func WithUnaryInterceptors(reg grpc.ServiceRegistrar, unaryInt ...grpc.UnaryServerInterceptor) grpc.ServiceRegistrar { 191 | if len(unaryInt) == 0 { 192 | return reg 193 | } 194 | var streamInt grpc.StreamServerInterceptor 195 | intReg, ok := reg.(*interceptingRegistry) 196 | if ok { 197 | // Instead of building a chain of multiple interceptingRegistry instances, 198 | // build a single interceptingRegistry with the combined set of interceptors. 199 | reg = intReg.reg 200 | if intReg.unaryInt != nil { 201 | unaryInt = append(unaryInt, intReg.unaryInt) 202 | } 203 | streamInt = intReg.streamInt 204 | } 205 | return &interceptingRegistry{reg: reg, unaryInt: chainUnaryServer(unaryInt), streamInt: streamInt} 206 | } 207 | 208 | func WithStreamInterceptors(reg grpc.ServiceRegistrar, streamInt ...grpc.StreamServerInterceptor) grpc.ServiceRegistrar { 209 | if len(streamInt) == 0 { 210 | return reg 211 | } 212 | var unaryInt grpc.UnaryServerInterceptor 213 | intReg, ok := reg.(*interceptingRegistry) 214 | if ok { 215 | // Instead of building a chain of multiple interceptingRegistry instances, 216 | // build a single interceptingRegistry with the combined set of interceptors. 217 | reg = intReg.reg 218 | unaryInt = intReg.unaryInt 219 | if intReg.streamInt != nil { 220 | streamInt = append(streamInt, intReg.streamInt) 221 | } 222 | } 223 | return &interceptingRegistry{reg: reg, unaryInt: unaryInt, streamInt: chainStreamServer(streamInt)} 224 | } 225 | 226 | type interceptingRegistry struct { 227 | reg grpc.ServiceRegistrar 228 | unaryInt grpc.UnaryServerInterceptor 229 | streamInt grpc.StreamServerInterceptor 230 | } 231 | 232 | func (r *interceptingRegistry) RegisterService(desc *grpc.ServiceDesc, srv any) { 233 | r.reg.RegisterService(InterceptServer(desc, r.unaryInt, r.streamInt), srv) 234 | } 235 | 236 | // InterceptServer returns a new service description that will intercepts RPCs 237 | // with the given interceptors. If both given interceptors are nil, returns 238 | // svcDesc. 239 | func InterceptServer(svcDesc *grpc.ServiceDesc, unaryInt grpc.UnaryServerInterceptor, streamInt grpc.StreamServerInterceptor) *grpc.ServiceDesc { 240 | if unaryInt == nil && streamInt == nil { 241 | return svcDesc 242 | } 243 | intercepted := *svcDesc 244 | 245 | if unaryInt != nil { 246 | intercepted.Methods = make([]grpc.MethodDesc, len(svcDesc.Methods)) 247 | for i, md := range svcDesc.Methods { 248 | origHandler := md.Handler 249 | intercepted.Methods[i] = grpc.MethodDesc{ 250 | MethodName: md.MethodName, 251 | Handler: func(srv any, ctx context.Context, dec func(any) error, interceptor grpc.UnaryServerInterceptor) (any, error) { 252 | combinedInterceptor := unaryInt 253 | if interceptor != nil { 254 | // combine unaryInt with the interceptor provided to handler 255 | combinedInterceptor = func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { 256 | h := func(ctx context.Context, req any) (any, error) { 257 | return unaryInt(ctx, req, info, handler) 258 | } 259 | // we first call provided interceptor, but supply a handler that will call unaryInt 260 | return interceptor(ctx, req, info, h) 261 | } 262 | } 263 | return origHandler(srv, ctx, dec, combinedInterceptor) 264 | }, 265 | } 266 | } 267 | } 268 | 269 | if streamInt != nil { 270 | intercepted.Streams = make([]grpc.StreamDesc, len(svcDesc.Streams)) 271 | for i, sd := range svcDesc.Streams { 272 | origHandler := sd.Handler 273 | info := &grpc.StreamServerInfo{ 274 | FullMethod: fmt.Sprintf("/%s/%s", svcDesc.ServiceName, sd.StreamName), 275 | IsClientStream: sd.ClientStreams, 276 | IsServerStream: sd.ServerStreams, 277 | } 278 | intercepted.Streams[i] = grpc.StreamDesc{ 279 | StreamName: sd.StreamName, 280 | ClientStreams: sd.ClientStreams, 281 | ServerStreams: sd.ServerStreams, 282 | Handler: func(srv any, stream grpc.ServerStream) error { 283 | return streamInt(srv, stream, info, origHandler) 284 | }, 285 | } 286 | } 287 | } 288 | 289 | return &intercepted 290 | } 291 | 292 | func chainUnaryServer(unaryInt []grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor { 293 | if len(unaryInt) == 1 { 294 | return unaryInt[0] 295 | } 296 | return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { 297 | for i := range unaryInt { 298 | currInterceptor := unaryInt[len(unaryInt)-i-1] // going backwards through the chain 299 | currHandler := handler 300 | handler = func(ctx context.Context, req any) (any, error) { 301 | return currInterceptor(ctx, req, info, currHandler) 302 | } 303 | } 304 | return handler(ctx, req) 305 | } 306 | } 307 | 308 | func chainStreamServer(streamInt []grpc.StreamServerInterceptor) grpc.StreamServerInterceptor { 309 | if len(streamInt) == 1 { 310 | return streamInt[0] 311 | } 312 | return func(impl any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 313 | for i := range streamInt { 314 | currInterceptor := streamInt[len(streamInt)-i-1] // going backwards through the chain 315 | currHandler := handler 316 | handler = func(impl any, stream grpc.ServerStream) error { 317 | return currInterceptor(impl, stream, info, currHandler) 318 | } 319 | } 320 | return handler(impl, stream) 321 | } 322 | } 323 | -------------------------------------------------------------------------------- /intercept_client_test.go: -------------------------------------------------------------------------------- 1 | package grpchan_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "testing" 8 | 9 | "google.golang.org/grpc" 10 | "google.golang.org/grpc/codes" 11 | "google.golang.org/grpc/metadata" 12 | "google.golang.org/grpc/status" 13 | "google.golang.org/protobuf/proto" 14 | 15 | "github.com/fullstorydev/grpchan" 16 | "github.com/fullstorydev/grpchan/grpchantesting" 17 | "github.com/fullstorydev/grpchan/internal" 18 | ) 19 | 20 | func TestInterceptClientConnUnary(t *testing.T) { 21 | tc := testConn{} 22 | 23 | var successCount, failCount int 24 | intercepted := grpchan.InterceptClientConn(&tc, 25 | func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { 26 | if err := invoker(ctx, method, req, reply, cc, opts...); err != nil { 27 | failCount++ 28 | return err 29 | } 30 | successCount++ 31 | return nil 32 | }, nil) 33 | 34 | cli := grpchantesting.NewTestServiceClient(intercepted) 35 | 36 | // success 37 | tc.resp = &grpchantesting.Message{Count: 123} 38 | resp, err := cli.Unary(context.Background(), &grpchantesting.Message{}) 39 | if err != nil { 40 | t.Fatalf("RPC failed: %v", err) 41 | } 42 | if !proto.Equal(resp, tc.resp.(proto.Message)) { 43 | t.Fatalf("unexpected reply: %v != %v", resp, tc.resp) 44 | } 45 | 46 | // failure 47 | ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("foo", "bar")) 48 | tc.code = codes.Aborted 49 | _, err = cli.Unary(ctx, &grpchantesting.Message{Count: 456}) 50 | if err == nil { 51 | t.Fatalf("expected RPC to fail") 52 | } 53 | s, ok := status.FromError(err) 54 | if !ok { 55 | t.Fatalf("wrong type of error %T: %v", err, err) 56 | } 57 | if s.Code() != codes.Aborted { 58 | t.Fatalf("wrong error code: %v != %v", s.Code(), codes.Aborted) 59 | } 60 | 61 | // check observed state 62 | if successCount != 1 { 63 | t.Fatalf("interceptor observed wrong number of successful RPCs: expecting %d, got %d", 1, successCount) 64 | } 65 | if failCount != 1 { 66 | t.Fatalf("interceptor observed wrong number of failed RPCs: expecting %d, got %d", 1, failCount) 67 | } 68 | 69 | expected := []*call{ 70 | { 71 | methodName: "/grpchantesting.TestService/Unary", 72 | reqs: []proto.Message{&grpchantesting.Message{}}, 73 | headers: nil, 74 | }, 75 | { 76 | methodName: "/grpchantesting.TestService/Unary", 77 | reqs: []proto.Message{&grpchantesting.Message{Count: 456}}, 78 | headers: metadata.Pairs("foo", "bar"), 79 | }, 80 | } 81 | 82 | checkCalls(t, expected, tc.calls) 83 | } 84 | 85 | func TestInterceptClientConnStream(t *testing.T) { 86 | tc := testConn{} 87 | 88 | var messageCount, successCount, failCount int 89 | intercepted := grpchan.InterceptClientConn(&tc, nil, 90 | func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { 91 | cs, err := streamer(ctx, desc, cc, method, opts...) 92 | if err != nil { 93 | return nil, err 94 | } 95 | return &testInterceptClientStream{ 96 | ClientStream: cs, 97 | messageCount: &messageCount, 98 | successCount: &successCount, 99 | failCount: &failCount, 100 | serverStreams: desc.ServerStreams, 101 | }, nil 102 | }) 103 | 104 | cli := grpchantesting.NewTestServiceClient(intercepted) 105 | 106 | // client stream, success 107 | tc.resp = &grpchantesting.Message{Count: 123} 108 | cs, err := cli.ClientStream(context.Background()) 109 | if err != nil { 110 | t.Fatalf("RPC failed: %v", err) 111 | } 112 | 113 | err = cs.Send(&grpchantesting.Message{}) 114 | if err != nil { 115 | t.Fatalf("sending request #1 failed: %v", err) 116 | } 117 | err = cs.Send(&grpchantesting.Message{Count: 1}) 118 | if err != nil { 119 | t.Fatalf("sending request #2 failed: %v", err) 120 | } 121 | err = cs.Send(&grpchantesting.Message{Count: 42}) 122 | if err != nil { 123 | t.Fatalf("sending request #3 failed: %v", err) 124 | } 125 | resp, err := cs.CloseAndRecv() 126 | if err != nil { 127 | t.Fatalf("failed to receive response: %v", err) 128 | } 129 | if !proto.Equal(resp, tc.resp.(proto.Message)) { 130 | t.Fatalf("unexpected reply: %v != %v", resp, tc.resp) 131 | } 132 | 133 | // server stream, success 134 | ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("foo", "bar")) 135 | tc.respCount = 5 136 | ss, err := cli.ServerStream(ctx, &grpchantesting.Message{Count: 456}) 137 | if err != nil { 138 | t.Fatalf("RPC failed: %v", err) 139 | } 140 | 141 | for i := 0; i < 5; i++ { 142 | resp, err = ss.Recv() 143 | if err != nil { 144 | t.Fatalf("failed to receive response #%d: %v", i+1, err) 145 | } 146 | if !proto.Equal(resp, tc.resp.(proto.Message)) { 147 | t.Fatalf("unexpected reply #%d: %v != %v", i+1, resp, tc.resp) 148 | } 149 | } 150 | 151 | _, err = ss.Recv() 152 | if err != io.EOF { 153 | t.Fatalf("expected EOF, instead got %v", err) 154 | } 155 | 156 | // bidi stream, failure 157 | ctx = metadata.NewOutgoingContext(context.Background(), metadata.Pairs("foo", "baz")) 158 | tc.code = codes.Aborted 159 | bs, err := cli.BidiStream(ctx) 160 | if err != nil { 161 | t.Fatalf("RPC failed: %v", err) 162 | } 163 | 164 | err = bs.Send(&grpchantesting.Message{Count: 333}) 165 | if err != nil { 166 | t.Fatalf("sending request #1 failed: %v", err) 167 | } 168 | err = bs.Send(&grpchantesting.Message{Count: 222}) 169 | if err != nil { 170 | t.Fatalf("sending request #2 failed: %v", err) 171 | } 172 | err = bs.Send(&grpchantesting.Message{Count: 111}) 173 | if err != nil { 174 | t.Fatalf("sending request #3 failed: %v", err) 175 | } 176 | 177 | for i := 0; i < 5; i++ { 178 | resp, err = bs.Recv() 179 | if err != nil { 180 | t.Fatalf("failed to receive response #%d: %v", i+1, err) 181 | } 182 | if !proto.Equal(resp, tc.resp.(proto.Message)) { 183 | t.Fatalf("unexpected reply #%d: %v != %v", i+1, resp, tc.resp) 184 | } 185 | } 186 | 187 | _, err = bs.Recv() 188 | if err == nil { 189 | t.Fatalf("expected RPC to fail") 190 | } 191 | s, ok := status.FromError(err) 192 | if !ok { 193 | t.Fatalf("wrong type of error %T: %v", err, err) 194 | } 195 | if s.Code() != codes.Aborted { 196 | t.Fatalf("wrong error code: %v != %v", s.Code(), codes.Aborted) 197 | } 198 | 199 | // check observed state 200 | expectedMessages := 1 + 5 + 5 201 | if messageCount != expectedMessages { 202 | t.Fatalf("interceptor observed wrong number of response messages: expecting %d, got %d", expectedMessages, messageCount) 203 | } 204 | if successCount != 2 { 205 | t.Fatalf("interceptor observed wrong number of successful RPCs: expecting %d, got %d", 2, successCount) 206 | } 207 | if failCount != 1 { 208 | t.Fatalf("interceptor observed wrong number of failed RPCs: expecting %d, got %d", 1, failCount) 209 | } 210 | 211 | expected := []*call{ 212 | { 213 | methodName: "/grpchantesting.TestService/ClientStream", 214 | reqs: []proto.Message{ 215 | &grpchantesting.Message{}, 216 | &grpchantesting.Message{Count: 1}, 217 | &grpchantesting.Message{Count: 42}, 218 | }, 219 | headers: nil, 220 | }, 221 | { 222 | methodName: "/grpchantesting.TestService/ServerStream", 223 | reqs: []proto.Message{&grpchantesting.Message{Count: 456}}, 224 | headers: metadata.Pairs("foo", "bar"), 225 | }, 226 | { 227 | methodName: "/grpchantesting.TestService/BidiStream", 228 | reqs: []proto.Message{ 229 | &grpchantesting.Message{Count: 333}, 230 | &grpchantesting.Message{Count: 222}, 231 | &grpchantesting.Message{Count: 111}, 232 | }, 233 | headers: metadata.Pairs("foo", "baz"), 234 | }, 235 | } 236 | 237 | checkCalls(t, expected, tc.calls) 238 | } 239 | 240 | type testInterceptClientStream struct { 241 | grpc.ClientStream 242 | messageCount, successCount, failCount *int 243 | serverStreams, closed bool 244 | } 245 | 246 | func (s *testInterceptClientStream) RecvMsg(m interface{}) error { 247 | err := s.ClientStream.RecvMsg(m) 248 | if err == nil { 249 | *s.messageCount++ 250 | if !s.serverStreams { 251 | s.closed = true 252 | *s.successCount++ 253 | } 254 | } else if !s.closed { 255 | s.closed = true 256 | if err == io.EOF { 257 | *s.successCount++ 258 | } else { 259 | *s.failCount++ 260 | } 261 | } 262 | return err 263 | } 264 | 265 | // testConn is a dummy channel that just records all incoming activity. 266 | // 267 | // If code is set and not codes.OK, RPCs will fail with that code. 268 | // 269 | // If resp is set, unary RPCs will reply with that value. If unset, unary 270 | // RPCs will reply with empty response message. 271 | // 272 | // If resp is set and respCount is non-zero, server-streaming RPCs (including 273 | // bidi streams) will reply with the given number of responses. Otherwise, 274 | // they reply with an empty stream. 275 | // 276 | // Streaming RPCs will receive the specified headers and trailers as response 277 | // metadata, if those fields are set. 278 | // 279 | // testConn is not thread-safe, and neither are any returned streams. 280 | type testConn struct { 281 | code codes.Code 282 | resp interface{} 283 | respCount int 284 | headers metadata.MD 285 | trailers metadata.MD 286 | calls []*call 287 | } 288 | 289 | type call struct { 290 | methodName string 291 | headers metadata.MD 292 | reqs []proto.Message 293 | } 294 | 295 | func (ch *testConn) Invoke(ctx context.Context, methodName string, req, resp interface{}, _ ...grpc.CallOption) error { 296 | headers, _ := metadata.FromOutgoingContext(ctx) 297 | reqClone, err := internal.CloneMessage(req) 298 | if err != nil { 299 | return err 300 | } 301 | ch.calls = append(ch.calls, &call{methodName: methodName, headers: headers, reqs: []proto.Message{reqClone.(proto.Message)}}) 302 | if ch.code != codes.OK { 303 | return status.Error(ch.code, ch.code.String()) 304 | } 305 | if ch.resp != nil { 306 | return internal.CopyMessage(resp, ch.resp) 307 | } 308 | return internal.ClearMessage(resp) 309 | } 310 | 311 | func (ch *testConn) NewStream(ctx context.Context, desc *grpc.StreamDesc, methodName string, _ ...grpc.CallOption) (grpc.ClientStream, error) { 312 | headers, _ := metadata.FromOutgoingContext(ctx) 313 | call := &call{methodName: methodName, headers: headers} 314 | ch.calls = append(ch.calls, call) 315 | count := ch.respCount 316 | if !desc.ServerStreams { 317 | if ch.code == codes.OK { 318 | count = 1 319 | } else { 320 | count = 0 321 | } 322 | } 323 | return &testClientStream{ 324 | ctx: ctx, 325 | code: ch.code, 326 | resp: ch.resp, 327 | respCount: count, 328 | headers: ch.headers, 329 | trailers: ch.trailers, 330 | call: call, 331 | }, nil 332 | } 333 | 334 | type testClientStream struct { 335 | ctx context.Context 336 | code codes.Code 337 | resp interface{} 338 | respCount int 339 | headers metadata.MD 340 | trailers metadata.MD 341 | call *call 342 | halfClosed bool 343 | closed bool 344 | } 345 | 346 | func (s *testClientStream) Header() (metadata.MD, error) { 347 | return s.headers, nil 348 | } 349 | 350 | func (s *testClientStream) Trailer() metadata.MD { 351 | return s.trailers 352 | } 353 | 354 | func (s *testClientStream) CloseSend() error { 355 | s.halfClosed = true 356 | return nil 357 | } 358 | 359 | func (s *testClientStream) Context() context.Context { 360 | return s.ctx 361 | } 362 | 363 | func (s *testClientStream) SendMsg(m interface{}) error { 364 | if s.halfClosed { 365 | return fmt.Errorf("stream closed") 366 | } 367 | if s.closed { 368 | return io.EOF 369 | } 370 | if err := s.ctx.Err(); err != nil { 371 | return internal.TranslateContextError(err) 372 | } 373 | mClone, err := internal.CloneMessage(m) 374 | if err != nil { 375 | return err 376 | } 377 | s.call.reqs = append(s.call.reqs, mClone.(proto.Message)) 378 | return nil 379 | } 380 | 381 | func (s *testClientStream) RecvMsg(m interface{}) error { 382 | if err := s.ctx.Err(); err != nil { 383 | return internal.TranslateContextError(err) 384 | } 385 | if s.respCount == 0 { 386 | s.closed = true 387 | if s.code == codes.OK { 388 | return io.EOF 389 | } else { 390 | return status.Error(s.code, s.code.String()) 391 | } 392 | } 393 | 394 | s.respCount-- 395 | if s.resp != nil { 396 | return internal.CopyMessage(m, s.resp) 397 | } 398 | return internal.ClearMessage(m) 399 | } 400 | -------------------------------------------------------------------------------- /intercept_server_test.go: -------------------------------------------------------------------------------- 1 | package grpchan_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "reflect" 8 | "testing" 9 | 10 | "google.golang.org/grpc" 11 | "google.golang.org/grpc/codes" 12 | "google.golang.org/grpc/metadata" 13 | "google.golang.org/grpc/status" 14 | "google.golang.org/protobuf/proto" 15 | "google.golang.org/protobuf/types/known/emptypb" 16 | 17 | "github.com/fullstorydev/grpchan" 18 | "github.com/fullstorydev/grpchan/grpchantesting" 19 | "github.com/fullstorydev/grpchan/internal" 20 | ) 21 | 22 | func TestInterceptServerUnary(t *testing.T) { 23 | svr := &testServer{} 24 | handlers := grpchan.HandlerMap{} 25 | 26 | // this will make sure unary interceptors are composed correctly 27 | var lastSeen string 28 | outerInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { 29 | lastSeen = "a" 30 | return handler(ctx, req) 31 | } 32 | 33 | var successCount, failCount int 34 | grpchantesting.RegisterTestServiceServer(grpchan.WithInterceptor(handlers, 35 | func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { 36 | if lastSeen != "a" { 37 | // interceptor above should have been invoked first! 38 | return nil, fmt.Errorf("interceptor not correctly invoked!") 39 | } 40 | lastSeen = "b" 41 | resp, err := handler(ctx, req) 42 | if err != nil { 43 | failCount++ 44 | } else { 45 | successCount++ 46 | } 47 | return resp, err 48 | }, nil), svr) 49 | 50 | sd, ss := handlers.QueryService("grpchantesting.TestService") 51 | // sanity check 52 | if ss != svr { 53 | t.Fatalf("queried handler does not match registered handler! %v != %v", ss, svr) 54 | } 55 | if sd == nil { 56 | t.Fatalf("service descriptor not found") 57 | } 58 | 59 | // get handler for the method we're going to invoke 60 | md := internal.FindUnaryMethod("Unary", sd.Methods) 61 | if md == nil { 62 | t.Fatalf("method descriptor not found") 63 | } 64 | 65 | // success 66 | svr.resp = &grpchantesting.Message{Count: 123} 67 | var m grpchantesting.Message 68 | dec := func(req interface{}) error { 69 | reqMsg := req.(*grpchantesting.Message) 70 | proto.Reset(reqMsg) 71 | proto.Merge(reqMsg, &m) 72 | return nil 73 | } 74 | resp, err := md.Handler(svr, context.Background(), dec, outerInt) 75 | if err != nil { 76 | t.Fatalf("RPC failed: %v", err) 77 | } 78 | if !reflect.DeepEqual(resp, svr.resp) { 79 | t.Fatalf("unexpected reply: expecting %v; got %v", svr.resp, resp) 80 | } 81 | if lastSeen != "b" { 82 | t.Fatalf("interceptors not composed correctly") 83 | } 84 | 85 | // failure 86 | ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs("foo", "bar")) 87 | svr.code = codes.Aborted 88 | m = grpchantesting.Message{Count: 456} 89 | _, err = md.Handler(svr, ctx, dec, outerInt) 90 | if err == nil { 91 | t.Fatalf("expected RPC to fail") 92 | } 93 | s, ok := status.FromError(err) 94 | if !ok { 95 | t.Fatalf("wrong type of error %T: %v", err, err) 96 | } 97 | if s.Code() != codes.Aborted { 98 | t.Fatalf("wrong error code: %v != %v", s.Code(), codes.Aborted) 99 | } 100 | if lastSeen != "b" { 101 | t.Fatalf("interceptors not composed correctly") 102 | } 103 | 104 | // check observed state 105 | if successCount != 1 { 106 | t.Fatalf("interceptor observed wrong number of successful RPCs: expecting %d, got %d", 1, successCount) 107 | } 108 | if failCount != 1 { 109 | t.Fatalf("interceptor observed wrong number of failed RPCs: expecting %d, got %d", 1, failCount) 110 | } 111 | 112 | expected := []*call{ 113 | { 114 | methodName: "Unary", 115 | reqs: []proto.Message{&grpchantesting.Message{}}, 116 | headers: nil, 117 | }, 118 | { 119 | methodName: "Unary", 120 | reqs: []proto.Message{&grpchantesting.Message{Count: 456}}, 121 | headers: metadata.Pairs("foo", "bar"), 122 | }, 123 | } 124 | 125 | checkCalls(t, expected, svr.calls) 126 | } 127 | 128 | func TestInterceptServerStream(t *testing.T) { 129 | svr := &testServer{} 130 | handlers := grpchan.HandlerMap{} 131 | 132 | var messageCount, successCount, failCount int 133 | grpchantesting.RegisterTestServiceServer(grpchan.WithInterceptor(handlers, nil, 134 | func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 135 | err := handler(srv, &testInterceptServerStream{ 136 | ServerStream: ss, 137 | messageCount: &messageCount, 138 | }) 139 | if err != nil { 140 | failCount++ 141 | } else { 142 | successCount++ 143 | } 144 | return err 145 | }), svr) 146 | 147 | sd, ss := handlers.QueryService("grpchantesting.TestService") 148 | // sanity check 149 | if ss != svr { 150 | t.Fatalf("queried handler does not match registered handler! %v != %v", ss, svr) 151 | } 152 | if sd == nil { 153 | t.Fatalf("service descriptor not found") 154 | } 155 | 156 | // get handlers for the methods we're going to invoke 157 | csdesc := internal.FindStreamingMethod("ClientStream", sd.Streams) 158 | if csdesc == nil { 159 | t.Fatalf("ClientStream stream descriptor not found") 160 | } 161 | ssdesc := internal.FindStreamingMethod("ServerStream", sd.Streams) 162 | if ssdesc == nil { 163 | t.Fatalf("ServerStream stream descriptor not found") 164 | } 165 | bsdesc := internal.FindStreamingMethod("BidiStream", sd.Streams) 166 | if bsdesc == nil { 167 | t.Fatalf("BidiStream stream descriptor not found") 168 | } 169 | 170 | // client stream, success 171 | svr.resp = &grpchantesting.Message{Count: 123} 172 | str := &testServerStream{ 173 | ctx: context.Background(), 174 | reqs: []proto.Message{ 175 | &grpchantesting.Message{}, 176 | &grpchantesting.Message{Count: 1}, 177 | &grpchantesting.Message{Count: 42}, 178 | }, 179 | } 180 | err := csdesc.Handler(svr, str) 181 | if err != nil { 182 | t.Fatalf("RPC failed: %v", err) 183 | } 184 | 185 | replies := []proto.Message{svr.resp} 186 | checkProtosEqual(t, replies, str.resps) 187 | 188 | // server stream, success 189 | ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs("foo", "bar")) 190 | svr.respCount = 5 191 | str = &testServerStream{ 192 | ctx: ctx, 193 | reqs: []proto.Message{ 194 | &grpchantesting.Message{Count: 456}, 195 | }, 196 | } 197 | err = ssdesc.Handler(svr, str) 198 | if err != nil { 199 | t.Fatalf("RPC failed: %v", err) 200 | } 201 | 202 | replies = []proto.Message{svr.resp, svr.resp, svr.resp, svr.resp, svr.resp} // five of 'em 203 | checkProtosEqual(t, replies, str.resps) 204 | 205 | // bidi stream, failure 206 | ctx = metadata.NewIncomingContext(context.Background(), metadata.Pairs("foo", "baz")) 207 | svr.code = codes.Aborted 208 | str = &testServerStream{ 209 | ctx: ctx, 210 | reqs: []proto.Message{ 211 | &grpchantesting.Message{Count: 333}, 212 | &grpchantesting.Message{Count: 222}, 213 | &grpchantesting.Message{Count: 111}, 214 | }, 215 | } 216 | err = bsdesc.Handler(svr, str) 217 | if err == nil { 218 | t.Fatalf("expected RPC to fail") 219 | } 220 | s, ok := status.FromError(err) 221 | if !ok { 222 | t.Fatalf("wrong type of error %T: %v", err, err) 223 | } 224 | if s.Code() != codes.Aborted { 225 | t.Fatalf("wrong error code: %v != %v", s.Code(), codes.Aborted) 226 | } 227 | 228 | checkProtosEqual(t, replies, str.resps) 229 | 230 | // check observed state 231 | expectedMessages := 1 + 5 + 5 232 | if messageCount != expectedMessages { 233 | t.Fatalf("interceptor observed wrong number of response messages: expecting %d, got %d", expectedMessages, messageCount) 234 | } 235 | if successCount != 2 { 236 | t.Fatalf("interceptor observed wrong number of successful RPCs: expecting %d, got %d", 2, successCount) 237 | } 238 | if failCount != 1 { 239 | t.Fatalf("interceptor observed wrong number of failed RPCs: expecting %d, got %d", 1, failCount) 240 | } 241 | 242 | expected := []*call{ 243 | { 244 | methodName: "ClientStream", 245 | reqs: []proto.Message{ 246 | &grpchantesting.Message{}, 247 | &grpchantesting.Message{Count: 1}, 248 | &grpchantesting.Message{Count: 42}, 249 | }, 250 | headers: nil, 251 | }, 252 | { 253 | methodName: "ServerStream", 254 | reqs: []proto.Message{&grpchantesting.Message{Count: 456}}, 255 | headers: metadata.Pairs("foo", "bar"), 256 | }, 257 | { 258 | methodName: "BidiStream", 259 | reqs: []proto.Message{ 260 | &grpchantesting.Message{Count: 333}, 261 | &grpchantesting.Message{Count: 222}, 262 | &grpchantesting.Message{Count: 111}, 263 | }, 264 | headers: metadata.Pairs("foo", "baz"), 265 | }, 266 | } 267 | 268 | checkCalls(t, expected, svr.calls) 269 | } 270 | 271 | func checkProtosEqual(t *testing.T, expected, actual []proto.Message) { 272 | t.Helper() 273 | if len(actual) != len(expected) { 274 | t.Fatalf("unexpected number of replies: expecting %d; got %d", len(expected), len(actual)) 275 | } 276 | for i := range expected { 277 | if !proto.Equal(expected[i], actual[i]) { 278 | t.Fatalf("unexpected reply[%d]: expecting %v; got %v", i+1, expected[i], actual[i]) 279 | } 280 | } 281 | } 282 | 283 | func checkCalls(t *testing.T, expected, actual []*call) { 284 | t.Helper() 285 | if len(actual) != len(expected) { 286 | t.Fatalf("unexpected number of calls: expecting %d; got %d", len(expected), len(actual)) 287 | } 288 | for i := range expected { 289 | exp := expected[i] 290 | act := actual[i] 291 | if exp.methodName != act.methodName { 292 | t.Fatalf("unexpected call[%d]: expecting %q; got %q", i+1, exp.methodName, act.methodName) 293 | } 294 | if !reflect.DeepEqual(exp.headers, act.headers) { 295 | t.Fatalf("unexpected call[%d] headers: expecting %v; got %v", i+1, exp.headers, act.headers) 296 | } 297 | checkProtosEqual(t, exp.reqs, act.reqs) 298 | } 299 | } 300 | 301 | type testInterceptServerStream struct { 302 | grpc.ServerStream 303 | messageCount *int 304 | } 305 | 306 | func (s *testInterceptServerStream) SendMsg(m interface{}) error { 307 | err := s.ServerStream.SendMsg(m) 308 | if err == nil { 309 | *s.messageCount++ 310 | } 311 | return err 312 | } 313 | 314 | // testServer is a dummy server that just records all incoming activity. 315 | // 316 | // If code is set and not codes.OK, RPCs will fail with that code. 317 | // 318 | // If resp is set, unary RPCs will reply with that value. If unset, unary 319 | // RPCs will reply with empty response message. 320 | // 321 | // If resp is set and respCount is non-zero, server-streaming RPCs (including 322 | // bidi streams) will reply with the given number of responses. Otherwise, 323 | // they reply with an empty stream. 324 | // 325 | // Streaming RPCs will receive the specified headers and trailers as response 326 | // metadata, if those fields are set. 327 | // 328 | // testServer is not thread-safe. 329 | type testServer struct { 330 | grpchantesting.UnimplementedTestServiceServer 331 | code codes.Code 332 | resp proto.Message 333 | respCount int 334 | headers metadata.MD 335 | trailers metadata.MD 336 | calls []*call 337 | } 338 | 339 | func (s *testServer) Unary(ctx context.Context, req *grpchantesting.Message) (*grpchantesting.Message, error) { 340 | resp := grpchantesting.Message{} 341 | err := s.unary(ctx, "Unary", req, &resp) 342 | if err != nil { 343 | return nil, err 344 | } 345 | return &resp, nil 346 | } 347 | 348 | func (s *testServer) ClientStream(stream grpchantesting.TestService_ClientStreamServer) error { 349 | return s.stream(&grpc.StreamDesc{ 350 | StreamName: "ClientStream", 351 | ClientStreams: true, 352 | }, nil, stream) 353 | } 354 | 355 | func (s *testServer) ServerStream(req *grpchantesting.Message, stream grpchantesting.TestService_ServerStreamServer) error { 356 | return s.stream(&grpc.StreamDesc{ 357 | StreamName: "ServerStream", 358 | ServerStreams: true, 359 | }, req, stream) 360 | } 361 | 362 | func (s *testServer) BidiStream(stream grpchantesting.TestService_BidiStreamServer) error { 363 | return s.stream(&grpc.StreamDesc{ 364 | StreamName: "BidiStream", 365 | ClientStreams: true, 366 | ServerStreams: true, 367 | }, nil, stream) 368 | } 369 | 370 | func (s *testServer) UseExternalMessageTwice(ctx context.Context, req *emptypb.Empty) (*emptypb.Empty, error) { 371 | resp := emptypb.Empty{} 372 | err := s.unary(ctx, "UseExternalMessageTwice", req, &resp) 373 | if err != nil { 374 | return nil, err 375 | } 376 | return &resp, nil 377 | } 378 | 379 | func (s *testServer) unary(ctx context.Context, methodName string, req, resp proto.Message) error { 380 | headers, _ := metadata.FromIncomingContext(ctx) 381 | reqClone, err := internal.CloneMessage(req) 382 | if err != nil { 383 | return err 384 | } 385 | s.calls = append(s.calls, &call{methodName: methodName, headers: headers, reqs: []proto.Message{reqClone.(proto.Message)}}) 386 | if s.code != codes.OK { 387 | return status.Error(s.code, s.code.String()) 388 | } 389 | if s.resp != nil { 390 | return internal.CopyMessage(resp, s.resp) 391 | } 392 | return internal.ClearMessage(resp) 393 | } 394 | 395 | func (s *testServer) stream(desc *grpc.StreamDesc, req *grpchantesting.Message, stream grpc.ServerStream) error { 396 | headers, _ := metadata.FromIncomingContext(stream.Context()) 397 | call := &call{methodName: desc.StreamName, headers: headers} 398 | s.calls = append(s.calls, call) 399 | 400 | // consume requests 401 | if desc.ClientStreams { 402 | for { 403 | m := &grpchantesting.Message{} 404 | err := stream.RecvMsg(m) 405 | if err == io.EOF { 406 | break 407 | } else if err != nil { 408 | return err 409 | } 410 | call.reqs = append(call.reqs, m) 411 | } 412 | } else { 413 | call.reqs = append(call.reqs, req) 414 | } 415 | 416 | // produce responses 417 | if len(s.headers) > 0 { 418 | if err := stream.SetHeader(s.headers); err != nil { 419 | return err 420 | } 421 | } 422 | 423 | count := s.respCount 424 | if !desc.ServerStreams { 425 | if s.code == codes.OK { 426 | count = 1 427 | } else { 428 | count = 0 429 | } 430 | } 431 | for count > 0 { 432 | m := s.resp 433 | if m == nil { 434 | m = &grpchantesting.Message{} 435 | } 436 | if err := stream.SendMsg(m); err != nil { 437 | return err 438 | } 439 | count-- 440 | } 441 | 442 | if len(s.trailers) > 0 { 443 | stream.SetTrailer(s.trailers) 444 | } 445 | 446 | if s.code != codes.OK { 447 | return status.Error(s.code, s.code.String()) 448 | } 449 | return nil 450 | } 451 | 452 | type testServerStream struct { 453 | ctx context.Context 454 | reqs []proto.Message 455 | resps []proto.Message 456 | headers metadata.MD 457 | headersSent bool 458 | trailers metadata.MD 459 | } 460 | 461 | func (s *testServerStream) SetHeader(md metadata.MD) error { 462 | if s.headersSent { 463 | return fmt.Errorf("headers already sent") 464 | } 465 | s.headers = metadata.Join(s.headers, md) 466 | return nil 467 | } 468 | 469 | func (s *testServerStream) SendHeader(md metadata.MD) error { 470 | if err := s.SetHeader(md); err != nil { 471 | return err 472 | } 473 | s.headersSent = true 474 | return nil 475 | } 476 | 477 | func (s *testServerStream) SetTrailer(md metadata.MD) { 478 | s.trailers = metadata.Join(s.trailers, md) 479 | } 480 | 481 | func (s *testServerStream) Context() context.Context { 482 | return s.ctx 483 | } 484 | 485 | func (s *testServerStream) SendMsg(m interface{}) error { 486 | if err := s.ctx.Err(); err != nil { 487 | return internal.TranslateContextError(err) 488 | } 489 | mClone, err := internal.CloneMessage(m) 490 | if err != nil { 491 | return err 492 | } 493 | s.resps = append(s.resps, mClone.(proto.Message)) 494 | return nil 495 | } 496 | 497 | func (s *testServerStream) RecvMsg(m interface{}) error { 498 | if len(s.reqs) == 0 { 499 | return io.EOF 500 | } 501 | req := s.reqs[0] 502 | s.reqs = s.reqs[1:] 503 | if req != nil { 504 | return internal.CopyMessage(m, req) 505 | } 506 | return internal.ClearMessage(m) 507 | } 508 | -------------------------------------------------------------------------------- /interceptor_chain_client_test.go: -------------------------------------------------------------------------------- 1 | package grpchan 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "reflect" 9 | "strconv" 10 | "testing" 11 | 12 | "google.golang.org/grpc" 13 | "google.golang.org/grpc/codes" 14 | "google.golang.org/grpc/metadata" 15 | "google.golang.org/grpc/status" 16 | 17 | "github.com/fullstorydev/grpchan/internal" 18 | ) 19 | 20 | var interceptorChainCases = []struct { 21 | name string 22 | setupClient func(grpc.ClientConnInterface) grpc.ClientConnInterface 23 | setupServer func(grpc.ServiceRegistrar) grpc.ServiceRegistrar 24 | unaryIntercepted bool 25 | streamIntercepted bool 26 | }{ 27 | { 28 | name: "batch", 29 | setupClient: func(conn grpc.ClientConnInterface) grpc.ClientConnInterface { 30 | return setupClientChainBatch(conn) 31 | }, 32 | setupServer: func(reg grpc.ServiceRegistrar) grpc.ServiceRegistrar { 33 | return setupServerChainBatch(reg) 34 | }, 35 | unaryIntercepted: true, 36 | streamIntercepted: true, 37 | }, 38 | { 39 | name: "singles", 40 | setupClient: func(conn grpc.ClientConnInterface) grpc.ClientConnInterface { 41 | return setupClientChainSingles(conn) 42 | }, 43 | setupServer: func(reg grpc.ServiceRegistrar) grpc.ServiceRegistrar { 44 | return setupServerChainSingles(reg) 45 | }, 46 | unaryIntercepted: true, 47 | streamIntercepted: true, 48 | }, 49 | { 50 | name: "pairs", 51 | setupClient: func(conn grpc.ClientConnInterface) grpc.ClientConnInterface { 52 | return setupClientChainPairs(conn) 53 | }, 54 | setupServer: func(reg grpc.ServiceRegistrar) grpc.ServiceRegistrar { 55 | return setupServerChainPairs(reg) 56 | }, 57 | unaryIntercepted: true, 58 | streamIntercepted: true, 59 | }, 60 | { 61 | name: "unary-only", 62 | setupClient: func(conn grpc.ClientConnInterface) grpc.ClientConnInterface { 63 | return setupClientChainUnaryOnly(conn) 64 | }, 65 | setupServer: func(reg grpc.ServiceRegistrar) grpc.ServiceRegistrar { 66 | return setupServerChainUnaryOnly(reg) 67 | }, 68 | unaryIntercepted: true, 69 | streamIntercepted: false, 70 | }, 71 | { 72 | name: "stream-only", 73 | setupClient: func(conn grpc.ClientConnInterface) grpc.ClientConnInterface { 74 | return setupClientChainStreamOnly(conn) 75 | }, 76 | setupServer: func(reg grpc.ServiceRegistrar) grpc.ServiceRegistrar { 77 | return setupServerChainStreamOnly(reg) 78 | }, 79 | unaryIntercepted: false, 80 | streamIntercepted: true, 81 | }, 82 | { 83 | name: "none", 84 | setupClient: func(conn grpc.ClientConnInterface) grpc.ClientConnInterface { 85 | return conn 86 | }, 87 | setupServer: func(reg grpc.ServiceRegistrar) grpc.ServiceRegistrar { 88 | return reg 89 | }, 90 | unaryIntercepted: false, 91 | streamIntercepted: false, 92 | }, 93 | } 94 | 95 | func TestInterceptorChainClient_Unary(t *testing.T) { 96 | for _, testCase := range interceptorChainCases { 97 | t.Run(testCase.name, func(t *testing.T) { 98 | ctx := context.Background() 99 | tc := testChainConn{t: t} 100 | intercepted := testCase.setupClient(&tc) 101 | var req, expectReply string 102 | var expectHeaders, expectTrailers metadata.MD 103 | if testCase.unaryIntercepted { 104 | req = "req" 105 | ctx = metadata.AppendToOutgoingContext(ctx, "header", "value") 106 | expectHeaders = metadata.Pairs("header", "value", "header-A", "value-A", "header-B", "value-B", "header-C", "value-C") 107 | expectTrailers = metadata.Pairs("trailer", "value", "trailer-A", "value-A", "trailer-B", "value-B", "trailer-C", "value-C") 108 | expectReply = "reply,C,B,A" 109 | } else { 110 | // Need to add stuff to the request and headers, etc that would otherwise be added by 111 | // interceptors for the test channel to be happy. 112 | req = "req,A,B,C" 113 | ctx = metadata.AppendToOutgoingContext(ctx, "header", "value", "header-A", "value-A", "header-B", "value-B", "header-C", "value-C") 114 | // And we don't expect anything to be added to the response stuff. 115 | expectHeaders = metadata.Pairs("header", "value") 116 | expectTrailers = metadata.Pairs("trailer", "value") 117 | expectReply = "reply" 118 | } 119 | var reply string 120 | var headers, trailers metadata.MD 121 | if err := intercepted.Invoke(ctx, "/foo/bar", req, &reply, grpc.Header(&headers), grpc.Trailer(&trailers)); err != nil { 122 | t.Fatalf("unexpected RPC error: %v", err) 123 | } else if reply != expectReply { 124 | t.Errorf("unexpected reply: %s", reply) 125 | } 126 | if !reflect.DeepEqual(expectHeaders, headers) { 127 | t.Errorf("unexpected headers: %s", headers) 128 | } 129 | if !reflect.DeepEqual(expectTrailers, trailers) { 130 | t.Errorf("unexpected trailers: %s", trailers) 131 | } 132 | }) 133 | } 134 | } 135 | 136 | func TestInterceptorChainClient_Stream(t *testing.T) { 137 | for _, testCase := range interceptorChainCases { 138 | t.Run(testCase.name, func(t *testing.T) { 139 | ctx := context.Background() 140 | tc := testChainConn{t: t} 141 | intercepted := testCase.setupClient(&tc) 142 | var reqSuffix, expectReplySuffix string 143 | var expectHeaders, expectTrailers metadata.MD 144 | if testCase.streamIntercepted { 145 | reqSuffix = "" 146 | ctx = metadata.AppendToOutgoingContext(ctx, "header", "value") 147 | expectHeaders = metadata.Pairs("header", "value", "header-A", "value-A", "header-B", "value-B", "header-C", "value-C") 148 | expectTrailers = metadata.Pairs("trailer", "value", "trailer-A", "value-A", "trailer-B", "value-B", "trailer-C", "value-C") 149 | expectReplySuffix = ",C,B,A" 150 | } else { 151 | // Need to add stuff to the request and headers, etc that would otherwise be added by 152 | // interceptors for the test channel to be happy. 153 | reqSuffix = ",A,B,C" 154 | ctx = metadata.AppendToOutgoingContext(ctx, "header", "value", "header-A", "value-A", "header-B", "value-B", "header-C", "value-C") 155 | // And we don't expect anything to be added to the response stuff. 156 | expectHeaders = metadata.Pairs("header", "value") 157 | expectTrailers = metadata.Pairs("trailer", "value") 158 | expectReplySuffix = "" 159 | } 160 | 161 | desc := &grpc.StreamDesc{StreamName: "bar", ClientStreams: true, ServerStreams: true} 162 | stream, err := intercepted.NewStream(ctx, desc, "/foo/bar") 163 | if err != nil { 164 | t.Errorf("unexpected RPC error: %v", err) 165 | } 166 | if err := stream.SendMsg("req1" + reqSuffix); err != nil { 167 | t.Fatalf("unexpected RPC error: %v", err) 168 | } 169 | if err := stream.SendMsg("req2" + reqSuffix); err != nil { 170 | t.Fatalf("unexpected RPC error: %v", err) 171 | } 172 | if err := stream.SendMsg("req3" + reqSuffix); err != nil { 173 | t.Fatalf("unexpected RPC error: %v", err) 174 | } 175 | if err := stream.CloseSend(); err != nil { 176 | t.Fatalf("unexpected RPC error: %v", err) 177 | } 178 | 179 | var reply string 180 | if err := stream.RecvMsg(&reply); err != nil { 181 | t.Fatalf("unexpected RPC error: %v", err) 182 | } 183 | if reply != "reply3"+expectReplySuffix { 184 | t.Errorf("unexpected reply: %s", reply) 185 | } 186 | if err := stream.RecvMsg(&reply); err != nil { 187 | t.Fatalf("unexpected RPC error: %v", err) 188 | } 189 | if reply != "reply2"+expectReplySuffix { 190 | t.Errorf("unexpected reply: %s", reply) 191 | } 192 | if err := stream.RecvMsg(&reply); err != nil { 193 | t.Fatalf("unexpected RPC error: %v", err) 194 | } 195 | if reply != "reply1"+expectReplySuffix { 196 | t.Errorf("unexpected reply: %s", reply) 197 | } 198 | if err := stream.RecvMsg(&reply); err == nil { 199 | t.Error("expecting io.EOF but got no error") 200 | } else if !errors.Is(err, io.EOF) { 201 | t.Errorf("expecting io.EOF but got %v", err) 202 | } 203 | 204 | headers, err := stream.Header() 205 | if err != nil { 206 | t.Errorf("unexpected RPC error: %v", err) 207 | } 208 | if !reflect.DeepEqual(expectHeaders, headers) { 209 | t.Errorf("unexpected headers: %s", headers) 210 | } 211 | trailers := stream.Trailer() 212 | if !reflect.DeepEqual(expectTrailers, trailers) { 213 | t.Errorf("unexpected trailers: %s", trailers) 214 | } 215 | }) 216 | } 217 | } 218 | 219 | func setupClientChainBatch(clientConn grpc.ClientConnInterface) grpc.ClientConnInterface { 220 | int1 := chainClientInterceptor{id: "A"} 221 | int2 := chainClientInterceptor{id: "B"} 222 | int3 := chainClientInterceptor{id: "C"} 223 | return InterceptClientConnUnary( 224 | InterceptClientConnStream( 225 | clientConn, 226 | int1.doStream, int2.doStream, int3.doStream, 227 | ), 228 | int1.doUnary, int2.doUnary, int3.doUnary, 229 | ) 230 | } 231 | 232 | func setupClientChainSingles(clientConn grpc.ClientConnInterface) grpc.ClientConnInterface { 233 | int1 := chainClientInterceptor{id: "A"} 234 | int2 := chainClientInterceptor{id: "B"} 235 | int3 := chainClientInterceptor{id: "C"} 236 | return InterceptClientConnStream( 237 | InterceptClientConnUnary( 238 | InterceptClientConnStream( 239 | InterceptClientConnUnary( 240 | InterceptClientConnStream( 241 | InterceptClientConnUnary( 242 | clientConn, 243 | int3.doUnary, 244 | ), 245 | int3.doStream, 246 | ), 247 | int2.doUnary, 248 | ), 249 | int2.doStream, 250 | ), 251 | int1.doUnary, 252 | ), 253 | int1.doStream, 254 | ) 255 | } 256 | 257 | func setupClientChainPairs(clientConn grpc.ClientConnInterface) grpc.ClientConnInterface { 258 | int1 := chainClientInterceptor{id: "A"} 259 | int2 := chainClientInterceptor{id: "B"} 260 | int3 := chainClientInterceptor{id: "C"} 261 | return InterceptClientConn( 262 | InterceptClientConn( 263 | InterceptClientConn( 264 | clientConn, 265 | int3.doUnary, int3.doStream, 266 | ), 267 | int2.doUnary, int2.doStream, 268 | ), 269 | int1.doUnary, int1.doStream, 270 | ) 271 | } 272 | 273 | func setupClientChainUnaryOnly(clientConn grpc.ClientConnInterface) grpc.ClientConnInterface { 274 | int1 := chainClientInterceptor{id: "A"} 275 | int2 := chainClientInterceptor{id: "B"} 276 | int3 := chainClientInterceptor{id: "C"} 277 | return InterceptClientConnUnary( 278 | InterceptClientConnUnary( 279 | InterceptClientConnUnary( 280 | clientConn, 281 | int3.doUnary, 282 | ), 283 | int2.doUnary, 284 | ), 285 | int1.doUnary, 286 | ) 287 | } 288 | 289 | func setupClientChainStreamOnly(clientConn grpc.ClientConnInterface) grpc.ClientConnInterface { 290 | int1 := chainClientInterceptor{id: "A"} 291 | int2 := chainClientInterceptor{id: "B"} 292 | int3 := chainClientInterceptor{id: "C"} 293 | return InterceptClientConnStream( 294 | InterceptClientConnStream( 295 | InterceptClientConnStream( 296 | clientConn, 297 | int3.doStream, 298 | ), 299 | int2.doStream, 300 | ), 301 | int1.doStream, 302 | ) 303 | } 304 | 305 | type chainClientInterceptor struct { 306 | id string 307 | } 308 | 309 | func (c *chainClientInterceptor) doUnary(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { 310 | ctx = metadata.AppendToOutgoingContext(ctx, "header-"+c.id, "value-"+c.id) 311 | str, ok := req.(string) 312 | if !ok { 313 | return status.Errorf(codes.Internal, "unexpected request type: %T", req) 314 | } 315 | str += "," + c.id 316 | strPtr, ok := reply.(*string) 317 | if !ok { 318 | return status.Errorf(codes.Internal, "unexpected response type: %T", reply) 319 | } 320 | var headers, trailers metadata.MD 321 | opts = append(opts, grpc.Header(&headers), grpc.Trailer(&trailers)) 322 | if err := invoker(ctx, method, str, strPtr, cc, opts...); err != nil { 323 | return err 324 | } 325 | if headers == nil { 326 | return errors.New("response headers are nil") 327 | } 328 | headers.Append("header-"+c.id, "value-"+c.id) 329 | if trailers == nil { 330 | return errors.New("response trailers are nil") 331 | } 332 | trailers.Append("trailer-"+c.id, "value-"+c.id) 333 | *strPtr += "," + c.id 334 | return nil 335 | } 336 | 337 | func (c *chainClientInterceptor) doStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { 338 | ctx = metadata.AppendToOutgoingContext(ctx, "header-"+c.id, "value-"+c.id) 339 | stream, err := streamer(ctx, desc, cc, method, opts...) 340 | if err != nil { 341 | return nil, err 342 | } 343 | return &chainClientStream{ClientStream: stream, id: c.id}, nil 344 | } 345 | 346 | type chainClientStream struct { 347 | grpc.ClientStream 348 | id string 349 | } 350 | 351 | func (c *chainClientStream) Header() (metadata.MD, error) { 352 | md, err := c.ClientStream.Header() 353 | if err != nil { 354 | return nil, err 355 | } 356 | if md == nil { 357 | return nil, errors.New("response header is nil") 358 | } 359 | md.Append("header-"+c.id, "value-"+c.id) 360 | return md, nil 361 | } 362 | 363 | func (c *chainClientStream) Trailer() metadata.MD { 364 | md := c.ClientStream.Trailer() 365 | if md != nil { 366 | md.Append("trailer-"+c.id, "value-"+c.id) 367 | } 368 | return md 369 | } 370 | 371 | func (c *chainClientStream) SendMsg(msg any) error { 372 | str, ok := msg.(string) 373 | if !ok { 374 | return status.Errorf(codes.Internal, "unexpected message type: %T", msg) 375 | } 376 | str += "," + c.id 377 | return c.ClientStream.SendMsg(str) 378 | } 379 | 380 | func (c *chainClientStream) RecvMsg(msg any) error { 381 | str, ok := msg.(*string) 382 | if !ok { 383 | return status.Errorf(codes.Internal, "unexpected message type: %T", msg) 384 | } 385 | if err := c.ClientStream.RecvMsg(str); err != nil { 386 | return err 387 | } 388 | *str += "," + c.id 389 | return nil 390 | } 391 | 392 | type testChainConn struct { 393 | t *testing.T 394 | } 395 | 396 | func (t *testChainConn) Invoke(ctx context.Context, method string, args any, reply any, opts ...grpc.CallOption) error { 397 | if method != "/foo/bar" { 398 | return fmt.Errorf("unexpected method: %s", method) 399 | } 400 | str, ok := args.(string) 401 | if !ok { 402 | return fmt.Errorf("invalid request: %T", args) 403 | } 404 | if str != "req,A,B,C" { 405 | t.t.Errorf("unexpected request: %s", str) 406 | } 407 | headers, _ := metadata.FromOutgoingContext(ctx) 408 | expectHeaders := metadata.Pairs("header", "value", "header-A", "value-A", "header-B", "value-B", "header-C", "value-C") 409 | if !reflect.DeepEqual(expectHeaders, headers) { 410 | t.t.Errorf("unexpected headers: %s", headers) 411 | } 412 | strPtr, ok := reply.(*string) 413 | if !ok { 414 | return fmt.Errorf("invalid response: %T", args) 415 | } 416 | *strPtr = "reply" 417 | callOpts := internal.GetCallOptions(opts) 418 | callOpts.SetHeaders(metadata.Pairs("header", "value")) 419 | callOpts.SetTrailers(metadata.Pairs("trailer", "value")) 420 | return nil 421 | } 422 | 423 | func (t *testChainConn) NewStream(ctx context.Context, _ *grpc.StreamDesc, method string, _ ...grpc.CallOption) (grpc.ClientStream, error) { 424 | if method != "/foo/bar" { 425 | return nil, fmt.Errorf("unexpected method: %s", method) 426 | } 427 | headers, _ := metadata.FromOutgoingContext(ctx) 428 | expectHeaders := metadata.Pairs("header", "value", "header-A", "value-A", "header-B", "value-B", "header-C", "value-C") 429 | if !reflect.DeepEqual(expectHeaders, headers) { 430 | t.t.Errorf("unexpected headers: %s", headers) 431 | } 432 | return &testChainStreamClient{t: t.t, ctx: ctx}, nil 433 | } 434 | 435 | type testChainStreamClient struct { 436 | t *testing.T 437 | ctx context.Context 438 | count int 439 | done bool 440 | } 441 | 442 | func (t *testChainStreamClient) Header() (metadata.MD, error) { 443 | return metadata.Pairs("header", "value"), nil 444 | } 445 | 446 | func (t *testChainStreamClient) Trailer() metadata.MD { 447 | return metadata.Pairs("trailer", "value") 448 | } 449 | 450 | func (t *testChainStreamClient) CloseSend() error { 451 | t.done = true 452 | return nil 453 | } 454 | 455 | func (t *testChainStreamClient) Context() context.Context { 456 | return t.ctx 457 | } 458 | 459 | func (t *testChainStreamClient) SendMsg(m any) error { 460 | if t.done { 461 | return io.EOF 462 | } 463 | str, ok := m.(string) 464 | if !ok { 465 | return fmt.Errorf("invalid message: %T", m) 466 | } 467 | t.count++ 468 | if str != "req"+strconv.Itoa(t.count)+",A,B,C" { 469 | t.t.Errorf("unexpected request: %s", str) 470 | } 471 | return nil 472 | 473 | } 474 | 475 | func (t *testChainStreamClient) RecvMsg(m any) error { 476 | if t.count == 0 { 477 | t.done = true 478 | return io.EOF 479 | } 480 | str, ok := m.(*string) 481 | if !ok { 482 | return fmt.Errorf("invalid message: %T", m) 483 | } 484 | *str = "reply" + strconv.Itoa(t.count) 485 | t.count-- 486 | return nil 487 | } 488 | -------------------------------------------------------------------------------- /interceptor_chain_server_test.go: -------------------------------------------------------------------------------- 1 | package grpchan 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "reflect" 9 | "strconv" 10 | "testing" 11 | 12 | "google.golang.org/grpc" 13 | "google.golang.org/grpc/codes" 14 | "google.golang.org/grpc/metadata" 15 | "google.golang.org/grpc/status" 16 | 17 | "github.com/fullstorydev/grpchan/internal" 18 | ) 19 | 20 | func TestInterceptorChainServer_Unary(t *testing.T) { 21 | for _, testCase := range interceptorChainCases { 22 | t.Run(testCase.name, func(t *testing.T) { 23 | ctx := context.Background() 24 | handlers := HandlerMap{} 25 | intercepted := testCase.setupServer(handlers) 26 | th := testChainHandler{t: t} 27 | intercepted.RegisterService(th.serviceDesc(), 0) 28 | var req, expectReply string 29 | var expectHeaders, expectTrailers metadata.MD 30 | if testCase.unaryIntercepted { 31 | req = "req" 32 | ctx = metadata.NewIncomingContext(ctx, metadata.Pairs("header", "value")) 33 | expectHeaders = metadata.Pairs("header", "value", "header-A", "value-A", "header-B", "value-B", "header-C", "value-C") 34 | expectTrailers = metadata.Pairs("trailer", "value", "trailer-A", "value-A", "trailer-B", "value-B", "trailer-C", "value-C") 35 | expectReply = "reply,C,B,A" 36 | } else { 37 | // Need to add stuff to the request and headers, etc that would otherwise be added by 38 | // interceptors for the test channel to be happy. 39 | req = "req,A,B,C" 40 | ctx = metadata.NewIncomingContext(ctx, metadata.Pairs("header", "value", "header-A", "value-A", "header-B", "value-B", "header-C", "value-C")) 41 | // And we don't expect anything to be added to the response stuff. 42 | expectHeaders = metadata.Pairs("header", "value") 43 | expectTrailers = metadata.Pairs("trailer", "value") 44 | expectReply = "reply" 45 | } 46 | var reply string 47 | dec := func(reqPtr any) error { 48 | str, ok := reqPtr.(*string) 49 | if !ok { 50 | t.Errorf("invalid request type: %T", reqPtr) 51 | } 52 | *str = req 53 | return nil 54 | } 55 | var stream testChainStreamServer 56 | sts := &internal.ServerTransportStream{Name: "/foo/bar", Stream: &stream} 57 | sd, srv := handlers.QueryService("foo") 58 | resp, err := sd.Methods[0].Handler(srv, grpc.NewContextWithServerTransportStream(ctx, sts), dec, nil) 59 | if err != nil { 60 | t.Fatalf("unexpected RPC error: %v", err) 61 | } 62 | reply, ok := resp.(string) 63 | if !ok { 64 | t.Errorf("invalid reply type: %T", resp) 65 | } else if reply != expectReply { 66 | t.Errorf("unexpected reply: %s", reply) 67 | } 68 | if !reflect.DeepEqual(expectHeaders, stream.headers) { 69 | t.Errorf("unexpected headers: %s", stream.headers) 70 | } 71 | if !reflect.DeepEqual(expectTrailers, stream.trailers) { 72 | t.Errorf("unexpected trailers: %s", stream.trailers) 73 | } 74 | }) 75 | } 76 | } 77 | 78 | func TestInterceptorChainServer_Stream(t *testing.T) { 79 | for _, testCase := range interceptorChainCases { 80 | t.Run(testCase.name, func(t *testing.T) { 81 | ctx := context.Background() 82 | handlers := HandlerMap{} 83 | intercepted := testCase.setupServer(handlers) 84 | th := testChainHandler{t: t} 85 | intercepted.RegisterService(th.serviceDesc(), 0) 86 | var reqSuffix, expectReplySuffix string 87 | var expectHeaders, expectTrailers metadata.MD 88 | if testCase.streamIntercepted { 89 | reqSuffix = "" 90 | ctx = metadata.NewIncomingContext(ctx, metadata.Pairs("header", "value")) 91 | expectHeaders = metadata.Pairs("header", "value", "header-A", "value-A", "header-B", "value-B", "header-C", "value-C") 92 | expectTrailers = metadata.Pairs("trailer", "value", "trailer-A", "value-A", "trailer-B", "value-B", "trailer-C", "value-C") 93 | expectReplySuffix = ",C,B,A" 94 | } else { 95 | // Need to add stuff to the request and headers, etc that would otherwise be added by 96 | // interceptors for the test channel to be happy. 97 | reqSuffix = ",A,B,C" 98 | ctx = metadata.NewIncomingContext(ctx, metadata.Pairs("header", "value", "header-A", "value-A", "header-B", "value-B", "header-C", "value-C")) 99 | // And we don't expect anything to be added to the response stuff. 100 | expectHeaders = metadata.Pairs("header", "value") 101 | expectTrailers = metadata.Pairs("trailer", "value") 102 | expectReplySuffix = "" 103 | } 104 | sd, srv := handlers.QueryService("foo") 105 | stream := &testChainStreamServer{ 106 | t: t, 107 | ctx: ctx, 108 | reqSuffix: reqSuffix, 109 | } 110 | err := sd.Streams[0].Handler(srv, stream) 111 | if err != nil { 112 | t.Fatalf("unexpected RPC error: %v", err) 113 | } 114 | expectedSent := []string{ 115 | "reply3" + expectReplySuffix, 116 | "reply2" + expectReplySuffix, 117 | "reply1" + expectReplySuffix, 118 | } 119 | if !reflect.DeepEqual(expectedSent, stream.sent) { 120 | t.Errorf("unexpected sent: %s", stream.sent) 121 | } 122 | if !reflect.DeepEqual(expectHeaders, stream.headers) { 123 | t.Errorf("unexpected headers: %s", stream.headers) 124 | } 125 | if !reflect.DeepEqual(expectTrailers, stream.trailers) { 126 | t.Errorf("unexpected trailers: %s", stream.trailers) 127 | } 128 | }) 129 | } 130 | } 131 | 132 | func setupServerChainBatch(reg grpc.ServiceRegistrar) grpc.ServiceRegistrar { 133 | int1 := chainServerInterceptor{id: "A"} 134 | int2 := chainServerInterceptor{id: "B"} 135 | int3 := chainServerInterceptor{id: "C"} 136 | return WithUnaryInterceptors( 137 | WithStreamInterceptors( 138 | reg, 139 | int1.doStream, int2.doStream, int3.doStream, 140 | ), 141 | int1.doUnary, int2.doUnary, int3.doUnary, 142 | ) 143 | } 144 | 145 | func setupServerChainSingles(reg grpc.ServiceRegistrar) grpc.ServiceRegistrar { 146 | int1 := chainServerInterceptor{id: "A"} 147 | int2 := chainServerInterceptor{id: "B"} 148 | int3 := chainServerInterceptor{id: "C"} 149 | return WithStreamInterceptors( 150 | WithUnaryInterceptors( 151 | WithStreamInterceptors( 152 | WithUnaryInterceptors( 153 | WithStreamInterceptors( 154 | WithUnaryInterceptors( 155 | reg, 156 | int3.doUnary, 157 | ), 158 | int3.doStream, 159 | ), 160 | int2.doUnary, 161 | ), 162 | int2.doStream, 163 | ), 164 | int1.doUnary, 165 | ), 166 | int1.doStream, 167 | ) 168 | } 169 | 170 | func setupServerChainPairs(reg grpc.ServiceRegistrar) grpc.ServiceRegistrar { 171 | int1 := chainServerInterceptor{id: "A"} 172 | int2 := chainServerInterceptor{id: "B"} 173 | int3 := chainServerInterceptor{id: "C"} 174 | return WithInterceptor( 175 | WithInterceptor( 176 | WithInterceptor( 177 | reg, 178 | int3.doUnary, int3.doStream, 179 | ), 180 | int2.doUnary, int2.doStream, 181 | ), 182 | int1.doUnary, int1.doStream, 183 | ) 184 | } 185 | 186 | func setupServerChainUnaryOnly(reg grpc.ServiceRegistrar) grpc.ServiceRegistrar { 187 | int1 := chainServerInterceptor{id: "A"} 188 | int2 := chainServerInterceptor{id: "B"} 189 | int3 := chainServerInterceptor{id: "C"} 190 | return WithUnaryInterceptors( 191 | WithUnaryInterceptors( 192 | WithUnaryInterceptors( 193 | reg, 194 | int3.doUnary, 195 | ), 196 | int2.doUnary, 197 | ), 198 | int1.doUnary, 199 | ) 200 | } 201 | 202 | func setupServerChainStreamOnly(reg grpc.ServiceRegistrar) grpc.ServiceRegistrar { 203 | int1 := chainServerInterceptor{id: "A"} 204 | int2 := chainServerInterceptor{id: "B"} 205 | int3 := chainServerInterceptor{id: "C"} 206 | return WithStreamInterceptors( 207 | WithStreamInterceptors( 208 | WithStreamInterceptors( 209 | reg, 210 | int3.doStream, 211 | ), 212 | int2.doStream, 213 | ), 214 | int1.doStream, 215 | ) 216 | } 217 | 218 | type chainServerInterceptor struct { 219 | id string 220 | } 221 | 222 | func (c *chainServerInterceptor) doUnary(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { 223 | md, ok := metadata.FromIncomingContext(ctx) 224 | if ok { 225 | md.Append("header-"+c.id, "value-"+c.id) 226 | ctx = metadata.NewIncomingContext(ctx, md) 227 | } 228 | str, ok := req.(string) 229 | if !ok { 230 | return nil, status.Errorf(codes.Internal, "unexpected request type: %T", req) 231 | } 232 | str += "," + c.id 233 | if err := grpc.SetHeader(ctx, metadata.Pairs("header-"+c.id, "value-"+c.id)); err != nil { 234 | return nil, err 235 | } 236 | if grpc.SetTrailer(ctx, metadata.Pairs("trailer-"+c.id, "value-"+c.id)) != nil { 237 | return nil, err 238 | } 239 | reply, err := handler(ctx, str) 240 | if err != nil { 241 | return nil, err 242 | } 243 | str, ok = reply.(string) 244 | if !ok { 245 | return nil, fmt.Errorf("unexpected response type: %T", reply) 246 | } 247 | str += "," + c.id 248 | return str, nil 249 | } 250 | 251 | func (c *chainServerInterceptor) doStream(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 252 | ctx := ss.Context() 253 | md, ok := metadata.FromIncomingContext(ss.Context()) 254 | if ok { 255 | md.Append("header-"+c.id, "value-"+c.id) 256 | ctx = metadata.NewIncomingContext(ctx, md) 257 | } 258 | ss = &chainServerStream{ServerStream: ss, ctx: ctx, id: c.id} 259 | if err := ss.SetHeader(metadata.Pairs("header-"+c.id, "value-"+c.id)); err != nil { 260 | return err 261 | } 262 | if err := handler(srv, ss); err != nil { 263 | return err 264 | } 265 | ss.SetTrailer(metadata.Pairs("trailer-"+c.id, "value-"+c.id)) 266 | return nil 267 | } 268 | 269 | type chainServerStream struct { 270 | grpc.ServerStream 271 | ctx context.Context 272 | id string 273 | } 274 | 275 | func (c *chainServerStream) Context() context.Context { 276 | return c.ctx 277 | } 278 | 279 | func (c *chainServerStream) SendMsg(msg any) error { 280 | str, ok := msg.(string) 281 | if !ok { 282 | return status.Errorf(codes.Internal, "unexpected message type: %T", msg) 283 | } 284 | str += "," + c.id 285 | return c.ServerStream.SendMsg(str) 286 | } 287 | 288 | func (c *chainServerStream) RecvMsg(msg any) error { 289 | str, ok := msg.(*string) 290 | if !ok { 291 | return status.Errorf(codes.Internal, "unexpected message type: %T", msg) 292 | } 293 | if err := c.ServerStream.RecvMsg(str); err != nil { 294 | return err 295 | } 296 | *str += "," + c.id 297 | return nil 298 | } 299 | 300 | type testChainHandler struct { 301 | t *testing.T 302 | } 303 | 304 | func (t *testChainHandler) handleUnary(ctx context.Context, req any) (any, error) { 305 | str, ok := req.(string) 306 | if !ok { 307 | return nil, fmt.Errorf("invalid request: %T", req) 308 | } 309 | if str != "req,A,B,C" { 310 | t.t.Errorf("unexpected request: %s", str) 311 | } 312 | headers, _ := metadata.FromIncomingContext(ctx) 313 | expectHeaders := metadata.Pairs("header", "value", "header-A", "value-A", "header-B", "value-B", "header-C", "value-C") 314 | if !reflect.DeepEqual(expectHeaders, headers) { 315 | t.t.Errorf("unexpected headers: %s", headers) 316 | } 317 | if err := grpc.SetHeader(ctx, metadata.Pairs("header", "value")); err != nil { 318 | return nil, err 319 | } 320 | if err := grpc.SetTrailer(ctx, metadata.Pairs("trailer", "value")); err != nil { 321 | return nil, err 322 | } 323 | return "reply", nil 324 | } 325 | 326 | func (t *testChainHandler) handleStream(_ any, stream grpc.ServerStream) error { 327 | headers, _ := metadata.FromIncomingContext(stream.Context()) 328 | expectHeaders := metadata.Pairs("header", "value", "header-A", "value-A", "header-B", "value-B", "header-C", "value-C") 329 | if !reflect.DeepEqual(expectHeaders, headers) { 330 | t.t.Errorf("unexpected headers: %s", headers) 331 | } 332 | var count int 333 | for { 334 | var req string 335 | if err := stream.RecvMsg(&req); err != nil { 336 | if errors.Is(err, io.EOF) { 337 | break 338 | } 339 | return err 340 | } 341 | count++ 342 | if req != "req"+strconv.Itoa(count)+",A,B,C" { 343 | t.t.Errorf("unexpected request: %s", req) 344 | } 345 | } 346 | if err := stream.SetHeader(metadata.Pairs("header", "value")); err != nil { 347 | return err 348 | } 349 | for i := range count { 350 | if err := stream.SendMsg("reply" + strconv.Itoa(count-i)); err != nil { 351 | return err 352 | } 353 | } 354 | stream.SetTrailer(metadata.Pairs("trailer", "value")) 355 | return nil 356 | } 357 | 358 | func (t *testChainHandler) serviceDesc() *grpc.ServiceDesc { 359 | unaryHandler := t.handleUnary 360 | streamHandler := t.handleStream 361 | return &grpc.ServiceDesc{ 362 | ServiceName: "foo", 363 | HandlerType: (*any)(nil), 364 | Methods: []grpc.MethodDesc{ 365 | { 366 | MethodName: "bar", 367 | Handler: func(srv any, ctx context.Context, dec func(any) error, interceptor grpc.UnaryServerInterceptor) (any, error) { 368 | var req string 369 | if err := dec(&req); err != nil { 370 | return nil, err 371 | } 372 | if interceptor != nil { 373 | return interceptor(ctx, req, &grpc.UnaryServerInfo{Server: srv, FullMethod: "/foo/bar"}, unaryHandler) 374 | } 375 | return unaryHandler(ctx, req) 376 | }, 377 | }, 378 | }, 379 | Streams: []grpc.StreamDesc{ 380 | { 381 | StreamName: "bar", 382 | Handler: streamHandler, 383 | }, 384 | }, 385 | } 386 | } 387 | 388 | type testChainStreamServer struct { 389 | t *testing.T 390 | ctx context.Context 391 | reqSuffix string 392 | done bool 393 | count int 394 | 395 | headers, trailers metadata.MD 396 | sent []string 397 | } 398 | 399 | func (t *testChainStreamServer) SetHeader(md metadata.MD) error { 400 | if t.done { 401 | return io.EOF 402 | } 403 | if t.headers == nil { 404 | t.headers = metadata.MD{} 405 | } 406 | for k, v := range md { 407 | t.headers[k] = append(t.headers[k], v...) 408 | } 409 | return nil 410 | } 411 | 412 | func (t *testChainStreamServer) SendHeader(md metadata.MD) error { 413 | if err := t.SetHeader(md); err != nil { 414 | return err 415 | } 416 | t.done = true 417 | return nil 418 | } 419 | 420 | func (t *testChainStreamServer) SetTrailer(md metadata.MD) { 421 | if t.trailers == nil { 422 | t.trailers = metadata.MD{} 423 | } 424 | for k, v := range md { 425 | t.trailers[k] = append(t.trailers[k], v...) 426 | } 427 | } 428 | 429 | func (t *testChainStreamServer) Context() context.Context { 430 | return t.ctx 431 | } 432 | 433 | func (t *testChainStreamServer) SendMsg(m any) error { 434 | str, ok := m.(string) 435 | if !ok { 436 | return fmt.Errorf("invalid message: %T", m) 437 | } 438 | t.sent = append(t.sent, str) 439 | return nil 440 | } 441 | 442 | func (t *testChainStreamServer) RecvMsg(m any) error { 443 | if t.count == 3 { 444 | return io.EOF 445 | } 446 | str, ok := m.(*string) 447 | if !ok { 448 | return fmt.Errorf("invalid message: %T", m) 449 | } 450 | t.count++ 451 | *str = "req" + strconv.Itoa(t.count) + t.reqSuffix 452 | return nil 453 | } 454 | -------------------------------------------------------------------------------- /internal/call_options.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "google.golang.org/grpc" 8 | "google.golang.org/grpc/credentials" 9 | "google.golang.org/grpc/metadata" 10 | "google.golang.org/grpc/peer" 11 | ) 12 | 13 | // CallOptions represents the state of in-effect grpc.CallOptions. 14 | type CallOptions struct { 15 | // Headers is a slice of metadata pointers which should all be set when 16 | // response header metadata is received. 17 | Headers []*metadata.MD 18 | // Trailers is a slice of metadata pointers which should all be set when 19 | // response trailer metadata is received. 20 | Trailers []*metadata.MD 21 | // Peer is a slice of peer pointers which should all be set when the 22 | // remote peer is known. 23 | Peer []*peer.Peer 24 | // Creds are per-RPC credentials to use for a call. 25 | Creds credentials.PerRPCCredentials 26 | // MaxRecv is the maximum number of bytes to receive for a single message 27 | // in a call. 28 | MaxRecv int 29 | // MaxSend is the maximum number of bytes to send for a single message in 30 | // a call. 31 | MaxSend int 32 | } 33 | 34 | // SetHeaders sets all accumulated header addresses to the given metadata. This 35 | // satisfies grpc.Header call options. 36 | func (co *CallOptions) SetHeaders(md metadata.MD) { 37 | for _, hdr := range co.Headers { 38 | *hdr = md 39 | } 40 | } 41 | 42 | // SetTrailers sets all accumulated trailer addresses to the given metadata. 43 | // This satisfies grpc.Trailer call options. 44 | func (co *CallOptions) SetTrailers(md metadata.MD) { 45 | for _, tlr := range co.Trailers { 46 | *tlr = md 47 | } 48 | } 49 | 50 | // SetPeer sets all accumulated peer addresses to the given peer. This satisfies 51 | // grpc.Peer call options. 52 | func (co *CallOptions) SetPeer(p *peer.Peer) { 53 | for _, pr := range co.Peer { 54 | *pr = *p 55 | } 56 | } 57 | 58 | // GetCallOptions converts the given slice of grpc.CallOptions into a 59 | // CallOptions struct. 60 | func GetCallOptions(opts []grpc.CallOption) *CallOptions { 61 | var copts CallOptions 62 | for _, o := range opts { 63 | switch o := o.(type) { 64 | case grpc.HeaderCallOption: 65 | copts.Headers = append(copts.Headers, o.HeaderAddr) 66 | case grpc.TrailerCallOption: 67 | copts.Trailers = append(copts.Trailers, o.TrailerAddr) 68 | case grpc.PeerCallOption: 69 | copts.Peer = append(copts.Peer, o.PeerAddr) 70 | case grpc.PerRPCCredsCallOption: 71 | copts.Creds = o.Creds 72 | case grpc.MaxRecvMsgSizeCallOption: 73 | copts.MaxRecv = o.MaxRecvMsgSize 74 | case grpc.MaxSendMsgSizeCallOption: 75 | copts.MaxSend = o.MaxSendMsgSize 76 | } 77 | } 78 | return &copts 79 | } 80 | 81 | // ApplyPerRPCCreds applies any per-RPC credentials in the given call options and 82 | // returns a new context with the additional metadata. It will return an error if 83 | // isChannelSecure is false but the per-RPC credentials require a secure channel. 84 | func ApplyPerRPCCreds(ctx context.Context, copts *CallOptions, uri string, isChannelSecure bool) (context.Context, error) { 85 | if copts.Creds != nil { 86 | if copts.Creds.RequireTransportSecurity() && !isChannelSecure { 87 | return nil, fmt.Errorf("transport security is required") 88 | } 89 | md, err := copts.Creds.GetRequestMetadata(ctx, uri) 90 | if err != nil { 91 | return nil, err 92 | } 93 | if len(md) > 0 { 94 | reqHeaders, ok := metadata.FromOutgoingContext(ctx) 95 | if ok { 96 | reqHeaders = metadata.Join(reqHeaders, metadata.New(md)) 97 | } else { 98 | reqHeaders = metadata.New(md) 99 | } 100 | ctx = metadata.NewOutgoingContext(ctx, reqHeaders) 101 | } 102 | } 103 | return ctx, nil 104 | } 105 | -------------------------------------------------------------------------------- /internal/misc.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "reflect" 7 | 8 | //lint:ignore SA1019 we use the old v1 package because 9 | // we need to support older generated messages 10 | "github.com/golang/protobuf/proto" 11 | "github.com/jhump/protoreflect/dynamic" 12 | "google.golang.org/grpc" 13 | "google.golang.org/grpc/codes" 14 | "google.golang.org/grpc/status" 15 | ) 16 | 17 | // CopyMessage copies data from the given in value to the given out value. It returns an 18 | // error if the two values do not have the same type or if the given out value is not 19 | // settable. 20 | func CopyMessage(out, in interface{}) error { 21 | pmIn, ok := in.(proto.Message) 22 | if !ok { 23 | return fmt.Errorf("value to copy is not a proto.Message: %T; use a custom cloner", in) 24 | } 25 | pmOut, ok := out.(proto.Message) 26 | if !ok { 27 | return fmt.Errorf("destination for copy is not a proto.Message: %T; use a custom cloner", in) 28 | } 29 | 30 | pmOut.Reset() 31 | // This will check that types are compatible and return an error if not. 32 | // Unlike proto.Merge, this allows one or the other to be a dynamic message. 33 | return dynamic.TryMerge(pmOut, pmIn) 34 | } 35 | 36 | // CloneMessage returns a copy of the given value. 37 | func CloneMessage(m interface{}) (interface{}, error) { 38 | pm, ok := m.(proto.Message) 39 | if !ok { 40 | return nil, fmt.Errorf("value to clone is not a proto.Message: %T; use a custom cloner", m) 41 | } 42 | 43 | // this does a proper deep copy 44 | return proto.Clone(pm), nil 45 | } 46 | 47 | // ClearMessage resets the given value to its zero-value state. It returns an error 48 | // if the given out value is not settable. 49 | func ClearMessage(m interface{}) error { 50 | dest := reflect.Indirect(reflect.ValueOf(m)) 51 | if !dest.CanSet() { 52 | return fmt.Errorf("unable to set destination: %v", reflect.ValueOf(m).Type()) 53 | } 54 | dest.Set(reflect.Zero(dest.Type())) 55 | return nil 56 | } 57 | 58 | // TranslateContextError converts the given error to a gRPC status error if it 59 | // is a context error. If it is context.DeadlineExceeded, it is converted to an 60 | // error with a status code of DeadlineExceeded. If it is context.Canceled, it 61 | // is converted to an error with a status code of Canceled. If it is not a 62 | // context error, it is returned without any conversion. 63 | func TranslateContextError(err error) error { 64 | switch err { 65 | case context.DeadlineExceeded: 66 | return status.Errorf(codes.DeadlineExceeded, err.Error()) 67 | case context.Canceled: 68 | return status.Errorf(codes.Canceled, err.Error()) 69 | } 70 | return err 71 | } 72 | 73 | // FindUnaryMethod returns the method descriptor for the named method. If the 74 | // method is not found in the given slice of descriptors, nil is returned. 75 | func FindUnaryMethod(methodName string, methods []grpc.MethodDesc) *grpc.MethodDesc { 76 | for i := range methods { 77 | if methods[i].MethodName == methodName { 78 | return &methods[i] 79 | } 80 | } 81 | return nil 82 | } 83 | 84 | // FindStreamingMethod returns the stream descriptor for the named method. If 85 | // the method is not found in the given slice of descriptors, nil is returned. 86 | func FindStreamingMethod(methodName string, methods []grpc.StreamDesc) *grpc.StreamDesc { 87 | for i := range methods { 88 | if methods[i].StreamName == methodName { 89 | return &methods[i] 90 | } 91 | } 92 | return nil 93 | } 94 | -------------------------------------------------------------------------------- /internal/transport_stream.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "fmt" 5 | "sync" 6 | 7 | "google.golang.org/grpc" 8 | "google.golang.org/grpc/metadata" 9 | ) 10 | 11 | // UnaryServerTransportStream implements grpc.ServerTransportStream and can be 12 | // used by unary calls to collect headers and trailers from a handler. 13 | type UnaryServerTransportStream struct { 14 | // Name is the full method name in "/service/method" format. 15 | Name string 16 | 17 | mu sync.Mutex 18 | hdrs metadata.MD 19 | hdrsSent bool 20 | tlrs metadata.MD 21 | tlrsSent bool 22 | } 23 | 24 | // Method satisfies the grpc.ServerTransportStream, returning the full path of 25 | // the method invocation that the stream represents. 26 | func (sts *UnaryServerTransportStream) Method() string { 27 | return sts.Name 28 | } 29 | 30 | // Finish marks headers and trailers as sent, so that subsequent calls to 31 | // SetHeader, SendHeader, or SetTrailer will fail. 32 | func (sts *UnaryServerTransportStream) Finish() { 33 | sts.mu.Lock() 34 | defer sts.mu.Unlock() 35 | sts.hdrsSent = true 36 | sts.tlrsSent = true 37 | } 38 | 39 | // SetHeader satisfies the grpc.ServerTransportStream, adding the given metadata 40 | // to the set of response headers that will be sent to the client. 41 | func (sts *UnaryServerTransportStream) SetHeader(md metadata.MD) error { 42 | sts.mu.Lock() 43 | defer sts.mu.Unlock() 44 | return sts.setHeaderLocked(md) 45 | } 46 | 47 | // SendHeader satisfies the grpc.ServerTransportStream, adding the given 48 | // metadata to the set of response headers. This implementation does not 49 | // actually send the headers but rather marks the headers as sent so that future 50 | // calls to SetHeader or SendHeader will return an error. 51 | func (sts *UnaryServerTransportStream) SendHeader(md metadata.MD) error { 52 | sts.mu.Lock() 53 | defer sts.mu.Unlock() 54 | if err := sts.setHeaderLocked(md); err != nil { 55 | return err 56 | } 57 | sts.hdrsSent = true 58 | return nil 59 | } 60 | 61 | func (sts *UnaryServerTransportStream) setHeaderLocked(md metadata.MD) error { 62 | if sts.hdrsSent { 63 | return fmt.Errorf("headers already sent") 64 | } 65 | if sts.hdrs == nil { 66 | sts.hdrs = metadata.MD{} 67 | } 68 | for k, v := range md { 69 | sts.hdrs[k] = append(sts.hdrs[k], v...) 70 | } 71 | return nil 72 | } 73 | 74 | // GetHeaders returns the cumulative set of headers set by calls to SetHeader 75 | // and SendHeader. This is used by a server to gather the headers that must 76 | // actually be sent to a client. 77 | func (sts *UnaryServerTransportStream) GetHeaders() metadata.MD { 78 | sts.mu.Lock() 79 | defer sts.mu.Unlock() 80 | return sts.hdrs 81 | } 82 | 83 | // SetTrailer satisfies the grpc.ServerTransportStream, adding the given 84 | // metadata to the set of response trailers that will be sent to the client. 85 | func (sts *UnaryServerTransportStream) SetTrailer(md metadata.MD) error { 86 | sts.mu.Lock() 87 | defer sts.mu.Unlock() 88 | if sts.tlrsSent { 89 | return fmt.Errorf("trailers already sent") 90 | } 91 | if sts.tlrs == nil { 92 | sts.tlrs = metadata.MD{} 93 | } 94 | for k, v := range md { 95 | sts.tlrs[k] = append(sts.tlrs[k], v...) 96 | } 97 | return nil 98 | } 99 | 100 | // GetTrailers returns the cumulative set of trailers set by calls to 101 | // SetTrailer. This is used by a server to gather the headers that must actually 102 | // be sent to a client. 103 | func (sts *UnaryServerTransportStream) GetTrailers() metadata.MD { 104 | sts.mu.Lock() 105 | defer sts.mu.Unlock() 106 | return sts.tlrs 107 | } 108 | 109 | // ServerTransportStream implements grpc.ServerTransportStream and wraps a 110 | // grpc.ServerStream, delegating most calls to it. 111 | type ServerTransportStream struct { 112 | // Name is the full method name in "/service/method" format. 113 | Name string 114 | // Stream is the underlying stream to which header and trailer calls are 115 | // delegated. 116 | Stream grpc.ServerStream 117 | } 118 | 119 | // Method satisfies the grpc.ServerTransportStream, returning the full path of 120 | // the method invocation that the stream represents. 121 | func (sts *ServerTransportStream) Method() string { 122 | return sts.Name 123 | } 124 | 125 | // SetHeader satisfies the grpc.ServerTransportStream and delegates to the 126 | // underlying grpc.ServerStream. 127 | func (sts *ServerTransportStream) SetHeader(md metadata.MD) error { 128 | return sts.Stream.SetHeader(md) 129 | } 130 | 131 | // SendHeader satisfies the grpc.ServerTransportStream and delegates to the 132 | // underlying grpc.ServerStream. 133 | func (sts *ServerTransportStream) SendHeader(md metadata.MD) error { 134 | return sts.Stream.SendHeader(md) 135 | } 136 | 137 | // SetTrailer satisfies the grpc.ServerTransportStream and delegates to the 138 | // underlying grpc.ServerStream. If the underlying stream provides a 139 | // TrySetTrailer(metadata.MD) error method, it will be used to set trailers. 140 | // Otherwise, the normal SetTrailer(metadata.MD) method will be used and a nil 141 | // error will always be returned. 142 | func (sts *ServerTransportStream) SetTrailer(md metadata.MD) error { 143 | type trailerWithErrors interface { 144 | TrySetTrailer(md metadata.MD) error 145 | } 146 | if t, ok := sts.Stream.(trailerWithErrors); ok { 147 | return t.TrySetTrailer(md) 148 | } 149 | sts.Stream.SetTrailer(md) 150 | return nil 151 | } 152 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | package grpchan 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | 7 | "google.golang.org/grpc" 8 | ) 9 | 10 | // ServiceRegistry accumulates service definitions. Servers typically have this 11 | // interface for accumulating the services they expose. 12 | // 13 | // Deprecated: Use grpc.ServiceRegistrar instead. 14 | type ServiceRegistry = grpc.ServiceRegistrar 15 | 16 | // HandlerMap is used to accumulate service handlers into a map. The handlers 17 | // can be registered once in the map, and then re-used to configure multiple 18 | // servers that should expose the same handlers. HandlerMap can also be used 19 | // as the internal store of registered handlers for a server implementation. 20 | type HandlerMap map[string]service 21 | 22 | var _ grpc.ServiceRegistrar = HandlerMap(nil) 23 | 24 | type service struct { 25 | desc *grpc.ServiceDesc 26 | handler interface{} 27 | } 28 | 29 | // RegisterService registers the given handler to be used for the given service. 30 | // Only a single handler can be registered for a given service. And services are 31 | // identified by their fully-qualified name (e.g. "package.name.Service"). 32 | func (m HandlerMap) RegisterService(desc *grpc.ServiceDesc, h interface{}) { 33 | ht := reflect.TypeOf(desc.HandlerType).Elem() 34 | st := reflect.TypeOf(h) 35 | if !st.Implements(ht) { 36 | panic(fmt.Sprintf("service %s: handler of type %v does not satisfy %v", desc.ServiceName, st, ht)) 37 | } 38 | if _, ok := m[desc.ServiceName]; ok { 39 | panic(fmt.Sprintf("service %s: handler already registered", desc.ServiceName)) 40 | } 41 | m[desc.ServiceName] = service{desc: desc, handler: h} 42 | } 43 | 44 | // QueryService returns the service descriptor and handler for the named 45 | // service. If no handler has been registered for the named service, then 46 | // nil, nil is returned. 47 | func (m HandlerMap) QueryService(name string) (*grpc.ServiceDesc, interface{}) { 48 | svc := m[name] 49 | return svc.desc, svc.handler 50 | } 51 | 52 | // GetServiceInfo returns a snapshot of information about the currently 53 | // registered services in the map. 54 | // 55 | // This mirrors the method of the same name on *grpc.Server. 56 | func (m HandlerMap) GetServiceInfo() map[string]grpc.ServiceInfo { 57 | ret := make(map[string]grpc.ServiceInfo, len(m)) 58 | for _, svc := range m { 59 | methods := make([]grpc.MethodInfo, 0, len(svc.desc.Methods)+len(svc.desc.Streams)) 60 | for _, mtd := range svc.desc.Methods { 61 | methods = append(methods, grpc.MethodInfo{Name: mtd.MethodName}) 62 | } 63 | for _, mtd := range svc.desc.Streams { 64 | methods = append(methods, grpc.MethodInfo{ 65 | Name: mtd.StreamName, 66 | IsClientStream: mtd.ClientStreams, 67 | IsServerStream: mtd.ServerStreams, 68 | }) 69 | } 70 | ret[svc.desc.ServiceName] = grpc.ServiceInfo{ 71 | Methods: methods, 72 | Metadata: svc.desc.Metadata, 73 | } 74 | } 75 | return ret 76 | } 77 | 78 | // ForEach calls the given function for each registered handler. The function is 79 | // provided the service description, and the handler. This can be used to 80 | // contribute all registered handlers to a server and means that applications 81 | // can easily expose the same services and handlers via multiple channels after 82 | // registering the handlers once, with the map: 83 | // 84 | // // Register all handlers once with the map: 85 | // reg := channel.HandlerMap{} 86 | // // (these registration functions are generated) 87 | // foo.RegisterHandlerFooBar(newFooBarImpl()) 88 | // fu.RegisterHandlerFuBaz(newFuBazImpl()) 89 | // 90 | // // Now we can re-use these handlers for multiple channels: 91 | // // Normal gRPC 92 | // svr := grpc.NewServer() 93 | // reg.ForEach(svr.RegisterService) 94 | // // In-process 95 | // ipch := &inprocgrpc.Channel{} 96 | // reg.ForEach(ipch.RegisterService) 97 | // // And HTTP 1.1 98 | // httpgrpc.HandleServices(http.HandleFunc, "/rpc/", reg, nil, nil) 99 | func (m HandlerMap) ForEach(fn func(desc *grpc.ServiceDesc, svr interface{})) { 100 | for _, svc := range m { 101 | fn(svc.desc, svc.handler) 102 | } 103 | } 104 | --------------------------------------------------------------------------------