├── .gitattributes ├── .github └── workflows │ ├── autobahn.yml │ ├── build_bsd.yml │ ├── build_linux.yml │ ├── build_windows.yml │ ├── close_inactive_issues.yml │ ├── codeql-analysis.yml │ └── golangci-lint.yml ├── .gitignore ├── .golangci.yml ├── LICENSE ├── Makefile ├── README.md ├── autobahn ├── .gitignore ├── config │ └── fuzzingclient.json ├── reporter │ └── reporter.go ├── script │ └── run.sh └── server │ └── server.go ├── conn.go ├── conn_std.go ├── conn_unix.go ├── engine.go ├── engine_std.go ├── engine_unix.go ├── error.go ├── extension └── tls │ └── tls.go ├── go.mod ├── go.sum ├── lmux ├── lmux.go └── lmux_test.go ├── logging ├── log.go └── log_test.go ├── mempool ├── aligned_allocator.go ├── allocator.go ├── debugger.go ├── mempool.go ├── mempool_test.go ├── std_allocator.go └── trace_debugger.go ├── nbhttp ├── body.go ├── body_test.go ├── client.go ├── client_conn.go ├── convert.go ├── convert_test.go ├── engine.go ├── error.go ├── parser.go ├── parser_test.go ├── processor.go ├── response.go ├── server.go ├── state.go ├── table.go └── websocket │ ├── compression.go │ ├── conn.go │ ├── dialer.go │ ├── error.go │ ├── upgrader.go │ └── upgrader_test.go ├── nbio_test.go ├── net_unix.go ├── poller_epoll.go ├── poller_kqueue.go ├── poller_std.go ├── protocol_stack.go ├── sendfile_std.go ├── sendfile_unix.go ├── taskpool ├── iotaskpool.go ├── taskpool.go └── taskpool_test.go ├── timer ├── timer.go └── timer_test.go ├── tools └── norace │ └── norace.go ├── writev_bsd.go └── writev_linux.go /.gitattributes: -------------------------------------------------------------------------------- 1 | *.go text eol=lf -------------------------------------------------------------------------------- /.github/workflows/autobahn.yml: -------------------------------------------------------------------------------- 1 | name: Autobahn 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | - dev 8 | pull_request: 9 | branches: 10 | - master 11 | - dev 12 | 13 | jobs: 14 | Autobahn: 15 | strategy: 16 | matrix: 17 | os: [ ubuntu-latest ] 18 | go: [ 1.18.x ] 19 | runs-on: ${{ matrix.os }} 20 | steps: 21 | - name: Checkout 22 | uses: actions/checkout@v2 23 | - name: Setup Go 24 | uses: actions/setup-go@v2 25 | with: 26 | go-version: ${{ matrix.go }} 27 | - name: Autobahn Test 28 | env: 29 | CRYPTOGRAPHY_ALLOW_OPENSSL_102: yes 30 | run: | 31 | chmod +x ./autobahn/script/run.sh & ./autobahn/script/run.sh 32 | - name: Autobahn Report Artifact 33 | if: >- 34 | startsWith(matrix.os, 'ubuntu') 35 | uses: actions/upload-artifact@v4 36 | 37 | with: 38 | name: autobahn report ${{ matrix.go }} ${{ matrix.os }} 39 | path: autobahn/report 40 | retention-days: 7 41 | -------------------------------------------------------------------------------- /.github/workflows/build_bsd.yml: -------------------------------------------------------------------------------- 1 | name: Build-BSD 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | - dev 8 | pull_request: 9 | branches: 10 | - master 11 | - dev 12 | 13 | env: 14 | GO111MODULE: off 15 | 16 | jobs: 17 | Build-BSD: 18 | name: Build-BSD 19 | strategy: 20 | fail-fast: false 21 | matrix: 22 | go: [1.18.x] 23 | os: [macos-latest] 24 | runs-on: ${{ matrix.os}} 25 | steps: 26 | - name: install golang 27 | uses: actions/setup-go@v2 28 | with: 29 | go-version: ${{ matrix.go }} 30 | - name: checkout 31 | uses: actions/checkout@v2 32 | - name: go env 33 | run: | 34 | printf "$(go version)\n" 35 | printf "\n\ngo environment:\n\n" 36 | go get -u github.com/lesismal/nbio 37 | ulimit -n 30000 38 | go env 39 | echo "short_sha=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT 40 | - name: go test 41 | run: go test -covermode=atomic -timeout 60s -coverprofile="./coverage" 42 | -------------------------------------------------------------------------------- /.github/workflows/build_linux.yml: -------------------------------------------------------------------------------- 1 | name: Build-Linux 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | - dev 8 | pull_request: 9 | branches: 10 | - master 11 | - dev 12 | 13 | env: 14 | GO111MODULE: off 15 | 16 | jobs: 17 | Build-Linux: 18 | name: Build-Linux 19 | strategy: 20 | fail-fast: false 21 | matrix: 22 | go: [1.18.x] 23 | os: [ubuntu-latest] 24 | runs-on: ${{ matrix.os}} 25 | steps: 26 | - name: install golang 27 | uses: actions/setup-go@v2 28 | with: 29 | go-version: ${{ matrix.go }} 30 | - name: checkout 31 | uses: actions/checkout@v2 32 | - name: go env 33 | run: | 34 | printf "$(go version)\n" 35 | printf "\n\ngo environment:\n\n" 36 | go get -u github.com/lesismal/nbio 37 | ulimit -n 30000 38 | go env 39 | echo "short_sha=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT 40 | - name: go test 41 | run: go test -covermode=atomic -timeout 60s -coverprofile="./coverage" 42 | -------------------------------------------------------------------------------- /.github/workflows/build_windows.yml: -------------------------------------------------------------------------------- 1 | name: Build-Windows 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | - dev 8 | pull_request: 9 | branches: 10 | - master 11 | - dev 12 | 13 | env: 14 | GO111MODULE: off 15 | 16 | jobs: 17 | Build-Windows: 18 | name: Build-Windows 19 | strategy: 20 | fail-fast: false 21 | matrix: 22 | go: [1.18.x] 23 | os: [windows-latest] 24 | runs-on: ${{ matrix.os}} 25 | steps: 26 | - name: install golang 27 | uses: actions/setup-go@v2 28 | with: 29 | go-version: ${{ matrix.go }} 30 | - name: checkout 31 | uses: actions/checkout@v2 32 | - name: go env 33 | run: | 34 | printf "$(go version)\n" 35 | printf "\n\ngo environment:\n\n" 36 | go get -u github.com/lesismal/nbio/... 37 | go env 38 | echo "short_sha=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT 39 | - name: go test 40 | run: go test -covermode=atomic -timeout 60s -coverprofile="./coverage" 41 | - name: update code coverage report 42 | uses: codecov/codecov-action@v1.2.1 43 | with: 44 | files: "./coverage" 45 | flags: unittests 46 | verbose: true 47 | name: codecov-nbio -------------------------------------------------------------------------------- /.github/workflows/close_inactive_issues.yml: -------------------------------------------------------------------------------- 1 | name: Close-Inactive-Issues 2 | on: 3 | schedule: 4 | - cron: "30 1 * * *" 5 | 6 | jobs: 7 | Close-Inactive-Issues: 8 | runs-on: ubuntu-latest 9 | permissions: 10 | issues: write 11 | pull-requests: write 12 | steps: 13 | - uses: actions/stale@v3 14 | with: 15 | days-before-issue-stale: 30 16 | days-before-issue-close: 14 17 | stale-issue-label: "stale" 18 | stale-issue-message: "This issue is stale because it has been open for 30 days with no activity." 19 | close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale." 20 | days-before-pr-stale: -1 21 | days-before-pr-close: -1 22 | repo-token: ${{ secrets.GITHUB_TOKEN }} 23 | -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: 17 | - master 18 | - dev 19 | pull_request: 20 | # The branches below must be a subset of the branches above 21 | branches: 22 | - master 23 | - dev 24 | schedule: 25 | - cron: '41 10 * * 3' 26 | 27 | jobs: 28 | CodeQL: 29 | name: CodeQL 30 | runs-on: ubuntu-latest 31 | permissions: 32 | actions: read 33 | contents: read 34 | security-events: write 35 | 36 | strategy: 37 | fail-fast: false 38 | matrix: 39 | language: [ 'go' ] 40 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python' ] 41 | # Learn more: 42 | # https://docs.github.com/en/free-pro-team@latest/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#changing-the-languages-that-are-analyzed 43 | 44 | steps: 45 | - name: Checkout 46 | uses: actions/checkout@v2 47 | 48 | # Initializes the CodeQL tools for scanning. 49 | - name: Init 50 | uses: github/codeql-action/init@v2 51 | with: 52 | languages: ${{ matrix.language }} 53 | # If you wish to specify custom queries, you can do so here or in a config file. 54 | # By default, queries listed here will override any specified in a config file. 55 | # Prefix the list here with "+" to use these queries and those in the config file. 56 | # queries: ./path/to/local/query, your-org/your-repo/queries@main 57 | 58 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 59 | # If this step fails, then you should remove it and run the build manually (see below) 60 | - name: Autobuild 61 | uses: github/codeql-action/autobuild@v2 62 | 63 | # ℹ️ Command-line programs to run using the OS shell. 64 | # 📚 https://git.io/JvXDl 65 | 66 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines 67 | # and modify them (or add more) to build your code if your project 68 | # uses a compiled language 69 | 70 | #- run: | 71 | # make bootstrap 72 | # make release 73 | 74 | - name: Analysis 75 | uses: github/codeql-action/analyze@v2 76 | -------------------------------------------------------------------------------- /.github/workflows/golangci-lint.yml: -------------------------------------------------------------------------------- 1 | name: golangci-lint 2 | on: 3 | push: 4 | tags: 5 | - v* 6 | branches: 7 | - master 8 | - dev 9 | pull_request: 10 | jobs: 11 | golangci: 12 | strategy: 13 | matrix: 14 | go-version: [1.18.x] 15 | os: [ubuntu-latest] 16 | name: lint 17 | runs-on: ${{ matrix.os }} 18 | steps: 19 | - uses: actions/checkout@v2 20 | - name: golangci-lint 21 | uses: golangci/golangci-lint-action@v6.1.1 22 | with: 23 | # Required: the version of golangci-lint is required and must be specified without patch version: we always use the latest patch version. 24 | version: v1.61.0 25 | # Optional: working directory, useful for monorepos 26 | # working-directory: somedir 27 | 28 | # Optional: golangci-lint command line arguments. 29 | # args: --issues-exit-code=0 30 | 31 | # Optional: show only new issues if it's a pull request. The default value is `false`. 32 | # only-new-issues: true 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | bin -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 lesismal 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | GO?=go 2 | PACKAGE_DIRS= $(shell $(GO) list -f '{{ .Dir }}' ./...|grep -v 'lesismal/nbio/examples') 3 | PACKAGES= $(shell $(GO) list ./...|grep -v 'lesismal/nbio/examples') 4 | .PHONY: all vet lint 5 | 6 | all: vet lint test 7 | 8 | vet: 9 | $(GO) vet $(PACKAGES) 10 | 11 | lint: 12 | golangci-lint run $(PACKAGE_DIRS) 13 | 14 | test: 15 | $(GO) test -v $(PACKAGES) 16 | 17 | clean: 18 | rm -rf ./autobahn/bin/* 19 | rm -rf ./autobahn/report/* 20 | 21 | autobahn: 22 | chmod +x ./autobahn/script/run.sh & ./autobahn/script/run.sh 23 | 24 | .PHONY: all vet lint test clean autobahn 25 | 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NBIO - NON-BLOCKING IO 2 | 3 | 4 | 5 | 6 | [![Mentioned in Awesome Go][3]][4] [![MIT licensed][5]][6] [![Go Version][7]][8] [![Build Status][9]][10] [![Go Report Card][11]][12] 7 | 8 | [1]: https://img.shields.io/badge/join-us%20on%20slack-gray.svg?longCache=true&logo=slack&colorB=green 9 | [2]: https://join.slack.com/t/arpcnbio/shared_invite/zt-vh3g1z2v-qqoDp1hQ45fJZqwPrSz4~Q 10 | [3]: https://awesome.re/mentioned-badge-flat.svg 11 | [4]: https://github.com/avelino/awesome-go#networking 12 | [5]: https://img.shields.io/badge/license-MIT-blue.svg 13 | [6]: LICENSE 14 | [7]: https://img.shields.io/badge/go-%3E%3D1.16-30dff3?style=flat-square&logo=go 15 | [8]: https://github.com/lesismal/nbio 16 | [9]: https://img.shields.io/github/actions/workflow/status/lesismal/nbio/autobahn.yml?branch=master&style=flat-square&logo=github-actions 17 | [10]: https://github.com/lesismal/nbio/actions?query=workflow%3autobahn 18 | [11]: https://goreportcard.com/badge/github.com/lesismal/nbio 19 | [12]: https://goreportcard.com/report/github.com/lesismal/nbio 20 | [13]: https://codecov.io/gh/lesismal/nbio/branch/master/graph/badge.svg 21 | [14]: https://codecov.io/gh/lesismal/nbio 22 | [15]: https://godoc.org/github.com/lesismal/nbio?status.svg 23 | [16]: https://godoc.org/github.com/lesismal/nbio 24 | 25 | 26 | ## Contents 27 | 28 | - [NBIO - NON-BLOCKING IO](#nbio---non-blocking-io) 29 | - [Contents](#contents) 30 | - [Features](#features) 31 | - [Cross Platform](#cross-platform) 32 | - [Protocols Supported](#protocols-supported) 33 | - [Interfaces](#interfaces) 34 | - [Quick Start](#quick-start) 35 | - [Examples](#examples) 36 | - [TCP Echo Examples](#tcp-echo-examples) 37 | - [UDP Echo Examples](#udp-echo-examples) 38 | - [TLS Examples](#tls-examples) 39 | - [HTTP Examples](#http-examples) 40 | - [HTTPS Examples](#https-examples) 41 | - [Websocket Examples](#websocket-examples) 42 | - [Websocket TLS Examples](#websocket-tls-examples) 43 | - [Use With Other STD Based Frameworkds](#use-with-other-std-based-frameworkds) 44 | - [More Examples](#more-examples) 45 | - [1M Websocket Connections Benchmark](#1m-websocket-connections-benchmark) 46 | - [Magics For HTTP and Websocket](#magics-for-http-and-websocket) 47 | - [Different IOMod](#different-iomod) 48 | - [Using Websocket With Std Server](#using-websocket-with-std-server) 49 | - [Credits](#credits) 50 | - [Contributors](#contributors) 51 | - [Star History](#star-history) 52 | 53 | ## Features 54 | ### Cross Platform 55 | - [x] Linux: Epoll with LT/ET/ET+ONESHOT supported, LT as default 56 | - [x] BSD(MacOS): Kqueue 57 | - [x] Windows: Based on std net, for debugging only 58 | 59 | ### Protocols Supported 60 | - [x] TCP/UDP/Unix Socket supported 61 | - [x] TLS supported 62 | - [x] HTTP/HTTPS 1.x supported 63 | - [x] Websocket supported, [Passes the Autobahn Test Suite](https://lesismal.github.io/nbio/websocket/autobahn), `OnOpen/OnMessage/OnClose` order guaranteed 64 | 65 | ### Interfaces 66 | - [x] Implements a non-blocking net.Conn(except windows) 67 | - [x] SetDeadline/SetReadDeadline/SetWriteDeadline supported 68 | - [x] Concurrent Write/Close supported(both nbio.Conn and nbio/nbhttp/websocket.Conn) 69 | 70 | 71 | ## Quick Start 72 | 73 | ```golang 74 | package main 75 | 76 | import ( 77 | "log" 78 | 79 | "github.com/lesismal/nbio" 80 | ) 81 | 82 | func main() { 83 | engine := nbio.NewEngine(nbio.Config{ 84 | Network: "tcp",//"udp", "unix" 85 | Addrs: []string{":8888"}, 86 | MaxWriteBufferSize: 6 * 1024 * 1024, 87 | }) 88 | 89 | // hanlde new connection 90 | engine.OnOpen(func(c *nbio.Conn) { 91 | log.Println("OnOpen:", c.RemoteAddr().String()) 92 | }) 93 | // hanlde connection closed 94 | engine.OnClose(func(c *nbio.Conn, err error) { 95 | log.Println("OnClose:", c.RemoteAddr().String(), err) 96 | }) 97 | // handle data 98 | engine.OnData(func(c *nbio.Conn, data []byte) { 99 | c.Write(append([]byte{}, data...)) 100 | }) 101 | 102 | err := engine.Start() 103 | if err != nil { 104 | log.Fatalf("nbio.Start failed: %v\n", err) 105 | return 106 | } 107 | defer engine.Stop() 108 | 109 | <-make(chan int) 110 | } 111 | ``` 112 | 113 | ## Examples 114 | ### TCP Echo Examples 115 | 116 | - [echo-server](https://github.com/lesismal/nbio_examples/blob/master/echo/server/server.go) 117 | - [echo-client](https://github.com/lesismal/nbio_examples/blob/master/echo/client/client.go) 118 | 119 | ### UDP Echo Examples 120 | 121 | - [udp-server](https://github.com/lesismal/nbio-examples/blob/master/udp/server/server.go) 122 | - [udp-client](https://github.com/lesismal/nbio-examples/blob/master/udp/client/client.go) 123 | 124 | ### TLS Examples 125 | 126 | - [tls-server](https://github.com/lesismal/nbio_examples/blob/master/tls/server/server.go) 127 | - [tls-client](https://github.com/lesismal/nbio_examples/blob/master/tls/client/client.go) 128 | 129 | ### HTTP Examples 130 | 131 | - [http-server](https://github.com/lesismal/nbio_examples/blob/master/http/server/server.go) 132 | - [http-client](https://github.com/lesismal/nbio_examples/blob/master/http/client/client.go) 133 | 134 | ### HTTPS Examples 135 | 136 | - [http-tls_server](https://github.com/lesismal/nbio_examples/blob/master/http/server_tls/server.go) 137 | - visit: https://localhost:8888/echo 138 | 139 | ### Websocket Examples 140 | 141 | - [websocket-server](https://github.com/lesismal/nbio_examples/blob/master/websocket/server/server.go) 142 | - [websocket-client](https://github.com/lesismal/nbio_examples/blob/master/websocket/client/client.go) 143 | 144 | ### Websocket TLS Examples 145 | 146 | - [websocket-tls-server](https://github.com/lesismal/nbio_examples/blob/master/websocket_tls/server/server.go) 147 | - [websocket-tls-client](https://github.com/lesismal/nbio_examples/blob/master/websocket_tls/client/client.go) 148 | 149 | ### Use With Other STD Based Frameworkds 150 | 151 | - [echo-http-and-websocket-server](https://github.com/lesismal/nbio_examples/blob/master/http_with_other_frameworks/echo_server/echo_server.go) 152 | - [gin-http-and-websocket-server](https://github.com/lesismal/nbio_examples/blob/master/http_with_other_frameworks/gin_server/gin_server.go) 153 | - [go-chi-http-and-websocket-server](https://github.com/lesismal/nbio_examples/blob/master/http_with_other_frameworks/go-chi_server/go-chi_server.go) 154 | 155 | ### More Examples 156 | 157 | - [nbio-examples](https://github.com/lesismal/nbio-examples) 158 | 159 | 160 | 161 | ## 1M Websocket Connections Benchmark 162 | 163 | For more details: [go-websocket-benchmark](https://github.com/lesismal/go-websocket-benchmark) 164 | 165 | ```sh 166 | # lsb_release -a 167 | LSB Version: core-11.1.0ubuntu2-noarch:security-11.1.0ubuntu2-noarch 168 | Distributor ID: Ubuntu 169 | Description: Ubuntu 20.04.6 LTS 170 | Release: 20.04 171 | Codename: focal 172 | 173 | # free 174 | total used free shared buff/cache available 175 | Mem: 24969564 15656352 3422212 1880 5891000 8899604 176 | Swap: 0 0 0 177 | 178 | # cat /proc/cpuinfo | grep processor 179 | processor : 0 180 | processor : 1 181 | processor : 2 182 | processor : 3 183 | processor : 4 184 | processor : 5 185 | processor : 6 186 | processor : 7 187 | processor : 8 188 | processor : 9 189 | processor : 10 190 | processor : 11 191 | processor : 12 192 | processor : 13 193 | processor : 14 194 | processor : 15 195 | 196 | 197 | # taskset 198 | run nbio_nonblocking server on cpu 0-7 199 | 200 | -------------------------------------------------------------- 201 | BenchType : BenchEcho 202 | Framework : nbio_nonblocking 203 | TPS : 104713 204 | EER : 280.33 205 | Min : 56.90us 206 | Avg : 95.36ms 207 | Max : 2.29s 208 | TP50 : 62.82ms 209 | TP75 : 65.38ms 210 | TP90 : 89.38ms 211 | TP95 : 409.55ms 212 | TP99 : 637.95ms 213 | Used : 47.75s 214 | Total : 5000000 215 | Success : 5000000 216 | Failed : 0 217 | Conns : 1000000 218 | Concurrency: 10000 219 | Payload : 1024 220 | CPU Min : 0.00% 221 | CPU Avg : 373.53% 222 | CPU Max : 602.33% 223 | MEM Min : 978.70M 224 | MEM Avg : 979.88M 225 | MEM Max : 981.14M 226 | -------------------------------------------------------------- 227 | ``` 228 | 229 | 230 | ## Magics For HTTP and Websocket 231 | 232 | ### Different IOMod 233 | 234 | | IOMod | Remarks | 235 | | ---------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | 236 | | IOModNonBlocking | There's no difference between this IOMod and the old version with no IOMod. All the connections will be handled by poller. | 237 | | IOModBlocking | All the connections will be handled by at least one goroutine, for websocket, we can set Upgrader.BlockingModAsyncWrite=true to handle writing with a separated goroutine and then avoid Head-of-line blocking on broadcasting scenarios. | 238 | | IOModMixed | We set the Engine.MaxBlockingOnline, if the online num is smaller than it, the new connection will be handled by single goroutine as IOModBlocking, else the new connection will be handled by poller. | 239 | 240 | The `IOModBlocking` aims to improve the performance for low online service, it runs faster than std. 241 | The `IOModMixed` aims to keep a balance between performance and cpu/mem cost in different scenarios: when there are not too many online connections, it performs better than std, or else it can serve lots of online connections and keep healthy. 242 | 243 | ### Using Websocket With Std Server 244 | 245 | ```golang 246 | package main 247 | 248 | import ( 249 | "fmt" 250 | "net/http" 251 | 252 | "github.com/lesismal/nbio/nbhttp/websocket" 253 | ) 254 | 255 | var ( 256 | upgrader = newUpgrader() 257 | ) 258 | 259 | func newUpgrader() *websocket.Upgrader { 260 | u := websocket.NewUpgrader() 261 | u.OnOpen(func(c *websocket.Conn) { 262 | // echo 263 | fmt.Println("OnOpen:", c.RemoteAddr().String()) 264 | }) 265 | u.OnMessage(func(c *websocket.Conn, messageType websocket.MessageType, data []byte) { 266 | // echo 267 | fmt.Println("OnMessage:", messageType, string(data)) 268 | c.WriteMessage(messageType, data) 269 | }) 270 | u.OnClose(func(c *websocket.Conn, err error) { 271 | fmt.Println("OnClose:", c.RemoteAddr().String(), err) 272 | }) 273 | return u 274 | } 275 | 276 | func onWebsocket(w http.ResponseWriter, r *http.Request) { 277 | conn, err := upgrader.Upgrade(w, r, nil) 278 | if err != nil { 279 | panic(err) 280 | } 281 | fmt.Println("Upgraded:", conn.RemoteAddr().String()) 282 | } 283 | 284 | func main() { 285 | mux := &http.ServeMux{} 286 | mux.HandleFunc("/ws", onWebsocket) 287 | server := http.Server{ 288 | Addr: "localhost:8080", 289 | Handler: mux, 290 | } 291 | fmt.Println("server exit:", server.ListenAndServe()) 292 | } 293 | ``` 294 | 295 | 296 | ## Credits 297 | - [xtaci/gaio](https://github.com/xtaci/gaio) 298 | - [gorilla/websocket](https://github.com/gorilla/websocket) 299 | - [crossbario/autobahn](https://github.com/crossbario) 300 | 301 | 302 | ## Contributors 303 | Thanks Everyone: 304 | - [acgreek](https://github.com/acgreek) 305 | - [acsecureworks](https://github.com/acsecureworks) 306 | - [arunsathiya](https://github.com/arunsathiya) 307 | - [guonaihong](https://github.com/guonaihong) 308 | - [isletnet](https://github.com/isletnet) 309 | - [liwnn](https://github.com/liwnn) 310 | - [manjun21](https://github.com/manjun21) 311 | - [om26er](https://github.com/om26er) 312 | - [rfyiamcool](https://github.com/rfyiamcool) 313 | - [sunny352](https://github.com/sunny352) 314 | - [sunvim](https://github.com/sunvim) 315 | - [wuqinqiang](https://github.com/wuqinqiang) 316 | - [wziww](https://github.com/wziww) 317 | - [youzhixiaomutou](https://github.com/youzhixiaomutou) 318 | - [zbh255](https://github.com/zbh255) 319 | - [IceflowRE](https://github.com/IceflowRE) 320 | - [YanKawaYu](https://github.com/YanKawaYu) 321 | 322 | 323 | ## Star History 324 | 325 | [![Star History Chart](https://api.star-history.com/svg?repos=lesismal/nbio&type=Date)](https://star-history.com/#lesismal/nbio&Date) 326 | -------------------------------------------------------------------------------- /autobahn/.gitignore: -------------------------------------------------------------------------------- 1 | report/ 2 | -------------------------------------------------------------------------------- /autobahn/config/fuzzingclient.json: -------------------------------------------------------------------------------- 1 | { 2 | "outdir": "./report", 3 | "servers": [ 4 | { 5 | "agent": "non-tls", 6 | "url": "ws://localhost:9998/echo/message", 7 | "options": { 8 | "version": 18 9 | } 10 | }, 11 | { 12 | "agent": "tls", 13 | "url": "wss://localhost:9999/echo/message", 14 | "options": { 15 | "version": 18 16 | } 17 | } 18 | ], 19 | "cases": [ 20 | "*" 21 | ], 22 | "exclude-cases": ["1[1-4].*"], 23 | "exclude-agent-cases": {} 24 | } 25 | -------------------------------------------------------------------------------- /autobahn/reporter/reporter.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "flag" 6 | "fmt" 7 | "html/template" 8 | "log" 9 | "net/http" 10 | "os" 11 | "path" 12 | "sort" 13 | "strconv" 14 | "strings" 15 | "text/tabwriter" 16 | ) 17 | 18 | var ( 19 | verbose = flag.Bool("verbose", false, "be verbose") 20 | web = flag.String("http", "", "open web browser instead") 21 | ) 22 | 23 | const ( 24 | statusOK = "OK" 25 | statusInformational = "INFORMATIONAL" 26 | statusUnimplemented = "UNIMPLEMENTED" 27 | statusNonStrict = "NON-STRICT" 28 | statusUnclean = "UNCLEAN" 29 | statusFailed = "FAILED" 30 | ) 31 | 32 | //go:norace 33 | func failing(behavior string) bool { 34 | switch behavior { 35 | // case statusUnclean, statusFailed, statusNonStrict: // we should probably fix the nonstrict as well at some point 36 | case statusUnclean, statusFailed: 37 | return true 38 | default: 39 | return false 40 | } 41 | } 42 | 43 | type statusCounter struct { 44 | Total int 45 | OK int 46 | Informational int 47 | Unimplemented int 48 | NonStrict int 49 | Unclean int 50 | Failed int 51 | } 52 | 53 | //go:norace 54 | func (c *statusCounter) Inc(s string) { 55 | c.Total++ 56 | switch s { 57 | case statusOK: 58 | c.OK++ 59 | case statusInformational: 60 | c.Informational++ 61 | case statusNonStrict: 62 | c.NonStrict++ 63 | case statusUnimplemented: 64 | c.Unimplemented++ 65 | case statusUnclean: 66 | c.Unclean++ 67 | case statusFailed: 68 | c.Failed++ 69 | default: 70 | panic(fmt.Sprintf("unexpected status %q", s)) 71 | } 72 | } 73 | 74 | //go:norace 75 | func main() { 76 | log.SetFlags(0) 77 | flag.Parse() 78 | 79 | if flag.NArg() < 1 { 80 | log.Fatalf("Usage: %s [options] ", os.Args[0]) 81 | } 82 | 83 | base := path.Dir(flag.Arg(0)) 84 | 85 | if addr := *web; addr != "" { 86 | http.HandleFunc("/", handlerIndex()) 87 | http.Handle("/report/", http.StripPrefix("/report/", 88 | http.FileServer(http.Dir(base)), 89 | )) 90 | log.Fatal(http.ListenAndServe(addr, nil)) 91 | return 92 | } 93 | 94 | var report report 95 | if err := decodeFile(os.Args[1], &report); err != nil { 96 | log.Fatal(err) 97 | } 98 | 99 | var servers []string 100 | for s := range report { 101 | servers = append(servers, s) 102 | } 103 | sort.Strings(servers) 104 | 105 | var ( 106 | failed bool 107 | ) 108 | tw := tabwriter.NewWriter(os.Stderr, 0, 4, 1, ' ', 0) 109 | for _, server := range servers { 110 | var ( 111 | srvFailed bool 112 | hdrWritten bool 113 | counter statusCounter 114 | ) 115 | 116 | var cases []string 117 | for id := range report[server] { 118 | cases = append(cases, id) 119 | } 120 | sortBySegment(cases) 121 | for _, id := range cases { 122 | c := report[server][id] 123 | 124 | var r entryReport 125 | err := decodeFile(path.Join(base, c.ReportFile), &r) 126 | if err != nil { 127 | log.Fatal(err) 128 | } 129 | counter.Inc(c.Behavior) 130 | bad := failing(c.Behavior) 131 | if bad { 132 | srvFailed = true 133 | failed = true 134 | } 135 | if *verbose || bad { 136 | if !hdrWritten { 137 | hdrWritten = true 138 | n, _ := fmt.Fprintf(os.Stderr, "AGENT %q\n", server) 139 | fmt.Fprintf(tw, "%s\n", strings.Repeat("=", n-1)) 140 | } 141 | fmt.Fprintf(tw, "%s\t%s\t%s\n", server, id, c.Behavior) 142 | } 143 | if bad { 144 | fmt.Fprintf(tw, "\tdesc:\t%s\n", r.Description) 145 | fmt.Fprintf(tw, "\texp: \t%s\n", r.Expectation) 146 | fmt.Fprintf(tw, "\tact: \t%s\n", r.Result) 147 | } 148 | } 149 | if hdrWritten { 150 | fmt.Fprint(tw, "\n") 151 | } 152 | var status string 153 | if srvFailed { 154 | status = statusFailed 155 | } else { 156 | status = statusOK 157 | } 158 | n, _ := fmt.Fprintf(tw, "AGENT %q SUMMARY (%s)\n", server, status) 159 | fmt.Fprintf(tw, "%s\n", strings.Repeat("=", n-1)) 160 | 161 | fmt.Fprintf(tw, "TOTAL:\t%d\n", counter.Total) 162 | fmt.Fprintf(tw, "%s:\t%d\n", statusOK, counter.OK) 163 | fmt.Fprintf(tw, "%s:\t%d\n", statusInformational, counter.Informational) 164 | fmt.Fprintf(tw, "%s:\t%d\n", statusUnimplemented, counter.Unimplemented) 165 | fmt.Fprintf(tw, "%s:\t%d\n", statusNonStrict, counter.NonStrict) 166 | fmt.Fprintf(tw, "%s:\t%d\n", statusUnclean, counter.Unclean) 167 | fmt.Fprintf(tw, "%s:\t%d\n", statusFailed, counter.Failed) 168 | fmt.Fprint(tw, "\n") 169 | tw.Flush() 170 | } 171 | var rc int 172 | if failed { 173 | rc = 1 174 | fmt.Fprintf(tw, "\n\nTEST %s\n\n", statusFailed) 175 | } else { 176 | fmt.Fprintf(tw, "\n\nTEST %s\n\n", statusOK) 177 | } 178 | 179 | tw.Flush() 180 | os.Exit(rc) 181 | } 182 | 183 | type report map[string]server 184 | 185 | type server map[string]entry 186 | 187 | type entry struct { 188 | Behavior string `json:"behavior"` 189 | BehaviorClose string `json:"behaviorClose"` 190 | Duration int `json:"duration"` 191 | RemoveCloseCode int `json:"removeCloseCode"` 192 | ReportFile string `json:"reportFile"` 193 | } 194 | 195 | type entryReport struct { 196 | Description string `json:"description"` 197 | Expectation string `json:"expectation"` 198 | Result string `json:"result"` 199 | Duration int `json:"duration"` 200 | } 201 | 202 | //go:norace 203 | func decodeFile(path string, x interface{}) error { 204 | f, err := os.Open(path) 205 | if err != nil { 206 | return err 207 | } 208 | defer f.Close() 209 | 210 | d := json.NewDecoder(f) 211 | return d.Decode(x) 212 | } 213 | 214 | //go:norace 215 | func compareBySegment(a, b string) int { 216 | as := strings.Split(a, ".") 217 | bs := strings.Split(b, ".") 218 | for i := 0; i < min(len(as), len(bs)); i++ { 219 | ax := mustInt(as[i]) 220 | bx := mustInt(bs[i]) 221 | if ax == bx { 222 | continue 223 | } 224 | return int(ax - bx) 225 | } 226 | return len(b) - len(a) 227 | } 228 | 229 | //go:norace 230 | func mustInt(s string) int64 { 231 | const bits = 32 << (^uint(0) >> 63) 232 | x, err := strconv.ParseInt(s, 10, bits) 233 | if err != nil { 234 | panic(err) 235 | } 236 | return x 237 | } 238 | 239 | //go:norace 240 | func min(a, b int) int { 241 | if a < b { 242 | return a 243 | } 244 | return b 245 | } 246 | 247 | //go:norace 248 | func handlerIndex() func(w http.ResponseWriter, r *http.Request) { 249 | return func(w http.ResponseWriter, r *http.Request) { 250 | path := r.URL.Path 251 | if path != "/" { 252 | w.WriteHeader(http.StatusNotFound) 253 | return 254 | } 255 | if err := index.Execute(w, nil); err != nil { 256 | w.WriteHeader(http.StatusInternalServerError) 257 | log.Fatal(err) 258 | return 259 | } 260 | } 261 | } 262 | 263 | var index = template.Must(template.New("").Parse(` 264 | 265 | 266 |

