├── .github ├── dependabot.yml └── workflows │ ├── build.yml │ └── codeql.yml ├── .gitignore ├── LICENSE ├── README.md ├── SECURITY.md ├── client.go ├── encoding ├── codec │ ├── codec.go │ ├── codec_test.go │ ├── header.go │ ├── http.go │ ├── protobuf.go │ └── raw.go ├── compress │ ├── compresser.go │ ├── gzip.go │ └── zstd.go ├── encoding.go └── encrypt │ └── encrypter.go ├── example ├── consts.go ├── ping │ ├── client │ │ └── main.go │ └── server │ │ └── main.go └── streaming │ ├── client │ └── main.go │ └── server │ └── main.go ├── go.mod ├── go.sum ├── header.go ├── internal ├── join │ ├── bytes.go │ └── join.go └── utils │ └── bytes_writer.go ├── network ├── conn.go └── stream.go ├── server.go ├── stream.go ├── test_data └── pb │ ├── msg.pb.go │ └── msg.proto └── transport.go /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "gomod" # See documentation for possible values 9 | directory: "/" # Location of package manifests 10 | schedule: 11 | interval: "weekly" 12 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | 11 | build: 12 | runs-on: ${{ matrix.os }} 13 | strategy: 14 | matrix: 15 | os: 16 | - ubuntu-20.04 17 | - ubuntu-22.04 18 | - ubuntu-24.04 19 | - windows-2019 20 | - windows-2022 21 | - windows-2025 22 | - macos-13 23 | - macos-14 24 | - macos-15 25 | go: 26 | - '1.22' 27 | steps: 28 | - uses: actions/checkout@v3 29 | 30 | - name: Set up Go 31 | uses: actions/setup-go@v3 32 | with: 33 | go-version: ${{ matrix.go }} 34 | 35 | - name: Lint 36 | run: | 37 | go install golang.org/x/lint/golint@latest 38 | golint -set_exit_status ./... 39 | go install github.com/gordonklaus/ineffassign@latest 40 | ineffassign ./... 41 | 42 | - name: Build 43 | run: go build ./... 44 | -------------------------------------------------------------------------------- /.github/workflows/codeql.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ "master" ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ "master" ] 20 | 21 | jobs: 22 | analyze: 23 | name: Analyze 24 | runs-on: ubuntu-latest 25 | permissions: 26 | actions: read 27 | contents: read 28 | security-events: write 29 | 30 | strategy: 31 | fail-fast: false 32 | matrix: 33 | language: [ 'go' ] 34 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 35 | # Use only 'java' to analyze code written in Java, Kotlin or both 36 | # Use only 'javascript' to analyze code written in JavaScript, TypeScript or both 37 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support 38 | 39 | steps: 40 | - name: Checkout repository 41 | uses: actions/checkout@v3 42 | 43 | # Initializes the CodeQL tools for scanning. 44 | - name: Initialize CodeQL 45 | uses: github/codeql-action/init@v2 46 | with: 47 | languages: ${{ matrix.language }} 48 | # If you wish to specify custom queries, you can do so here or in a config file. 49 | # By default, queries listed here will override any specified in a config file. 50 | # Prefix the list here with "+" to use these queries and those in the config file. 51 | 52 | # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 53 | # queries: security-extended,security-and-quality 54 | 55 | 56 | # Autobuild attempts to build any compiled languages (C/C++, C#, Go, or Java). 57 | # If this step fails, then you should remove it and run the build manually (see below) 58 | - name: Autobuild 59 | uses: github/codeql-action/autobuild@v2 60 | 61 | # ℹ️ Command-line programs to run using the OS shell. 62 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 63 | 64 | # If the Autobuild fails above, remove it and uncomment the following three lines. 65 | # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. 66 | 67 | # - run: | 68 | # echo "Run, Build Application using script" 69 | # ./location_of_script_within_repo/buildscript.sh 70 | 71 | - name: Perform CodeQL Analysis 72 | uses: github/codeql-action/analyze@v2 73 | with: 74 | category: "/language:${{matrix.language}}" 75 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /test -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 李文超 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # crpc 2 | 3 | [![crpc](https://github.com/lwch/crpc/actions/workflows/build.yml/badge.svg)](https://github.com/lwch/crpc/actions/workflows/build.yml) 4 | [![Go Report Card](https://goreportcard.com/badge/github.com/lwch/crpc)](https://goreportcard.com/report/github.com/lwch/crpc) 5 | [![Go Reference](https://pkg.go.dev/badge/badge/github.com/lwch/crpc.svg)](https://pkg.go.dev/badge/github.com/lwch/crpc) 6 | [![license](https://img.shields.io/github/license/lwch/crpc)](https://opensource.org/licenses/MIT) 7 | 8 | golang rpc框架,支持以下功能: 9 | 10 | 1. 流式传输 11 | 2. 数据加密 12 | 3. 数据压缩 13 | 4. 结构序列化,已支持数据类型: 14 | - []byte 15 | - http.Request, http.Response 16 | - proto.Message 17 | 18 | ## 分层设计 19 | 20 | 在crpc框架使用以下的多层设计,每一个层次有其相应的数据结构 21 | 22 | +------------+-------------------+--------------------+-------+ 23 | | data frame | encrypt(optional) | compress(optional) | codec | 24 | +------------+-------------------+--------------------+-------+ 25 | 26 | - `data frame`: 数据帧,最底层数据结构,直接面向于tcp协议 27 | - `encrypt`: 数据加密层,目前已支持aes和des加密算法 28 | - `compress`: 数据压缩层,目前已支持gzip和zstd压缩算法 29 | - `codec`: 数据序列化层,目前支持`[]byte`、`http.Request`、`http.Response`三种数据结构的序列化 30 | 31 | ### 数据帧(network) 32 | 33 | 数据帧为最基础数据结构,直接作用于tcp链路,其封装格式如下 34 | 35 | +-------------+---------+----------+---------+---------+ 36 | | Sequence(4) | Size(2) | Crc32(4) | Flag(4) | Payload | 37 | +-------------+---------+----------+---------+---------+ 38 | 39 | 以上内容括号中的数字表示字节数,其中`Flag`字段为枚举类型,枚举值如下 40 | 41 | +---------+------------+----------+---------+---------+---------+-----------+---------------+ 42 | | Open(1) | OpenAck(1) | Close(1) | Data(1) | Ping(1) | Pong(1) | Unused(2) | Stream ID(24) | 43 | +---------+------------+----------+---------+---------+---------+-----------+---------------+ 44 | 45 | 以上内容括号中的数字表示比特位,其中每一个比特位代表一个标志位,互相之间是互斥关系,目前仅使用了`Flag`字段第一字节的高6位,由于Stream ID字段仅有3字节,因此crpc中仅支持16777215个stream`同时`传输数据 46 | 47 | ### 数据加密层(encoding/encrypt) 48 | 49 | 数据加密层用于将原始数据进行加密,在数据加密前会将原始数据的crc32校验码添加到数据尾部作为解密后的校验依据,其封装格式如下: 50 | 51 | +----------+----------+ 52 | + Src Data | Crc32(4) | 53 | +----------+----------+ 54 | 55 | - `aes`加密算法: aes加密算法使用32字节长度密钥以及16字节的iv进行CBC算法加密 56 | - `des`加密算法: des加密算法使用24字节长度密钥以及8字节的iv进行TripleDES算法加密 57 | 58 | 当给定密钥长度不足时,底层会重复多次密钥内容以保证加密运算的进行 59 | 60 | ### 数据压缩层(encoding/compress) 61 | 62 | 数据压缩层用于将原始数据进行压缩,在数据压缩前会将原始数据的crc32校验码添加到数据尾部作为解压后的校验依据,其封装格式如下: 63 | 64 | +----------+----------+ 65 | + Src Data | Crc32(4) | 66 | +----------+----------+ 67 | 68 | ### 数据编码层(encoding/codec) 69 | 70 | 数据编码层用于描述原始数据类型,主要作用于数据的序列化和反序列化过程,数据结构如下: 71 | 72 | +---------+---------+ 73 | + Type(1) | Payload | 74 | +---------+---------+ 75 | 76 | 其中Type字段为1字节,表示当前数据类型,定义如下: 77 | 78 | - `0`: 未知数据类型 79 | - `1`: raw data,可反序列化到[]byte 80 | - `2`: http request,可反序列化到http.Request 81 | - `3`: http response,可反序列化到http.Response 82 | - `4`: protobuf,可反序列化到proto.Message 83 | 84 | #### http请求 85 | 86 | grpc框架底层使用`X-Crpc-Request-Id`字段进行request与response的关联,因此在使用过程中请勿使用该字段。 87 | 88 | ## 示例 89 | 90 | TODO -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | 3 | ## Supported Versions 4 | 5 | Use this section to tell people about which versions of your project are 6 | currently being supported with security updates. 7 | 8 | | Version | Supported | 9 | | ------- | ------------------ | 10 | | 5.1.x | :white_check_mark: | 11 | | 5.0.x | :x: | 12 | | 4.0.x | :white_check_mark: | 13 | | < 4.0 | :x: | 14 | 15 | ## Reporting a Vulnerability 16 | 17 | Use this section to tell people how to report a vulnerability. 18 | 19 | Tell them where to go, how often they can expect to get an update on a 20 | reported vulnerability, what to expect if the vulnerability is accepted or 21 | declined, etc. 22 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | package crpc 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "net" 8 | "net/http" 9 | "sync" 10 | "time" 11 | 12 | "github.com/lwch/crpc/encoding" 13 | "github.com/lwch/logging" 14 | ) 15 | 16 | // ErrReconnecting reconnecting error 17 | var ErrReconnecting = errors.New("reconnecting") 18 | 19 | // ErrClosed closed error 20 | var ErrClosed = errors.New("closed") 21 | 22 | // Client rpc client 23 | type Client struct { 24 | sync.RWMutex 25 | addr string 26 | tp *transport 27 | // runtime 28 | ctx context.Context 29 | cancel context.CancelFunc 30 | } 31 | 32 | // NewClient create client 33 | func NewClient(addr string) (*Client, error) { 34 | conn, err := dial(addr, 1) 35 | if err != nil { 36 | return nil, err 37 | } 38 | ctx, cancel := context.WithCancel(context.Background()) 39 | cli := &Client{ 40 | addr: addr, 41 | tp: new(conn), 42 | ctx: ctx, 43 | cancel: cancel, 44 | } 45 | go cli.serve() 46 | return cli, nil 47 | } 48 | 49 | // SetEncrypter set encrypter 50 | func (cli *Client) SetEncrypter(encrypter encoding.Encrypter) { 51 | cli.tp.SetEncrypter(encrypter) 52 | } 53 | 54 | // SetCompresser set compresser 55 | func (cli *Client) SetCompresser(compresser encoding.Compresser) { 56 | cli.tp.SetCompresser(compresser) 57 | } 58 | 59 | func dial(addr string, retry int) (net.Conn, error) { 60 | for i := 0; retry == 0 || i < retry; i++ { 61 | conn, err := net.DialTimeout("tcp", addr, 3*time.Second) 62 | if err == nil { 63 | return conn, nil 64 | } 65 | time.Sleep(time.Second) 66 | } 67 | return nil, fmt.Errorf("transport: dial more than %d times", retry) 68 | } 69 | 70 | // Close close client 71 | func (cli *Client) Close() error { 72 | var err error 73 | if cli.tp != nil { 74 | err = cli.tp.Close() 75 | } 76 | cli.cancel() 77 | return err 78 | } 79 | 80 | func (cli *Client) serve() error { 81 | defer cli.cancel() 82 | for { 83 | select { 84 | case <-cli.ctx.Done(): 85 | return ErrClosed 86 | default: 87 | } 88 | err := cli.tp.Serve() 89 | if err != nil { 90 | logging.Error("serve %s: %v", cli.addr, err) 91 | } 92 | encrypter := cli.tp.encrypter 93 | compresser := cli.tp.compresser 94 | cli.Lock() 95 | cli.tp.Close() 96 | cli.tp = nil 97 | cli.Unlock() 98 | conn, err := dial(cli.addr, 0) 99 | if err != nil { 100 | continue 101 | } 102 | tp := new(conn) 103 | tp.SetEncrypter(encrypter) 104 | tp.SetCompresser(compresser) 105 | cli.Lock() 106 | cli.tp = tp 107 | cli.Unlock() 108 | } 109 | } 110 | 111 | // Call call http request 112 | func (cli *Client) Call(ctx context.Context, req *http.Request) (*http.Response, error) { 113 | select { 114 | case <-cli.ctx.Done(): 115 | return nil, ErrClosed 116 | default: 117 | } 118 | cli.RLock() 119 | tp := cli.tp 120 | cli.RUnlock() 121 | if tp == nil { 122 | return nil, ErrReconnecting 123 | } 124 | return tp.Call(ctx, req) 125 | } 126 | 127 | // OpenStream open stream 128 | func (cli *Client) OpenStream(ctx context.Context) (*Stream, error) { 129 | select { 130 | case <-cli.ctx.Done(): 131 | return nil, ErrClosed 132 | default: 133 | } 134 | cli.RLock() 135 | tp := cli.tp 136 | cli.RUnlock() 137 | if tp == nil { 138 | return nil, ErrReconnecting 139 | } 140 | return tp.OpenStream(ctx) 141 | } 142 | -------------------------------------------------------------------------------- /encoding/codec/codec.go: -------------------------------------------------------------------------------- 1 | package codec 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "errors" 7 | "net/http" 8 | "reflect" 9 | "sync" 10 | 11 | "github.com/lwch/crpc/encoding" 12 | "github.com/lwch/crpc/internal/join" 13 | "google.golang.org/protobuf/proto" 14 | ) 15 | 16 | var errUnsupportedType = errors.New("codec: unsupported type") 17 | var errIsNotPointer = errors.New("codec: the specify variable is not pointer") 18 | var errProtoMessage = errors.New("codec: the specify variable is not proto.Message") 19 | 20 | // Codec serializer 21 | type Codec struct { 22 | bufPool sync.Pool 23 | joinPool sync.Pool 24 | } 25 | 26 | // New create codec 27 | func New() encoding.Codec { 28 | c := &Codec{} 29 | c.bufPool.New = func() any { 30 | return new(join.BytesBuffer) 31 | } 32 | c.joinPool.New = func() any { 33 | return join.New() 34 | } 35 | return c 36 | } 37 | 38 | // Marshal serialize data 39 | func (c *Codec) Marshal(v any) ([]byte, error) { 40 | switch value := v.(type) { 41 | case []byte: 42 | return c.marshalRaw(value) 43 | case http.Request, *http.Request: 44 | return c.marshalHTTPRequest(value) 45 | case http.Response, *http.Response: 46 | return c.marshalHTTPResponse(value) 47 | case proto.Message: 48 | return c.marshalProtoMessage(value) 49 | default: 50 | return nil, errUnsupportedType 51 | } 52 | } 53 | 54 | // Unmarshal deserialize data 55 | func (c *Codec) Unmarshal(data []byte, v any) (int, error) { 56 | vv := reflect.ValueOf(v) 57 | if vv.Kind() != reflect.Ptr { 58 | return 0, errIsNotPointer 59 | } 60 | r := bytes.NewReader(data) 61 | var hdr header 62 | err := binary.Read(r, binary.BigEndian, &hdr) 63 | if err != nil { 64 | return 0, err 65 | } 66 | 67 | switch hdr.Type { 68 | case TypeRaw: 69 | return c.unmarshalRaw(r, v, len(data)-1) 70 | case TypeHTTPRequest: 71 | return c.unmarshalHTTPRequest(r, v) 72 | case TypeHTTPResponse: 73 | return c.unmarshalHTTPResponse(r, v) 74 | case TypeProtobuf: 75 | return c.unmarshalProtoMessage(r, v) 76 | default: 77 | return 0, errUnsupportedType 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /encoding/codec/codec_test.go: -------------------------------------------------------------------------------- 1 | package codec 2 | 3 | import ( 4 | "io" 5 | "net/http" 6 | "strings" 7 | "testing" 8 | 9 | "github.com/lwch/crpc/test_data/pb" 10 | ) 11 | 12 | func TestRequest(t *testing.T) { 13 | c := New() 14 | req, err := http.NewRequest(http.MethodGet, "http://localhost/ping", nil) 15 | if err != nil { 16 | t.Fatal(err) 17 | } 18 | data, err := c.Marshal(req) 19 | if err != nil { 20 | t.Fatal(err) 21 | } 22 | 23 | var newReq http.Request 24 | _, err = c.Unmarshal(data, &newReq) 25 | if err != nil { 26 | t.Fatal(err) 27 | } 28 | if req.Method != newReq.Method { 29 | t.Fatal("invalid method") 30 | } 31 | if req.URL.Path != newReq.URL.Path { 32 | t.Fatal("invalid url") 33 | } 34 | 35 | var newReq2 any 36 | _, err = c.Unmarshal(data, &newReq2) 37 | if err != nil { 38 | t.Fatal(err) 39 | } 40 | newReq3, ok := newReq2.(*http.Request) 41 | if !ok { 42 | t.Fatal("invalid type") 43 | } 44 | if req.Method != newReq3.Method { 45 | t.Fatal("invalid method") 46 | } 47 | if req.URL.Path != newReq3.URL.Path { 48 | t.Fatal("invalid url") 49 | } 50 | } 51 | 52 | func TestResponse(t *testing.T) { 53 | c := New() 54 | rep := &http.Response{ 55 | StatusCode: http.StatusOK, 56 | Body: io.NopCloser(strings.NewReader("pong")), 57 | } 58 | data, err := c.Marshal(rep) 59 | if err != nil { 60 | t.Fatal(err) 61 | } 62 | 63 | var newRep http.Response 64 | _, err = c.Unmarshal(data, &newRep) 65 | if err != nil { 66 | t.Fatal(err) 67 | } 68 | if rep.StatusCode != newRep.StatusCode { 69 | t.Fatal("invalid status code") 70 | } 71 | if newRep.Body == nil { 72 | t.Fatal("invalid body") 73 | } 74 | var buf strings.Builder 75 | if _, err := io.Copy(&buf, newRep.Body); err != nil { 76 | t.Fatal(err) 77 | } 78 | if buf.String() != "pong" { 79 | t.Fatal("invalid body") 80 | } 81 | 82 | var newRep2 any 83 | _, err = c.Unmarshal(data, &newRep2) 84 | if err != nil { 85 | t.Fatal(err) 86 | } 87 | newRep3, ok := newRep2.(*http.Response) 88 | if !ok { 89 | t.Fatal("invalid type") 90 | } 91 | if rep.StatusCode != newRep3.StatusCode { 92 | t.Fatal("invalid status code") 93 | } 94 | if newRep3.Body == nil { 95 | t.Fatal("invalid body") 96 | } 97 | var buf2 strings.Builder 98 | if _, err := io.Copy(&buf2, newRep3.Body); err != nil { 99 | t.Fatal(err) 100 | } 101 | if buf2.String() != "pong" { 102 | t.Fatal("invalid body") 103 | } 104 | } 105 | 106 | func TestBytes(t *testing.T) { 107 | c := New() 108 | data := []byte("codec") 109 | buf, err := c.Marshal(data) 110 | if err != nil { 111 | t.Fatal(err) 112 | } 113 | 114 | var newBuf []byte 115 | _, err = c.Unmarshal(buf, &newBuf) 116 | if err != nil { 117 | t.Fatal(err) 118 | } 119 | if string(data) != string(newBuf) { 120 | t.Fatal("invalid bytes") 121 | } 122 | 123 | var newBuf2 any 124 | _, err = c.Unmarshal(buf, &newBuf2) 125 | if err != nil { 126 | t.Fatal(err) 127 | } 128 | newBuf3, ok := newBuf2.([]byte) 129 | if !ok { 130 | t.Fatal("invalid type") 131 | } 132 | if string(data) != string(newBuf3) { 133 | t.Fatal("invalid bytes") 134 | } 135 | } 136 | 137 | func TestProtobuf(t *testing.T) { 138 | c := New() 139 | data := pb.Request{ 140 | Id: 1, 141 | Uri: "/ping", 142 | Args: map[string]string{ 143 | "key": "value", 144 | }, 145 | } 146 | buf, err := c.Marshal(&data) 147 | if err != nil { 148 | t.Fatal(err) 149 | } 150 | 151 | var newReq pb.Request 152 | _, err = c.Unmarshal(buf, &newReq) 153 | if err != nil { 154 | t.Fatal(err) 155 | } 156 | if data.Id != newReq.Id { 157 | t.Fatal("invalid id") 158 | } 159 | if data.Uri != newReq.Uri { 160 | t.Fatal("invalid uri") 161 | } 162 | if data.Args["key"] != newReq.Args["key"] { 163 | t.Fatal("invalid args") 164 | } 165 | } 166 | -------------------------------------------------------------------------------- /encoding/codec/header.go: -------------------------------------------------------------------------------- 1 | package codec 2 | 3 | // DataType data type 4 | type DataType byte 5 | 6 | const ( 7 | // TypeUnknown unknown type 8 | TypeUnknown DataType = iota 9 | // TypeRaw raw data 10 | TypeRaw 11 | // TypeHTTPRequest http request data 12 | TypeHTTPRequest 13 | // TypeHTTPResponse http response data 14 | TypeHTTPResponse 15 | // TypeProtobuf protobuf data 16 | TypeProtobuf 17 | ) 18 | 19 | type header struct { 20 | Type DataType 21 | } 22 | 23 | func (h *header) Marshal() []byte { 24 | return []byte{byte(h.Type)} 25 | } 26 | -------------------------------------------------------------------------------- /encoding/codec/http.go: -------------------------------------------------------------------------------- 1 | package codec 2 | 3 | import ( 4 | "bufio" 5 | "io" 6 | "net/http" 7 | "reflect" 8 | 9 | "github.com/lwch/crpc/internal/join" 10 | ) 11 | 12 | func (c *Codec) marshalHTTPRequest(v any) ([]byte, error) { 13 | type writer interface { 14 | Write(io.Writer) error 15 | } 16 | var hdr header 17 | payload := c.bufPool.Get().(*join.BytesBuffer) 18 | defer c.bufPool.Put(payload) 19 | payload.Reset() 20 | hdr.Type = TypeHTTPRequest 21 | if err := v.(writer).Write(payload); err != nil { 22 | return nil, err 23 | } 24 | joiner := c.joinPool.Get().(*join.Joiner) 25 | defer c.joinPool.Put(joiner) 26 | joiner.SetHeader(&hdr) 27 | joiner.SetPayload(payload) 28 | return joiner.Marshal() 29 | } 30 | 31 | func (c *Codec) marshalHTTPResponse(v any) ([]byte, error) { 32 | type writer interface { 33 | Write(io.Writer) error 34 | } 35 | var hdr header 36 | payload := c.bufPool.Get().(*join.BytesBuffer) 37 | defer c.bufPool.Put(payload) 38 | payload.Reset() 39 | hdr.Type = TypeHTTPResponse 40 | if err := v.(writer).Write(payload); err != nil { 41 | return nil, err 42 | } 43 | joiner := c.joinPool.Get().(*join.Joiner) 44 | defer c.joinPool.Put(joiner) 45 | joiner.SetHeader(&hdr) 46 | joiner.SetPayload(payload) 47 | return joiner.Marshal() 48 | } 49 | 50 | func (c *Codec) unmarshalHTTPRequest(r io.Reader, v any) (int, error) { 51 | if req, ok := v.(*http.Request); ok { 52 | v, err := http.ReadRequest(bufio.NewReader(r)) 53 | if err != nil { 54 | return 0, err 55 | } 56 | *req = *v 57 | return 0, nil 58 | } 59 | req, err := http.ReadRequest(bufio.NewReader(r)) 60 | if err != nil { 61 | return 0, err 62 | } 63 | reflect.ValueOf(v).Elem().Set(reflect.ValueOf(req)) 64 | return 0, nil 65 | } 66 | 67 | func (c *Codec) unmarshalHTTPResponse(r io.Reader, v any) (int, error) { 68 | if resp, ok := v.(*http.Response); ok { 69 | v, err := http.ReadResponse(bufio.NewReader(r), nil) 70 | if err != nil { 71 | return 0, err 72 | } 73 | *resp = *v 74 | return 0, nil 75 | } 76 | resp, err := http.ReadResponse(bufio.NewReader(r), nil) 77 | if err != nil { 78 | return 0, err 79 | } 80 | reflect.ValueOf(v).Elem().Set(reflect.ValueOf(resp)) 81 | return 0, nil 82 | } 83 | -------------------------------------------------------------------------------- /encoding/codec/protobuf.go: -------------------------------------------------------------------------------- 1 | package codec 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | 7 | "github.com/lwch/crpc/internal/join" 8 | "google.golang.org/protobuf/proto" 9 | ) 10 | 11 | func (c *Codec) marshalProtoMessage(v any) ([]byte, error) { 12 | var hdr header 13 | payload := c.bufPool.Get().(*join.BytesBuffer) 14 | defer c.bufPool.Put(payload) 15 | payload.Reset() 16 | hdr.Type = TypeProtobuf 17 | enc, err := proto.Marshal(v.(proto.Message)) 18 | if err != nil { 19 | return nil, err 20 | } 21 | _, err = io.Copy(payload, bytes.NewReader(enc)) 22 | if err != nil { 23 | return nil, err 24 | } 25 | joiner := c.joinPool.Get().(*join.Joiner) 26 | defer c.joinPool.Put(joiner) 27 | joiner.SetHeader(&hdr) 28 | joiner.SetPayload(payload) 29 | return joiner.Marshal() 30 | } 31 | 32 | func (c *Codec) unmarshalProtoMessage(r io.Reader, v any) (int, error) { 33 | msg, ok := v.(proto.Message) 34 | if !ok { 35 | return 0, errProtoMessage 36 | } 37 | data, err := io.ReadAll(r) 38 | if err != nil { 39 | return 0, err 40 | } 41 | if err := proto.Unmarshal(data, msg); err != nil { 42 | return 0, err 43 | } 44 | return 0, nil 45 | } 46 | -------------------------------------------------------------------------------- /encoding/codec/raw.go: -------------------------------------------------------------------------------- 1 | package codec 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "reflect" 7 | 8 | "github.com/lwch/crpc/internal/join" 9 | "github.com/lwch/crpc/internal/utils" 10 | ) 11 | 12 | func (c *Codec) marshalRaw(v any) ([]byte, error) { 13 | var hdr header 14 | payload := c.bufPool.Get().(*join.BytesBuffer) 15 | defer c.bufPool.Put(payload) 16 | payload.Reset() 17 | hdr.Type = TypeRaw 18 | if _, err := io.Copy(payload, bytes.NewReader(v.([]byte))); err != nil { 19 | return nil, err 20 | } 21 | joiner := c.joinPool.Get().(*join.Joiner) 22 | defer c.joinPool.Put(joiner) 23 | joiner.SetHeader(&hdr) 24 | joiner.SetPayload(payload) 25 | return joiner.Marshal() 26 | } 27 | 28 | func (*Codec) unmarshalRaw(r io.Reader, v any, len int) (int, error) { 29 | if vv, ok := v.(*[]byte); ok { 30 | if cap(*vv) == 0 { 31 | *vv = make([]byte, len) 32 | } 33 | n, err := io.Copy(utils.BytesWriter(*vv), r) 34 | return int(n), err 35 | } 36 | var buf bytes.Buffer 37 | n, err := io.Copy(&buf, r) 38 | if err != nil { 39 | return 0, err 40 | } 41 | reflect.ValueOf(v).Elem().Set(reflect.ValueOf(buf.Bytes())) 42 | return int(n), nil 43 | } 44 | -------------------------------------------------------------------------------- /encoding/compress/compresser.go: -------------------------------------------------------------------------------- 1 | package compress 2 | 3 | import ( 4 | "bytes" 5 | "compress/gzip" 6 | "encoding/binary" 7 | "errors" 8 | "hash/crc32" 9 | "io" 10 | "sync" 11 | 12 | "github.com/klauspost/compress/zstd" 13 | ) 14 | 15 | var errNoDecompresser = errors.New("compress: no decompresser") 16 | var errInvalidChecksum = errors.New("compress: invalid checksum") 17 | 18 | // Method compress method 19 | type Method byte 20 | 21 | const ( 22 | // Gzip gzip method 23 | Gzip Method = iota 24 | // Zstd zstd method 25 | Zstd 26 | ) 27 | 28 | type compresser interface { 29 | io.Writer 30 | Reset(io.Writer) 31 | Close() error 32 | } 33 | 34 | type decompresser interface { 35 | io.Reader 36 | Reset(io.Reader) error 37 | } 38 | 39 | // Compresser compresser 40 | type Compresser struct { 41 | nc func(int) (compresser, error) 42 | nd func(io.Reader) (io.Reader, error) 43 | level int 44 | poolCompresser map[int]*sync.Pool 45 | mPoolCompresser sync.RWMutex 46 | poolDecompresser sync.Pool 47 | } 48 | 49 | // New create new compresser 50 | func New(m Method) *Compresser { 51 | switch m { 52 | case Gzip: 53 | cp := &Compresser{ 54 | nc: newGzipCompresser, 55 | nd: newGzipDecompresser, 56 | level: gzip.DefaultCompression, 57 | poolCompresser: make(map[int]*sync.Pool), 58 | } 59 | return cp 60 | case Zstd: 61 | cp := &Compresser{ 62 | nc: newZstdCompresser, 63 | level: int(zstd.SpeedDefault), 64 | poolCompresser: make(map[int]*sync.Pool), 65 | } 66 | cp.poolDecompresser.New = func() any { 67 | decompresser, err := newZstdDecompresser() 68 | if err != nil { 69 | return nil 70 | } 71 | return decompresser 72 | } 73 | return cp 74 | default: 75 | return nil 76 | } 77 | } 78 | 79 | // Compress compress func 80 | func (cp *Compresser) Compress(data []byte) ([]byte, error) { 81 | cp.mPoolCompresser.RLock() 82 | level := cp.level 83 | pool := cp.poolCompresser[level] 84 | cp.mPoolCompresser.RUnlock() 85 | if pool == nil { 86 | pool = new(sync.Pool) 87 | pool.New = func() any { 88 | compresser, err := cp.nc(level) 89 | if err != nil { 90 | return nil 91 | } 92 | return compresser 93 | } 94 | cp.mPoolCompresser.Lock() 95 | cp.poolCompresser[level] = pool 96 | cp.mPoolCompresser.Unlock() 97 | } 98 | obj := pool.Get() 99 | var w compresser 100 | if obj == nil { 101 | var err error 102 | w, err = cp.nc(level) 103 | if err != nil { 104 | return nil, err 105 | } 106 | } else { 107 | w = obj.(compresser) 108 | } 109 | defer pool.Put(w) 110 | data = binary.BigEndian.AppendUint32(data, crc32.ChecksumIEEE(data)) 111 | var buf bytes.Buffer 112 | w.Reset(&buf) 113 | _, err := io.Copy(w, bytes.NewReader(data)) 114 | if err != nil { 115 | return nil, err 116 | } 117 | err = w.Close() 118 | if err != nil { 119 | return nil, err 120 | } 121 | return buf.Bytes(), nil 122 | } 123 | 124 | // Decompress decompress func 125 | func (cp *Compresser) Decompress(data []byte) ([]byte, error) { 126 | obj := cp.poolDecompresser.Get() 127 | if obj == nil { 128 | if cp.nd == nil { 129 | return nil, errNoDecompresser 130 | } 131 | var err error 132 | obj, err = cp.nd(bytes.NewReader(data)) 133 | if err != nil { 134 | return nil, err 135 | } 136 | } 137 | defer cp.poolDecompresser.Put(obj) 138 | r := obj.(decompresser) 139 | err := r.Reset(bytes.NewReader(data)) 140 | if err != nil { 141 | return nil, err 142 | } 143 | data, err = io.ReadAll(r) 144 | if err != nil { 145 | return nil, err 146 | } 147 | sum := binary.BigEndian.Uint32(data[len(data)-4:]) 148 | data = data[:len(data)-4] 149 | if crc32.ChecksumIEEE(data) != sum { 150 | return nil, errInvalidChecksum 151 | } 152 | return data, nil 153 | } 154 | 155 | // SetLevel set compress level 156 | func (cp *Compresser) SetLevel(level int) { 157 | cp.level = level 158 | } 159 | -------------------------------------------------------------------------------- /encoding/compress/gzip.go: -------------------------------------------------------------------------------- 1 | package compress 2 | 3 | import ( 4 | "compress/gzip" 5 | "io" 6 | ) 7 | 8 | func newGzipCompresser(level int) (compresser, error) { 9 | return gzip.NewWriterLevel(io.Discard, level) 10 | } 11 | 12 | func newGzipDecompresser(r io.Reader) (io.Reader, error) { 13 | return gzip.NewReader(r) 14 | } 15 | -------------------------------------------------------------------------------- /encoding/compress/zstd.go: -------------------------------------------------------------------------------- 1 | package compress 2 | 3 | import ( 4 | "io" 5 | 6 | "github.com/klauspost/compress/zstd" 7 | ) 8 | 9 | func newZstdCompresser(level int) (compresser, error) { 10 | return zstd.NewWriter(io.Discard, zstd.WithEncoderLevel(zstd.EncoderLevel(level))) 11 | } 12 | 13 | func newZstdDecompresser() (io.Reader, error) { 14 | return zstd.NewReader(nil) 15 | } 16 | -------------------------------------------------------------------------------- /encoding/encoding.go: -------------------------------------------------------------------------------- 1 | package encoding 2 | 3 | // Codec serializer 4 | type Codec interface { 5 | Marshal(any) ([]byte, error) 6 | Unmarshal([]byte, any) (int, error) 7 | } 8 | 9 | // Encrypter encrypter 10 | type Encrypter interface { 11 | Encrypt([]byte) ([]byte, error) 12 | Decrypt([]byte) ([]byte, error) 13 | } 14 | 15 | // Compresser compresser 16 | type Compresser interface { 17 | Compress([]byte) ([]byte, error) 18 | Decompress([]byte) ([]byte, error) 19 | } 20 | -------------------------------------------------------------------------------- /encoding/encrypt/encrypter.go: -------------------------------------------------------------------------------- 1 | package encrypt 2 | 3 | import ( 4 | "bytes" 5 | "crypto/aes" 6 | "crypto/cipher" 7 | "crypto/des" 8 | "encoding/binary" 9 | "errors" 10 | "hash/crc32" 11 | ) 12 | 13 | var errInvalidChecksum = errors.New("encrypt: invalid checksum") 14 | var errInvalidBlockSize = errors.New("encrypt: invalid block size") 15 | 16 | // Method encrypt method 17 | type Method byte 18 | 19 | const ( 20 | // Aes aes method 21 | Aes Method = iota 22 | // Des des method 23 | Des 24 | ) 25 | 26 | type padFunc func([]byte) []byte 27 | 28 | // Encrypter encrypter 29 | type Encrypter struct { 30 | block cipher.Block 31 | iv []byte 32 | pad padFunc 33 | unpad padFunc 34 | } 35 | 36 | func makePad(size int) padFunc { 37 | return func(p []byte) []byte { 38 | if len(p) == 0 { 39 | return p 40 | } 41 | padSize := size - (len(p) % size) 42 | pad := bytes.Repeat([]byte{byte(padSize)}, padSize) 43 | return append(p, pad...) 44 | } 45 | } 46 | 47 | func unpad(p []byte) []byte { 48 | padSize := int(p[len(p)-1]) 49 | return p[:len(p)-padSize] 50 | } 51 | 52 | func repeat(str string, limit int) string { 53 | for len(str) < limit { 54 | str += str 55 | } 56 | return str 57 | } 58 | 59 | // New create new encrypter 60 | func New(m Method, key string) *Encrypter { 61 | var block cipher.Block 62 | var iv []byte 63 | var err error 64 | var pad padFunc 65 | switch m { 66 | case Aes: 67 | key = repeat(key, 32+aes.BlockSize) 68 | block, err = aes.NewCipher([]byte(key[:32])) 69 | if err != nil { 70 | return nil 71 | } 72 | iv = []byte(key[32 : 32+aes.BlockSize]) 73 | pad = makePad(aes.BlockSize) 74 | case Des: 75 | key = repeat(key, 24+des.BlockSize) 76 | block, err = des.NewTripleDESCipher([]byte(key[:24])) 77 | if err != nil { 78 | return nil 79 | } 80 | iv = []byte(key[24 : 24+des.BlockSize]) 81 | pad = makePad(des.BlockSize) 82 | } 83 | return &Encrypter{ 84 | block: block, 85 | iv: iv, 86 | pad: pad, 87 | unpad: unpad, 88 | } 89 | } 90 | 91 | // Encrypt encrypt data 92 | func (enc *Encrypter) Encrypt(src []byte) ([]byte, error) { 93 | bm := cipher.NewCBCEncrypter(enc.block, enc.iv) 94 | src = binary.BigEndian.AppendUint32(src, crc32.ChecksumIEEE(src)) 95 | src = enc.pad(src) 96 | dst := make([]byte, len(src)) 97 | bm.CryptBlocks(dst, src) 98 | return dst, nil 99 | } 100 | 101 | // Decrypt decrypt data 102 | func (enc *Encrypter) Decrypt(src []byte) ([]byte, error) { 103 | if len(src) == 0 { 104 | return src, nil 105 | } 106 | bm := cipher.NewCBCDecrypter(enc.block, enc.iv) 107 | if len(src)%bm.BlockSize() != 0 { 108 | return nil, errInvalidBlockSize 109 | } 110 | dst := make([]byte, len(src)) 111 | bm.CryptBlocks(dst, src) 112 | dst = enc.unpad(dst) 113 | sum := binary.BigEndian.Uint32(dst[len(dst)-4:]) 114 | dst = dst[:len(dst)-4] 115 | if crc32.ChecksumIEEE(dst) != sum { 116 | return nil, errInvalidChecksum 117 | } 118 | return dst, nil 119 | } 120 | -------------------------------------------------------------------------------- /example/consts.go: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | // Key encrypt key 4 | const Key = "crpc encrypt key" 5 | 6 | // Listen listen address 7 | const Listen = "localhost:8080" 8 | -------------------------------------------------------------------------------- /example/ping/client/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "log" 7 | "net/http" 8 | "time" 9 | 10 | "github.com/lwch/crpc" 11 | "github.com/lwch/crpc/encoding/compress" 12 | "github.com/lwch/crpc/encoding/encrypt" 13 | "github.com/lwch/crpc/example" 14 | ) 15 | 16 | func assert(err error) { 17 | if err != nil { 18 | panic(err) 19 | } 20 | } 21 | 22 | func main() { 23 | cli, err := crpc.NewClient(example.Listen) 24 | assert(err) 25 | defer cli.Close() 26 | cli.SetEncrypter(encrypt.New(encrypt.Aes, example.Key)) 27 | cli.SetCompresser(compress.New(compress.Gzip)) 28 | for { 29 | func() { 30 | req, err := http.NewRequest(http.MethodGet, "http://localhost/ping", nil) 31 | assert(err) 32 | ctx, cancel := context.WithTimeout(context.Background(), time.Second) 33 | defer cancel() 34 | rep, err := cli.Call(ctx, req) 35 | assert(err) 36 | data, err := io.ReadAll(rep.Body) 37 | assert(err) 38 | log.Printf("status_code=%d, data=%s", rep.StatusCode, string(data)) 39 | }() 40 | time.Sleep(time.Second) 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /example/ping/server/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "io" 5 | "log" 6 | "net/http" 7 | "strings" 8 | 9 | "github.com/lwch/crpc" 10 | "github.com/lwch/crpc/encoding/compress" 11 | "github.com/lwch/crpc/encoding/encrypt" 12 | "github.com/lwch/crpc/example" 13 | ) 14 | 15 | func main() { 16 | svr := crpc.NewServer(crpc.ServerConfig{ 17 | Encrypter: encrypt.New(encrypt.Aes, example.Key), 18 | Compresser: compress.New(compress.Gzip), 19 | OnRequest: func(r *http.Request) (*http.Response, error) { 20 | log.Println("ping recved") 21 | return &http.Response{ 22 | StatusCode: http.StatusOK, 23 | Body: io.NopCloser(strings.NewReader("pong")), 24 | }, nil 25 | }, 26 | }) 27 | defer svr.Close() 28 | svr.ListenAndServe(example.Listen) 29 | } 30 | -------------------------------------------------------------------------------- /example/streaming/client/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "log" 6 | "time" 7 | 8 | "github.com/lwch/crpc" 9 | "github.com/lwch/crpc/encoding/compress" 10 | "github.com/lwch/crpc/encoding/encrypt" 11 | "github.com/lwch/crpc/example" 12 | ) 13 | 14 | func assert(err error) { 15 | if err != nil { 16 | panic(err) 17 | } 18 | } 19 | 20 | func main() { 21 | cli, err := crpc.NewClient(example.Listen) 22 | assert(err) 23 | defer cli.Close() 24 | cli.SetEncrypter(encrypt.New(encrypt.Aes, example.Key)) 25 | cli.SetCompresser(compress.New(compress.Gzip)) 26 | ctx, cancel := context.WithTimeout(context.Background(), time.Second) 27 | defer cancel() 28 | s, err := cli.OpenStream(ctx) 29 | assert(err) 30 | defer s.Close() 31 | buf := make([]byte, 1024) 32 | for { 33 | _, err := s.Write([]byte("ping")) 34 | if err != nil { 35 | return 36 | } 37 | n, err := s.Read(buf) 38 | if err != nil { 39 | return 40 | } 41 | log.Println(string(buf[:n])) 42 | time.Sleep(time.Second) 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /example/streaming/server/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | 6 | "github.com/lwch/crpc" 7 | "github.com/lwch/crpc/encoding/compress" 8 | "github.com/lwch/crpc/encoding/encrypt" 9 | "github.com/lwch/crpc/example" 10 | ) 11 | 12 | func main() { 13 | svr := crpc.NewServer(crpc.ServerConfig{ 14 | Encrypter: encrypt.New(encrypt.Aes, example.Key), 15 | Compresser: compress.New(compress.Gzip), 16 | OnAccept: func(s *crpc.Stream) { 17 | defer s.Close() 18 | buf := make([]byte, 1024) 19 | for { 20 | n, err := s.Read(buf) 21 | if err != nil { 22 | return 23 | } 24 | log.Printf("%s recved", string(buf[:n])) 25 | _, err = s.Write([]byte("pong")) 26 | if err != nil { 27 | return 28 | } 29 | } 30 | }, 31 | }) 32 | defer svr.Close() 33 | svr.ListenAndServe(example.Listen) 34 | } 35 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/lwch/crpc 2 | 3 | go 1.22 4 | 5 | require ( 6 | github.com/klauspost/compress v1.18.0 7 | github.com/lwch/logging v1.1.3 8 | google.golang.org/protobuf v1.36.6 9 | ) 10 | 11 | require ( 12 | github.com/dustin/go-humanize v1.0.1 // indirect 13 | github.com/lwch/runtime v1.0.1 // indirect 14 | gopkg.in/yaml.v3 v3.0.1 // indirect 15 | ) 16 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= 2 | github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= 3 | github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= 4 | github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 5 | github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= 6 | github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= 7 | github.com/lwch/logging v1.1.3 h1:fYppr1h3TxSwH4WNHeO5g1gt1KDg4omd0rYw4ttddcU= 8 | github.com/lwch/logging v1.1.3/go.mod h1:MxaC1CKm3o5EZcgRvPMXxS2ogwXTEPeo/3SCBiTqn3o= 9 | github.com/lwch/runtime v1.0.1 h1:xfurs9IGzkTWfdum1K5GaDkEuGopohOQhk7roELmbf4= 10 | github.com/lwch/runtime v1.0.1/go.mod h1:mJuSABS7wUvRK3rUyV624NZ3+rV5eiAhsGEuU1iNrtk= 11 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= 12 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 13 | google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= 14 | google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= 15 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 16 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 17 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 18 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 19 | -------------------------------------------------------------------------------- /header.go: -------------------------------------------------------------------------------- 1 | package crpc 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | ) 7 | 8 | const keyRequestID = "X-Crpc-Request-Id" 9 | 10 | func (tp *transport) buildRequest(req *http.Request) ([]byte, uint64, error) { 11 | if req.Header == nil { 12 | req.Header = make(http.Header) 13 | } 14 | seq := tp.sequence.Add(1) 15 | req.Header.Set(keyRequestID, fmt.Sprintf("%d", seq)) 16 | payload, err := tp.codec.Marshal(req) 17 | if err != nil { 18 | return nil, 0, err 19 | } 20 | if tp.compresser != nil { 21 | payload, err = tp.compresser.Compress(payload) 22 | if err != nil { 23 | return nil, 0, err 24 | } 25 | } 26 | if tp.encrypter != nil { 27 | payload, err = tp.encrypter.Encrypt(payload) 28 | if err != nil { 29 | return nil, 0, err 30 | } 31 | } 32 | return payload, seq, nil 33 | } 34 | 35 | func (tp *transport) buildResponse(rep *http.Response, reqID uint64) ([]byte, error) { 36 | if rep.Header == nil { 37 | rep.Header = make(http.Header) 38 | } 39 | rep.Header.Set(keyRequestID, fmt.Sprintf("%d", reqID)) 40 | payload, err := tp.codec.Marshal(rep) 41 | if err != nil { 42 | return nil, err 43 | } 44 | if tp.compresser != nil { 45 | payload, err = tp.compresser.Compress(payload) 46 | if err != nil { 47 | return nil, err 48 | } 49 | } 50 | if tp.encrypter != nil { 51 | payload, err = tp.encrypter.Encrypt(payload) 52 | if err != nil { 53 | return nil, err 54 | } 55 | } 56 | return payload, nil 57 | } 58 | 59 | func (tp *transport) decode(data []byte) (any, error) { 60 | if tp.encrypter != nil { 61 | var err error 62 | data, err = tp.encrypter.Decrypt(data) 63 | if err != nil { 64 | return nil, err 65 | } 66 | } 67 | if tp.compresser != nil { 68 | var err error 69 | data, err = tp.compresser.Decompress(data) 70 | if err != nil { 71 | return nil, err 72 | } 73 | } 74 | var value any 75 | _, err := tp.codec.Unmarshal(data, &value) 76 | if err != nil { 77 | return nil, err 78 | } 79 | return value, nil 80 | } 81 | -------------------------------------------------------------------------------- /internal/join/bytes.go: -------------------------------------------------------------------------------- 1 | package join 2 | 3 | import "bytes" 4 | 5 | // BytesBuffer bytes buffer 6 | type BytesBuffer struct { 7 | bytes.Buffer 8 | } 9 | 10 | // Marshal marshal bytes 11 | func (b *BytesBuffer) Marshal() []byte { 12 | return b.Bytes() 13 | } 14 | 15 | // Bytes bytes 16 | type Bytes []byte 17 | 18 | // Marshal marshal bytes 19 | func (b Bytes) Marshal() []byte { 20 | return b 21 | } 22 | -------------------------------------------------------------------------------- /internal/join/join.go: -------------------------------------------------------------------------------- 1 | package join 2 | 3 | import "bytes" 4 | 5 | // Marshaler marshaler 6 | type Marshaler interface { 7 | Marshal() []byte 8 | } 9 | 10 | // Joiner joiner 11 | type Joiner struct { 12 | buf bytes.Buffer 13 | header Marshaler 14 | payload Marshaler 15 | } 16 | 17 | // New create joiner 18 | func New() *Joiner { 19 | return &Joiner{} 20 | } 21 | 22 | // SetHeader set header 23 | func (j *Joiner) SetHeader(header Marshaler) { 24 | j.header = header 25 | } 26 | 27 | // SetPayload set payload 28 | func (j *Joiner) SetPayload(body Marshaler) { 29 | j.payload = body 30 | } 31 | 32 | // Marshal marshal 33 | func (j *Joiner) Marshal() ([]byte, error) { 34 | j.buf.Reset() 35 | if _, err := j.buf.Write(j.header.Marshal()); err != nil { 36 | return nil, err 37 | } 38 | if _, err := j.buf.Write(j.payload.Marshal()); err != nil { 39 | return nil, err 40 | } 41 | cpy := make([]byte, j.buf.Len()) 42 | copy(cpy, j.buf.Bytes()) 43 | return cpy, nil 44 | } 45 | -------------------------------------------------------------------------------- /internal/utils/bytes_writer.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | // BytesWriter is a simple io.Writer implementation that writes to a []byte. 4 | type BytesWriter []byte 5 | 6 | // Write implements io.Writer. 7 | func (w BytesWriter) Write(p []byte) (int, error) { 8 | return copy(w, p), nil 9 | } 10 | -------------------------------------------------------------------------------- /network/conn.go: -------------------------------------------------------------------------------- 1 | package network 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/binary" 7 | "errors" 8 | "fmt" 9 | "hash/crc32" 10 | "io" 11 | "math" 12 | "net" 13 | "sync" 14 | "sync/atomic" 15 | 16 | "github.com/lwch/logging" 17 | ) 18 | 19 | var errTooLarge = errors.New("network: too large") 20 | var errBufferTooShort = errors.New("network: buffer too short") 21 | var errInvalidPacketChecksum = errors.New("network: invalid packet checksum") 22 | var errOpenStreamDone = errors.New("network: open stream done") 23 | var errStreamNotFound = errors.New("network: stream not found") 24 | 25 | const ( 26 | flagData = 0 27 | flagStreamOpen = 1 << 31 // 32位表示open请求 28 | flagStreamOpenAck = 1 << 30 // 31位表示open成功 29 | flagStreamClose = 1 << 29 // 30位表示关闭请求 30 | flagStreamData = 1 << 28 // 29位表示数据传输 31 | flagPing = 1 << 27 // 28位表示ping请求 32 | flagPong = 1 << 26 // 27位表示pong响应 33 | ) 34 | 35 | type writeArgs struct { 36 | flag uint32 37 | data []byte 38 | } 39 | 40 | type writeControlArgs struct { 41 | id uint32 42 | flag uint32 43 | } 44 | 45 | // Conn connection 46 | type Conn struct { 47 | conn net.Conn 48 | mRead sync.Mutex 49 | sequence atomic.Uint64 50 | streamID atomic.Uint32 51 | chRead chan []byte 52 | chWrite chan writeArgs 53 | chWriteControl chan writeControlArgs 54 | // stream 55 | streams map[uint32]*Stream 56 | mStreams sync.RWMutex 57 | chStreamOpened chan *Stream 58 | // runtime 59 | err error 60 | ctx context.Context 61 | } 62 | 63 | // 封包格式 64 | // +-------------+---------+----------+---------+---------+ 65 | // | Sequence(4) | Size(2) | Crc32(4) | Flag(4) | Payload | 66 | // +-------------+---------+----------+---------+---------+ 67 | // Flag字段格式 68 | // +---------+------------+----------+---------+---------+---------+-----------+---------------+ 69 | // | Open(1) | OpenAck(1) | Close(1) | Data(1) | Ping(1) | Pong(1) | Unused(2) | Stream ID(24) | 70 | // +---------+------------+----------+---------+---------+---------+-----------+---------------+ 71 | // 高6位为标志位,后2位暂未使用,低24位为stream id 72 | // Stream ID由Accept方进行分配,在Open请求中Stream ID为0 73 | 74 | type header struct { 75 | Sequence uint64 76 | Size uint16 77 | Crc32 uint32 78 | Flag uint32 79 | } 80 | 81 | // New new connection 82 | func New(conn net.Conn) *Conn { 83 | ctx, cancel := context.WithCancel(context.Background()) 84 | ret := &Conn{ 85 | conn: conn, 86 | chRead: make(chan []byte, 10000), 87 | chWrite: make(chan writeArgs, 10000), 88 | chWriteControl: make(chan writeControlArgs, 100), 89 | streams: make(map[uint32]*Stream), 90 | chStreamOpened: make(chan *Stream), 91 | ctx: ctx, 92 | } 93 | go ret.loopRead(cancel) 94 | go ret.loopWrite(cancel) 95 | return ret 96 | } 97 | 98 | // Close close connection 99 | func (c *Conn) Close() error { 100 | logging.Error("connection closed: %s", c.conn.RemoteAddr().String()) 101 | return c.conn.Close() 102 | } 103 | 104 | // AcceptStream accept stream 105 | func (c *Conn) AcceptStream() (*Stream, error) { 106 | select { 107 | case <-c.ctx.Done(): 108 | return nil, c.err 109 | case stream := <-c.chStreamOpened: 110 | return stream, nil 111 | } 112 | } 113 | 114 | // OpenStream open stream 115 | func (c *Conn) OpenStream(ctx context.Context) (*Stream, error) { 116 | c.chWriteControl <- writeControlArgs{ 117 | id: 0, 118 | flag: flagStreamOpen, 119 | } 120 | select { 121 | case <-ctx.Done(): 122 | return nil, errOpenStreamDone 123 | case <-c.ctx.Done(): 124 | return nil, c.err 125 | case stream := <-c.chStreamOpened: 126 | return stream, nil 127 | } 128 | } 129 | 130 | // Write send data 131 | func (c *Conn) Write(p []byte) (int, error) { 132 | if len(p) > math.MaxUint16 { 133 | return 0, errTooLarge 134 | } 135 | data := make([]byte, len(p)) 136 | copy(data, p) 137 | c.chWrite <- writeArgs{ 138 | flag: flagData, 139 | data: data, 140 | } 141 | return len(p), nil 142 | } 143 | 144 | // Read read data 145 | func (c *Conn) Read(p []byte) (int, error) { 146 | select { 147 | case <-c.ctx.Done(): 148 | return 0, c.err 149 | case data := <-c.chRead: 150 | if len(p) < len(data) { 151 | return 0, errBufferTooShort 152 | } 153 | return copy(p, data), nil 154 | } 155 | } 156 | 157 | func (c *Conn) read(p []byte) (*header, int, error) { 158 | c.mRead.Lock() 159 | defer c.mRead.Unlock() 160 | var hdr header 161 | err := binary.Read(c.conn, binary.BigEndian, &hdr) 162 | if err != nil { 163 | return nil, 0, fmt.Errorf("network: read packet header: %v", err) 164 | } 165 | if len(p) < int(hdr.Size) { 166 | return nil, 0, errBufferTooShort 167 | } 168 | if hdr.Size == 0 { 169 | return &hdr, 0, nil 170 | } 171 | n, err := io.ReadFull(c.conn, p[:hdr.Size]) 172 | if err != nil { 173 | return nil, 0, fmt.Errorf("network: read packet payload[%d]: %v", hdr.Sequence, err) 174 | } 175 | if crc32.ChecksumIEEE(p[:hdr.Size]) != hdr.Crc32 { 176 | return nil, 0, errInvalidPacketChecksum 177 | } 178 | return &hdr, n, nil 179 | } 180 | 181 | func dup(data []byte) []byte { 182 | ret := make([]byte, len(data)) 183 | copy(ret, data) 184 | return ret 185 | } 186 | 187 | func (c *Conn) onClose(err error) { 188 | logging.Error("connection closed: %s", c.conn.RemoteAddr().String()) 189 | c.conn.Close() 190 | var streams []*Stream 191 | c.mStreams.RLock() 192 | for _, stream := range c.streams { 193 | streams = append(streams, stream) 194 | } 195 | c.mStreams.RUnlock() 196 | for _, stream := range streams { 197 | stream.onClose(err) 198 | } 199 | } 200 | 201 | func (c *Conn) loopRead(cancel context.CancelFunc) { 202 | var err error 203 | defer func() { 204 | c.err = err 205 | cancel() 206 | }() 207 | defer c.onClose(err) 208 | buf := make([]byte, math.MaxUint16) 209 | for { 210 | var hdr *header 211 | var n int 212 | hdr, n, err = c.read(buf) 213 | if err != nil { 214 | logging.Error("loop read => %s: %v", c.conn.RemoteAddr().String(), err) 215 | return 216 | } 217 | if hdr.Flag&flagPing != 0 { 218 | err = c.handlePing() 219 | if err != nil { 220 | logging.Error("handle ping => %s: %v", c.conn.RemoteAddr().String(), err) 221 | return 222 | } 223 | continue 224 | } 225 | if hdr.Flag&flagStreamOpen != 0 { 226 | err = c.handleOpenStream() 227 | if err != nil { 228 | logging.Error("handle open stream => %s: %v", c.conn.RemoteAddr().String(), err) 229 | return 230 | } 231 | continue 232 | } 233 | if hdr.Flag&flagStreamOpenAck != 0 { 234 | err = c.handleOpenStreamAck(hdr.Flag) 235 | if err != nil { 236 | logging.Error("handle open stream ack => %s: %v", c.conn.RemoteAddr().String(), err) 237 | return 238 | } 239 | continue 240 | } 241 | if hdr.Flag&flagStreamClose != 0 { 242 | err = c.handleCloseStream(hdr.Flag) 243 | if err != nil { 244 | if err == errStreamNotFound { 245 | continue 246 | } 247 | logging.Error("handle close stream => %s: %v", c.conn.RemoteAddr().String(), err) 248 | return 249 | } 250 | continue 251 | } 252 | if hdr.Flag&flagStreamData != 0 { 253 | err = c.handleStreamData(hdr.Flag, buf[:n]) 254 | if err != nil { 255 | if err == errStreamNotFound { 256 | continue 257 | } 258 | logging.Error("handle data => %s: %v", c.conn.RemoteAddr().String(), err) 259 | return 260 | } 261 | continue 262 | } 263 | if hdr.Size == 0 { 264 | continue 265 | } 266 | c.chRead <- dup(buf[:n]) 267 | } 268 | } 269 | 270 | func (c *Conn) loopWrite(cancel context.CancelFunc) { 271 | var err error 272 | defer func() { 273 | c.err = err 274 | cancel() 275 | }() 276 | defer c.onClose(err) 277 | for { 278 | select { 279 | case args := <-c.chWrite: 280 | err := c.writeData(args.flag, args.data) 281 | if err != nil { 282 | logging.Error("network: %v", err) 283 | return 284 | } 285 | case ctrl := <-c.chWriteControl: 286 | err := c.writeControl(ctrl.id | ctrl.flag) 287 | if err != nil { 288 | logging.Error("network: %v", err) 289 | return 290 | } 291 | case <-c.ctx.Done(): 292 | err = c.ctx.Err() 293 | return 294 | } 295 | } 296 | } 297 | 298 | func (c *Conn) writeData(flag uint32, p []byte) error { 299 | var buf bytes.Buffer 300 | sequence := c.sequence.Add(1) 301 | err := binary.Write(&buf, binary.BigEndian, header{ 302 | Size: uint16(len(p)), 303 | Crc32: crc32.ChecksumIEEE(p), 304 | Sequence: sequence, 305 | Flag: flag, 306 | }) 307 | if err != nil { 308 | return fmt.Errorf("build packet header[%d]: %v", sequence, err) 309 | } 310 | _, err = io.Copy(&buf, bytes.NewReader(p)) 311 | if err != nil { 312 | return fmt.Errorf("build packet payload[%d]: %v", sequence, err) 313 | } 314 | _, err = c.conn.Write(buf.Bytes()) 315 | if err != nil { 316 | return fmt.Errorf("write packet[%d]: %v", sequence, err) 317 | } 318 | return nil 319 | } 320 | 321 | func (c *Conn) writeControl(flag uint32) error { 322 | sequence := c.sequence.Add(1) 323 | err := binary.Write(c.conn, binary.BigEndian, header{ 324 | Size: 0, 325 | Crc32: 0, 326 | Sequence: sequence, 327 | Flag: flag, 328 | }) 329 | if err != nil { 330 | return fmt.Errorf("send openstream: %v", err) 331 | } 332 | return nil 333 | } 334 | 335 | func (c *Conn) getStream(stream uint32) *Stream { 336 | stream = stream & 0xffffff 337 | c.mStreams.RLock() 338 | defer c.mStreams.RUnlock() 339 | return c.streams[stream] 340 | } 341 | 342 | // SendKeepalive send keepalive packet 343 | func (c *Conn) SendKeepalive() error { 344 | var buf bytes.Buffer 345 | sequence := c.sequence.Add(1) 346 | err := binary.Write(&buf, binary.BigEndian, header{ 347 | Size: 0, 348 | Crc32: 0, 349 | Sequence: sequence, 350 | Flag: flagPing, 351 | }) 352 | if err != nil { 353 | return fmt.Errorf("network: build ping packet header[%d]: %v", sequence, err) 354 | } 355 | _, err = c.conn.Write(buf.Bytes()) 356 | if err != nil { 357 | return fmt.Errorf("network: write ping packet[%d]: %v", sequence, err) 358 | } 359 | return nil 360 | } 361 | 362 | func (c *Conn) handlePing() error { 363 | var buf bytes.Buffer 364 | sequence := c.sequence.Add(1) 365 | err := binary.Write(&buf, binary.BigEndian, header{ 366 | Size: 0, 367 | Crc32: 0, 368 | Sequence: sequence, 369 | Flag: flagPong, 370 | }) 371 | if err != nil { 372 | return fmt.Errorf("network: build pong packet header[%d]: %v", sequence, err) 373 | } 374 | _, err = c.conn.Write(buf.Bytes()) 375 | if err != nil { 376 | return fmt.Errorf("network: write pong packet[%d]: %v", sequence, err) 377 | } 378 | return nil 379 | } 380 | -------------------------------------------------------------------------------- /network/stream.go: -------------------------------------------------------------------------------- 1 | package network 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "math" 7 | "sync/atomic" 8 | ) 9 | 10 | // ErrStreamClosed stream closed error 11 | var ErrStreamClosed = errors.New("network: stream closed") 12 | 13 | // ErrClosedByRemote closed by remote error 14 | var ErrClosedByRemote = errors.New("network: closed by remote") 15 | 16 | // Stream stream 17 | type Stream struct { 18 | parent *Conn 19 | id uint32 20 | closed atomic.Bool 21 | chRead chan []byte 22 | // runtime 23 | err error 24 | ctx context.Context 25 | cancel context.CancelFunc 26 | } 27 | 28 | func newStream(parent *Conn, id uint32) *Stream { 29 | ctx, cancel := context.WithCancel(context.Background()) 30 | return &Stream{ 31 | parent: parent, 32 | id: id & 0xffffff, 33 | chRead: make(chan []byte, 1000), 34 | ctx: ctx, 35 | cancel: cancel, 36 | } 37 | } 38 | 39 | // ID get stream id 40 | func (s *Stream) ID() uint32 { 41 | return s.id & 0xffffff 42 | } 43 | 44 | // Close close stream 45 | func (s *Stream) Close() error { 46 | s.onClose(nil) 47 | return nil 48 | } 49 | 50 | func (s *Stream) onClose(err error) { 51 | s.closed.Store(true) 52 | s.err = err 53 | s.cancel() 54 | s.parent.chWriteControl <- writeControlArgs{ 55 | id: s.ID(), 56 | flag: flagStreamClose, 57 | } 58 | s.parent.mStreams.Lock() 59 | delete(s.parent.streams, s.ID()) 60 | s.parent.mStreams.Unlock() 61 | } 62 | 63 | // Read read data 64 | func (s *Stream) Read(p []byte) (int, error) { 65 | if s.closed.Load() { 66 | return 0, ErrStreamClosed 67 | } 68 | select { 69 | case data := <-s.chRead: 70 | if len(data) > len(p) { 71 | return 0, errBufferTooShort 72 | } 73 | return copy(p, data), nil 74 | case <-s.ctx.Done(): 75 | return 0, s.err 76 | } 77 | } 78 | 79 | // Write write data 80 | func (s *Stream) Write(p []byte) (int, error) { 81 | if s.closed.Load() { 82 | return 0, ErrStreamClosed 83 | } 84 | if len(p) > math.MaxUint16 { 85 | return 0, errTooLarge 86 | } 87 | data := make([]byte, len(p)) 88 | copy(data, p) 89 | s.parent.chWrite <- writeArgs{ 90 | flag: s.ID() | flagStreamData, 91 | data: data, 92 | } 93 | return len(p), nil 94 | } 95 | 96 | func (c *Conn) handleOpenStream() error { 97 | stream := newStream(c, c.streamID.Add(1)) 98 | c.chWriteControl <- writeControlArgs{ 99 | id: stream.id, 100 | flag: flagStreamOpenAck, 101 | } 102 | c.mStreams.Lock() 103 | c.streams[stream.id] = stream 104 | c.mStreams.Unlock() 105 | c.chStreamOpened <- stream 106 | return nil 107 | } 108 | 109 | func (c *Conn) handleOpenStreamAck(flag uint32) error { 110 | s := newStream(c, flag) 111 | c.mStreams.Lock() 112 | c.streams[s.id] = s 113 | c.mStreams.Unlock() 114 | c.chStreamOpened <- s 115 | return nil 116 | } 117 | 118 | func (c *Conn) handleCloseStream(flag uint32) error { 119 | s := c.getStream(flag) 120 | if s == nil { 121 | return errStreamNotFound 122 | } 123 | s.onClose(ErrClosedByRemote) 124 | return nil 125 | } 126 | 127 | func (c *Conn) handleStreamData(flag uint32, data []byte) error { 128 | s := c.getStream(flag) 129 | if s == nil { 130 | return errStreamNotFound 131 | } 132 | s.chRead <- dup(data) 133 | return nil 134 | } 135 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | package crpc 2 | 3 | import ( 4 | "net" 5 | 6 | "github.com/lwch/crpc/encoding" 7 | "github.com/lwch/logging" 8 | ) 9 | 10 | // AcceptStreamHandlerFunc handler func after accept 11 | type AcceptStreamHandlerFunc func(*Stream) 12 | 13 | // Server rpc server 14 | type Server struct { 15 | listener net.Listener 16 | encrypter encoding.Encrypter 17 | compresser encoding.Compresser 18 | onRequest RequestHandlerFunc 19 | onAcceptStream AcceptStreamHandlerFunc 20 | } 21 | 22 | // ServerConfig server config 23 | type ServerConfig struct { 24 | Encrypter encoding.Encrypter 25 | Compresser encoding.Compresser 26 | OnRequest RequestHandlerFunc 27 | OnAccept AcceptStreamHandlerFunc 28 | } 29 | 30 | // NewServer create server 31 | func NewServer(cfg ServerConfig) *Server { 32 | return &Server{ 33 | encrypter: cfg.Encrypter, 34 | compresser: cfg.Compresser, 35 | onRequest: cfg.OnRequest, 36 | onAcceptStream: cfg.OnAccept, 37 | } 38 | } 39 | 40 | // ListenAndServe listen and serve 41 | func (svr *Server) ListenAndServe(addr string) error { 42 | var err error 43 | svr.listener, err = net.Listen("tcp", addr) 44 | if err != nil { 45 | return err 46 | } 47 | defer svr.listener.Close() 48 | for { 49 | conn, err := svr.listener.Accept() 50 | if err != nil { 51 | continue 52 | } 53 | go svr.handle(conn) 54 | } 55 | } 56 | 57 | // Close close server 58 | func (svr *Server) Close() error { 59 | return svr.listener.Close() 60 | } 61 | 62 | func (svr *Server) handle(conn net.Conn) { 63 | defer conn.Close() 64 | tp := new(conn) 65 | tp.SetEncrypter(svr.encrypter) 66 | tp.SetCompresser(svr.compresser) 67 | defer tp.Close() 68 | tp.SetOnRequest(svr.onRequest) 69 | go svr.acceptStream(tp) 70 | tp.Serve() 71 | } 72 | 73 | func (svr *Server) acceptStream(tp *transport) { 74 | for { 75 | stream, err := tp.AcceptStream() 76 | if err != nil { 77 | logging.Error("accept stream: %v", err) 78 | return 79 | } 80 | go svr.onAcceptStream(stream) 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /stream.go: -------------------------------------------------------------------------------- 1 | package crpc 2 | 3 | import ( 4 | "github.com/lwch/crpc/network" 5 | ) 6 | 7 | // Stream stream 8 | type Stream struct { 9 | parent *transport 10 | s *network.Stream 11 | } 12 | 13 | // Close close stream 14 | func (s *Stream) Close() error { 15 | return s.s.Close() 16 | } 17 | 18 | // Write write data in stream 19 | func (s *Stream) Write(p []byte) (int, error) { 20 | data, err := s.parent.codec.Marshal(p) 21 | if err != nil { 22 | return 0, err 23 | } 24 | if s.parent.compresser != nil { 25 | data, err = s.parent.compresser.Compress(data) 26 | if err != nil { 27 | return 0, err 28 | } 29 | } 30 | if s.parent.encrypter != nil { 31 | data, err = s.parent.encrypter.Encrypt(data) 32 | if err != nil { 33 | return 0, err 34 | } 35 | } 36 | _, err = s.s.Write(data) 37 | if err != nil { 38 | return 0, err 39 | } 40 | return len(p), nil 41 | } 42 | 43 | // Read read data from stream 44 | func (s *Stream) Read(p []byte) (int, error) { 45 | buf := make([]byte, 65535) 46 | n, err := s.s.Read(buf) 47 | if err != nil { 48 | return 0, err 49 | } 50 | if n == 0 { 51 | return 0, nil 52 | } 53 | buf = buf[:n] 54 | if s.parent.encrypter != nil { 55 | buf, err = s.parent.encrypter.Decrypt(buf) 56 | if err != nil { 57 | return 0, err 58 | } 59 | } 60 | if s.parent.compresser != nil { 61 | buf, err = s.parent.compresser.Decompress(buf) 62 | if err != nil { 63 | return 0, err 64 | } 65 | } 66 | n, err = s.parent.codec.Unmarshal(buf, &p) 67 | if err != nil { 68 | return 0, err 69 | } 70 | return n, nil 71 | } 72 | -------------------------------------------------------------------------------- /test_data/pb/msg.pb.go: -------------------------------------------------------------------------------- 1 | // Code generated by protoc-gen-go. DO NOT EDIT. 2 | // versions: 3 | // protoc-gen-go v1.28.1 4 | // protoc v3.21.2 5 | // source: msg.proto 6 | 7 | package pb 8 | 9 | import ( 10 | protoreflect "google.golang.org/protobuf/reflect/protoreflect" 11 | protoimpl "google.golang.org/protobuf/runtime/protoimpl" 12 | reflect "reflect" 13 | sync "sync" 14 | ) 15 | 16 | const ( 17 | // Verify that this generated code is sufficiently up-to-date. 18 | _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) 19 | // Verify that runtime/protoimpl is sufficiently up-to-date. 20 | _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) 21 | ) 22 | 23 | type Request struct { 24 | state protoimpl.MessageState 25 | sizeCache protoimpl.SizeCache 26 | unknownFields protoimpl.UnknownFields 27 | 28 | Id uint64 `protobuf:"varint,1,opt,name=id,proto3" json:"id,omitempty"` 29 | Uri string `protobuf:"bytes,2,opt,name=uri,proto3" json:"uri,omitempty"` 30 | Args map[string]string `protobuf:"bytes,3,rep,name=args,proto3" json:"args,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` 31 | } 32 | 33 | func (x *Request) Reset() { 34 | *x = Request{} 35 | if protoimpl.UnsafeEnabled { 36 | mi := &file_msg_proto_msgTypes[0] 37 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 38 | ms.StoreMessageInfo(mi) 39 | } 40 | } 41 | 42 | func (x *Request) String() string { 43 | return protoimpl.X.MessageStringOf(x) 44 | } 45 | 46 | func (*Request) ProtoMessage() {} 47 | 48 | func (x *Request) ProtoReflect() protoreflect.Message { 49 | mi := &file_msg_proto_msgTypes[0] 50 | if protoimpl.UnsafeEnabled && x != nil { 51 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 52 | if ms.LoadMessageInfo() == nil { 53 | ms.StoreMessageInfo(mi) 54 | } 55 | return ms 56 | } 57 | return mi.MessageOf(x) 58 | } 59 | 60 | // Deprecated: Use Request.ProtoReflect.Descriptor instead. 61 | func (*Request) Descriptor() ([]byte, []int) { 62 | return file_msg_proto_rawDescGZIP(), []int{0} 63 | } 64 | 65 | func (x *Request) GetId() uint64 { 66 | if x != nil { 67 | return x.Id 68 | } 69 | return 0 70 | } 71 | 72 | func (x *Request) GetUri() string { 73 | if x != nil { 74 | return x.Uri 75 | } 76 | return "" 77 | } 78 | 79 | func (x *Request) GetArgs() map[string]string { 80 | if x != nil { 81 | return x.Args 82 | } 83 | return nil 84 | } 85 | 86 | var File_msg_proto protoreflect.FileDescriptor 87 | 88 | var file_msg_proto_rawDesc = []byte{ 89 | 0x0a, 0x09, 0x6d, 0x73, 0x67, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x02, 0x70, 0x62, 0x22, 90 | 0x8f, 0x01, 0x0a, 0x07, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 91 | 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x02, 0x69, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x75, 92 | 0x72, 0x69, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x69, 0x12, 0x29, 0x0a, 93 | 0x04, 0x61, 0x72, 0x67, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x70, 0x62, 94 | 0x2e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x2e, 0x41, 0x72, 0x67, 0x73, 0x45, 0x6e, 0x74, 95 | 0x72, 0x79, 0x52, 0x04, 0x61, 0x72, 0x67, 0x73, 0x1a, 0x37, 0x0a, 0x09, 0x41, 0x72, 0x67, 0x73, 96 | 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 97 | 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 98 | 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 99 | 0x01, 0x42, 0x07, 0x5a, 0x05, 0x2e, 0x2f, 0x3b, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 100 | 0x6f, 0x33, 101 | } 102 | 103 | var ( 104 | file_msg_proto_rawDescOnce sync.Once 105 | file_msg_proto_rawDescData = file_msg_proto_rawDesc 106 | ) 107 | 108 | func file_msg_proto_rawDescGZIP() []byte { 109 | file_msg_proto_rawDescOnce.Do(func() { 110 | file_msg_proto_rawDescData = protoimpl.X.CompressGZIP(file_msg_proto_rawDescData) 111 | }) 112 | return file_msg_proto_rawDescData 113 | } 114 | 115 | var file_msg_proto_msgTypes = make([]protoimpl.MessageInfo, 2) 116 | var file_msg_proto_goTypes = []interface{}{ 117 | (*Request)(nil), // 0: pb.Request 118 | nil, // 1: pb.Request.ArgsEntry 119 | } 120 | var file_msg_proto_depIdxs = []int32{ 121 | 1, // 0: pb.Request.args:type_name -> pb.Request.ArgsEntry 122 | 1, // [1:1] is the sub-list for method output_type 123 | 1, // [1:1] is the sub-list for method input_type 124 | 1, // [1:1] is the sub-list for extension type_name 125 | 1, // [1:1] is the sub-list for extension extendee 126 | 0, // [0:1] is the sub-list for field type_name 127 | } 128 | 129 | func init() { file_msg_proto_init() } 130 | func file_msg_proto_init() { 131 | if File_msg_proto != nil { 132 | return 133 | } 134 | if !protoimpl.UnsafeEnabled { 135 | file_msg_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { 136 | switch v := v.(*Request); i { 137 | case 0: 138 | return &v.state 139 | case 1: 140 | return &v.sizeCache 141 | case 2: 142 | return &v.unknownFields 143 | default: 144 | return nil 145 | } 146 | } 147 | } 148 | type x struct{} 149 | out := protoimpl.TypeBuilder{ 150 | File: protoimpl.DescBuilder{ 151 | GoPackagePath: reflect.TypeOf(x{}).PkgPath(), 152 | RawDescriptor: file_msg_proto_rawDesc, 153 | NumEnums: 0, 154 | NumMessages: 2, 155 | NumExtensions: 0, 156 | NumServices: 0, 157 | }, 158 | GoTypes: file_msg_proto_goTypes, 159 | DependencyIndexes: file_msg_proto_depIdxs, 160 | MessageInfos: file_msg_proto_msgTypes, 161 | }.Build() 162 | File_msg_proto = out.File 163 | file_msg_proto_rawDesc = nil 164 | file_msg_proto_goTypes = nil 165 | file_msg_proto_depIdxs = nil 166 | } 167 | -------------------------------------------------------------------------------- /test_data/pb/msg.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package pb; 4 | option go_package="./;pb"; 5 | 6 | message Request { 7 | uint64 id = 1; 8 | string uri = 2; 9 | map args = 3; 10 | } -------------------------------------------------------------------------------- /transport.go: -------------------------------------------------------------------------------- 1 | package crpc 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "errors" 7 | "io" 8 | "net" 9 | "net/http" 10 | "net/http/httputil" 11 | "strconv" 12 | "strings" 13 | "sync" 14 | "sync/atomic" 15 | "time" 16 | 17 | "github.com/lwch/crpc/encoding" 18 | "github.com/lwch/crpc/encoding/codec" 19 | "github.com/lwch/crpc/network" 20 | "github.com/lwch/logging" 21 | ) 22 | 23 | // ErrDone done error 24 | var ErrDone = errors.New("transport: done") 25 | 26 | var errDataType = errors.New("transport: data type error") 27 | 28 | // RequestHandlerFunc request handler 29 | type RequestHandlerFunc func(*http.Request) (*http.Response, error) 30 | 31 | type transport struct { 32 | conn *network.Conn 33 | codec encoding.Codec 34 | encrypter encoding.Encrypter 35 | compresser encoding.Compresser 36 | sequence atomic.Uint64 37 | onResponse map[uint64]chan *http.Response 38 | mResponse sync.RWMutex 39 | onRequest RequestHandlerFunc 40 | // runtime 41 | err error 42 | ctx context.Context 43 | cancel context.CancelFunc 44 | } 45 | 46 | func new(conn net.Conn) *transport { 47 | conn.SetDeadline(time.Time{}) // no timeout 48 | ctx, cancel := context.WithCancel(context.Background()) 49 | t := &transport{ 50 | conn: network.New(conn), 51 | codec: codec.New(), 52 | onResponse: make(map[uint64]chan *http.Response), 53 | onRequest: func(r *http.Request) (*http.Response, error) { 54 | return &http.Response{}, nil 55 | }, 56 | ctx: ctx, 57 | cancel: cancel, 58 | } 59 | go t.keepalive() 60 | return t 61 | } 62 | 63 | func (tp *transport) SetEncrypter(encrypter encoding.Encrypter) { 64 | tp.encrypter = encrypter 65 | } 66 | 67 | func (tp *transport) SetCompresser(compresser encoding.Compresser) { 68 | tp.compresser = compresser 69 | } 70 | 71 | func (tp *transport) AcceptStream() (*Stream, error) { 72 | s, err := tp.conn.AcceptStream() 73 | if err != nil { 74 | return nil, err 75 | } 76 | return &Stream{ 77 | parent: tp, 78 | s: s, 79 | }, nil 80 | } 81 | 82 | func (tp *transport) OpenStream(ctx context.Context) (*Stream, error) { 83 | s, err := tp.conn.OpenStream(ctx) 84 | if err != nil { 85 | return nil, err 86 | } 87 | return &Stream{ 88 | parent: tp, 89 | s: s, 90 | }, nil 91 | } 92 | 93 | func (tp *transport) SetOnRequest(fn RequestHandlerFunc) { 94 | tp.onRequest = fn 95 | } 96 | 97 | func (tp *transport) Close() error { 98 | return tp.conn.Close() 99 | } 100 | 101 | func (tp *transport) Call(ctx context.Context, req *http.Request) (*http.Response, error) { 102 | data, reqID, err := tp.buildRequest(req) 103 | if err != nil { 104 | return nil, err 105 | } 106 | ch := make(chan *http.Response, 1) 107 | tp.mResponse.Lock() 108 | tp.onResponse[reqID] = ch 109 | tp.mResponse.Unlock() 110 | defer func() { 111 | close(ch) 112 | tp.mResponse.Lock() 113 | defer tp.mResponse.Unlock() 114 | delete(tp.onResponse, reqID) 115 | }() 116 | hdr, _ := httputil.DumpRequest(req, false) 117 | logging.Debug("< http call(%d):\n%s", reqID, string(hdr)) 118 | _, err = tp.conn.Write(data) 119 | if err != nil { 120 | return nil, err 121 | } 122 | select { 123 | case <-tp.ctx.Done(): 124 | return nil, tp.err 125 | case <-ctx.Done(): 126 | return nil, ErrDone 127 | case resp := <-ch: 128 | return resp, nil 129 | } 130 | } 131 | 132 | func (tp *transport) Serve() error { 133 | var err error 134 | defer func() { 135 | tp.err = err 136 | tp.cancel() 137 | }() 138 | buf := make([]byte, 65535) 139 | for { 140 | var n int 141 | n, err = tp.conn.Read(buf) 142 | if err != nil { 143 | logging.Error("serve: %v", err) 144 | return err 145 | } 146 | err := tp.parse(buf[:n]) 147 | if err != nil { 148 | logging.Error("parse: %v", err) 149 | return err 150 | } 151 | } 152 | } 153 | 154 | func (tp *transport) parse(data []byte) error { 155 | payload, err := tp.decode(data) 156 | if err != nil { 157 | logging.Error("decode: %v", err) 158 | return err 159 | } 160 | switch v := payload.(type) { 161 | case *http.Request: 162 | str := v.Header.Get(keyRequestID) 163 | seq, _ := strconv.ParseUint(str, 10, 64) 164 | go tp.handleRequest(v, seq) 165 | case *http.Response: 166 | str := v.Header.Get(keyRequestID) 167 | seq, _ := strconv.ParseUint(str, 10, 64) 168 | tp.mResponse.RLock() 169 | ch := tp.onResponse[seq] 170 | tp.mResponse.RUnlock() 171 | if ch == nil { 172 | return nil 173 | } 174 | // recover on closed 175 | send := func(ch chan *http.Response, rep *http.Response) { 176 | defer func() { 177 | recover() 178 | }() 179 | ch <- rep 180 | } 181 | send(ch, v) 182 | default: 183 | return errDataType 184 | } 185 | return nil 186 | } 187 | 188 | func (tp *transport) handleRequest(req *http.Request, reqID uint64) { 189 | hdr, _ := httputil.DumpRequest(req, false) 190 | logging.Debug("> received http call(%d):\n%s", reqID, string(hdr)) 191 | if tp.onRequest == nil { 192 | return 193 | } 194 | resp, err := tp.onRequest(req) 195 | if err != nil { 196 | resp = &http.Response{ 197 | StatusCode: http.StatusInternalServerError, 198 | Body: io.NopCloser(strings.NewReader(err.Error())), 199 | } 200 | } 201 | resp.ProtoMajor = req.ProtoMajor 202 | resp.ProtoMinor = req.ProtoMinor 203 | if resp.Body != nil { 204 | var buf bytes.Buffer 205 | if _, err := io.Copy(&buf, resp.Body); err != nil { 206 | logging.Error("read body(%d): %v", reqID, err) 207 | return 208 | } 209 | resp.ContentLength = int64(buf.Len()) 210 | resp.Body = io.NopCloser(&buf) 211 | } 212 | data, err := tp.buildResponse(resp, reqID) 213 | if err != nil { 214 | logging.Error("build response(%d): %v", reqID, err) 215 | return 216 | } 217 | hdr, _ = httputil.DumpResponse(resp, false) 218 | logging.Debug("< http response(%d):\n%s", reqID, string(hdr)) 219 | _, err = tp.conn.Write(data) 220 | if err != nil { 221 | logging.Error("write response(%d): %v", reqID, err) 222 | return 223 | } 224 | } 225 | 226 | func (tp *transport) keepalive() { 227 | ticker := time.NewTicker(10 * time.Second) 228 | defer ticker.Stop() 229 | for { 230 | select { 231 | case <-tp.ctx.Done(): 232 | return 233 | case <-ticker.C: 234 | err := tp.conn.SendKeepalive() 235 | if err != nil { 236 | logging.Error("keepalive: %v", err) 237 | continue 238 | } 239 | } 240 | } 241 | } 242 | --------------------------------------------------------------------------------