├── .github └── workflows │ ├── pre_check.yaml │ └── tests.yaml ├── .gitignore ├── .licenserc.yaml ├── CONTIRBUTING.md ├── LICENSE ├── README.md ├── README_CN.md ├── ROADMAP.md ├── bench_test.go ├── block_io.go ├── block_io_test.go ├── buffer.go ├── buffer_manager.go ├── buffer_manager_test.go ├── buffer_slice.go ├── buffer_slice_test.go ├── buffer_test.go ├── config.go ├── config_test.go ├── const.go ├── debug.go ├── debug_test.go ├── epoll_linux.go ├── epoll_linux_arm64.go ├── errors.go ├── event_dispatcher.go ├── event_dispatcher_linux.go ├── event_dispatcher_race_linux.go ├── event_dispatcher_test.go ├── example ├── best_practice │ ├── idl │ │ └── example.go │ ├── net_client │ │ └── main.go │ ├── net_server │ │ └── main.go │ ├── run_net_client_server.sh │ ├── run_shmipc_async_client_server.sh │ ├── run_shmipc_client_server.sh │ ├── shmipc_async_client │ │ └── client.go │ ├── shmipc_async_server │ │ └── server.go │ ├── shmipc_client │ │ └── main.go │ └── shmipc_server │ │ └── main.go ├── helloworld │ ├── greeter_client │ │ └── main.go │ └── greeter_server │ │ └── main.go └── hot_restart_test │ ├── README.md │ ├── client │ ├── bootstrap.sh │ ├── build.sh │ └── client.go │ └── server │ ├── bootstrap.sh │ ├── bootstrap_hot_restart.sh │ ├── build.sh │ ├── callback_impl.go │ └── server.go ├── go.mod ├── go.sum ├── listener.go ├── listener_test.go ├── net_listener.go ├── protocol_event.go ├── protocol_initializer.go ├── protocol_manager.go ├── protocol_manager_test.go ├── queue.go ├── queue_test.go ├── session.go ├── session_manager.go ├── session_manager_test.go ├── session_test.go ├── stats.go ├── stream.go ├── stream_test.go ├── sys_memfd_create_bsd.go ├── sys_memfd_create_linux.go ├── util.go └── util_test.go /.github/workflows/pre_check.yaml: -------------------------------------------------------------------------------- 1 | name: Pull Request Check 2 | 3 | on: [ pull_request ] 4 | 5 | jobs: 6 | compliant: 7 | runs-on: [ self-hosted, X64 ] 8 | steps: 9 | - uses: actions/checkout@v3 10 | 11 | - name: Check License Header 12 | uses: apache/skywalking-eyes/header@v0.4.0 13 | env: 14 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 15 | 16 | - name: Check Spell 17 | uses: crate-ci/typos@v1.13.14 18 | 19 | staticcheck: 20 | runs-on: [ self-hosted, X64 ] 21 | steps: 22 | - uses: actions/checkout@v3 23 | - name: Set up Go 24 | uses: actions/setup-go@v3 25 | with: 26 | go-version: 1.19 27 | 28 | - uses: actions/cache@v3 29 | with: 30 | path: ~/go/pkg/mod 31 | key: reviewdog-${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} 32 | restore-keys: | 33 | reviewdog-${{ runner.os }}-go- 34 | - uses: reviewdog/action-staticcheck@v1 35 | with: 36 | github_token: ${{ secrets.github_token }} 37 | # Change reviewdog reporter if you need [github-pr-check,github-check,github-pr-review]. 38 | reporter: github-pr-review 39 | # Report all results. 40 | filter_mode: added 41 | # Exit with 1 when it find at least one finding. 42 | fail_on_error: true 43 | # Set staticcheck flags 44 | staticcheck_flags: -checks=inherit,-SA1029,-SA2002,-SA4006,-SA2002,-SA9003,-S1024 -exclude=example -------------------------------------------------------------------------------- /.github/workflows/tests.yaml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: [ push, pull_request ] 4 | 5 | jobs: 6 | unit-scenario-test: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v3 10 | - name: Set up Go 11 | uses: actions/setup-go@v3 12 | with: 13 | go-version: '1.17' 14 | - name: Unit Test 15 | run: go test -gcflags=-l -covermode=atomic -coverprofile=coverage.txt 16 | - name: Codecov 17 | run: bash <(curl -s https://codecov.io/bash) 18 | 19 | benchmark-test: 20 | runs-on: ubuntu-latest 21 | steps: 22 | - uses: actions/checkout@v3 23 | - name: Set up Go 24 | uses: actions/setup-go@v3 25 | with: 26 | go-version: '1.17' 27 | - name: Benchmark 28 | run: go test -gcflags='all=-N -l' -bench=. -benchmem -run=none 29 | 30 | compatibility-test: 31 | strategy: 32 | matrix: 33 | go: [1.13, 1.14, 1.15, 1.16, 1.18, 1.19, "1.20" ] 34 | os: [ X64 ] 35 | runs-on: ${{ matrix.os }} 36 | steps: 37 | - uses: actions/checkout@v3 38 | - name: Set up Go 39 | uses: actions/setup-go@v3 40 | with: 41 | go-version: ${{ matrix.go }} 42 | - name: Unit Test 43 | run: go test -gcflags=-l -covermode=atomic -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | -------------------------------------------------------------------------------- /.licenserc.yaml: -------------------------------------------------------------------------------- 1 | header: 2 | license: 3 | spdx-id: Apache-2.0 4 | copyright-owner: CloudWeGo Authors 5 | 6 | paths: 7 | - '**/*.go' 8 | - '**/*.s' 9 | 10 | paths-ignore: 11 | 12 | 13 | comment: on-failure 14 | -------------------------------------------------------------------------------- /CONTIRBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | ## Your First Pull Request 4 | We use github for our codebase. You can start by reading [How To Pull Request](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/about-pull-requests). 5 | 6 | ## Branch Organization 7 | We use [git-flow](https://nvie.com/posts/a-successful-git-branching-model/) as our branch organization, as known as [FDD](https://en.wikipedia.org/wiki/Feature-driven_development) 8 | 9 | ## Bugs 10 | ### 1. How to Find Known Issues 11 | We are using [Github Issues](https://github.com/cloudwego/shmipc-go/issues) for our public bugs. We keep a close eye on this and try to make it clear when we have an internal fix in progress. Before filing a new task, try to make sure your problem doesn't already exist. 12 | 13 | ### 2. Reporting New Issues 14 | Providing a reduced test code is a recommended way for reporting issues. Then can placed in: 15 | - Just in issues 16 | - [Golang Playground](https://play.golang.org/) 17 | 18 | ### 3. Security Bugs 19 | Please do not report the safe disclosure of bugs to public issues. Contact us by [Support Email](mailto:conduct@cloudwego.io) 20 | 21 | ## How to Get in Touch 22 | - [Email](mailto:conduct@cloudwego.io) 23 | 24 | ## Submit a Pull Request 25 | Before you submit your Pull Request (PR) consider the following guidelines: 26 | 1. Search [GitHub](https://github.com/cloudwego/shmipc-go/pulls) for an open or closed PR that relates to your submission. You don't want to duplicate existing efforts. 27 | 2. Please submit an issue instead of PR if you have a better suggestion for format tools. We won't accept a lot of file changes directly without issue statement and assignment. 28 | 3. Be sure that the issue describes the problem you're fixing, or documents the design for the feature you'd like to add. Before we accepting your work, we need to conduct some checks and evaluations. So, It will be better if you can discuss the design with us. 29 | 4. [Fork](https://docs.github.com/en/github/getting-started-with-github/fork-a-repo) the cloudwego/shmipc-go repo. 30 | 5. In your forked repository, make your changes in a new git branch: 31 | ``` 32 | git checkout -b my-fix-branch develop 33 | ``` 34 | 6. Create your patch, including appropriate test cases. 35 | 7. Follow our [Style Guides](#code-style-guides). 36 | 8. Commit your changes using a descriptive commit message that follows [AngularJS Git Commit Message Conventions](https://docs.google.com/document/d/1QrDFcIiPjSLDn3EL15IJygNPiHORgU1_OOAqWjiDU5Y/edit). 37 | Adherence to these conventions is necessary because release notes are automatically generated from these messages. 38 | 9. Push your branch to GitHub: 39 | ``` 40 | git push origin my-fix-branch 41 | ``` 42 | 10. In GitHub, send a pull request to `shmipc:develop` with a clear and unambiguous title. 43 | 44 | ## Contribution Prerequisites 45 | - Our development environment keeps up with [Go Official](https://golang.org/project/). 46 | - You need fully checking with lint tools before submit your pull request. [gofmt](https://golang.org/pkg/cmd/gofmt/) and [golangci-lint](https://github.com/golangci/golangci-lint) 47 | - You are familiar with [Github](https://github.com) 48 | - Maybe you need familiar with [Actions](https://github.com/features/actions)(our default workflow tool). 49 | 50 | ## Code Style Guides 51 | 52 | See [Go Code Review Comments](https://github.com/golang/go/wiki/CodeReviewComments). 53 | 54 | Good resources: 55 | - [Effective Go](https://golang.org/doc/effective_go) 56 | - [Pingcap General advice](https://pingcap.github.io/style-guide/general.html) 57 | - [Uber Go Style Guide](https://github.com/uber-go/guide/blob/master/style.md) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Shmipc 3 | 4 | English | [中文](README_CN.md) 5 | 6 | ## Introduction 7 | 8 | Shmipc is a high performance inter-process communication library developed by ByteDance. 9 | It is built on Linux's shared memory technology and uses unix or tcp connection to do process synchronization and finally implements zero copy communication across inter-processes. 10 | In IO-intensive or large-package scenarios, it has better performance. 11 | 12 | ## Features 13 | 14 | ### Zero copy 15 | 16 | In an industrial production environment, the unix domain socket and tcp loopback are often used in inter-process communication.The read operation or the write operation need copy data between user space buffer and kernel space buffer.But shmipc directly store data to the share memory, so no copy compared to the former. 17 | 18 | ### batch IO 19 | 20 | An IO queue was mapped to share memory, which describe the metadata of communication data. 21 | so that a process could put many request to the IO queue, and other process could handle a batch IO per synchronization. It could effectively reduce the system calls which was brought by process synchronization. 22 | 23 | ## Performance Testing 24 | 25 | The source code bench_test.go, doing a performance comparison between shmipc and unix domain in ping-pong scenario with different package size. The result is as follows: having a performance improvement whatever small package or large package. 26 | 27 | ``` 28 | go test -bench=BenchmarkParallelPingPong -run BenchmarkParallelPingPong 29 | goos: linux 30 | goarch: amd64 31 | pkg: github.com/cloudwego/shmipc-go 32 | cpu: Intel(R) Xeon(R) CPU E5-2630 v4 @ 2.20GHz 33 | BenchmarkParallelPingPongByShmipc64B-40 733821 1970 ns/op 64.97 MB/s 0 B/op 0 allocs/op 34 | BenchmarkParallelPingPongByShmipc512B-40 536190 1990 ns/op 514.45 MB/s 0 B/op 0 allocs/op 35 | BenchmarkParallelPingPongByShmipc1KB-40 540517 2045 ns/op 1001.62 MB/s 0 B/op 0 allocs/op 36 | BenchmarkParallelPingPongByShmipc4KB-40 509047 2063 ns/op 3970.91 MB/s 0 B/op 0 allocs/op 37 | BenchmarkParallelPingPongByShmipc16KB-40 590398 1996 ns/op 16419.46 MB/s 0 B/op 0 allocs/op 38 | BenchmarkParallelPingPongByShmipc32KB-40 607756 1937 ns/op 33829.82 MB/s 0 B/op 0 allocs/op 39 | BenchmarkParallelPingPongByShmipc64KB-40 609824 1995 ns/op 65689.31 MB/s 0 B/op 0 allocs/op 40 | BenchmarkParallelPingPongByShmipc256KB-40 622755 1793 ns/op 292363.56 MB/s 0 B/op 0 allocs/op 41 | BenchmarkParallelPingPongByShmipc512KB-40 695401 1993 ns/op 526171.77 MB/s 0 B/op 0 allocs/op 42 | BenchmarkParallelPingPongByShmipc1MB-40 538208 1873 ns/op 1119401.64 MB/s 0 B/op 0 allocs/op 43 | BenchmarkParallelPingPongByShmipc4MB-40 606144 1891 ns/op 4436936.93 MB/s 0 B/op 0 allocs/op 44 | BenchmarkParallelPingPongByUds64B-40 446019 2657 ns/op 48.18 MB/s 0 B/op 0 allocs/op 45 | BenchmarkParallelPingPongByUds512B-40 450124 2665 ns/op 384.30 MB/s 0 B/op 0 allocs/op 46 | BenchmarkParallelPingPongByUds1KB-40 446389 2680 ns/op 764.29 MB/s 0 B/op 0 allocs/op 47 | BenchmarkParallelPingPongByUds4KB-40 383552 3093 ns/op 2648.83 MB/s 1 B/op 0 allocs/op 48 | BenchmarkParallelPingPongByUds16KB-40 307816 3884 ns/op 8436.27 MB/s 8 B/op 0 allocs/op 49 | BenchmarkParallelPingPongByUds64KB-40 103027 10259 ns/op 12776.17 MB/s 102 B/op 0 allocs/op 50 | BenchmarkParallelPingPongByUds256KB-40 25286 46352 ns/op 11311.01 MB/s 1661 B/op 0 allocs/op 51 | BenchmarkParallelPingPongByUds512KB-40 9788 122873 ns/op 8533.84 MB/s 8576 B/op 0 allocs/op 52 | BenchmarkParallelPingPongByUds1MB-40 4177 283729 ns/op 7391.38 MB/s 40178 B/op 0 allocs/op 53 | BenchmarkParallelPingPongByUds4MB-40 919 1253338 ns/op 6693.01 MB/s 730296 B/op 1 allocs/op 54 | PASS 55 | ok github.com/cloudwego/shmipc 42.138s 56 | 57 | ``` 58 | 59 | - BenchmarkParallelPingPongByUds, the ping-pong communication base on Unix domain socket. 60 | - BenchmarkParallelPingPongByShmipc, the ping-pong communication base on shmipc. 61 | - the suffix of the testing case name is the package size of communication, which from 64 Byte to 4 MB. 62 | 63 | ### Quick start 64 | 65 | #### HelloWorld 66 | 67 | - [HelloWorldClient](https://github.com/cloudwego/shmipc-go/blob/main/example/helloworld/greeter_client/main.go) 68 | - [HelloWorldServer](https://github.com/cloudwego/shmipc-go/blob/main/example/helloworld/greeter_server/main.go) 69 | 70 | #### Integrate with application 71 | 72 | - [serialization and deserialization](https://github.com/cloudwego/shmipc-go/blob/main/example/best_practice/idl/example.go) 73 | - [client which using synchronous interface](https://github.com/cloudwego/shmipc-go/blob/main/example/best_practice/shmipc_client/main.go) 74 | - [server which using synchronous interface](https://github.com/cloudwego/shmipc-go/blob/main/example/best_practice/shmipc_server/main.go) 75 | - [client which using asynchronous interface](https://github.com/cloudwego/shmipc-go/blob/main/example/best_practice/shmipc_async_client/client.go) 76 | - [server which using asynchronous interface](https://github.com/cloudwego/shmipc-go/blob/main/example/best_practice/shmipc_async_server/server.go) 77 | 78 | #### HotRestart 79 | 80 | [hot restart demo](https://github.com/cloudwego/shmipc-go/blob/main/example/hot_restart_test/README.md) 81 | -------------------------------------------------------------------------------- /README_CN.md: -------------------------------------------------------------------------------- 1 | # Shmipc 2 | 3 | [English](README.md) | 中文 4 | 5 | ## 简介 6 | 7 | Shmipc是一个由字节跳动开发的高性能进程间通讯库。它基于Linux的共享内存构建,使用unix/tcp连接进行进程同步,实现进程间通讯零拷贝。在IO密集型场景或大包场景能够获得显著的性能收益。 8 | 9 | ## 特性 10 | 11 | ### 零拷贝 12 | 13 | 在工业生产环境中,Unix domain socket和Tcp loopback常用于进程间通讯,读写均涉及通讯数据在用户态buffer与内核态buffer的来回拷贝。而Shmipc使用共享内存存放通讯数据,相对于前者没有数据拷贝。 14 | 15 | ### 批量收割IO 16 | 17 | Shmipc在共享内存中引入了一个IO队列来描述通讯数据的元信息,一个进程可以并发地将多个请求的元信息放入IO队列,另外一个进程只要需要一次同步就能批量收割IO.这在IO密集的场景下能够有效减少进程同步带来的system call。 18 | 19 | ## 性能测试 20 | 21 | 源码中 bench_test.go 进行了Shmipc与Unix domain socket在ping-pong场景下不同数据包大小的性能对比,结果如下所示: 从小包到大包均有性能提升。 22 | 23 | ``` 24 | go test -bench=BenchmarkParallelPingPong -run BenchmarkParallelPingPong 25 | goos: linux 26 | goarch: amd64 27 | pkg: github.com/cloudwego/shmipc-go 28 | cpu: Intel(R) Xeon(R) CPU E5-2630 v4 @ 2.20GHz 29 | BenchmarkParallelPingPongByShmipc64B-40 733821 1970 ns/op 64.97 MB/s 0 B/op 0 allocs/op 30 | BenchmarkParallelPingPongByShmipc512B-40 536190 1990 ns/op 514.45 MB/s 0 B/op 0 allocs/op 31 | BenchmarkParallelPingPongByShmipc1KB-40 540517 2045 ns/op 1001.62 MB/s 0 B/op 0 allocs/op 32 | BenchmarkParallelPingPongByShmipc4KB-40 509047 2063 ns/op 3970.91 MB/s 0 B/op 0 allocs/op 33 | BenchmarkParallelPingPongByShmipc16KB-40 590398 1996 ns/op 16419.46 MB/s 0 B/op 0 allocs/op 34 | BenchmarkParallelPingPongByShmipc32KB-40 607756 1937 ns/op 33829.82 MB/s 0 B/op 0 allocs/op 35 | BenchmarkParallelPingPongByShmipc64KB-40 609824 1995 ns/op 65689.31 MB/s 0 B/op 0 allocs/op 36 | BenchmarkParallelPingPongByShmipc256KB-40 622755 1793 ns/op 292363.56 MB/s 0 B/op 0 allocs/op 37 | BenchmarkParallelPingPongByShmipc512KB-40 695401 1993 ns/op 526171.77 MB/s 0 B/op 0 allocs/op 38 | BenchmarkParallelPingPongByShmipc1MB-40 538208 1873 ns/op 1119401.64 MB/s 0 B/op 0 allocs/op 39 | BenchmarkParallelPingPongByShmipc4MB-40 606144 1891 ns/op 4436936.93 MB/s 0 B/op 0 allocs/op 40 | BenchmarkParallelPingPongByUds64B-40 446019 2657 ns/op 48.18 MB/s 0 B/op 0 allocs/op 41 | BenchmarkParallelPingPongByUds512B-40 450124 2665 ns/op 384.30 MB/s 0 B/op 0 allocs/op 42 | BenchmarkParallelPingPongByUds1KB-40 446389 2680 ns/op 764.29 MB/s 0 B/op 0 allocs/op 43 | BenchmarkParallelPingPongByUds4KB-40 383552 3093 ns/op 2648.83 MB/s 1 B/op 0 allocs/op 44 | BenchmarkParallelPingPongByUds16KB-40 307816 3884 ns/op 8436.27 MB/s 8 B/op 0 allocs/op 45 | BenchmarkParallelPingPongByUds64KB-40 103027 10259 ns/op 12776.17 MB/s 102 B/op 0 allocs/op 46 | BenchmarkParallelPingPongByUds256KB-40 25286 46352 ns/op 11311.01 MB/s 1661 B/op 0 allocs/op 47 | BenchmarkParallelPingPongByUds512KB-40 9788 122873 ns/op 8533.84 MB/s 8576 B/op 0 allocs/op 48 | BenchmarkParallelPingPongByUds1MB-40 4177 283729 ns/op 7391.38 MB/s 40178 B/op 0 allocs/op 49 | BenchmarkParallelPingPongByUds4MB-40 919 1253338 ns/op 6693.01 MB/s 730296 B/op 1 allocs/op 50 | PASS 51 | ok github.com/cloudwego/shmipc 42.138s 52 | ``` 53 | 54 | - BenchmarkParallelPingPongByUds,基于Unix domain socket进行ping-pong通讯。 55 | - BenchmarkParallelPingPongByShmipc,基于Shmipc进行ping-pong通讯。 56 | - 后缀为ping-pong的数据包大小, 从 64 Byte ~ 4MB 不等。 57 | 58 | ### 快速开始 59 | 60 | #### HelloWorld 61 | 62 | - [HelloWorld客户端](https://github.com/cloudwego/shmipc-go/blob/main/example/helloworld/greeter_client/main.go) 63 | - [HelloWorld服务端](https://github.com/cloudwego/shmipc-go/blob/main/example/helloworld/greeter_server/main.go) 64 | 65 | #### 与应用集成 66 | 67 | - [使用Stream的Buffer接口进行对象的序列化与反序列化](https://github.com/cloudwego/shmipc-go/blob/main/example/best_practice/idl/example.go) 68 | - [使用Shmipc同步接口的Client实现](https://github.com/cloudwego/shmipc-go/blob/main/example/best_practice/shmipc_client/main.go) 69 | - [使用Shmipc同步接口的Server实现](https://github.com/cloudwego/shmipc-go/blob/main/example/best_practice/shmipc_server/main.go) 70 | - [使用Shmipc异步接口的Client实现](https://github.com/cloudwego/shmipc-go/blob/main/example/best_practice/shmipc_async_client/client.go) 71 | - [使用Shmipc异步接口的Server实现](https://github.com/cloudwego/shmipc-go/blob/main/example/best_practice/shmipc_async_server/server.go) 72 | 73 | #### 热升级 74 | 75 | [热升级demo](https://github.com/cloudwego/shmipc-go/blob/main/example/hot_restart_test/README.md) 76 | -------------------------------------------------------------------------------- /ROADMAP.md: -------------------------------------------------------------------------------- 1 | # Shmipc RoadMap 2 | 3 | ## New Features: 4 | 5 | ### 1.max write buffer of stream 6 | Give restrictions to the write buffer of the stream, so that the writing side will be blocked when the write buffer is full. The writing side will only be unblocked and allowed to continue writing when the receiving process finishes reading the data. 7 | 8 | This can prevent a situation where the writing rate is faster than the reading rate of the receiving process, which would quickly fill up the shared memory and trigger the slower fallback path, resulting in degraded program performance. 9 | 10 | 11 | ### 2. Abstract process synchronization mechanisms 12 | Currently, the main methods used for process synchronization are Unix domain sockets or TCP loopback, which are suitable for online ping-pong scenarios. However, by abstracting the process synchronization mechanisms, we can introduce different synchronization mechanisms to adapt to different scenarios and improve program performance. 13 | 14 | ### 3. Add process synchronization mechanisms with timed synchronization 15 | 16 | For offline scenarios (not sensitive to latency), we can use high-interval sleep and polling of flag bits in shared memory for synchronization, which can effectively reduce the overhead of process synchronization and improve program performance. 17 | 18 | ### 4. Support for ARM architecture 19 | 20 | 21 | ## Optimization: 22 | 23 | - Optimize the performance of the fallback path by reducing unnecessary data packet copies. 24 | 25 | - Implementation of a lock-free IO queue. 26 | 27 | - Consider removing the BufferWriter and BufferReader interfaces and replacing them with concrete implementations to improve performance when serializing and deserializing data. 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /block_io.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package shmipc 18 | 19 | import ( 20 | "fmt" 21 | "io" 22 | "syscall" 23 | ) 24 | 25 | func blockReadFull(connFd int, data []byte) error { 26 | readSize := 0 27 | for readSize < len(data) { 28 | n, err := syscall.Read(connFd, data[readSize:]) 29 | if err != nil { 30 | return fmt.Errorf("ReadFull failed, had readSize:%d reason:%s", readSize, err.Error()) 31 | } 32 | readSize += n 33 | if n == 0 { 34 | return io.EOF 35 | } 36 | } 37 | return nil 38 | } 39 | 40 | func blockWriteFull(connFd int, data []byte) error { 41 | written := 0 42 | for written < len(data) { 43 | n, err := syscall.Write(connFd, data[written:]) 44 | if err != nil { 45 | return err 46 | } 47 | written += n 48 | } 49 | return nil 50 | } 51 | 52 | func sendFd(connFd int, oob []byte) error { 53 | err := syscall.Sendmsg(connFd, nil, oob, nil, 0) 54 | return err 55 | } 56 | 57 | func blockReadOutOfBoundForFd(connFd int, oob []byte) (oobn int, err error) { 58 | _, oobn, _, _, err = syscall.Recvmsg(connFd, nil, oob, 0) 59 | return 60 | } 61 | -------------------------------------------------------------------------------- /block_io_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package shmipc 18 | 19 | import ( 20 | "net" 21 | "os" 22 | "testing" 23 | 24 | "github.com/stretchr/testify/assert" 25 | ) 26 | 27 | func TestBlockReadFullAndBlockWriteFull(t *testing.T) { 28 | content := "hello,shmipc!" 29 | // Create a local Unix socket listener 30 | laddr, err := net.ResolveUnixAddr("unix", "/tmp/testBlockRWFull.sock") 31 | if err != nil { 32 | t.Fatalf("failed to resolve unix address: %v\n", err) 33 | } 34 | listener, err := net.ListenUnix("unix", laddr) 35 | if err != nil { 36 | t.Fatalf("failed to listen unix: %v\n", err) 37 | return 38 | } 39 | defer func() { 40 | listener.Close() 41 | os.Remove("/tmp/testBlockRWFull.sock") 42 | }() 43 | 44 | // Start a goroutine to accept a connection and write data 45 | go func() { 46 | conn, err := listener.Accept() 47 | if err != nil { 48 | t.Errorf("failed to accept connection: %v\n", err) 49 | } 50 | defer conn.Close() 51 | fd, err := getConnDupFd(conn) 52 | if err != nil { 53 | t.Errorf("failed to getConnDupFd: %v", err) 54 | } 55 | 56 | // Write data using blockWriteFull 57 | data := []byte(content) 58 | if err := blockWriteFull(int(fd.Fd()), data); err != nil { 59 | t.Errorf("failed to write data: %v\n", err) 60 | } 61 | }() 62 | 63 | // Dial the Unix socket and read data 64 | conn, err := net.DialUnix("unix", nil, laddr) 65 | if err != nil { 66 | t.Errorf("failed to dial unix: %v\n", err) 67 | } 68 | defer conn.Close() 69 | fd, err := getConnDupFd(conn) 70 | if err != nil { 71 | t.Errorf("failed to getConnDupFd: %v", err) 72 | } 73 | 74 | // Read data using blockReadFull 75 | buf := make([]byte, 1024) 76 | if err := blockReadFull(int(fd.Fd()), buf[:len(content)]); err != nil { 77 | t.Errorf("failed to read data: %v\n", err) 78 | } 79 | // Check if the read data is correct 80 | assert.Equal(t, buf[:len(content)], []byte(content)) 81 | 82 | } 83 | -------------------------------------------------------------------------------- /buffer_manager_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package shmipc 18 | 19 | import ( 20 | "math/rand" 21 | "sync" 22 | "testing" 23 | "time" 24 | 25 | "github.com/stretchr/testify/assert" 26 | ) 27 | 28 | func TestBufferManager_CreateAndMapping(t *testing.T) { 29 | //create 30 | mem := make([]byte, 32<<20) 31 | bm1, err := createBufferManager([]*SizePercentPair{ 32 | {4096, 70}, 33 | {16 * 1024, 20}, 34 | {64 * 1024, 10}, 35 | }, "", mem, 0) 36 | if err != nil { 37 | t.Fatal("create buffer manager failed, err=" + err.Error()) 38 | } 39 | 40 | allocateFunc := func(bm *bufferManager) { 41 | for i := 0; i < 10; i++ { 42 | _, err := bm.allocShmBuffer(4096) 43 | assert.Equal(t, nil, err) 44 | _, err = bm.allocShmBuffer(16 * 1024) 45 | assert.Equal(t, nil, err) 46 | _, err = bm.allocShmBuffer(64 * 1024) 47 | assert.Equal(t, nil, err) 48 | } 49 | } 50 | allocateFunc(bm1) 51 | 52 | //mapping 53 | bm2, err := mappingBufferManager("", mem, 0) 54 | if err != nil { 55 | t.Fatal("mapping buffer manager failed, err=" + err.Error()) 56 | } 57 | 58 | for i := range bm1.lists { 59 | assert.Equal(t, *bm1.lists[i].capPerBuffer, *bm2.lists[i].capPerBuffer) 60 | assert.Equal(t, *bm1.lists[i].size, *bm2.lists[i].size) 61 | assert.Equal(t, bm1.lists[i].offsetInShm, bm2.lists[i].offsetInShm) 62 | } 63 | 64 | allocateFunc(bm2) 65 | 66 | for i := range bm1.lists { 67 | assert.Equal(t, *bm1.lists[i].capPerBuffer, *bm2.lists[i].capPerBuffer) 68 | assert.Equal(t, *bm1.lists[i].size, *bm2.lists[i].size) 69 | assert.Equal(t, bm1.lists[i].offsetInShm, bm2.lists[i].offsetInShm) 70 | } 71 | } 72 | 73 | func TestBufferManager_ReadBufferSlice(t *testing.T) { 74 | mem := make([]byte, 1<<20) 75 | bm, err := createBufferManager([]*SizePercentPair{ 76 | {Size: uint32(4096), Percent: 100}, 77 | }, "", mem, 0) 78 | assert.Equal(t, nil, err) 79 | 80 | s, err := bm.allocShmBuffer(4096) 81 | assert.Equal(t, nil, err) 82 | data := make([]byte, 4096) 83 | rand.Read(data) 84 | assert.Equal(t, 4096, s.append(data...)) 85 | assert.Equal(t, 4096, s.size()) 86 | s.update() 87 | 88 | s2, err := bm.readBufferSlice(s.offsetInShm) 89 | assert.Equal(t, nil, err) 90 | assert.Equal(t, s.capacity(), s2.capacity()) 91 | assert.Equal(t, s.size(), s2.size()) 92 | 93 | getData, err := s2.read(4096) 94 | assert.Equal(t, nil, err) 95 | assert.Equal(t, data, getData) 96 | 97 | s3, err := bm.readBufferSlice(s.offsetInShm + 1<<20) 98 | assert.NotEqual(t, nil, err) 99 | assert.Equal(t, (*bufferSlice)(nil), s3) 100 | 101 | s4, err := bm.readBufferSlice(s.offsetInShm + 4096) 102 | assert.NotEqual(t, nil, err) 103 | assert.Equal(t, (*bufferSlice)(nil), s4) 104 | } 105 | 106 | func TestBufferManager_AllocRecycle(t *testing.T) { 107 | //allocBuffer 108 | mem := make([]byte, 1<<20) 109 | bm, err := createBufferManager([]*SizePercentPair{ 110 | {Size: 4096, Percent: 50}, 111 | {Size: 8192, Percent: 50}, 112 | }, "", mem, 0) 113 | assert.Equal(t, nil, err) 114 | // use first two buffer to record buffer list info(List header) 115 | assert.Equal(t, uint32(1<<20-4096-8192), bm.remainSize()) 116 | 117 | numOfSlice := bm.sliceSize() 118 | buffers := make([]*bufferSlice, 0, 1024) 119 | for { 120 | buf, err := bm.allocShmBuffer(4096) 121 | if err != nil { 122 | break 123 | } 124 | buffers = append(buffers, buf) 125 | } 126 | for i := range buffers { 127 | bm.recycleBuffer(buffers[i]) 128 | } 129 | buffers = buffers[:0] 130 | 131 | //allocBuffers, recycleBuffers 132 | slices := newSliceList() 133 | size := bm.allocShmBuffers(slices, 256*1024) 134 | assert.Equal(t, int(size), 256*1024) 135 | linkedBufferSlices := newEmptyLinkedBuffer(bm) 136 | for slices.size() > 0 { 137 | linkedBufferSlices.appendBufferSlice(slices.popFront()) 138 | } 139 | linkedBufferSlices.done(false) 140 | bm.recycleBuffers(linkedBufferSlices.sliceList.popFront()) 141 | assert.Equal(t, numOfSlice, bm.sliceSize()) 142 | } 143 | 144 | func TestBufferList_PutPop(t *testing.T) { 145 | capPerBuffer := uint32(4096) 146 | bufferNum := uint32(1000) 147 | mem := make([]byte, countBufferListMemSize(bufferNum, capPerBuffer)) 148 | 149 | l, err := createFreeBufferList(bufferNum, capPerBuffer, mem, 0) 150 | if err != nil { 151 | t.Fatal(err) 152 | } 153 | 154 | buffers := make([]*bufferSlice, 0, 1024) 155 | originSize := l.remain() 156 | for i := 0; l.remain() > 0; i++ { 157 | b, err := l.pop() 158 | if err != nil { 159 | t.Fatal(err) 160 | } 161 | buffers = append(buffers, b) 162 | assert.Equal(t, capPerBuffer, b.cap) 163 | assert.Equal(t, 0, b.size()) 164 | assert.Equal(t, false, b.hasNext()) 165 | } 166 | 167 | for i := range buffers { 168 | l.push(buffers[i]) 169 | } 170 | 171 | assert.Equal(t, originSize, l.remain()) 172 | for i := 0; l.remain() > 0; i++ { 173 | b, err := l.pop() 174 | if err != nil { 175 | t.Fatal(err) 176 | } 177 | buffers = append(buffers, b) 178 | assert.Equal(t, capPerBuffer, b.cap) 179 | assert.Equal(t, 0, b.size()) 180 | assert.Equal(t, false, b.hasNext()) 181 | } 182 | } 183 | 184 | func TestBufferList_ConcurrentPutPop(t *testing.T) { 185 | capPerBuffer := uint32(10) 186 | bufferNum := uint32(10) 187 | mem := make([]byte, countBufferListMemSize(bufferNum, capPerBuffer)) 188 | l, err := createFreeBufferList(bufferNum, capPerBuffer, mem, 0) 189 | if err != nil { 190 | t.Fatal(err) 191 | } 192 | 193 | start := make(chan struct{}) 194 | var finishedWg sync.WaitGroup 195 | var startWg sync.WaitGroup 196 | concurrency := 100 197 | finishedWg.Add(concurrency) 198 | startWg.Add(concurrency) 199 | for i := 0; i < concurrency; i++ { 200 | go func() { 201 | defer finishedWg.Done() 202 | //put and pop 203 | startWg.Done() 204 | <-start 205 | for j := 0; j < 10000; j++ { 206 | var err error 207 | var b *bufferSlice 208 | b, err = l.pop() 209 | for err != nil { 210 | time.Sleep(time.Millisecond) 211 | b, err = l.pop() 212 | } 213 | assert.Equal(t, capPerBuffer, b.cap) 214 | assert.Equal(t, 0, b.size()) 215 | assert.Equal(t, false, b.hasNext(), "offset:%d next:%d", b.offsetInShm, b.nextBufferOffset()) 216 | l.push(b) 217 | } 218 | }() 219 | } 220 | startWg.Wait() 221 | close(start) 222 | finishedWg.Wait() 223 | assert.Equal(t, bufferNum, uint32(*l.size)) 224 | } 225 | 226 | func TestBufferList_CreateAndMappingFreeBufferList(t *testing.T) { 227 | capPerBuffer := uint32(10) 228 | bufferNum := uint32(10) 229 | mem := make([]byte, countBufferListMemSize(bufferNum, capPerBuffer)) 230 | l, err := createFreeBufferList(0, capPerBuffer, mem, 0) 231 | assert.NotEqual(t, nil, err) 232 | assert.Equal(t, (*bufferList)(nil), l) 233 | 234 | mem = make([]byte, countBufferListMemSize(bufferNum, capPerBuffer)) 235 | l, err = createFreeBufferList(bufferNum+1, capPerBuffer, mem, 0) 236 | assert.NotEqual(t, nil, err) 237 | assert.Equal(t, (*bufferList)(nil), l) 238 | 239 | mem = make([]byte, countBufferListMemSize(bufferNum, capPerBuffer)) 240 | l, err = createFreeBufferList(bufferNum, capPerBuffer, mem, 0) 241 | assert.Equal(t, nil, err) 242 | assert.NotEqual(t, (*bufferList)(nil), l) 243 | 244 | testMem := make([]byte, 10) 245 | ml, err := mappingFreeBufferList(testMem, 0) 246 | assert.NotEqual(t, nil, err) 247 | assert.Equal(t, (*bufferList)(nil), ml) 248 | 249 | ml, err = mappingFreeBufferList(mem, 10) 250 | assert.NotEqual(t, nil, err) 251 | assert.Equal(t, (*bufferList)(nil), ml) 252 | 253 | ml, err = mappingFreeBufferList(mem, 0) 254 | assert.Equal(t, nil, err) 255 | assert.NotEqual(t, (*bufferList)(nil), ml) 256 | 257 | if err != nil { 258 | t.Fatalf("fail to mapping bufferlist:%s", err.Error()) 259 | } 260 | 261 | } 262 | 263 | func BenchmarkBufferList_PutPop(b *testing.B) { 264 | capPerBuffer := uint32(10) 265 | bufferNum := uint32(10000) 266 | mem := make([]byte, countBufferListMemSize(bufferNum, capPerBuffer)) 267 | l, err := createFreeBufferList(bufferNum, capPerBuffer, mem, 0) 268 | if err != nil { 269 | b.Fatal(err) 270 | } 271 | b.ReportAllocs() 272 | b.ResetTimer() 273 | 274 | for i := 0; i < b.N; i++ { 275 | buf, err := l.pop() 276 | if err != nil { 277 | b.Fatal(err) 278 | } 279 | l.push(buf) 280 | } 281 | } 282 | 283 | func BenchmarkBufferList_PutPopParallel(b *testing.B) { 284 | capPerBuffer := uint32(1) 285 | bufferNum := uint32(100 * 10000) 286 | mem := make([]byte, countBufferListMemSize(bufferNum, capPerBuffer)) 287 | l, err := createFreeBufferList(bufferNum, capPerBuffer, mem, 0) 288 | if err != nil { 289 | b.Fatal(err) 290 | } 291 | 292 | b.ReportAllocs() 293 | b.ResetTimer() 294 | b.RunParallel(func(pb *testing.PB) { 295 | for pb.Next() { 296 | var err error 297 | var buf *bufferSlice 298 | buf, err = l.pop() 299 | for err != nil { 300 | time.Sleep(time.Millisecond) 301 | buf, err = l.pop() 302 | } 303 | l.push(buf) 304 | } 305 | }) 306 | } 307 | 308 | func TestCreateFreeBufferList(t *testing.T) { 309 | _, err := createFreeBufferList(4294967295, 4294967295, []byte{'w'}, 4294967279) 310 | assert.NotNil(t, err) 311 | } 312 | -------------------------------------------------------------------------------- /buffer_slice.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package shmipc 18 | 19 | import ( 20 | "sync" 21 | "unsafe" 22 | ) 23 | 24 | var ( 25 | bufferSlicePool = &sync.Pool{ 26 | New: func() interface{} { 27 | return &bufferSlice{} 28 | }, 29 | } 30 | ) 31 | 32 | type bufferHeader []byte 33 | 34 | type bufferSlice struct { 35 | //bufferHeader layout: cap 4 byte | size 4 byte | start 4 byte | next 4 byte | flag 4 byte 36 | bufferHeader 37 | data []byte 38 | cap uint32 39 | //use for prepend 40 | start uint32 41 | offsetInShm uint32 42 | readIndex int 43 | writeIndex int 44 | isFromShm bool 45 | nextSlice *bufferSlice 46 | } 47 | 48 | func (s *bufferSlice) next() *bufferSlice { 49 | return s.nextSlice 50 | } 51 | 52 | func newBufferSlice(header []byte, data []byte, offsetInShm uint32, isFromShm bool) *bufferSlice { 53 | s := bufferSlicePool.Get().(*bufferSlice) 54 | if isFromShm && header != nil { 55 | s.cap = *(*uint32)(unsafe.Pointer(&header[bufferCapOffset])) 56 | s.start = *(*uint32)(unsafe.Pointer(&header[bufferDataStartOffset])) 57 | s.readIndex = int(s.start) 58 | s.writeIndex = int(s.start + *(*uint32)(unsafe.Pointer(&(header[bufferSizeOffset])))) 59 | } else { 60 | s.cap = uint32(cap(data)) 61 | } 62 | s.bufferHeader = header 63 | s.data = data 64 | s.offsetInShm = offsetInShm 65 | s.isFromShm = isFromShm 66 | return s 67 | } 68 | 69 | func putBackBufferSlice(s *bufferSlice) { 70 | s.isFromShm = false 71 | s.offsetInShm = 0 72 | s.data = nil 73 | s.bufferHeader = nil 74 | s.cap = 0 75 | s.writeIndex = 0 76 | s.readIndex = 0 77 | s.start = 0 78 | s.nextSlice = nil 79 | bufferSlicePool.Put(s) 80 | } 81 | 82 | func (s bufferHeader) nextBufferOffset() uint32 { 83 | return *(*uint32)(unsafe.Pointer(&s[nextBufferOffset])) 84 | } 85 | 86 | func (s bufferHeader) hasNext() bool { 87 | return (s[bufferFlagOffset] & hasNextBufferFlag) > 0 88 | } 89 | 90 | func (s bufferHeader) clearFlag() { 91 | s[bufferFlagOffset] = 0 92 | } 93 | 94 | func (s bufferHeader) setInUsed() { 95 | s[bufferFlagOffset] |= sliceInUsedFlag 96 | } 97 | 98 | func (s bufferHeader) isInUsed() bool { 99 | return (s[bufferFlagOffset] & sliceInUsedFlag) > 0 100 | } 101 | 102 | func (s bufferHeader) linkNext(next uint32) { 103 | *(*uint32)(unsafe.Pointer(&s[nextBufferOffset])) = next 104 | s[bufferFlagOffset] |= hasNextBufferFlag 105 | } 106 | 107 | func (s *bufferSlice) update() { 108 | if s.bufferHeader != nil { 109 | *(*uint32)(unsafe.Pointer(&s.bufferHeader[bufferSizeOffset])) = uint32(s.size()) 110 | *(*uint32)(unsafe.Pointer(&s.bufferHeader[bufferDataStartOffset])) = s.start 111 | if s.nextSlice != nil { 112 | s.linkNext(s.nextSlice.offsetInShm) 113 | } 114 | } 115 | } 116 | 117 | func (s *bufferSlice) reset() { 118 | if s.bufferHeader != nil { 119 | *(*uint32)(unsafe.Pointer(&s.bufferHeader[bufferSizeOffset])) = 0 120 | *(*uint32)(unsafe.Pointer(&s.bufferHeader[bufferDataStartOffset])) = 0 121 | s.bufferHeader.clearFlag() 122 | } 123 | s.writeIndex = 0 124 | s.readIndex = 0 125 | s.nextSlice = nil 126 | } 127 | 128 | func (s *bufferSlice) size() int { 129 | return s.writeIndex - s.readIndex 130 | } 131 | 132 | func (s *bufferSlice) remain() int { 133 | return int(s.cap) - s.writeIndex 134 | } 135 | 136 | func (s *bufferSlice) capacity() int { 137 | return int(s.cap) 138 | } 139 | 140 | func (s *bufferSlice) reserve(size int) ([]byte, error) { 141 | start := s.writeIndex 142 | remain := s.remain() 143 | if remain >= size { 144 | s.writeIndex += size 145 | return s.data[start:s.writeIndex], nil 146 | } 147 | return nil, ErrNoMoreBuffer 148 | } 149 | 150 | func (s *bufferSlice) prepend() { 151 | panic("TODO") 152 | } 153 | 154 | func (s *bufferSlice) append(data ...byte) int { 155 | if len(data) == 0 { 156 | return 0 157 | } 158 | copySize := copy(s.data[s.writeIndex:], data) 159 | s.writeIndex += copySize 160 | return copySize 161 | } 162 | 163 | func (s *bufferSlice) read(size int) (data []byte, err error) { 164 | unRead := s.size() 165 | if unRead < size { 166 | size = unRead 167 | err = ErrNotEnoughData 168 | } 169 | data = s.data[s.readIndex : s.readIndex+size] 170 | s.readIndex += size 171 | return 172 | } 173 | 174 | func (s *bufferSlice) peek(size int) (data []byte, err error) { 175 | origin := s.readIndex 176 | data, err = s.read(size) 177 | s.readIndex = origin 178 | return 179 | } 180 | 181 | func (s *bufferSlice) skip(size int) int { 182 | unRead := s.size() 183 | if unRead > size { 184 | s.readIndex += size 185 | return size 186 | } 187 | s.readIndex += unRead 188 | return unRead 189 | } 190 | 191 | type sliceList struct { 192 | frontSlice *bufferSlice 193 | writeSlice *bufferSlice 194 | backSlice *bufferSlice 195 | len int 196 | } 197 | 198 | func newSliceList() *sliceList { 199 | return &sliceList{} 200 | } 201 | 202 | func (l *sliceList) front() *bufferSlice { 203 | return l.frontSlice 204 | } 205 | 206 | func (l *sliceList) back() *bufferSlice { 207 | return l.backSlice 208 | } 209 | 210 | func (l *sliceList) size() int { 211 | return l.len 212 | } 213 | 214 | func (l *sliceList) pushBack(s *bufferSlice) { 215 | if s == nil { 216 | return 217 | } 218 | 219 | if l.len > 0 { 220 | l.backSlice.nextSlice = s 221 | } else { 222 | l.frontSlice = s 223 | } 224 | 225 | l.backSlice = s 226 | l.len++ 227 | } 228 | 229 | func (l *sliceList) popFront() *bufferSlice { 230 | r := l.frontSlice 231 | 232 | if l.len > 0 { 233 | l.len-- 234 | l.frontSlice = l.frontSlice.nextSlice 235 | } 236 | 237 | if l.len == 0 { 238 | l.frontSlice = nil 239 | l.backSlice = nil 240 | } 241 | 242 | return r 243 | } 244 | 245 | func (l *sliceList) splitFromWrite() *bufferSlice { 246 | nextListHead := l.writeSlice.nextSlice 247 | l.backSlice = l.writeSlice 248 | l.backSlice.nextSlice = nil 249 | nextListSize := 0 250 | for s := nextListHead; s != nil; s = s.nextSlice { 251 | nextListSize++ 252 | } 253 | l.len -= nextListSize 254 | return nextListHead 255 | } 256 | -------------------------------------------------------------------------------- /buffer_slice_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package shmipc 18 | 19 | import ( 20 | "math/rand" 21 | "testing" 22 | "unsafe" 23 | 24 | "github.com/stretchr/testify/assert" 25 | ) 26 | 27 | func TestBufferSlice_ReadWrite(t *testing.T) { 28 | size := 8192 29 | slice := newBufferSlice(nil, make([]byte, size), 0, false) 30 | 31 | for i := 0; i < size; i++ { 32 | n := slice.append(byte(i)) 33 | assert.Equal(t, 1, n) 34 | } 35 | n := slice.append(byte(10)) 36 | assert.Equal(t, 0, n) 37 | 38 | data, err := slice.read(size * 10) 39 | assert.Equal(t, size, len(data)) 40 | assert.Equal(t, err, ErrNotEnoughData) 41 | 42 | //verify read data. 43 | for i := 0; i < size; i++ { 44 | assert.Equal(t, byte(i), data[i]) 45 | } 46 | } 47 | 48 | func TestBufferSlice_Skip(t *testing.T) { 49 | slice := newBufferSlice(nil, make([]byte, 8192), 0, false) 50 | slice.append(make([]byte, slice.capacity())...) 51 | remain := slice.capacity() 52 | 53 | n := slice.skip(10) 54 | remain -= n 55 | assert.Equal(t, remain, slice.size()) 56 | 57 | n = slice.skip(100) 58 | remain -= n 59 | assert.Equal(t, remain, slice.size()) 60 | 61 | n = slice.skip(10000) 62 | assert.Equal(t, 0, slice.size()) 63 | } 64 | 65 | func TestBufferSlice_Reserve(t *testing.T) { 66 | size := 8192 67 | slice := newBufferSlice(nil, make([]byte, size), 0, false) 68 | data1, err := slice.reserve(100) 69 | assert.Equal(t, nil, err) 70 | assert.Equal(t, 100, len(data1)) 71 | 72 | data2, err := slice.reserve(size) 73 | assert.Equal(t, ErrNoMoreBuffer, err) 74 | assert.Equal(t, 0, len(data2)) 75 | 76 | for i := range data1 { 77 | data1[i] = byte(i) 78 | } 79 | for i := range data2 { 80 | data2[i] = byte(i) 81 | } 82 | 83 | readData, err := slice.read(100) 84 | assert.Equal(t, nil, err) 85 | assert.Equal(t, len(data1), len(readData)) 86 | 87 | for i := 0; i < len(data1); i++ { 88 | assert.Equal(t, data1[i], readData[i]) 89 | } 90 | 91 | readData, err = slice.read(10000) 92 | assert.Equal(t, ErrNotEnoughData, err) 93 | assert.Equal(t, len(data2), len(readData)) 94 | 95 | for i := 0; i < len(data2); i++ { 96 | assert.Equal(t, data2[i], readData[i]) 97 | } 98 | } 99 | 100 | func TestBufferSlice_Update(t *testing.T) { 101 | size := 8192 102 | header := make([]byte, bufferHeaderSize) 103 | *(*uint32)(unsafe.Pointer(&header[bufferCapOffset])) = uint32(size) 104 | slice := newBufferSlice(header, make([]byte, size), 0, true) 105 | 106 | n := slice.append(make([]byte, size)...) 107 | assert.Equal(t, size, n) 108 | slice.update() 109 | assert.Equal(t, size, int(*(*uint32)(unsafe.Pointer(&slice.bufferHeader[bufferSizeOffset])))) 110 | } 111 | 112 | func TestBufferSlice_linkedNext(t *testing.T) { 113 | size := 8192 114 | sliceNum := 100 115 | 116 | slices := make([]*bufferSlice, 0, sliceNum) 117 | mem := make([]byte, 10<<20) 118 | bm, err := createBufferManager([]*SizePercentPair{ 119 | {Size: uint32(size), Percent: 100}, 120 | }, "", mem, 0) 121 | assert.Equal(t, nil, err) 122 | 123 | writeDataArray := make([][]byte, 0, sliceNum) 124 | 125 | for i := 0; i < sliceNum; i++ { 126 | s, err := bm.allocShmBuffer(uint32(size)) 127 | assert.Equal(t, nil, err, "i:%d", i) 128 | data := make([]byte, size) 129 | rand.Read(data) 130 | writeDataArray = append(writeDataArray, data) 131 | assert.Equal(t, size, s.append(data...)) 132 | s.update() 133 | slices = append(slices, s) 134 | } 135 | 136 | for i := 0; i <= len(slices)-2; i++ { 137 | slices[i].linkNext(slices[i+1].offsetInShm) 138 | } 139 | 140 | next := slices[0].offsetInShm 141 | for i := 0; i < sliceNum; i++ { 142 | s, err := bm.readBufferSlice(next) 143 | assert.Equal(t, nil, err) 144 | assert.Equal(t, size, s.capacity()) 145 | assert.Equal(t, size, s.size()) 146 | readData, err := s.read(size) 147 | assert.Equal(t, nil, err, "i:%d offset:%d", i, next) 148 | assert.Equal(t, readData, writeDataArray[i]) 149 | isLastSlice := i == sliceNum-1 150 | assert.Equal(t, !isLastSlice, s.hasNext()) 151 | next = s.nextBufferOffset() 152 | } 153 | } 154 | 155 | func TestSliceList_PushPop(t *testing.T) { 156 | //1. twice push , twice pop 157 | l := newSliceList() 158 | l.pushBack(newBufferSlice(nil, make([]byte, 1024), 0, false)) 159 | assert.Equal(t, l.front(), l.back()) 160 | assert.Equal(t, 1, l.size()) 161 | 162 | l.pushBack(newBufferSlice(nil, make([]byte, 1024), 0, false)) 163 | assert.Equal(t, 2, l.size()) 164 | assert.Equal(t, false, l.front() == l.back()) 165 | 166 | assert.Equal(t, l.front(), l.popFront()) 167 | assert.Equal(t, 1, l.size()) 168 | assert.Equal(t, l.front(), l.back()) 169 | 170 | assert.Equal(t, l.front(), l.popFront()) 171 | assert.Equal(t, 0, l.size()) 172 | assert.Equal(t, true, l.front() == nil) 173 | assert.Equal(t, true, l.back() == nil) 174 | 175 | // multi push and pop 176 | const iterCount = 100 177 | for i := 0; i < iterCount; i++ { 178 | l.pushBack(newBufferSlice(nil, make([]byte, 1024), 0, false)) 179 | assert.Equal(t, i+1, l.size()) 180 | } 181 | for i := 0; i < iterCount; i++ { 182 | l.popFront() 183 | assert.Equal(t, iterCount-(i+1), l.size()) 184 | } 185 | assert.Equal(t, 0, l.size()) 186 | assert.Equal(t, true, l.front() == nil) 187 | assert.Equal(t, true, l.back() == nil) 188 | } 189 | 190 | func TestSliceList_SplitFromWrite(t *testing.T) { 191 | //1. sliceList's size == 1 192 | l := newSliceList() 193 | l.pushBack(newBufferSlice(nil, make([]byte, 1024), 0, false)) 194 | l.writeSlice = l.front() 195 | assert.Equal(t, true, l.splitFromWrite() == nil) 196 | assert.Equal(t, 1, l.size()) 197 | assert.Equal(t, l.front(), l.back()) 198 | assert.Equal(t, l.back(), l.writeSlice) 199 | 200 | //2. sliceList's size == 2, writeSlice's index is 0 201 | l = newSliceList() 202 | l.pushBack(newBufferSlice(nil, make([]byte, 1024), 0, false)) 203 | l.pushBack(newBufferSlice(nil, make([]byte, 1024), 0, false)) 204 | l.writeSlice = l.front() 205 | assert.Equal(t, l.back(), l.splitFromWrite()) 206 | assert.Equal(t, 1, l.size()) 207 | assert.Equal(t, l.front(), l.back()) 208 | assert.Equal(t, l.back(), l.writeSlice) 209 | 210 | //2. sliceList's size == 2, writeSlice's index is 1 211 | l = newSliceList() 212 | l.pushBack(newBufferSlice(nil, make([]byte, 1024), 0, false)) 213 | l.pushBack(newBufferSlice(nil, make([]byte, 1024), 0, false)) 214 | l.writeSlice = l.back() 215 | assert.Equal(t, true, l.splitFromWrite() == nil) 216 | assert.Equal(t, 2, l.size()) 217 | assert.Equal(t, l.back(), l.writeSlice) 218 | 219 | //3. sliceList's size == 2, writeSlice's index is 50 220 | l = newSliceList() 221 | for i := 0; i < 100; i++ { 222 | l.pushBack(newBufferSlice(nil, make([]byte, 1024), 0, false)) 223 | if i == 50 { 224 | l.writeSlice = l.back() 225 | } 226 | } 227 | l.splitFromWrite() 228 | assert.Equal(t, l.writeSlice, l.back()) 229 | assert.Equal(t, 51, l.size()) 230 | } 231 | -------------------------------------------------------------------------------- /buffer_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package shmipc 18 | 19 | import ( 20 | "math/rand" 21 | "os" 22 | "testing" 23 | 24 | "github.com/stretchr/testify/assert" 25 | ) 26 | 27 | var ( 28 | shmPath = "/tmp/ipc.test" 29 | bm *bufferManager 30 | ) 31 | 32 | func initShm() { 33 | os.Remove(shmPath) 34 | shmSize := 10 << 20 35 | var err error 36 | mem := make([]byte, shmSize) 37 | bm, err = createBufferManager([]*SizePercentPair{ 38 | {defaultSingleBufferSize, 70}, 39 | {16 * 1024, 20}, 40 | {64 * 1024, 10}, 41 | }, "", mem, 0) 42 | if err != nil { 43 | panic(err) 44 | } 45 | } 46 | 47 | func newLinkedBufferWithSlice(manager *bufferManager, slice *bufferSlice) *linkedBuffer { 48 | l := &linkedBuffer{ 49 | sliceList: newSliceList(), 50 | pinnedList: newSliceList(), 51 | bufferManager: manager, 52 | endStream: true, 53 | isFromShm: slice.isFromShm, 54 | } 55 | l.sliceList.pushBack(slice) 56 | l.sliceList.writeSlice = l.sliceList.back() 57 | return l 58 | } 59 | 60 | func TestLinkedBuffer_ReadWrite(t *testing.T) { 61 | initShm() 62 | factory := func() *linkedBuffer { 63 | buf, err := bm.allocShmBuffer(1024) 64 | if err != nil { 65 | t.Fatal(err) 66 | return nil 67 | } 68 | return newLinkedBufferWithSlice(bm, buf) 69 | } 70 | testLinkedBufferReadBytes(t, factory) 71 | } 72 | 73 | func TestLinkedBuffer_ReleasePreviousRead(t *testing.T) { 74 | initShm() 75 | slice, err := bm.allocShmBuffer(1024) 76 | if err != nil { 77 | t.Fatal(err) 78 | } 79 | buf := newLinkedBufferWithSlice(bm, slice) 80 | sliceNum := 100 81 | for i := 0; i < sliceNum*4096; i++ { 82 | assert.Equal(t, nil, buf.WriteByte(byte(i))) 83 | } 84 | buf.done(true) 85 | 86 | for i := 0; i < sliceNum/2; i++ { 87 | r, err := buf.ReadBytes(4096) 88 | assert.Equal(t, 4096, len(r), "buf.Len():%d", buf.Len()) 89 | assert.Equal(t, nil, err) 90 | } 91 | assert.Equal(t, (sliceNum/2)-1, buf.pinnedList.size()) 92 | _, _ = buf.Discard(buf.Len()) 93 | 94 | buf.releasePreviousReadAndReserve() 95 | assert.Equal(t, 0, buf.pinnedList.size()) 96 | assert.Equal(t, 0, buf.Len()) 97 | //the last slice shouldn't release 98 | assert.Equal(t, 1, buf.sliceList.size()) 99 | assert.Equal(t, true, buf.sliceList.writeSlice != nil) 100 | 101 | buf.ReleasePreviousRead() 102 | assert.Equal(t, 0, buf.sliceList.size()) 103 | assert.Equal(t, true, buf.sliceList.writeSlice == nil) 104 | } 105 | 106 | func TestLinkedBuffer_FallbackWhenWrite(t *testing.T) { 107 | mem := make([]byte, 10*1024) 108 | bm, err := createBufferManager([]*SizePercentPair{ 109 | {Size: 1024, Percent: 100}, 110 | }, "", mem, 0) 111 | assert.Equal(t, nil, err) 112 | 113 | buf, err := bm.allocShmBuffer(1024) 114 | assert.Equal(t, nil, err) 115 | writer := newLinkedBufferWithSlice(bm, buf) 116 | dataSize := 1024 117 | mockDataArray := make([][]byte, 100) 118 | for i := range mockDataArray { 119 | mockDataArray[i] = make([]byte, dataSize) 120 | rand.Read(mockDataArray[i]) 121 | n, err := writer.WriteBytes(mockDataArray[i]) 122 | assert.Equal(t, dataSize, n) 123 | assert.Equal(t, err, nil) 124 | assert.Equal(t, dataSize*(i+1), writer.Len()) 125 | } 126 | assert.Equal(t, false, writer.isFromShm) 127 | 128 | reader := writer.done(false) 129 | all := dataSize * len(mockDataArray) 130 | assert.Equal(t, all, writer.Len()) 131 | 132 | for i := range mockDataArray { 133 | assert.Equal(t, all-i*dataSize, reader.Len()) 134 | get, err := reader.ReadBytes(dataSize) 135 | if err != nil { 136 | t.Fatal("reader.ReadBytes error", err, i) 137 | } 138 | assert.Equal(t, mockDataArray[i], get) 139 | } 140 | } 141 | 142 | func testBufferReadString(t *testing.T, createBufferWriter func() *linkedBuffer) { 143 | writer := createBufferWriter() 144 | oneSliceSize := 16 << 10 145 | strBytesArray := make([][]byte, 100) 146 | for i := 0; i < len(strBytesArray); i++ { 147 | strBytesArray[i] = make([]byte, oneSliceSize) 148 | rand.Read(strBytesArray[i]) 149 | _, err := writer.WriteBytes(strBytesArray[i]) 150 | assert.Equal(t, true, nil == err) 151 | } 152 | 153 | reader := writer.done(false) 154 | for i := 0; i < len(strBytesArray); i++ { 155 | str, err := reader.ReadString(oneSliceSize) 156 | assert.Equal(t, true, nil == err) 157 | assert.Equal(t, string(strBytesArray[i]), str) 158 | } 159 | } 160 | 161 | // TODO: ensure reserving logic 162 | func TestLinkedBuffer_Reserve(t *testing.T) { 163 | initShm() 164 | 165 | // alloc 3 buffer slice 166 | buffer := newLinkedBuffer(bm, (64+64+64)*1024) 167 | assert.Equal(t, 3, buffer.sliceList.size()) 168 | assert.Equal(t, true, buffer.isFromShm) 169 | assert.Equal(t, buffer.sliceList.front(), buffer.sliceList.writeSlice) 170 | 171 | // reserve a buf in first slice 172 | ret, err := buffer.Reserve(60 * 1024) 173 | if err != nil { 174 | t.Fatal("LinkedBuffer Reserve error", err) 175 | } 176 | assert.Equal(t, len(ret), 60*1024) 177 | assert.Equal(t, 3, buffer.sliceList.size()) 178 | assert.Equal(t, true, buffer.isFromShm) 179 | assert.Equal(t, buffer.sliceList.front(), buffer.sliceList.writeSlice) 180 | 181 | // reserve a buf in the second slice when the first one is not enough 182 | ret, err = buffer.Reserve(6 * 1024) 183 | if err != nil { 184 | t.Fatal("LinkedBuffer Reserve error", err) 185 | } 186 | assert.Equal(t, len(ret), 6*1024) 187 | assert.Equal(t, 3, buffer.sliceList.size()) 188 | assert.Equal(t, true, buffer.isFromShm) 189 | assert.Equal(t, buffer.sliceList.front().next(), buffer.sliceList.writeSlice) 190 | 191 | // reserve a buf in a new allocated slice 192 | ret, err = buffer.Reserve(128 * 1024) 193 | if err != nil { 194 | t.Fatal("LinkedBuffer Reserve error", err) 195 | } 196 | assert.Equal(t, len(ret), 128*1024) 197 | assert.Equal(t, 4, buffer.sliceList.size()) 198 | assert.Equal(t, false, buffer.isFromShm) 199 | assert.Equal(t, buffer.sliceList.back(), buffer.sliceList.writeSlice) 200 | } 201 | 202 | func newLinkedBuffer(manager *bufferManager, size uint32) *linkedBuffer { 203 | l := &linkedBuffer{ 204 | sliceList: newSliceList(), 205 | pinnedList: newSliceList(), 206 | bufferManager: manager, 207 | isFromShm: true, 208 | } 209 | l.alloc(size) 210 | l.sliceList.writeSlice = l.sliceList.front() 211 | return l 212 | } 213 | 214 | func TestLinkedBuffer_Done(t *testing.T) { 215 | initShm() 216 | mockDataSize := 128 * 1024 217 | mockData := make([]byte, mockDataSize) 218 | rand.Read(mockData) 219 | // alloc 3 buffer slice 220 | buffer := newLinkedBuffer(bm, (64+64+64)*1024) 221 | assert.Equal(t, 3, buffer.sliceList.size()) 222 | 223 | // write data to full 2 slice, remove one 224 | _, _ = buffer.WriteBytes(mockData) 225 | reader := buffer.done(true) 226 | assert.Equal(t, 2, buffer.sliceList.size()) 227 | getBytes, _ := reader.ReadBytes(mockDataSize) 228 | assert.Equal(t, mockData, getBytes) 229 | } 230 | 231 | func testLinkedBufferReadBytes(t *testing.T, createBufferWriter func() *linkedBuffer) { 232 | 233 | writeAndRead := func(buf *linkedBuffer) { 234 | //2 MB 235 | size := 1 << 21 236 | data := make([]byte, size) 237 | rand.Read(data) 238 | for buf.Len() < size { 239 | oneWriteSize := rand.Intn(size / 10) 240 | if buf.Len()+oneWriteSize > size { 241 | oneWriteSize = size - buf.Len() 242 | } 243 | n, err := buf.WriteBytes(data[buf.Len() : buf.Len()+oneWriteSize]) 244 | assert.Equal(t, true, err == nil, err) 245 | assert.Equal(t, oneWriteSize, n) 246 | } 247 | 248 | buf.done(false) 249 | read := 0 250 | for buf.Len() > 0 { 251 | oneReadSize := rand.Intn(size / 10000) 252 | if read+oneReadSize > buf.Len() { 253 | oneReadSize = buf.Len() 254 | } 255 | //do nothing 256 | _, _ = buf.Peek(oneReadSize) 257 | 258 | readData, err := buf.ReadBytes(oneReadSize) 259 | assert.Equal(t, true, err == nil, err) 260 | if len(readData) == 0 { 261 | assert.Equal(t, oneReadSize, 0) 262 | } else { 263 | assert.Equal(t, data[read:read+oneReadSize], readData) 264 | } 265 | read += oneReadSize 266 | } 267 | assert.Equal(t, 1<<21, read) 268 | _, _ = buf.ReadBytes(-10) 269 | _, _ = buf.ReadBytes(0) 270 | buf.ReleasePreviousRead() 271 | } 272 | 273 | for i := 0; i < 100; i++ { 274 | writeAndRead(createBufferWriter()) 275 | } 276 | } 277 | -------------------------------------------------------------------------------- /config.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package shmipc 18 | 19 | import ( 20 | "errors" 21 | "fmt" 22 | "io" 23 | "os" 24 | "runtime" 25 | "time" 26 | ) 27 | 28 | // Config is used to tune the shmipc session 29 | type Config struct { 30 | // ConnectionWriteTimeout is meant to be a "safety valve" timeout after 31 | // we which will suspect a problem with the underlying connection and 32 | // close it. This is only applied to writes, where's there's generally 33 | // an expectation that things will move along quickly. 34 | ConnectionWriteTimeout time.Duration 35 | 36 | //In the initialization phase, client and server will exchange metadata and mapping share memory. 37 | //the InitializeTimeout specify how long time could use in this phase. 38 | InitializeTimeout time.Duration 39 | 40 | //The max number of pending io request. default is 8192 41 | QueueCap uint32 42 | 43 | //Share memory path of the underlying queue. 44 | QueuePath string 45 | 46 | //The capacity of buffer in share memory. default is 32MB 47 | ShareMemoryBufferCap uint32 48 | 49 | //The share memory path prefix of buffer. 50 | ShareMemoryPathPrefix string 51 | 52 | //LogOutput is used to control the log destination. 53 | LogOutput io.Writer 54 | 55 | //BufferSliceSizes could adjust share memory buffer slice size. 56 | //which could improve performance if most of all request or response's could write into single buffer slice instead of multi buffer slices. 57 | //Because multi buffer slices mean that more allocate and free operation, 58 | //and if the payload cross different buffer slice, it mean that payload in memory isn't continuous. 59 | //Default value is: 60 | // 1. 50% share memory hold on buffer slices that every slice is near to 8KB. 61 | // 2. 30% share memory hold on buffer slices that every slice is near to 32KB. 62 | // 3. 20% share memory hold on buffer slices that every slice is near to 128KB. 63 | BufferSliceSizes []*SizePercentPair 64 | 65 | // Server side's asynchronous API 66 | listenCallback ListenCallback 67 | 68 | //MemMapTypeDevShmFile or MemMapTypeMemFd (client set) 69 | MemMapType MemMapType 70 | 71 | //Session will emit some metrics to the Monitor with periodically (default 30s) 72 | Monitor Monitor 73 | 74 | // client rebuild session interval 75 | rebuildInterval time.Duration 76 | } 77 | 78 | //DefaultConfig is used to return a default configuration 79 | func DefaultConfig() *Config { 80 | return &Config{ 81 | ConnectionWriteTimeout: 10 * time.Second, 82 | InitializeTimeout: 1000 * time.Millisecond, 83 | QueueCap: defaultQueueCap, 84 | ShareMemoryBufferCap: defaultShareMemoryCap, 85 | ShareMemoryPathPrefix: "/dev/shm/shmipc", 86 | QueuePath: "/dev/shm/shmipc_queue", 87 | LogOutput: os.Stdout, 88 | MemMapType: MemMapTypeDevShmFile, 89 | BufferSliceSizes: []*SizePercentPair{ 90 | {8192 - bufferHeaderSize, 50}, 91 | {32*1024 - bufferHeaderSize, 30}, 92 | {128*1024 - bufferHeaderSize, 20}, 93 | }, 94 | rebuildInterval: sessionRebuildInterval, 95 | } 96 | } 97 | 98 | //VerifyConfig is used to verify the sanity of configuration 99 | func VerifyConfig(config *Config) error { 100 | if config.ShareMemoryBufferCap < (1 << 20) { 101 | return fmt.Errorf("share memory size is too small:%d, must greater than %d", config.ShareMemoryBufferCap, 1<<20) 102 | } 103 | if len(config.BufferSliceSizes) == 0 { 104 | return fmt.Errorf("BufferSliceSizes could not be nil") 105 | } 106 | 107 | sum := 0 108 | for _, pair := range config.BufferSliceSizes { 109 | sum += int(pair.Percent) 110 | if pair.Size > config.ShareMemoryBufferCap { 111 | return fmt.Errorf("BufferSliceSizes's Size:%d couldn't greater than ShareMemoryBufferCap:%d", 112 | pair.Size, config.ShareMemoryBufferCap) 113 | } 114 | 115 | if isArmArch() && pair.Size%4 != 0 { 116 | return fmt.Errorf("the SizePercentPair.Size must be a multiple of 4") 117 | } 118 | } 119 | if sum != 100 { 120 | return errors.New("the sum of BufferSliceSizes's Percent should be 100") 121 | } 122 | 123 | if isArmArch() && config.QueueCap%8 != 0 { 124 | return fmt.Errorf("the QueueCap must be a multiple of 8") 125 | } 126 | 127 | if config.ShareMemoryPathPrefix == "" || config.QueuePath == "" { 128 | return errors.New("buffer path or queue path could not be nil") 129 | } 130 | 131 | if runtime.GOOS != "linux" { 132 | return ErrOSNonSupported 133 | } 134 | 135 | if runtime.GOARCH != "amd64" && runtime.GOARCH != "arm64" { 136 | return ErrArchNonSupported 137 | } 138 | 139 | return nil 140 | } 141 | -------------------------------------------------------------------------------- /config_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package shmipc 18 | 19 | import ( 20 | "testing" 21 | "time" 22 | 23 | "github.com/stretchr/testify/assert" 24 | ) 25 | 26 | func Test_VerifyConfig(t *testing.T) { 27 | config := DefaultConfig() 28 | // shm too small, err 29 | config.ShareMemoryBufferCap = 1 30 | err := VerifyConfig(config) 31 | assert.NotEqual(t, nil, err) 32 | config.ShareMemoryBufferCap = 1 << 20 33 | 34 | config.BufferSliceSizes = []*SizePercentPair{} 35 | err = VerifyConfig(config) 36 | assert.NotEqual(t, nil, err) 37 | config.BufferSliceSizes = []*SizePercentPair{ 38 | {4096, 70}, 39 | {16 << 10, 20}, 40 | {64 << 10, 9}, 41 | } 42 | err = VerifyConfig(config) 43 | assert.NotEqual(t, nil, err) 44 | 45 | config.BufferSliceSizes = []*SizePercentPair{ 46 | {4096, 70}, 47 | {16 << 10, 20}, 48 | {64 << 10, 11}, 49 | } 50 | err = VerifyConfig(config) 51 | assert.NotEqual(t, nil, err) 52 | 53 | config.BufferSliceSizes = []*SizePercentPair{ 54 | {4096, 70}, 55 | {16 << 10, 20}, 56 | {defaultShareMemoryCap, 11}, 57 | } 58 | err = VerifyConfig(config) 59 | assert.NotEqual(t, nil, err) 60 | 61 | config.BufferSliceSizes = []*SizePercentPair{ 62 | {4096, 70}, 63 | {16 << 10, 20}, 64 | {64 << 10, 10}, 65 | } 66 | err = VerifyConfig(config) 67 | assert.Equal(t, nil, err) 68 | 69 | } 70 | 71 | func Test_CreateCSByWrongConfig(t *testing.T) { 72 | conn1, conn2 := testConn() 73 | config := DefaultConfig() 74 | config.ShareMemoryBufferCap = 1 75 | c, err := newSession(config, conn1, true) 76 | assert.NotEqual(t, nil, err) 77 | assert.Equal(t, (*Session)(nil), c) 78 | 79 | ok := make(chan struct{}) 80 | go func() { 81 | s, err := Server(conn2, config) 82 | assert.NotEqual(t, nil, err) 83 | assert.Equal(t, (*Session)(nil), s) 84 | close(ok) 85 | }() 86 | <-ok 87 | } 88 | 89 | func Test_CreateCSWithoutConfig(t *testing.T) { 90 | conn1, conn2 := testConn() 91 | ok := make(chan struct{}) 92 | go func() { 93 | s, err := Server(conn2, nil) 94 | assert.Equal(t, nil, err) 95 | assert.NotEqual(t, (*Session)(nil), s) 96 | if err == nil { 97 | defer s.Close() 98 | } 99 | close(ok) 100 | }() 101 | 102 | c, err := newSession(nil, conn1, true) 103 | assert.Equal(t, nil, err) 104 | if err == nil { 105 | defer c.Close() 106 | } 107 | assert.NotEqual(t, (*Session)(nil), c) 108 | time.Sleep(time.Second) 109 | <-ok 110 | } 111 | -------------------------------------------------------------------------------- /const.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package shmipc 18 | 19 | import ( 20 | "time" 21 | ) 22 | 23 | const ( 24 | // protoVersion is the only version we support 25 | protoVersion uint8 = 2 26 | maxSupportProtoVersion uint8 = 3 27 | magicNumber uint16 = 0x7758 28 | ) 29 | 30 | type eventType uint8 31 | 32 | type sessionSateType uint32 33 | 34 | // MemMapType is the mapping type of shared memory 35 | type MemMapType uint8 36 | 37 | const ( 38 | // MemMapTypeDevShmFile maps share memory to /dev/shm (tmpfs) 39 | MemMapTypeDevShmFile MemMapType = 0 40 | // MemMapTypeMemFd maps share memory to memfd (Linux OS v3.17+) 41 | MemMapTypeMemFd MemMapType = 1 42 | ) 43 | 44 | const ( 45 | defaultState sessionSateType = iota 46 | // server: send hot restart event; client: recv hot restart event 47 | hotRestartState 48 | // sercer: recv hot restart ack event; client: send hot restart ack event 49 | hotRestartDoneState 50 | ) 51 | 52 | const ( 53 | memfdCreateName = "shmipc" 54 | 55 | memfdDataLen = 4 56 | memfdCount = 2 57 | 58 | bufferPathSuffix = "_buffer" 59 | unixNetwork = "unix" 60 | 61 | hotRestartCheckTimeout = 2 * time.Second 62 | hotRestartCheckInterval = 100 * time.Millisecond 63 | 64 | sessionRebuildInterval = time.Second * 60 65 | 66 | epochIDLen = 8 67 | // linux file name max length 68 | fileNameMaxLen = 255 69 | // buffer path = %s_epoch_${epochID}_${randID} 70 | // len("_epoch_") + maxUint64StrLength + len("_") + maxUint64StrLength 71 | epochInfoMaxLen = 7 + 20 + 1 + 20 72 | // _queue_{sessionID int} 73 | queueInfoMaxLen = 7 + 20 74 | ) 75 | 76 | const ( 77 | defaultQueueCap = 8192 78 | defaultShareMemoryCap = 32 * 1024 * 1024 79 | defaultSingleBufferSize = 4096 80 | queueElementLen = 12 81 | queueCount = 2 82 | ) 83 | 84 | const ( 85 | sizeOfLength = 4 86 | sizeOfMagic = 2 87 | sizeOfVersion = 1 88 | sizeOfType = 1 89 | 90 | headerSize = sizeOfLength + sizeOfMagic + sizeOfVersion + sizeOfType 91 | ) 92 | 93 | var ( 94 | zeroTime = time.Time{} 95 | pollingEventWithVersion [maxSupportProtoVersion + 1]header 96 | ) 97 | -------------------------------------------------------------------------------- /debug.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package shmipc 18 | 19 | import ( 20 | "bytes" 21 | "fmt" 22 | "io" 23 | "io/ioutil" 24 | "os" 25 | "path/filepath" 26 | "runtime" 27 | "strconv" 28 | "sync/atomic" 29 | "time" 30 | "unsafe" 31 | ) 32 | 33 | type logger struct { 34 | name string 35 | out io.Writer 36 | callDepth int 37 | } 38 | 39 | var ( 40 | internalLogger = &logger{"", os.Stdout, 3} 41 | protocolLogger = &logger{"protocol trace", os.Stdout, 4} 42 | level int 43 | debugMode = false 44 | 45 | magenta = string([]byte{27, 91, 57, 53, 109}) // Trace 46 | green = string([]byte{27, 91, 57, 50, 109}) // Debug 47 | blue = string([]byte{27, 91, 57, 52, 109}) // Info 48 | yellow = string([]byte{27, 91, 57, 51, 109}) // Warn 49 | red = string([]byte{27, 91, 57, 49, 109}) // Error 50 | reset = string([]byte{27, 91, 48, 109}) 51 | 52 | colors = []string{ 53 | magenta, 54 | green, 55 | blue, 56 | yellow, 57 | red, 58 | } 59 | 60 | levelName = []string{ 61 | "Trace", 62 | "Debug", 63 | "Info", 64 | "Warn", 65 | "Error", 66 | } 67 | ) 68 | 69 | const ( 70 | levelTrace = iota 71 | levelDebug 72 | levelInfo 73 | levelWarn 74 | levelError 75 | levelNoPrint 76 | ) 77 | 78 | func init() { 79 | level = levelWarn 80 | if os.Getenv("SHMIPC_LOG_LEVEL") != "" { 81 | if n, err := strconv.Atoi(os.Getenv("SHMIPC_LOG_LEVEL")); err == nil { 82 | if n <= levelNoPrint { 83 | level = n 84 | } 85 | } 86 | } 87 | 88 | if os.Getenv("SHMIPC_DEBUG_MODE") != "" { 89 | debugMode = true 90 | } 91 | } 92 | 93 | //SetLogLevel used to change the internal logger's level and the default level is Warning. 94 | //The process env `SHMIPC_LOG_LEVEL` also could set log level 95 | func SetLogLevel(l int) { 96 | if l <= levelNoPrint { 97 | level = l 98 | } 99 | } 100 | 101 | func newLogger(name string, out io.Writer) *logger { 102 | if out == nil { 103 | out = os.Stdout 104 | } 105 | return &logger{ 106 | name: name, 107 | out: out, 108 | callDepth: 3, 109 | } 110 | } 111 | 112 | func (l *logger) errorf(format string, a ...interface{}) { 113 | if level > levelError { 114 | return 115 | } 116 | fmt.Fprintf(l.out, l.prefix(levelError)+format+reset+"\n", a...) 117 | } 118 | 119 | func (l *logger) error(v interface{}) { 120 | if level > levelError { 121 | return 122 | } 123 | fmt.Fprintln(l.out, l.prefix(levelError), v, reset) 124 | } 125 | 126 | func (l *logger) warnf(format string, a ...interface{}) { 127 | if level > levelWarn { 128 | return 129 | } 130 | fmt.Fprintf(l.out, l.prefix(levelWarn)+format+reset+"\n", a...) 131 | } 132 | 133 | func (l *logger) warn(v interface{}) { 134 | if level > levelWarn { 135 | return 136 | } 137 | fmt.Fprintln(l.out, l.prefix(levelWarn), v, reset) 138 | } 139 | 140 | func (l *logger) infof(format string, a ...interface{}) { 141 | if level > levelInfo { 142 | return 143 | } 144 | fmt.Fprintf(l.out, l.prefix(levelInfo)+format+reset+"\n", a...) 145 | } 146 | 147 | func (l *logger) info(v interface{}) { 148 | if level > levelInfo { 149 | return 150 | } 151 | fmt.Fprintln(l.out, l.prefix(levelInfo), v, reset) 152 | } 153 | 154 | func (l *logger) debugf(format string, a ...interface{}) { 155 | if level > levelDebug { 156 | return 157 | } 158 | fmt.Fprintf(l.out, l.prefix(levelDebug)+format+reset+"\n", a...) 159 | } 160 | 161 | func (l *logger) debug(v interface{}) { 162 | if level > levelDebug { 163 | return 164 | } 165 | fmt.Fprintln(l.out, l.prefix(levelDebug), v, reset) 166 | } 167 | 168 | func (l *logger) tracef(format string, a ...interface{}) { 169 | if level > levelTrace { 170 | return 171 | } 172 | //todo optimized 173 | fmt.Fprintf(l.out, l.prefix(levelTrace)+format+reset+"\n", a...) 174 | } 175 | 176 | func (l *logger) trace(v interface{}) { 177 | if level > levelTrace { 178 | return 179 | } 180 | fmt.Fprintln(l.out, l.prefix(levelTrace), v, reset) 181 | } 182 | 183 | func (l *logger) prefix(level int) string { 184 | var buffer [64]byte 185 | buf := bytes.NewBuffer(buffer[:0]) 186 | _, _ = buf.WriteString(colors[level]) 187 | _, _ = buf.WriteString(levelName[level]) 188 | _ = buf.WriteByte(' ') 189 | _, _ = buf.WriteString(time.Now().Format("2006-01-02 15:04:05.999999")) 190 | _ = buf.WriteByte(' ') 191 | _, _ = buf.WriteString(l.location()) 192 | _ = buf.WriteByte(' ') 193 | _, _ = buf.WriteString(l.name) 194 | _ = buf.WriteByte(' ') 195 | return buf.String() 196 | } 197 | 198 | func (l *logger) location() string { 199 | _, file, line, ok := runtime.Caller(l.callDepth) 200 | if !ok { 201 | file = "???" 202 | line = 0 203 | } 204 | file = filepath.Base(file) 205 | return file + ":" + strconv.Itoa(line) 206 | } 207 | 208 | func computeFreeSliceNum(list *bufferList) int { 209 | freeSlices := 0 210 | offset := atomic.LoadUint32(list.head) 211 | for { 212 | freeSlices++ 213 | bh := bufferHeader(list.bufferRegion[offset:]) 214 | hasNext := bh.hasNext() 215 | if !hasNext && offset != atomic.LoadUint32(list.tail) { 216 | fmt.Printf("something error, expectedTailOffset:%d but:%d current freeSlices:%d \n", 217 | atomic.LoadUint32(list.tail), offset, freeSlices) 218 | } 219 | if !hasNext { 220 | break 221 | } 222 | offset = bh.nextBufferOffset() 223 | if offset >= uint32(len(list.bufferRegion)) { 224 | fmt.Printf("something error , next offset is :%d ,greater than bufferRegion length:%d\n", 225 | offset, len(list.bufferRegion)) 226 | break 227 | } 228 | } 229 | return freeSlices 230 | } 231 | 232 | //1.occurred memory leak, if list's free slice number != expect free slice number. 233 | //2.print the metadata and payload of the leaked slice. 234 | func debugBufferListDetail(path string, bufferMgrHeaderSize int, bufferHeaderSize int) { 235 | mem, err := ioutil.ReadFile(path) 236 | if err != nil { 237 | fmt.Println(err) 238 | return 239 | } 240 | bm, err := mappingBufferManager(path, mem, 0) 241 | expectAllSliceNum := uint32(0) 242 | for i, list := range bm.lists { 243 | fmt.Printf("%d. list capacity:%d size:%d realSize:%d perSliceCap:%d headOffset:%d tailOffset:%d\n", 244 | i, *list.cap, *list.size, computeFreeSliceNum(list), *list.capPerBuffer, *list.head, *list.tail) 245 | expectAllSliceNum += *list.cap 246 | } 247 | freeSliceNum := uint32(bm.sliceSize()) 248 | fmt.Printf("summary:memory leak:%t all free slice num:%d expect num:%d\n", 249 | freeSliceNum != expectAllSliceNum, freeSliceNum, expectAllSliceNum) 250 | 251 | if freeSliceNum != expectAllSliceNum { 252 | fmt.Println("now check the buffer slice which is in used") 253 | } 254 | printLeakShareMemory(bm, bufferMgrHeaderSize, bufferHeaderSize) 255 | } 256 | func printLeakShareMemory(bm *bufferManager, bufferMgrHeaderSize int, bufferHeaderSize int) { 257 | offsetInShm := bufferMgrHeaderSize 258 | for i := range bm.lists { 259 | offsetInShm += bufferListHeaderSize 260 | data := bm.lists[i].bufferRegion 261 | offset := 0 262 | realSize := 0 263 | for offset < len(data) { 264 | realSize++ 265 | bh := bufferHeader(data[offset:]) 266 | size := int(*(*uint32)(unsafe.Pointer(&data[offset+bufferSizeOffset]))) 267 | flag := int(*(*uint8)(unsafe.Pointer(&data[offset+bufferFlagOffset]))) 268 | if !bh.hasNext() || bh.isInUsed() { 269 | fmt.Printf("offset in shm :%d next:%d len:%d flag:%d inused:%t data:%s\n", 270 | offset+offsetInShm, bh.nextBufferOffset(), size, flag, bh.isInUsed(), 271 | string(data[offset+bufferHeaderSize:offset+bufferHeaderSize+size])) 272 | } 273 | offset += int(*bm.lists[i].capPerBuffer) + bufferHeaderSize 274 | } 275 | offsetInShm += offset 276 | } 277 | } 278 | 279 | //DebugBufferListDetail print all BufferList's status in share memory located in the `path` 280 | //if MemMapType is MemMapTypeMemFd, you could using the command that 281 | //`lsof -p $PID` to found the share memory which was mmap by memfd, 282 | //and the command `cat /proc/$PID/$MEMFD > $path` dump the share memory to file system. 283 | func DebugBufferListDetail(path string) { 284 | debugBufferListDetail(path, 8, 20) 285 | } 286 | 287 | //DebugQueueDetail print IO-Queue's status which was mmap in the `path` 288 | func DebugQueueDetail(path string) { 289 | mem, err := ioutil.ReadFile(path) 290 | if err != nil { 291 | fmt.Println(err) 292 | return 293 | } 294 | sendQueue := mappingQueueFromBytes(mem[len(mem)/2:]) 295 | recvQueue := mappingQueueFromBytes(mem[:len(mem)/2]) 296 | printFunc := func(name string, q *queue) { 297 | fmt.Printf("path:%s name:%s, cap:%d head:%d tail:%d size:%d flag:%d\n", 298 | name, path, q.cap, *q.head, *q.tail, q.size(), *q.workingFlag) 299 | } 300 | printFunc("sendQueue", sendQueue) 301 | printFunc("recvQueue", recvQueue) 302 | } 303 | -------------------------------------------------------------------------------- /debug_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package shmipc 18 | 19 | import ( 20 | "testing" 21 | ) 22 | 23 | func TestLogColor(t *testing.T) { 24 | SetLogLevel(levelTrace) 25 | 26 | internalLogger.tracef("this is tracef %s", "hello world") 27 | internalLogger.trace("this is trace") 28 | 29 | internalLogger.infof("this is infof %s", "hello world") 30 | internalLogger.info("this is info") 31 | 32 | internalLogger.debugf("this is debugf %s", "hello world") 33 | internalLogger.debug("this is debug") 34 | 35 | internalLogger.warnf("this is warnf %s", "hello world") 36 | internalLogger.warn("this is warn") 37 | 38 | internalLogger.errorf("this is errorf %s", "hello world") 39 | internalLogger.error("this is error") 40 | } 41 | -------------------------------------------------------------------------------- /epoll_linux.go: -------------------------------------------------------------------------------- 1 | // +build !arm64 2 | 3 | /* 4 | * Copyright 2023 CloudWeGo Authors 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | 19 | package shmipc 20 | 21 | import ( 22 | "syscall" 23 | "unsafe" 24 | ) 25 | 26 | const epollModeET = -syscall.EPOLLET 27 | 28 | type epollEvent struct { 29 | events uint32 30 | data [8]byte 31 | } 32 | 33 | func epollCtl(epfd int, op int, fd int, event *epollEvent) (err error) { 34 | _, _, errCode := syscall.RawSyscall6(syscall.SYS_EPOLL_CTL, uintptr(epfd), uintptr(op), uintptr(fd), uintptr(unsafe.Pointer(event)), 0, 0) 35 | if errCode != syscall.Errno(0) { 36 | err = errCode 37 | } 38 | return err 39 | } 40 | 41 | func epollWait(epfd int, events []epollEvent, msec int) (n int, err error) { 42 | var n_ uintptr 43 | n_, _, errNo := syscall.Syscall6(syscall.SYS_EPOLL_WAIT, uintptr(epfd), uintptr(unsafe.Pointer(&events[0])), 44 | uintptr(len(events)), uintptr(msec), 0, 0) 45 | if errNo == syscall.Errno(0) { 46 | err = nil 47 | } 48 | return int(n_), err 49 | } 50 | -------------------------------------------------------------------------------- /epoll_linux_arm64.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package shmipc 18 | 19 | import ( 20 | "syscall" 21 | "unsafe" 22 | ) 23 | 24 | const epollModeET = syscall.EPOLLET 25 | 26 | type epollEvent struct { 27 | events uint32 28 | _ int32 29 | data [8]byte 30 | } 31 | 32 | func epollCtl(epfd int, op int, fd int, event *epollEvent) (err error) { 33 | _, _, errCode := syscall.RawSyscall6(syscall.SYS_EPOLL_CTL, uintptr(epfd), uintptr(op), uintptr(fd), 34 | uintptr(unsafe.Pointer(event)), 0, 0) 35 | if errCode != syscall.Errno(0) { 36 | err = errCode 37 | } 38 | return err 39 | } 40 | 41 | func epollWait(epfd int, events []epollEvent, msec int) (n int, err error) { 42 | var n_ uintptr 43 | n_, _, errNo := syscall.Syscall6(syscall.SYS_EPOLL_PWAIT, uintptr(epfd), uintptr(unsafe.Pointer(&events[0])), 44 | uintptr(len(events)), uintptr(msec), 0, 0) 45 | if errNo == syscall.Errno(0) { 46 | err = nil 47 | } 48 | return int(n_), err 49 | } 50 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package shmipc 18 | 19 | import ( 20 | "errors" 21 | ) 22 | 23 | var ( 24 | //ErrInvalidVersion means that we received a message with an invalid version 25 | ErrInvalidVersion = errors.New("invalid protocol version") 26 | 27 | //ErrInvalidMsgType means that we received a message with an invalid message type 28 | ErrInvalidMsgType = errors.New("invalid msg type") 29 | 30 | //ErrSessionShutdown means that the session is shutdown 31 | ErrSessionShutdown = errors.New("session shutdown") 32 | 33 | //ErrStreamsExhausted means that the stream id was used out and maybe have some streams leaked. 34 | ErrStreamsExhausted = errors.New("streams exhausted") 35 | 36 | //ErrTimeout is used when we reach an IO deadline 37 | ErrTimeout = errors.New("i/o deadline reached") 38 | 39 | //ErrStreamClosed was returned when using a closed stream 40 | ErrStreamClosed = errors.New("stream closed") 41 | 42 | //ErrConnectionWriteTimeout means that the write timeout was happened in tcp/unix connection. 43 | ErrConnectionWriteTimeout = errors.New("connection write timeout") 44 | 45 | //ErrEndOfStream means that the stream is end, user shouldn't to read from the stream. 46 | ErrEndOfStream = errors.New("end of stream") 47 | 48 | //ErrSessionUnhealthy occurred at Session.OpenStream(), which mean that the session is overload. 49 | //user should retry after 60 seconds(now). the followings situation will result in ErrSessionUnhealthy. 50 | //on client side: 51 | // 1. when local share memory is not enough, client send request data via unix domain socket. 52 | // 2. when peer share memory is not enough, client receive response data from unix domain socket. 53 | ErrSessionUnhealthy = errors.New("now the session is unhealthy, please retry later") 54 | 55 | //ErrNotEnoughData means that the real read size < expect read size. 56 | ErrNotEnoughData = errors.New("current buffer is not enough data to read") 57 | 58 | //ErrNoMoreBuffer means that the share memory is busy, and not more buffer to allocate. 59 | ErrNoMoreBuffer = errors.New("share memory not more buffer") 60 | 61 | //ErrShareMemoryHadNotLeftSpace means that reached the limitation of the file system when using MemMapTypeDevShm. 62 | ErrShareMemoryHadNotLeftSpace = errors.New("share memory had not left space") 63 | 64 | //ErrStreamCallbackHadExisted was returned if the Stream'Callbacks had existed 65 | ErrStreamCallbackHadExisted = errors.New("stream callbacks had existed") 66 | 67 | //ErrOSNonSupported means that shmipc couldn't work in current OS. (only support Linux now) 68 | ErrOSNonSupported = errors.New("shmipc just support linux OS now") 69 | 70 | //ErrArchNonSupported means that shmipc only support amd64 and arm64 71 | ErrArchNonSupported = errors.New("shmipc just support amd64 or arm64 arch") 72 | 73 | //ErrHotRestartInProgress was returned by Listener.HotRestart when the Session had under the hot restart state 74 | ErrHotRestartInProgress = errors.New("hot restart in progress, try again later") 75 | 76 | //ErrInHandshakeStage was happened in the case that the uninitialized session doing hot restart. 77 | ErrInHandshakeStage = errors.New("session in handshake stage, try again later") 78 | 79 | //ErrFileNameTooLong mean that eht Config.ShareMemoryPathPrefixFile's length reached the limitation of the OS. 80 | ErrFileNameTooLong = errors.New("share memory path prefix too long") 81 | 82 | //ErrQueueFull mean that the server is so busy that the io queue is full 83 | ErrQueueFull = errors.New("the io queue is full") 84 | 85 | errQueueEmpty = errors.New("the io queue is empty") 86 | ) 87 | -------------------------------------------------------------------------------- /event_dispatcher.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package shmipc 18 | 19 | import ( 20 | "fmt" 21 | "net" 22 | "os" 23 | "sync" 24 | ) 25 | 26 | var defaultDispatcher dispatcher 27 | 28 | type eventConnCallback interface { 29 | onEventData(buf []byte, conn eventConn) error 30 | onRemoteClose() 31 | onLocalClose() 32 | } 33 | 34 | type eventConn interface { 35 | commitRead(n int) 36 | setCallback(cb eventConnCallback) error 37 | write(data []byte) error 38 | writev(data ...[]byte) error 39 | close() error 40 | } 41 | 42 | //only serve for connection now 43 | type dispatcher interface { 44 | runLoop() error 45 | newConnection(connFd *os.File) eventConn 46 | shutdown() error 47 | post(f func()) 48 | } 49 | 50 | var ( 51 | dispatcherInitOnce sync.Once 52 | ) 53 | 54 | func ensureDefaultDispatcherInit() { 55 | if defaultDispatcher != nil { 56 | dispatcherInitOnce.Do(func() { 57 | err := defaultDispatcher.runLoop() 58 | if err != nil { 59 | panic("shmipc run diapcher failed, reason:" + err.Error()) 60 | } 61 | }) 62 | } 63 | } 64 | 65 | func getConnDupFd(conn net.Conn) (*os.File, error) { 66 | type hasFile interface { 67 | File() (f *os.File, err error) 68 | } 69 | f, ok := conn.(hasFile) 70 | if !ok { 71 | return nil, fmt.Errorf("conn has no method File() (f *os.File, err error)") 72 | } 73 | return f.File() 74 | } 75 | -------------------------------------------------------------------------------- /event_dispatcher_linux.go: -------------------------------------------------------------------------------- 1 | // +build !race 2 | 3 | /* 4 | * Copyright 2023 CloudWeGo Authors 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | 19 | package shmipc 20 | 21 | import ( 22 | "fmt" 23 | "os" 24 | "runtime" 25 | "sync" 26 | "sync/atomic" 27 | "syscall" 28 | "unsafe" 29 | ) 30 | 31 | var ( 32 | _ dispatcher = &epollDispatcher{} 33 | _ eventConn = &connEventHandler{} 34 | ) 35 | 36 | func init() { 37 | defaultDispatcher = newEpollDispatcher() 38 | } 39 | 40 | type connEventHandler struct { 41 | dispatcher *epollDispatcher 42 | file *os.File 43 | callback eventConnCallback 44 | readBuffer []byte 45 | onWriteReadyCh chan struct{} 46 | ioves [256]syscall.Iovec 47 | readStartOff int 48 | readEndOff int 49 | fd int 50 | isClose uint32 51 | } 52 | 53 | func (c *connEventHandler) handleEvent(events int, d *epollDispatcher) { 54 | if events&syscall.EPOLLRDHUP != 0 { 55 | c.onRemoteClose() 56 | return 57 | } 58 | 59 | if events&syscall.EPOLLIN != 0 { 60 | if err := c.onReadReady(); err != nil { 61 | internalLogger.warnf("read failed fd:%d reason:%s", c.fd, err.Error()) 62 | } 63 | } 64 | 65 | if events&syscall.EPOLLOUT != 0 { 66 | if err := c.onWriteReady(); err != nil { 67 | internalLogger.warnf("write failed fd:%d reason:%s", c.fd, err.Error()) 68 | } 69 | } 70 | } 71 | 72 | func (c *connEventHandler) onRemoteClose() { 73 | c.callback.onRemoteClose() 74 | c.deferredClose() 75 | } 76 | 77 | //avoiding write concurrently, blocking until return 78 | func (c *connEventHandler) writev(data ...[]byte) error { 79 | if len(data) == 0 { 80 | return nil 81 | } 82 | writtenSliceNum := 0 83 | for writtenSliceNum < len(data) { 84 | n, err := c.doWritev(data[writtenSliceNum:]...) 85 | if err != nil { 86 | return err 87 | } 88 | writtenSliceNum += n 89 | } 90 | return nil 91 | } 92 | 93 | func (c *connEventHandler) doWritev(data ...[]byte) (int, error) { 94 | needSubmitIovecLen := 0 95 | for needSubmitIovecLen < len(data) && needSubmitIovecLen < len(c.ioves) { 96 | sliceLen := len(data[needSubmitIovecLen]) 97 | c.ioves[needSubmitIovecLen].Len = uint64(sliceLen) 98 | c.ioves[needSubmitIovecLen].Base = &data[needSubmitIovecLen][0] 99 | needSubmitIovecLen++ 100 | } 101 | writtenSliceNum := needSubmitIovecLen 102 | 103 | writtenVec := 0 104 | for needSubmitIovecLen > 0 { 105 | if atomic.LoadUint32(&c.isClose) == 1 { 106 | return -1, syscall.EPIPE 107 | } 108 | n, _, err := syscall.Syscall(syscall.SYS_WRITEV, uintptr(c.fd), 109 | uintptr(unsafe.Pointer(&c.ioves[writtenVec])), uintptr(needSubmitIovecLen)) 110 | 111 | if err == syscall.EAGAIN { 112 | <-c.onWriteReadyCh 113 | continue 114 | } 115 | 116 | if err != syscall.Errno(0) { 117 | return -1, err 118 | } 119 | 120 | //ack write 121 | for writtenSize := uint64(n); writtenSize > 0; { 122 | if writtenSize >= c.ioves[writtenVec].Len { 123 | writtenSize -= c.ioves[writtenVec].Len 124 | needSubmitIovecLen-- 125 | writtenVec++ 126 | } else { 127 | c.ioves[writtenVec].Len -= writtenSize 128 | startOff := uint64(len(data[writtenVec])) - c.ioves[writtenVec].Len 129 | c.ioves[writtenVec].Base = &data[writtenVec][startOff] 130 | break 131 | } 132 | } 133 | 134 | } 135 | return writtenSliceNum, nil 136 | } 137 | 138 | //avoiding write concurrently, blocking until return 139 | func (c *connEventHandler) write(data []byte) error { 140 | written, size := 0, len(data) 141 | for written < size { 142 | if atomic.LoadUint32(&c.isClose) == 1 { 143 | return syscall.EPIPE 144 | } 145 | n, _, err := syscall.Syscall(syscall.SYS_WRITE, uintptr(c.fd), uintptr(unsafe.Pointer(&data[written])), 146 | uintptr(size-written)) 147 | if err == syscall.EAGAIN { 148 | <-c.onWriteReadyCh 149 | continue 150 | } 151 | if err != syscall.Errno(0) { 152 | return err 153 | } 154 | written += int(n) 155 | } 156 | 157 | return nil 158 | } 159 | 160 | func (c *connEventHandler) maybeExpandReadBuffer() { 161 | bufRemain := len(c.readBuffer) - c.readEndOff 162 | if bufRemain == 0 { 163 | newBuf := make([]byte, 2*len(c.readBuffer)) 164 | c.readEndOff = copy(newBuf, c.readBuffer[c.readStartOff:c.readEndOff]) 165 | c.readStartOff = 0 166 | c.readBuffer = newBuf 167 | } 168 | } 169 | 170 | // we could ensure that when the onReadReady() was called, the connection's fd must be open. 171 | func (c *connEventHandler) onReadReady() (err error) { 172 | const onDataThreshold = 1 * 1024 * 1024 173 | for { 174 | c.maybeExpandReadBuffer() 175 | n, _, errCode := syscall.RawSyscall(syscall.SYS_READ, uintptr(c.fd), 176 | uintptr(unsafe.Pointer(&c.readBuffer[c.readEndOff])), uintptr(len(c.readBuffer)-c.readEndOff)) 177 | 178 | if errCode == syscall.EAGAIN { 179 | break 180 | } 181 | 182 | if errCode != 0 { 183 | return 184 | } 185 | if n == 0 { 186 | c.onRemoteClose() 187 | break 188 | } 189 | 190 | c.readEndOff += int(n) 191 | if c.readEndOff-c.readStartOff >= onDataThreshold { 192 | if err = c.callback.onEventData(c.readBuffer[c.readStartOff:c.readEndOff], c); err != nil { 193 | return err 194 | } 195 | } 196 | } 197 | 198 | return c.callback.onEventData(c.readBuffer[c.readStartOff:c.readEndOff], c) 199 | } 200 | 201 | func (c *connEventHandler) onWriteReady() error { 202 | //fmt.Println("onWriteReady fd:", c.fd) 203 | if atomic.LoadUint32(&c.isClose) == 0 { 204 | asyncNotify(c.onWriteReadyCh) 205 | } 206 | return nil 207 | } 208 | 209 | func (c *connEventHandler) commitRead(n int) { 210 | c.readStartOff += n 211 | if c.readStartOff == c.readEndOff { 212 | // The threshold for resizing is 4M 213 | minResizedBufferSize := 4 * 1024 * 1024 214 | // The size of readBuffer is even times of 64K 215 | if len(c.readBuffer) > minResizedBufferSize { 216 | // We need not reset buffer size to 64k immediately 217 | // since we may expand next time. 218 | // Reduce it gradually 219 | minResizedBufferSize = len(c.readBuffer) / 2 220 | c.readBuffer = c.readBuffer[:minResizedBufferSize] 221 | } 222 | c.readStartOff = 0 223 | c.readEndOff = 0 224 | } 225 | } 226 | 227 | func (c *connEventHandler) close() error { 228 | c.callback.onLocalClose() 229 | c.deferredClose() 230 | return nil 231 | } 232 | 233 | func (c *connEventHandler) deferredClose() { 234 | if atomic.CompareAndSwapUint32(&c.isClose, 0, 1) { 235 | close(c.onWriteReadyCh) 236 | c.dispatcher.post(func() { 237 | epollCtl(c.dispatcher.epollFd, syscall.EPOLL_CTL_DEL, c.fd, nil) 238 | c.dispatcher.lock.Lock() 239 | delete(c.dispatcher.conns, c.fd) 240 | c.dispatcher.lock.Unlock() 241 | c.file.Close() 242 | }) 243 | } 244 | } 245 | 246 | func (c *connEventHandler) setCallback(cb eventConnCallback) error { 247 | if err := syscall.SetNonblock(c.fd, true); err != nil { 248 | return fmt.Errorf("fd:%d couldn't set nobloking,reason=%s", c.fd, err) 249 | } 250 | event := &epollEvent{ 251 | events: syscall.EPOLLIN | syscall.EPOLLOUT | epollModeET | syscall.EPOLLRDHUP, 252 | } 253 | *(**connEventHandler)(unsafe.Pointer(&event.data)) = c 254 | c.dispatcher.lock.Lock() 255 | defer c.dispatcher.lock.Unlock() 256 | c.callback = cb 257 | if err := epollCtl(c.dispatcher.epollFd, syscall.EPOLL_CTL_ADD, c.fd, event); err != nil { 258 | return fmt.Errorf("epollCt fd:%d failed, reason=%s", c.fd, err) 259 | } 260 | c.dispatcher.conns[c.fd] = c 261 | return nil 262 | } 263 | 264 | type epollDispatcher struct { 265 | epollFd int 266 | epollFile *os.File 267 | conns map[int]*connEventHandler 268 | lock sync.Mutex 269 | waitLoopExitWg sync.WaitGroup 270 | toCloseConns []*connEventHandler 271 | pendingLambda []func() 272 | runningLambda []func() 273 | lambdaLock sync.Mutex 274 | isShutdown bool 275 | } 276 | 277 | func newEpollDispatcher() *epollDispatcher { 278 | return &epollDispatcher{ 279 | conns: make(map[int]*connEventHandler, 8), 280 | pendingLambda: make([]func(), 0, 32), 281 | runningLambda: make([]func(), 0, 32), 282 | } 283 | } 284 | 285 | func (d *epollDispatcher) post(f func()) { 286 | d.lambdaLock.Lock() 287 | d.pendingLambda = append(d.pendingLambda, f) 288 | d.lambdaLock.Unlock() 289 | } 290 | 291 | func (d *epollDispatcher) newConnection(file *os.File) eventConn { 292 | return &connEventHandler{ 293 | fd: int(file.Fd()), //notice: the function file.Fd() will set fd blocking 294 | file: file, 295 | dispatcher: d, 296 | readBuffer: make([]byte, 64*1024), 297 | onWriteReadyCh: make(chan struct{}, 1), 298 | isClose: 0, 299 | } 300 | } 301 | 302 | func (d *epollDispatcher) runLambda() { 303 | d.lambdaLock.Lock() 304 | d.runningLambda, d.pendingLambda = d.pendingLambda, d.runningLambda 305 | d.lambdaLock.Unlock() 306 | for _, f := range d.runningLambda { 307 | if d.isShutdown { 308 | return 309 | } 310 | f() 311 | } 312 | d.runningLambda = d.runningLambda[:0] 313 | } 314 | 315 | func (d *epollDispatcher) runLoop() error { 316 | epollFd, err := syscall.EpollCreate1(0) 317 | if err != nil { 318 | return err 319 | } 320 | d.epollFd = epollFd 321 | d.waitLoopExitWg.Add(1) 322 | go func() { 323 | defer d.waitLoopExitWg.Done() 324 | var events [128]epollEvent 325 | timeout := 0 326 | for { 327 | n, err := epollWait(d.epollFd, events[:], timeout) 328 | if err != nil { 329 | fmt.Println("epoll wait error", err) 330 | return 331 | } 332 | if n <= 0 { 333 | timeout = 1000 334 | runtime.Gosched() 335 | d.runLambda() 336 | continue 337 | } 338 | timeout = 0 339 | d.lock.Lock() 340 | for i := 0; i < n; i++ { 341 | var h *connEventHandler = *(**connEventHandler)(unsafe.Pointer(&events[i].data)) 342 | h.handleEvent(int(events[i].events), d) 343 | } 344 | d.lock.Unlock() 345 | d.runLambda() 346 | } 347 | runtime.KeepAlive(d) 348 | }() 349 | return nil 350 | } 351 | 352 | func (d *epollDispatcher) shutdown() error { 353 | d.epollFile.Close() 354 | d.waitLoopExitWg.Wait() 355 | d.lock.Lock() 356 | defer d.lock.Unlock() 357 | d.isShutdown = true 358 | for fd := range d.conns { 359 | syscall.Close(fd) 360 | delete(d.conns, fd) 361 | } 362 | return nil 363 | } 364 | -------------------------------------------------------------------------------- /event_dispatcher_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package shmipc 18 | 19 | import ( 20 | "fmt" 21 | "github.com/stretchr/testify/assert" 22 | "math/rand" 23 | "net" 24 | "runtime" 25 | "testing" 26 | "time" 27 | ) 28 | 29 | var ( 30 | expectData []byte 31 | writevData [][]byte 32 | done = make(chan struct{}) 33 | ) 34 | 35 | type serverConnCallback struct { 36 | t *testing.T 37 | readBuffer []byte 38 | } 39 | 40 | type clientConnCallback struct { 41 | t *testing.T 42 | } 43 | 44 | func (c *clientConnCallback) onEventData(buf []byte, conn eventConn) error { return nil } 45 | func (c *clientConnCallback) onRemoteClose() { fmt.Println("client onRemoteClose") } 46 | func (c *clientConnCallback) onLocalClose() { fmt.Println("client onLocalClose") } 47 | 48 | func (c *serverConnCallback) onEventData(buf []byte, conn eventConn) error { 49 | c.readBuffer = append(c.readBuffer, buf...) 50 | conn.commitRead(len(buf)) 51 | time.Sleep(time.Millisecond) 52 | if len(c.readBuffer) == len(expectData) { 53 | //fmt.Println("c.readBufferLen", len(c.readBuffer), "expectData len", len(expectData)) 54 | assert.Equal(c.t, c.readBuffer, expectData) 55 | close(done) 56 | } 57 | //fmt.Println("c.readBufferLen", len(c.readBuffer), "expectData len", len(expectData)) 58 | return nil 59 | } 60 | 61 | func (c *serverConnCallback) onRemoteClose() { fmt.Println("server onRemoteClose") } 62 | func (c *serverConnCallback) onLocalClose() { fmt.Println("server onLocalClose") } 63 | 64 | var _ eventConnCallback = &serverConnCallback{} 65 | var _ eventConnCallback = &clientConnCallback{} 66 | 67 | func fillTestingData() { 68 | const msgN = 1020 69 | writevData = make([][]byte, msgN) 70 | expectData = make([]byte, 0, 1024*1024*msgN) 71 | for i := 0; i < msgN; i++ { 72 | writevData[i] = make([]byte, rand.Intn(1*1024*1024)) 73 | rand.Read(writevData[i]) 74 | expectData = append(expectData, writevData[i]...) 75 | //fmt.Println("slice i ", i, len(writevData[i])) 76 | } 77 | } 78 | 79 | func Test_EventDispatcher(t *testing.T) { 80 | ensureDefaultDispatcherInit() 81 | fillTestingData() 82 | d := defaultDispatcher 83 | var clientConn, serverConn eventConn 84 | go func() { 85 | ln, err := net.Listen("tcp", ":7777") 86 | if err != nil { 87 | fmt.Println(err) 88 | return 89 | } 90 | for { 91 | conn, err := ln.Accept() 92 | if err != nil { 93 | fmt.Println(err) 94 | return 95 | } 96 | fd, err := getConnDupFd(conn) 97 | if err != nil { 98 | fmt.Println(err) 99 | return 100 | } 101 | conn.Close() 102 | serverConn = d.newConnection(fd) 103 | if err := serverConn.setCallback(&serverConnCallback{ 104 | t: t, 105 | readBuffer: make([]byte, 0, len(expectData)), 106 | }); err != nil { 107 | panic(err) 108 | } 109 | runtime.KeepAlive(fd) 110 | } 111 | }() 112 | 113 | time.Sleep(100 * time.Millisecond) 114 | 115 | go func() { 116 | conn, err := net.Dial("tcp", ":7777") 117 | if err != nil { 118 | fmt.Println(err) 119 | return 120 | } 121 | 122 | fd, err := getConnDupFd(conn) 123 | if err != nil { 124 | fmt.Println(err) 125 | return 126 | } 127 | conn.Close() 128 | clientConn = d.newConnection(fd) 129 | if err := clientConn.setCallback(&clientConnCallback{t}); err != nil { 130 | fmt.Println(err) 131 | return 132 | } 133 | 134 | err = clientConn.write(writevData[0]) 135 | if err != nil { 136 | fmt.Println("write error", err) 137 | return 138 | } 139 | 140 | err = clientConn.writev(writevData[1:]...) 141 | if err != nil { 142 | fmt.Println("writev error", err) 143 | return 144 | } 145 | runtime.KeepAlive(fd) 146 | }() 147 | 148 | <-done 149 | clientConn.close() 150 | serverConn.close() 151 | } 152 | -------------------------------------------------------------------------------- /example/best_practice/idl/example.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package idl 18 | 19 | import ( 20 | "encoding/binary" 21 | "net" 22 | "sync" 23 | 24 | shmipc "github.com/cloudwego/shmipc-go" 25 | ) 26 | 27 | // the idl is the Request struct and Response struct. 28 | // the demo will show hot to use Stream's BufferWriter() and BufferReader() to do serialization and deserialization. 29 | 30 | type Request struct { 31 | ID uint64 32 | Name string 33 | Key []byte 34 | } 35 | 36 | type Response struct { 37 | ID uint64 38 | Name string 39 | Image []byte 40 | } 41 | 42 | func (r *Request) ReadFromShm(reader shmipc.BufferReader) error { 43 | //1.read ID 44 | data, err := reader.ReadBytes(8) 45 | if err != nil { 46 | return err 47 | } 48 | r.ID = binary.BigEndian.Uint64(data) 49 | 50 | //2.read Name 51 | data, err = reader.ReadBytes(4) 52 | if err != nil { 53 | return err 54 | } 55 | strLen := binary.BigEndian.Uint32(data) 56 | data, err = reader.ReadBytes(int(strLen)) 57 | if err != nil { 58 | return err 59 | } 60 | r.Name = string(data) 61 | 62 | //3. Read Key 63 | data, err = reader.ReadBytes(4) 64 | if err != nil { 65 | return err 66 | } 67 | keyLen := binary.BigEndian.Uint32(data) 68 | data, err = reader.ReadBytes(int(keyLen)) 69 | if err != nil { 70 | return err 71 | } 72 | r.Key = r.Key[:0] 73 | r.Key = append(r.Key, data...) 74 | 75 | //4. release share memory buffer 76 | reader.ReleasePreviousRead() 77 | return nil 78 | } 79 | 80 | func (r *Request) WriteToShm(writer shmipc.BufferWriter) error { 81 | //1.write ID 82 | data, err := writer.Reserve(8) 83 | if err != nil { 84 | return err 85 | } 86 | binary.BigEndian.PutUint64(data, r.ID) 87 | 88 | //2.write Name 89 | data, err = writer.Reserve(4) 90 | if err != nil { 91 | return err 92 | } 93 | binary.BigEndian.PutUint32(data, uint32(len(r.Name))) 94 | if err = writer.WriteString(r.Name); err != nil { 95 | return nil 96 | } 97 | 98 | //3.write Key 99 | data, err = writer.Reserve(4) 100 | if err != nil { 101 | return err 102 | } 103 | binary.BigEndian.PutUint32(data, uint32(len(r.Key))) 104 | if _, err = writer.WriteBytes(r.Key); err != nil { 105 | return err 106 | } 107 | return nil 108 | } 109 | 110 | var BufferPool = sync.Pool{New: func() interface{} { 111 | return make([]byte, 4096) 112 | }} 113 | 114 | func (r *Request) Serialize() []byte { 115 | data := BufferPool.Get().([]byte) 116 | offset := 0 117 | binary.BigEndian.PutUint64(data, r.ID) 118 | offset += 8 119 | binary.BigEndian.PutUint32(data[offset:], uint32(len(r.Name))) 120 | offset += 4 121 | copy(data[offset:], r.Name) 122 | offset += len(r.Name) 123 | binary.BigEndian.PutUint32(data[offset:], uint32(len(r.Key))) 124 | offset += 4 125 | copy(data[offset:], r.Key) 126 | offset += len(r.Key) 127 | return data[:offset] 128 | } 129 | 130 | func (r *Request) Deserialize(data []byte) { 131 | //1.read ID 132 | offset := 0 133 | r.ID = binary.BigEndian.Uint64(data[offset:]) 134 | offset += 8 135 | 136 | strLen := binary.BigEndian.Uint32(data[offset:]) 137 | offset += 4 138 | r.Name = string(data[offset : offset+int(strLen)]) 139 | offset += int(strLen) 140 | 141 | //3. Read Key 142 | keyLen := binary.BigEndian.Uint32(data[offset:]) 143 | offset += 4 144 | r.Key = r.Key[:0] 145 | r.Key = append(r.Key, data[offset:offset+int(keyLen)]...) 146 | } 147 | 148 | func (r *Request) Reset() { 149 | r.ID = 0 150 | r.Name = "" 151 | r.Key = r.Key[:0] 152 | } 153 | 154 | func (r *Response) ReadFromShm(reader shmipc.BufferReader) error { 155 | //1.read ID 156 | data, err := reader.ReadBytes(8) 157 | if err != nil { 158 | return err 159 | } 160 | r.ID = binary.BigEndian.Uint64(data) 161 | 162 | //2.read Name 163 | data, err = reader.ReadBytes(4) 164 | if err != nil { 165 | return err 166 | } 167 | strLen := binary.BigEndian.Uint32(data) 168 | data, err = reader.ReadBytes(int(strLen)) 169 | if err != nil { 170 | return err 171 | } 172 | r.Name = string(data) 173 | 174 | //3. Read Image 175 | data, err = reader.ReadBytes(4) 176 | if err != nil { 177 | return err 178 | } 179 | imageLen := binary.BigEndian.Uint32(data) 180 | data, err = reader.ReadBytes(int(imageLen)) 181 | if err != nil { 182 | return err 183 | } 184 | r.Image = r.Image[:0] 185 | r.Image = append(r.Image, data...) 186 | 187 | //4. release share memory buffer 188 | reader.ReleasePreviousRead() 189 | return nil 190 | } 191 | 192 | func (r *Response) Deserialize(data []byte) { 193 | //1.read ID 194 | offset := 0 195 | r.ID = binary.BigEndian.Uint64(data) 196 | offset += 8 197 | 198 | //2.read Name 199 | strLen := binary.BigEndian.Uint32(data[offset:]) 200 | offset += 4 201 | r.Name = string(data[offset : offset+int(strLen)]) 202 | offset += int(strLen) 203 | 204 | //3. Read Image 205 | imageLen := binary.BigEndian.Uint32(data[offset:]) 206 | offset += 4 207 | r.Image = r.Image[:0] 208 | r.Image = append(r.Image, data[offset:offset+int(imageLen)]...) 209 | 210 | } 211 | 212 | func (r *Response) Serialize() []byte { 213 | //1.write ID 214 | data := BufferPool.Get().([]byte) 215 | offset := 0 216 | binary.BigEndian.PutUint64(data, r.ID) 217 | offset += 8 218 | 219 | //2.write Name 220 | binary.BigEndian.PutUint32(data[offset:], uint32(len(r.Name))) 221 | offset += 4 222 | offset += copy(data[offset:], r.Name) 223 | 224 | //3.write Image 225 | binary.BigEndian.PutUint32(data[offset:], uint32(len(r.Image))) 226 | offset += 4 227 | copy(data[offset:], r.Image) 228 | offset += len(r.Image) 229 | return data[:offset] 230 | } 231 | 232 | func (r *Response) WriteToShm(writer shmipc.BufferWriter) error { 233 | //1.write ID 234 | data, err := writer.Reserve(8) 235 | if err != nil { 236 | return err 237 | } 238 | binary.BigEndian.PutUint64(data, r.ID) 239 | 240 | //2.write Name 241 | data, err = writer.Reserve(4) 242 | if err != nil { 243 | return err 244 | } 245 | binary.BigEndian.PutUint32(data, uint32(len(r.Name))) 246 | if err = writer.WriteString(r.Name); err != nil { 247 | return nil 248 | } 249 | 250 | //3.write Image 251 | data, err = writer.Reserve(4) 252 | if err != nil { 253 | return err 254 | } 255 | binary.BigEndian.PutUint32(data, uint32(len(r.Image))) 256 | if _, err = writer.WriteBytes(r.Image); err != nil { 257 | return err 258 | } 259 | return nil 260 | } 261 | 262 | func (r *Response) Reset() { 263 | r.ID = 0 264 | r.Name = "" 265 | r.Image = r.Image[:0] 266 | } 267 | 268 | func MustWrite(conn net.Conn, data []byte) { 269 | written := 0 270 | for written < len(data) { 271 | n, err := conn.Write(data[written:]) 272 | if err != nil { 273 | panic(err) 274 | } 275 | written += n 276 | } 277 | } 278 | -------------------------------------------------------------------------------- /example/best_practice/net_client/main.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package main 18 | 19 | import ( 20 | "flag" 21 | "fmt" 22 | "math/rand" 23 | "net" 24 | "net/http" 25 | _ "net/http/pprof" 26 | "os" 27 | "path/filepath" 28 | "runtime" 29 | "sync/atomic" 30 | "time" 31 | 32 | "github.com/cloudwego/shmipc-go/example/best_practice/idl" 33 | ) 34 | 35 | var count uint64 36 | 37 | func init() { 38 | go func() { 39 | lastCount := count 40 | for range time.Tick(time.Second) { 41 | curCount := atomic.LoadUint64(&count) 42 | fmt.Println("net_client qps:", curCount-lastCount) 43 | lastCount = curCount 44 | } 45 | }() 46 | go func() { 47 | http.ListenAndServe(":20001", nil) //nolint:errcheck 48 | }() 49 | runtime.GOMAXPROCS(1) 50 | } 51 | 52 | func main() { 53 | packageSize := flag.Int("p", 1024, "-p 1024 mean that request and response's size are both near 1KB") 54 | flag.Parse() 55 | 56 | randContent := make([]byte, *packageSize) 57 | rand.Read(randContent) 58 | 59 | // 1. dial unix domain socket 60 | dir, err := os.Getwd() 61 | if err != nil { 62 | panic(err) 63 | } 64 | 65 | address := filepath.Join(dir, "../ipc_test.sock") 66 | network := "unix" 67 | 68 | concurrency := 500 69 | qps := 50000000 70 | 71 | for i := 0; i < concurrency; i++ { 72 | go func() { 73 | req := &idl.Request{} 74 | resp := &idl.Response{} 75 | n := qps / concurrency 76 | conn, err := net.Dial(network, address) 77 | if err != nil { 78 | fmt.Println("dial error", err) 79 | return 80 | } 81 | for range time.Tick(time.Second) { 82 | for k := 0; k < n; k++ { 83 | now := time.Now() 84 | //serialize request 85 | req.Reset() 86 | req.ID = uint64(now.UnixNano()) 87 | req.Name = "xxx" 88 | req.Key = randContent 89 | writeBuffer := req.Serialize() 90 | idl.MustWrite(conn, writeBuffer) 91 | idl.BufferPool.Put(writeBuffer) 92 | 93 | //wait and read response 94 | buf := idl.BufferPool.Get().([]byte) 95 | n, err := conn.Read(buf) 96 | if err != nil { 97 | fmt.Println("conn.Read error ", err) 98 | return 99 | } 100 | resp.Reset() 101 | resp.Deserialize(buf[:n]) 102 | idl.BufferPool.Put(buf) 103 | 104 | { 105 | //handle response... 106 | atomic.AddUint64(&count, 1) 107 | } 108 | } 109 | } 110 | }() 111 | } 112 | 113 | time.Sleep(1200 * time.Second) 114 | } 115 | -------------------------------------------------------------------------------- /example/best_practice/net_server/main.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package main 18 | 19 | import ( 20 | "fmt" 21 | "net" 22 | "net/http" 23 | _ "net/http/pprof" 24 | "os" 25 | "path/filepath" 26 | "runtime" 27 | "sync/atomic" 28 | "syscall" 29 | "time" 30 | 31 | "github.com/cloudwego/shmipc-go/example/best_practice/idl" 32 | ) 33 | 34 | var count uint64 35 | 36 | func handleConn(conn net.Conn) { 37 | defer conn.Close() 38 | 39 | req := &idl.Request{} 40 | resp := &idl.Response{} 41 | for { 42 | //1.deserialize Request 43 | readBuffer := idl.BufferPool.Get().([]byte) 44 | n, err := conn.Read(readBuffer) 45 | if err != nil { 46 | fmt.Println("conn.Read", err) 47 | return 48 | } 49 | req.Deserialize(readBuffer[:n]) 50 | idl.BufferPool.Put(readBuffer) 51 | 52 | { 53 | //2.handle request 54 | atomic.AddUint64(&count, 1) 55 | } 56 | 57 | //3.serialize Response 58 | resp.ID = req.ID 59 | resp.Name = req.Name 60 | resp.Image = req.Key 61 | writeBuffer := req.Serialize() 62 | idl.MustWrite(conn, writeBuffer) 63 | idl.BufferPool.Put(writeBuffer) 64 | 65 | req.Reset() 66 | resp.Reset() 67 | } 68 | } 69 | 70 | func init() { 71 | go func() { 72 | lastCount := count 73 | for range time.Tick(time.Second) { 74 | curCount := atomic.LoadUint64(&count) 75 | fmt.Println("net_server qps:", curCount-lastCount) 76 | lastCount = curCount 77 | } 78 | }() 79 | go func() { 80 | http.ListenAndServe(":20000", nil)//nolint:errcheck 81 | }() 82 | runtime.GOMAXPROCS(1) 83 | } 84 | 85 | func main() { 86 | // 1. listen unix domain socket 87 | dir, err := os.Getwd() 88 | if err != nil { 89 | panic(err) 90 | } 91 | udsPath := filepath.Join(dir, "../ipc_test.sock") 92 | 93 | _ = syscall.Unlink(udsPath) 94 | ln, err := net.ListenUnix("unix", &net.UnixAddr{Name: udsPath, Net: "unix"}) 95 | if err != nil { 96 | panic(err) 97 | } 98 | defer ln.Close() 99 | 100 | // 2. accept a unix domain socket 101 | for { 102 | conn, err := ln.Accept() 103 | if err != nil { 104 | fmt.Printf("accept error:%s now exit \n", err.Error()) 105 | return 106 | } 107 | go handleConn(conn) 108 | } 109 | 110 | fmt.Println("shmipc server exited") 111 | } 112 | -------------------------------------------------------------------------------- /example/best_practice/run_net_client_server.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd net_server 4 | go build 5 | cd ../net_client 6 | go build 7 | 8 | cd ../net_server 9 | ./net_server & 10 | SERVER_PID=$! 11 | echo "server pid is $SERVER_PID" 12 | sleep 1s 13 | 14 | cd ../net_client 15 | ./net_client & 16 | CLIENT_PID=$! 17 | echo "client pid is $CLIENT_PID" 18 | 19 | trap 'echo "exiting, now kill client and server";kill $CLIENT_PID;kill $SERVER_PID' SIGHUP SIGINT SIGQUIT SIGALRM SIGTERM 20 | cd ../ 21 | 22 | sleep 1000s 23 | -------------------------------------------------------------------------------- /example/best_practice/run_shmipc_async_client_server.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd shmipc_async_server 4 | go build 5 | cd ../shmipc_async_client 6 | go build 7 | 8 | cd ../shmipc_async_server 9 | ./shmipc_async_server & 10 | SERVER_PID=$! 11 | echo "server pid is $SERVER_PID" 12 | sleep 1s 13 | 14 | cd ../shmipc_async_client 15 | ./shmipc_async_client & 16 | CLIENT_PID=$! 17 | echo "client pid is $CLIENT_PID" 18 | 19 | trap 'echo "exiting, now kill client and server";kill $CLIENT_PID;kill $SERVER_PID' SIGHUP SIGINT SIGQUIT SIGALRM SIGTERM 20 | cd ../ 21 | 22 | sleep 1000s 23 | -------------------------------------------------------------------------------- /example/best_practice/run_shmipc_client_server.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd shmipc_server 4 | go build 5 | cd ../shmipc_client 6 | go build 7 | 8 | cd ../shmipc_server 9 | ./shmipc_server & 10 | SERVER_PID=$! 11 | echo "server pid is $SERVER_PID" 12 | sleep 1s 13 | 14 | cd ../shmipc_client 15 | ./shmipc_client & 16 | CLIENT_PID=$! 17 | echo "client pid is $CLIENT_PID" 18 | 19 | trap 'echo "exiting, now kill client and server";kill $CLIENT_PID;kill $SERVER_PID' SIGHUP SIGINT SIGQUIT SIGALRM SIGTERM 20 | cd ../ 21 | 22 | sleep 1000s 23 | -------------------------------------------------------------------------------- /example/best_practice/shmipc_async_client/client.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package main 18 | 19 | import ( 20 | "fmt" 21 | "math" 22 | "math/rand" 23 | "net/http" 24 | _ "net/http/pprof" 25 | "os" 26 | "path/filepath" 27 | "runtime" 28 | "sync/atomic" 29 | "time" 30 | 31 | "github.com/cloudwego/shmipc-go" 32 | "github.com/cloudwego/shmipc-go/example/best_practice/idl" 33 | ) 34 | 35 | var ( 36 | count uint64 37 | _ shmipc.StreamCallbacks = &streamCbImpl{} 38 | ) 39 | 40 | func init() { 41 | go func() { 42 | lastCount := count 43 | for range time.Tick(time.Second) { 44 | curCount := atomic.LoadUint64(&count) 45 | fmt.Println("shmipc_async_client qps:", curCount-lastCount) 46 | lastCount = curCount 47 | } 48 | }() 49 | runtime.GOMAXPROCS(1) 50 | 51 | go func() { 52 | http.ListenAndServe(":20001", nil) //nolint:errcheck 53 | }() 54 | } 55 | 56 | type streamCbImpl struct { 57 | req idl.Request 58 | resp idl.Response 59 | stream *shmipc.Stream 60 | smgr *shmipc.SessionManager 61 | key []byte 62 | loop uint64 63 | n uint64 64 | } 65 | 66 | func (s *streamCbImpl) OnData(reader shmipc.BufferReader) { 67 | //wait and read response 68 | s.resp.Reset() 69 | if err := s.resp.ReadFromShm(reader); err != nil { 70 | fmt.Println("write request to share memory failed, err=" + err.Error()) 71 | return 72 | } 73 | s.stream.ReleaseReadAndReuse() 74 | 75 | { 76 | //handle response... 77 | atomic.AddUint64(&count, 1) 78 | } 79 | s.send() 80 | } 81 | 82 | func (s *streamCbImpl) send() { 83 | s.n++ 84 | if s.n >= s.loop { 85 | return 86 | } 87 | now := time.Now() 88 | //serialize request 89 | s.req.Reset() 90 | s.req.ID = uint64(now.UnixNano()) 91 | s.req.Name = "xxx" 92 | s.req.Key = s.key 93 | if err := s.req.WriteToShm(s.stream.BufferWriter()); err != nil { 94 | fmt.Println("write request to share memory failed, err=" + err.Error()) 95 | return 96 | } 97 | if err := s.stream.Flush(false); err != nil { 98 | fmt.Println(" flush request to peer failed, err=" + err.Error()) 99 | return 100 | } 101 | } 102 | 103 | func (s *streamCbImpl) OnLocalClose() { 104 | //fmt.Println("stream OnLocalClose") 105 | } 106 | 107 | func (s *streamCbImpl) OnRemoteClose() { 108 | //fmt.Println("stream OnRemoteClose") 109 | 110 | } 111 | 112 | func main() { 113 | // 1. dial unix domain socket 114 | dir, err := os.Getwd() 115 | if err != nil { 116 | panic(err) 117 | } 118 | 119 | // 2. init session manager 120 | conf := shmipc.DefaultSessionManagerConfig() 121 | conf.Address = filepath.Join(dir, "../ipc_test.sock") 122 | conf.Network = "unix" 123 | conf.SessionNum = 1 124 | conf.ShareMemoryBufferCap = 32 << 20 125 | conf.MemMapType = shmipc.MemMapTypeMemFd 126 | smgr, err := shmipc.InitGlobalSessionManager(conf) 127 | if err != nil { 128 | panic(err) 129 | } 130 | 131 | concurrency := 500 132 | 133 | for i := 0; i < concurrency; i++ { 134 | go func() { 135 | key := make([]byte, 1024) 136 | rand.Read(key) 137 | s := &streamCbImpl{key: key, smgr: smgr, loop: math.MaxUint64} 138 | stream, err := smgr.GetStream() 139 | if err != nil { 140 | fmt.Println("get stream error:" + err.Error()) 141 | return 142 | } 143 | s.stream = stream 144 | s.n = 0 145 | stream.SetCallbacks(s) 146 | s.send() 147 | // and maybe call `smgr.PutBack()` or `stream.Close()` when you are no longer using the stream. 148 | }() 149 | } 150 | 151 | time.Sleep(10030 * time.Second) 152 | fmt.Println("smgr.Close():", smgr.Close()) 153 | } 154 | -------------------------------------------------------------------------------- /example/best_practice/shmipc_async_server/server.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package main 18 | 19 | import ( 20 | "fmt" 21 | "net/http" 22 | _ "net/http/pprof" 23 | "os" 24 | "path/filepath" 25 | "runtime" 26 | "sync/atomic" 27 | "syscall" 28 | "time" 29 | 30 | "github.com/cloudwego/shmipc-go" 31 | "github.com/cloudwego/shmipc-go/example/best_practice/idl" 32 | ) 33 | 34 | var ( 35 | count uint64 36 | _ shmipc.StreamCallbacks = &streamCbImpl{} 37 | _ shmipc.ListenCallback = &listenCbImpl{} 38 | ) 39 | 40 | type listenCbImpl struct{} 41 | 42 | func (l *listenCbImpl) OnNewStream(s *shmipc.Stream) { 43 | s.SetCallbacks(&streamCbImpl{stream: s}) 44 | } 45 | 46 | func (l *listenCbImpl) OnShutdown(reason string) { 47 | fmt.Println("OnShutdown reason:" + reason) 48 | } 49 | 50 | type streamCbImpl struct { 51 | req idl.Request 52 | resp idl.Response 53 | stream *shmipc.Stream 54 | } 55 | 56 | func (s *streamCbImpl) OnData(reader shmipc.BufferReader) { 57 | //1.deserialize Request 58 | if err := s.req.ReadFromShm(reader); err != nil { 59 | fmt.Println("stream read request, err=" + err.Error()) 60 | return 61 | } 62 | 63 | { 64 | //2.handle request 65 | atomic.AddUint64(&count, 1) 66 | } 67 | 68 | //3.serialize Response 69 | s.resp.ID = s.req.ID 70 | s.resp.Name = s.req.Name 71 | s.resp.Image = s.req.Key 72 | if err := s.resp.WriteToShm(s.stream.BufferWriter()); err != nil { 73 | fmt.Println("stream write response failed, err=" + err.Error()) 74 | return 75 | } 76 | if err := s.stream.Flush(false); err != nil { 77 | fmt.Println("stream write response failed, err=" + err.Error()) 78 | return 79 | } 80 | s.stream.ReleaseReadAndReuse() 81 | s.req.Reset() 82 | s.resp.Reset() 83 | } 84 | 85 | func (s *streamCbImpl) OnLocalClose() { 86 | //fmt.Println("stream OnLocalClose") 87 | } 88 | 89 | func (s *streamCbImpl) OnRemoteClose() { 90 | //fmt.Println("stream OnRemoteClose") 91 | } 92 | 93 | func init() { 94 | go func() { 95 | lastCount := count 96 | for range time.Tick(time.Second) { 97 | curCount := atomic.LoadUint64(&count) 98 | fmt.Println("shmipc_async_server qps:", curCount-lastCount) 99 | lastCount = curCount 100 | } 101 | }() 102 | runtime.GOMAXPROCS(1) 103 | 104 | go func() { 105 | http.ListenAndServe(":20000", nil)//nolint:errcheck 106 | }() 107 | } 108 | 109 | func main() { 110 | // 1. listen unix domain socket 111 | dir, err := os.Getwd() 112 | if err != nil { 113 | panic(err) 114 | } 115 | udsPath := filepath.Join(dir, "../ipc_test.sock") 116 | 117 | _ = syscall.Unlink(udsPath) 118 | config := shmipc.NewDefaultListenerConfig(udsPath, "unix") 119 | ln, err := shmipc.NewListener(&listenCbImpl{}, config) 120 | if err != nil { 121 | fmt.Println(err) 122 | return 123 | } 124 | ln.Run() 125 | } 126 | -------------------------------------------------------------------------------- /example/best_practice/shmipc_client/main.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package main 18 | 19 | import ( 20 | "flag" 21 | "fmt" 22 | "math/rand" 23 | "net/http" 24 | _ "net/http/pprof" 25 | "os" 26 | "path/filepath" 27 | "runtime" 28 | "sync/atomic" 29 | "time" 30 | 31 | "github.com/cloudwego/shmipc-go" 32 | "github.com/cloudwego/shmipc-go/example/best_practice/idl" 33 | ) 34 | 35 | var count uint64 36 | 37 | func init() { 38 | go func() { 39 | lastCount := count 40 | for range time.Tick(time.Second) { 41 | curCount := atomic.LoadUint64(&count) 42 | fmt.Println("shmipc_client qps:", curCount-lastCount) 43 | lastCount = curCount 44 | } 45 | }() 46 | 47 | go func() { 48 | http.ListenAndServe(":20001", nil) //nolint:errcheck 49 | }() 50 | 51 | runtime.GOMAXPROCS(1) 52 | } 53 | 54 | func main() { 55 | packageSize := flag.Int("p", 1024, "-p 1024 mean that request and response's size are both near 1KB") 56 | flag.Parse() 57 | 58 | randContent := make([]byte, *packageSize) 59 | rand.Read(randContent) 60 | 61 | // 1. get current directory 62 | dir, err := os.Getwd() 63 | if err != nil { 64 | panic(err) 65 | } 66 | 67 | // 2. init session manager 68 | conf := shmipc.DefaultSessionManagerConfig() 69 | conf.Address = filepath.Join(dir, "../ipc_test.sock") 70 | conf.Network = "unix" 71 | conf.MemMapType = shmipc.MemMapTypeMemFd 72 | conf.SessionNum = 1 73 | conf.InitializeTimeout = 100 * time.Second 74 | smgr, err := shmipc.NewSessionManager(conf) 75 | if err != nil { 76 | panic(err) 77 | } 78 | 79 | concurrency := 500 80 | qps := 50000000 81 | 82 | for i := 0; i < concurrency; i++ { 83 | go func() { 84 | req := &idl.Request{} 85 | resp := &idl.Response{} 86 | n := qps / concurrency 87 | 88 | for range time.Tick(time.Second) { 89 | // 3. get stream from SessionManager 90 | stream, err := smgr.GetStream() 91 | if err != nil { 92 | fmt.Println("get stream error:" + err.Error()) 93 | continue 94 | } 95 | 96 | for k := 0; k < n; k++ { 97 | // 4. set request object 98 | req.Reset() 99 | req.ID = uint64(time.Now().UnixNano()) 100 | req.Name = "xxx" 101 | req.Key = randContent 102 | 103 | // 5. write req to buffer of stream 104 | if err := req.WriteToShm(stream.BufferWriter()); err != nil { 105 | fmt.Println("write request to share memory failed, err=" + err.Error()) 106 | return 107 | } 108 | 109 | // 6. flush the buffered data of stream to peer 110 | if err := stream.Flush(false); err != nil { 111 | fmt.Println(" flush request to peer failed, err=" + err.Error()) 112 | return 113 | } 114 | 115 | // 7. wait and read response 116 | resp.Reset() 117 | if err := resp.ReadFromShm(stream.BufferReader()); err != nil { 118 | fmt.Println("write request to share memory failed, err=" + err.Error()) 119 | continue 120 | } 121 | 122 | atomic.AddUint64(&count, 1) 123 | } 124 | 125 | //call `PutBack` return the stream to the stream pool for next time use, which will improve performance. 126 | smgr.PutBack(stream) 127 | //or call `stream.Close` the close Stream, otherwise it will be leak. 128 | } 129 | }() 130 | } 131 | 132 | time.Sleep(1200 * time.Second) 133 | } 134 | -------------------------------------------------------------------------------- /example/best_practice/shmipc_server/main.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package main 18 | 19 | import ( 20 | "fmt" 21 | "net" 22 | "net/http" 23 | _ "net/http/pprof" 24 | "os" 25 | "path/filepath" 26 | "runtime" 27 | "sync/atomic" 28 | "syscall" 29 | "time" 30 | 31 | "github.com/cloudwego/shmipc-go" 32 | "github.com/cloudwego/shmipc-go/example/best_practice/idl" 33 | ) 34 | 35 | var count uint64 36 | 37 | func handleStream(s *shmipc.Stream) { 38 | req := &idl.Request{} 39 | resp := &idl.Response{} 40 | for { 41 | // 1. deserialize Request 42 | if err := req.ReadFromShm(s.BufferReader()); err != nil { 43 | fmt.Println("stream read request, err=" + err.Error()) 44 | return 45 | } 46 | 47 | { 48 | // 2. handle request 49 | atomic.AddUint64(&count, 1) 50 | } 51 | 52 | // 3.serialize Response 53 | resp.ID = req.ID 54 | resp.Name = req.Name 55 | resp.Image = req.Key 56 | if err := resp.WriteToShm(s.BufferWriter()); err != nil { 57 | fmt.Println("stream write response failed, err=" + err.Error()) 58 | return 59 | } 60 | if err := s.Flush(false); err != nil { 61 | fmt.Println("stream write response failed, err=" + err.Error()) 62 | return 63 | } 64 | req.Reset() 65 | resp.Reset() 66 | } 67 | } 68 | 69 | func init() { 70 | go func() { 71 | lastCount := count 72 | for range time.Tick(time.Second) { 73 | curCount := atomic.LoadUint64(&count) 74 | fmt.Println("shmipc_server qps:", curCount-lastCount) 75 | lastCount = curCount 76 | } 77 | }() 78 | go func() { 79 | http.ListenAndServe(":20000", nil) //nolint:errcheck 80 | }() 81 | runtime.GOMAXPROCS(1) 82 | } 83 | 84 | func main() { 85 | dir, err := os.Getwd() 86 | if err != nil { 87 | panic(err) 88 | } 89 | udsPath := filepath.Join(dir, "../ipc_test.sock") 90 | 91 | // 1. listen unix domain socket 92 | _ = syscall.Unlink(udsPath) 93 | ln, err := net.ListenUnix("unix", &net.UnixAddr{Name: udsPath, Net: "unix"}) 94 | if err != nil { 95 | panic(err) 96 | } 97 | defer ln.Close() 98 | 99 | // 2. accept a unix domain socket 100 | for { 101 | conn, err := ln.Accept() 102 | if err != nil { 103 | fmt.Printf("accept error:%s now exit", err.Error()) 104 | return 105 | } 106 | go func() { 107 | defer conn.Close() 108 | 109 | // 3. create server session 110 | conf := shmipc.DefaultConfig() 111 | server, err := shmipc.Server(conn, conf) 112 | if err != nil { 113 | panic("new ipc server failed " + err.Error()) 114 | } 115 | defer server.Close() 116 | 117 | // 4. accept stream and handle 118 | for { 119 | stream, err := server.AcceptStream() 120 | if err != nil { 121 | fmt.Println("shmipc server accept stream failed, err=" + err.Error()) 122 | break 123 | } 124 | go handleStream(stream) 125 | } 126 | }() 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /example/helloworld/greeter_client/main.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package main 18 | 19 | import ( 20 | "fmt" 21 | "os" 22 | "path/filepath" 23 | "runtime" 24 | 25 | "github.com/cloudwego/shmipc-go" 26 | ) 27 | 28 | func main() { 29 | dir, err := os.Getwd() 30 | if err != nil { 31 | panic(err) 32 | } 33 | udsPath := filepath.Join(dir, "../ipc_test.sock") 34 | 35 | // 1.create client session manager 36 | conf := shmipc.DefaultSessionManagerConfig() 37 | conf.ShareMemoryPathPrefix = "/dev/shm/client.ipc.shm" 38 | conf.Network = "unix" 39 | conf.Address = udsPath 40 | if runtime.GOOS == "darwin" { 41 | conf.ShareMemoryPathPrefix = "/tmp/client.ipc.shm" 42 | conf.QueuePath = "/tmp/client.ipc.shm_queue" 43 | } 44 | 45 | s, err := shmipc.NewSessionManager(conf) 46 | if err != nil { 47 | panic("create client session failed, " + err.Error()) 48 | } 49 | defer s.Close() 50 | 51 | // 2.create stream 52 | stream, err := s.GetStream() 53 | if err != nil { 54 | panic("client open stream failed, " + err.Error()) 55 | } 56 | defer s.PutBack(stream) 57 | 58 | // 3. write message 59 | requestMsg := "client say hello world!!!" 60 | writer := stream.BufferWriter() 61 | err = writer.WriteString(requestMsg) 62 | if err != nil { 63 | panic("buffer writeString failed " + err.Error()) 64 | } 65 | 66 | // 4. flush the stream buffer data to peer 67 | fmt.Println("client stream send request:" + requestMsg) 68 | err = stream.Flush(true) 69 | if err != nil { 70 | panic("stream Flush failed," + err.Error()) 71 | } 72 | 73 | reader := stream.BufferReader() 74 | // 5.read response 75 | respData, err := reader.ReadBytes(len("server hello world!!!")) 76 | if err != nil { 77 | panic("respBuf ReadBytes failed," + err.Error()) 78 | } 79 | 80 | fmt.Println("client stream receive response " + string(respData)) 81 | } 82 | -------------------------------------------------------------------------------- /example/helloworld/greeter_server/main.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package main 18 | 19 | import ( 20 | "fmt" 21 | "net" 22 | "os" 23 | "path/filepath" 24 | "syscall" 25 | "time" 26 | 27 | "github.com/cloudwego/shmipc-go" 28 | ) 29 | 30 | func main() { 31 | // 1. listen unix domain socket 32 | dir, err := os.Getwd() 33 | if err != nil { 34 | panic(err) 35 | } 36 | udsPath := filepath.Join(dir, "../ipc_test.sock") 37 | 38 | _ = syscall.Unlink(udsPath) 39 | ln, err := net.ListenUnix("unix", &net.UnixAddr{Name: udsPath, Net: "unix"}) 40 | if err != nil { 41 | panic(err) 42 | } 43 | defer ln.Close() 44 | 45 | // 2. accept a unix domain socket 46 | conn, err := ln.Accept() 47 | if err != nil { 48 | panic(err) 49 | } 50 | defer conn.Close() 51 | 52 | // 3. create server session 53 | conf := shmipc.DefaultConfig() 54 | s, err := shmipc.Server(conn, conf) 55 | if err != nil { 56 | panic("new ipc server failed " + err.Error()) 57 | } 58 | defer s.Close() 59 | 60 | // 4. accept a stream 61 | stream, err := s.AcceptStream() 62 | if err != nil { 63 | panic("accept stream failed " + err.Error()) 64 | } 65 | defer stream.Close() 66 | 67 | // 5.read request data 68 | reader := stream.BufferReader() 69 | reqData, err := reader.ReadBytes(len("client say hello world!!!")) 70 | if err != nil { 71 | panic("reqBuf readData failed " + err.Error()) 72 | } 73 | fmt.Println("server receive request message:" + string(reqData)) 74 | 75 | // 6.write data to response buffer 76 | respMsg := "server hello world!!!" 77 | writer := stream.BufferWriter() 78 | err = writer.WriteString(respMsg) 79 | if err != nil { 80 | panic("respBuf WriteString failed " + err.Error()) 81 | } 82 | 83 | // 7.flush response buffer to peer. 84 | err = stream.Flush(true) 85 | if err != nil { 86 | panic("stream write response failed," + err.Error()) 87 | } 88 | fmt.Println("server reply response: " + respMsg) 89 | 90 | time.Sleep(time.Second) 91 | } 92 | -------------------------------------------------------------------------------- /example/hot_restart_test/README.md: -------------------------------------------------------------------------------- 1 | ### client path 2 | - build.sh compiling script 3 | - bootstrap.sh start script 4 | 5 | env PATH_PREFIX can set the buffer path, different paths are equivalent to multiple clients 6 | 7 | ### server path 8 | - build.sh compiling script 9 | - bootstrap.sh first start server script 10 | - bootstrap_hot_restart.sh hot restart server script 11 | 12 | 13 | env IS_HOT_RESTART value is true hot restart 14 | 15 | env HOT_RESTART_EPOCH epoch id of the server, each hot restart needs to be different 16 | 17 | `The hot restart test requires bootstrap.sh to start the server first,bootstrap_hot_restart.sh hot restart the new server by change the value of HOT_RESTART_EPOCH` 18 | -------------------------------------------------------------------------------- /example/hot_restart_test/client/bootstrap.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | 3 | # adjust the log level 4 | export SHMIPC_LOG_LEVEL=0 5 | export SHMIPC_DEBUG_MODE=1 6 | 7 | # PATH_PREFIX 8 | #export PATH_PREFIX="/dev/shm/test_shm_buffer" 9 | ./client 10 | 11 | -------------------------------------------------------------------------------- /example/hot_restart_test/client/build.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | 3 | go build -------------------------------------------------------------------------------- /example/hot_restart_test/client/client.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package main 18 | 19 | import ( 20 | "fmt" 21 | "net/http" 22 | _ "net/http/pprof" 23 | "os" 24 | "path/filepath" 25 | "runtime" 26 | "sync/atomic" 27 | "time" 28 | 29 | "github.com/cloudwego/shmipc-go" 30 | ) 31 | 32 | var ( 33 | count uint64 34 | errCount uint64 35 | 36 | sendStr = "hello world" 37 | ) 38 | 39 | func init() { 40 | go func() { 41 | lastCount := atomic.LoadUint64(&count) 42 | for range time.Tick(time.Second) { 43 | curCount := atomic.LoadUint64(&count) 44 | err := atomic.LoadUint64(&errCount) 45 | fmt.Println("qps:", curCount-lastCount, " errCount = ", err, " count ", atomic.LoadUint64(&count)) 46 | lastCount = curCount 47 | } 48 | }() 49 | runtime.GOMAXPROCS(4) 50 | 51 | go func() { 52 | http.ListenAndServe(":20001", nil)//nolint:errcheck 53 | }() 54 | } 55 | 56 | func addCount(isErr bool) { 57 | if isErr { 58 | atomic.AddUint64(&errCount, 1) 59 | } 60 | atomic.AddUint64(&count, 1) 61 | } 62 | 63 | func main() { 64 | // 1. dial unix domain socket 65 | dir, err := os.Getwd() 66 | if err != nil { 67 | panic(err) 68 | } 69 | 70 | // 2. init session manager 71 | conf := shmipc.DefaultSessionManagerConfig() 72 | conf.Address = filepath.Join(dir, "../ipc_test.sock") 73 | conf.Network = "unix" 74 | conf.SessionNum = 4 75 | conf.ShareMemoryBufferCap = 32 << 20 76 | //conf.MemMapType = shmipc.MemMapTypeMemFd 77 | conf.MemMapType = shmipc.MemMapTypeDevShmFile 78 | if os.Getenv("PATH_PREFIX") != "" { 79 | conf.ShareMemoryPathPrefix = os.Getenv("PATH_PREFIX") 80 | } 81 | smgr, err := shmipc.InitGlobalSessionManager(conf) 82 | if err != nil { 83 | panic(err) 84 | } 85 | 86 | for i := 0; i < 50; i++ { 87 | go func() { 88 | for { 89 | stream, err := smgr.GetStream() 90 | if err != nil { 91 | fmt.Printf("stream GetStream error %+v\n", err) 92 | time.Sleep(time.Second * 5) 93 | continue 94 | } 95 | 96 | err = stream.BufferWriter().WriteString(sendStr) 97 | if err != nil { 98 | fmt.Printf("stream WriteString error %+v\n", err) 99 | addCount(true) 100 | continue 101 | } 102 | 103 | err = stream.Flush(false) 104 | if err != nil { 105 | fmt.Printf("stream Flush error %+v\n", err) 106 | addCount(true) 107 | continue 108 | } 109 | 110 | ret, err := stream.BufferReader().ReadString(len("hello world")) 111 | if err != nil { 112 | fmt.Printf("stream ReadString error %+v\n", err) 113 | addCount(true) 114 | } else { 115 | addCount(false) 116 | } 117 | 118 | if ret != sendStr { 119 | fmt.Printf("stream ret %+v err\n", ret) 120 | } 121 | 122 | smgr.PutBack(stream) 123 | 124 | time.Sleep(time.Millisecond * 10) 125 | } 126 | }() 127 | } 128 | 129 | time.Sleep(100000 * time.Second) 130 | fmt.Println("smgr.Close():", smgr.Close()) 131 | } 132 | -------------------------------------------------------------------------------- /example/hot_restart_test/server/bootstrap.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | 3 | export SHMIPC_LOG_LEVEL=0 4 | export SHMIPC_DEBUG_MODE=1 5 | export DEBUG_PORT=20000 6 | ./server 7 | 8 | -------------------------------------------------------------------------------- /example/hot_restart_test/server/bootstrap_hot_restart.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | 3 | export SHMIPC_LOG_LEVEL=0 4 | export SHMIPC_DEBUG_MODE=1 5 | 6 | ## hot restart 7 | export IS_HOT_RESTART=1 8 | export HOT_RESTART_EPOCH=1122 9 | export DEBUG_PORT=20002 10 | ./server 11 | 12 | -------------------------------------------------------------------------------- /example/hot_restart_test/server/build.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | 3 | go build -------------------------------------------------------------------------------- /example/hot_restart_test/server/callback_impl.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package main 18 | 19 | import ( 20 | "fmt" 21 | "sync/atomic" 22 | 23 | "github.com/cloudwego/shmipc-go" 24 | ) 25 | 26 | var ( 27 | _ shmipc.ListenCallback = &listenCbImpl{} 28 | _ shmipc.StreamCallbacks = &streamCbImpl{} 29 | ) 30 | 31 | type listenCbImpl struct{} 32 | 33 | func (l listenCbImpl) OnNewStream(s *shmipc.Stream) { 34 | //fmt.Printf("OnNewStream StreamID %d\n", s.StreamID()) 35 | if err := s.SetCallbacks(streamCbImpl{stream: s}); err != nil { 36 | fmt.Printf("OnNewStream SetCallbacks error %+v\n", err) 37 | } 38 | } 39 | 40 | func (l listenCbImpl) OnShutdown(reason string) { 41 | //fmt.Printf("OnShutdown reason %s\n", reason) 42 | } 43 | 44 | type streamCbImpl struct { 45 | stream *shmipc.Stream 46 | } 47 | 48 | func (s streamCbImpl) OnData(reader shmipc.BufferReader) { 49 | ret, err := reader.ReadString(len(sendStr)) 50 | if err != nil { 51 | fmt.Printf("OnData stream ReadString request error %+v\n", err) 52 | atomic.AddUint64(&errCount, 1) 53 | return 54 | } 55 | 56 | if ret != sendStr { 57 | fmt.Println("OnData ret ", ret) 58 | } 59 | 60 | atomic.AddUint64(&count, 1) 61 | 62 | err = s.stream.BufferWriter().WriteString(sendStr) 63 | if err != nil { 64 | fmt.Printf("OnData stream WriteString response failed error %+v\n", err) 65 | return 66 | } 67 | err = s.stream.Flush(false) 68 | if err != nil { 69 | fmt.Printf("OnData stream Flush response failed error %+v\n", err) 70 | return 71 | } 72 | 73 | s.stream.ReleaseReadAndReuse() 74 | } 75 | 76 | func (s streamCbImpl) OnLocalClose() { 77 | fmt.Printf("StreamCallbacks OnLocalClose StreamID %+v\n", s.stream.StreamID()) 78 | } 79 | 80 | func (s streamCbImpl) OnRemoteClose() { 81 | fmt.Printf("StreamCallbacks OnRemoteClose StreamID %+v\n", s.stream.StreamID()) 82 | } 83 | -------------------------------------------------------------------------------- /example/hot_restart_test/server/server.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package main 18 | 19 | import ( 20 | "encoding/binary" 21 | "fmt" 22 | "net" 23 | "net/http" 24 | _ "net/http/pprof" 25 | "os" 26 | "path/filepath" 27 | "runtime" 28 | "strconv" 29 | "sync/atomic" 30 | "syscall" 31 | "time" 32 | 33 | "github.com/cloudwego/shmipc-go" 34 | ) 35 | 36 | var ( 37 | count uint64 38 | errCount uint64 39 | 40 | sendStr = "hello world" 41 | 42 | udsPath = "" 43 | adminPath = "" 44 | 45 | ENV_IS_HOT_RESTART_KEY = "IS_HOT_RESTART" 46 | ENV_HOT_RESTART_EPOCH_KEY = "HOT_RESTART_EPOCH" 47 | 48 | hotRestartEpochId = 0 49 | ) 50 | 51 | func main() { 52 | // Initialize some variables, start debug port, print qps info ... 53 | Init() 54 | 55 | if os.Getenv(ENV_IS_HOT_RESTART_KEY) == "1" { 56 | // hotrestart server 57 | restart() 58 | } else { 59 | // first startup 60 | start() 61 | } 62 | } 63 | 64 | func Init() { 65 | go func() { 66 | lastCount := count 67 | for range time.Tick(time.Second) { 68 | curCount := atomic.LoadUint64(&count) 69 | fmt.Println("qps:", curCount-lastCount, " errcount ", atomic.LoadUint64(&errCount), " count ", atomic.LoadUint64(&count)) 70 | lastCount = curCount 71 | } 72 | }() 73 | runtime.GOMAXPROCS(4) 74 | 75 | dir, err := os.Getwd() 76 | if err != nil { 77 | panic(err) 78 | } 79 | udsPath = filepath.Join(dir, "../ipc_test.sock") 80 | adminPath = filepath.Join(dir, "../admin.sock") 81 | fmt.Printf("shmipc udsPath %s adminPath %s\n", udsPath, adminPath) 82 | 83 | debugPort := 20000 84 | if os.Getenv("DEBUG_PORT") != "" { 85 | debugPort, err = strconv.Atoi(os.Getenv("DEBUG_PORT")) 86 | if err != nil { 87 | panic(err) 88 | } 89 | } 90 | 91 | go func() { 92 | http.ListenAndServe(fmt.Sprintf(":%d", debugPort), nil)//nolint:errcheck 93 | }() 94 | } 95 | 96 | func listenFD(ln net.Listener) (fd int, f *os.File) { 97 | defer func() { 98 | if fd < 0 && f == nil { 99 | panic(fmt.Errorf("hot-restart can't parse listener, which type %T", ln)) 100 | } 101 | }() 102 | if getter, ok := ln.(interface{ Fd() (fd int) }); ok { 103 | return getter.Fd(), nil 104 | } 105 | switch l := ln.(type) { 106 | case *net.TCPListener: 107 | f, _ = l.File() 108 | case *net.UnixListener: 109 | f, _ = l.File() 110 | } 111 | return int(f.Fd()), f 112 | } 113 | 114 | // admin handle hotrestart 115 | func listenAdmin(svr *shmipc.Listener, admin net.Listener) { 116 | // step 1. accept connection, begin hotrestart 117 | adminln, ok := admin.(*net.UnixListener) 118 | if !ok { 119 | panic("admin ln error") 120 | } 121 | adminln.SetUnlinkOnClose(false) 122 | conn, err := adminln.AcceptUnix() 123 | if err != nil { 124 | panic(fmt.Errorf("adminln AcceptUnix error %+v", err)) 125 | } 126 | 127 | // step 2. get admin listener fd 128 | adminFd, _ := listenFD(admin) 129 | fmt.Printf("dump listenAdmin adminFd %d\n", adminFd) 130 | syscall.SetNonblock(adminFd, true) 131 | 132 | // step 3. send admin listener fd 133 | var buf [8]byte // recv epoch id, 8 bytes 134 | rights := syscall.UnixRights(adminFd) 135 | var writeN, oobN int 136 | writeN, oobN, err = conn.WriteMsgUnix(buf[:], rights, nil) 137 | fmt.Println("WriteMsgUnix writeN ", writeN, " oobN ", oobN, " err ", err) 138 | if err != nil { 139 | panic(err) 140 | } 141 | 142 | // step 4. after recv epoch id, the new server is ready for hot restart 143 | _, err = conn.Read(buf[:]) 144 | if err != nil { 145 | panic(err) 146 | } 147 | epochId := binary.BigEndian.Uint64(buf[:]) 148 | fmt.Println("recv new server epoch id ", epochId) 149 | if epochId <= 0 { 150 | panic("epochId error") 151 | } 152 | 153 | // step 5. begin shmipc hot restart 154 | fmt.Println("old server begin shmipc hot restart epoch id = ", epochId) 155 | err = svr.HotRestart(uint64(epochId)) 156 | if err != nil { 157 | panic(err) 158 | } 159 | 160 | // step 6. wait for the hot restart done 161 | for { 162 | if svr.IsHotRestartDone() { 163 | break 164 | } 165 | time.Sleep(50 * time.Millisecond) 166 | } 167 | fmt.Println("old server finish shmipc hot restart") 168 | 169 | time.Sleep(1 * time.Second) 170 | // step 6. hot restart done, old server exit 171 | err = svr.Close() 172 | if err != nil { 173 | panic(err) 174 | } 175 | err = admin.Close() 176 | if err != nil { 177 | panic(err) 178 | } 179 | } 180 | 181 | func recvFds(conn *net.UnixConn) []int { 182 | var buf [1]byte 183 | var rightsBuf [1024]byte 184 | readN, oobN, _, _, err := conn.ReadMsgUnix(buf[:], rightsBuf[:]) 185 | fmt.Println("readN = ", readN, " oobN = ", oobN) 186 | if err != nil { 187 | panic(err) 188 | } 189 | 190 | rights := rightsBuf[:oobN] 191 | ctrlMsgs, err := syscall.ParseSocketControlMessage(rights) 192 | if err != nil { 193 | panic(err) 194 | } 195 | fds, err := syscall.ParseUnixRights(&ctrlMsgs[0]) 196 | if err != nil { 197 | panic(err) 198 | } 199 | 200 | return fds 201 | } 202 | 203 | func rebuildListener(fd int) (net.Listener, error) { 204 | file := os.NewFile(uintptr(fd), "") 205 | if file == nil { 206 | return nil, fmt.Errorf("hot-restart failed to new file with fd %d", fd) 207 | } 208 | // can't close file here ! 209 | ln, err := net.FileListener(file) 210 | if err != nil { 211 | return nil, err 212 | } 213 | return ln, nil 214 | } 215 | 216 | func start() { 217 | _ = syscall.Unlink(udsPath) 218 | syscall.Unlink(adminPath) 219 | fmt.Printf("server normal start\n") 220 | 221 | // step 1. create shmipc listener 222 | config := shmipc.NewDefaultListenerConfig(udsPath, "unix") 223 | shmipcListener, err := shmipc.NewListener(&listenCbImpl{}, config) 224 | if err != nil { 225 | panic(fmt.Errorf("shmipc NewListener error %+v", err)) 226 | } 227 | shmipcListener.SetUnlinkOnClose(false) 228 | 229 | // step 2. create admin listener 230 | adminln, err := net.Listen("unix", adminPath) 231 | if err != nil { 232 | panic(fmt.Errorf("listener adminPath %s error %+v", adminPath, err)) 233 | } 234 | if admin, ok := adminln.(*net.UnixListener); ok { 235 | admin.SetUnlinkOnClose(false) 236 | } 237 | 238 | // step 3. admin handle hot restart protocols 239 | go listenAdmin(shmipcListener, adminln) 240 | 241 | // step 4. server run 242 | err = shmipcListener.Run() 243 | fmt.Println("shmipcListener Run Ret err ", err) 244 | } 245 | 246 | func restart() { 247 | // step 1. send Epoch Id to old server 248 | epoch := os.Getenv(ENV_HOT_RESTART_EPOCH_KEY) 249 | var err error 250 | hotRestartEpochId, err = strconv.Atoi(epoch) 251 | fmt.Println("hot restart epoch id = ", hotRestartEpochId) 252 | if err != nil { 253 | panic(fmt.Errorf("%s parse hotRestartEpochId error %+v", ENV_HOT_RESTART_EPOCH_KEY, err)) 254 | } 255 | 256 | // step 2. connect admin socket 257 | conn, err := net.DialUnix("unix", nil, &net.UnixAddr{Name: adminPath, Net: "unix"}) 258 | if err != nil { 259 | panic(err) 260 | } 261 | 262 | // step 3. recv admin fds 263 | fds := recvFds(conn) 264 | fmt.Println("new server recv fds ", fds) 265 | if len(fds) != 1 { 266 | panic(fmt.Errorf("fds len is err")) 267 | } 268 | 269 | // step 3. use fd rebuild listener 270 | var adminln net.Listener 271 | adminln, err = rebuildListener(fds[0]) 272 | if err != nil { 273 | panic(err) 274 | } 275 | fmt.Println("rebuildListener adminln.addr ", adminln.Addr()) 276 | 277 | // step 4. create new shmipc listener 278 | config := shmipc.NewDefaultListenerConfig(udsPath, "unix") 279 | srvLn, err := shmipc.NewListener(&listenCbImpl{}, config) 280 | if err != nil { 281 | fmt.Println(err) 282 | return 283 | } 284 | srvLn.SetUnlinkOnClose(false) 285 | 286 | // step 5. send to old server 287 | go func() { 288 | // sleep make sure new shmipc listener run 289 | time.Sleep(500 * time.Millisecond) 290 | fmt.Println("start hot restart") 291 | 292 | epochId := make([]byte, 8) 293 | binary.BigEndian.PutUint64(epochId, uint64(hotRestartEpochId)) 294 | _, err := conn.Write(epochId) 295 | if err != nil { 296 | panic(err) 297 | } 298 | 299 | _ = conn.Close() 300 | // new server continues to perform admin duties 301 | go listenAdmin(srvLn, adminln) 302 | }() 303 | 304 | // step 6. new server run 305 | err = srvLn.Run() 306 | fmt.Println("server Run ret err ", err) 307 | } 308 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/cloudwego/shmipc-go 2 | 3 | go 1.15 4 | 5 | require ( 6 | github.com/bytedance/gopkg v0.0.0-20220817015305-b879a72dc90f 7 | github.com/kr/pretty v0.1.0 // indirect 8 | github.com/shirou/gopsutil/v3 v3.22.1 9 | github.com/stretchr/testify v1.8.2 10 | golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2 11 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect 12 | ) 13 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/bytedance/gopkg v0.0.0-20220817015305-b879a72dc90f h1:U3Bk6S9UyqFM5tU3bZ3pwqx5xyypHP7Bm2QCbOUwxSc= 2 | github.com/bytedance/gopkg v0.0.0-20220817015305-b879a72dc90f/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q= 3 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 5 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 6 | github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= 7 | github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= 8 | github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 9 | github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= 10 | github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= 11 | github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= 12 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 13 | github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= 14 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 15 | github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= 16 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 17 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 18 | github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= 19 | github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= 20 | github.com/shirou/gopsutil/v3 v3.22.1 h1:33y31Q8J32+KstqPfscvFwBlNJ6xLaBy4xqBXzlYV5w= 21 | github.com/shirou/gopsutil/v3 v3.22.1/go.mod h1:WapW1AOOPlHyXr+yOyw3uYx36enocrtSoSBy0L5vUHY= 22 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 23 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 24 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 25 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 26 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 27 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 28 | github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= 29 | github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 30 | github.com/tklauser/go-sysconf v0.3.9/go.mod h1:11DU/5sG7UexIrp/O6g35hrWzu0JxlwQ3LSFUzyeuhs= 31 | github.com/tklauser/numcpus v0.3.0/go.mod h1:yFGUr7TUHQRAhyqBcEg0Ge34zDBAsIvJJcyE6boqnA8= 32 | github.com/yusufpapurcu/wmi v1.2.2 h1:KBNDSne4vP5mbSWnJbO+51IMOXJB67QiYCSBrubbPRg= 33 | github.com/yusufpapurcu/wmi v1.2.2/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= 34 | golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 35 | golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 36 | golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 37 | golang.org/x/sys v0.0.0-20210816074244-15123e1e1f71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 38 | golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 39 | golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 40 | golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2 h1:fqTvyMIIj+HRzMmnzr9NtpHP6uVpvB5fkHcgPDC4nu8= 41 | golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 42 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 43 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 44 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= 45 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 46 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 47 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 48 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 49 | -------------------------------------------------------------------------------- /listener.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package shmipc 18 | 19 | import ( 20 | "errors" 21 | "fmt" 22 | "net" 23 | "os" 24 | "runtime" 25 | "strings" 26 | "sync" 27 | "sync/atomic" 28 | "time" 29 | ) 30 | 31 | // ListenCallback is server's asynchronous API 32 | type ListenCallback interface { 33 | //OnNewStream was called when accept a new stream 34 | OnNewStream(s *Stream) 35 | //OnShutdown was called when the listener was stopped 36 | OnShutdown(reason string) 37 | } 38 | 39 | // ListenerConfig is the configuration of Listener 40 | type ListenerConfig struct { 41 | *Config 42 | Network string //Only support unix or tcp 43 | //If Network is "tcp', the ListenPath is ip address and port, such as 0.0.0.0:6666(ipv4), [::]:6666 (ipv6) 44 | //If Network is "unix", the ListenPath is a file path, such as /your/socket/path/xx_shmipc.sock 45 | ListenPath string 46 | } 47 | 48 | // Listener listen socket and accept connection as shmipc server connection 49 | type Listener struct { 50 | mu sync.Mutex 51 | dispatcher dispatcher 52 | config *ListenerConfig 53 | sessions *sessions 54 | ln net.Listener 55 | logger *logger 56 | callback ListenCallback 57 | shutdownErrStr string 58 | isClose bool 59 | state sessionSateType 60 | epoch uint64 61 | hotRestartAckCount int 62 | unlinkOnClose bool 63 | } 64 | 65 | // NewDefaultListenerConfig return the default Listener's config 66 | func NewDefaultListenerConfig(listenPath string, network string) *ListenerConfig { 67 | return &ListenerConfig{ 68 | Config: DefaultConfig(), 69 | Network: network, 70 | ListenPath: listenPath, 71 | } 72 | } 73 | 74 | // NewListener will try listen the ListenPath of the configuration, and return the Listener if no error happened. 75 | func NewListener(callback ListenCallback, config *ListenerConfig) (*Listener, error) { 76 | if callback == nil { 77 | return nil, errors.New("ListenCallback couldn't be nil") 78 | } 79 | 80 | if runtime.GOOS != "linux" { 81 | return nil, fmt.Errorf("only support linux OS") 82 | } 83 | 84 | if config.MemMapType == MemMapTypeMemFd && config.Network != "unix" { 85 | return nil, errors.New("config.Network must be unix when config.MemMapType is MemMapTypeMemFd") 86 | } 87 | 88 | if config.Network == "unix" { 89 | safeRemoveUdsFile(config.ListenPath) 90 | } 91 | 92 | ln, err := net.Listen(config.Network, config.ListenPath) 93 | 94 | if err != nil { 95 | return nil, fmt.Errorf("create listener failed, reason%s", err.Error()) 96 | } 97 | 98 | return &Listener{ 99 | config: config, 100 | ln: ln, 101 | dispatcher: defaultDispatcher, 102 | sessions: newSessions(), 103 | logger: newLogger("listener", nil), 104 | callback: callback, 105 | unlinkOnClose: true, 106 | }, nil 107 | } 108 | 109 | // Close closes the listener. 110 | // Any blocked Accept operations will be unblocked and return errors. 111 | func (l *Listener) Close() error { 112 | l.mu.Lock() 113 | defer l.mu.Unlock() 114 | if l.isClose { 115 | return nil 116 | } 117 | l.isClose = true 118 | 119 | if l.shutdownErrStr != "" { 120 | l.callback.OnShutdown(l.shutdownErrStr) 121 | } else { 122 | l.callback.OnShutdown("close by Listener.Close()") 123 | } 124 | l.ln.Close() 125 | if l.config.Network == "unix" && l.unlinkOnClose { 126 | os.Remove(l.config.ListenPath) 127 | } 128 | l.sessions.closeAll() 129 | return nil 130 | } 131 | 132 | // Addr returns the listener's network address. 133 | func (l *Listener) Addr() net.Addr { 134 | return l.ln.Addr() 135 | } 136 | 137 | // Accept doesn't work, whose existence just adapt to the net.Listener interface. 138 | func (l *Listener) Accept() (net.Conn, error) { 139 | return nil, errors.New("not support now, just compact net.Listener interface") 140 | } 141 | 142 | // Run starting a loop to listen socket 143 | func (l *Listener) Run() error { 144 | for { 145 | conn, err := l.ln.Accept() 146 | if err != nil { 147 | if nerr, ok := err.(net.Error); ok && nerr.Temporary() { 148 | continue 149 | } 150 | if strings.Contains(err.Error(), "too many open file") { 151 | time.Sleep(10 * time.Millisecond) 152 | continue 153 | } 154 | l.shutdownErrStr = "accept failed,reason:" + err.Error() 155 | l.logger.errorf("run accept error %s", l.shutdownErrStr) 156 | l.Close() 157 | break 158 | } 159 | 160 | configCopy := *l.config.Config 161 | configCopy.listenCallback = &sessionCallback{l} 162 | session, err := newSession(&configCopy, conn, false) 163 | if err != nil { 164 | conn.Close() 165 | l.logger.warnf("new server session failed, reason" + err.Error()) 166 | continue 167 | } 168 | session.listener = l 169 | l.sessions.add(session) 170 | } 171 | return nil 172 | } 173 | 174 | // HotRestart will do shmipc server hot restart 175 | func (l *Listener) HotRestart(epoch uint64) error { 176 | l.logger.warnf("begin HotRestart epoch:%d", epoch) 177 | 178 | l.mu.Lock() 179 | defer l.mu.Unlock() 180 | 181 | if l.state == hotRestartState { 182 | return ErrHotRestartInProgress 183 | } 184 | 185 | l.state = hotRestartState 186 | l.epoch = epoch 187 | 188 | l.sessions.sessionMu.Lock() 189 | defer l.sessions.sessionMu.Unlock() 190 | 191 | for session := range l.sessions.data { 192 | if !session.handshakeDone { 193 | return ErrInHandshakeStage 194 | } 195 | if session.state != defaultState { 196 | continue 197 | } 198 | if err := session.hotRestart(epoch, typeHotRestart); err != nil { 199 | session.logger.warnf("%s hotRestart epoch %d error %+v", session.name, epoch, err) 200 | l.state = defaultState 201 | return err 202 | } 203 | session.state = hotRestartState 204 | l.hotRestartAckCount++ 205 | } 206 | 207 | go func() { 208 | l.checkHotRestart() 209 | }() 210 | 211 | return nil 212 | } 213 | 214 | // IsHotRestartDone return whether the Listener is under the hot restart state. 215 | func (l *Listener) IsHotRestartDone() bool { 216 | l.mu.Lock() 217 | defer l.mu.Unlock() 218 | 219 | return l.state != hotRestartState 220 | } 221 | 222 | // SetUnlinkOnClose sets whether unlink unix socket path when Listener was stopped 223 | func (l *Listener) SetUnlinkOnClose(unlink bool) { 224 | l.mu.Lock() 225 | defer l.mu.Unlock() 226 | 227 | l.unlinkOnClose = unlink 228 | if lnUnix, ok := l.ln.(*net.UnixListener); ok { 229 | lnUnix.SetUnlinkOnClose(unlink) 230 | } 231 | } 232 | 233 | // since the hot restart process is establish connection operation 234 | // so waiting timeout is set short 235 | func (l *Listener) checkHotRestart() { 236 | timeout := time.NewTimer(hotRestartCheckTimeout) 237 | defer timeout.Stop() 238 | ticker := time.NewTicker(hotRestartCheckInterval) 239 | defer ticker.Stop() 240 | 241 | for { 242 | select { 243 | case <-ticker.C: 244 | l.mu.Lock() 245 | if l.state != hotRestartState { 246 | l.mu.Unlock() 247 | return 248 | } 249 | 250 | if l.hotRestartAckCount == 0 { 251 | l.logger.warnf("[epoch:%d] checkHotRestart done", l.epoch) 252 | l.state = hotRestartDoneState 253 | l.sessions.onHotRestart(true) 254 | l.mu.Unlock() 255 | return 256 | } 257 | 258 | l.mu.Unlock() 259 | case <-timeout.C: 260 | l.logger.errorf("[epoch:%d] checkHotRestart timeout", l.epoch) 261 | l.resetState() 262 | l.sessions.onHotRestart(false) 263 | return 264 | } 265 | } 266 | } 267 | 268 | func (l *Listener) resetState() { 269 | l.mu.Lock() 270 | defer l.mu.Unlock() 271 | 272 | l.state = defaultState 273 | l.hotRestartAckCount = 0 274 | 275 | l.sessions.sessionMu.Lock() 276 | defer l.sessions.sessionMu.Unlock() 277 | 278 | for session := range l.sessions.data { 279 | session.state = defaultState 280 | } 281 | } 282 | 283 | type sessionCallback struct { 284 | listener *Listener 285 | } 286 | 287 | func (c *sessionCallback) OnNewStream(s *Stream) { 288 | c.listener.callback.OnNewStream(s) 289 | } 290 | 291 | func (c *sessionCallback) OnShutdown(reason string) { 292 | c.listener.logger.warnf("on session shutdown,reason:" + reason) 293 | c.listener.sessions.removeShutdownSession() 294 | } 295 | 296 | var _ ListenCallback = &sessionCallback{} 297 | 298 | type sessions struct { 299 | sessionMu sync.Mutex 300 | data map[*Session]struct{} 301 | } 302 | 303 | func newSessions() *sessions { return &sessions{data: make(map[*Session]struct{}, 8)} } 304 | 305 | func (s *sessions) add(session *Session) { 306 | s.sessionMu.Lock() 307 | if s.data != nil { 308 | s.data[session] = struct{}{} 309 | } else { 310 | session.logger.warnf("listener is closed, session %s will not be add", session.name) 311 | session.Close() 312 | } 313 | s.sessionMu.Unlock() 314 | } 315 | 316 | func (s *sessions) removeShutdownSession() { 317 | s.sessionMu.Lock() 318 | for session := range s.data { 319 | if session.IsClosed() { 320 | delete(s.data, session) 321 | } 322 | } 323 | s.sessionMu.Unlock() 324 | } 325 | 326 | func (s *sessions) closeAll() { 327 | s.sessionMu.Lock() 328 | toCloseSessions := s.data 329 | s.data = nil 330 | s.sessionMu.Unlock() 331 | for session := range toCloseSessions { 332 | session.Close() 333 | } 334 | } 335 | 336 | func (s *sessions) onHotRestart(success bool) { 337 | s.sessionMu.Lock() 338 | for session := range s.data { 339 | if success { 340 | atomic.AddUint64(&session.stats.hotRestartSuccessCount, 1) 341 | } else { 342 | atomic.AddUint64(&session.stats.hotRestartErrorCount, 1) 343 | } 344 | } 345 | s.sessionMu.Unlock() 346 | } 347 | -------------------------------------------------------------------------------- /listener_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package shmipc 18 | 19 | import ( 20 | "fmt" 21 | "testing" 22 | "time" 23 | 24 | "github.com/stretchr/testify/assert" 25 | ) 26 | 27 | var ( 28 | hotRestartUDSPath = "/tmp/hot_restart_test.sock" 29 | 30 | firstMsg = "hello start" 31 | secondMsg = "hello restart" 32 | 33 | oldListenerExit chan struct{} 34 | firstMsgDone chan struct{} 35 | secondMsgDone chan struct{} 36 | 37 | _ ListenCallback = &listenCbImpl{} 38 | _ StreamCallbacks = &streamCbImpl{} 39 | ) 40 | 41 | type listenCbImpl struct{} 42 | 43 | func (l listenCbImpl) OnNewStream(s *Stream) { 44 | if err := s.SetCallbacks(streamCbImpl{stream: s}); err != nil { 45 | fmt.Printf("OnNewStream SetCallbacks error %+v\n", err) 46 | } 47 | } 48 | 49 | func (l listenCbImpl) OnShutdown(reason string) { 50 | } 51 | 52 | type streamCbImpl struct { 53 | stream *Stream 54 | } 55 | 56 | func (s streamCbImpl) OnData(reader BufferReader) { 57 | _, _ = reader.Peek(1) 58 | ret, err := reader.ReadString(reader.Len()) 59 | if err != nil { 60 | fmt.Println("streamCbImpl OnData err ", err) 61 | return 62 | } 63 | 64 | if ret == firstMsg { 65 | fmt.Println("old server recv msg") 66 | _ = s.stream.BufferWriter().WriteString(ret) 67 | s.stream.Flush(false) 68 | s.stream.ReleaseReadAndReuse() 69 | 70 | firstMsgDone <- struct{}{} 71 | } 72 | 73 | if ret == secondMsg { 74 | fmt.Println("new server recv msg") 75 | _ = s.stream.BufferWriter().WriteString(ret) 76 | s.stream.Flush(false) 77 | s.stream.ReleaseReadAndReuse() 78 | 79 | secondMsgDone <- struct{}{} 80 | } 81 | } 82 | 83 | func (s streamCbImpl) OnLocalClose() { 84 | fmt.Println("streamCbImpl OnLocalClose") 85 | } 86 | 87 | func (s streamCbImpl) OnRemoteClose() { 88 | fmt.Println("streamCbImpl OnRemoteClose") 89 | } 90 | 91 | func getSessionManagerConfig() *SessionManagerConfig { 92 | conf := DefaultSessionManagerConfig() 93 | conf.Address = hotRestartUDSPath 94 | conf.Network = "unix" 95 | conf.SessionNum = 4 96 | conf.MemMapType = MemMapTypeMemFd 97 | 98 | return conf 99 | } 100 | 101 | func getListenerConfig() *ListenerConfig { 102 | return NewDefaultListenerConfig(hotRestartUDSPath, "unix") 103 | } 104 | 105 | func genListenerByConfig(config *ListenerConfig) *Listener { 106 | listener, err := NewListener(&listenCbImpl{}, config) 107 | if err != nil { 108 | fmt.Println("NewListener error ", err) 109 | return nil 110 | } 111 | return listener 112 | } 113 | 114 | // mock hot restart, more details in the directory example/hot_restart_test 115 | // step1. start one server and one client 116 | // step2. client send the message `hello start` 117 | // step3. after server receive message, which will start a new server and do hot restart. 118 | // step4. client send the message `hello restart` 119 | // step5. the new server will receive the message `hello restart` 120 | func TestHotRestart(t *testing.T) { 121 | firstMsgDone = make(chan struct{}) 122 | secondMsgDone = make(chan struct{}) 123 | oldListenerExit = make(chan struct{}) 124 | 125 | oldListener := genListenerByConfig(getListenerConfig()) 126 | assert.NotNil(t, oldListener) 127 | oldListener.SetUnlinkOnClose(false) 128 | go func() { 129 | runErr := oldListener.Run() 130 | fmt.Println("oldListener run exit ", runErr) 131 | assert.Nil(t, runErr) 132 | oldListenerExit <- struct{}{} 133 | }() 134 | 135 | sessionManager, err := InitGlobalSessionManager(getSessionManagerConfig()) 136 | assert.NotNil(t, sessionManager) 137 | assert.Nil(t, err) 138 | 139 | if sessionManager == nil || oldListener == nil { 140 | fmt.Println("create listener or session manager error") 141 | return 142 | } 143 | 144 | // send first message 145 | stream, err := sessionManager.GetStream() 146 | assert.NotNil(t, stream) 147 | assert.Nil(t, err) 148 | err = stream.BufferWriter().WriteString(firstMsg) 149 | assert.Nil(t, err) 150 | err = stream.Flush(false) 151 | assert.Nil(t, err) 152 | ret, err := stream.BufferReader().ReadString(len(firstMsg)) 153 | assert.Nil(t, err) 154 | assert.Equal(t, ret, firstMsg) 155 | sessionManager.PutBack(stream) 156 | 157 | <-firstMsgDone 158 | fmt.Println("begin hot restart") 159 | // after handle first message, and then do hot restart 160 | newListener := genListenerByConfig(getListenerConfig()) 161 | assert.NotNil(t, newListener) 162 | newListener.SetUnlinkOnClose(false) 163 | go func() { 164 | runErr := newListener.Run() 165 | fmt.Println("newListener run exit ", runErr) 166 | assert.Nil(t, runErr) 167 | }() 168 | err = oldListener.HotRestart(1024) 169 | assert.Nil(t, err) 170 | 171 | // wait old server reply 172 | for !oldListener.IsHotRestartDone() { 173 | time.Sleep(time.Millisecond * 100) 174 | } 175 | err = oldListener.Close() 176 | assert.Nil(t, err) 177 | 178 | <-oldListenerExit 179 | fmt.Println("hot restart done, old server exit") 180 | stream, err = sessionManager.GetStream() 181 | assert.NotNil(t, stream) 182 | assert.Nil(t, err) 183 | err = stream.BufferWriter().WriteString(secondMsg) 184 | assert.Nil(t, err) 185 | err = stream.Flush(false) 186 | assert.Nil(t, err) 187 | ret, err = stream.BufferReader().ReadString(len(secondMsg)) 188 | assert.Nil(t, err) 189 | assert.Equal(t, ret, secondMsg) 190 | sessionManager.PutBack(stream) 191 | 192 | <-secondMsgDone 193 | err = newListener.Close() 194 | assert.Nil(t, err) 195 | time.Sleep(1 * time.Second) 196 | } 197 | -------------------------------------------------------------------------------- /net_listener.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package shmipc 18 | 19 | import ( 20 | "errors" 21 | "net" 22 | "sync" 23 | "sync/atomic" 24 | "time" 25 | ) 26 | 27 | const defaultBacklog = 4096 // backlog number is the stream to accept channel size 28 | 29 | type listener struct { 30 | listener net.Listener // the raw listener 31 | 32 | sessions map[*Session]*sync.WaitGroup // all established sessions 33 | mu sync.Mutex // lock sessions 34 | 35 | closed uint32 // to mark if the raw listener is closed when accept returns error 36 | closeCh chan struct{} // to make select returns when closed 37 | backlog chan net.Conn // all accepted streams(will never be closed otherwise may send to closed channel) 38 | } 39 | 40 | // create listener and run background goroutines 41 | func newListener(rawListener net.Listener, backlog int) *listener { 42 | listener := &listener{ 43 | listener: rawListener, 44 | sessions: make(map[*Session]*sync.WaitGroup), 45 | backlog: make(chan net.Conn, backlog), 46 | closeCh: make(chan struct{}), 47 | } 48 | go listener.listenLoop() 49 | return listener 50 | } 51 | 52 | // accept connection from the raw listener in loop, 53 | // spawn another goroutine to create session with the connection, save it, and then accept streams from the session. 54 | func (l *listener) listenLoop() { 55 | for { 56 | conn, err := l.listener.Accept() 57 | if err != nil { 58 | if atomic.LoadUint32(&l.closed) == 1 { 59 | internalLogger.infof("listener closed: %s", err) 60 | return 61 | } 62 | internalLogger.errorf("error when accept: %s", err) 63 | continue 64 | } 65 | internalLogger.info("receive a new incoming raw connection") 66 | go func() { 67 | session, err := Server(conn, DefaultConfig()) 68 | if err != nil { 69 | internalLogger.errorf("error when create session: %s", err) 70 | return 71 | } 72 | internalLogger.info("new session created") 73 | 74 | l.mu.Lock() 75 | if atomic.LoadUint32(&l.closed) == 1 { 76 | l.mu.Unlock() 77 | _ = session.Close() 78 | internalLogger.infof("listener is closed and the session should be closed") 79 | return 80 | } 81 | 82 | // Here we maintain a ref counter for every session. 83 | // The listener holds 1 ref, and every stream holds 1 ref. 84 | // Only when listener closed and all stream closed, the session be terminated. 85 | wg := new(sync.WaitGroup) 86 | wg.Add(1) 87 | l.sessions[session] = wg 88 | l.mu.Unlock() 89 | 90 | go func() { 91 | wg.Wait() 92 | _ = session.Close() 93 | internalLogger.infof("wait group finished, session is closed") 94 | }() 95 | 96 | for { 97 | stream, err := session.AcceptStream() 98 | if err != nil { 99 | if err != ErrSessionShutdown { 100 | internalLogger.errorf("error when accept new stream: %s", err) 101 | } 102 | _ = session.Close() 103 | internalLogger.infof("session is closed early: %s", err) 104 | 105 | l.mu.Lock() 106 | if _, ok := l.sessions[session]; ok { 107 | delete(l.sessions, session) 108 | wg.Done() 109 | } 110 | l.mu.Unlock() 111 | return 112 | } 113 | internalLogger.info("accepted a new stream") 114 | conn := newStreamWrapper(stream, stream.LocalAddr(), stream.RemoteAddr(), wg) 115 | select { 116 | case <-l.closeCh: 117 | return 118 | case l.backlog <- conn: 119 | } 120 | } 121 | }() 122 | } 123 | } 124 | 125 | // accept gets connections from the backlog channel 126 | func (l *listener) Accept() (net.Conn, error) { 127 | select { 128 | case conn := <-l.backlog: 129 | return conn, nil 130 | case <-l.closeCh: 131 | return nil, errors.New("listener is closed") 132 | } 133 | } 134 | 135 | // When listen closed, all sessions should be closed which otherwise would leak. 136 | // Because the underlying connection is closed, all streams will be closed too. 137 | // Note: The close behaviour here may differs from a normal connection. 138 | func (l *listener) Close() (err error) { 139 | // mark closed and close the listener 140 | swapped := atomic.CompareAndSwapUint32(&l.closed, 0, 1) 141 | err = l.listener.Close() 142 | // close the closeCh to make blocking call return 143 | if swapped { 144 | close(l.closeCh) 145 | } 146 | // closed and clear sessions to avoid leaking 147 | l.mu.Lock() 148 | for _, wg := range l.sessions { 149 | wg.Done() 150 | } 151 | l.sessions = map[*Session]*sync.WaitGroup{} 152 | l.mu.Unlock() 153 | return 154 | } 155 | 156 | // Addr is forwarded to the raw listener 157 | func (l *listener) Addr() net.Addr { 158 | return l.listener.Addr() 159 | } 160 | 161 | // Listen create listener with default backlog size(4096) 162 | // shmIPCAddress is uds address used as underlying connection, the returned value is net.Listener 163 | // Remember close the listener if it is created successfully, or goroutine may leak 164 | // Should I use Listen? 165 | // If you want the best performance, you should use low level API(not this one) to marshal and unmarshal manually, 166 | // which can achieve better batch results. 167 | // If you just care about the compatibility, you can use this high level API. For example, you can hardly change grpc 168 | // and protobuf, then you can use this listener to make it compatible with a little bit improved performance. 169 | func Listen(shmIPCAddress string) (net.Listener, error) { 170 | return ListenWithBacklog(shmIPCAddress, defaultBacklog) 171 | } 172 | 173 | // ListenWithBacklog create listener with given backlog size 174 | // shmIPCAddress is uds address used as underlying connection, the returned value is net.Listener 175 | // Remember close the listener if it is created successfully, or goroutine may leak 176 | // Should I use ListenWithBacklog? 177 | // If you want the best performance, you should use low level API(not this one) to marshal and unmarshal manually, 178 | // which can achieve better batch results. 179 | // If you just care about the compatibility, you can use this high level API. For example, you can hardly change grpc 180 | // and protobuf, then you can use this listener to make it compatible with a little bit improved performance. 181 | func ListenWithBacklog(shmIPCAddress string, backlog int) (net.Listener, error) { 182 | rawListener, err := net.Listen("unix", shmIPCAddress) 183 | if err != nil { 184 | return nil, err 185 | } 186 | return newListener(rawListener, backlog), nil 187 | } 188 | 189 | // A wrapper around a stream to impl net.Conn 190 | func newStreamWrapper(stream *Stream, localAddr, remoteAddr net.Addr, wg *sync.WaitGroup) net.Conn { 191 | wg.Add(1) 192 | return &streamWrapper{stream: stream, localAddr: localAddr, remoteAddr: remoteAddr, wg: wg} 193 | } 194 | 195 | type streamWrapper struct { 196 | stream *Stream 197 | localAddr net.Addr 198 | remoteAddr net.Addr 199 | 200 | closed uint32 201 | wg *sync.WaitGroup 202 | } 203 | 204 | func (s *streamWrapper) Read(b []byte) (n int, err error) { 205 | return s.stream.copyRead(b) 206 | } 207 | 208 | func (s *streamWrapper) Write(b []byte) (n int, err error) { 209 | return s.stream.copyWriteAndFlush(b) 210 | } 211 | 212 | func (s *streamWrapper) Close() error { 213 | if atomic.CompareAndSwapUint32(&s.closed, 0, 1) { 214 | _ = s.stream.Close() 215 | s.wg.Done() 216 | } 217 | return nil 218 | } 219 | 220 | func (s *streamWrapper) LocalAddr() net.Addr { 221 | return s.localAddr 222 | } 223 | 224 | func (s *streamWrapper) RemoteAddr() net.Addr { 225 | return s.remoteAddr 226 | } 227 | 228 | func (s *streamWrapper) SetDeadline(t time.Time) error { 229 | return s.stream.SetDeadline(t) 230 | } 231 | 232 | func (s *streamWrapper) SetReadDeadline(t time.Time) error { 233 | return s.stream.SetReadDeadline(t) 234 | } 235 | 236 | func (s *streamWrapper) SetWriteDeadline(t time.Time) error { 237 | return s.stream.SetWriteDeadline(t) 238 | } 239 | -------------------------------------------------------------------------------- /protocol_event.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package shmipc 18 | 19 | import ( 20 | "encoding/binary" 21 | "fmt" 22 | "strconv" 23 | ) 24 | 25 | // event type for internal implements 26 | const ( 27 | typeShareMemoryByFilePath eventType = 0 28 | // notify peer start consume 29 | typePolling eventType = 1 30 | // stream level event notify peer stream close 31 | typeStreamClose eventType = 2 32 | // typePing // TODO 33 | typeFallbackData eventType = 3 34 | // exchange proto version 35 | typeExchangeProtoVersion eventType = 4 36 | //query the mem map type supported by the server 37 | typeShareMemoryByMemfd eventType = 5 38 | // when server mapping share memory success, give the ack to client. 39 | typeAckShareMemory eventType = 6 40 | typeAckReadyRecvFD eventType = 7 41 | typeHotRestart eventType = 8 42 | typeHotRestartAck eventType = 9 43 | 44 | minEventType = typeShareMemoryByFilePath 45 | maxEventType = typeHotRestartAck 46 | ) 47 | 48 | func init() { 49 | for i := 0; i < int(maxSupportProtoVersion)+1; i++ { 50 | pollingEventWithVersion[i] = make([]byte, headerSize) 51 | pollingEventWithVersion[i].encode(headerSize, uint8(i), typePolling) 52 | } 53 | } 54 | 55 | type header []byte 56 | 57 | func (h header) Length() uint32 { 58 | return binary.BigEndian.Uint32(h[0:4]) 59 | } 60 | 61 | func (h header) Magic() uint16 { 62 | return binary.BigEndian.Uint16(h[4:6]) 63 | } 64 | 65 | func (h header) Version() uint8 { 66 | return h[6] 67 | } 68 | 69 | func (h header) MsgType() eventType { 70 | return eventType(h[7]) 71 | } 72 | 73 | func (h header) String() string { 74 | return fmt.Sprintf("Length:%d Magic:%d Version:%d Type:%s ", 75 | h.Length(), h.Magic(), h.Version(), h.MsgType().String()) 76 | } 77 | 78 | func (h header) encode(length uint32, version uint8, msgType eventType) { 79 | binary.BigEndian.PutUint32(h[0:4], length) 80 | binary.BigEndian.PutUint16(h[4:6], magicNumber) 81 | h[6] = version 82 | h[7] = uint8(msgType) 83 | } 84 | 85 | // header | seqID | status 86 | type fallbackDataEvent [headerSize + 8]byte 87 | 88 | func (f *fallbackDataEvent) encode(length int, version uint8, seqID uint32, status uint32) { 89 | binary.BigEndian.PutUint32(f[0:4], uint32(length)) 90 | binary.BigEndian.PutUint16(f[4:6], magicNumber) 91 | f[6] = version 92 | f[7] = uint8(typeFallbackData) 93 | binary.BigEndian.PutUint32(f[8:12], seqID) 94 | binary.BigEndian.PutUint32(f[12:16], status) 95 | } 96 | 97 | func (t eventType) String() string { 98 | switch t { 99 | case typeShareMemoryByFilePath: 100 | return "ShareMemoryByFilePath" 101 | case typePolling: 102 | return "Polling" 103 | case typeStreamClose: 104 | return "StreamClose" 105 | case typeFallbackData: 106 | return "FallbackData" 107 | case typeShareMemoryByMemfd: 108 | return "ShareMemoryByMemfd" 109 | case typeExchangeProtoVersion: 110 | return "ExchangeProtoVersion" 111 | case typeAckShareMemory: 112 | return "AckShareMemory" 113 | case typeAckReadyRecvFD: 114 | return "AckReadyRecvFD" 115 | case typeHotRestart: 116 | return "HotRestart" 117 | case typeHotRestartAck: 118 | return "HotRestartAck" 119 | } 120 | 121 | return "" + strconv.Itoa(int(t)) 122 | } 123 | 124 | func checkEventValid(hdr header) error { 125 | // Verify the magic&version 126 | if hdr.Magic() != magicNumber || hdr.Version() == 0 { 127 | internalLogger.errorf("shmipc: Invalid magic or version %d, %d", hdr.Magic(), hdr.Version()) 128 | return ErrInvalidVersion 129 | } 130 | mt := hdr.MsgType() 131 | if mt < minEventType || mt > maxEventType { 132 | internalLogger.errorf("shmipc, invalid protocol header: " + hdr.String()) 133 | return ErrInvalidMsgType 134 | } 135 | return nil 136 | } 137 | -------------------------------------------------------------------------------- /protocol_initializer.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package shmipc 18 | 19 | import ( 20 | "errors" 21 | "fmt" 22 | "net" 23 | ) 24 | 25 | var ( 26 | protocolVersionInitializersFactory = make(map[int]func(session *Session, firstEvent header) protocolInitializer) 27 | _ protocolInitializer = &protocolInitializerV2{} 28 | _ protocolInitializer = &protocolInitializerV3{} 29 | ) 30 | 31 | func init() { 32 | protocolVersionInitializersFactory[2] = func(session *Session, firstEvent header) protocolInitializer { 33 | return &protocolInitializerV2{session: session, firstEvent: firstEvent} 34 | } 35 | 36 | protocolVersionInitializersFactory[3] = func(session *Session, firstEvent header) protocolInitializer { 37 | return &protocolInitializerV3{session: session, firstEvent: firstEvent} 38 | } 39 | } 40 | 41 | type protocolInitializer interface { 42 | Version() uint8 43 | Init() error 44 | } 45 | 46 | type protocolInitializerV2 struct { 47 | session *Session 48 | netConn net.Conn 49 | firstEvent header 50 | } 51 | 52 | func (p *protocolInitializerV2) Init() error { 53 | if !p.session.isClient { 54 | if p.firstEvent.MsgType() != typeShareMemoryByFilePath { 55 | return fmt.Errorf("protocolInitializerV2 expect first event is:%d(%s),but:%d", 56 | typeShareMemoryByFilePath, typeShareMemoryByFilePath.String(), p.firstEvent.MsgType()) 57 | } 58 | return handleShareMemoryByFilePath(p.session, p.firstEvent) 59 | } 60 | return sendShareMemoryByFilePath(p.session) 61 | } 62 | 63 | func (p *protocolInitializerV2) Version() uint8 { 64 | return 2 65 | } 66 | 67 | type protocolInitializerV3 struct { 68 | session *Session 69 | firstEvent header 70 | } 71 | 72 | func (p *protocolInitializerV3) Init() error { 73 | if p.session.isClient { 74 | return p.clientInit() 75 | } 76 | return p.serverInit() 77 | } 78 | 79 | func (p *protocolInitializerV3) Version() uint8 { 80 | return 3 81 | } 82 | 83 | func (p *protocolInitializerV3) serverInit() error { 84 | if p.firstEvent.MsgType() != typeExchangeProtoVersion { 85 | return fmt.Errorf("protocolInitializerV3 expect firsts event is:%d(%s) but:%d", 86 | typeExchangeProtoVersion, typeExchangeProtoVersion.String(), p.firstEvent.MsgType()) 87 | } 88 | 89 | //1.exchange version 90 | if err := handleExchangeVersion(p.session, p.firstEvent); err != nil { 91 | return errors.New("protocolInitializerV3 exchangeVersion failed, reason:" + err.Error()) 92 | } 93 | 94 | //2.recv and mapping share memory 95 | h, err := blockReadEventHeader(p.session.connFd) 96 | if err != nil { 97 | return errors.New("protocolInitializerV3 blockReadEventHeader failed,reason:" + err.Error()) 98 | } 99 | switch h.MsgType() { 100 | case typeShareMemoryByFilePath: 101 | err = handleShareMemoryByFilePath(p.session, h) 102 | case typeShareMemoryByMemfd: 103 | err = handleShareMemoryByMemFd(p.session, h) 104 | default: 105 | return fmt.Errorf("expect event type is typeShareMemoryByFilePath or typeShareMemoryByMemfd but:%d %s", 106 | h.MsgType(), h.MsgType().String()) 107 | } 108 | 109 | if err != nil { 110 | return err 111 | } 112 | 113 | //3.ack share memory 114 | respHeader := header(make([]byte, headerSize)) 115 | respHeader.encode(headerSize, p.session.communicationVersion, typeAckShareMemory) 116 | protocolTrace(respHeader, nil, true) 117 | return blockWriteFull(p.session.connFd, respHeader) 118 | } 119 | 120 | func (p *protocolInitializerV3) clientInit() error { 121 | var err error 122 | memType := p.session.config.MemMapType 123 | switch memType { 124 | case MemMapTypeDevShmFile: 125 | err = sendShareMemoryByFilePath(p.session) 126 | case MemMapTypeMemFd: 127 | err = sendMemFdToPeer(p.session) 128 | default: 129 | err = fmt.Errorf("unknown memory type:%d", memType) 130 | } 131 | 132 | if err != nil { 133 | return err 134 | } 135 | _, err = waitEventHeader(p.session.connFd, typeAckShareMemory) 136 | 137 | return err 138 | } 139 | -------------------------------------------------------------------------------- /protocol_manager_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package shmipc 18 | 19 | import ( 20 | "fmt" 21 | "github.com/stretchr/testify/assert" 22 | _ "net/http/pprof" 23 | "testing" 24 | ) 25 | 26 | const udsPath = "/tmp/shmipc.sock" 27 | 28 | func TestProtocolCompatibilityForNetUnixConn(t *testing.T) { 29 | //testProtocolCompatibility(t, MemMapTypeDevShmFile) 30 | testProtocolCompatibility(t, MemMapTypeMemFd) 31 | } 32 | 33 | func testProtocolCompatibility(t *testing.T, memType MemMapType) { 34 | fmt.Println("----------bengin test protocolAdaptor MemMapType ----------", memType) 35 | clientConn, serverConn := testUdsConn() 36 | conf := testConf() 37 | conf.MemMapType = memType 38 | go func() { 39 | sconf := testConf() 40 | server, err := Server(serverConn, sconf) 41 | assert.Equal(t, true, err == nil, err) 42 | if err == nil { 43 | server.Close() 44 | } 45 | }() 46 | 47 | client, err := newSession(conf, clientConn, true) 48 | assert.Equal(t, true, err == nil, err) 49 | if err == nil { 50 | client.Close() 51 | } 52 | 53 | fmt.Println("----------end test protocolAdaptor client V2 to server V2 ----------") 54 | } 55 | -------------------------------------------------------------------------------- /queue.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package shmipc 18 | 19 | import ( 20 | "errors" 21 | "fmt" 22 | "os" 23 | "path/filepath" 24 | "sync" 25 | "sync/atomic" 26 | "syscall" 27 | "unsafe" 28 | ) 29 | 30 | const ( 31 | queueHeaderLength = 24 32 | ) 33 | 34 | type queueManager struct { 35 | path string 36 | sendQueue *queue 37 | recvQueue *queue 38 | mem []byte 39 | mmapMapType MemMapType 40 | memFd int 41 | } 42 | 43 | //default cap is 16384, which mean that 16384 * 8 = 128 KB memory. 44 | type queue struct { 45 | sync.Mutex 46 | head *int64 // consumer write, producer read 47 | tail *int64 // producer write, consumer read 48 | workingFlag *uint32 //when peer is consuming the queue, the workingFlag is 1, otherwise 0. 49 | cap int64 50 | queueBytesOnMemory []byte // it could be from share memory or process memory. 51 | } 52 | 53 | type queueElement struct { 54 | seqID uint32 55 | offsetInShmBuf uint32 56 | status uint32 57 | } 58 | 59 | func createQueueManagerWithMemFd(queuePathName string, queueCap uint32) (*queueManager, error) { 60 | memFd, err := MemfdCreate(queuePathName, 0) 61 | if err != nil { 62 | return nil, err 63 | } 64 | 65 | memSize := countQueueMemSize(queueCap) * queueCount 66 | if err := syscall.Ftruncate(memFd, int64(memSize)); err != nil { 67 | return nil, fmt.Errorf("createQueueManagerWithMemFd truncate share memory failed,%w", err) 68 | } 69 | 70 | mem, err := syscall.Mmap(memFd, 0, memSize, syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED) 71 | if err != nil { 72 | return nil, err 73 | } 74 | for i := 0; i < len(mem); i++ { 75 | mem[i] = 0 76 | } 77 | 78 | return &queueManager{ 79 | sendQueue: createQueueFromBytes(mem[:memSize/2], queueCap), 80 | recvQueue: createQueueFromBytes(mem[memSize/2:], queueCap), 81 | mem: mem, 82 | path: queuePathName, 83 | mmapMapType: MemMapTypeMemFd, 84 | memFd: memFd, 85 | }, nil 86 | } 87 | 88 | func createQueueManager(shmPath string, queueCap uint32) (*queueManager, error) { 89 | //ignore mkdir error 90 | _ = os.MkdirAll(filepath.Dir(shmPath), os.ModePerm) 91 | if pathExists(shmPath) { 92 | return nil, errors.New("queue was existed,path" + shmPath) 93 | } 94 | memSize := countQueueMemSize(queueCap) * queueCount 95 | if !canCreateOnDevShm(uint64(memSize), shmPath) { 96 | return nil, fmt.Errorf("err:%s path:%s, size:%d", ErrShareMemoryHadNotLeftSpace.Error(), shmPath, memSize) 97 | } 98 | f, err := os.OpenFile(shmPath, os.O_CREATE|os.O_RDWR, os.ModePerm) 99 | if err != nil { 100 | return nil, err 101 | } 102 | defer f.Close() 103 | 104 | if err := f.Truncate(int64(memSize)); err != nil { 105 | return nil, fmt.Errorf("truncate share memory failed,%s", err.Error()) 106 | } 107 | mem, err := syscall.Mmap(int(f.Fd()), 0, memSize, syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED) 108 | if err != nil { 109 | return nil, err 110 | } 111 | for i := 0; i < len(mem); i++ { 112 | mem[i] = 0 113 | } 114 | return &queueManager{ 115 | sendQueue: createQueueFromBytes(mem[:memSize/2], queueCap), 116 | recvQueue: createQueueFromBytes(mem[memSize/2:], queueCap), 117 | mem: mem, 118 | path: shmPath, 119 | }, nil 120 | } 121 | 122 | func mappingQueueManagerMemfd(queuePathName string, memFd int) (*queueManager, error) { 123 | var fileInfo syscall.Stat_t 124 | if err := syscall.Fstat(memFd, &fileInfo); err != nil { 125 | return nil, err 126 | } 127 | 128 | mappingSize := int(fileInfo.Size) 129 | //a queueManager have two queue, a queue's head and tail should align to 8 byte boundary 130 | if isArmArch() && mappingSize%16 != 0 { 131 | return nil, fmt.Errorf("the memory size of queue should be a multiple of 16") 132 | } 133 | mem, err := syscall.Mmap(memFd, 0, mappingSize, syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED) 134 | if err != nil { 135 | return nil, err 136 | } 137 | return &queueManager{ 138 | sendQueue: mappingQueueFromBytes(mem[mappingSize/2:]), 139 | recvQueue: mappingQueueFromBytes(mem[:mappingSize/2]), 140 | mem: mem, 141 | path: queuePathName, 142 | memFd: memFd, 143 | mmapMapType: MemMapTypeMemFd, 144 | }, nil 145 | } 146 | 147 | func mappingQueueManager(shmPath string) (*queueManager, error) { 148 | f, err := os.OpenFile(shmPath, os.O_RDWR, os.ModePerm) 149 | if err != nil { 150 | return nil, err 151 | } 152 | defer f.Close() 153 | fileInfo, err := f.Stat() 154 | if err != nil { 155 | return nil, err 156 | } 157 | mappingSize := int(fileInfo.Size()) 158 | 159 | //a queueManager have two queue, a queue's head and tail should align to 8 byte boundary 160 | if isArmArch() && mappingSize%16 != 0 { 161 | return nil, fmt.Errorf("the memory size of queue should be a multiple of 16") 162 | } 163 | mem, err := syscall.Mmap(int(f.Fd()), 0, mappingSize, syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED) 164 | if err != nil { 165 | return nil, err 166 | } 167 | return &queueManager{ 168 | sendQueue: mappingQueueFromBytes(mem[mappingSize/2:]), 169 | recvQueue: mappingQueueFromBytes(mem[:mappingSize/2]), 170 | mem: mem, 171 | path: shmPath, 172 | }, nil 173 | } 174 | 175 | func countQueueMemSize(queueCap uint32) int { 176 | return queueHeaderLength + queueElementLen*int(queueCap) 177 | } 178 | 179 | func createQueueFromBytes(data []byte, cap uint32) *queue { 180 | *(*uint32)(unsafe.Pointer(&data[0])) = cap 181 | q := mappingQueueFromBytes(data) 182 | *q.head = 0 183 | *q.tail = 0 184 | *q.workingFlag = 0 185 | return q 186 | } 187 | 188 | func mappingQueueFromBytes(data []byte) *queue { 189 | cap := *(*uint32)(unsafe.Pointer(&data[0])) 190 | queueStartOffset := queueHeaderLength 191 | queueEndOffset := queueHeaderLength + cap*queueElementLen 192 | if isArmArch() { 193 | // align 8 byte boundary for head and tail 194 | return &queue{ 195 | cap: int64(cap), 196 | workingFlag: (*uint32)(unsafe.Pointer(&data[4])), 197 | head: (*int64)(unsafe.Pointer(&data[8])), 198 | tail: (*int64)(unsafe.Pointer(&data[16])), 199 | queueBytesOnMemory: data[queueStartOffset:queueEndOffset], 200 | } 201 | } 202 | return &queue{ 203 | cap: int64(cap), 204 | head: (*int64)(unsafe.Pointer(&data[4])), 205 | tail: (*int64)(unsafe.Pointer(&data[12])), 206 | workingFlag: (*uint32)(unsafe.Pointer(&data[20])), 207 | queueBytesOnMemory: data[queueStartOffset:queueEndOffset], 208 | } 209 | } 210 | 211 | //cap prefer equals 2^n 212 | func createQueue(cap uint32) *queue { 213 | return createQueueFromBytes(make([]byte, queueHeaderLength+int(cap*queueElementLen)), cap) 214 | } 215 | 216 | func (q *queueManager) unmap() { 217 | if err := syscall.Munmap(q.mem); err != nil { 218 | internalLogger.warnf("queueManager unmap error:" + err.Error()) 219 | } 220 | if q.mmapMapType == MemMapTypeDevShmFile { 221 | if err := os.Remove(q.path); err != nil { 222 | internalLogger.warnf("queueManager remove file:%s failed, error=%s", q.path, err.Error()) 223 | } else { 224 | internalLogger.infof("queueManager remove file:%s", q.path) 225 | } 226 | } else { 227 | if err := syscall.Close(q.memFd); err != nil { 228 | internalLogger.warnf("queueManager close queue fd:%d, error:%s", q.memFd, err.Error()) 229 | } else { 230 | internalLogger.infof("queueManager close queue fd:%d", q.memFd) 231 | } 232 | } 233 | } 234 | 235 | func (q *queue) isFull() bool { 236 | return q.size() == q.cap 237 | } 238 | 239 | func (q *queue) isEmpty() bool { 240 | return q.size() == 0 241 | } 242 | 243 | func (q *queue) size() int64 { 244 | return atomic.LoadInt64(q.tail) - atomic.LoadInt64(q.head) 245 | } 246 | 247 | func (q *queue) pop() (e queueElement, err error) { 248 | //atomic ensure the data that peer write to share memory could be see. 249 | head := atomic.LoadInt64(q.head) 250 | if head >= atomic.LoadInt64(q.tail) { 251 | err = errQueueEmpty 252 | return 253 | } 254 | queueOffset := (head % q.cap) * queueElementLen 255 | e.seqID = *(*uint32)(unsafe.Pointer(&q.queueBytesOnMemory[queueOffset])) 256 | e.offsetInShmBuf = *(*uint32)(unsafe.Pointer(&q.queueBytesOnMemory[queueOffset+4])) 257 | e.status = *(*uint32)(unsafe.Pointer(&q.queueBytesOnMemory[queueOffset+8])) 258 | atomic.AddInt64(q.head, 1) 259 | return 260 | } 261 | 262 | func (q *queue) put(e queueElement) error { 263 | //ensure that increasing q.tail and writing queueElement are both atomic. 264 | //because if firstly increase q.tail, the peer will think that the queue have new element and will consume it. 265 | //but at this moment, the new element hadn't finished writing, the peer will get a old element. 266 | q.Lock() 267 | tail := atomic.LoadInt64(q.tail) 268 | if tail-atomic.LoadInt64(q.head) >= q.cap { 269 | q.Unlock() 270 | return ErrQueueFull 271 | } 272 | queueOffset := (tail % q.cap) * queueElementLen 273 | *(*uint32)(unsafe.Pointer(&q.queueBytesOnMemory[queueOffset])) = e.seqID 274 | *(*uint32)(unsafe.Pointer(&q.queueBytesOnMemory[queueOffset+4])) = e.offsetInShmBuf 275 | *(*uint32)(unsafe.Pointer(&q.queueBytesOnMemory[queueOffset+8])) = e.status 276 | atomic.AddInt64(q.tail, 1) 277 | q.Unlock() 278 | return nil 279 | } 280 | 281 | func (q *queue) consumerIsWorking() bool { 282 | return (atomic.LoadUint32(q.workingFlag)) > 0 283 | } 284 | 285 | func (q *queue) markWorking() bool { 286 | return atomic.CompareAndSwapUint32(q.workingFlag, 0, 1) 287 | } 288 | 289 | func (q *queue) markNotWorking() bool { 290 | atomic.StoreUint32(q.workingFlag, 0) 291 | if q.size() == 0 { 292 | return true 293 | } 294 | atomic.StoreUint32(q.workingFlag, 1) 295 | return false 296 | } 297 | -------------------------------------------------------------------------------- /queue_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package shmipc 18 | 19 | import ( 20 | "fmt" 21 | _ "net/http/pprof" 22 | "sync" 23 | "testing" 24 | "time" 25 | 26 | "github.com/stretchr/testify/assert" 27 | ) 28 | 29 | var ( 30 | queueCap = 1000000 31 | parallelism = 100 32 | ) 33 | 34 | func TestQueueManager_CreateMapping(t *testing.T) { 35 | path := "/tmp/ipc.queue" 36 | qm1, err := createQueueManager(path, 8192) 37 | assert.Equal(t, nil, err) 38 | qm2, err := mappingQueueManager(path) 39 | assert.Equal(t, nil, err) 40 | 41 | assert.Equal(t, nil, qm1.sendQueue.put(queueElement{})) 42 | _, err = qm2.recvQueue.pop() 43 | assert.Equal(t, nil, err) 44 | 45 | assert.Equal(t, nil, qm2.sendQueue.put(queueElement{})) 46 | _, err = qm1.recvQueue.pop() 47 | assert.Equal(t, nil, err) 48 | 49 | qm1.unmap() 50 | } 51 | 52 | func TestQueueOperate(t *testing.T) { 53 | q := createQueue(defaultQueueCap) 54 | 55 | fmt.Println("-----------test queue operate ----------------") 56 | assert.Equal(t, true, q.isEmpty(), "queue should be empty") 57 | assert.Equal(t, false, q.isFull(), "queue is not full") 58 | assert.Equal(t, int64(0), q.size(), "queue size should be 0") 59 | 60 | putCount, popCount := 0, 0 61 | var err error 62 | for i := 0; i < defaultQueueCap; i++ { 63 | err = q.put(queueElement{seqID: uint32(i), offsetInShmBuf: uint32(i), status: uint32(i)}) 64 | assert.Equal(t, nil, err) 65 | putCount++ 66 | } 67 | err = q.put(queueElement{1, 1, 1}) 68 | assert.Equal(t, ErrQueueFull, err) 69 | assert.Equal(t, true, q.isFull(), "queue should be full") 70 | assert.Equal(t, false, q.isEmpty(), "queue is not empty") 71 | assert.Equal(t, int64(putCount), q.size(), "queue size") 72 | 73 | for i := 0; i < defaultQueueCap; i++ { 74 | e, err := q.pop() 75 | assert.Equal(t, nil, err) 76 | popCount++ 77 | assert.Equal(t, i, int(e.seqID), "queue pop verify seqID") 78 | assert.Equal(t, i, int(e.offsetInShmBuf), "queue pop verify offset") 79 | assert.Equal(t, i, int(e.status), "queue pop verify offset") 80 | } 81 | _, err = q.pop() 82 | assert.Equal(t, errQueueEmpty, err) 83 | assert.Equal(t, false, q.isFull(), "queue is not full") 84 | assert.Equal(t, true, q.isEmpty(), "queue should be empty") 85 | assert.Equal(t, int64(0), q.size(), "queue size") 86 | 87 | fmt.Println("-----------test queue status ----------------") 88 | assert.Equal(t, false, q.consumerIsWorking(), "consumer should be not working") 89 | q.markWorking() 90 | assert.Equal(t, true, q.consumerIsWorking(), "consumer should be working") 91 | q.markNotWorking() 92 | assert.Equal(t, false, q.consumerIsWorking(), "consumer should be not working") 93 | 94 | _ = q.put(queueElement{1, 1, 1}) 95 | q.markNotWorking() 96 | assert.Equal(t, true, q.consumerIsWorking(), "consumer should be working") 97 | } 98 | 99 | func TestQueueMultiProducerAndSingleConsumer(t *testing.T) { 100 | fmt.Println("-----------test queue multi-producer single consumer ----------------") 101 | q := createQueue(uint32(queueCap)) 102 | var wg sync.WaitGroup 103 | popCount := 0 104 | for i := 0; i < parallelism; i++ { 105 | //producer 106 | go func() { 107 | for k := 0; k < queueCap/parallelism; k++ { 108 | wg.Add(1) 109 | if err := q.put(queueElement{seqID: 1, offsetInShmBuf: 1, status: 1}); err != nil { 110 | panic(err) 111 | } 112 | } 113 | }() 114 | } 115 | 116 | //consumer 117 | for popCount != queueCap { 118 | _, err := q.pop() 119 | if err == nil { 120 | wg.Done() 121 | popCount++ 122 | } else { 123 | time.Sleep(time.Microsecond) 124 | } 125 | } 126 | wg.Wait() 127 | } 128 | 129 | func BenchmarkQueuePut(b *testing.B) { 130 | q := createQueue(uint32(b.N)) 131 | b.ResetTimer() 132 | b.ReportAllocs() 133 | for i := 0; i < b.N; i++ { 134 | _ = q.put(queueElement{seqID: uint32(i), offsetInShmBuf: uint32(i)}) 135 | } 136 | } 137 | 138 | func BenchmarkQueuePop(b *testing.B) { 139 | q := createQueue(uint32(b.N)) 140 | for i := 0; i < b.N; i++ { 141 | _ = q.put(queueElement{seqID: uint32(i), offsetInShmBuf: uint32(i)}) 142 | } 143 | b.ResetTimer() 144 | b.ReportAllocs() 145 | for i := 0; i < b.N; i++ { 146 | q.pop() 147 | } 148 | } 149 | 150 | func BenchmarkQueueMultiPut(b *testing.B) { 151 | b.SetParallelism(50) 152 | q := createQueue(uint32(b.N)) 153 | b.ResetTimer() 154 | b.ReportAllocs() 155 | b.RunParallel(func(pb *testing.PB) { 156 | c := 0 157 | for pb.Next() { 158 | c++ 159 | _ = q.put(queueElement{seqID: uint32(c), offsetInShmBuf: uint32(c)}) 160 | 161 | } 162 | }) 163 | } 164 | 165 | func BenchmarkQueueMultiPop(b *testing.B) { 166 | b.SetParallelism(50) 167 | q := createQueue(uint32(b.N)) 168 | b.ResetTimer() 169 | b.ReportAllocs() 170 | for i := 0; i < b.N; i++ { 171 | _ = q.put(queueElement{seqID: uint32(i), offsetInShmBuf: uint32(i)}) 172 | } 173 | b.RunParallel(func(pb *testing.PB) { 174 | 175 | for pb.Next() { 176 | q.pop() 177 | } 178 | }) 179 | } 180 | -------------------------------------------------------------------------------- /stats.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package shmipc 18 | 19 | // Monitor could emit some metrics with periodically 20 | type Monitor interface { 21 | // OnEmitSessionMetrics was called by Session with periodically. 22 | OnEmitSessionMetrics(PerformanceMetrics, StabilityMetrics, ShareMemoryMetrics, *Session) 23 | // flush metrics 24 | Flush() error 25 | } 26 | 27 | type stats struct { 28 | allocShmErrorCount uint64 29 | fallbackWriteCount uint64 30 | fallbackReadCount uint64 31 | eventConnErrorCount uint64 32 | queueFullErrorCount uint64 33 | recvPollingEventCount uint64 34 | sendPollingEventCount uint64 35 | outFlowBytes uint64 36 | inFlowBytes uint64 37 | hotRestartSuccessCount uint64 38 | hotRestartErrorCount uint64 39 | } 40 | 41 | //PerformanceMetrics is the metrics about performance 42 | type PerformanceMetrics struct { 43 | ReceiveSyncEventCount uint64 //the SyncEvent count that session had received 44 | SendSyncEventCount uint64 //the SyncEvent count that session had sent 45 | OutFlowBytes uint64 //the out flow in bytes that session had sent 46 | InFlowBytes uint64 //the in flow in bytes that session had receive 47 | SendQueueCount uint64 //the pending count of send queue 48 | ReceiveQueueCount uint64 //the pending count of receive queue 49 | } 50 | 51 | //StabilityMetrics is the metrics about stability 52 | type StabilityMetrics struct { 53 | AllocShmErrorCount uint64 //the error count of allocating share memory 54 | FallbackWriteCount uint64 //the count of the fallback data write to unix/tcp connection 55 | FallbackReadCount uint64 //the error count of receiving fallback data from unix/tcp connection every period 56 | 57 | //the error count of unix/tcp connection 58 | //which usually happened in that the peer's process exit(crashed or other reason) 59 | EventConnErrorCount uint64 60 | 61 | //the error count due to the IO-Queue(SendQueue or ReceiveQueue) is full 62 | //which usually happened in that the peer was busy 63 | QueueFullErrorCount uint64 64 | 65 | //current all active stream count 66 | ActiveStreamCount uint64 67 | 68 | //the successful count of hot restart 69 | HotRestartSuccessCount uint64 70 | //the failed count of hot restart 71 | HotRestartErrorCount uint64 72 | } 73 | 74 | //ShareMemoryMetrics is the metrics about share memory's status 75 | type ShareMemoryMetrics struct { 76 | CapacityOfShareMemoryInBytes uint64 //capacity of all share memory 77 | AllInUsedShareMemoryInBytes uint64 //current in-used share memory 78 | } 79 | -------------------------------------------------------------------------------- /sys_memfd_create_bsd.go: -------------------------------------------------------------------------------- 1 | // +build !linux 2 | 3 | /* 4 | * Copyright 2023 CloudWeGo Authors 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | 19 | package shmipc 20 | 21 | import ( 22 | "fmt" 23 | "runtime" 24 | ) 25 | 26 | // MemfdCreate used to create memfd (only linux support) 27 | func MemfdCreate(name string, flags int) (fd int, err error) { 28 | return 0, fmt.Errorf("%s unsupported MemfdCreate system call", runtime.GOOS) 29 | } 30 | -------------------------------------------------------------------------------- /sys_memfd_create_linux.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package shmipc 18 | 19 | import ( 20 | "golang.org/x/sys/unix" 21 | ) 22 | 23 | // linux 3.17+ provided 24 | func MemfdCreate(name string, flags int) (fd int, err error) { 25 | memFd, err := unix.MemfdCreate(memfdCreateName+name, 0) 26 | if err != nil { 27 | return 0, err 28 | } 29 | 30 | return memFd, nil 31 | } 32 | -------------------------------------------------------------------------------- /util.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package shmipc 18 | 19 | import ( 20 | "os" 21 | "reflect" 22 | "runtime" 23 | "strings" 24 | "sync" 25 | "time" 26 | "unsafe" 27 | 28 | "github.com/shirou/gopsutil/v3/disk" 29 | ) 30 | 31 | var ( 32 | timerPool = &sync.Pool{ 33 | New: func() interface{} { 34 | timer := time.NewTimer(time.Hour * 1e6) 35 | timer.Stop() 36 | return timer 37 | }, 38 | } 39 | ) 40 | 41 | // asyncSendErr is used to try an async send of an error 42 | func asyncSendErr(ch chan error, err error) { 43 | if ch == nil { 44 | return 45 | } 46 | select { 47 | case ch <- err: 48 | default: 49 | } 50 | } 51 | 52 | // asyncNotify is used to signal a waiting goroutine 53 | func asyncNotify(ch chan struct{}) { 54 | select { 55 | case ch <- struct{}{}: 56 | default: 57 | } 58 | } 59 | 60 | // min computes the minimum of two values 61 | func min(a, b uint32) uint32 { 62 | if a < b { 63 | return a 64 | } 65 | return b 66 | } 67 | 68 | func minInt(a, b int) int { 69 | if a < b { 70 | return a 71 | } 72 | return b 73 | } 74 | 75 | func maxInt(a, b int) int { 76 | if a < b { 77 | return b 78 | } 79 | return a 80 | } 81 | 82 | func string2bytesZeroCopy(s string) []byte { 83 | stringHeader := (*reflect.StringHeader)(unsafe.Pointer(&s)) 84 | 85 | bh := reflect.SliceHeader{ 86 | Data: stringHeader.Data, 87 | Len: stringHeader.Len, 88 | Cap: stringHeader.Len, 89 | } 90 | 91 | return *(*[]byte)(unsafe.Pointer(&bh)) 92 | } 93 | 94 | func pathExists(path string) bool { 95 | _, err := os.Stat(path) 96 | if err != nil { 97 | return os.IsExist(err) 98 | } 99 | return true 100 | } 101 | 102 | //In Linux OS, there is a limitation which is the capacity of the tmpfs (which usually on the directory /dev/shm). 103 | //if we do mmap on /dev/shm/xxx and the free memory of the tmpfs is not enough, mmap have no any error. 104 | //but when program is running, it maybe crashed due to the bus error. 105 | func canCreateOnDevShm(size uint64, path string) bool { 106 | if runtime.GOOS == "linux" && strings.Contains(path, "/dev/shm") { 107 | stat, err := disk.Usage("/dev/shm") 108 | if err != nil { 109 | internalLogger.warnf("could read /dev/shm free size, canCreateOnDevShm default return true") 110 | return false 111 | } 112 | return stat.Free >= size 113 | } 114 | return true 115 | } 116 | 117 | // delete only existing files 118 | func safeRemoveUdsFile(filename string) bool { 119 | fileInfo, err := os.Stat(filename) 120 | if err != nil { 121 | internalLogger.warnf("%s Stat error %+v", filename, err) 122 | return false 123 | } 124 | 125 | if fileInfo.IsDir() { 126 | return false 127 | } 128 | 129 | if err := os.Remove(filename); err != nil { 130 | internalLogger.warnf("%s Remove error %+v", filename, err) 131 | return false 132 | } 133 | 134 | return true 135 | } 136 | 137 | func isArmArch() bool { 138 | return runtime.GOARCH == "arm" || runtime.GOARCH == "arm64" 139 | } 140 | -------------------------------------------------------------------------------- /util_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2023 CloudWeGo Authors 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package shmipc 18 | 19 | import ( 20 | "math" 21 | "os" 22 | "runtime" 23 | "testing" 24 | 25 | "github.com/shirou/gopsutil/v3/disk" 26 | "github.com/stretchr/testify/assert" 27 | ) 28 | 29 | func TestAsyncSendErr(t *testing.T) { 30 | ch := make(chan error) 31 | asyncSendErr(ch, ErrTimeout) 32 | select { 33 | case <-ch: 34 | t.Fatalf("should not get") 35 | default: 36 | } 37 | 38 | ch = make(chan error, 1) 39 | asyncSendErr(ch, ErrTimeout) 40 | select { 41 | case <-ch: 42 | default: 43 | t.Fatalf("should get") 44 | } 45 | } 46 | 47 | func TestAsyncNotify(t *testing.T) { 48 | ch := make(chan struct{}) 49 | asyncNotify(ch) 50 | select { 51 | case <-ch: 52 | t.Fatalf("should not get") 53 | default: 54 | } 55 | 56 | ch = make(chan struct{}, 1) 57 | asyncNotify(ch) 58 | select { 59 | case <-ch: 60 | default: 61 | t.Fatalf("should get") 62 | } 63 | } 64 | 65 | func TestMin(t *testing.T) { 66 | if min(1, 2) != 1 { 67 | t.Fatalf("bad") 68 | } 69 | if min(2, 1) != 1 { 70 | t.Fatalf("bad") 71 | } 72 | } 73 | 74 | func TestMinInt(t *testing.T) { 75 | if minInt(1, 2) != 1 { 76 | t.Fatalf("bad") 77 | } 78 | if minInt(2, 1) != 1 { 79 | t.Fatalf("bad") 80 | } 81 | } 82 | 83 | func TestMaxInt(t *testing.T) { 84 | if maxInt(1, 2) != 2 { 85 | t.Fatalf("bad") 86 | } 87 | if maxInt(2, 1) != 2 { 88 | t.Fatalf("bad") 89 | } 90 | } 91 | 92 | func TestPathExists(t *testing.T) { 93 | path := "test_path_exists" 94 | f, err := os.OpenFile(path, os.O_CREATE, os.ModePerm) 95 | if err != nil { 96 | t.Fatal(err) 97 | } 98 | f.Close() 99 | assert.Equal(t, true, pathExists(path)) 100 | os.Remove(path) 101 | } 102 | 103 | func TestCanCreateOnDevShm(t *testing.T) { 104 | switch runtime.GOOS { 105 | case "linux": 106 | //just on /dev/shm, other always return true 107 | assert.Equal(t, true, canCreateOnDevShm(math.MaxUint64, "sdffafds")) 108 | stat, err := disk.Usage("/dev/shm") 109 | if err != nil { 110 | t.Fatal(err) 111 | } 112 | assert.Equal(t, true, canCreateOnDevShm(stat.Free, "/dev/shm/xxx")) 113 | assert.Equal(t, false, canCreateOnDevShm(stat.Free+1, "/dev/shm/yyy")) 114 | case "darwin": 115 | //always return true 116 | assert.Equal(t, true, canCreateOnDevShm(33333, "sdffafds")) 117 | } 118 | } 119 | 120 | func TestSafeRemoveUdsFile(t *testing.T) { 121 | path := "test_path_remove" 122 | f, err := os.OpenFile(path, os.O_CREATE, os.ModePerm) 123 | if err != nil { 124 | t.Fatal(err) 125 | } 126 | f.Close() 127 | 128 | assert.Equal(t, true, safeRemoveUdsFile(path)) 129 | assert.Equal(t, false, safeRemoveUdsFile("not_existing_file")) 130 | } 131 | --------------------------------------------------------------------------------