Welcome to WebSocket test server!

267 |

Ready to Autobahn!

268 | Reports 269 | 270 | 271 | `)) 272 | 273 | //go:norace 274 | func sortBySegment(s []string) { 275 | sort.Slice(s, func(i, j int) bool { 276 | return compareBySegment(s[i], s[j]) < 0 277 | }) 278 | } 279 | -------------------------------------------------------------------------------- /autobahn/script/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir -p ./autobahn/bin 4 | go build -o ./autobahn/bin/autobahn_server ./autobahn/server/ 5 | go build -o ./autobahn/bin/autobahn_reporter ./autobahn/reporter/ 6 | 7 | echo "pwd:" $(pwd) 8 | ./autobahn/bin/autobahn_server & 9 | 10 | rm -rf ${PWD}/autobahn/report 11 | mkdir -p ${PWD}/autobahn/report/ 12 | 13 | docker pull crossbario/autobahn-testsuite 14 | 15 | docker run -i --rm \ 16 | -v ${PWD}/autobahn/config:/config \ 17 | -v ${PWD}/autobahn/report:/report \ 18 | --network host \ 19 | --name=autobahn \ 20 | crossbario/autobahn-testsuite \ 21 | wstest -m fuzzingclient -s /config/fuzzingclient.json 22 | 23 | trap ctrl_c INT 24 | ctrl_c() { 25 | echo "SIGINT received; cleaning up" 26 | docker kill --signal INT "autobahn" >/dev/null 27 | rm -rf ${PWD}/autobahn/bin 28 | rm -rf ${PWD}/autobahn/report 29 | cleanup 30 | exit 130 31 | } 32 | 33 | cleanup() { 34 | killall autobahn_server 35 | } 36 | 37 | ./autobahn/bin/autobahn_reporter ${PWD}/autobahn/report/index.json 38 | 39 | cleanup 40 | -------------------------------------------------------------------------------- /autobahn/server/server.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "net/http" 7 | "os" 8 | "os/signal" 9 | "time" 10 | 11 | "github.com/lesismal/llib/std/crypto/tls" 12 | "github.com/lesismal/nbio/nbhttp" 13 | "github.com/lesismal/nbio/nbhttp/websocket" 14 | "github.com/lesismal/nbio/taskpool" 15 | ) 16 | 17 | //go:norace 18 | func newUpgrader(isDataFrame bool) *websocket.Upgrader { 19 | u := websocket.NewUpgrader() 20 | u.EnableCompression(true) 21 | if isDataFrame { 22 | isFirst := true 23 | u.OnDataFrame(func(c *websocket.Conn, messageType websocket.MessageType, fin bool, data []byte) { 24 | err := c.WriteFrame(messageType, isFirst, fin, data) 25 | if err != nil { 26 | c.Close() 27 | return 28 | } 29 | if fin { 30 | isFirst = true 31 | } else { 32 | isFirst = false 33 | } 34 | }) 35 | } else { 36 | u.OnMessage(func(c *websocket.Conn, messageType websocket.MessageType, data []byte) { 37 | c.WriteMessage(messageType, data) 38 | }) 39 | } 40 | 41 | return u 42 | } 43 | 44 | //go:norace 45 | func onWebsocketFrame(w http.ResponseWriter, r *http.Request) { 46 | upgrader := newUpgrader(true) 47 | conn, err := upgrader.Upgrade(w, r, nil) 48 | if err != nil { 49 | panic(err) 50 | } 51 | conn.SetDeadline(time.Time{}) 52 | } 53 | 54 | //go:norace 55 | func onWebsocketMessage(w http.ResponseWriter, r *http.Request) { 56 | upgrader := newUpgrader(false) 57 | conn, err := upgrader.Upgrade(w, r, nil) 58 | if err != nil { 59 | panic(err) 60 | } 61 | conn.SetDeadline(time.Time{}) 62 | } 63 | 64 | //go:norace 65 | func main() { 66 | cert, err := tls.X509KeyPair(rsaCertPEM, rsaKeyPEM) 67 | if err != nil { 68 | log.Fatalf("tls.X509KeyPair failed: %v", err) 69 | } 70 | tlsConfig := &tls.Config{ 71 | Certificates: []tls.Certificate{cert}, 72 | InsecureSkipVerify: true, 73 | } 74 | 75 | mux := &http.ServeMux{} 76 | mux.HandleFunc("/echo/message", onWebsocketMessage) 77 | mux.HandleFunc("/echo/frame", onWebsocketFrame) 78 | 79 | log.Printf("calling new server tls\n") 80 | 81 | messageHandlerExecutePool := taskpool.New(100, 1000) 82 | svrTLS := nbhttp.NewServer(nbhttp.Config{ 83 | Network: "tcp", 84 | AddrsTLS: []string{"localhost:9999"}, 85 | TLSConfig: tlsConfig, 86 | ReadBufferSize: 1024 * 1024, 87 | Handler: mux, 88 | ServerExecutor: messageHandlerExecutePool.Go, 89 | }) 90 | svr := nbhttp.NewServer(nbhttp.Config{ 91 | Network: "tcp", 92 | Addrs: []string{"localhost:9998"}, 93 | ReadBufferSize: 1024 * 1024, 94 | Handler: mux, 95 | ServerExecutor: messageHandlerExecutePool.Go, 96 | }) 97 | 98 | log.Printf("calling start non-tls\n") 99 | err = svr.Start() 100 | if err != nil { 101 | fmt.Printf("nbio.Start non-tls failed: %v\n", err) 102 | return 103 | } 104 | defer svr.Stop() 105 | 106 | log.Printf("calling start tls\n") 107 | err = svrTLS.Start() 108 | if err != nil { 109 | fmt.Printf("nbio.Start tls failed: %v\n", err) 110 | return 111 | } 112 | defer svrTLS.Stop() 113 | 114 | interrupt := make(chan os.Signal, 1) 115 | signal.Notify(interrupt, os.Interrupt) 116 | <-interrupt 117 | log.Println("exit") 118 | } 119 | 120 | var rsaCertPEM = []byte(`-----BEGIN CERTIFICATE----- 121 | MIIDazCCAlOgAwIBAgIUJeohtgk8nnt8ofratXJg7kUJsI4wDQYJKoZIhvcNAQEL 122 | BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM 123 | GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yMDEyMDcwODIyNThaFw0zMDEy 124 | MDUwODIyNThaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw 125 | HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEB 126 | AQUAA4IBDwAwggEKAoIBAQCy+ZrIvwwiZv4bPmvKx/637ltZLwfgh3ouiEaTchGu 127 | IQltthkqINHxFBqqJg44TUGHWthlrq6moQuKnWNjIsEc6wSD1df43NWBLgdxbPP0 128 | x4tAH9pIJU7TQqbznjDBhzRbUjVXBIcn7bNknY2+5t784pPF9H1v7h8GqTWpNH9l 129 | cz/v+snoqm9HC+qlsFLa4A3X9l5v05F1uoBfUALlP6bWyjHAfctpiJkoB9Yw1TJa 130 | gpq7E50kfttwfKNkkAZIbib10HugkMoQJAs2EsGkje98druIl8IXmuvBIF6nZHuM 131 | lt3UIZjS9RwPPLXhRHt1P0mR7BoBcOjiHgtSEs7Wk+j7AgMBAAGjUzBRMB0GA1Ud 132 | DgQWBBQdheJv73XSOhgMQtkwdYPnfO02+TAfBgNVHSMEGDAWgBQdheJv73XSOhgM 133 | QtkwdYPnfO02+TAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQBf 134 | SKVNMdmBpD9m53kCrguo9iKQqmhnI0WLkpdWszc/vBgtpOE5ENOfHGAufHZve871 135 | 2fzTXrgR0TF6UZWsQOqCm5Oh3URsCdXWewVMKgJ3DCii6QJ0MnhSFt6+xZE9C6Hi 136 | WhcywgdR8t/JXKDam6miohW8Rum/IZo5HK9Jz/R9icKDGumcqoaPj/ONvY4EUwgB 137 | irKKB7YgFogBmCtgi30beLVkXgk0GEcAf19lHHtX2Pv/lh3m34li1C9eBm1ca3kk 138 | M2tcQtm1G89NROEjcG92cg+GX3GiWIjbI0jD1wnVy2LCOXMgOVbKfGfVKISFt0b1 139 | DNn00G8C6ttLoGU2snyk 140 | -----END CERTIFICATE----- 141 | `) 142 | 143 | var rsaKeyPEM = []byte(`-----BEGIN RSA PRIVATE KEY----- 144 | MIIEogIBAAKCAQEAsvmayL8MImb+Gz5rysf+t+5bWS8H4Id6LohGk3IRriEJbbYZ 145 | KiDR8RQaqiYOOE1Bh1rYZa6upqELip1jYyLBHOsEg9XX+NzVgS4HcWzz9MeLQB/a 146 | SCVO00Km854wwYc0W1I1VwSHJ+2zZJ2Nvube/OKTxfR9b+4fBqk1qTR/ZXM/7/rJ 147 | 6KpvRwvqpbBS2uAN1/Zeb9ORdbqAX1AC5T+m1soxwH3LaYiZKAfWMNUyWoKauxOd 148 | JH7bcHyjZJAGSG4m9dB7oJDKECQLNhLBpI3vfHa7iJfCF5rrwSBep2R7jJbd1CGY 149 | 0vUcDzy14UR7dT9JkewaAXDo4h4LUhLO1pPo+wIDAQABAoIBAF6yWwekrlL1k7Xu 150 | jTI6J7hCUesaS1yt0iQUzuLtFBXCPS7jjuUPgIXCUWl9wUBhAC8SDjWe+6IGzAiH 151 | xjKKDQuz/iuTVjbDAeTb6exF7b6yZieDswdBVjfJqHR2Wu3LEBTRpo9oQesKhkTS 152 | aFF97rZ3XCD9f/FdWOU5Wr8wm8edFK0zGsZ2N6r57yf1N6ocKlGBLBZ0v1Sc5ShV 153 | 1PVAxeephQvwL5DrOgkArnuAzwRXwJQG78L0aldWY2q6xABQZQb5+ml7H/kyytef 154 | i+uGo3jHKepVALHmdpCGr9Yv+yCElup+ekv6cPy8qcmMBqGMISL1i1FEONxLcKWp 155 | GEJi6QECgYEA3ZPGMdUm3f2spdHn3C+/+xskQpz6efiPYpnqFys2TZD7j5OOnpcP 156 | ftNokA5oEgETg9ExJQ8aOCykseDc/abHerYyGw6SQxmDbyBLmkZmp9O3iMv2N8Pb 157 | Nrn9kQKSr6LXZ3gXzlrDvvRoYUlfWuLSxF4b4PYifkA5AfsdiKkj+5sCgYEAzseF 158 | XDTRKHHJnzxZDDdHQcwA0G9agsNj64BGUEjsAGmDiDyqOZnIjDLRt0O2X3oiIE5S 159 | TXySSEiIkxjfErVJMumLaIwqVvlS4pYKdQo1dkM7Jbt8wKRQdleRXOPPN7msoEUk 160 | Ta9ZsftHVUknPqblz9Uthb5h+sRaxIaE1llqDiECgYATS4oHzuL6k9uT+Qpyzymt 161 | qThoIJljQ7TgxjxvVhD9gjGV2CikQM1Vov1JBigj4Toc0XuxGXaUC7cv0kAMSpi2 162 | Y+VLG+K6ux8J70sGHTlVRgeGfxRq2MBfLKUbGplBeDG/zeJs0tSW7VullSkblgL6 163 | nKNa3LQ2QEt2k7KHswryHwKBgENDxk8bY1q7wTHKiNEffk+aFD25q4DUHMH0JWti 164 | fVsY98+upFU+gG2S7oOmREJE0aser0lDl7Zp2fu34IEOdfRY4p+s0O0gB+Vrl5VB 165 | L+j7r9bzaX6lNQN6MvA7ryHahZxRQaD/xLbQHgFRXbHUyvdTyo4yQ1821qwNclLk 166 | HUrhAoGAUtjR3nPFR4TEHlpTSQQovS8QtGTnOi7s7EzzdPWmjHPATrdLhMA0ezPj 167 | Mr+u5TRncZBIzAZtButlh1AHnpN/qO3P0c0Rbdep3XBc/82JWO8qdb5QvAkxga3X 168 | BpA7MNLxiqss+rCbwf3NbWxEMiDQ2zRwVoafVFys7tjmv6t2Xck= 169 | -----END RSA PRIVATE KEY----- 170 | `) 171 | -------------------------------------------------------------------------------- /conn.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 lesismal. All rights reserved. 2 | // Use of this source code is governed by an MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package nbio 6 | 7 | import ( 8 | "net" 9 | "runtime" 10 | "time" 11 | "unsafe" 12 | 13 | "github.com/lesismal/nbio/logging" 14 | ) 15 | 16 | // ConnType is used to identify different types of Conn. 17 | type ConnType = int8 18 | 19 | const ( 20 | // ConnTypeTCP represents TCP Conn. 21 | ConnTypeTCP ConnType = iota + 1 22 | // ConnTypeUDPServer represents UDP Conn used as a listener. 23 | ConnTypeUDPServer 24 | // ConnTypeUDPClientFromRead represents UDP connection that 25 | // is sending data to our UDP Server from peer. 26 | ConnTypeUDPClientFromRead 27 | // ConnTypeUDPClientFromDial represents UDP Conn that is sending 28 | // data to other UDP Server from ourselves. 29 | ConnTypeUDPClientFromDial 30 | // ConnTypeUnix represents Unix Conn. 31 | ConnTypeUnix 32 | ) 33 | 34 | // Type . 35 | // 36 | //go:norace 37 | func (c *Conn) Type() ConnType { 38 | return c.typ 39 | } 40 | 41 | // IsTCP returns whether this Conn is a TCP Conn. 42 | // 43 | //go:norace 44 | func (c *Conn) IsTCP() bool { 45 | return c.typ == ConnTypeTCP 46 | } 47 | 48 | // IsUDP returns whether this Conn is a UDP Conn. 49 | // 50 | //go:norace 51 | func (c *Conn) IsUDP() bool { 52 | switch c.typ { 53 | case ConnTypeUDPServer, ConnTypeUDPClientFromDial, ConnTypeUDPClientFromRead: 54 | return true 55 | } 56 | return false 57 | } 58 | 59 | // IsUnix returns whether this Conn is a Unix Conn. 60 | // 61 | //go:norace 62 | func (c *Conn) IsUnix() bool { 63 | return c.typ == ConnTypeUnix 64 | } 65 | 66 | // Session returns user session. 67 | // 68 | //go:norace 69 | func (c *Conn) Session() interface{} { 70 | return c.session 71 | } 72 | 73 | // SetSession sets user session. 74 | // 75 | //go:norace 76 | func (c *Conn) SetSession(session interface{}) { 77 | c.session = session 78 | } 79 | 80 | // OnData registers Conn's data handler. 81 | // Notice: 82 | // 1. The data readed by the poller is not handled by this Conn's data 83 | // handler by default. 84 | // 2. The data readed by the poller is handled by nbio.Engine's data 85 | // handler which is registered by nbio.Engine.OnData by default. 86 | // 3. This Conn's data handler is used to customize your implementation, 87 | // you can set different data handler for different Conns, 88 | // and call Conn's data handler in nbio.Engine's data handler. 89 | // For example: 90 | // engine.OnData(func(c *nbio.Conn, data byte){ 91 | // c.DataHandler()(c, data) 92 | // }) 93 | // conn1.OnData(yourDatahandler1) 94 | // conn2.OnData(yourDatahandler2) 95 | // 96 | //go:norace 97 | func (c *Conn) OnData(h func(conn *Conn, data []byte)) { 98 | c.dataHandler = h 99 | } 100 | 101 | // DataHandler returns Conn's data handler. 102 | // 103 | //go:norace 104 | func (c *Conn) DataHandler() func(conn *Conn, data []byte) { 105 | return c.dataHandler 106 | } 107 | 108 | // Dial calls net.Dial to make a net.Conn and convert it to *nbio.Conn. 109 | // 110 | //go:norace 111 | func Dial(network string, address string) (*Conn, error) { 112 | conn, err := net.Dial(network, address) 113 | if err != nil { 114 | return nil, err 115 | } 116 | return NBConn(conn) 117 | } 118 | 119 | // Dial calls net.DialTimeout to make a net.Conn and convert it to *nbio.Conn. 120 | // 121 | //go:norace 122 | func DialTimeout(network string, address string, timeout time.Duration) (*Conn, error) { 123 | conn, err := net.DialTimeout(network, address, timeout) 124 | if err != nil { 125 | return nil, err 126 | } 127 | return NBConn(conn) 128 | } 129 | 130 | // Lock . 131 | // 132 | //go:norace 133 | func (c *Conn) Lock() { 134 | c.mux.Lock() 135 | } 136 | 137 | // Unlock . 138 | // 139 | //go:norace 140 | func (c *Conn) Unlock() { 141 | c.mux.Unlock() 142 | } 143 | 144 | // IsClosed returns whether the Conn is closed. 145 | // 146 | //go:norace 147 | func (c *Conn) IsClosed() (bool, error) { 148 | return c.closed, c.closeErr 149 | } 150 | 151 | // ExecuteLen returns the length of the Conn's job list. 152 | // 153 | //go:norace 154 | func (c *Conn) ExecuteLen() int { 155 | c.mux.Lock() 156 | n := len(c.jobList) 157 | c.mux.Unlock() 158 | return n 159 | } 160 | 161 | // Execute is used to run the job. 162 | // 163 | // How it works: 164 | // If the job is the head/first of the Conn's job list, it will call the 165 | // nbio.Engine.Execute to run all the jobs in the job list that include: 166 | // 1. This job 167 | // 2. New jobs that are pushed to the back of the list before this job 168 | // is done. 169 | // 3. nbio.Engine.Execute returns until there's no more jobs in the job 170 | // list. 171 | // 172 | // Else if the job is not the head/first of the job list, it will push the 173 | // job to the back of the job list and wait to be called. 174 | // This guarantees there's at most one flow or goroutine running job/jobs 175 | // for each Conn. 176 | // This guarantees all the jobs are executed in order. 177 | // 178 | // Notice: 179 | // 1. The job wouldn't run or pushed to the back of the job list if the 180 | // connection is closed. 181 | // 2. nbio.Engine.Execute is handled by a goroutine pool by default, users 182 | // can customize it. 183 | // 184 | //go:norace 185 | func (c *Conn) Execute(job func()) bool { 186 | c.mux.Lock() 187 | if c.closed { 188 | c.mux.Unlock() 189 | return false 190 | } 191 | 192 | isHead := (len(c.jobList) == 0) 193 | c.jobList = append(c.jobList, job) 194 | c.mux.Unlock() 195 | 196 | // If there's no job running, run Engine.Execute to run this job 197 | // and new jobs appended before this head job is done. 198 | if isHead { 199 | c.execute(job) 200 | } 201 | return true 202 | } 203 | 204 | // MustExecute implements a similar function as Execute did, 205 | // but will still execute or push the job to the 206 | // back of the job list no matter whether Conn has been closed, 207 | // it guarantees the job to be executed. 208 | // This is used to handle the close event in nbio/nbhttp. 209 | // 210 | //go:norace 211 | func (c *Conn) MustExecute(job func()) { 212 | c.mux.Lock() 213 | isHead := (len(c.jobList) == 0) 214 | c.jobList = append(c.jobList, job) 215 | c.mux.Unlock() 216 | 217 | // If there's no job running, run Engine.Execute to run this job 218 | // and new jobs appended before this head job is done. 219 | if isHead { 220 | c.execute(job) 221 | } 222 | } 223 | 224 | //go:norace 225 | func (c *Conn) execute(job func()) { 226 | c.p.g.Execute(func() { 227 | i := 0 228 | for { 229 | func() { 230 | defer func() { 231 | if err := recover(); err != nil { 232 | const size = 64 << 10 233 | buf := make([]byte, size) 234 | buf = buf[:runtime.Stack(buf, false)] 235 | logging.Error("conn execute failed: %v\n%v\n", 236 | err, 237 | *(*string)(unsafe.Pointer(&buf)), 238 | ) 239 | } 240 | }() 241 | job() 242 | }() 243 | 244 | c.mux.Lock() 245 | i++ 246 | if len(c.jobList) == i { 247 | // set nil to release the job and gc 248 | c.jobList[i-1] = nil 249 | // reuse the slice 250 | c.jobList = c.jobList[0:0] 251 | c.mux.Unlock() 252 | return 253 | } 254 | // get next job 255 | job = c.jobList[i] 256 | // set nil to release the job and gc 257 | c.jobList[i] = nil 258 | c.mux.Unlock() 259 | } 260 | }) 261 | } 262 | -------------------------------------------------------------------------------- /conn_std.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 lesismal. All rights reserved. 2 | // Use of this source code is governed by an MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | //go:build windows 6 | // +build windows 7 | 8 | package nbio 9 | 10 | import ( 11 | "bytes" 12 | "errors" 13 | "io" 14 | "net" 15 | "sync" 16 | "syscall" 17 | "time" 18 | 19 | "github.com/lesismal/nbio/timer" 20 | ) 21 | 22 | // Conn wraps net.Conn. 23 | type Conn struct { 24 | p *poller 25 | 26 | hash int 27 | 28 | mux sync.Mutex 29 | 30 | conn net.Conn 31 | connUDP *udpConn 32 | 33 | rTimer *time.Timer 34 | 35 | typ ConnType 36 | closed bool 37 | closeErr error 38 | 39 | ReadBuffer []byte 40 | 41 | // user session. 42 | session interface{} 43 | 44 | jobList []func() 45 | 46 | cache *bytes.Buffer 47 | 48 | dataHandler func(c *Conn, data []byte) 49 | 50 | onConnected func(c *Conn, err error) 51 | } 52 | 53 | // Hash returns a hashcode. 54 | // 55 | //go:norace 56 | func (c *Conn) Hash() int { 57 | return c.hash 58 | } 59 | 60 | // Read wraps net.Conn.Read. 61 | // 62 | //go:norace 63 | func (c *Conn) Read(b []byte) (int, error) { 64 | if c.closeErr != nil { 65 | return 0, c.closeErr 66 | } 67 | 68 | var reader io.Reader = c.conn 69 | if c.cache != nil { 70 | reader = c.cache 71 | } 72 | nread, err := reader.Read(b) 73 | if c.closeErr == nil { 74 | c.closeErr = err 75 | } 76 | return nread, err 77 | } 78 | 79 | //go:norace 80 | func (c *Conn) read(b []byte) (int, error) { 81 | var err error 82 | var nread int 83 | switch c.typ { 84 | case ConnTypeTCP: 85 | nread, err = c.readTCP(b) 86 | case ConnTypeUDPServer, ConnTypeUDPClientFromDial: 87 | nread, err = c.readUDP(b) 88 | case ConnTypeUDPClientFromRead: 89 | err = errors.New("invalid udp conn for reading") 90 | default: 91 | } 92 | return nread, err 93 | } 94 | 95 | //go:norace 96 | func (c *Conn) readTCP(b []byte) (int, error) { 97 | g := c.p.g 98 | // g.beforeRead(c) 99 | nread, err := c.conn.Read(b) 100 | if c.closeErr == nil { 101 | c.closeErr = err 102 | } 103 | if g.onRead != nil { 104 | if nread > 0 { 105 | if c.cache == nil { 106 | c.cache = bytes.NewBuffer(nil) 107 | } 108 | c.cache.Write(b[:nread]) 109 | } 110 | g.onRead(c) 111 | return nread, nil 112 | } else if nread > 0 { 113 | b = b[:nread] 114 | g.onDataPtr(c, &b) 115 | } 116 | return nread, err 117 | } 118 | 119 | //go:norace 120 | func (c *Conn) readUDP(b []byte) (int, error) { 121 | if c.connUDP == nil { 122 | return 0, errors.New("invalid conn") 123 | } 124 | nread, rAddr, err := c.connUDP.ReadFromUDP(b) 125 | if c.closeErr == nil { 126 | c.closeErr = err 127 | } 128 | if err != nil { 129 | return 0, err 130 | } 131 | 132 | var g = c.p.g 133 | var dstConn = c 134 | if c.typ == ConnTypeUDPServer { 135 | uc, ok := c.connUDP.getConn(c.p, rAddr) 136 | if g.UDPReadTimeout > 0 { 137 | uc.SetReadDeadline(time.Now().Add(g.UDPReadTimeout)) 138 | } 139 | if !ok { 140 | p := g.pollers[c.Hash()%len(g.pollers)] 141 | p.addConn(uc) 142 | } 143 | dstConn = uc 144 | } 145 | 146 | if g.onRead != nil { 147 | if nread > 0 { 148 | if dstConn.cache == nil { 149 | dstConn.cache = bytes.NewBuffer(nil) 150 | } 151 | dstConn.cache.Write(b[:nread]) 152 | } 153 | g.onRead(dstConn) 154 | return nread, nil 155 | } else if nread > 0 { 156 | buf := b[:nread] 157 | g.onDataPtr(dstConn, &buf) 158 | } 159 | 160 | return nread, err 161 | } 162 | 163 | // Write wraps net.Conn.Write. 164 | // 165 | //go:norace 166 | func (c *Conn) Write(b []byte) (int, error) { 167 | var n int 168 | var err error 169 | switch c.typ { 170 | case ConnTypeTCP: 171 | n, err = c.writeTCP(b) 172 | case ConnTypeUDPServer: 173 | case ConnTypeUDPClientFromDial: 174 | n, err = c.writeUDPClientFromDial(b) 175 | case ConnTypeUDPClientFromRead: 176 | n, err = c.writeUDPClientFromRead(b) 177 | default: 178 | } 179 | if c.p.g.onWrittenSize != nil && n > 0 { 180 | c.p.g.onWrittenSize(c, b[:n], n) 181 | } 182 | return n, err 183 | } 184 | 185 | //go:norace 186 | func (c *Conn) writeTCP(b []byte) (int, error) { 187 | // c.p.g.beforeWrite(c) 188 | nwrite, err := c.conn.Write(b) 189 | if err != nil { 190 | if c.closeErr == nil { 191 | c.closeErr = err 192 | } 193 | c.Close() 194 | } 195 | 196 | return nwrite, err 197 | } 198 | 199 | //go:norace 200 | func (c *Conn) writeUDPClientFromDial(b []byte) (int, error) { 201 | nwrite, err := c.connUDP.Write(b) 202 | if err != nil { 203 | if c.closeErr == nil { 204 | c.closeErr = err 205 | } 206 | c.Close() 207 | } 208 | return nwrite, err 209 | } 210 | 211 | //go:norace 212 | func (c *Conn) writeUDPClientFromRead(b []byte) (int, error) { 213 | nwrite, err := c.connUDP.WriteToUDP(b, c.connUDP.rAddr) 214 | if err != nil { 215 | if c.closeErr == nil { 216 | c.closeErr = err 217 | } 218 | c.Close() 219 | } 220 | return nwrite, err 221 | } 222 | 223 | // Writev wraps buffers.WriteTo/syscall.Writev. 224 | // 225 | //go:norace 226 | func (c *Conn) Writev(in [][]byte) (int, error) { 227 | if c.connUDP == nil { 228 | buffers := net.Buffers(in) 229 | nwrite, err := buffers.WriteTo(c.conn) 230 | if err != nil { 231 | if c.closeErr == nil { 232 | c.closeErr = err 233 | } 234 | c.Close() 235 | } 236 | if c.p.g.onWrittenSize != nil && nwrite > 0 { 237 | total := int(nwrite) 238 | for i := 0; total > 0; i++ { 239 | if total <= len(in[i]) { 240 | c.p.g.onWrittenSize(c, in[i][:total], total) 241 | total = 0 242 | } else { 243 | c.p.g.onWrittenSize(c, in[i], len(in[i])) 244 | total -= len(in[i]) 245 | } 246 | } 247 | } 248 | return int(nwrite), err 249 | } 250 | 251 | var total = 0 252 | for _, b := range in { 253 | nwrite, err := c.Write(b) 254 | if nwrite > 0 { 255 | total += nwrite 256 | } 257 | if c.p.g.onWrittenSize != nil && nwrite > 0 { 258 | c.p.g.onWrittenSize(c, b[:nwrite], nwrite) 259 | } 260 | if err != nil { 261 | if c.closeErr == nil { 262 | c.closeErr = err 263 | } 264 | c.Close() 265 | return total, err 266 | } 267 | } 268 | return total, nil 269 | } 270 | 271 | // Close wraps net.Conn.Close. 272 | // 273 | //go:norace 274 | func (c *Conn) Close() error { 275 | var err error 276 | c.mux.Lock() 277 | if !c.closed { 278 | c.closed = true 279 | 280 | if c.rTimer != nil { 281 | c.rTimer.Stop() 282 | c.rTimer = nil 283 | } 284 | 285 | switch c.typ { 286 | case ConnTypeTCP: 287 | err = c.conn.Close() 288 | case ConnTypeUDPServer, ConnTypeUDPClientFromDial, ConnTypeUDPClientFromRead: 289 | err = c.connUDP.Close() 290 | default: 291 | } 292 | 293 | c.mux.Unlock() 294 | if c.p.g != nil { 295 | c.p.deleteConn(c) 296 | } 297 | return err 298 | } 299 | c.mux.Unlock() 300 | return err 301 | } 302 | 303 | // CloseWithError . 304 | // 305 | //go:norace 306 | func (c *Conn) CloseWithError(err error) error { 307 | if c.closeErr == nil { 308 | c.closeErr = err 309 | } 310 | return c.Close() 311 | } 312 | 313 | // LocalAddr wraps net.Conn.LocalAddr. 314 | // 315 | //go:norace 316 | func (c *Conn) LocalAddr() net.Addr { 317 | switch c.typ { 318 | case ConnTypeTCP: 319 | return c.conn.LocalAddr() 320 | case ConnTypeUDPServer, ConnTypeUDPClientFromDial, ConnTypeUDPClientFromRead: 321 | return c.connUDP.LocalAddr() 322 | default: 323 | } 324 | return nil 325 | } 326 | 327 | // RemoteAddr wraps net.Conn.RemoteAddr. 328 | // 329 | //go:norace 330 | func (c *Conn) RemoteAddr() net.Addr { 331 | switch c.typ { 332 | case ConnTypeTCP: 333 | return c.conn.RemoteAddr() 334 | case ConnTypeUDPClientFromDial: 335 | return c.connUDP.RemoteAddr() 336 | case ConnTypeUDPClientFromRead: 337 | return c.connUDP.rAddr 338 | default: 339 | } 340 | return nil 341 | } 342 | 343 | // SetDeadline wraps net.Conn.SetDeadline. 344 | // 345 | //go:norace 346 | func (c *Conn) SetDeadline(t time.Time) error { 347 | if c.typ == ConnTypeTCP { 348 | return c.conn.SetDeadline(t) 349 | } 350 | return c.SetReadDeadline(t) 351 | } 352 | 353 | // SetReadDeadline wraps net.Conn.SetReadDeadline. 354 | // 355 | //go:norace 356 | func (c *Conn) SetReadDeadline(t time.Time) error { 357 | if t.IsZero() { 358 | t = time.Now().Add(timer.TimeForever) 359 | } 360 | 361 | if c.typ == ConnTypeTCP { 362 | return c.conn.SetReadDeadline(t) 363 | } 364 | 365 | timeout := time.Until(t) 366 | if c.rTimer == nil { 367 | c.rTimer = c.p.g.AfterFunc(timeout, func() { 368 | c.CloseWithError(errReadTimeout) 369 | }) 370 | } else { 371 | c.rTimer.Reset(timeout) 372 | } 373 | 374 | return nil 375 | } 376 | 377 | // SetWriteDeadline wraps net.Conn.SetWriteDeadline. 378 | // 379 | //go:norace 380 | func (c *Conn) SetWriteDeadline(t time.Time) error { 381 | if c.typ != ConnTypeTCP { 382 | return nil 383 | } 384 | 385 | if t.IsZero() { 386 | t = time.Now().Add(timer.TimeForever) 387 | } 388 | 389 | return c.conn.SetWriteDeadline(t) 390 | } 391 | 392 | // SetNoDelay wraps net.Conn.SetNoDelay. 393 | // 394 | //go:norace 395 | func (c *Conn) SetNoDelay(nodelay bool) error { 396 | if c.typ != ConnTypeTCP { 397 | return nil 398 | } 399 | 400 | conn, ok := c.conn.(*net.TCPConn) 401 | if ok { 402 | return conn.SetNoDelay(nodelay) 403 | } 404 | return nil 405 | } 406 | 407 | // SetReadBuffer wraps net.Conn.SetReadBuffer. 408 | // 409 | //go:norace 410 | func (c *Conn) SetReadBuffer(bytes int) error { 411 | if c.typ != ConnTypeTCP { 412 | return nil 413 | } 414 | 415 | conn, ok := c.conn.(*net.TCPConn) 416 | if ok { 417 | return conn.SetReadBuffer(bytes) 418 | } 419 | return nil 420 | } 421 | 422 | // SetWriteBuffer wraps net.Conn.SetWriteBuffer. 423 | // 424 | //go:norace 425 | func (c *Conn) SetWriteBuffer(bytes int) error { 426 | if c.typ != ConnTypeTCP { 427 | return nil 428 | } 429 | 430 | conn, ok := c.conn.(*net.TCPConn) 431 | if ok { 432 | return conn.SetWriteBuffer(bytes) 433 | } 434 | return nil 435 | } 436 | 437 | // SetKeepAlive wraps net.Conn.SetKeepAlive. 438 | // 439 | //go:norace 440 | func (c *Conn) SetKeepAlive(keepalive bool) error { 441 | if c.typ != ConnTypeTCP { 442 | return nil 443 | } 444 | 445 | conn, ok := c.conn.(*net.TCPConn) 446 | if ok { 447 | return conn.SetKeepAlive(keepalive) 448 | } 449 | return nil 450 | } 451 | 452 | // SetKeepAlivePeriod wraps net.Conn.SetKeepAlivePeriod. 453 | // 454 | //go:norace 455 | func (c *Conn) SetKeepAlivePeriod(d time.Duration) error { 456 | if c.typ != ConnTypeTCP { 457 | return nil 458 | } 459 | 460 | conn, ok := c.conn.(*net.TCPConn) 461 | if ok { 462 | return conn.SetKeepAlivePeriod(d) 463 | } 464 | return nil 465 | } 466 | 467 | // SetLinger wraps net.Conn.SetLinger. 468 | // 469 | //go:norace 470 | func (c *Conn) SetLinger(onoff int32, linger int32) error { 471 | if c.typ != ConnTypeTCP { 472 | return nil 473 | } 474 | 475 | conn, ok := c.conn.(*net.TCPConn) 476 | if ok { 477 | return conn.SetLinger(int(linger)) 478 | } 479 | return nil 480 | } 481 | 482 | //go:norace 483 | func newConn(conn net.Conn) *Conn { 484 | c := &Conn{} 485 | addr := conn.LocalAddr().String() 486 | 487 | uc, ok := conn.(*net.UDPConn) 488 | if ok { 489 | rAddr := uc.RemoteAddr() 490 | if rAddr == nil { 491 | c.typ = ConnTypeUDPServer 492 | c.connUDP = &udpConn{ 493 | UDPConn: uc, 494 | conns: map[string]*Conn{}, 495 | } 496 | } else { 497 | c.typ = ConnTypeUDPClientFromDial 498 | addr += rAddr.String() 499 | c.connUDP = &udpConn{ 500 | UDPConn: uc, 501 | } 502 | } 503 | } else { 504 | c.conn = conn 505 | c.typ = ConnTypeTCP 506 | } 507 | 508 | for _, ch := range addr { 509 | c.hash = 31*c.hash + int(ch) 510 | } 511 | if c.hash < 0 { 512 | c.hash = -c.hash 513 | } 514 | 515 | return c 516 | } 517 | 518 | // NBConn converts net.Conn to *Conn. 519 | // 520 | //go:norace 521 | func NBConn(conn net.Conn) (*Conn, error) { 522 | if conn == nil { 523 | return nil, errors.New("invalid conn: nil") 524 | } 525 | c, ok := conn.(*Conn) 526 | if !ok { 527 | c = newConn(conn) 528 | } 529 | return c, nil 530 | } 531 | 532 | type udpConn struct { 533 | *net.UDPConn 534 | rAddr *net.UDPAddr 535 | 536 | mux sync.RWMutex 537 | parent *udpConn 538 | conns map[string]*Conn 539 | } 540 | 541 | //go:norace 542 | func (u *udpConn) Close() error { 543 | parent := u.parent 544 | if parent != nil { 545 | parent.mux.Lock() 546 | delete(parent.conns, u.rAddr.String()) 547 | parent.mux.Unlock() 548 | } else { 549 | u.UDPConn.Close() 550 | for _, c := range u.conns { 551 | c.Close() 552 | } 553 | u.conns = nil 554 | } 555 | return nil 556 | } 557 | 558 | //go:norace 559 | func (u *udpConn) getConn(p *poller, rAddr *net.UDPAddr) (*Conn, bool) { 560 | u.mux.RLock() 561 | addr := rAddr.String() 562 | c, ok := u.conns[addr] 563 | u.mux.RUnlock() 564 | 565 | if !ok { 566 | c = &Conn{ 567 | p: p, 568 | typ: ConnTypeUDPClientFromRead, 569 | connUDP: &udpConn{ 570 | parent: u, 571 | rAddr: rAddr, 572 | UDPConn: u.UDPConn, 573 | }, 574 | } 575 | hashAddr := u.LocalAddr().String() + addr 576 | for _, ch := range hashAddr { 577 | c.hash = 31*c.hash + int(ch) 578 | } 579 | if c.hash < 0 { 580 | c.hash = -c.hash 581 | } 582 | u.mux.Lock() 583 | u.conns[addr] = c 584 | u.mux.Unlock() 585 | } 586 | 587 | return c, ok 588 | } 589 | 590 | //go:norace 591 | func (c *Conn) SyscallConn() (syscall.RawConn, error) { 592 | if rc, ok := c.conn.(interface { 593 | SyscallConn() (syscall.RawConn, error) 594 | }); ok { 595 | return rc.SyscallConn() 596 | } 597 | return nil, ErrUnsupported 598 | } 599 | -------------------------------------------------------------------------------- /engine_std.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 lesismal. All rights reserved. 2 | // Use of this source code is governed by an MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | //go:build windows 6 | // +build windows 7 | 8 | package nbio 9 | 10 | import ( 11 | "net" 12 | "runtime" 13 | "strings" 14 | "time" 15 | 16 | "github.com/lesismal/nbio/logging" 17 | "github.com/lesismal/nbio/mempool" 18 | "github.com/lesismal/nbio/timer" 19 | ) 20 | 21 | // Start inits and starts pollers. 22 | // 23 | //go:norace 24 | func (g *Engine) Start() error { 25 | // Create listener pollers. 26 | udpListeners := make([]*net.UDPConn, len(g.Addrs))[0:0] 27 | switch g.Network { 28 | case NETWORK_TCP, NETWORK_TCP4, NETWORK_TCP6: 29 | for i := range g.Addrs { 30 | ln, err := newPoller(g, true, i) 31 | if err != nil { 32 | for j := 0; j < i; j++ { 33 | g.listeners[j].stop() 34 | } 35 | return err 36 | } 37 | g.Addrs[i] = ln.listener.Addr().String() 38 | g.listeners = append(g.listeners, ln) 39 | } 40 | case NETWORK_UDP, NETWORK_UDP4, NETWORK_UDP6: 41 | for i, addrStr := range g.Addrs { 42 | addr, err := net.ResolveUDPAddr(g.Network, addrStr) 43 | if err != nil { 44 | for j := 0; j < i; j++ { 45 | udpListeners[j].Close() 46 | } 47 | return err 48 | } 49 | ln, err := g.ListenUDP("udp", addr) 50 | if err != nil { 51 | for j := 0; j < i; j++ { 52 | udpListeners[j].Close() 53 | } 54 | return err 55 | } 56 | g.Addrs[i] = ln.LocalAddr().String() 57 | udpListeners = append(udpListeners, ln) 58 | } 59 | } 60 | 61 | // Create IO pollers. 62 | for i := 0; i < g.NPoller; i++ { 63 | p, err := newPoller(g, false, i) 64 | if err != nil { 65 | for j := 0; j < len(g.listeners); j++ { 66 | g.listeners[j].stop() 67 | } 68 | 69 | for j := 0; j < i; j++ { 70 | g.pollers[j].stop() 71 | } 72 | return err 73 | } 74 | g.pollers[i] = p 75 | } 76 | 77 | // Start IO pollers. 78 | for i := 0; i < g.NPoller; i++ { 79 | g.Add(1) 80 | go g.pollers[i].start() 81 | } 82 | 83 | // Start TCP/Unix listener pollers. 84 | for _, l := range g.listeners { 85 | g.Add(1) 86 | go l.start() 87 | } 88 | 89 | // Start UDP listener pollers. 90 | for _, ln := range udpListeners { 91 | _, err := g.AddConn(ln) 92 | if err != nil { 93 | for j := 0; j < len(g.listeners); j++ { 94 | g.listeners[j].stop() 95 | } 96 | 97 | for j := 0; j < len(g.pollers); j++ { 98 | g.pollers[j].stop() 99 | } 100 | 101 | for j := 0; j < len(udpListeners); j++ { 102 | udpListeners[j].Close() 103 | } 104 | 105 | return err 106 | } 107 | } 108 | 109 | // g.Timer.Start() 110 | 111 | if len(g.Addrs) == 0 { 112 | logging.Info("NBIO Engine[%v] start with [%v eventloop, MaxOpenFiles: %v]", 113 | g.Name, 114 | g.NPoller, 115 | MaxOpenFiles, 116 | ) 117 | } else { 118 | logging.Info("NBIO Engine[%v] start with [%v eventloop], listen on: [\"%v@%v\"], MaxOpenFiles: %v", 119 | g.Name, 120 | g.NPoller, 121 | g.Network, 122 | strings.Join(g.Addrs, `", "`), 123 | MaxOpenFiles, 124 | ) 125 | } 126 | 127 | return nil 128 | } 129 | 130 | // NewEngine creates an Engine and init default configurations. 131 | // 132 | //go:norace 133 | func NewEngine(conf Config) *Engine { 134 | cpuNum := runtime.NumCPU() 135 | if conf.Name == "" { 136 | conf.Name = "NB" 137 | } 138 | if conf.NPoller <= 0 { 139 | conf.NPoller = cpuNum 140 | } 141 | if conf.ReadBufferSize <= 0 { 142 | conf.ReadBufferSize = DefaultReadBufferSize 143 | } 144 | if conf.MaxWriteBufferSize <= 0 { 145 | conf.MaxWriteBufferSize = DefaultMaxWriteBufferSize 146 | } 147 | if conf.Listen == nil { 148 | conf.Listen = net.Listen 149 | } 150 | if conf.ListenUDP == nil { 151 | conf.ListenUDP = net.ListenUDP 152 | } 153 | if conf.BodyAllocator == nil { 154 | conf.BodyAllocator = mempool.DefaultMemPool 155 | } 156 | 157 | g := &Engine{ 158 | Config: conf, 159 | Timer: timer.New(conf.Name), 160 | listeners: make([]*poller, len(conf.Addrs))[0:0], 161 | pollers: make([]*poller, conf.NPoller), 162 | connsStd: map[*Conn]struct{}{}, 163 | } 164 | 165 | g.initHandlers() 166 | 167 | g.OnReadBufferAlloc(func(c *Conn) *[]byte { 168 | if c.ReadBuffer == nil { 169 | c.ReadBuffer = make([]byte, g.ReadBufferSize) 170 | } 171 | return &c.ReadBuffer 172 | }) 173 | 174 | return g 175 | } 176 | 177 | // DialAsync connects asynchrony to the address on the named network. 178 | // 179 | //go:norace 180 | func (engine *Engine) DialAsync(network, addr string, onConnected func(*Conn, error)) error { 181 | return engine.DialAsyncTimeout(network, addr, 0, onConnected) 182 | } 183 | 184 | // DialAsync connects asynchrony to the address on the named network with timeout. 185 | // 186 | //go:norace 187 | func (engine *Engine) DialAsyncTimeout(network, addr string, timeout time.Duration, onConnected func(*Conn, error)) error { 188 | go func() { 189 | var err error 190 | var conn net.Conn 191 | if timeout > 0 { 192 | conn, err = net.DialTimeout(network, addr, timeout) 193 | } else { 194 | conn, err = net.Dial(network, addr) 195 | } 196 | if err != nil { 197 | onConnected(nil, err) 198 | return 199 | } 200 | nbc, err := NBConn(conn) 201 | if err != nil { 202 | onConnected(nil, err) 203 | return 204 | } 205 | engine.wgConn.Add(1) 206 | nbc, err = engine.addDialer(nbc) 207 | if err == nil { 208 | nbc.SetWriteDeadline(time.Time{}) 209 | } else { 210 | engine.wgConn.Done() 211 | } 212 | onConnected(nbc, err) 213 | }() 214 | return nil 215 | } 216 | -------------------------------------------------------------------------------- /engine_unix.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 lesismal. All rights reserved. 2 | // Use of this source code is governed by an MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | //go:build linux || darwin || netbsd || freebsd || openbsd || dragonfly 6 | // +build linux darwin netbsd freebsd openbsd dragonfly 7 | 8 | package nbio 9 | 10 | import ( 11 | "errors" 12 | "net" 13 | "runtime" 14 | "strings" 15 | "syscall" 16 | "time" 17 | 18 | "github.com/lesismal/nbio/logging" 19 | "github.com/lesismal/nbio/mempool" 20 | "github.com/lesismal/nbio/taskpool" 21 | "github.com/lesismal/nbio/timer" 22 | ) 23 | 24 | // Start inits and starts pollers. 25 | // 26 | //go:norace 27 | func (g *Engine) Start() error { 28 | g.connsUnix = make([]*Conn, MaxOpenFiles) 29 | 30 | // Create pollers and listeners. 31 | g.pollers = make([]*poller, g.NPoller) 32 | g.listeners = make([]*poller, len(g.Addrs))[0:0] 33 | udpListeners := make([]*net.UDPConn, len(g.Addrs))[0:0] 34 | 35 | switch g.Network { 36 | case NETWORK_UNIX, NETWORK_TCP, NETWORK_TCP4, NETWORK_TCP6: 37 | for i := range g.Addrs { 38 | ln, err := newPoller(g, true, i) 39 | if err != nil { 40 | for j := 0; j < i; j++ { 41 | g.listeners[j].stop() 42 | } 43 | return err 44 | } 45 | g.Addrs[i] = ln.listener.Addr().String() 46 | g.listeners = append(g.listeners, ln) 47 | } 48 | case NETWORK_UDP, NETWORK_UDP4, NETWORK_UDP6: 49 | for i, addrStr := range g.Addrs { 50 | addr, err := net.ResolveUDPAddr(g.Network, addrStr) 51 | if err != nil { 52 | for j := 0; j < i; j++ { 53 | udpListeners[j].Close() 54 | } 55 | return err 56 | } 57 | ln, err := g.ListenUDP("udp", addr) 58 | if err != nil { 59 | for j := 0; j < i; j++ { 60 | udpListeners[j].Close() 61 | } 62 | return err 63 | } 64 | g.Addrs[i] = ln.LocalAddr().String() 65 | udpListeners = append(udpListeners, ln) 66 | } 67 | } 68 | 69 | // Create IO pollers. 70 | for i := 0; i < g.NPoller; i++ { 71 | p, err := newPoller(g, false, i) 72 | if err != nil { 73 | for j := 0; j < len(g.listeners); j++ { 74 | g.listeners[j].stop() 75 | } 76 | 77 | for j := 0; j < i; j++ { 78 | g.pollers[j].stop() 79 | } 80 | return err 81 | } 82 | g.pollers[i] = p 83 | } 84 | 85 | // Start IO pollers. 86 | for i := 0; i < g.NPoller; i++ { 87 | g.pollers[i].ReadBuffer = make([]byte, g.ReadBufferSize) 88 | g.Add(1) 89 | go g.pollers[i].start() 90 | } 91 | 92 | // Start TCP/Unix listener pollers. 93 | for _, l := range g.listeners { 94 | g.Add(1) 95 | go l.start() 96 | } 97 | 98 | // Start UDP listener pollers. 99 | for _, ln := range udpListeners { 100 | _, err := g.AddConn(ln) 101 | if err != nil { 102 | for j := 0; j < len(g.listeners); j++ { 103 | g.listeners[j].stop() 104 | } 105 | 106 | for j := 0; j < len(g.pollers); j++ { 107 | g.pollers[j].stop() 108 | } 109 | 110 | for j := 0; j < len(udpListeners); j++ { 111 | udpListeners[j].Close() 112 | } 113 | 114 | return err 115 | } 116 | } 117 | 118 | g.Timer.Start() 119 | g.isOneshot = (g.EpollMod == EPOLLET && g.EPOLLONESHOT == EPOLLONESHOT) 120 | 121 | if g.AsyncReadInPoller { 122 | if g.IOExecute == nil { 123 | g.ioTaskPool = taskpool.NewIO(0, 0, 0) 124 | g.IOExecute = g.ioTaskPool.Go 125 | } 126 | } 127 | 128 | if len(g.Addrs) == 0 { 129 | logging.Info("NBIO Engine[%v] start with [%v eventloop, MaxOpenFiles: %v]", 130 | g.Name, 131 | g.NPoller, 132 | MaxOpenFiles, 133 | ) 134 | } else { 135 | logging.Info("NBIO Engine[%v] start with [%v eventloop], listen on: [\"%v@%v\"], MaxOpenFiles: %v", 136 | g.Name, 137 | g.NPoller, 138 | g.Network, 139 | strings.Join(g.Addrs, `", "`), 140 | MaxOpenFiles, 141 | ) 142 | } 143 | 144 | return nil 145 | } 146 | 147 | // DialAsync connects asynchrony to the address on the named network. 148 | // 149 | //go:norace 150 | func (engine *Engine) DialAsync(network, addr string, onConnected func(*Conn, error)) error { 151 | return engine.DialAsyncTimeout(network, addr, 0, onConnected) 152 | } 153 | 154 | // DialAsync connects asynchrony to the address on the named network with timeout. 155 | // 156 | //go:norace 157 | func (engine *Engine) DialAsyncTimeout(network, addr string, timeout time.Duration, onConnected func(*Conn, error)) error { 158 | h := func(c *Conn, err error) { 159 | if err == nil { 160 | c.SetWriteDeadline(time.Time{}) 161 | } 162 | onConnected(c, err) 163 | } 164 | domain, typ, dialaddr, raddr, connType, err := parseDomainAndType(network, addr) 165 | if err != nil { 166 | return err 167 | } 168 | fd, err := syscall.Socket(domain, typ, 0) 169 | if err != nil { 170 | return err 171 | } 172 | err = syscall.SetNonblock(fd, true) 173 | if err != nil { 174 | syscall.Close(fd) 175 | return err 176 | } 177 | err = syscall.Connect(fd, dialaddr) 178 | inprogress := false 179 | if err != nil { 180 | if errors.Is(err, syscall.EINPROGRESS) { 181 | inprogress = true 182 | } else { 183 | syscall.Close(fd) 184 | return err 185 | } 186 | } 187 | sa, _ := syscall.Getsockname(fd) 188 | c := &Conn{ 189 | fd: fd, 190 | rAddr: raddr, 191 | typ: connType, 192 | } 193 | if inprogress { 194 | c.onConnected = h 195 | } 196 | switch vt := sa.(type) { 197 | case *syscall.SockaddrInet4: 198 | switch connType { 199 | case ConnTypeTCP: 200 | c.lAddr = &net.TCPAddr{ 201 | IP: []byte{vt.Addr[0], vt.Addr[1], vt.Addr[2], vt.Addr[3]}, 202 | Port: vt.Port, 203 | } 204 | case ConnTypeUDPClientFromDial: 205 | c.lAddr = &net.TCPAddr{ 206 | IP: []byte{vt.Addr[0], vt.Addr[1], vt.Addr[2], vt.Addr[3]}, 207 | Port: vt.Port, 208 | } 209 | c.connUDP = &udpConn{ 210 | parent: c, 211 | } 212 | } 213 | case *syscall.SockaddrInet6: 214 | var iface *net.Interface 215 | iface, err = net.InterfaceByIndex(int(vt.ZoneId)) 216 | if err != nil { 217 | syscall.Close(fd) 218 | return err 219 | } 220 | switch connType { 221 | case ConnTypeTCP: 222 | c.lAddr = &net.TCPAddr{ 223 | IP: make([]byte, len(vt.Addr)), 224 | Port: vt.Port, 225 | Zone: iface.Name, 226 | } 227 | case ConnTypeUDPClientFromDial: 228 | c.lAddr = &net.UDPAddr{ 229 | IP: make([]byte, len(vt.Addr)), 230 | Port: vt.Port, 231 | Zone: iface.Name, 232 | } 233 | c.connUDP = &udpConn{ 234 | parent: c, 235 | } 236 | } 237 | case *syscall.SockaddrUnix: 238 | c.lAddr = &net.UnixAddr{ 239 | Net: network, 240 | Name: vt.Name, 241 | } 242 | } 243 | 244 | engine.wgConn.Add(1) 245 | _, err = engine.addDialer(c) 246 | if err != nil { 247 | engine.wgConn.Done() 248 | return err 249 | } 250 | 251 | if !inprogress { 252 | engine.Async(func() { 253 | h(c, nil) 254 | }) 255 | } else if timeout > 0 { 256 | c.setDeadline(&c.wTimer, ErrDialTimeout, time.Now().Add(timeout)) 257 | } 258 | 259 | return nil 260 | } 261 | 262 | // NewEngine creates an Engine and init default configurations. 263 | // 264 | //go:norace 265 | func NewEngine(conf Config) *Engine { 266 | if conf.Name == "" { 267 | conf.Name = "NB" 268 | } 269 | if conf.NPoller <= 0 { 270 | conf.NPoller = runtime.NumCPU() / 4 271 | if conf.AsyncReadInPoller && conf.EpollMod == EPOLLET { 272 | conf.NPoller = 1 273 | } 274 | if conf.NPoller == 0 { 275 | conf.NPoller = 1 276 | } 277 | } 278 | if conf.ReadBufferSize <= 0 { 279 | conf.ReadBufferSize = DefaultReadBufferSize 280 | } 281 | if conf.MaxWriteBufferSize <= 0 { 282 | conf.MaxWriteBufferSize = DefaultMaxWriteBufferSize 283 | } 284 | if conf.MaxConnReadTimesPerEventLoop <= 0 { 285 | conf.MaxConnReadTimesPerEventLoop = DefaultMaxConnReadTimesPerEventLoop 286 | } 287 | if conf.Listen == nil { 288 | conf.Listen = net.Listen 289 | } 290 | if conf.ListenUDP == nil { 291 | conf.ListenUDP = net.ListenUDP 292 | } 293 | if conf.BodyAllocator == nil { 294 | conf.BodyAllocator = mempool.DefaultMemPool 295 | } 296 | 297 | g := &Engine{ 298 | Config: conf, 299 | Timer: timer.New(conf.Name), 300 | } 301 | 302 | g.initHandlers() 303 | 304 | return g 305 | } 306 | -------------------------------------------------------------------------------- /error.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 lesismal. All rights reserved. 2 | // Use of this source code is governed by an MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package nbio 6 | 7 | import ( 8 | "errors" 9 | ) 10 | 11 | var ( 12 | ErrReadTimeout = errors.New("read timeout") 13 | errReadTimeout = ErrReadTimeout 14 | 15 | ErrWriteTimeout = errors.New("write timeout") 16 | errWriteTimeout = ErrWriteTimeout 17 | 18 | ErrOverflow = errors.New("write overflow") 19 | errOverflow = ErrOverflow 20 | 21 | ErrDialTimeout = errors.New("dial timeout") 22 | 23 | ErrUnsupported = errors.New("unsupported operation") 24 | ) 25 | -------------------------------------------------------------------------------- /extension/tls/tls.go: -------------------------------------------------------------------------------- 1 | package tls 2 | 3 | // deprecated. 4 | 5 | import ( 6 | "github.com/lesismal/llib/std/crypto/tls" 7 | "github.com/lesismal/nbio" 8 | "github.com/lesismal/nbio/mempool" 9 | ) 10 | 11 | // Conn . 12 | type Conn = tls.Conn 13 | 14 | // Config . 15 | type Config = tls.Config 16 | 17 | // Dial returns a net.Conn to be added to a Engine. 18 | // 19 | //go:norace 20 | func Dial(network, addr string, config *Config) (*tls.Conn, error) { 21 | tlsConn, err := tls.Dial(network, addr, config, mempool.DefaultMemPool) 22 | if err != nil { 23 | return nil, err 24 | } 25 | 26 | return tlsConn, nil 27 | } 28 | 29 | // WrapOpen returns an opening handler of nbio.Engine. 30 | // 31 | //go:norace 32 | func WrapOpen(tlsConfig *Config, isClient bool, h func(c *nbio.Conn, tlsConn *Conn)) func(c *nbio.Conn) { 33 | return func(c *nbio.Conn) { 34 | var tlsConn *tls.Conn 35 | sesseion := c.Session() 36 | if sesseion != nil { 37 | tlsConn = sesseion.(*tls.Conn) 38 | } 39 | if tlsConn == nil && !isClient { 40 | tlsConn = tls.NewConn(c, tlsConfig, isClient, true, mempool.DefaultMemPool) 41 | c.SetSession(tlsConn) 42 | } 43 | if h != nil { 44 | h(c, tlsConn) 45 | } 46 | } 47 | } 48 | 49 | // WrapClose returns an closing handler of nbio.Engine. 50 | // 51 | //go:norace 52 | func WrapClose(h func(c *nbio.Conn, tlsConn *Conn, err error)) func(c *nbio.Conn, err error) { 53 | return func(c *nbio.Conn, err error) { 54 | if h != nil && c != nil { 55 | if session := c.Session(); session != nil { 56 | if tlsConn, ok := session.(*Conn); ok { 57 | h(c, tlsConn, err) 58 | } 59 | } 60 | } 61 | } 62 | } 63 | 64 | // WrapData returns a data handler of nbio.Engine. 65 | // 66 | //go:norace 67 | func WrapData(h func(c *nbio.Conn, tlsConn *Conn, data []byte), args ...interface{}) func(c *nbio.Conn, data []byte) { 68 | getBuffer := func() []byte { 69 | return make([]byte, 2048) 70 | } 71 | if len(args) > 0 { 72 | if bh, ok := args[0].(func() []byte); ok { 73 | getBuffer = bh 74 | } 75 | } 76 | return func(c *nbio.Conn, data []byte) { 77 | if session := c.Session(); session != nil { 78 | if tlsConn, ok := session.(*Conn); ok { 79 | tlsConn.Append(data) 80 | buffer := getBuffer() 81 | for { 82 | n, err := tlsConn.Read(buffer) 83 | if err != nil { 84 | c.Close() 85 | return 86 | } 87 | if h != nil && n > 0 { 88 | h(c, tlsConn, buffer[:n]) 89 | } 90 | if n == 0 { 91 | return 92 | } 93 | } 94 | } 95 | } 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/lesismal/nbio 2 | 3 | go 1.16 4 | 5 | require github.com/lesismal/llib v1.2.2 6 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/lesismal/llib v1.2.2 h1:ZoVgP9J58Ju3Yue5jtj8ybWl+BKqoVmdRaN1mNwG5Gc= 2 | github.com/lesismal/llib v1.2.2/go.mod h1:70tFXXe7P1FZ02AU9l8LgSOK7d7sRrpnkUr3rd3gKSg= 3 | golang.org/x/crypto v0.0.0-20210513122933-cd7d49e622d5 h1:N6Jp/LCiEoIBX56BZSR2bepK5GtbSC2DDOYT742mMfE= 4 | golang.org/x/crypto v0.0.0-20210513122933-cd7d49e622d5/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= 5 | golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= 6 | golang.org/x/net v0.0.0-20210510120150-4163338589ed/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= 7 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 8 | golang.org/x/sys v0.0.0-20210423082822-04245dca01da h1:b3NXsE2LusjYGGjL5bxEVZZORm/YEFFrWFjR8eFrw/c= 9 | golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 10 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 11 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 12 | golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 13 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 14 | -------------------------------------------------------------------------------- /lmux/lmux.go: -------------------------------------------------------------------------------- 1 | package lmux 2 | 3 | import ( 4 | "errors" 5 | "net" 6 | "sync/atomic" 7 | "time" 8 | 9 | "github.com/lesismal/nbio/logging" 10 | ) 11 | 12 | type event struct { 13 | err error 14 | conn net.Conn 15 | } 16 | 17 | type listenerAB struct { 18 | a, b *ChanListener 19 | } 20 | 21 | // New returns a ListenerMux. 22 | // 23 | //go:norace 24 | func New(maxOnlineA int) *ListenerMux { 25 | return &ListenerMux{ 26 | listeners: map[net.Listener]listenerAB{}, 27 | chClose: make(chan struct{}), 28 | maxOnlineA: int32(maxOnlineA), 29 | } 30 | } 31 | 32 | // ListenerMux manages listeners and handle the connection dispatching logic. 33 | type ListenerMux struct { 34 | shutdown bool 35 | listeners map[net.Listener]listenerAB 36 | chClose chan struct{} 37 | onlineA int32 38 | maxOnlineA int32 39 | } 40 | 41 | // Mux creates and returns ChanListener A and B: 42 | // If the online num of A is less than ListenerMux. maxOnlineA, the new connection will be dispatched to A; 43 | // Else the new connection will be dispatched to B. 44 | // 45 | //go:norace 46 | func (lm *ListenerMux) Mux(l net.Listener) (*ChanListener, *ChanListener) { 47 | if l == nil || lm == nil { 48 | return nil, nil 49 | } 50 | if lm.listeners == nil { 51 | lm.listeners = map[net.Listener]listenerAB{} 52 | } 53 | ab := listenerAB{ 54 | a: &ChanListener{ 55 | addr: l.Addr(), 56 | chClose: lm.chClose, 57 | chEvent: make(chan event, 1024*64), 58 | decrease: lm.DecreaseOnlineA, 59 | }, 60 | b: &ChanListener{ 61 | addr: l.Addr(), 62 | chClose: lm.chClose, 63 | chEvent: make(chan event, 1024*64), 64 | }, 65 | } 66 | lm.listeners[l] = ab 67 | return ab.a, ab.b 68 | } 69 | 70 | // Start starts to accept and dispatch the connections to ChanListener A or B. 71 | // 72 | //go:norace 73 | func (lm *ListenerMux) Start() { 74 | if lm == nil { 75 | return 76 | } 77 | lm.shutdown = false 78 | for k, v := range lm.listeners { 79 | go func(l net.Listener, listenerA *ChanListener, listenerB *ChanListener) { 80 | for !lm.shutdown { 81 | c, err := l.Accept() 82 | if err != nil { 83 | var ne net.Error 84 | if ok := errors.As(err, &ne); ok && ne.Timeout() { 85 | logging.Error("Accept failed: timeout error, retrying...") 86 | time.Sleep(time.Second / 20) 87 | } else { 88 | if !lm.shutdown { 89 | logging.Error("Accept failed: %v, exit...", err) 90 | } 91 | listenerA.chEvent <- event{err: err, conn: c} 92 | listenerB.chEvent <- event{err: err, conn: c} 93 | // return 94 | } 95 | continue 96 | } 97 | if atomic.AddInt32(&lm.onlineA, 1) <= lm.maxOnlineA { 98 | listenerA.chEvent <- event{err: nil, conn: c} 99 | } else { 100 | atomic.AddInt32(&lm.onlineA, -1) 101 | listenerB.chEvent <- event{err: nil, conn: c} 102 | } 103 | } 104 | }(k, v.a, v.b) 105 | } 106 | } 107 | 108 | // Stop stops all the listeners. 109 | // 110 | //go:norace 111 | func (lm *ListenerMux) Stop() { 112 | if lm == nil { 113 | return 114 | } 115 | lm.shutdown = true 116 | for l, ab := range lm.listeners { 117 | l.Close() 118 | ab.a.Close() 119 | ab.b.Close() 120 | } 121 | close(lm.chClose) 122 | } 123 | 124 | // DecreaseOnlineA decreases the online num of ChanListener A. 125 | // 126 | //go:norace 127 | func (lm *ListenerMux) DecreaseOnlineA() { 128 | atomic.AddInt32(&lm.onlineA, -1) 129 | } 130 | 131 | // ChanListener . 132 | type ChanListener struct { 133 | addr net.Addr 134 | chEvent chan event 135 | chClose chan struct{} 136 | decrease func() 137 | } 138 | 139 | // Accept accepts a connection. 140 | // 141 | //go:norace 142 | func (l *ChanListener) Accept() (net.Conn, error) { 143 | select { 144 | case e := <-l.chEvent: 145 | return e.conn, e.err 146 | case <-l.chClose: 147 | return nil, net.ErrClosed 148 | } 149 | } 150 | 151 | // Close does nothing but implementing net.Conn.Close. 152 | // User should call ListenerMux.Close to close it automatically. 153 | // 154 | //go:norace 155 | func (l *ChanListener) Close() error { 156 | return nil 157 | } 158 | 159 | // Addr returns the listener's network address. 160 | // 161 | //go:norace 162 | func (l *ChanListener) Addr() net.Addr { 163 | return l.addr 164 | } 165 | 166 | // Decrease decreases the online num if it's A. 167 | // 168 | //go:norace 169 | func (l *ChanListener) Decrease() { 170 | if l.decrease != nil { 171 | l.decrease() 172 | } 173 | } 174 | -------------------------------------------------------------------------------- /lmux/lmux_test.go: -------------------------------------------------------------------------------- 1 | package lmux 2 | 3 | import ( 4 | "net" 5 | "sync" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | func TestListenerMux(t *testing.T) { 11 | maxOnlineA := 3 12 | totalConn := 5 13 | network := "tcp" 14 | addr1 := "localhost:8001" 15 | addr2 := "localhost:8002" 16 | lm := New(maxOnlineA) 17 | 18 | listen := func(addr string) net.Listener { 19 | l, err := net.Listen(network, addr) 20 | if err != nil { 21 | t.Fatal(err) 22 | } 23 | return l 24 | } 25 | l1 := listen(addr1) 26 | listenerA, listenerB := lm.Mux(l1) 27 | l2 := listen(addr2) 28 | listenerC, listenerD := lm.Mux(l2) 29 | lm.Start() 30 | 31 | time.Sleep(time.Second / 5) 32 | wg := sync.WaitGroup{} 33 | chA := make(chan net.Conn, totalConn) 34 | chB := make(chan net.Conn, totalConn) 35 | chC := make(chan net.Conn, totalConn) 36 | chD := make(chan net.Conn, totalConn) 37 | chErr := make(chan error, totalConn) 38 | conns := make([]net.Conn, totalConn)[:0] 39 | 40 | accept := func(ln net.Listener, chConn chan net.Conn) { 41 | for { 42 | conn, err := ln.Accept() 43 | if err != nil { 44 | chErr <- err 45 | break 46 | } 47 | chConn <- conn 48 | } 49 | } 50 | go accept(listenerA, chA) 51 | go accept(listenerB, chB) 52 | go accept(listenerC, chC) 53 | go accept(listenerD, chD) 54 | 55 | dialN := func(n int, addr string) { 56 | wg.Add(1) 57 | go func() { 58 | defer wg.Done() 59 | for i := 0; i < n; i++ { 60 | conn, err := net.Dial(network, addr) 61 | if err != nil { 62 | chErr <- err 63 | break 64 | } 65 | conns = append(conns, conn) 66 | } 67 | }() 68 | } 69 | closeConns := func() { 70 | for _, v := range conns { 71 | v.Close() 72 | } 73 | } 74 | dialN(totalConn, addr1) 75 | dialN(totalConn, addr2) 76 | 77 | wg.Wait() 78 | time.Sleep(time.Second / 5) 79 | 80 | if len(chErr) != 0 { 81 | t.Fatalf("len(chA) != maxOnlineA, want %v, got %v", 0, len(chErr)) 82 | } 83 | 84 | if len(chA)+len(chC) != maxOnlineA { 85 | t.Fatalf("len(chA)+len(chC) != maxOnlineA, want %v, got %v[A=%v, C=%v]", maxOnlineA, len(chA)+len(chC), len(chA), len(chC)) 86 | } 87 | 88 | if len(chB)+len(chD) != totalConn*2-maxOnlineA { 89 | t.Fatalf("len(chB)+len(chD) != maxOnlineA, want %v, got %v[B=%v, D=%v]", totalConn*2-maxOnlineA, len(chB)+len(chD), len(chB), len(chD)) 90 | } 91 | 92 | closeConns() 93 | 94 | clean := func(ln *ChanListener, chConn chan net.Conn) { 95 | n := len(chConn) 96 | for i := 0; i < n; i++ { 97 | <-chConn 98 | ln.Decrease() 99 | } 100 | } 101 | clean(listenerA, chA) 102 | clean(listenerB, chB) 103 | clean(listenerC, chC) 104 | clean(listenerD, chD) 105 | 106 | conns = conns[:0] 107 | dialN(totalConn, addr1) 108 | dialN(totalConn, addr2) 109 | defer closeConns() 110 | 111 | wg.Wait() 112 | time.Sleep(time.Second / 5) 113 | 114 | if len(chA)+len(chC) != maxOnlineA { 115 | t.Fatalf("len(chA)+len(chC) != maxOnlineA, want %v, got %v[A=%v, C=%v]", maxOnlineA, len(chA)+len(chC), len(chA), len(chC)) 116 | } 117 | 118 | if len(chB)+len(chD) != totalConn*2-maxOnlineA { 119 | t.Fatalf("len(chB)+len(chD) != maxOnlineA, want %v, got %v[B=%v, D=%v]", totalConn*2-maxOnlineA, len(chB)+len(chD), len(chB), len(chD)) 120 | } 121 | 122 | lm.Stop() 123 | } 124 | -------------------------------------------------------------------------------- /logging/log.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 lesismal. All rights reserved. 2 | // Use of this source code is governed by an MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package logging 6 | 7 | import ( 8 | "fmt" 9 | "io" 10 | "os" 11 | "time" 12 | ) 13 | 14 | var ( 15 | // TimeFormat is used to format time parameters. 16 | TimeFormat = "2006/01/02 15:04:05.000" 17 | 18 | // Output is used to receive log output. 19 | Output io.Writer = os.Stdout 20 | 21 | // DefaultLogger is the default logger and is used by arpc. 22 | DefaultLogger Logger = &logger{level: LevelInfo} 23 | ) 24 | 25 | const ( 26 | // LevelAll enables all logs. 27 | LevelAll = iota 28 | // LevelDebug logs are usually disabled in production. 29 | LevelDebug 30 | // LevelInfo is the default logging priority. 31 | LevelInfo 32 | // LevelWarn . 33 | LevelWarn 34 | // LevelError . 35 | LevelError 36 | // LevelNone disables all logs. 37 | LevelNone 38 | ) 39 | 40 | // Logger defines log interface. 41 | type Logger interface { 42 | Debug(format string, v ...interface{}) 43 | Info(format string, v ...interface{}) 44 | Warn(format string, v ...interface{}) 45 | Error(format string, v ...interface{}) 46 | } 47 | 48 | // SetLogger sets default logger. 49 | // 50 | //go:norace 51 | func SetLogger(l Logger) { 52 | DefaultLogger = l 53 | } 54 | 55 | // SetLevel sets default logger's priority. 56 | // 57 | //go:norace 58 | func SetLevel(lvl int) { 59 | if l, ok := DefaultLogger.(interface { 60 | SetLevel(lvl int) 61 | }); ok { 62 | l.SetLevel(lvl) 63 | } 64 | } 65 | 66 | // logger implements Logger and is used in arpc by default. 67 | type logger struct { 68 | level int 69 | } 70 | 71 | // SetLevel sets logs priority. 72 | // 73 | //go:norace 74 | func (l *logger) SetLevel(lvl int) { 75 | switch lvl { 76 | case LevelAll, LevelDebug, LevelInfo, LevelWarn, LevelError, LevelNone: 77 | l.level = lvl 78 | default: 79 | fmt.Fprintf(Output, "invalid log level: %v", lvl) 80 | } 81 | } 82 | 83 | // Debug uses fmt.Printf to log a message at LevelDebug. 84 | // 85 | //go:norace 86 | func (l *logger) Debug(format string, v ...interface{}) { 87 | if LevelDebug >= l.level { 88 | fmt.Fprintf(Output, time.Now().Format(TimeFormat)+" [DBG] "+format+"\n", v...) 89 | } 90 | } 91 | 92 | // Info uses fmt.Printf to log a message at LevelInfo. 93 | // 94 | //go:norace 95 | func (l *logger) Info(format string, v ...interface{}) { 96 | if LevelInfo >= l.level { 97 | fmt.Fprintf(Output, time.Now().Format(TimeFormat)+" [INF] "+format+"\n", v...) 98 | } 99 | } 100 | 101 | // Warn uses fmt.Printf to log a message at LevelWarn. 102 | // 103 | //go:norace 104 | func (l *logger) Warn(format string, v ...interface{}) { 105 | if LevelWarn >= l.level { 106 | fmt.Fprintf(Output, time.Now().Format(TimeFormat)+" [WRN] "+format+"\n", v...) 107 | } 108 | } 109 | 110 | // Error uses fmt.Printf to log a message at LevelError. 111 | // 112 | //go:norace 113 | func (l *logger) Error(format string, v ...interface{}) { 114 | if LevelError >= l.level { 115 | fmt.Fprintf(Output, time.Now().Format(TimeFormat)+" [ERR] "+format+"\n", v...) 116 | } 117 | } 118 | 119 | // Debug uses DefaultLogger to log a message at LevelDebug. 120 | // 121 | //go:norace 122 | func Debug(format string, v ...interface{}) { 123 | if DefaultLogger != nil { 124 | DefaultLogger.Debug(format, v...) 125 | } 126 | } 127 | 128 | // Info uses DefaultLogger to log a message at LevelInfo. 129 | // 130 | //go:norace 131 | func Info(format string, v ...interface{}) { 132 | if DefaultLogger != nil { 133 | DefaultLogger.Info(format, v...) 134 | } 135 | } 136 | 137 | // Warn uses DefaultLogger to log a message at LevelWarn. 138 | // 139 | //go:norace 140 | func Warn(format string, v ...interface{}) { 141 | if DefaultLogger != nil { 142 | DefaultLogger.Warn(format, v...) 143 | } 144 | } 145 | 146 | // Error uses DefaultLogger to log a message at LevelError. 147 | // 148 | //go:norace 149 | func Error(format string, v ...interface{}) { 150 | if DefaultLogger != nil { 151 | DefaultLogger.Error(format, v...) 152 | } 153 | } 154 | -------------------------------------------------------------------------------- /logging/log_test.go: -------------------------------------------------------------------------------- 1 | package logging 2 | 3 | import "testing" 4 | 5 | func TestSetLogger(t *testing.T) { 6 | l := &logger{level: LevelDebug} 7 | SetLogger(l) 8 | } 9 | 10 | func TestSetLevel(t *testing.T) { 11 | SetLevel(LevelAll) 12 | func() { 13 | defer func() { 14 | err := recover() 15 | if err != nil { 16 | t.Errorf("recorver returned err: %s", err) 17 | } 18 | }() 19 | SetLevel(1000) 20 | }() 21 | } 22 | 23 | func Test_logger_SetLevel(t *testing.T) { 24 | l := &logger{level: LevelDebug} 25 | l.SetLevel(LevelAll) 26 | } 27 | 28 | func Test_logger_Debug(t *testing.T) { 29 | l := &logger{level: LevelDebug} 30 | l.Debug("logger debug test") 31 | } 32 | 33 | func Test_logger_Info(t *testing.T) { 34 | l := &logger{level: LevelDebug} 35 | l.Info("logger info test") 36 | } 37 | 38 | func Test_logger_Warn(t *testing.T) { 39 | l := &logger{level: LevelDebug} 40 | l.Warn("logger warn test") 41 | } 42 | 43 | func Test_logger_Error(t *testing.T) { 44 | l := &logger{level: LevelDebug} 45 | l.Error("logger error test") 46 | } 47 | 48 | func Test_Debug(t *testing.T) { 49 | Debug("log.Debug") 50 | } 51 | 52 | func Test_Info(t *testing.T) { 53 | Info("log.Info") 54 | } 55 | 56 | func Test_Warn(t *testing.T) { 57 | Warn("log.Warn") 58 | } 59 | 60 | func Test_Error(t *testing.T) { 61 | Error("log.Error") 62 | } 63 | -------------------------------------------------------------------------------- /mempool/aligned_allocator.go: -------------------------------------------------------------------------------- 1 | package mempool 2 | 3 | import ( 4 | "sync" 5 | "unsafe" 6 | ) 7 | 8 | var ( 9 | alignedPools [alignedPoolBucketNum]sync.Pool 10 | alignedIndexes [maxAlignedBufferSize + 1]byte 11 | ) 12 | 13 | const ( 14 | minAlignedBufferSizeBits = 5 15 | maxAlignedBufferSizeBits = 15 16 | minAlignedBufferSize = 1 << minAlignedBufferSizeBits // 32 17 | minAlignedBufferSizeMask = minAlignedBufferSize - 1 // 31 18 | maxAlignedBufferSize = 1 << maxAlignedBufferSizeBits // 32k 19 | alignedPoolBucketNum = maxAlignedBufferSizeBits - minAlignedBufferSizeBits + 1 // 12 20 | ) 21 | 22 | //go:norace 23 | func init() { 24 | var poolSizes [alignedPoolBucketNum]int 25 | for i := range alignedPools { 26 | size := 1 << (i + minAlignedBufferSizeBits) 27 | poolSizes[i] = size 28 | alignedPools[i].New = func() interface{} { 29 | b := make([]byte, size) 30 | return &b 31 | } 32 | } 33 | 34 | getPoolBySize := func(size int) byte { 35 | for i, n := range poolSizes { 36 | if size <= n { 37 | return byte(i) 38 | } 39 | } 40 | return 0xFF 41 | } 42 | 43 | for i := range alignedIndexes { 44 | alignedIndexes[i] = getPoolBySize(i) 45 | } 46 | } 47 | 48 | // NewAligned . 49 | // 50 | //go:norace 51 | func NewAligned() Allocator { 52 | amp := &AlignedAllocator{ 53 | debugger: &debugger{}, 54 | } 55 | return amp 56 | } 57 | 58 | // AlignedAllocator . 59 | type AlignedAllocator struct { 60 | *debugger 61 | } 62 | 63 | // Malloc . 64 | // 65 | //go:norace 66 | func (amp *AlignedAllocator) Malloc(size int) *[]byte { 67 | if size < 0 { 68 | return nil 69 | } 70 | var ret []byte 71 | if size <= maxAlignedBufferSize { 72 | idx := alignedIndexes[size] 73 | ret = (*(alignedPools[idx].Get().(*[]byte)))[:size] 74 | } else { 75 | ret = make([]byte, size) 76 | } 77 | amp.incrMalloc(&ret) 78 | return &ret 79 | } 80 | 81 | // Realloc . 82 | // 83 | //go:norace 84 | func (amp *AlignedAllocator) Realloc(pbuf *[]byte, size int) *[]byte { 85 | if size <= cap(*pbuf) { 86 | *pbuf = (*pbuf)[:size] 87 | return pbuf 88 | } 89 | newBufPtr := amp.Malloc(size) 90 | copy(*newBufPtr, *pbuf) 91 | amp.Free(pbuf) 92 | return newBufPtr 93 | } 94 | 95 | // Append . 96 | // 97 | //go:norace 98 | func (amp *AlignedAllocator) Append(pbuf *[]byte, more ...byte) *[]byte { 99 | if cap(*pbuf)-len(*pbuf) >= len(more) { 100 | *pbuf = append(*pbuf, more...) 101 | return pbuf 102 | } 103 | newBufPtr := amp.Malloc(len(*pbuf) + len(more)) 104 | copy(*newBufPtr, *pbuf) 105 | copy((*newBufPtr)[len(*pbuf):], more) 106 | amp.Free(pbuf) 107 | return newBufPtr 108 | } 109 | 110 | // AppendString . 111 | // 112 | //go:norace 113 | func (amp *AlignedAllocator) AppendString(pbuf *[]byte, s string) *[]byte { 114 | x := (*[2]uintptr)(unsafe.Pointer(&s)) 115 | h := [3]uintptr{x[0], x[1], x[1]} 116 | more := *(*[]byte)(unsafe.Pointer(&h)) 117 | return amp.Append(pbuf, more...) 118 | } 119 | 120 | // Free . 121 | // 122 | //go:norace 123 | func (amp *AlignedAllocator) Free(pbuf *[]byte) { 124 | size := cap(*pbuf) 125 | if (size&minAlignedBufferSizeMask) != 0 || size > maxAlignedBufferSize { 126 | return 127 | } 128 | amp.incrFree(pbuf) 129 | idx := alignedIndexes[size] 130 | alignedPools[idx].Put(pbuf) 131 | } 132 | -------------------------------------------------------------------------------- /mempool/allocator.go: -------------------------------------------------------------------------------- 1 | package mempool 2 | 3 | // DefaultMemPool . 4 | var DefaultMemPool = New(1024, 1024*1024*1024) 5 | 6 | type Allocator interface { 7 | Malloc(size int) *[]byte 8 | Realloc(buf *[]byte, size int) *[]byte // deprecated. 9 | Append(buf *[]byte, more ...byte) *[]byte 10 | AppendString(buf *[]byte, more string) *[]byte 11 | Free(buf *[]byte) 12 | } 13 | 14 | type DebugAllocator interface { 15 | Allocator 16 | String() string 17 | SetDebug(bool) 18 | } 19 | 20 | //go:norace 21 | func Malloc(size int) *[]byte { 22 | return DefaultMemPool.Malloc(size) 23 | } 24 | 25 | //go:norace 26 | func Realloc(pbuf *[]byte, size int) *[]byte { 27 | return DefaultMemPool.Realloc(pbuf, size) 28 | } 29 | 30 | //go:norace 31 | func Append(pbuf *[]byte, more ...byte) *[]byte { 32 | return DefaultMemPool.Append(pbuf, more...) 33 | } 34 | 35 | //go:norace 36 | func AppendString(pbuf *[]byte, more string) *[]byte { 37 | return DefaultMemPool.AppendString(pbuf, more) 38 | } 39 | 40 | //go:norace 41 | func Free(pbuf *[]byte) { 42 | DefaultMemPool.Free(pbuf) 43 | } 44 | 45 | // func Init(bufSize, freeSize int) { 46 | // DefaultMemPool = New(bufSize, freeSize) 47 | // } 48 | -------------------------------------------------------------------------------- /mempool/debugger.go: -------------------------------------------------------------------------------- 1 | package mempool 2 | 3 | import ( 4 | "encoding/json" 5 | "sync" 6 | "sync/atomic" 7 | ) 8 | 9 | type sizeMap struct { 10 | MallocCount int64 `json:"MallocCount"` 11 | FreeCount int64 `json:"FreeCount"` 12 | NeedFree int64 `json:"NeedFree"` 13 | } 14 | 15 | type debugger struct { 16 | mux sync.Mutex 17 | on bool 18 | MallocCount int64 `json:"MallocCount"` 19 | FreeCount int64 `json:"FreeCount"` 20 | NeedFree int64 `json:"NeedFree"` 21 | SizeMap map[int]*sizeMap `json:"SizeMap"` 22 | } 23 | 24 | //go:norace 25 | func (d *debugger) SetDebug(dbg bool) { 26 | d.on = dbg 27 | } 28 | 29 | //go:norace 30 | func (d *debugger) incrMalloc(pbuf *[]byte) { 31 | if d.on { 32 | d.incrMallocSlow(pbuf) 33 | } 34 | } 35 | 36 | //go:norace 37 | func (d *debugger) incrMallocSlow(pbuf *[]byte) { 38 | atomic.AddInt64(&d.MallocCount, 1) 39 | atomic.AddInt64(&d.NeedFree, 1) 40 | size := cap(*pbuf) 41 | d.mux.Lock() 42 | defer d.mux.Unlock() 43 | if d.SizeMap == nil { 44 | d.SizeMap = map[int]*sizeMap{} 45 | } 46 | if v, ok := d.SizeMap[size]; ok { 47 | v.MallocCount++ 48 | v.NeedFree++ 49 | } else { 50 | d.SizeMap[size] = &sizeMap{ 51 | MallocCount: 1, 52 | NeedFree: 1, 53 | } 54 | } 55 | } 56 | 57 | //go:norace 58 | func (d *debugger) incrFree(pbuf *[]byte) { 59 | if d.on { 60 | d.incrFreeSlow(pbuf) 61 | } 62 | } 63 | 64 | //go:norace 65 | func (d *debugger) incrFreeSlow(pbuf *[]byte) { 66 | atomic.AddInt64(&d.FreeCount, 1) 67 | atomic.AddInt64(&d.NeedFree, -1) 68 | size := cap(*pbuf) 69 | d.mux.Lock() 70 | defer d.mux.Unlock() 71 | if v, ok := d.SizeMap[size]; ok { 72 | v.FreeCount++ 73 | v.NeedFree-- 74 | } else { 75 | d.SizeMap[size] = &sizeMap{ 76 | MallocCount: 1, 77 | NeedFree: -1, 78 | } 79 | } 80 | } 81 | 82 | //go:norace 83 | func (d *debugger) String() string { 84 | if d.on { 85 | b, err := json.Marshal(d) 86 | if err == nil { 87 | return string(b) 88 | } 89 | } 90 | return "" 91 | } 92 | -------------------------------------------------------------------------------- /mempool/mempool.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 lesismal. All rights reserved. 2 | // Use of this source code is governed by an MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package mempool 6 | 7 | import ( 8 | "sync" 9 | ) 10 | 11 | // MemPool . 12 | type MemPool struct { 13 | *debugger 14 | bufSize int 15 | freeSize int 16 | pool *sync.Pool 17 | } 18 | 19 | // New . 20 | func New(bufSize, freeSize int) Allocator { 21 | if bufSize <= 0 { 22 | bufSize = 64 23 | } 24 | if freeSize <= 0 { 25 | freeSize = 64 * 1024 26 | } 27 | if freeSize < bufSize { 28 | freeSize = bufSize 29 | } 30 | 31 | mp := &MemPool{ 32 | debugger: &debugger{}, 33 | bufSize: bufSize, 34 | freeSize: freeSize, 35 | pool: &sync.Pool{}, 36 | // Debug: true, 37 | } 38 | mp.pool.New = func() interface{} { 39 | buf := make([]byte, bufSize) 40 | return &buf 41 | } 42 | return mp 43 | } 44 | 45 | // Malloc . 46 | func (mp *MemPool) Malloc(size int) *[]byte { 47 | var ret []byte 48 | if size > mp.freeSize { 49 | ret = make([]byte, size) 50 | mp.incrMalloc(&ret) 51 | return &ret 52 | } 53 | pbuf := mp.pool.Get().(*[]byte) 54 | n := cap(*pbuf) 55 | if n < size { 56 | *pbuf = append((*pbuf)[:n], make([]byte, size-n)...) 57 | } 58 | (*pbuf) = (*pbuf)[:size] 59 | mp.incrMalloc(pbuf) 60 | return pbuf 61 | } 62 | 63 | // Realloc . 64 | func (mp *MemPool) Realloc(pbuf *[]byte, size int) *[]byte { 65 | if size <= cap(*pbuf) { 66 | *pbuf = (*pbuf)[:size] 67 | return pbuf 68 | } 69 | 70 | if cap(*pbuf) < mp.freeSize { 71 | newBufPtr := mp.pool.Get().(*[]byte) 72 | n := cap(*newBufPtr) 73 | if n < size { 74 | *newBufPtr = append((*newBufPtr)[:n], make([]byte, size-n)...) 75 | } 76 | *newBufPtr = (*newBufPtr)[:size] 77 | copy(*newBufPtr, *pbuf) 78 | mp.Free(pbuf) 79 | return newBufPtr 80 | } 81 | *pbuf = append((*pbuf)[:cap(*pbuf)], make([]byte, size-cap(*pbuf))...)[:size] 82 | return pbuf 83 | } 84 | 85 | // Append . 86 | func (mp *MemPool) Append(pbuf *[]byte, more ...byte) *[]byte { 87 | *pbuf = append(*pbuf, more...) 88 | return pbuf 89 | } 90 | 91 | // AppendString . 92 | func (mp *MemPool) AppendString(pbuf *[]byte, more string) *[]byte { 93 | *pbuf = append(*pbuf, more...) 94 | return pbuf 95 | } 96 | 97 | // Free . 98 | func (mp *MemPool) Free(pbuf *[]byte) { 99 | if pbuf != nil && cap(*pbuf) > 0 { 100 | mp.incrFree(pbuf) 101 | if cap(*pbuf) > mp.freeSize { 102 | return 103 | } 104 | mp.pool.Put(pbuf) 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /mempool/mempool_test.go: -------------------------------------------------------------------------------- 1 | package mempool 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestMemPool(t *testing.T) { 8 | pool := New(1024*1024*1024, 1024*1024*1024) 9 | for i := 0; i < 1024*1024; i++ { 10 | pbuf := pool.Malloc(i) 11 | if len(*pbuf) != i { 12 | t.Fatalf("invalid len: %v != %v", len(*pbuf), i) 13 | } 14 | pool.Free(pbuf) 15 | } 16 | for i := 1024 * 1024; i < 1024*1024*1024; i += 1024 * 1024 { 17 | pbuf := pool.Malloc(i) 18 | if len(*pbuf) != i { 19 | t.Fatalf("invalid len: %v != %v", len(*pbuf), i) 20 | } 21 | pool.Free(pbuf) 22 | } 23 | 24 | pbuf := pool.Malloc(0) 25 | for i := 1; i < 1024*1024; i++ { 26 | pbuf = pool.Realloc(pbuf, i) 27 | if len(*pbuf) != i { 28 | t.Fatalf("invalid len: %v != %v", len(*pbuf), i) 29 | } 30 | } 31 | pool.Free(pbuf) 32 | } 33 | 34 | func TestAlignedMemPool(t *testing.T) { 35 | pool := NewAligned() 36 | b := pool.Malloc(32769) 37 | pool.Free(b) 38 | tmpBuf := make([]byte, 60001) 39 | pool.Free(&tmpBuf) 40 | for i := 0; i < 1024*64+1024; i += 1 { 41 | pbuf := pool.Malloc(i) 42 | if len(*pbuf) != i { 43 | t.Fatalf("invalid length: %v != %v", len(*pbuf), i) 44 | } 45 | pool.Free(pbuf) 46 | } 47 | for i := minAlignedBufferSizeBits; i < maxAlignedBufferSizeBits; i++ { 48 | size := 1 << i 49 | pbuf := pool.Malloc(size) 50 | if len(*pbuf) != size || cap(*pbuf) > size*2 { 51 | t.Fatalf("invalid len or cap: %v, %v %v, %v ", i, len(*pbuf), cap(*pbuf), size) 52 | } 53 | pbuf = pool.Malloc(size + 1) 54 | if i != maxAlignedBufferSizeBits { 55 | if len(*pbuf) != size+1 || cap(*pbuf) != size*2 || cap(*pbuf) > (size+1)*2 { 56 | t.Fatalf("invalid len or cap: %v, %v %v, %v ", i, len(*pbuf), cap(*pbuf), size) 57 | } 58 | } else { 59 | if len(*pbuf) != size+1 || cap(*pbuf) != size+1 { 60 | t.Fatalf("invalid len or cap: %v, %v %v, %v ", i, len(*pbuf), cap(*pbuf), size) 61 | } 62 | } 63 | pool.Free(pbuf) 64 | } 65 | for i := -10; i < 0; i++ { 66 | pbuf := pool.Malloc(i) 67 | if pbuf != nil { 68 | t.Fatalf("invalid malloc, should be nil but got: %v, %v", len(*pbuf), cap(*pbuf)) 69 | } 70 | } 71 | for i := 1 << maxAlignedBufferSizeBits; i < 1< 50 { 152 | break 153 | } 154 | n, err := fmt.Fprintf(stackWriter, "\t%d [file: %s] [func: %s] [line: %d]\n", i-1, file, runtime.FuncForPC(pc).Name(), line) 155 | if n > 0 { 156 | nwrite += n 157 | } 158 | if err != nil { 159 | break 160 | } 161 | } 162 | 163 | buf := stackBuf[:nwrite] 164 | stack := *(*string)(unsafe.Pointer(&buf)) 165 | if ptr, ok := stackMap[stack]; ok { 166 | return ptr2StackString(ptr), ptr 167 | } 168 | stack = string(buf) 169 | ptr := *(*[2]uintptr)(unsafe.Pointer(&stack)) 170 | ptrCopy := [2]uintptr{ptr[0], ptr[1]} 171 | stackMap[stack] = ptrCopy 172 | return stack, ptrCopy 173 | } 174 | 175 | //go:norace 176 | func ptr2StackString(ptr [2]uintptr) string { 177 | if ptr[0] == 0 && ptr[1] == 0 { 178 | return "nil" 179 | } 180 | return *((*string)(unsafe.Pointer(&ptr))) 181 | } 182 | 183 | // func bytesToStr(b []byte) string { 184 | // return *(*string)(unsafe.Pointer(&b)) 185 | // } 186 | 187 | // func strToBytes(s string) []byte { 188 | // x := (*[2]uintptr)(unsafe.Pointer(&s)) 189 | // h := [3]uintptr{x[0], x[1], x[1]} 190 | // return *(*[]byte)(unsafe.Pointer(&h)) 191 | // } 192 | 193 | //go:norace 194 | func printStack(info string, preStackPtr [2]uintptr) { 195 | var ( 196 | currStack, _ = getStackAndPtr() 197 | preStack = ptr2StackString(preStackPtr) 198 | ) 199 | fmt.Printf(` 200 | ------------------------------------------- 201 | [mempool trace] %v -> 202 | 203 | previous stack: 204 | %v 205 | 206 | ------------------------------------------- 207 | 208 | current stack : 209 | %v 210 | ------------------------------------------- 211 | 212 | `, info, preStack, currStack) 213 | // os.Exit(-1) 214 | } 215 | 216 | //go:norace 217 | func bytesPointer(pbuf *[]byte) uintptr { 218 | return (uintptr)(unsafe.Pointer(&((*pbuf)[:1][0]))) 219 | } 220 | 221 | // func stringPointer(s *string) uintptr { 222 | // ptr := (*uintptr)(unsafe.Pointer(s)) 223 | // return (uintptr)(unsafe.Pointer(&ptr)) 224 | // } 225 | -------------------------------------------------------------------------------- /nbhttp/body.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 lesismal. All rights reserved. 2 | // Use of this source code is governed by an MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package nbhttp 6 | 7 | import ( 8 | "io" 9 | "sync" 10 | ) 11 | 12 | var ( 13 | emptyBodyReader = BodyReader{} 14 | bodyReaderPool = sync.Pool{ 15 | New: func() interface{} { 16 | return &BodyReader{} 17 | }, 18 | } 19 | ) 20 | 21 | // BodyReader implements io.ReadCloser and is to be used as HTTP body. 22 | type BodyReader struct { 23 | index int // first buffer read index 24 | left int // num of byte left 25 | buffers []*[]byte // buffers that storage HTTP body 26 | engine *Engine // allocator that manages buffers 27 | closed bool 28 | } 29 | 30 | // Read reads body bytes to p, returns the num of bytes read and error. 31 | // 32 | //go:norace 33 | func (br *BodyReader) Read(p []byte) (int, error) { 34 | need := len(p) 35 | if br.left <= 0 { 36 | return 0, io.EOF 37 | } 38 | ncopy := 0 39 | for ncopy < need && br.left > 0 { 40 | pbuf := br.buffers[0] 41 | nc := copy(p[ncopy:], (*pbuf)[br.index:]) 42 | if nc+br.index >= len(*pbuf) { 43 | br.engine.BodyAllocator.Free(pbuf) 44 | br.buffers[0] = nil 45 | br.buffers = br.buffers[1:] 46 | br.index = 0 47 | } else { 48 | br.index += nc 49 | } 50 | ncopy += nc 51 | br.left -= nc 52 | } 53 | return ncopy, nil 54 | } 55 | 56 | // Close frees buffers and resets itself to empty value. 57 | // 58 | //go:norace 59 | func (br *BodyReader) Close() error { 60 | if br.closed { 61 | return nil 62 | } 63 | br.closed = true 64 | if br.buffers != nil { 65 | for _, b := range br.buffers { 66 | br.engine.BodyAllocator.Free(b) 67 | } 68 | } 69 | // *br = emptyBodyReader 70 | // bodyReaderPool.Put(br) 71 | return nil 72 | } 73 | 74 | // Index returns current head buffer's reading index. 75 | // 76 | //go:norace 77 | func (br *BodyReader) Index() int { 78 | return br.index 79 | } 80 | 81 | // Left returns how many bytes are left for reading. 82 | // 83 | //go:norace 84 | func (br *BodyReader) Left() int { 85 | return br.left 86 | } 87 | 88 | // Buffers returns the underlayer buffers that store the HTTP Body. 89 | // 90 | //go:norace 91 | func (br *BodyReader) Buffers() []*[]byte { 92 | return br.buffers 93 | } 94 | 95 | // RawBodyBuffers returns a reference of BodyReader's current buffers. 96 | // The buffers returned will be closed(released automatically when closed) 97 | // HTTP Handler is called, users should not free the buffers and should 98 | // not hold it any longer after the HTTP Handler is called. 99 | // 100 | //go:norace 101 | func (br *BodyReader) RawBodyBuffers() [][]byte { 102 | buffers := make([][]byte, len(br.buffers)) 103 | for i, pbuf := range br.buffers { 104 | if i == 0 { 105 | buffers[i] = (*pbuf)[br.index:] 106 | } else { 107 | buffers[i] = *pbuf 108 | } 109 | } 110 | return buffers 111 | } 112 | 113 | // Engine returns Engine that creates this HTTP Body. 114 | // 115 | //go:norace 116 | func (br *BodyReader) Engine() *Engine { 117 | return br.engine 118 | } 119 | 120 | // append appends data to buffers. 121 | // 122 | //go:norace 123 | func (br *BodyReader) append(data []byte) error { 124 | if len(data) == 0 { 125 | return nil 126 | } 127 | 128 | if br.engine.MaxHTTPBodySize > 0 && len(data)+br.left > br.engine.MaxHTTPBodySize { 129 | return ErrTooLong 130 | } 131 | 132 | br.left += (len(data)) 133 | if len(br.buffers) == 0 { 134 | pbuf := br.engine.BodyAllocator.Malloc(len(data)) 135 | copy(*pbuf, data) 136 | br.buffers = append(br.buffers, pbuf) 137 | } else { 138 | i := len(br.buffers) - 1 139 | pbuf := br.buffers[i] 140 | l := len(*pbuf) 141 | bLeft := cap(*pbuf) - len(*pbuf) 142 | if bLeft > 0 { 143 | if bLeft > len(data) { 144 | *pbuf = (*pbuf)[:l+len(data)] 145 | } else { 146 | *pbuf = (*pbuf)[:cap(*pbuf)] 147 | } 148 | nc := copy((*pbuf)[l:], data) 149 | data = data[nc:] 150 | br.buffers[i] = pbuf 151 | } 152 | if len(data) > 0 { 153 | pbuf = br.engine.BodyAllocator.Malloc(len(data)) 154 | copy(*pbuf, data) 155 | br.buffers = append(br.buffers, pbuf) 156 | } 157 | } 158 | return nil 159 | } 160 | 161 | // NewBodyReader creates a BodyReader. 162 | // 163 | //go:norace 164 | func NewBodyReader(engine *Engine) *BodyReader { 165 | br := bodyReaderPool.Get().(*BodyReader) 166 | br.engine = engine 167 | return br 168 | } 169 | -------------------------------------------------------------------------------- /nbhttp/body_test.go: -------------------------------------------------------------------------------- 1 | package nbhttp 2 | 3 | import ( 4 | "bytes" 5 | "crypto/rand" 6 | "io" 7 | "testing" 8 | 9 | "github.com/lesismal/nbio/mempool" 10 | ) 11 | 12 | func TestBodyReaderPool(t *testing.T) { 13 | br := bodyReaderPool.Get().(*BodyReader) 14 | buf := make([]byte, 10) 15 | pbuf := &buf 16 | br.buffers = append(br.buffers, pbuf) 17 | *br = emptyBodyReader 18 | bodyReaderPool.Put(br) 19 | 20 | for i := 0; i < 1000; i++ { 21 | br2 := bodyReaderPool.Get().(*BodyReader) 22 | if br2.buffers != nil { 23 | t.Fatal("len>0") 24 | } 25 | buf = make([]byte, 10) 26 | pbuf = &buf 27 | br2.buffers = append(br.buffers, pbuf) 28 | *br2 = emptyBodyReader 29 | bodyReaderPool.Put(br) 30 | } 31 | } 32 | 33 | func TestBodyReader(t *testing.T) { 34 | engine := NewEngine(Config{ 35 | BodyAllocator: mempool.NewAligned(), 36 | }) 37 | var ( 38 | b0 []byte 39 | b1 = make([]byte, 2049) 40 | b2 = make([]byte, 1132) 41 | b3 = make([]byte, 11111) 42 | ) 43 | rand.Read(b1) 44 | rand.Read(b2) 45 | rand.Read(b3) 46 | 47 | allBytes := append(b0, b1...) 48 | allBytes = append(allBytes, b2...) 49 | allBytes = append(allBytes, b3...) 50 | 51 | newBR := func() *BodyReader { 52 | br := NewBodyReader(engine) 53 | br.append(b1) 54 | br.append(b2) 55 | br.append(b3) 56 | return br 57 | } 58 | 59 | br1 := newBR() 60 | body1, err := io.ReadAll(br1) 61 | if err != nil { 62 | t.Fatalf("io.ReadAll(br1) failed: %v", err) 63 | } 64 | if !bytes.Equal(allBytes, body1) { 65 | t.Fatalf("!bytes.Equal(allBytes, body1)") 66 | } 67 | br1.Close() 68 | 69 | br2 := newBR() 70 | body2 := make([]byte, len(allBytes)) 71 | for i := range body2 { 72 | _, err := br2.Read(body2[i : i+1]) 73 | if err != nil { 74 | t.Fatalf("br2.Readbody2[%d:%d] failed: %v", i, i+1, err) 75 | } 76 | } 77 | if !bytes.Equal(allBytes, body2) { 78 | t.Fatalf("!bytes.Equal(allBytes, body2)") 79 | } 80 | br2.Close() 81 | } 82 | -------------------------------------------------------------------------------- /nbhttp/client_conn.go: -------------------------------------------------------------------------------- 1 | package nbhttp 2 | 3 | import ( 4 | "io" 5 | "net" 6 | "net/http" 7 | "net/url" 8 | "runtime" 9 | "strings" 10 | "sync" 11 | "time" 12 | "unsafe" 13 | 14 | "github.com/lesismal/llib/std/crypto/tls" 15 | "github.com/lesismal/nbio" 16 | "github.com/lesismal/nbio/logging" 17 | "github.com/lesismal/nbio/mempool" 18 | ) 19 | 20 | type resHandler struct { 21 | c net.Conn 22 | t time.Time 23 | h func(res *http.Response, conn net.Conn, err error) 24 | } 25 | 26 | // ClientConn . 27 | type ClientConn struct { 28 | mux sync.Mutex 29 | conn net.Conn 30 | handlers []resHandler 31 | 32 | closed bool 33 | 34 | onClose func() 35 | 36 | Engine *Engine 37 | 38 | Jar http.CookieJar 39 | 40 | Timeout time.Duration 41 | 42 | IdleConnTimeout time.Duration 43 | 44 | TLSClientConfig *tls.Config 45 | 46 | Dial func(network, addr string) (net.Conn, error) 47 | 48 | Proxy func(*http.Request) (*url.URL, error) 49 | 50 | CheckRedirect func(req *http.Request, via []*http.Request) error 51 | } 52 | 53 | // Reset resets itself as new created. 54 | // 55 | //go:norace 56 | func (c *ClientConn) Reset() { 57 | c.mux.Lock() 58 | if c.closed { 59 | c.conn = nil 60 | c.handlers = nil 61 | c.closed = false 62 | } 63 | c.mux.Unlock() 64 | } 65 | 66 | // OnClose registers a callback for closing. 67 | // 68 | //go:norace 69 | func (c *ClientConn) OnClose(h func()) { 70 | c.onClose = h 71 | } 72 | 73 | // Close closes underlayer connection with EOF. 74 | // 75 | //go:norace 76 | func (c *ClientConn) Close() { 77 | c.CloseWithError(io.EOF) 78 | } 79 | 80 | // CloseWithError closes underlayer connection with error. 81 | // 82 | //go:norace 83 | func (c *ClientConn) CloseWithError(err error) { 84 | c.mux.Lock() 85 | defer c.mux.Unlock() 86 | if !c.closed { 87 | c.closed = true 88 | c.closeWithErrorWithoutLock(err) 89 | } 90 | } 91 | 92 | //go:norace 93 | func (c *ClientConn) closeWithErrorWithoutLock(err error) { 94 | if err == nil { 95 | err = io.EOF 96 | } 97 | for _, h := range c.handlers { 98 | h.h(nil, c.conn, err) 99 | } 100 | c.handlers = nil 101 | if c.conn != nil { 102 | nbc, ok := c.conn.(*nbio.Conn) 103 | if !ok { 104 | if tlsConn, ok2 := c.conn.(*tls.Conn); ok2 { 105 | nbc, ok = tlsConn.Conn().(*nbio.Conn) 106 | } 107 | } 108 | if ok { 109 | key, _ := conn2Array(nbc) 110 | c.Engine.mux.Lock() 111 | delete(c.Engine.dialerConns, key) 112 | c.Engine.mux.Unlock() 113 | } 114 | c.conn.Close() 115 | c.conn = nil 116 | } 117 | if c.onClose != nil { 118 | c.onClose() 119 | } 120 | } 121 | 122 | //go:norace 123 | func (c *ClientConn) onResponse(res *http.Response, err error) { 124 | c.mux.Lock() 125 | defer c.mux.Unlock() 126 | 127 | if !c.closed && len(c.handlers) > 0 { 128 | head := c.handlers[0] 129 | head.h(res, c.conn, err) 130 | 131 | c.handlers = c.handlers[1:] 132 | if len(c.handlers) > 0 { 133 | next := c.handlers[0] 134 | timeout := c.Timeout 135 | deadline := next.t.Add(timeout) 136 | if timeout > 0 { 137 | if time.Now().After(deadline) { 138 | c.closeWithErrorWithoutLock(ErrClientTimeout) 139 | } 140 | } else { 141 | c.conn.SetReadDeadline(deadline) 142 | } 143 | } else { 144 | if c.IdleConnTimeout > 0 { 145 | c.conn.SetReadDeadline(time.Now().Add(c.IdleConnTimeout)) 146 | } else { 147 | c.conn.SetReadDeadline(time.Time{}) 148 | } 149 | } 150 | if len(c.handlers) == 0 { 151 | c.handlers = nil 152 | } 153 | } 154 | } 155 | 156 | // Do sends an HTTP request and returns an HTTP response. 157 | // Notice: 158 | // 1. It's blocking when Dial to the server; 159 | // 2. It's non-blocking for waiting for the response; 160 | // 3. It calls the handler when the response is received 161 | // or other errors occur, such as timeout. 162 | // 163 | //go:norace 164 | func (c *ClientConn) Do(req *http.Request, handler func(res *http.Response, conn net.Conn, err error)) { 165 | c.mux.Lock() 166 | defer func() { 167 | c.mux.Unlock() 168 | if err := recover(); err != nil { 169 | const size = 64 << 10 170 | buf := make([]byte, size) 171 | buf = buf[:runtime.Stack(buf, false)] 172 | logging.Error("ClientConn Do failed: %v\n%v\n", err, *(*string)(unsafe.Pointer(&buf))) 173 | } 174 | }() 175 | 176 | if c.closed { 177 | handler(nil, nil, ErrClientClosed) 178 | return 179 | } 180 | 181 | var engine = c.Engine 182 | var confTimeout = c.Timeout 183 | 184 | c.handlers = append(c.handlers, resHandler{c: c.conn, t: time.Now(), h: handler}) 185 | 186 | var deadline time.Time 187 | if confTimeout > 0 { 188 | deadline = time.Now().Add(confTimeout) 189 | } 190 | 191 | sendRequest := func() { 192 | if c.Engine.WriteTimeout > 0 { 193 | c.conn.SetWriteDeadline(time.Now().Add(c.Engine.WriteTimeout)) 194 | } 195 | err := req.Write(c.conn) 196 | if err != nil { 197 | c.closeWithErrorWithoutLock(err) 198 | return 199 | } 200 | } 201 | 202 | if c.conn != nil { 203 | if confTimeout > 0 && len(c.handlers) == 1 { 204 | c.conn.SetReadDeadline(deadline) 205 | } 206 | sendRequest() 207 | } else { 208 | var timeout time.Duration 209 | if confTimeout > 0 { 210 | timeout = time.Until(deadline) 211 | if timeout <= 0 { 212 | c.closeWithErrorWithoutLock(ErrClientTimeout) 213 | return 214 | } 215 | } 216 | 217 | strs := strings.Split(req.URL.Host, ":") 218 | host := strs[0] 219 | port := req.URL.Scheme 220 | if len(strs) >= 2 { 221 | port = strs[1] 222 | } 223 | addr := host + ":" + port 224 | 225 | var dialer = c.Dial 226 | var netDial netDialerFunc 227 | if confTimeout <= 0 { 228 | if dialer == nil { 229 | dialer = net.Dial 230 | } 231 | netDial = func(network, addr string) (net.Conn, error) { 232 | return dialer(network, addr) 233 | } 234 | } else { 235 | if dialer == nil { 236 | dialer = func(network, addr string) (net.Conn, error) { 237 | return net.DialTimeout(network, addr, timeout) 238 | } 239 | } 240 | netDial = func(network, addr string) (net.Conn, error) { 241 | conn, err := dialer(network, addr) 242 | if err == nil { 243 | conn.SetReadDeadline(deadline) 244 | } 245 | return conn, err 246 | } 247 | } 248 | 249 | if c.Proxy != nil { 250 | proxyURL, err := c.Proxy(req) 251 | if err != nil { 252 | c.closeWithErrorWithoutLock(err) 253 | return 254 | } 255 | if proxyURL != nil { 256 | dialer, err := proxyFromURL(proxyURL, netDial) 257 | if err != nil { 258 | c.closeWithErrorWithoutLock(err) 259 | return 260 | } 261 | netDial = dialer.Dial 262 | } 263 | } 264 | 265 | netConn, err := netDial(defaultNetwork, addr) 266 | if err != nil { 267 | c.closeWithErrorWithoutLock(err) 268 | return 269 | } 270 | 271 | switch req.URL.Scheme { 272 | case "http": 273 | var nbc *nbio.Conn 274 | nbc, err = nbio.NBConn(netConn) 275 | if err != nil { 276 | c.closeWithErrorWithoutLock(err) 277 | return 278 | } 279 | 280 | key, _ := conn2Array(nbc) 281 | engine.mux.Lock() 282 | engine.dialerConns[key] = struct{}{} 283 | engine.mux.Unlock() 284 | 285 | c.conn = nbc 286 | processor := NewClientProcessor(c, c.onResponse) 287 | parser := NewParser(nbc, engine, processor, true, nbc.Execute) 288 | parser.OnClose(func(p *Parser, err error) { 289 | c.CloseWithError(err) 290 | }) 291 | nbc.SetSession(parser) 292 | 293 | nbc.OnData(engine.DataHandler) 294 | engine.AddConn(nbc) 295 | case "https": 296 | tlsConfig := c.TLSClientConfig 297 | if tlsConfig == nil { 298 | tlsConfig = &tls.Config{} 299 | } else { 300 | tlsConfig = tlsConfig.Clone() 301 | } 302 | if tlsConfig.ServerName == "" { 303 | tlsConfig.ServerName = host 304 | } 305 | tlsConn := tls.NewConn(netConn, tlsConfig, true, false, mempool.DefaultMemPool) 306 | err = tlsConn.Handshake() 307 | if err != nil { 308 | c.closeWithErrorWithoutLock(err) 309 | return 310 | } 311 | if !tlsConfig.InsecureSkipVerify { 312 | if err := tlsConn.VerifyHostname(tlsConfig.ServerName); err != nil { 313 | c.closeWithErrorWithoutLock(err) 314 | return 315 | } 316 | } 317 | 318 | nbc, err := nbio.NBConn(tlsConn.Conn()) 319 | if err != nil { 320 | c.closeWithErrorWithoutLock(err) 321 | return 322 | } 323 | 324 | key, err := conn2Array(nbc) 325 | if err != nil { 326 | logging.Error("add dialer conn failed: %v", err) 327 | c.closeWithErrorWithoutLock(err) 328 | return 329 | } 330 | engine.mux.Lock() 331 | engine.dialerConns[key] = struct{}{} 332 | engine.mux.Unlock() 333 | 334 | isNonblock := true 335 | tlsConn.ResetConn(nbc, isNonblock) 336 | 337 | nbhttpConn := &Conn{Conn: tlsConn} 338 | c.conn = nbhttpConn 339 | processor := NewClientProcessor(c, c.onResponse) 340 | parser := NewParser(nbhttpConn, engine, processor, true, nbc.Execute) 341 | parser.Conn = nbhttpConn 342 | parser.Engine = engine 343 | parser.OnClose(func(p *Parser, err error) { 344 | c.CloseWithError(err) 345 | }) 346 | nbc.SetSession(parser) 347 | 348 | nbc.OnData(engine.TLSDataHandler) 349 | _, err = engine.AddConn(nbc) 350 | if err != nil { 351 | c.closeWithErrorWithoutLock(err) 352 | return 353 | } 354 | default: 355 | c.closeWithErrorWithoutLock(ErrClientUnsupportedSchema) 356 | return 357 | } 358 | 359 | sendRequest() 360 | } 361 | } 362 | -------------------------------------------------------------------------------- /nbhttp/convert.go: -------------------------------------------------------------------------------- 1 | package nbhttp 2 | 3 | import ( 4 | "crypto/tls" 5 | "encoding/binary" 6 | "fmt" 7 | "net" 8 | "unsafe" 9 | 10 | ltls "github.com/lesismal/llib/std/crypto/tls" 11 | "github.com/lesismal/nbio" 12 | ) 13 | 14 | const ( 15 | uintptrSize = int(unsafe.Sizeof(uintptr(0))) 16 | connValueSize = uintptrSize + 1 17 | ) 18 | 19 | const ( 20 | connTypNONE byte = 0 21 | connTypNBIO byte = 1 22 | connTypTCP byte = 2 23 | connTypUNIX byte = 3 24 | connTypTLS byte = 4 25 | connTypLTLS byte = 5 26 | ) 27 | 28 | // We can use this array-value as map key to reduce gc cost. 29 | // Ref: https://github.com/lesismal/nbio/pull/304#issuecomment-1583880587 30 | type connValue [connValueSize]byte 31 | 32 | // Convert net.Conn to array value. 33 | // 34 | //go:norace 35 | func conn2Array(conn net.Conn) (connValue, error) { 36 | var p uintptr 37 | var b connValue 38 | switch vt := conn.(type) { 39 | case *nbio.Conn: 40 | p = uintptr(unsafe.Pointer(vt)) 41 | b[uintptrSize] = connTypNBIO 42 | case *net.TCPConn: 43 | p = uintptr(unsafe.Pointer(vt)) 44 | b[uintptrSize] = connTypTCP 45 | case *net.UnixConn: 46 | p = uintptr(unsafe.Pointer(vt)) 47 | b[uintptrSize] = connTypUNIX 48 | case *tls.Conn: 49 | p = uintptr(unsafe.Pointer(vt)) 50 | b[uintptrSize] = connTypTLS 51 | case *ltls.Conn: 52 | p = uintptr(unsafe.Pointer(vt)) 53 | b[uintptrSize] = connTypLTLS 54 | default: 55 | return b, fmt.Errorf("invalid conn type: %v", vt) 56 | } 57 | switch uintptrSize { 58 | case 4: 59 | binary.LittleEndian.PutUint32(b[:uintptrSize], uint32(p)) 60 | case 8: 61 | binary.LittleEndian.PutUint64(b[:uintptrSize], uint64(p)) 62 | default: 63 | return b, fmt.Errorf("unsupported platform: invalid uintptr size %v", uintptrSize) 64 | } 65 | return b, nil 66 | } 67 | 68 | // Convert array value to net.Conn. 69 | // 70 | //go:norace 71 | func array2Conn(b connValue) (net.Conn, error) { 72 | var p uintptr 73 | switch uintptrSize { 74 | case 4: 75 | p = uintptr(binary.LittleEndian.Uint32(b[:uintptrSize])) 76 | case 8: 77 | p = uintptr(binary.LittleEndian.Uint64(b[:uintptrSize])) 78 | default: 79 | return nil, fmt.Errorf("unsupported platform: invalid uintptr size %v", uintptrSize) 80 | } 81 | 82 | switch b[uintptrSize] { 83 | case connTypNBIO: 84 | conn := *((**nbio.Conn)(unsafe.Pointer(&p))) 85 | return conn, nil 86 | case connTypTCP: 87 | conn := *((**net.TCPConn)(unsafe.Pointer(&p))) 88 | return conn, nil 89 | case connTypUNIX: 90 | conn := *((**net.UnixConn)(unsafe.Pointer(&p))) 91 | return conn, nil 92 | case connTypTLS: 93 | conn := *((**tls.Conn)(unsafe.Pointer(&p))) 94 | return conn, nil 95 | case connTypLTLS: 96 | conn := *((**ltls.Conn)(unsafe.Pointer(&p))) 97 | return conn, nil 98 | default: 99 | } 100 | 101 | return nil, fmt.Errorf("invalid conn type: %v", b[uintptrSize]) 102 | } 103 | -------------------------------------------------------------------------------- /nbhttp/convert_test.go: -------------------------------------------------------------------------------- 1 | package nbhttp 2 | 3 | import ( 4 | "crypto/tls" 5 | "net" 6 | "testing" 7 | 8 | ltls "github.com/lesismal/llib/std/crypto/tls" 9 | "github.com/lesismal/nbio" 10 | ) 11 | 12 | func TestConn2String(t *testing.T) { 13 | var nbc = &nbio.Conn{} 14 | snbc, err := conn2Array(nbc) 15 | if err != nil { 16 | t.Fatal(err) 17 | } 18 | nbc2, err := array2Conn(snbc) 19 | if err != nil { 20 | t.Fatal(err) 21 | } 22 | if nbc2 != nbc { 23 | t.Fatalf("nbc2 != nbc") 24 | } 25 | 26 | var tcp = &net.TCPConn{} 27 | stcp, err := conn2Array(tcp) 28 | if err != nil { 29 | t.Fatal(err) 30 | } 31 | tcp2, err := array2Conn(stcp) 32 | if err != nil { 33 | t.Fatal(err) 34 | } 35 | if tcp2 != tcp { 36 | t.Fatalf("tcp2 != tcp") 37 | } 38 | 39 | var unix = &net.UnixConn{} 40 | sunix, err := conn2Array(unix) 41 | if err != nil { 42 | t.Fatal(err) 43 | } 44 | unix2, err := array2Conn(sunix) 45 | if err != nil { 46 | t.Fatal(err) 47 | } 48 | if unix2 != unix { 49 | t.Fatalf("unix2 != unix") 50 | } 51 | 52 | var tls = &tls.Conn{} 53 | stls, err := conn2Array(tls) 54 | if err != nil { 55 | t.Fatal(err) 56 | } 57 | tls2, err := array2Conn(stls) 58 | if err != nil { 59 | t.Fatal(err) 60 | } 61 | if tls2 != tls { 62 | t.Fatalf("tls2 != tls") 63 | } 64 | 65 | var ltls = <ls.Conn{} 66 | sltls, err := conn2Array(ltls) 67 | if err != nil { 68 | t.Fatal(err) 69 | } 70 | ltls2, err := array2Conn(sltls) 71 | if err != nil { 72 | t.Fatal(err) 73 | } 74 | if ltls2 != ltls { 75 | t.Fatalf("ltls2 != ltls") 76 | } 77 | 78 | var udp = &net.UDPConn{} 79 | _, err = conn2Array(udp) 80 | if err == nil { 81 | t.Fatal("err is nil") 82 | } 83 | _, err = array2Conn(connValue{'a', 'a', 'a'}) 84 | if err == nil { 85 | t.Fatal("err is nil") 86 | } 87 | _, err = array2Conn(connValue{}) 88 | if err == nil { 89 | t.Fatal("err is nil") 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /nbhttp/error.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 lesismal. All rights reserved. 2 | // Use of this source code is governed by an MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package nbhttp 6 | 7 | import ( 8 | "errors" 9 | ) 10 | 11 | var ( 12 | // ErrInvalidCRLF . 13 | ErrInvalidCRLF = errors.New("invalid cr/lf at the end of line") 14 | 15 | // ErrInvalidHTTPVersion . 16 | ErrInvalidHTTPVersion = errors.New("invalid HTTP version") 17 | 18 | // ErrInvalidHTTPStatusCode . 19 | ErrInvalidHTTPStatusCode = errors.New("invalid HTTP status code") 20 | // ErrInvalidHTTPStatus . 21 | ErrInvalidHTTPStatus = errors.New("invalid HTTP status") 22 | 23 | // ErrInvalidMethod . 24 | ErrInvalidMethod = errors.New("invalid HTTP method") 25 | 26 | // ErrInvalidRequestURI . 27 | ErrInvalidRequestURI = errors.New("invalid URL") 28 | 29 | // ErrInvalidHost . 30 | ErrInvalidHost = errors.New("invalid host") 31 | 32 | // ErrInvalidPort . 33 | ErrInvalidPort = errors.New("invalid port") 34 | 35 | // ErrInvalidPath . 36 | ErrInvalidPath = errors.New("invalid path") 37 | 38 | // ErrInvalidQueryString . 39 | ErrInvalidQueryString = errors.New("invalid query string") 40 | 41 | // ErrInvalidFragment . 42 | ErrInvalidFragment = errors.New("invalid fragment") 43 | 44 | // ErrCRExpected . 45 | ErrCRExpected = errors.New("CR character expected") 46 | 47 | // ErrLFExpected . 48 | ErrLFExpected = errors.New("LF character expected") 49 | 50 | // ErrInvalidCharInHeader . 51 | ErrInvalidCharInHeader = errors.New("invalid character in header") 52 | 53 | // ErrUnexpectedContentLength . 54 | ErrUnexpectedContentLength = errors.New("unexpected content-length header") 55 | 56 | // ErrInvalidContentLength . 57 | ErrInvalidContentLength = errors.New("invalid ContentLength") 58 | 59 | // ErrInvalidChunkSize . 60 | ErrInvalidChunkSize = errors.New("invalid chunk size") 61 | 62 | // ErrTrailerExpected . 63 | ErrTrailerExpected = errors.New("trailer expected") 64 | 65 | // ErrTooLong . 66 | ErrTooLong = errors.New("invalid http message: too long") 67 | ) 68 | 69 | var ( 70 | // ErrInvalidH2SM . 71 | ErrInvalidH2SM = errors.New("invalid http2 SM characters") 72 | 73 | // ErrInvalidH2HeaderR . 74 | ErrInvalidH2HeaderR = errors.New("invalid http2 SM characters") 75 | ) 76 | 77 | var ( 78 | // ErrNilConn . 79 | ErrNilConn = errors.New("nil Conn") 80 | ) 81 | 82 | var ( 83 | // ErrClientUnsupportedSchema . 84 | ErrClientUnsupportedSchema = errors.New("unsupported schema") 85 | 86 | // ErrClientTimeout . 87 | ErrClientTimeout = errors.New("timeout") 88 | 89 | // ErrClientClosed . 90 | ErrClientClosed = errors.New("http client closed") 91 | ) 92 | 93 | var ( 94 | // ErrServiceOverload . 95 | ErrServiceOverload = errors.New("service overload") 96 | ) 97 | -------------------------------------------------------------------------------- /nbhttp/parser_test.go: -------------------------------------------------------------------------------- 1 | package nbhttp 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "net" 7 | "net/http" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | func TestServerParserContentLength(t *testing.T) { 13 | data := []byte("POST /echo HTTP/1.1\r\nEmpty:\r\n Empty2:\r\nHost : localhost:8080 \r\nConnection: close \r\n Accept-Encoding : gzip , deflate ,br \r\n\r\n") 14 | err := testParser(t, false, data) 15 | if err != nil { 16 | t.Fatalf("test failed: %v", err) 17 | } 18 | 19 | data = []byte("POST /echo HTTP/1.1\r\nHost: localhost:8080\r\n Connection: close \r\nContent-Length : 0\r\nAccept-Encoding : gzip \r\n\r\n") 20 | err = testParser(t, false, data) 21 | if err != nil { 22 | t.Fatalf("test failed: %v", err) 23 | } 24 | 25 | data = []byte("POST /echo HTTP/1.1\r\nHost: localhost:8080\r\n Connection: close \r\nContent-Length : 5\r\nAccept-Encoding : gzip \r\n\r\nhello") 26 | err = testParser(t, false, data) 27 | if err != nil { 28 | t.Fatalf("test failed: %v", err) 29 | } 30 | } 31 | 32 | func TestServerParserChunks(t *testing.T) { 33 | data := []byte("POST / HTTP/1.1\r\nHost: localhost:1235\r\n User-Agent: Go-http-client/1.1\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n4 \r\nbody\r\n0 \r\n\r\n") 34 | err := testParser(t, false, data) 35 | if err != nil { 36 | t.Fatalf("test failed: %v", err) 37 | } 38 | 39 | data = []byte("POST / HTTP/1.1\r\nHost: localhost:1235\r\n User-Agent: Go-http-client/1.1\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n0\r\n\r\n") 40 | err = testParser(t, false, data) 41 | if err != nil { 42 | t.Fatalf("test failed: %v", err) 43 | } 44 | } 45 | 46 | func TestServerParserTrailer(t *testing.T) { 47 | data := []byte("POST / HTTP/1.1\r\nHost : localhost:1235\r\n User-Agent : Go-http-client/1.1 \r\nTransfer-Encoding: chunked\r\nTrailer: Md5,Size\r\nAccept-Encoding: gzip \r\n\r\n4\r\nbody\r\n0\r\n Md5 : 841a2d689ad86bd1611447453c22c6fc \r\n Size : 4 \r\n\r\n") 48 | err := testParser(t, false, data) 49 | if err != nil { 50 | t.Fatalf("test failed: %v", err) 51 | } 52 | data = []byte("POST / HTTP/1.1\r\nHost: localhost:1235\r\nUser-Agent: Go-http-client/1.1\r\nTransfer-Encoding: chunked\r\nTrailer: Md5,Size\r\nAccept-Encoding: gzip \r\n\r\n0\r\nMd5 : 841a2d689ad86bd1611447453c22c6fc \r\n Size: 4 \r\n\r\n") 53 | err = testParser(t, false, data) 54 | if err != nil { 55 | t.Fatalf("test failed: %v", err) 56 | } 57 | } 58 | 59 | func TestClientParserContentLength(t *testing.T) { 60 | data := []byte("HTTP/1.1 200 OK\r\nHost: localhost:8080\r\n Connection: close \r\n Accept-Encoding : gzip \r\n\r\n") 61 | err := testParser(t, true, data) 62 | if err != nil { 63 | t.Fatalf("test failed: %v", err) 64 | } 65 | 66 | data = []byte("HTTP/1.1 200 OK\r\nHost: localhost:8080\r\n Connection: close \r\n Content-Length : 0\r\nAccept-Encoding : gzip \r\n\r\n") 67 | err = testParser(t, true, data) 68 | if err != nil { 69 | t.Fatalf("test failed: %v", err) 70 | } 71 | 72 | data = []byte("HTTP/1.1 200 OK\r\nHost: localhost:8080\r\n Connection: close \r\n Content-Length : 5\r\nAccept-Encoding : gzip \r\n\r\nhello") 73 | err = testParser(t, true, data) 74 | if err != nil { 75 | t.Fatalf("test failed: %v", err) 76 | } 77 | } 78 | 79 | func TestClientParserChunks(t *testing.T) { 80 | data := []byte("HTTP/1.1 200 OK\r\nHost: localhost:1235\r\n User-Agent: Go-http-client/1.1\r\n Transfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n4\r\nbody\r\n0\r\n\r\n") 81 | err := testParser(t, true, data) 82 | if err != nil { 83 | t.Fatalf("test failed: %v", err) 84 | } 85 | data = []byte("HTTP/1.1 200 OK\r\nHost: localhost:1235\r\nUser-Agent: Go-http-client/1.1\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n0\r\n\r\n") 86 | err = testParser(t, true, data) 87 | if err != nil { 88 | t.Fatalf("test failed: %v", err) 89 | } 90 | } 91 | 92 | func TestClientParserTrailer(t *testing.T) { 93 | data := []byte("HTTP/1.1 200 OK\r\nHost: localhost:1235\r\n User-Agent: Go-http-client/1.1\r\n Transfer-Encoding: chunked\r\nTrailer: Md5,Size\r\nAccept-Encoding: gzip\r\n\r\n4\r\nbody\r\n0\r\nMd5: 841a2d689ad86bd1611447453c22c6fc\r\nSize: 4\r\n\r\n") 94 | err := testParser(t, true, data) 95 | if err != nil { 96 | t.Fatalf("test failed: %v", err) 97 | } 98 | data = []byte("HTTP/1.1 200 OK\r\nHost: localhost:1235\r\nUser-Agent: Go-http-client/1.1\r\nTransfer-Encoding : chunked\r\nTrailer: Md5,Size\r\nAccept-Encoding: gzip\r\n\r\n0\r\nMd5: 841a2d689ad86bd1611447453c22c6fc\r\nSize: 4\r\n\r\n") 99 | err = testParser(t, true, data) 100 | if err != nil { 101 | t.Fatalf("test failed: %v", err) 102 | } 103 | } 104 | 105 | func testParser(t *testing.T, isClient bool, data []byte) error { 106 | parser := newParser(isClient) 107 | defer func() { 108 | if parser.Conn != nil { 109 | parser.Conn.Close() 110 | } 111 | }() 112 | err := parser.Parse(data) 113 | if err != nil { 114 | t.Fatal(err) 115 | } 116 | 117 | for i := 0; i < len(data)-1; i++ { 118 | err = parser.Parse(append([]byte{}, data[i:i+1]...)) 119 | if err != nil { 120 | t.Fatal(err) 121 | } 122 | } 123 | err = parser.Parse(append([]byte{}, data[len(data)-1:]...)) 124 | if err != nil { 125 | t.Fatal(err) 126 | } 127 | 128 | nRequest := 0 129 | data = append(data, data...) 130 | 131 | mux := &http.ServeMux{} 132 | mux.HandleFunc("/", func(w http.ResponseWriter, request *http.Request) { 133 | nRequest++ 134 | }) 135 | conn := newConn() 136 | defer conn.Close() 137 | processor := NewServerProcessor() 138 | if isClient { 139 | processor = NewClientProcessor(nil, func(*http.Response, error) { 140 | nRequest++ 141 | }) 142 | } 143 | engine := NewEngine(Config{ 144 | Handler: mux, 145 | }) 146 | parser = NewParser(conn, engine, processor, isClient, nil) 147 | parser.Engine = engine 148 | tBegin := time.Now() 149 | loop := 10000 150 | for i := 0; i < loop; i++ { 151 | tmp := data 152 | reads := [][]byte{} 153 | for len(tmp) > 0 { 154 | nRead := rand.Intn(len(tmp)) + 1 155 | readBuf := append([]byte{}, tmp[:nRead]...) 156 | reads = append(reads, readBuf) 157 | tmp = tmp[nRead:] 158 | err = parser.Parse(readBuf) 159 | if err != nil { 160 | t.Fatalf("nRead: %v, numOne: %v, reads: %v, error: %v", len(data)-len(tmp), len(data), reads, err) 161 | } 162 | 163 | } 164 | if nRequest != (i+1)*2 { 165 | return fmt.Errorf("nRequest: %v, %v", i, nRequest) 166 | } 167 | } 168 | tUsed := time.Since(tBegin) 169 | fmt.Printf("%v loops, %v s used, %v ns/op, %v req/s\n", loop, tUsed.Seconds(), tUsed.Nanoseconds()/int64(loop), float64(loop)/tUsed.Seconds()) 170 | 171 | return nil 172 | } 173 | 174 | func newParser(isClient bool) *Parser { 175 | mux := &http.ServeMux{} 176 | engine := NewEngine(Config{ 177 | Handler: mux, 178 | }) 179 | conn := newConn() 180 | if isClient { 181 | processor := NewClientProcessor(nil, func(*http.Response, error) {}) 182 | parser := NewParser(conn, engine, processor, isClient, nil) 183 | parser.Engine = engine 184 | return parser 185 | } 186 | processor := NewServerProcessor() 187 | parser := NewParser(conn, engine, processor, isClient, nil) 188 | parser.Conn = conn 189 | return parser 190 | } 191 | 192 | func newConn() net.Conn { 193 | var conn net.Conn 194 | for i := 0; i < 1000; i++ { 195 | addr := fmt.Sprintf("127.0.0.1:%d", 8000+i) 196 | ln, err := net.Listen("tcp", addr) 197 | if err != nil { 198 | continue 199 | } 200 | go func() { 201 | defer ln.Close() 202 | ln.Accept() 203 | }() 204 | conn, err = net.Dial("tcp", addr) 205 | if err != nil { 206 | panic(err) 207 | } 208 | break 209 | } 210 | return conn 211 | } 212 | 213 | // func printMessage(w http.ResponseWriter, request *http.Request) { 214 | // fmt.Printf("----------------------------------------------------------------\n") 215 | // fmt.Println("OnRequest") 216 | // fmt.Println("Method:", request.Method) 217 | // fmt.Println("Path:", request.URL.Path) 218 | // fmt.Println("Proto:", request.Proto) 219 | // fmt.Println("Host:", request.URL.Host) 220 | // fmt.Println("Rawpath:", request.URL.RawPath) 221 | // fmt.Println("Content-Length:", request.ContentLength) 222 | // for k, v := range request.Header { 223 | // fmt.Printf("Header: [\"%v\": \"%v\"]\n", k, v) 224 | // } 225 | // for k, v := range request.Trailer { 226 | // fmt.Printf("Trailer: [\"%v\": \"%v\"]\n", k, v) 227 | // } 228 | // body := request.Body 229 | // if body != nil { 230 | // nread := 0 231 | // buffer := make([]byte, 1024) 232 | // for { 233 | // n, err := body.Read(buffer) 234 | // if n > 0 { 235 | // nread += n 236 | // } 237 | // if errors.Is(err, io.EOF) { 238 | // break 239 | // } 240 | // } 241 | // fmt.Println("body:", string(buffer[:nread])) 242 | // } else { 243 | // fmt.Println("body: null") 244 | // } 245 | // } 246 | 247 | var benchData = []byte("POST /joyent/http-parser HTTP/1.1\r\n" + 248 | "Host: github.com\r\n" + 249 | "DNT: 1\r\n" + 250 | "Accept-Encoding: gzip, deflate, sdch\r\n" + 251 | "Accept-Language: ru-RU,ru;q=0.8,en-US;q=0.6,en;q=0.4\r\n" + 252 | "User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) " + 253 | "AppleWebKit/537.36 (KHTML, like Gecko) " + 254 | "Chrome/39.0.2171.65 Safari/537.36\r\n" + 255 | "Accept: text/html,application/xhtml+xml,application/xml;q=0.9," + 256 | "image/webp,*/*;q=0.8\r\n" + 257 | "Referer: https://github.com/joyent/http-parser\r\n" + 258 | "Connection: keep-alive\r\n" + 259 | "Transfer-Encoding: chunked\r\n" + 260 | "Cache-Control: max-age=0\r\n\r\nb\r\nhello world\r\n0\r\n\r\n") 261 | 262 | func BenchmarkServerProcessor(b *testing.B) { 263 | isClient := false 264 | processor := NewServerProcessor() 265 | mux := http.NewServeMux() 266 | mux.HandleFunc("/", func(http.ResponseWriter, *http.Request) {}) 267 | engine := NewEngine(Config{ 268 | Handler: mux, 269 | }) 270 | parser := NewParser(newConn(), engine, processor, isClient, nil) 271 | defer parser.Conn.Close() 272 | b.ReportAllocs() 273 | b.ResetTimer() 274 | for i := 0; i < b.N; i++ { 275 | for j := 0; j < 5; j++ { 276 | err := parser.Parse(benchData) 277 | if err != nil { 278 | b.Fatal(err) 279 | } 280 | } 281 | } 282 | } 283 | -------------------------------------------------------------------------------- /nbhttp/server.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 lesismal. All rights reserved. 2 | // Use of this source code is governed by an MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package nbhttp 6 | 7 | import ( 8 | "net/http" 9 | 10 | "github.com/lesismal/llib/std/crypto/tls" 11 | ) 12 | 13 | // Server . 14 | type Server struct { 15 | *Engine 16 | } 17 | 18 | // NewServer . 19 | // 20 | //go:norace 21 | func NewServer(conf Config, v ...interface{}) *Server { 22 | if len(v) > 0 { 23 | if handler, ok := v[0].(http.Handler); ok { 24 | conf.Handler = handler 25 | } 26 | } 27 | if len(v) > 1 { 28 | if messageHandlerExecutor, ok := v[1].(func(f func())); ok { 29 | conf.ServerExecutor = messageHandlerExecutor 30 | } 31 | } 32 | return &Server{Engine: NewEngine(conf)} 33 | } 34 | 35 | // NewServerTLS . 36 | // 37 | //go:norace 38 | func NewServerTLS(conf Config, v ...interface{}) *Server { 39 | if len(v) > 0 { 40 | if handler, ok := v[0].(http.Handler); ok { 41 | conf.Handler = handler 42 | } 43 | } 44 | if len(v) > 1 { 45 | if messageHandlerExecutor, ok := v[1].(func(f func())); ok { 46 | conf.ServerExecutor = messageHandlerExecutor 47 | } 48 | } 49 | if len(v) > 2 { 50 | if tlsConfig, ok := v[2].(*tls.Config); ok { 51 | conf.TLSConfig = tlsConfig 52 | } 53 | } 54 | conf.AddrsTLS = append(conf.AddrsTLS, conf.Addrs...) 55 | conf.Addrs = nil 56 | return &Server{Engine: NewEngine(conf)} 57 | } 58 | -------------------------------------------------------------------------------- /nbhttp/state.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 lesismal. All rights reserved. 2 | // Use of this source code is governed by an MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package nbhttp 6 | 7 | const ( 8 | // state: RequestLine 9 | stateClose int8 = iota 10 | stateMethodBefore 11 | stateMethod 12 | 13 | statePathBefore 14 | statePath 15 | stateProtoBefore 16 | stateProto 17 | stateProtoLF 18 | stateClientProtoBefore 19 | stateClientProto 20 | stateStatusCodeBefore 21 | stateStatusCode 22 | stateStatusBefore 23 | stateStatus 24 | stateStatusLF 25 | 26 | // state: Header 27 | stateHeaderKeyBefore 28 | stateHeaderValueLF 29 | stateHeaderKey 30 | 31 | stateHeaderValueBefore 32 | stateHeaderValue 33 | 34 | // state: Body ContentLength 35 | stateBodyContentLength 36 | 37 | // state: Body Chunk 38 | stateHeaderOverLF 39 | stateBodyChunkSizeBefore 40 | stateBodyChunkSize 41 | stateBodyChunkSizeLF 42 | stateBodyChunkData 43 | stateBodyChunkDataCR 44 | stateBodyChunkDataLF 45 | 46 | // state: Body Trailer 47 | stateBodyTrailerHeaderValueLF 48 | stateBodyTrailerHeaderKeyBefore 49 | stateBodyTrailerHeaderKey 50 | stateBodyTrailerHeaderValueBefore 51 | stateBodyTrailerHeaderValue 52 | 53 | // state: Body CRLF 54 | stateTailCR 55 | stateTailLF 56 | ) 57 | -------------------------------------------------------------------------------- /nbhttp/table.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 lesismal. All rights reserved. 2 | // Use of this source code is governed by an MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package nbhttp 6 | 7 | import "strings" 8 | 9 | var ( 10 | validMethods = map[string]bool{ 11 | "OPTIONS": true, 12 | "GET": true, 13 | "HEAD": true, 14 | "POST": true, 15 | "PUT": true, 16 | "DELETE": true, 17 | "TRACE": true, 18 | "CONNECT": true, 19 | "PATCH": true, // RFC 5789 20 | 21 | // http 2.0 22 | "PRI": true, 23 | } 24 | 25 | tokenCharMap = [256]bool{ 26 | '!': true, 27 | '#': true, 28 | '$': true, 29 | '%': true, 30 | '&': true, 31 | '\'': true, 32 | '*': true, 33 | '+': true, 34 | '-': true, 35 | '.': true, 36 | '0': true, 37 | '1': true, 38 | '2': true, 39 | '3': true, 40 | '4': true, 41 | '5': true, 42 | '6': true, 43 | '7': true, 44 | '8': true, 45 | '9': true, 46 | 'A': true, 47 | 'B': true, 48 | 'C': true, 49 | 'D': true, 50 | 'E': true, 51 | 'F': true, 52 | 'G': true, 53 | 'H': true, 54 | 'I': true, 55 | 'J': true, 56 | 'K': true, 57 | 'L': true, 58 | 'M': true, 59 | 'N': true, 60 | 'O': true, 61 | 'P': true, 62 | 'Q': true, 63 | 'R': true, 64 | 'S': true, 65 | 'T': true, 66 | 'U': true, 67 | 'W': true, 68 | 'V': true, 69 | 'X': true, 70 | 'Y': true, 71 | 'Z': true, 72 | '^': true, 73 | '_': true, 74 | '`': true, 75 | 'a': true, 76 | 'b': true, 77 | 'c': true, 78 | 'd': true, 79 | 'e': true, 80 | 'f': true, 81 | 'g': true, 82 | 'h': true, 83 | 'i': true, 84 | 'j': true, 85 | 'k': true, 86 | 'l': true, 87 | 'm': true, 88 | 'n': true, 89 | 'o': true, 90 | 'p': true, 91 | 'q': true, 92 | 'r': true, 93 | 's': true, 94 | 't': true, 95 | 'u': true, 96 | 'v': true, 97 | 'w': true, 98 | 'x': true, 99 | 'y': true, 100 | 'z': true, 101 | '|': true, 102 | '~': true, 103 | } 104 | 105 | // headerCharMap = [256]bool{}. 106 | 107 | numCharMap = [256]bool{} 108 | hexCharMap = [256]bool{} 109 | alphaCharMap = [256]bool{} 110 | alphaNumCharMap = [256]bool{} 111 | 112 | validMethodCharMap = [256]bool{} 113 | ) 114 | 115 | //go:norace 116 | func init() { 117 | var dis byte = 'a' - 'A' 118 | 119 | for m := range validMethods { 120 | for _, c := range m { 121 | validMethodCharMap[c] = true 122 | validMethodCharMap[byte(c)+dis] = true 123 | } 124 | } 125 | 126 | for i := byte(0); i < 10; i++ { 127 | numCharMap['0'+i] = true 128 | alphaNumCharMap['0'+i] = true 129 | hexCharMap['0'+i] = true 130 | } 131 | for i := byte(0); i < 6; i++ { 132 | hexCharMap['A'+i] = true 133 | hexCharMap['a'+i] = true 134 | } 135 | 136 | for i := byte(0); i < 26; i++ { 137 | alphaCharMap['A'+i] = true 138 | alphaCharMap['A'+i+dis] = true 139 | alphaNumCharMap['A'+i] = true 140 | alphaNumCharMap['A'+i+dis] = true 141 | } 142 | 143 | // for i := 0; i < len(tokenCharMap); i++ { 144 | // headerCharMap[i] = tokenCharMap[i] 145 | // } 146 | // headerCharMap[':'] = true 147 | // headerCharMap['?'] = true 148 | } 149 | 150 | //go:norace 151 | func isAlpha(c byte) bool { 152 | return alphaCharMap[c] 153 | } 154 | 155 | //go:norace 156 | func isNum(c byte) bool { 157 | return numCharMap[c] 158 | } 159 | 160 | //go:norace 161 | func isHex(c byte) bool { 162 | return hexCharMap[c] 163 | } 164 | 165 | // func isAlphaNum(c byte) bool { 166 | // return alphaNumCharMap[c] 167 | // } 168 | 169 | //go:norace 170 | func isToken(c byte) bool { 171 | return tokenCharMap[c] 172 | } 173 | 174 | //go:norace 175 | func isValidMethod(m string) bool { 176 | return validMethods[strings.ToUpper(m)] 177 | } 178 | 179 | //go:norace 180 | func isValidMethodChar(c byte) bool { 181 | return validMethodCharMap[c] 182 | } 183 | -------------------------------------------------------------------------------- /nbhttp/websocket/compression.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import ( 4 | "compress/flate" 5 | "errors" 6 | "io" 7 | "sync" 8 | ) 9 | 10 | const ( 11 | minCompressionLevel = -2 12 | maxCompressionLevel = flate.BestCompression 13 | defaultCompressionLevel = 1 14 | 15 | flateReaderTail = "\x00\x00\xff\xff" + "\x01\x00\x00\xff\xff" 16 | ) 17 | 18 | var ( 19 | flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool 20 | flateReaderPool = sync.Pool{New: func() interface{} { 21 | return flate.NewReader(nil) 22 | }} 23 | ) 24 | 25 | //go:norace 26 | func isValidCompressionLevel(level int) bool { 27 | return minCompressionLevel <= level && level <= maxCompressionLevel 28 | } 29 | 30 | //go:norace 31 | func decompressReader(r io.Reader) io.ReadCloser { 32 | fr, _ := flateReaderPool.Get().(io.ReadCloser) 33 | fr.(flate.Resetter).Reset(r, nil) 34 | return &flateReadWrapper{fr} 35 | } 36 | 37 | type flateReadWrapper struct { 38 | fr io.ReadCloser 39 | } 40 | 41 | //go:norace 42 | func (r *flateReadWrapper) Read(p []byte) (int, error) { 43 | if r.fr == nil { 44 | return 0, io.ErrClosedPipe 45 | } 46 | n, err := r.fr.Read(p) 47 | if errors.Is(err, io.EOF) { 48 | // Preemptively place the reader back in the pool. This helps with 49 | // scenarios where the application does not call NextReader() soon after 50 | // this final read. 51 | r.Close() 52 | } 53 | return n, err 54 | } 55 | 56 | //go:norace 57 | func (r *flateReadWrapper) Close() error { 58 | if r.fr == nil { 59 | return io.ErrClosedPipe 60 | } 61 | err := r.fr.Close() 62 | flateReaderPool.Put(r.fr) 63 | r.fr = nil 64 | return err 65 | } 66 | 67 | //go:norace 68 | func compressWriter(w io.WriteCloser, level int) io.WriteCloser { 69 | p := &flateWriterPools[level-minCompressionLevel] 70 | fw, _ := p.Get().(*flate.Writer) 71 | tw := &truncWriter{w: w} 72 | if fw == nil { 73 | fw, _ = flate.NewWriter(tw, level) 74 | } else { 75 | fw.Reset(tw) 76 | } 77 | return &flateWriteWrapper{fw: fw, p: p} 78 | } 79 | 80 | type truncWriter struct { 81 | w io.WriteCloser 82 | n int 83 | p [4]byte 84 | } 85 | 86 | //go:norace 87 | func (w *truncWriter) Write(p []byte) (int, error) { 88 | n := 0 89 | 90 | if w.n < len(w.p) { 91 | n = copy(w.p[w.n:], p) 92 | p = p[n:] 93 | w.n += n 94 | if len(p) == 0 { 95 | return n, nil 96 | } 97 | } 98 | 99 | m := len(p) 100 | if m > len(w.p) { 101 | m = len(w.p) 102 | } 103 | 104 | if nn, err := w.w.Write(w.p[:m]); err != nil { 105 | return n + nn, err 106 | } 107 | 108 | copy(w.p[:], w.p[m:]) 109 | copy(w.p[len(w.p)-m:], p[len(p)-m:]) 110 | nn, err := w.w.Write(p[:len(p)-m]) 111 | return n + nn, err 112 | } 113 | 114 | type flateWriteWrapper struct { 115 | fw *flate.Writer 116 | p *sync.Pool 117 | } 118 | 119 | //go:norace 120 | func (w *flateWriteWrapper) Write(p []byte) (int, error) { 121 | return w.fw.Write(p) 122 | } 123 | 124 | //go:norace 125 | func (w *flateWriteWrapper) Close() error { 126 | err := w.fw.Flush() 127 | w.p.Put(w.fw) 128 | w.fw = nil 129 | return err 130 | } 131 | -------------------------------------------------------------------------------- /nbhttp/websocket/dialer.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net" 7 | "net/http" 8 | "net/url" 9 | "strings" 10 | "time" 11 | 12 | "github.com/lesismal/llib/std/crypto/tls" 13 | "github.com/lesismal/nbio" 14 | "github.com/lesismal/nbio/nbhttp" 15 | ) 16 | 17 | const ( 18 | hostHeaderField = "Host" 19 | upgradeHeaderField = "Upgrade" 20 | connectionHeaderField = "Connection" 21 | secWebsocketKeyHeaderField = "Sec-Websocket-Key" 22 | secWebsocketVersionHeaderField = "Sec-Websocket-Version" 23 | secWebsocketExtHeaderField = "Sec-Websocket-Extensions" 24 | secWebsocketProtoHeaderField = "Sec-Websocket-Protocol" 25 | ) 26 | 27 | // Dialer . 28 | type Dialer struct { 29 | Engine *nbhttp.Engine 30 | 31 | Options *Options 32 | Upgrader *Upgrader 33 | 34 | Jar http.CookieJar 35 | 36 | DialTimeout time.Duration 37 | 38 | TLSClientConfig *tls.Config 39 | 40 | Proxy func(*http.Request) (*url.URL, error) 41 | 42 | CheckRedirect func(req *http.Request, via []*http.Request) error 43 | 44 | Subprotocols []string 45 | 46 | EnableCompression bool 47 | 48 | Cancel context.CancelFunc 49 | } 50 | 51 | // Dial . 52 | // 53 | //go:norace 54 | func (d *Dialer) Dial(urlStr string, requestHeader http.Header, v ...interface{}) (*Conn, *http.Response, error) { 55 | ctx := context.Background() 56 | if d.DialTimeout > 0 { 57 | ctx, d.Cancel = context.WithTimeout(ctx, d.DialTimeout) 58 | } 59 | return d.DialContext(ctx, urlStr, requestHeader, v...) 60 | } 61 | 62 | // DialContext . 63 | // 64 | //go:norace 65 | func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header, v ...interface{}) (*Conn, *http.Response, error) { 66 | if d.Cancel != nil { 67 | defer d.Cancel() 68 | } 69 | 70 | options := d.Options 71 | if options == nil { 72 | options = d.Upgrader 73 | } 74 | if options == nil { 75 | return nil, nil, errors.New("invalid Options: nil") 76 | } 77 | 78 | challengeKey, err := challengeKey() 79 | if err != nil { 80 | return nil, nil, err 81 | } 82 | 83 | u, err := url.Parse(urlStr) 84 | if err != nil { 85 | return nil, nil, err 86 | } 87 | 88 | switch u.Scheme { 89 | case "ws": 90 | u.Scheme = "http" 91 | case "wss": 92 | u.Scheme = "https" 93 | default: 94 | return nil, nil, ErrMalformedURL 95 | } 96 | 97 | if u.User != nil { 98 | return nil, nil, ErrMalformedURL 99 | } 100 | 101 | req := &http.Request{ 102 | Method: "GET", 103 | URL: u, 104 | Proto: "HTTP/1.1", 105 | ProtoMajor: 1, 106 | ProtoMinor: 1, 107 | Header: make(http.Header), 108 | Host: u.Host, 109 | } 110 | 111 | if d.Jar != nil { 112 | for _, cookie := range d.Jar.Cookies(u) { 113 | req.AddCookie(cookie) 114 | } 115 | } 116 | 117 | req.Header[upgradeHeaderField] = []string{"websocket"} 118 | req.Header[connectionHeaderField] = []string{"Upgrade"} 119 | req.Header[secWebsocketKeyHeaderField] = []string{challengeKey} 120 | req.Header[secWebsocketVersionHeaderField] = []string{"13"} 121 | if len(d.Subprotocols) > 0 { 122 | req.Header[secWebsocketProtoHeaderField] = []string{strings.Join(d.Subprotocols, ", ")} 123 | } 124 | for k, vs := range requestHeader { 125 | switch { 126 | case k == hostHeaderField: 127 | if len(vs) > 0 { 128 | req.Host = vs[0] 129 | } 130 | case k == upgradeHeaderField || 131 | k == connectionHeaderField || 132 | k == secWebsocketKeyHeaderField || 133 | k == secWebsocketVersionHeaderField || 134 | k == secWebsocketExtHeaderField || 135 | (k == secWebsocketProtoHeaderField && len(d.Subprotocols) > 0): 136 | return nil, nil, errors.New("websocket: duplicate header not allowed: " + k) 137 | case k == secWebsocketProtoHeaderField: 138 | req.Header[secWebsocketProtoHeaderField] = vs 139 | default: 140 | req.Header[k] = vs 141 | } 142 | } 143 | 144 | if options.enableCompression { 145 | req.Header[secWebsocketExtHeaderField] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"} 146 | } 147 | 148 | var asyncHandler func(*Conn, *http.Response, error) 149 | if len(v) > 0 { 150 | if h, ok := v[0].(func(*Conn, *http.Response, error)); ok { 151 | asyncHandler = h 152 | } 153 | } 154 | 155 | var wsConn *Conn 156 | var res *http.Response 157 | var errCh chan error 158 | if asyncHandler == nil { 159 | errCh = make(chan error, 1) 160 | } 161 | 162 | cliConn := &nbhttp.ClientConn{ 163 | Engine: d.Engine, 164 | Jar: d.Jar, 165 | Timeout: d.DialTimeout, 166 | TLSClientConfig: d.TLSClientConfig, 167 | Proxy: d.Proxy, 168 | CheckRedirect: d.CheckRedirect, 169 | } 170 | cliConn.Do(req, func(resp *http.Response, conn net.Conn, err error) { 171 | res = resp 172 | 173 | notifyResult := func(e error) { 174 | if asyncHandler == nil { 175 | select { 176 | case errCh <- e: 177 | case <-ctx.Done(): 178 | if conn != nil { 179 | conn.Close() 180 | } 181 | } 182 | } else { 183 | d.Engine.Execute(func() { 184 | asyncHandler(wsConn, res, e) 185 | }) 186 | } 187 | } 188 | 189 | if err != nil { 190 | notifyResult(err) 191 | return 192 | } 193 | 194 | nbc, ok := conn.(*nbio.Conn) 195 | if !ok { 196 | nbhttpConn, ok2 := conn.(*nbhttp.Conn) 197 | if !ok2 { 198 | err = ErrBadHandshake 199 | notifyResult(err) 200 | return 201 | } 202 | tlsConn, tlsOk := nbhttpConn.Conn.(*tls.Conn) 203 | if !tlsOk { 204 | err = ErrBadHandshake 205 | notifyResult(err) 206 | return 207 | } 208 | nbc, tlsOk = tlsConn.Conn().(*nbio.Conn) 209 | if !tlsOk { 210 | err = errors.New(http.StatusText(http.StatusInternalServerError)) 211 | notifyResult(err) 212 | return 213 | } 214 | } 215 | 216 | parser, ok := nbc.Session().(*nbhttp.Parser) 217 | if !ok { 218 | err = errors.New(http.StatusText(http.StatusInternalServerError)) 219 | notifyResult(err) 220 | return 221 | } 222 | 223 | if d.Jar != nil { 224 | if rc := resp.Cookies(); len(rc) > 0 { 225 | d.Jar.SetCookies(req.URL, rc) 226 | } 227 | } 228 | 229 | remoteCompressionEnabled := false 230 | if resp.StatusCode != 101 || 231 | !headerContains(resp.Header, "Upgrade", "websocket") || 232 | !headerContains(resp.Header, "Connection", "upgrade") || 233 | resp.Header.Get("Sec-Websocket-Accept") != acceptKeyString(challengeKey) { 234 | err = ErrBadHandshake 235 | notifyResult(err) 236 | return 237 | } 238 | 239 | for _, ext := range parseExtensions(resp.Header) { 240 | if ext[""] != "permessage-deflate" { 241 | continue 242 | } 243 | _, snct := ext["server_no_context_takeover"] 244 | _, cnct := ext["client_no_context_takeover"] 245 | if !snct || !cnct { 246 | err = ErrInvalidCompression 247 | notifyResult(err) 248 | return 249 | } 250 | 251 | remoteCompressionEnabled = true 252 | break 253 | } 254 | 255 | wsConn = NewClientConn(options, conn, resp.Header.Get(secWebsocketProtoHeaderField), remoteCompressionEnabled, false) 256 | parser.ParserCloser = wsConn 257 | wsConn.Engine = parser.Engine 258 | wsConn.Execute = parser.Execute 259 | nbc.SetSession(wsConn) 260 | 261 | if wsConn.openHandler != nil { 262 | wsConn.openHandler(wsConn) 263 | } 264 | 265 | notifyResult(err) 266 | }) 267 | 268 | if asyncHandler == nil { 269 | select { 270 | case err = <-errCh: 271 | case <-ctx.Done(): 272 | err = nbhttp.ErrClientTimeout 273 | } 274 | if err != nil { 275 | cliConn.CloseWithError(err) 276 | } 277 | return wsConn, res, err 278 | } 279 | 280 | return nil, nil, nil 281 | } 282 | -------------------------------------------------------------------------------- /nbhttp/websocket/error.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 lesismal. All rights reserved. 2 | // Use of this source code is governed by an MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "errors" 9 | "fmt" 10 | ) 11 | 12 | var ( 13 | // ErrUpgradeTokenNotFound . 14 | ErrUpgradeTokenNotFound = errors.New("websocket: the client is not using the websocket protocol: 'upgrade' token not found in 'Connection' header") 15 | 16 | // ErrUpgradeMethodIsGet . 17 | ErrUpgradeMethodIsGet = errors.New("websocket: the client is not using the websocket protocol: request method is not GET") 18 | 19 | // ErrUpgradeInvalidWebsocketVersion . 20 | ErrUpgradeInvalidWebsocketVersion = errors.New("websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header") 21 | 22 | // ErrUpgradeUnsupportedExtensions . 23 | ErrUpgradeUnsupportedExtensions = errors.New("websocket: application specific 'Sec-WebSocket-Extensions' headers are unsupported") 24 | 25 | // ErrUpgradeOriginNotAllowed . 26 | ErrUpgradeOriginNotAllowed = errors.New("websocket: request origin not allowed by Upgrader.CheckOrigin") 27 | 28 | // ErrUpgradeMissingWebsocketKey . 29 | ErrUpgradeMissingWebsocketKey = errors.New("websocket: not a websocket handshake: 'Sec-WebSocket-Key' header is missing or blank") 30 | 31 | // ErrUpgradeNotHijacker . 32 | ErrUpgradeNotHijacker = errors.New("websocket: response does not implement http.Hijacker") 33 | 34 | // ErrInvalidControlFrame . 35 | ErrInvalidControlFrame = errors.New("websocket: invalid control frame") 36 | 37 | // ErrInvalidWriteCalling . 38 | ErrInvalidWriteCalling = errors.New("websocket: invalid write calling, should call WriteMessage instead") 39 | 40 | // ErrReserveBitSet . 41 | ErrReserveBitSet = errors.New("websocket: reserved bit set it frame") 42 | 43 | // ErrReservedMessageType . 44 | ErrReservedMessageType = errors.New("websocket: reserved message type received") 45 | 46 | // ErrControlMessageFragmented . 47 | ErrControlMessageFragmented = errors.New("websocket: control messages must not be fragmented") 48 | 49 | // ErrControlMessageTooBig . 50 | ErrControlMessageTooBig = errors.New("websocket: control frame length > 125") 51 | 52 | // ErrFragmentsShouldNotHaveBinaryOrTextMessage . 53 | ErrFragmentsShouldNotHaveBinaryOrTextMessage = errors.New("websocket: fragments should not have message type of text or binary") 54 | 55 | // ErrInvalidCloseCode . 56 | ErrInvalidCloseCode = errors.New("websocket: invalid close code") 57 | 58 | // ErrBadHandshake . 59 | ErrBadHandshake = errors.New("websocket: bad handshake") 60 | 61 | // ErrInvalidCompression . 62 | ErrInvalidCompression = errors.New("websocket: invalid compression negotiation") 63 | 64 | // ErrInvalidUtf8 . 65 | ErrInvalidUtf8 = errors.New("websocket: invalid UTF-8 bytes") 66 | 67 | // ErrInvalidFragmentMessage . 68 | ErrInvalidFragmentMessage = errors.New("invalid fragment message") 69 | 70 | // ErrMalformedURL . 71 | ErrMalformedURL = errors.New("websocket: malformed ws or wss URL") 72 | 73 | // ErrMessageTooLarge. 74 | ErrMessageTooLarge = errors.New("message exceeds the configured limit") 75 | 76 | // ErrMessageSendQuqueIsFull . 77 | ErrMessageSendQuqueIsFull = errors.New("message send queue is full") 78 | ) 79 | 80 | // CloseError . 81 | type CloseError struct { 82 | Code int 83 | Reason string 84 | } 85 | 86 | // Error . 87 | // 88 | //go:norace 89 | func (ce CloseError) Error() string { 90 | return fmt.Sprintf("websocket: close code=%d and reason=%q", ce.Code, ce.Reason) 91 | } 92 | 93 | // CloseCode . 94 | // 95 | //go:norace 96 | func CloseCode(err error) int { 97 | var ce CloseError 98 | if errors.As(err, &ce) { 99 | return ce.Code 100 | } 101 | return -1 102 | } 103 | 104 | // CloseReason . 105 | // 106 | //go:norace 107 | func CloseReason(err error) string { 108 | var ce CloseError 109 | if errors.As(err, &ce) { 110 | return ce.Reason 111 | } 112 | return "" 113 | } 114 | -------------------------------------------------------------------------------- /nbhttp/websocket/upgrader_test.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import "testing" 4 | 5 | func Test_validFrame(t *testing.T) { 6 | type args struct { 7 | opcode MessageType 8 | fin bool 9 | res1 bool 10 | res2 bool 11 | res3 bool 12 | expectingFragments bool 13 | } 14 | tests := []struct { 15 | name string 16 | args args 17 | wantErr bool 18 | }{ 19 | {"validtext", args{TextMessage, true, false, false, false, false}, false}, 20 | {"validbinary", args{BinaryMessage, true, false, false, false, false}, false}, 21 | {"validbinaryFragmented", args{BinaryMessage, true, false, false, false, false}, false}, 22 | {"reservedOpcode", args{MessageType(3), true, false, false, false, false}, true}, 23 | {"reservedBit1", args{BinaryMessage, true, true, false, false, false}, true}, 24 | {"reservedBit2", args{BinaryMessage, true, false, true, false, false}, true}, 25 | {"reservedBit3", args{BinaryMessage, true, false, false, true, false}, true}, 26 | {"CloseFragmented", args{CloseMessage, false, false, false, false, false}, true}, 27 | {"ExpectingFragmentButGotText", args{TextMessage, false, false, false, false, true}, true}, 28 | } 29 | for _, tt := range tests { 30 | t.Run(tt.name, func(t *testing.T) { 31 | u := NewUpgrader() 32 | wsc := NewServerConn(u, nil, "", true, true) 33 | if err := wsc.validFrame(tt.args.opcode, tt.args.fin, tt.args.res1, tt.args.res2, tt.args.res3, tt.args.expectingFragments); (err != nil) != tt.wantErr { 34 | t.Errorf("validFrame() error = %v, wantErr %v", err, tt.wantErr) 35 | } 36 | }) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /net_unix.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 lesismal. All rights reserved. 2 | // Use of this source code is governed by an MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | //go:build linux || darwin || netbsd || freebsd || openbsd || dragonfly 6 | // +build linux darwin netbsd freebsd openbsd dragonfly 7 | 8 | package nbio 9 | 10 | import ( 11 | "errors" 12 | "net" 13 | "strings" 14 | "syscall" 15 | ) 16 | 17 | //go:norace 18 | func init() { 19 | var limit syscall.Rlimit 20 | if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &limit); err == nil { 21 | if n := int(limit.Max); n > 0 && n < MaxOpenFiles { 22 | MaxOpenFiles = n 23 | } 24 | } 25 | } 26 | 27 | //go:norace 28 | func dupStdConn(conn net.Conn) (*Conn, error) { 29 | sc, ok := conn.(interface { 30 | SyscallConn() (syscall.RawConn, error) 31 | }) 32 | if !ok { 33 | return nil, errors.New("RawConn Unsupported") 34 | } 35 | rc, err := sc.SyscallConn() 36 | if err != nil { 37 | return nil, errors.New("RawConn Unsupported") 38 | } 39 | 40 | var newFd int 41 | errCtrl := rc.Control(func(fd uintptr) { 42 | newFd, err = syscall.Dup(int(fd)) 43 | }) 44 | 45 | if errCtrl != nil { 46 | return nil, errCtrl 47 | } 48 | 49 | if err != nil { 50 | return nil, err 51 | } 52 | 53 | lAddr := conn.LocalAddr() 54 | rAddr := conn.RemoteAddr() 55 | 56 | conn.Close() 57 | 58 | // err = syscall.SetNonblock(newFd, true) 59 | // if err != nil { 60 | // syscall.Close(newFd) 61 | // return nil, err 62 | // } 63 | 64 | c := &Conn{ 65 | fd: newFd, 66 | lAddr: lAddr, 67 | rAddr: rAddr, 68 | } 69 | 70 | switch conn.(type) { 71 | case *net.TCPConn: 72 | c.typ = ConnTypeTCP 73 | case *net.UnixConn: 74 | c.typ = ConnTypeUnix 75 | case *net.UDPConn: 76 | lAddrUDP := lAddr.(*net.UDPAddr) 77 | newLAddr := net.UDPAddr{ 78 | IP: make([]byte, len(lAddrUDP.IP)), 79 | Port: lAddrUDP.Port, 80 | Zone: lAddrUDP.Zone, 81 | } 82 | 83 | copy(newLAddr.IP, lAddrUDP.IP) 84 | 85 | c.lAddr = &newLAddr 86 | 87 | // no remote addr, this is a listener 88 | if rAddr == nil { 89 | c.typ = ConnTypeUDPServer 90 | c.connUDP = &udpConn{ 91 | parent: c, 92 | conns: map[udpAddrKey]*Conn{}, 93 | } 94 | } else { 95 | // has remote addr, this is a dialer 96 | c.typ = ConnTypeUDPClientFromDial 97 | c.connUDP = &udpConn{ 98 | parent: c, 99 | } 100 | } 101 | default: 102 | } 103 | 104 | return c, nil 105 | } 106 | 107 | //go:norace 108 | func parseDomainAndType(network, addr string) (int, int, syscall.Sockaddr, net.Addr, ConnType, error) { 109 | var ( 110 | isIPv4 = len(strings.Split(addr, ":")) == 2 111 | ) 112 | 113 | socketResult := func(sockType int, connType ConnType) (int, int, syscall.Sockaddr, net.Addr, ConnType, error) { 114 | var ( 115 | ip net.IP 116 | port int 117 | zone string 118 | retAddr net.Addr 119 | ) 120 | if connType == ConnTypeTCP { 121 | dstAddr, err := net.ResolveTCPAddr(network, addr) 122 | if err != nil { 123 | return 0, 0, nil, nil, 0, err 124 | } 125 | ip, port, zone, retAddr = dstAddr.IP, dstAddr.Port, dstAddr.Zone, dstAddr 126 | } else { 127 | dstAddr, err := net.ResolveUDPAddr(network, addr) 128 | if err != nil { 129 | return 0, 0, nil, nil, 0, err 130 | } 131 | ip, port, zone, retAddr = dstAddr.IP, dstAddr.Port, dstAddr.Zone, dstAddr 132 | } 133 | 134 | if isIPv4 { 135 | return syscall.AF_INET, sockType, &syscall.SockaddrInet4{ 136 | Addr: [4]byte{ip[0], ip[1], ip[2], ip[3]}, 137 | Port: port, 138 | }, retAddr, connType, nil 139 | } 140 | 141 | iface, err := net.InterfaceByName(zone) 142 | if err != nil { 143 | return 0, 0, nil, nil, 0, err 144 | } 145 | addr6 := &syscall.SockaddrInet6{ 146 | Port: port, 147 | ZoneId: uint32(iface.Index), 148 | } 149 | copy(addr6.Addr[:], ip) 150 | return syscall.AF_INET6, sockType, addr6, retAddr, connType, nil 151 | } 152 | 153 | switch network { 154 | case NETWORK_TCP, NETWORK_TCP4, NETWORK_TCP6: 155 | return socketResult(syscall.SOCK_STREAM, ConnTypeTCP) 156 | case NETWORK_UDP, NETWORK_UDP4, NETWORK_UDP6: 157 | return socketResult(syscall.SOCK_DGRAM, ConnTypeUDPClientFromDial) 158 | case NETWORK_UNIX, NETWORK_UNIXGRAM, NETWORK_UNIXPACKET: 159 | sotype := syscall.SOCK_STREAM 160 | switch network { 161 | case NETWORK_UNIX: 162 | sotype = syscall.SOCK_STREAM 163 | case NETWORK_UNIXGRAM: 164 | sotype = syscall.SOCK_DGRAM 165 | case NETWORK_UNIXPACKET: 166 | sotype = syscall.SOCK_SEQPACKET 167 | default: 168 | } 169 | dstAddr := &net.UnixAddr{ 170 | Net: network, 171 | Name: addr, 172 | } 173 | return syscall.AF_UNIX, sotype, &syscall.SockaddrUnix{Name: addr}, dstAddr, ConnTypeUnix, nil 174 | default: 175 | } 176 | return 0, 0, nil, nil, 0, net.UnknownNetworkError(network) 177 | } 178 | -------------------------------------------------------------------------------- /poller_epoll.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 lesismal. All rights reserved. 2 | // Use of this source code is governed by an MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | //go:build linux 6 | // +build linux 7 | 8 | package nbio 9 | 10 | import ( 11 | "errors" 12 | "fmt" 13 | "io" 14 | "net" 15 | "os" 16 | "runtime" 17 | "syscall" 18 | "time" 19 | "unsafe" 20 | 21 | "github.com/lesismal/nbio/logging" 22 | ) 23 | 24 | const ( 25 | // EPOLLLT . 26 | EPOLLLT = 0 27 | 28 | // EPOLLET . 29 | EPOLLET = 0x80000000 30 | 31 | // EPOLLONESHOT . 32 | EPOLLONESHOT = syscall.EPOLLONESHOT 33 | ) 34 | 35 | const ( 36 | epollEventsRead = syscall.EPOLLPRI | syscall.EPOLLIN 37 | epollEventsWrite = syscall.EPOLLOUT 38 | epollEventsError = syscall.EPOLLERR | syscall.EPOLLHUP | syscall.EPOLLRDHUP 39 | ) 40 | 41 | const ( 42 | IPPROTO_TCP = syscall.IPPROTO_TCP 43 | TCP_KEEPINTVL = syscall.TCP_KEEPINTVL 44 | TCP_KEEPIDLE = syscall.TCP_KEEPIDLE 45 | ) 46 | 47 | type poller struct { 48 | g *Engine // parent engine 49 | 50 | epfd int // epoll fd 51 | evtfd int // event fd for trigger 52 | 53 | index int // poller index in engine 54 | 55 | pollType string // listener or io poller 56 | 57 | shutdown bool // state 58 | 59 | // whether poller is used for listener. 60 | isListener bool 61 | // listener. 62 | listener net.Listener 63 | // if poller is used as UnixConn listener, 64 | // store the addr and remove it when exit. 65 | unixSockAddr string 66 | 67 | ReadBuffer []byte // default reading buffer 68 | } 69 | 70 | // add the connection to poller and handle its io events. 71 | // 72 | //go:norace 73 | func (p *poller) addConn(c *Conn) error { 74 | fd := c.fd 75 | if fd >= len(p.g.connsUnix) { 76 | err := fmt.Errorf("too many open files, fd[%d] >= MaxOpenFiles[%d]", 77 | fd, 78 | len(p.g.connsUnix), 79 | ) 80 | c.closeWithError(err) 81 | return err 82 | } 83 | c.p = p 84 | if c.typ != ConnTypeUDPServer { 85 | p.g.onOpen(c) 86 | } else { 87 | p.g.onUDPListen(c) 88 | } 89 | p.g.connsUnix[fd] = c 90 | err := p.addRead(fd) 91 | if err != nil { 92 | p.g.connsUnix[fd] = nil 93 | c.closeWithError(err) 94 | } 95 | return err 96 | } 97 | 98 | // add the connection to poller and handle its io events. 99 | // 100 | //go:norace 101 | func (p *poller) addDialer(c *Conn) error { 102 | fd := c.fd 103 | if fd >= len(p.g.connsUnix) { 104 | err := fmt.Errorf("too many open files, fd[%d] >= MaxOpenFiles[%d]", 105 | fd, 106 | len(p.g.connsUnix), 107 | ) 108 | c.closeWithError(err) 109 | return err 110 | } 111 | c.p = p 112 | p.g.connsUnix[fd] = c 113 | c.isWAdded = true 114 | err := p.addReadWrite(fd) 115 | if err != nil { 116 | p.g.connsUnix[fd] = nil 117 | c.closeWithError(err) 118 | } 119 | return err 120 | } 121 | 122 | //go:norace 123 | func (p *poller) getConn(fd int) *Conn { 124 | return p.g.connsUnix[fd] 125 | } 126 | 127 | //go:norace 128 | func (p *poller) deleteConn(c *Conn) { 129 | if c == nil { 130 | return 131 | } 132 | fd := c.fd 133 | 134 | if c.typ != ConnTypeUDPClientFromRead { 135 | if c == p.g.connsUnix[fd] { 136 | p.g.connsUnix[fd] = nil 137 | } 138 | // p.deleteEvent(fd) 139 | } 140 | 141 | if c.typ != ConnTypeUDPServer { 142 | p.g.onClose(c, c.closeErr) 143 | } 144 | } 145 | 146 | //go:norace 147 | func (p *poller) start() { 148 | defer p.g.Done() 149 | 150 | logging.Debug("NBIO[%v][%v_%v] start", p.g.Name, p.pollType, p.index) 151 | defer logging.Debug("NBIO[%v][%v_%v] stopped", p.g.Name, p.pollType, p.index) 152 | 153 | if p.isListener { 154 | p.acceptorLoop() 155 | } else { 156 | defer func() { 157 | syscall.Close(p.epfd) 158 | syscall.Close(p.evtfd) 159 | }() 160 | p.readWriteLoop() 161 | } 162 | } 163 | 164 | //go:norace 165 | func (p *poller) acceptorLoop() { 166 | if p.g.LockListener { 167 | runtime.LockOSThread() 168 | defer runtime.UnlockOSThread() 169 | } 170 | 171 | p.shutdown = false 172 | for !p.shutdown { 173 | conn, err := p.listener.Accept() 174 | if err == nil { 175 | var c *Conn 176 | c, err = NBConn(conn) 177 | if err != nil { 178 | conn.Close() 179 | continue 180 | } 181 | err = p.g.pollers[c.Hash()%len(p.g.pollers)].addConn(c) 182 | if err != nil { 183 | logging.Error("NBIO[%v][%v_%v] addConn [fd: %v] failed: %v", 184 | p.g.Name, 185 | p.pollType, 186 | p.index, 187 | c.fd, 188 | err, 189 | ) 190 | } 191 | } else { 192 | var ne net.Error 193 | if ok := errors.As(err, &ne); ok && ne.Timeout() { 194 | logging.Error("NBIO[%v][%v_%v] Accept failed: timeout error, retrying...", 195 | p.g.Name, 196 | p.pollType, 197 | p.index, 198 | ) 199 | time.Sleep(time.Second / 20) 200 | } else { 201 | if !p.shutdown { 202 | logging.Error("NBIO[%v][%v_%v] Accept failed: %v, exit...", 203 | p.g.Name, 204 | p.pollType, 205 | p.index, 206 | err, 207 | ) 208 | } 209 | if p.g.onAcceptError != nil { 210 | p.g.onAcceptError(err) 211 | } 212 | } 213 | } 214 | } 215 | } 216 | 217 | //go:norace 218 | func (p *poller) readWriteLoop() { 219 | if p.g.LockPoller { 220 | runtime.LockOSThread() 221 | defer runtime.UnlockOSThread() 222 | } 223 | 224 | msec := -1 225 | events := make([]syscall.EpollEvent, 1024) 226 | 227 | if p.g.onRead == nil && p.g.EpollMod == EPOLLET { 228 | p.g.MaxConnReadTimesPerEventLoop = 1<<31 - 1 229 | } 230 | 231 | g := p.g 232 | p.shutdown = false 233 | isOneshot := g.isOneshot 234 | asyncReadEnabled := g.AsyncReadInPoller && (g.EpollMod == EPOLLET) 235 | for !p.shutdown { 236 | n, err := syscall.EpollWait(p.epfd, events, msec) 237 | if err != nil && !errors.Is(err, syscall.EINTR) { 238 | logging.Error("NBIO[%v][%v_%v] EpollWait failed: %v, exit...", 239 | p.g.Name, 240 | p.pollType, 241 | p.index, 242 | err, 243 | ) 244 | return 245 | } 246 | 247 | if n <= 0 { 248 | continue 249 | } 250 | 251 | for _, ev := range events[:n] { 252 | fd := int(ev.Fd) 253 | switch fd { 254 | case p.evtfd: // triggered by stop, exit event loop 255 | 256 | default: // for socket connections 257 | c := p.getConn(fd) 258 | if c != nil { 259 | if ev.Events&epollEventsWrite != 0 { 260 | if c.onConnected == nil { 261 | c.flush() 262 | } else { 263 | c.onConnected(c, nil) 264 | c.onConnected = nil 265 | c.resetRead() 266 | } 267 | } 268 | 269 | if ev.Events&epollEventsRead != 0 { 270 | if g.onRead == nil { 271 | if asyncReadEnabled { 272 | c.AsyncRead() 273 | } else { 274 | for i := 0; i < g.MaxConnReadTimesPerEventLoop; i++ { 275 | pbuf := g.borrow(c) 276 | bufLen := len(*pbuf) 277 | rc, n, err := c.ReadAndGetConn(pbuf) 278 | if n > 0 { 279 | *pbuf = (*pbuf)[:n] 280 | g.onDataPtr(rc, pbuf) 281 | } 282 | g.payback(c, pbuf) 283 | if errors.Is(err, syscall.EINTR) { 284 | continue 285 | } 286 | if errors.Is(err, syscall.EAGAIN) { 287 | break 288 | } 289 | if err != nil { 290 | c.closeWithError(err) 291 | break 292 | } 293 | if n < bufLen { 294 | break 295 | } 296 | } 297 | if isOneshot { 298 | c.ResetPollerEvent() 299 | } 300 | } 301 | } else { 302 | g.onRead(c) 303 | } 304 | } 305 | 306 | if ev.Events&epollEventsError != 0 { 307 | c.closeWithError(io.EOF) 308 | continue 309 | } 310 | } 311 | } 312 | } 313 | } 314 | } 315 | 316 | //go:norace 317 | func (p *poller) stop() { 318 | logging.Debug("NBIO[%v][%v_%v] stop...", p.g.Name, p.pollType, p.index) 319 | p.shutdown = true 320 | if p.listener != nil { 321 | p.listener.Close() 322 | if p.unixSockAddr != "" { 323 | os.Remove(p.unixSockAddr) 324 | } 325 | } else { 326 | n := uint64(1) 327 | syscall.Write(p.evtfd, (*(*[8]byte)(unsafe.Pointer(&n)))[:]) 328 | } 329 | } 330 | 331 | //go:norace 332 | func (p *poller) addRead(fd int) error { 333 | return p.setRead(syscall.EPOLL_CTL_ADD, fd) 334 | } 335 | 336 | //go:norace 337 | func (p *poller) resetRead(fd int) error { 338 | return p.setRead(syscall.EPOLL_CTL_MOD, fd) 339 | } 340 | 341 | //go:norace 342 | func (p *poller) setRead(op int, fd int) error { 343 | switch p.g.EpollMod { 344 | case EPOLLET: 345 | events := syscall.EPOLLERR | 346 | syscall.EPOLLHUP | 347 | syscall.EPOLLRDHUP | 348 | syscall.EPOLLPRI | 349 | syscall.EPOLLIN | 350 | EPOLLET | 351 | p.g.EPOLLONESHOT 352 | if p.g.EPOLLONESHOT != EPOLLONESHOT { 353 | if op == syscall.EPOLL_CTL_ADD { 354 | return syscall.EpollCtl(p.epfd, op, fd, &syscall.EpollEvent{ 355 | Fd: int32(fd), 356 | Events: events | syscall.EPOLLOUT, 357 | }) 358 | } 359 | return nil 360 | } 361 | return syscall.EpollCtl(p.epfd, op, fd, &syscall.EpollEvent{ 362 | Fd: int32(fd), 363 | Events: events, 364 | }) 365 | default: 366 | return syscall.EpollCtl( 367 | p.epfd, 368 | op, 369 | fd, 370 | &syscall.EpollEvent{ 371 | Fd: int32(fd), 372 | Events: syscall.EPOLLERR | 373 | syscall.EPOLLHUP | 374 | syscall.EPOLLRDHUP | 375 | syscall.EPOLLPRI | 376 | syscall.EPOLLIN, 377 | }, 378 | ) 379 | } 380 | } 381 | 382 | //go:norace 383 | func (p *poller) modWrite(fd int) error { 384 | return p.setReadWrite(syscall.EPOLL_CTL_MOD, fd) 385 | } 386 | 387 | //go:norace 388 | func (p *poller) addReadWrite(fd int) error { 389 | return p.setReadWrite(syscall.EPOLL_CTL_ADD, fd) 390 | } 391 | 392 | //go:norace 393 | func (p *poller) setReadWrite(op int, fd int) error { 394 | switch p.g.EpollMod { 395 | case EPOLLET: 396 | events := syscall.EPOLLERR | 397 | syscall.EPOLLHUP | 398 | syscall.EPOLLRDHUP | 399 | syscall.EPOLLPRI | 400 | syscall.EPOLLIN | 401 | syscall.EPOLLOUT | 402 | EPOLLET | 403 | p.g.EPOLLONESHOT 404 | if p.g.EPOLLONESHOT != EPOLLONESHOT { 405 | if op == syscall.EPOLL_CTL_ADD { 406 | return syscall.EpollCtl(p.epfd, op, fd, &syscall.EpollEvent{ 407 | Fd: int32(fd), 408 | Events: events, 409 | }) 410 | } 411 | return nil 412 | } 413 | return syscall.EpollCtl(p.epfd, op, fd, &syscall.EpollEvent{ 414 | Fd: int32(fd), 415 | Events: events, 416 | }) 417 | default: 418 | return syscall.EpollCtl( 419 | p.epfd, op, fd, 420 | &syscall.EpollEvent{ 421 | Fd: int32(fd), 422 | Events: syscall.EPOLLERR | 423 | syscall.EPOLLHUP | 424 | syscall.EPOLLRDHUP | 425 | syscall.EPOLLPRI | 426 | syscall.EPOLLIN | 427 | syscall.EPOLLOUT, 428 | }, 429 | ) 430 | } 431 | } 432 | 433 | // func (p *poller) deleteEvent(fd int) error { 434 | // return syscall.EpollCtl( 435 | // p.epfd, 436 | // syscall.EPOLL_CTL_DEL, 437 | // fd, 438 | // &syscall.EpollEvent{Fd: int32(fd)}, 439 | // ) 440 | // } 441 | 442 | //go:norace 443 | func newPoller(g *Engine, isListener bool, index int) (*poller, error) { 444 | if isListener { 445 | if len(g.Addrs) == 0 { 446 | panic("invalid listener num") 447 | } 448 | 449 | addr := g.Addrs[index%len(g.Addrs)] 450 | ln, err := g.Listen(g.Network, addr) 451 | if err != nil { 452 | return nil, err 453 | } 454 | 455 | p := &poller{ 456 | g: g, 457 | index: index, 458 | listener: ln, 459 | isListener: isListener, 460 | pollType: "LISTENER", 461 | } 462 | if g.Network == "unix" { 463 | p.unixSockAddr = addr 464 | } 465 | 466 | return p, nil 467 | } 468 | 469 | fd, err := syscall.EpollCreate1(0) 470 | if err != nil { 471 | return nil, err 472 | } 473 | 474 | r0, _, e0 := syscall.Syscall(syscall.SYS_EVENTFD2, 0, syscall.O_NONBLOCK, 0) 475 | if e0 != 0 { 476 | syscall.Close(fd) 477 | return nil, e0 478 | } 479 | 480 | err = syscall.EpollCtl(fd, syscall.EPOLL_CTL_ADD, int(r0), 481 | &syscall.EpollEvent{Fd: int32(r0), 482 | Events: syscall.EPOLLIN, 483 | }, 484 | ) 485 | if err != nil { 486 | syscall.Close(fd) 487 | syscall.Close(int(r0)) 488 | return nil, err 489 | } 490 | 491 | p := &poller{ 492 | g: g, 493 | epfd: fd, 494 | evtfd: int(r0), 495 | index: index, 496 | isListener: isListener, 497 | pollType: "POLLER", 498 | } 499 | 500 | return p, nil 501 | } 502 | 503 | //go:norace 504 | func (c *Conn) ResetPollerEvent() { 505 | p := c.p 506 | g := p.g 507 | fd := c.fd 508 | if g.isOneshot && !c.closed { 509 | if len(c.writeList) == 0 { 510 | p.resetRead(fd) 511 | } else { 512 | p.modWrite(fd) 513 | } 514 | } 515 | } 516 | -------------------------------------------------------------------------------- /poller_kqueue.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 lesismal. All rights reserved. 2 | // Use of this source code is governed by an MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | //go:build darwin || netbsd || freebsd || openbsd || dragonfly 6 | // +build darwin netbsd freebsd openbsd dragonfly 7 | 8 | package nbio 9 | 10 | import ( 11 | "errors" 12 | "fmt" 13 | "io" 14 | "net" 15 | "os" 16 | "runtime" 17 | "sync" 18 | "syscall" 19 | "time" 20 | 21 | "github.com/lesismal/nbio/logging" 22 | ) 23 | 24 | const ( 25 | // EPOLLLT . 26 | EPOLLLT = 0 27 | 28 | // EPOLLET . 29 | EPOLLET = 1 30 | 31 | // EPOLLONESHOT . 32 | EPOLLONESHOT = 0 33 | ) 34 | 35 | const ( 36 | IPPROTO_TCP = 0 37 | TCP_KEEPINTVL = 0 38 | TCP_KEEPIDLE = 0 39 | ) 40 | 41 | type poller struct { 42 | mux sync.Mutex 43 | 44 | g *Engine 45 | 46 | kfd int 47 | evtfd int 48 | 49 | index int 50 | 51 | shutdown bool 52 | 53 | listener net.Listener 54 | isListener bool 55 | unixSockAddr string 56 | 57 | ReadBuffer []byte 58 | 59 | pollType string 60 | 61 | eventList []syscall.Kevent_t 62 | } 63 | 64 | //go:norace 65 | func (p *poller) addConn(c *Conn) error { 66 | fd := c.fd 67 | if fd >= len(p.g.connsUnix) { 68 | err := fmt.Errorf("too many open files, fd[%d] >= MaxOpenFiles[%d]", 69 | fd, 70 | len(p.g.connsUnix)) 71 | c.closeWithError(err) 72 | return err 73 | } 74 | c.p = p 75 | if c.typ != ConnTypeUDPServer { 76 | p.g.onOpen(c) 77 | } else { 78 | p.g.onUDPListen(c) 79 | } 80 | p.g.connsUnix[fd] = c 81 | p.addRead(fd) 82 | return nil 83 | } 84 | 85 | //go:norace 86 | func (p *poller) addDialer(c *Conn) error { 87 | fd := c.fd 88 | if fd >= len(p.g.connsUnix) { 89 | err := fmt.Errorf("too many open files, fd[%d] >= MaxOpenFiles[%d]", 90 | fd, 91 | len(p.g.connsUnix), 92 | ) 93 | c.closeWithError(err) 94 | return err 95 | } 96 | c.p = p 97 | p.g.connsUnix[fd] = c 98 | c.isWAdded = true 99 | p.addReadWrite(fd) 100 | return nil 101 | } 102 | 103 | //go:norace 104 | func (p *poller) getConn(fd int) *Conn { 105 | return p.g.connsUnix[fd] 106 | } 107 | 108 | //go:norace 109 | func (p *poller) deleteConn(c *Conn) { 110 | if c == nil { 111 | return 112 | } 113 | fd := c.fd 114 | 115 | if c.typ != ConnTypeUDPClientFromRead { 116 | if c == p.g.connsUnix[fd] { 117 | p.g.connsUnix[fd] = nil 118 | } 119 | // p.deleteEvent(fd) 120 | } 121 | 122 | if c.typ != ConnTypeUDPServer { 123 | p.g.onClose(c, c.closeErr) 124 | } 125 | } 126 | 127 | //go:norace 128 | func (p *poller) trigger() { 129 | syscall.Kevent(p.kfd, []syscall.Kevent_t{{Ident: 0, Filter: syscall.EVFILT_USER, Fflags: syscall.NOTE_TRIGGER}}, nil, nil) 130 | } 131 | 132 | //go:norace 133 | func (p *poller) addRead(fd int) { 134 | p.mux.Lock() 135 | p.eventList = append(p.eventList, syscall.Kevent_t{Ident: uint64(fd), Flags: syscall.EV_ADD, Filter: syscall.EVFILT_READ}) 136 | // p.eventList = append(p.eventList, syscall.Kevent_t{Ident: uint64(fd), Flags: syscall.EV_ADD, Filter: syscall.EVFILT_WRITE}) 137 | p.mux.Unlock() 138 | p.trigger() 139 | } 140 | 141 | //go:norace 142 | func (p *poller) resetRead(fd int) { 143 | p.mux.Lock() 144 | p.eventList = append(p.eventList, syscall.Kevent_t{Ident: uint64(fd), Flags: syscall.EV_DELETE, Filter: syscall.EVFILT_WRITE}) 145 | p.mux.Unlock() 146 | p.trigger() 147 | } 148 | 149 | //go:norace 150 | func (p *poller) modWrite(fd int) { 151 | p.mux.Lock() 152 | p.eventList = append(p.eventList, syscall.Kevent_t{Ident: uint64(fd), Flags: syscall.EV_ADD, Filter: syscall.EVFILT_WRITE}) 153 | p.mux.Unlock() 154 | p.trigger() 155 | } 156 | 157 | //go:norace 158 | func (p *poller) addReadWrite(fd int) { 159 | p.mux.Lock() 160 | p.eventList = append(p.eventList, syscall.Kevent_t{Ident: uint64(fd), Flags: syscall.EV_ADD, Filter: syscall.EVFILT_READ}) 161 | p.eventList = append(p.eventList, syscall.Kevent_t{Ident: uint64(fd), Flags: syscall.EV_ADD, Filter: syscall.EVFILT_WRITE}) 162 | p.mux.Unlock() 163 | p.trigger() 164 | } 165 | 166 | // func (p *poller) deleteEvent(fd int) { 167 | // p.mux.Lock() 168 | // p.eventList = append(p.eventList, 169 | // syscall.Kevent_t{Ident: uint64(fd), Flags: syscall.EV_DELETE, Filter: syscall.EVFILT_READ}, 170 | // syscall.Kevent_t{Ident: uint64(fd), Flags: syscall.EV_DELETE, Filter: syscall.EVFILT_WRITE}) 171 | // p.mux.Unlock() 172 | // p.trigger() 173 | // } 174 | 175 | //go:norace 176 | func (p *poller) readWrite(ev *syscall.Kevent_t) { 177 | if ev.Flags&syscall.EV_DELETE > 0 { 178 | return 179 | } 180 | fd := int(ev.Ident) 181 | c := p.getConn(fd) 182 | if c != nil { 183 | if ev.Filter == syscall.EVFILT_READ { 184 | if p.g.onRead == nil { 185 | for { 186 | pbuf := p.g.borrow(c) 187 | bufLen := len(*pbuf) 188 | rc, n, err := c.ReadAndGetConn(pbuf) 189 | if n > 0 { 190 | *pbuf = (*pbuf)[:n] 191 | p.g.onDataPtr(rc, pbuf) 192 | } 193 | p.g.payback(c, pbuf) 194 | if errors.Is(err, syscall.EINTR) { 195 | continue 196 | } 197 | if errors.Is(err, syscall.EAGAIN) { 198 | return 199 | } 200 | if (err != nil || n == 0) && ev.Flags&syscall.EV_DELETE == 0 { 201 | if err == nil { 202 | err = io.EOF 203 | } 204 | c.closeWithError(err) 205 | } 206 | if n < bufLen { 207 | break 208 | } 209 | } 210 | } else { 211 | p.g.onRead(c) 212 | } 213 | 214 | if ev.Flags&syscall.EV_EOF != 0 { 215 | if c.onConnected == nil { 216 | c.flush() 217 | } else { 218 | c.onConnected(c, nil) 219 | c.onConnected = nil 220 | c.resetRead() 221 | } 222 | } 223 | } 224 | 225 | if ev.Filter == syscall.EVFILT_WRITE { 226 | if c.onConnected == nil { 227 | c.flush() 228 | } else { 229 | c.resetRead() 230 | c.onConnected(c, nil) 231 | c.onConnected = nil 232 | } 233 | } 234 | } 235 | } 236 | 237 | //go:norace 238 | func (p *poller) start() { 239 | if p.g.LockPoller { 240 | runtime.LockOSThread() 241 | defer runtime.UnlockOSThread() 242 | } 243 | defer p.g.Done() 244 | 245 | logging.Debug("NBIO[%v][%v_%v] start", p.g.Name, p.pollType, p.index) 246 | defer logging.Debug("NBIO[%v][%v_%v] stopped", p.g.Name, p.pollType, p.index) 247 | 248 | if p.isListener { 249 | p.acceptorLoop() 250 | } else { 251 | defer syscall.Close(p.kfd) 252 | p.readWriteLoop() 253 | } 254 | } 255 | 256 | //go:norace 257 | func (p *poller) acceptorLoop() { 258 | if p.g.LockListener { 259 | runtime.LockOSThread() 260 | defer runtime.UnlockOSThread() 261 | } 262 | 263 | p.shutdown = false 264 | for !p.shutdown { 265 | conn, err := p.listener.Accept() 266 | if err == nil { 267 | var c *Conn 268 | c, err = NBConn(conn) 269 | if err != nil { 270 | conn.Close() 271 | continue 272 | } 273 | p.g.pollers[c.Hash()%len(p.g.pollers)].addConn(c) 274 | } else { 275 | var ne net.Error 276 | if ok := errors.As(err, &ne); ok && ne.Timeout() { 277 | logging.Error("NBIO[%v][%v_%v] Accept failed: timeout error, retrying...", p.g.Name, p.pollType, p.index) 278 | time.Sleep(time.Second / 20) 279 | } else { 280 | if !p.shutdown { 281 | logging.Error("NBIO[%v][%v_%v] Accept failed: %v, exit...", p.g.Name, p.pollType, p.index, err) 282 | } 283 | if p.g.onAcceptError != nil { 284 | p.g.onAcceptError(err) 285 | } 286 | } 287 | } 288 | } 289 | } 290 | 291 | //go:norace 292 | func (p *poller) readWriteLoop() { 293 | if p.g.LockPoller { 294 | runtime.LockOSThread() 295 | defer runtime.UnlockOSThread() 296 | } 297 | 298 | events := make([]syscall.Kevent_t, 1024) 299 | var changes []syscall.Kevent_t 300 | 301 | p.shutdown = false 302 | for !p.shutdown { 303 | p.mux.Lock() 304 | changes = p.eventList 305 | p.eventList = nil 306 | p.mux.Unlock() 307 | n, err := syscall.Kevent(p.kfd, changes, events, nil) 308 | if err != nil && !errors.Is(err, syscall.EINTR) && !errors.Is(err, syscall.EBADF) && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EINVAL) { 309 | logging.Error("NBIO[%v][%v_%v] Kevent failed: %v, exit...", p.g.Name, p.pollType, p.index, err) 310 | return 311 | } 312 | 313 | for i := 0; i < n; i++ { 314 | switch int(events[i].Ident) { 315 | case p.evtfd: 316 | default: 317 | p.readWrite(&events[i]) 318 | } 319 | } 320 | } 321 | } 322 | 323 | //go:norace 324 | func (p *poller) stop() { 325 | logging.Debug("NBIO[%v][%v_%v] stop...", p.g.Name, p.pollType, p.index) 326 | p.shutdown = true 327 | if p.listener != nil { 328 | p.listener.Close() 329 | if p.unixSockAddr != "" { 330 | os.Remove(p.unixSockAddr) 331 | } 332 | } 333 | p.trigger() 334 | } 335 | 336 | //go:norace 337 | func newPoller(g *Engine, isListener bool, index int) (*poller, error) { 338 | if isListener { 339 | if len(g.Addrs) == 0 { 340 | panic("invalid listener num") 341 | } 342 | 343 | addr := g.Addrs[index%len(g.Addrs)] 344 | ln, err := g.Listen(g.Network, addr) 345 | if err != nil { 346 | return nil, err 347 | } 348 | 349 | p := &poller{ 350 | g: g, 351 | index: index, 352 | listener: ln, 353 | isListener: isListener, 354 | pollType: "LISTENER", 355 | } 356 | if g.Network == "unix" { 357 | p.unixSockAddr = addr 358 | } 359 | 360 | return p, nil 361 | } 362 | 363 | fd, err := syscall.Kqueue() 364 | if err != nil { 365 | return nil, err 366 | } 367 | 368 | _, err = syscall.Kevent(fd, []syscall.Kevent_t{{ 369 | Ident: 0, 370 | Filter: syscall.EVFILT_USER, 371 | Flags: syscall.EV_ADD | syscall.EV_CLEAR, 372 | }}, nil, nil) 373 | 374 | if err != nil { 375 | syscall.Close(fd) 376 | return nil, err 377 | } 378 | 379 | p := &poller{ 380 | g: g, 381 | kfd: fd, 382 | index: index, 383 | isListener: isListener, 384 | pollType: "POLLER", 385 | } 386 | 387 | return p, nil 388 | } 389 | 390 | //go:norace 391 | func (c *Conn) ResetPollerEvent() { 392 | } 393 | -------------------------------------------------------------------------------- /poller_std.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 lesismal. All rights reserved. 2 | // Use of this source code is governed by an MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | //go:build windows 6 | // +build windows 7 | 8 | package nbio 9 | 10 | import ( 11 | "errors" 12 | "net" 13 | "runtime" 14 | "time" 15 | 16 | "github.com/lesismal/nbio/logging" 17 | ) 18 | 19 | const ( 20 | // EPOLLLT . 21 | EPOLLLT = 0 22 | 23 | // EPOLLET . 24 | EPOLLET = 1 25 | 26 | // EPOLLONESHOT . 27 | EPOLLONESHOT = 0 28 | ) 29 | 30 | type poller struct { 31 | g *Engine 32 | 33 | index int 34 | 35 | ReadBuffer []byte 36 | 37 | pollType string 38 | isListener bool 39 | listener net.Listener 40 | shutdown bool 41 | 42 | chStop chan struct{} 43 | } 44 | 45 | //go:norace 46 | func (p *poller) accept() error { 47 | conn, err := p.listener.Accept() 48 | if err != nil { 49 | return err 50 | } 51 | 52 | c := newConn(conn) 53 | o := p.g.pollers[c.Hash()%len(p.g.pollers)] 54 | o.addConn(c) 55 | 56 | return nil 57 | } 58 | 59 | //go:norace 60 | func (p *poller) readConn(c *Conn) { 61 | for { 62 | pbuf := p.g.borrow(c) 63 | _, err := c.read(*pbuf) 64 | p.g.payback(c, pbuf) 65 | if err != nil { 66 | c.Close() 67 | return 68 | } 69 | } 70 | } 71 | 72 | //go:norace 73 | func (p *poller) addConn(c *Conn) error { 74 | c.p = p 75 | p.g.mux.Lock() 76 | p.g.connsStd[c] = struct{}{} 77 | p.g.mux.Unlock() 78 | // should not call onOpen for udp server conn 79 | if c.typ != ConnTypeUDPServer { 80 | p.g.onOpen(c) 81 | } else { 82 | p.g.onUDPListen(c) 83 | } 84 | // should not read udp client from reading udp server conn 85 | if c.typ != ConnTypeUDPClientFromRead { 86 | go p.readConn(c) 87 | } 88 | 89 | return nil 90 | } 91 | 92 | //go:norace 93 | func (p *poller) addDialer(c *Conn) error { 94 | c.p = p 95 | p.g.mux.Lock() 96 | p.g.connsStd[c] = struct{}{} 97 | p.g.mux.Unlock() 98 | go p.readConn(c) 99 | return nil 100 | } 101 | 102 | //go:norace 103 | func (p *poller) deleteConn(c *Conn) { 104 | p.g.mux.Lock() 105 | delete(p.g.connsStd, c) 106 | p.g.mux.Unlock() 107 | // should not call onClose for udp server conn 108 | if c.typ != ConnTypeUDPServer { 109 | p.g.onClose(c, c.closeErr) 110 | } 111 | } 112 | 113 | //go:norace 114 | func (p *poller) start() { 115 | if p.g.LockListener { 116 | runtime.LockOSThread() 117 | defer runtime.UnlockOSThread() 118 | } 119 | defer p.g.Done() 120 | 121 | logging.Debug("NBIO[%v][%v_%v] start", p.g.Name, p.pollType, p.index) 122 | defer logging.Debug("NBIO[%v][%v_%v] stopped", p.g.Name, p.pollType, p.index) 123 | 124 | if p.isListener { 125 | var err error 126 | p.shutdown = false 127 | for !p.shutdown { 128 | err = p.accept() 129 | if err != nil { 130 | var ne net.Error 131 | if ok := errors.As(err, &ne); ok && ne.Timeout() { 132 | logging.Error("NBIO[%v][%v_%v] Accept failed: timeout error, retrying...", p.g.Name, p.pollType, p.index) 133 | time.Sleep(time.Second / 20) 134 | } else { 135 | if !p.shutdown { 136 | logging.Error("NBIO[%v][%v_%v] Accept failed: %v, exit...", p.g.Name, p.pollType, p.index, err) 137 | } 138 | if p.g.onAcceptError != nil { 139 | p.g.onAcceptError(err) 140 | } 141 | } 142 | } 143 | 144 | } 145 | } 146 | <-p.chStop 147 | } 148 | 149 | //go:norace 150 | func (p *poller) stop() { 151 | logging.Debug("NBIO[%v][%v_%v] stop...", p.g.Name, p.pollType, p.index) 152 | p.shutdown = true 153 | if p.isListener { 154 | p.listener.Close() 155 | } 156 | close(p.chStop) 157 | } 158 | 159 | //go:norace 160 | func newPoller(g *Engine, isListener bool, index int) (*poller, error) { 161 | p := &poller{ 162 | g: g, 163 | index: index, 164 | isListener: isListener, 165 | chStop: make(chan struct{}), 166 | } 167 | 168 | if isListener { 169 | var err error 170 | var addr = g.Addrs[index%len(g.Addrs)] 171 | p.listener, err = g.Listen(g.Network, addr) 172 | if err != nil { 173 | return nil, err 174 | } 175 | p.pollType = "LISTENER" 176 | } else { 177 | p.pollType = "POLLER" 178 | } 179 | 180 | return p, nil 181 | } 182 | 183 | //go:norace 184 | func (c *Conn) ResetPollerEvent() { 185 | 186 | } 187 | -------------------------------------------------------------------------------- /protocol_stack.go: -------------------------------------------------------------------------------- 1 | package nbio 2 | 3 | import ( 4 | "net" 5 | ) 6 | 7 | type Protocol interface { 8 | Parse(c net.Conn, b []byte, ps *ProtocolStack) (net.Conn, []byte, error) 9 | Write(b []byte) (int, error) 10 | } 11 | 12 | type ProtocolStack struct { 13 | stack []Protocol 14 | } 15 | 16 | //go:norace 17 | func (ps *ProtocolStack) Add(p Protocol) { 18 | ps.stack = append(ps.stack, p) 19 | } 20 | 21 | //go:norace 22 | func (ps *ProtocolStack) Delete(p Protocol) { 23 | i := len(ps.stack) - 1 24 | for i >= 0 { 25 | if ps.stack[i] == p { 26 | ps.stack[i] = nil 27 | if i+1 > len(ps.stack)-1 { 28 | ps.stack = ps.stack[:i] 29 | } else { 30 | ps.stack = append(ps.stack[:i], ps.stack[i+1:]...) 31 | } 32 | return 33 | } 34 | i-- 35 | } 36 | } 37 | 38 | //go:norace 39 | func (ps *ProtocolStack) Parse(c net.Conn, b []byte, ps_ ProtocolStack) (net.Conn, []byte, error) { 40 | var err error 41 | for _, p := range ps.stack { 42 | if p == nil { 43 | continue 44 | } 45 | c, b, err = p.Parse(c, b, ps) 46 | if err != nil { 47 | break 48 | } 49 | } 50 | return c, b, err 51 | } 52 | 53 | //go:norace 54 | func (ps *ProtocolStack) Write(b []byte) (int, error) { 55 | return -1, ErrUnsupported 56 | } 57 | -------------------------------------------------------------------------------- /sendfile_std.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 lesismal. All rights reserved. 2 | // Use of this source code is governed by an MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | //go:build windows 6 | // +build windows 7 | 8 | package nbio 9 | 10 | import ( 11 | "io" 12 | "os" 13 | ) 14 | 15 | // Sendfile . 16 | // 17 | //go:norace 18 | func (c *Conn) Sendfile(f *os.File, remain int64) (written int64, err error) { 19 | if f == nil { 20 | return 0, nil 21 | } 22 | 23 | if remain <= 0 { 24 | stat, e := f.Stat() 25 | if e != nil { 26 | return 0, e 27 | } 28 | remain = stat.Size() 29 | } 30 | 31 | for remain > 0 { 32 | bufLen := 1024 * 32 33 | if bufLen > int(remain) { 34 | bufLen = int(remain) 35 | } 36 | pbuf := c.p.g.BodyAllocator.Malloc(bufLen) 37 | nr, er := f.Read(*pbuf) 38 | if nr > 0 { 39 | nw, ew := c.Write((*pbuf)[0:nr]) 40 | c.p.g.BodyAllocator.Free(pbuf) 41 | if nw < 0 { 42 | nw = 0 43 | } 44 | remain -= int64(nw) 45 | written += int64(nw) 46 | if ew != nil { 47 | err = ew 48 | break 49 | } 50 | if nr != nw { 51 | err = io.ErrShortWrite 52 | break 53 | } 54 | } 55 | if er != nil { 56 | if er != io.EOF { 57 | err = er 58 | } 59 | break 60 | } 61 | } 62 | 63 | if c.p.g.onWrittenSize != nil && written > 0 { 64 | c.p.g.onWrittenSize(c, nil, int(written)) 65 | } 66 | 67 | return written, err 68 | } 69 | -------------------------------------------------------------------------------- /sendfile_unix.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 lesismal. All rights reserved. 2 | // Use of this source code is governed by an MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | //go:build linux || darwin || netbsd || freebsd || openbsd || dragonfly 6 | // +build linux darwin netbsd freebsd openbsd dragonfly 7 | 8 | package nbio 9 | 10 | import ( 11 | "errors" 12 | "io" 13 | "net" 14 | "os" 15 | "syscall" 16 | ) 17 | 18 | const maxSendfileSize = 4 << 20 19 | 20 | // Sendfile . 21 | // 22 | //go:norace 23 | func (c *Conn) Sendfile(f *os.File, remain int64) (int64, error) { 24 | if f == nil { 25 | return 0, nil 26 | } 27 | 28 | c.mux.Lock() 29 | defer c.mux.Unlock() 30 | if c.closed { 31 | return 0, net.ErrClosed 32 | } 33 | 34 | offset, err := f.Seek(0, io.SeekCurrent) 35 | if err != nil { 36 | return 0, err 37 | } 38 | stat, err := f.Stat() 39 | if err != nil { 40 | return 0, err 41 | } 42 | size := stat.Size() 43 | if (remain <= 0) || (remain > size-offset) { 44 | remain = size - offset 45 | } 46 | 47 | // f.Fd() will set the fd to blocking mod. 48 | // We need to set the fd to non-blocking mod again. 49 | src := int(f.Fd()) 50 | err = syscall.SetNonblock(src, true) 51 | if err != nil { 52 | return 0, err 53 | } 54 | 55 | // If c.writeList is not empty, the socket is not writable now. 56 | // We push this File to writeList and wait to send it when writable. 57 | if len(c.writeList) > 0 { 58 | // After this Sendfile func returns, fs will be closed by the caller. 59 | // So we need to dup the fd and close it when we don't need it any more. 60 | src, err = syscall.Dup(src) 61 | if err != nil { 62 | return 0, err 63 | } 64 | c.newToWriteFile(src, offset, remain) 65 | // c.appendWrite(t) 66 | return remain, nil 67 | } 68 | 69 | // c.p.g.beforeWrite(c) 70 | 71 | var ( 72 | n int 73 | dst = c.fd 74 | total = remain 75 | ) 76 | 77 | for remain > 0 { 78 | n = maxSendfileSize 79 | if int64(n) > remain { 80 | n = int(remain) 81 | } 82 | var tmpOffset = offset 83 | n, err = syscall.Sendfile(dst, src, &tmpOffset, n) 84 | if n > 0 { 85 | remain -= int64(n) 86 | offset += int64(n) 87 | } else if n == 0 && err == nil { 88 | break 89 | } 90 | if errors.Is(err, syscall.EINTR) { 91 | continue 92 | } 93 | if errors.Is(err, syscall.EAGAIN) { 94 | // After this Sendfile func returns, fs will be closed by the caller. 95 | // So we need to dup the fd and close it when we don't need it any more. 96 | src, err = syscall.Dup(src) 97 | if err == nil { 98 | c.newToWriteFile(src, offset, remain) 99 | // c.appendWrite(t) 100 | c.modWrite() 101 | } 102 | break 103 | } 104 | if err != nil { 105 | c.closed = true 106 | c.closeWithErrorWithoutLock(err) 107 | return 0, err 108 | } 109 | } 110 | 111 | return total, nil 112 | } 113 | -------------------------------------------------------------------------------- /taskpool/iotaskpool.go: -------------------------------------------------------------------------------- 1 | package taskpool 2 | 3 | import ( 4 | "sync" 5 | ) 6 | 7 | // IOTaskPool . 8 | type IOTaskPool struct { 9 | task *TaskPool 10 | pool sync.Pool 11 | } 12 | 13 | // Call . 14 | // 15 | //go:norace 16 | func (tp *IOTaskPool) Call(f func(*[]byte)) { 17 | tp.task.Call(func() { 18 | pbuf := tp.pool.Get().(*[]byte) 19 | f(pbuf) 20 | tp.pool.Put(pbuf) 21 | }) 22 | } 23 | 24 | // Go . 25 | // 26 | //go:norace 27 | func (tp *IOTaskPool) Go(f func(*[]byte)) { 28 | tp.task.Go(func() { 29 | pbuf := tp.pool.Get().(*[]byte) 30 | f(pbuf) 31 | tp.pool.Put(pbuf) 32 | }) 33 | } 34 | 35 | // Stop . 36 | // 37 | //go:norace 38 | func (tp *IOTaskPool) Stop() { 39 | tp.task.Stop() 40 | } 41 | 42 | // NewIO creates and returns a IOTaskPool. 43 | // 44 | //go:norace 45 | func NewIO(concurrent, queueSize, bufSize int, v ...interface{}) *IOTaskPool { 46 | task := New(concurrent, queueSize, v...) 47 | 48 | tp := &IOTaskPool{ 49 | task: task, 50 | pool: sync.Pool{ 51 | New: func() interface{} { 52 | buf := make([]byte, bufSize) 53 | return &buf 54 | }, 55 | }, 56 | } 57 | 58 | return tp 59 | } 60 | -------------------------------------------------------------------------------- /taskpool/taskpool.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 lesismal. All rights reserved. 2 | // Use of this source code is governed by an MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package taskpool 6 | 7 | import ( 8 | "runtime" 9 | "sync/atomic" 10 | "unsafe" 11 | 12 | "github.com/lesismal/nbio/logging" 13 | ) 14 | 15 | // TaskPool . 16 | type TaskPool struct { 17 | concurrent int64 18 | maxConcurrent int64 19 | chQqueue chan func() 20 | chClose chan struct{} 21 | caller func(f func()) 22 | } 23 | 24 | // fork . 25 | // 26 | //go:norace 27 | func (tp *TaskPool) fork(f func()) bool { 28 | if atomic.AddInt64(&tp.concurrent, 1) < tp.maxConcurrent { 29 | go func() { 30 | defer atomic.AddInt64(&tp.concurrent, -1) 31 | tp.caller(f) 32 | for { 33 | select { 34 | case f = <-tp.chQqueue: 35 | if f != nil { 36 | tp.caller(f) 37 | } 38 | default: 39 | return 40 | } 41 | } 42 | }() 43 | return true 44 | } 45 | return false 46 | } 47 | 48 | // Call . 49 | // 50 | //go:norace 51 | func (tp *TaskPool) Call(f func()) { 52 | tp.caller(f) 53 | } 54 | 55 | // Go . 56 | // 57 | //go:norace 58 | func (tp *TaskPool) Go(f func()) { 59 | // If current goroutine num is less than maxConcurrent, 60 | // creat a new goroutine to exec new task. 61 | if tp.fork(f) { 62 | return 63 | } 64 | 65 | // Else push the new task into chan/queue. 66 | atomic.AddInt64(&tp.concurrent, -1) 67 | select { 68 | case tp.chQqueue <- f: 69 | case <-tp.chClose: 70 | } 71 | } 72 | 73 | // Stop . 74 | // 75 | //go:norace 76 | func (tp *TaskPool) Stop() { 77 | atomic.AddInt64(&tp.concurrent, tp.maxConcurrent) 78 | close(tp.chClose) 79 | } 80 | 81 | // New creates and returns a TaskPool. 82 | // 83 | //go:norace 84 | func New(maxConcurrent int, chQqueueSize int, v ...interface{}) *TaskPool { 85 | tp := &TaskPool{ 86 | maxConcurrent: int64(maxConcurrent - 1), 87 | chQqueue: make(chan func(), chQqueueSize), 88 | chClose: make(chan struct{}), 89 | } 90 | tp.caller = func(f func()) { 91 | defer func() { 92 | if err := recover(); err != nil { 93 | const size = 64 << 10 94 | buf := make([]byte, size) 95 | buf = buf[:runtime.Stack(buf, false)] 96 | logging.Error("taskpool call failed: %v\n%v\n", err, *(*string)(unsafe.Pointer(&buf))) 97 | } 98 | }() 99 | f() 100 | } 101 | if len(v) > 0 { 102 | if caller, ok := v[0].(func(f func())); ok { 103 | tp.caller = func(f func()) { 104 | defer atomic.AddInt64(&tp.concurrent, -1) 105 | caller(f) 106 | } 107 | } 108 | } 109 | go func() { 110 | for { 111 | select { 112 | case f := <-tp.chQqueue: 113 | if tp.fork(f) { 114 | continue 115 | } 116 | 117 | if f != nil { 118 | tp.caller(f) 119 | } 120 | case <-tp.chClose: 121 | return 122 | } 123 | } 124 | }() 125 | return tp 126 | } 127 | -------------------------------------------------------------------------------- /taskpool/taskpool_test.go: -------------------------------------------------------------------------------- 1 | package taskpool 2 | 3 | import ( 4 | "runtime" 5 | "sync" 6 | "testing" 7 | "time" 8 | "unsafe" 9 | 10 | "github.com/lesismal/nbio/logging" 11 | ) 12 | 13 | const testLoopNum = 1024 * 8 14 | const sleepTime = time.Nanosecond * 0 15 | 16 | func BenchmarkGo(b *testing.B) { 17 | b.ReportAllocs() 18 | b.ResetTimer() 19 | 20 | for i := 0; i < b.N; i++ { 21 | wg := sync.WaitGroup{} 22 | wg.Add(testLoopNum) 23 | for j := 0; j < testLoopNum; j++ { 24 | go func() { 25 | defer func() { 26 | if err := recover(); err != nil { 27 | const size = 64 << 10 28 | buf := make([]byte, size) 29 | buf = buf[:runtime.Stack(buf, false)] 30 | logging.Error("taskpool call failed: %v\n%v\n", err, *(*string)(unsafe.Pointer(&buf))) 31 | } 32 | }() 33 | if sleepTime > 0 { 34 | time.Sleep(sleepTime) 35 | } 36 | wg.Done() 37 | }() 38 | } 39 | wg.Wait() 40 | } 41 | } 42 | 43 | func BenchmarkTaskPool(b *testing.B) { 44 | p := New(32, 1024) 45 | defer p.Stop() 46 | 47 | b.ReportAllocs() 48 | b.ResetTimer() 49 | 50 | for i := 0; i < b.N; i++ { 51 | wg := sync.WaitGroup{} 52 | wg.Add(testLoopNum) 53 | for j := 0; j < testLoopNum; j++ { 54 | p.Go(func() { 55 | if sleepTime > 0 { 56 | time.Sleep(sleepTime) 57 | } 58 | wg.Done() 59 | }) 60 | } 61 | wg.Wait() 62 | } 63 | } 64 | 65 | func BenchmarkIOTaskPool(b *testing.B) { 66 | p := NewIO(32, 1024, 1024) 67 | defer p.Stop() 68 | 69 | b.ReportAllocs() 70 | b.ResetTimer() 71 | for i := 0; i < b.N; i++ { 72 | wg := sync.WaitGroup{} 73 | wg.Add(testLoopNum) 74 | for j := 0; j < testLoopNum; j++ { 75 | p.Go(func(pbuf *[]byte) { 76 | if sleepTime > 0 { 77 | time.Sleep(sleepTime) 78 | } 79 | wg.Done() 80 | }) 81 | } 82 | wg.Wait() 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /timer/timer.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 lesismal. All rights reserved. 2 | // Use of this source code is governed by an MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package timer 6 | 7 | import ( 8 | "math" 9 | "runtime" 10 | "sync" 11 | "time" 12 | "unsafe" 13 | 14 | "github.com/lesismal/nbio/logging" 15 | ) 16 | 17 | const ( 18 | TimeForever = time.Duration(math.MaxInt64) 19 | ) 20 | 21 | type Timer struct { 22 | name string 23 | asyncMux sync.Mutex 24 | asyncList []func() 25 | } 26 | 27 | //go:norace 28 | func New(name string) *Timer { 29 | return &Timer{name: name, asyncList: make([]func(), 8)[0:0]} 30 | } 31 | 32 | // IsTimerRunning . 33 | // 34 | //go:norace 35 | func (t *Timer) IsTimerRunning() bool { 36 | return true 37 | } 38 | 39 | // Start . 40 | // 41 | //go:norace 42 | func (t *Timer) Start() {} 43 | 44 | // Stop . 45 | // 46 | //go:norace 47 | func (t *Timer) Stop() {} 48 | 49 | // After used as time.After. 50 | // 51 | //go:norace 52 | func (t *Timer) After(d time.Duration) <-chan time.Time { 53 | return time.After(d) 54 | } 55 | 56 | // AfterFunc used as time.AfterFunc. 57 | // 58 | //go:norace 59 | func (t *Timer) AfterFunc(timeout time.Duration, f func()) *time.Timer { 60 | return time.AfterFunc(timeout, func() { 61 | defer func() { 62 | err := recover() 63 | if err != nil { 64 | const size = 64 << 10 65 | buf := make([]byte, size) 66 | buf = buf[:runtime.Stack(buf, false)] 67 | logging.Error("Timer[%v] exec call failed: %v\n%v\n", t.name, err, *(*string)(unsafe.Pointer(&buf))) 68 | } 69 | }() 70 | f() 71 | }) 72 | } 73 | 74 | // Async executes f in another goroutine. 75 | // 76 | //go:norace 77 | func (t *Timer) Async(f func()) { 78 | t.asyncMux.Lock() 79 | isHead := (len(t.asyncList) == 0) 80 | t.asyncList = append(t.asyncList, f) 81 | t.asyncMux.Unlock() 82 | if isHead { 83 | go func() { 84 | i := 0 85 | for { 86 | t.asyncMux.Lock() 87 | if i == len(t.asyncList) { 88 | if cap(t.asyncList) > 1024 { 89 | t.asyncList = make([]func(), 0, 8) 90 | } else { 91 | t.asyncList = t.asyncList[0:0] 92 | } 93 | t.asyncMux.Unlock() 94 | return 95 | } 96 | f := t.asyncList[i] 97 | i++ 98 | t.asyncMux.Unlock() 99 | func() { 100 | defer func() { 101 | err := recover() 102 | if err != nil { 103 | const size = 64 << 10 104 | buf := make([]byte, size) 105 | buf = buf[:runtime.Stack(buf, false)] 106 | logging.Error("Timer[%v] async call failed: %v\n%v\n", t.name, err, *(*string)(unsafe.Pointer(&buf))) 107 | } 108 | }() 109 | f() 110 | }() 111 | } 112 | }() 113 | } 114 | } 115 | 116 | // func (t *Timer) Async(f func()) { 117 | 118 | // go func() { 119 | // defer func() { 120 | // err := recover() 121 | // if err != nil { 122 | // const size = 64 << 10 123 | // buf := make([]byte, size) 124 | // buf = buf[:runtime.Stack(buf, false)] 125 | // logging.Error("Timer[%v] exec call failed: %v\n%v\n", t.name, err, *(*string)(unsafe.Pointer(&buf))) 126 | // } 127 | // }() 128 | // f() 129 | // }() 130 | // } 131 | -------------------------------------------------------------------------------- /timer/timer_test.go: -------------------------------------------------------------------------------- 1 | package timer 2 | 3 | import ( 4 | "log" 5 | "math/rand" 6 | "sync" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | func TestTimer(t *testing.T) { 12 | tg := New("nbio") 13 | tg.Start() 14 | defer tg.Stop() 15 | 16 | timeout := time.Second / 50 17 | 18 | testAsync(tg) 19 | testTimerNormal(tg, timeout) 20 | testTimerExecPanic(tg, timeout) 21 | testTimerNormalExecMany(tg, timeout) 22 | testTimerExecManyRandtime(tg) 23 | } 24 | 25 | func testAsync(tg *Timer) { 26 | loops := 3 27 | wg := sync.WaitGroup{} 28 | for i := 0; i < loops; i++ { 29 | wg.Add(1) 30 | tg.Async(func() { 31 | defer wg.Done() 32 | }) 33 | } 34 | wg.Wait() 35 | } 36 | 37 | func testTimerNormal(tg *Timer, timeout time.Duration) { 38 | t1 := time.Now() 39 | ch1 := make(chan int) 40 | tg.AfterFunc(timeout*5, func() { 41 | close(ch1) 42 | }) 43 | <-ch1 44 | to1 := time.Since(t1) 45 | if to1 < timeout*4 || to1 > timeout*10 { 46 | log.Panicf("invalid to1: %v", to1) 47 | } 48 | 49 | t2 := time.Now() 50 | ch2 := make(chan int) 51 | it2 := tg.AfterFunc(timeout, func() { 52 | close(ch2) 53 | }) 54 | it2.Reset(timeout * 5) 55 | <-ch2 56 | to2 := time.Since(t2) 57 | if to2 < timeout*4 || to2 > timeout*10 { 58 | log.Panicf("invalid to2: %v", to2) 59 | } 60 | 61 | ch3 := make(chan int) 62 | it3 := tg.AfterFunc(timeout, func() { 63 | close(ch3) 64 | }) 65 | it3.Stop() 66 | <-tg.After(timeout * 2) 67 | select { 68 | case <-ch3: 69 | log.Panicf("stop failed") 70 | default: 71 | } 72 | } 73 | 74 | func testTimerExecPanic(tg *Timer, timeout time.Duration) { 75 | tg.AfterFunc(timeout, func() { 76 | panic("test") 77 | }) 78 | } 79 | 80 | func testTimerNormalExecMany(tg *Timer, timeout time.Duration) { 81 | ch4 := make(chan int, 5) 82 | for i := 0; i < 5; i++ { 83 | n := i + 1 84 | if n == 3 { 85 | n = 5 86 | } else if n == 5 { 87 | n = 3 88 | } 89 | 90 | tg.AfterFunc(timeout*time.Duration(n), func() { 91 | ch4 <- n 92 | }) 93 | } 94 | 95 | for i := 0; i < 5; i++ { 96 | n := <-ch4 97 | if n != i+1 { 98 | log.Panicf("invalid n: %v, %v", i, n) 99 | } 100 | } 101 | } 102 | 103 | func testTimerExecManyRandtime(tg *Timer) { 104 | its := make([]*time.Timer, 100)[0:0] 105 | ch5 := make(chan int, 100) 106 | for i := 0; i < 100; i++ { 107 | n := 500 + rand.Int()%200 108 | to := time.Duration(n) * time.Second / 1000 109 | its = append(its, tg.AfterFunc(to, func() { 110 | ch5 <- n 111 | })) 112 | } 113 | for i := 0; i < 50; i++ { 114 | if its[0] == nil { 115 | log.Panicf("invalid its[0]") 116 | } 117 | its[0].Stop() 118 | its = its[1:] 119 | } 120 | recved := 0 121 | LOOP_RECV: 122 | for { 123 | select { 124 | case <-ch5: 125 | recved++ 126 | case <-time.After(time.Second): 127 | break LOOP_RECV 128 | } 129 | } 130 | if recved != 50 { 131 | log.Panicf("invalid recved num: %v", recved) 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /tools/norace/norace.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "os" 7 | "path/filepath" 8 | "runtime" 9 | "strings" 10 | "sync" 11 | "time" 12 | ) 13 | 14 | var ( 15 | root, _ = filepath.Abs("./") 16 | 17 | skipPaths = []string{ 18 | "_test.go", 19 | "norace.go", 20 | } 21 | 22 | chTask = make(chan func(), 32) 23 | ) 24 | 25 | func main() { 26 | defer close(chTask) 27 | 28 | for i := 0; i < runtime.NumCPU(); i++ { 29 | go func() { 30 | for f := range chTask { 31 | f() 32 | } 33 | }() 34 | } 35 | 36 | wg := &sync.WaitGroup{} 37 | 38 | walk(wg, root) 39 | 40 | // wg.Done() 41 | fmt.Println("wait") 42 | wg.Wait() 43 | 44 | fmt.Println("exit") 45 | } 46 | 47 | func run(f func()) { 48 | chTask <- f 49 | } 50 | 51 | func walk(wg *sync.WaitGroup, currRoot string) { 52 | err := filepath.Walk(currRoot, func(path string, info os.FileInfo, err error) error { 53 | if info.IsDir() { 54 | 55 | } else { 56 | if shouldSkip(path) { 57 | return nil 58 | } 59 | wg.Add(1) 60 | run(func() { 61 | defer wg.Done() 62 | addNorace(path, info) 63 | }) 64 | } 65 | return nil 66 | }) 67 | if err != nil { 68 | panic(err) 69 | } 70 | } 71 | 72 | func shouldSkip(path string) bool { 73 | path, _ = filepath.Abs(path) 74 | if !strings.HasSuffix(path, ".go") { 75 | return true 76 | } 77 | for _, v := range skipPaths { 78 | if strings.Contains(path, v) { 79 | return true 80 | } 81 | } 82 | return path == root || path == "." || path == "./" || path == "\\." 83 | } 84 | 85 | func addNorace(path string, info os.FileInfo) { 86 | data, err := os.ReadFile(path) 87 | if err != nil { 88 | panic(err) 89 | } 90 | s := string(data) 91 | s = strings.Replace(s, "\nfunc", "\n//go:norace\nfunc", -1) 92 | tag := "//go:norace\n" 93 | tag2 := tag + tag 94 | for strings.Contains(s, tag2) { 95 | s = strings.Replace(s, tag2, tag, -1) 96 | } 97 | data = []byte(s) 98 | 99 | tmpFile := path + time.Now().Format(".20060102150405.dec") 100 | err = os.WriteFile(tmpFile, data, info.Mode().Perm()) 101 | if err != nil { 102 | log.Printf("xxx WriteFile origin [%v] failed: %v", path, err) 103 | panic(err) 104 | } 105 | 106 | err = os.Remove(path) 107 | if err != nil { 108 | log.Printf("xxx Remove origin [%v] failed: %v", path, err) 109 | panic(err) 110 | } 111 | 112 | err = os.Rename(tmpFile, path) 113 | if err != nil { 114 | log.Printf("xxx Rename tmp file[%v] -> origin file[%v] failed: %v", tmpFile, path, err) 115 | panic(err) 116 | } 117 | log.Printf("+++ add norace for file[%v]", path) 118 | } 119 | -------------------------------------------------------------------------------- /writev_bsd.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 lesismal. All rights reserved. 2 | // Use of this source code is governed by an MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | //go:build darwin || netbsd || freebsd || openbsd || dragonfly 6 | // +build darwin netbsd freebsd openbsd dragonfly 7 | 8 | package nbio 9 | 10 | import ( 11 | "syscall" 12 | ) 13 | 14 | //go:norace 15 | func writev(c *Conn, iovs [][]byte) (int, error) { 16 | size := 0 17 | for _, v := range iovs { 18 | size += len(v) 19 | } 20 | pbuf := c.p.g.BodyAllocator.Malloc(size) 21 | *pbuf = (*pbuf)[0:0] 22 | for _, v := range iovs { 23 | pbuf = c.p.g.BodyAllocator.Append(pbuf, v...) 24 | } 25 | n, err := syscall.Write(c.fd, *pbuf) 26 | c.p.g.BodyAllocator.Free(pbuf) 27 | return n, err 28 | } 29 | -------------------------------------------------------------------------------- /writev_linux.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 lesismal. All rights reserved. 2 | // Use of this source code is governed by an MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | //go:build linux 6 | // +build linux 7 | 8 | package nbio 9 | 10 | import ( 11 | "syscall" 12 | "unsafe" 13 | ) 14 | 15 | //go:norace 16 | func writev(c *Conn, bs [][]byte) (int, error) { 17 | iovs := make([]syscall.Iovec, len(bs))[0:0] 18 | for _, b := range bs { 19 | if len(b) > 0 { 20 | v := syscall.Iovec{} 21 | v.SetLen(len(b)) 22 | v.Base = &b[0] 23 | iovs = append(iovs, v) 24 | } 25 | } 26 | 27 | if len(iovs) > 0 { 28 | var _p0 = unsafe.Pointer(&iovs[0]) 29 | var n, _, err = syscall.Syscall(syscall.SYS_WRITEV, uintptr(c.fd), uintptr(_p0), uintptr(len(iovs))) 30 | if err == 0 { 31 | return int(n), nil 32 | } 33 | return int(n), err 34 | } 35 | return 0, nil 36 | } 37 | --------------------------------------------------------------------------------