├── .ci └── code-coverage.sh ├── .codecov.yml ├── .github └── workflows │ └── ci.yml ├── LICENSE ├── README.md ├── conn.go ├── cxx_zmq4_compat.go ├── czmq4_test.go ├── dealer.go ├── example ├── hwclient.go ├── hwserver.go ├── psenvpub.go ├── psenvsub.go ├── rrclient.go ├── rrworker.go └── rtdealer.go ├── go.mod ├── go.sum ├── internal ├── errgroup │ ├── errgroup.go │ └── errgroup_test.go ├── inproc │ ├── conn.go │ ├── inproc.go │ ├── inproc_test.go │ └── transport.go └── leaks_test │ └── reaper_leak_test.go ├── msg.go ├── msgio.go ├── null_security.go ├── options.go ├── pair.go ├── protocol.go ├── protocol_test.go ├── proxy.go ├── proxy_test.go ├── pub.go ├── pull.go ├── push.go ├── queue.go ├── queue_test.go ├── reaper_test.go ├── rep.go ├── rep_test.go ├── req.go ├── router.go ├── security.go ├── security ├── null │ ├── null.go │ └── null_test.go └── plain │ ├── plain.go │ ├── plain_cxx_test.go │ ├── plain_test.go │ └── testdata │ └── password.txt ├── security_test.go ├── socket.go ├── socket_test.go ├── socket_types.go ├── sub.go ├── transport.go ├── transport └── transport.go ├── transport_test.go ├── utils.go ├── utils_test.go ├── xpub.go ├── xsub.go ├── zall_test.go ├── zmq4.go ├── zmq4_pair_test.go ├── zmq4_pubsub_test.go ├── zmq4_pushpull_test.go ├── zmq4_reqrep_test.go ├── zmq4_routerdealer_test.go ├── zmq4_test.go ├── zmq4_timeout_test.go └── zmq4_xpubsub_test.go /.ci/code-coverage.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright 2018 The go-zeromq Authors. All rights reserved. 4 | # Use of this source code is governed by a BSD-style 5 | # license that can be found in the LICENSE file. 6 | 7 | 8 | set -e 9 | 10 | echo "" > coverage.txt 11 | 12 | for d in $(go list ./... | grep -v vendor); do 13 | go test -v $TAGS -race -coverprofile=profile.out -covermode=atomic $d 14 | if [ -f profile.out ]; then 15 | cat profile.out >> coverage.txt 16 | rm profile.out 17 | fi 18 | done 19 | -------------------------------------------------------------------------------- /.codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | enabled: yes 6 | patch: 7 | default: 8 | enabled: no 9 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | env: 10 | TAGS: "-tags=czmq4" 11 | 12 | jobs: 13 | main: 14 | strategy: 15 | matrix: 16 | platform: [ubuntu-latest] 17 | go-version: [1.22.x, 1.21.x] 18 | 19 | runs-on: ${{ matrix.platform }} 20 | 21 | steps: 22 | - name: Checkout code 23 | uses: actions/checkout@v4 24 | 25 | - name: Install Go 26 | uses: actions/setup-go@v5 27 | with: 28 | go-version: ${{ matrix.go-version }} 29 | cache: true 30 | 31 | - name: Install Linux packages 32 | if: matrix.platform == 'ubuntu-latest' 33 | run: | 34 | sudo apt-get update -qq -y 35 | sudo apt-get install -y libsodium-dev libczmq-dev 36 | 37 | - name: go install 38 | run: GOARCH=amd64 go install -v $TAGS ./... 39 | 40 | - name: Test Linux 41 | run: ./.ci/code-coverage.sh 42 | 43 | - name: Verify code formatting 44 | run: | 45 | test -z $(gofmt -l .) 46 | 47 | - name: Run linter 48 | uses: dominikh/staticcheck-action@v1 49 | with: 50 | install-go: false 51 | cache-key: ${{ matrix.platform }} 52 | version: "2023.1" 53 | 54 | - name: Upload code coverage 55 | uses: codecov/codecov-action@v3 56 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright ©2018 The go-zeromq Authors. All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions are met: 5 | * Redistributions of source code must retain the above copyright 6 | notice, this list of conditions and the following disclaimer. 7 | * Redistributions in binary form must reproduce the above copyright 8 | notice, this list of conditions and the following disclaimer in the 9 | documentation and/or other materials provided with the distribution. 10 | * Neither the name of the go-zeromq project nor the names of its authors and 11 | contributors may be used to endorse or promote products derived from this 12 | software without specific prior written permission. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 15 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 16 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # zmq4 2 | 3 | [![GitHub release](https://img.shields.io/github/release/go-zeromq/zmq4.svg)](https://github.com/go-zeromq/zmq4/releases) 4 | [![go.dev reference](https://pkg.go.dev/badge/github.com/go-zeromq/zmq4)](https://pkg.go.dev/github.com/go-zeromq/zmq4) 5 | [![CI](https://github.com/go-zeromq/zmq4/workflows/CI/badge.svg)](https://github.com/go-zeromq/zmq4/actions) 6 | [![codecov](https://codecov.io/gh/go-zeromq/zmq4/branch/main/graph/badge.svg)](https://codecov.io/gh/go-zeromq/zmq4) 7 | [![GoDoc](https://godoc.org/github.com/go-zeromq/zmq4?status.svg)](https://godoc.org/github.com/go-zeromq/zmq4) 8 | [![License](https://img.shields.io/badge/License-BSD--3-blue.svg)](https://github.com/go-zeromq/license) 9 | [![DOI](https://zenodo.org/badge/129430151.svg)](https://zenodo.org/badge/latestdoi/129430151) 10 | 11 | `zmq4` is a pure-Go implementation of ØMQ (ZeroMQ), version 4. 12 | 13 | See [zeromq.org](http://zeromq.org) for more informations. 14 | 15 | ## Development 16 | 17 | `zmq4` needs a caring maintainer. 18 | I (`sbinet`) have not much time to dedicate anymore to this project (as `$WORK` doesn't need it anymore). 19 | 20 | ## License 21 | 22 | `zmq4` is released under the `BSD-3` license. 23 | 24 | ## Documentation 25 | 26 | Documentation for `zmq4` is served by [GoDoc](https://godoc.org/github.com/go-zeromq/zmq4). 27 | 28 | 29 | -------------------------------------------------------------------------------- /cxx_zmq4_compat.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | //go:build czmq4 6 | // +build czmq4 7 | 8 | package zmq4 9 | 10 | import ( 11 | "context" 12 | "fmt" 13 | "net" 14 | "strings" 15 | 16 | czmq4 "github.com/go-zeromq/goczmq/v4" 17 | ) 18 | 19 | func NewCPair(ctx context.Context, opts ...czmq4.SockOption) Socket { 20 | return newCSocket(czmq4.Pair, opts...) 21 | } 22 | 23 | func NewCPub(ctx context.Context, opts ...czmq4.SockOption) Socket { 24 | return newCSocket(czmq4.Pub, opts...) 25 | } 26 | 27 | func NewCSub(ctx context.Context, opts ...czmq4.SockOption) Socket { 28 | return newCSocket(czmq4.Sub, opts...) 29 | } 30 | 31 | func NewCReq(ctx context.Context, opts ...czmq4.SockOption) Socket { 32 | return newCSocket(czmq4.Req, opts...) 33 | } 34 | 35 | func NewCRep(ctx context.Context, opts ...czmq4.SockOption) Socket { 36 | return newCSocket(czmq4.Rep, opts...) 37 | } 38 | 39 | func NewCDealer(ctx context.Context, opts ...czmq4.SockOption) Socket { 40 | return newCSocket(czmq4.Dealer, opts...) 41 | } 42 | 43 | func NewCRouter(ctx context.Context, opts ...czmq4.SockOption) Socket { 44 | return newCSocket(czmq4.Router, opts...) 45 | } 46 | 47 | func NewCPull(ctx context.Context, opts ...czmq4.SockOption) Socket { 48 | return newCSocket(czmq4.Pull, opts...) 49 | } 50 | 51 | func NewCPush(ctx context.Context, opts ...czmq4.SockOption) Socket { 52 | return newCSocket(czmq4.Push, opts...) 53 | } 54 | 55 | func NewCXPub(ctx context.Context, opts ...czmq4.SockOption) Socket { 56 | return newCSocket(czmq4.XPub, opts...) 57 | } 58 | 59 | func NewCXSub(ctx context.Context, opts ...czmq4.SockOption) Socket { 60 | return newCSocket(czmq4.XSub, opts...) 61 | } 62 | 63 | type csocket struct { 64 | sock *czmq4.Sock 65 | addr net.Addr 66 | } 67 | 68 | func newCSocket(ctyp int, opts ...czmq4.SockOption) *csocket { 69 | sck := &csocket{sock: czmq4.NewSock(ctyp)} 70 | for _, opt := range opts { 71 | opt(sck.sock) 72 | } 73 | return sck 74 | } 75 | 76 | func (sck *csocket) Close() error { 77 | sck.sock.Destroy() 78 | return nil 79 | } 80 | 81 | // Send puts the message on the outbound send queue. 82 | // Send blocks until the message can be queued or the send deadline expires. 83 | func (sck *csocket) Send(msg Msg) error { 84 | return sck.sock.SendMessage(msg.Frames) 85 | } 86 | 87 | // SendMulti puts the message on the outbound send queue. 88 | // SendMulti blocks until the message can be queued or the send deadline expires. 89 | // The message will be sent as a multipart message. 90 | func (sck *csocket) SendMulti(msg Msg) error { 91 | return sck.sock.SendMessage(msg.Frames) 92 | } 93 | 94 | // Recv receives a complete message. 95 | func (sck *csocket) Recv() (Msg, error) { 96 | frames, err := sck.sock.RecvMessage() 97 | return Msg{Frames: frames}, err 98 | } 99 | 100 | // Listen connects a local endpoint to the Socket. 101 | func (sck *csocket) Listen(addr string) error { 102 | port, err := sck.sock.Bind(addr) 103 | if err != nil { 104 | return err 105 | } 106 | sck.addr = netAddrFrom(port, addr) 107 | return nil 108 | } 109 | 110 | // Dial connects a remote endpoint to the Socket. 111 | func (sck *csocket) Dial(addr string) error { 112 | return sck.sock.Connect(addr) 113 | } 114 | 115 | // Type returns the type of this Socket (PUB, SUB, ...) 116 | func (sck *csocket) Type() SocketType { 117 | switch sck.sock.GetType() { 118 | case czmq4.Pair: 119 | return Pair 120 | case czmq4.Pub: 121 | return Pub 122 | case czmq4.Sub: 123 | return Sub 124 | case czmq4.Req: 125 | return Req 126 | case czmq4.Rep: 127 | return Rep 128 | case czmq4.Dealer: 129 | return Dealer 130 | case czmq4.Router: 131 | return Router 132 | case czmq4.Pull: 133 | return Pull 134 | case czmq4.Push: 135 | return Push 136 | case czmq4.XPub: 137 | return XPub 138 | case czmq4.XSub: 139 | return XSub 140 | } 141 | panic("invalid C-socket type") 142 | } 143 | 144 | // Addr returns the listener's address. 145 | // Addr returns nil if the socket isn't a listener. 146 | func (sck *csocket) Addr() net.Addr { 147 | return sck.addr 148 | } 149 | 150 | // Conn returns the underlying net.Conn the socket is bound to. 151 | func (sck *csocket) Conn() net.Conn { 152 | panic("not implemented") 153 | } 154 | 155 | // GetOption is used to retrieve an option for a socket. 156 | func (sck *csocket) GetOption(name string) (interface{}, error) { 157 | panic("not implemented") 158 | } 159 | 160 | // SetOption is used to set an option for a socket. 161 | func (sck *csocket) SetOption(name string, value interface{}) error { 162 | switch name { 163 | case OptionSubscribe: 164 | topic := value.(string) 165 | sck.sock.SetOption(czmq4.SockSetSubscribe(topic)) 166 | return nil 167 | case OptionUnsubscribe: 168 | topic := value.(string) 169 | sck.sock.SetOption(czmq4.SockSetUnsubscribe(topic)) 170 | return nil 171 | default: 172 | panic("unknown set option name [" + name + "]") 173 | } 174 | panic("not implemented") 175 | } 176 | 177 | // CWithID configures a ZeroMQ socket identity. 178 | func CWithID(id SocketIdentity) czmq4.SockOption { 179 | return czmq4.SockSetIdentity(string(id)) 180 | } 181 | 182 | func netAddrFrom(port int, ep string) net.Addr { 183 | network, addr, err := splitAddr(ep) 184 | if err != nil { 185 | panic(err) 186 | } 187 | switch network { 188 | case "ipc": 189 | network = "unix" 190 | case "tcp": 191 | network = "tcp" 192 | case "udp": 193 | network = "udp" 194 | case "inproc": 195 | network = "inproc" 196 | default: 197 | panic("zmq4: unknown protocol [" + network + "]") 198 | } 199 | if idx := strings.Index(addr, ":"); idx != -1 { 200 | addr = string(addr[:idx]) 201 | } 202 | return caddr{host: addr, port: fmt.Sprintf("%d", port), net: network} 203 | } 204 | 205 | type caddr struct { 206 | host string 207 | port string 208 | net string 209 | } 210 | 211 | func (addr caddr) Network() string { return addr.net } 212 | func (addr caddr) String() string { 213 | return addr.host + ":" + addr.port 214 | } 215 | 216 | var ( 217 | _ Socket = (*csocket)(nil) 218 | _ net.Addr = (*caddr)(nil) 219 | ) 220 | -------------------------------------------------------------------------------- /dealer.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4 6 | 7 | import ( 8 | "context" 9 | "net" 10 | ) 11 | 12 | // NewDealer returns a new DEALER ZeroMQ socket. 13 | // The returned socket value is initially unbound. 14 | func NewDealer(ctx context.Context, opts ...Option) Socket { 15 | dealer := &dealerSocket{newSocket(ctx, Dealer, opts...)} 16 | return dealer 17 | } 18 | 19 | // dealerSocket is a DEALER ZeroMQ socket. 20 | type dealerSocket struct { 21 | sck *socket 22 | } 23 | 24 | // Close closes the open Socket 25 | func (dealer *dealerSocket) Close() error { 26 | return dealer.sck.Close() 27 | } 28 | 29 | // Send puts the message on the outbound send queue. 30 | // Send blocks until the message can be queued or the send deadline expires. 31 | func (dealer *dealerSocket) Send(msg Msg) error { 32 | return dealer.sck.Send(msg) 33 | } 34 | 35 | // SendMulti puts the message on the outbound send queue. 36 | // SendMulti blocks until the message can be queued or the send deadline expires. 37 | // The message will be sent as a multipart message. 38 | func (dealer *dealerSocket) SendMulti(msg Msg) error { 39 | return dealer.sck.SendMulti(msg) 40 | } 41 | 42 | // Recv receives a complete message. 43 | func (dealer *dealerSocket) Recv() (Msg, error) { 44 | return dealer.sck.Recv() 45 | } 46 | 47 | // Listen connects a local endpoint to the Socket. 48 | func (dealer *dealerSocket) Listen(ep string) error { 49 | return dealer.sck.Listen(ep) 50 | } 51 | 52 | // Dial connects a remote endpoint to the Socket. 53 | func (dealer *dealerSocket) Dial(ep string) error { 54 | return dealer.sck.Dial(ep) 55 | } 56 | 57 | // Type returns the type of this Socket (PUB, SUB, ...) 58 | func (dealer *dealerSocket) Type() SocketType { 59 | return dealer.sck.Type() 60 | } 61 | 62 | // Addr returns the listener's address. 63 | // Addr returns nil if the socket isn't a listener. 64 | func (dealer *dealerSocket) Addr() net.Addr { 65 | return dealer.sck.Addr() 66 | } 67 | 68 | // GetOption is used to retrieve an option for a socket. 69 | func (dealer *dealerSocket) GetOption(name string) (interface{}, error) { 70 | return dealer.sck.GetOption(name) 71 | } 72 | 73 | // SetOption is used to set an option for a socket. 74 | func (dealer *dealerSocket) SetOption(name string, value interface{}) error { 75 | return dealer.sck.SetOption(name, value) 76 | } 77 | 78 | var ( 79 | _ Socket = (*dealerSocket)(nil) 80 | ) 81 | -------------------------------------------------------------------------------- /example/hwclient.go: -------------------------------------------------------------------------------- 1 | // Copyright 2022 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | //go:build ignore 6 | // +build ignore 7 | 8 | // Hello client 9 | package main 10 | 11 | import ( 12 | "context" 13 | "fmt" 14 | "log" 15 | "time" 16 | 17 | zmq "github.com/go-zeromq/zmq4" 18 | ) 19 | 20 | func main() { 21 | if err := hwclient(); err != nil { 22 | log.Fatalf("hwclient: %v", err) 23 | } 24 | } 25 | 26 | func hwclient() error { 27 | ctx := context.Background() 28 | socket := zmq.NewReq(ctx, zmq.WithDialerRetry(time.Second)) 29 | defer socket.Close() 30 | 31 | fmt.Printf("Connecting to hello world server...") 32 | if err := socket.Dial("tcp://localhost:5555"); err != nil { 33 | return fmt.Errorf("dialing: %w", err) 34 | } 35 | 36 | for i := 0; i < 10; i++ { 37 | // Send hello. 38 | m := zmq.NewMsgString("hello") 39 | fmt.Println("sending ", m) 40 | if err := socket.Send(m); err != nil { 41 | return fmt.Errorf("sending: %w", err) 42 | } 43 | 44 | // Wait for reply. 45 | r, err := socket.Recv() 46 | if err != nil { 47 | return fmt.Errorf("receiving: %w", err) 48 | } 49 | fmt.Println("received ", r.String()) 50 | } 51 | return nil 52 | } 53 | -------------------------------------------------------------------------------- /example/hwserver.go: -------------------------------------------------------------------------------- 1 | // Copyright 2022 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | //go:build ignore 6 | // +build ignore 7 | 8 | // Hello server 9 | package main 10 | 11 | import ( 12 | "context" 13 | "fmt" 14 | "log" 15 | "time" 16 | 17 | zmq "github.com/go-zeromq/zmq4" 18 | ) 19 | 20 | func main() { 21 | if err := hwserver(); err != nil { 22 | log.Fatalf("hwserver: %w", err) 23 | } 24 | } 25 | 26 | func hwserver() error { 27 | ctx := context.Background() 28 | // Socket to talk to clients 29 | socket := zmq.NewRep(ctx) 30 | defer socket.Close() 31 | if err := socket.Listen("tcp://*:5555"); err != nil { 32 | return fmt.Errorf("listening: %w", err) 33 | } 34 | 35 | for { 36 | msg, err := socket.Recv() 37 | if err != nil { 38 | return fmt.Errorf("receiving: %w", err) 39 | } 40 | fmt.Println("Received ", msg) 41 | 42 | // Do some 'work' 43 | time.Sleep(time.Second) 44 | 45 | reply := fmt.Sprintf("World") 46 | if err := socket.Send(zmq.NewMsgString(reply)); err != nil { 47 | return fmt.Errorf("sending reply: %w", err) 48 | } 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /example/psenvpub.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | //go:build ignore 6 | // +build ignore 7 | 8 | // PubSub envelope publisher 9 | package main 10 | 11 | import ( 12 | "context" 13 | "log" 14 | "time" 15 | 16 | "github.com/go-zeromq/zmq4" 17 | ) 18 | 19 | func main() { 20 | log.SetPrefix("psenvpub: ") 21 | 22 | // prepare the publisher 23 | pub := zmq4.NewPub(context.Background()) 24 | defer pub.Close() 25 | 26 | err := pub.Listen("tcp://*:5563") 27 | if err != nil { 28 | log.Fatalf("could not listen: %v", err) 29 | } 30 | 31 | msgA := zmq4.NewMsgFrom( 32 | []byte("A"), 33 | []byte("We don't want to see this"), 34 | ) 35 | msgB := zmq4.NewMsgFrom( 36 | []byte("B"), 37 | []byte("We would like to see this"), 38 | ) 39 | for { 40 | // Write two messages, each with an envelope and content 41 | err = pub.Send(msgA) 42 | if err != nil { 43 | log.Fatal(err) 44 | } 45 | err = pub.Send(msgB) 46 | if err != nil { 47 | log.Fatal(err) 48 | } 49 | time.Sleep(time.Second) 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /example/psenvsub.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | //go:build ignore 6 | // +build ignore 7 | 8 | // PubSub envelope subscriber 9 | package main 10 | 11 | import ( 12 | "context" 13 | "log" 14 | 15 | "github.com/go-zeromq/zmq4" 16 | ) 17 | 18 | func main() { 19 | log.SetPrefix("psenvsub: ") 20 | 21 | // Prepare our subscriber 22 | sub := zmq4.NewSub(context.Background()) 23 | defer sub.Close() 24 | 25 | err := sub.Dial("tcp://localhost:5563") 26 | if err != nil { 27 | log.Fatalf("could not dial: %v", err) 28 | } 29 | 30 | err = sub.SetOption(zmq4.OptionSubscribe, "B") 31 | if err != nil { 32 | log.Fatalf("could not subscribe: %v", err) 33 | } 34 | 35 | for { 36 | // Read envelope 37 | msg, err := sub.Recv() 38 | if err != nil { 39 | log.Fatalf("could not receive message: %v", err) 40 | } 41 | log.Printf("[%s] %s\n", msg.Frames[0], msg.Frames[1]) 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /example/rrclient.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | //go:build ignore 6 | // +build ignore 7 | 8 | // Request-reply client. 9 | // 10 | // Connects REQ socket to tcp://localhost:5559 11 | // Sends "Hello" to server, expects "World" back 12 | package main 13 | 14 | import ( 15 | "context" 16 | "log" 17 | 18 | "github.com/go-zeromq/zmq4" 19 | ) 20 | 21 | func main() { 22 | log.SetPrefix("rrclient: ") 23 | 24 | req := zmq4.NewReq(context.Background()) 25 | defer req.Close() 26 | 27 | err := req.Dial("tcp://localhost:5559") 28 | if err != nil { 29 | log.Fatalf("could not dial: %v", err) 30 | } 31 | 32 | for i := 0; i < 10; i++ { 33 | err := req.Send(zmq4.NewMsgString("Hello")) 34 | if err != nil { 35 | log.Fatalf("could not send greeting: %v", err) 36 | } 37 | 38 | msg, err := req.Recv() 39 | if err != nil { 40 | log.Fatalf("could not recv greeting: %v", err) 41 | } 42 | log.Printf("received reply %d [%s]\n", i, msg.Frames[0]) 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /example/rrworker.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | //go:build ignore 6 | // +build ignore 7 | 8 | // Request-reply worker. 9 | // 10 | // Connects REP socket to tcp://*:5559 11 | // Expects "Hello" from client, replies with "World" 12 | package main 13 | 14 | import ( 15 | "context" 16 | "log" 17 | "time" 18 | 19 | "github.com/go-zeromq/zmq4" 20 | ) 21 | 22 | func main() { 23 | log.SetPrefix("rrworker: ") 24 | 25 | // Socket to talk to clients 26 | rep := zmq4.NewRep(context.Background()) 27 | defer rep.Close() 28 | 29 | err := rep.Listen("tcp://*:5559") 30 | if err != nil { 31 | log.Fatalf("could not dial: %v", err) 32 | } 33 | 34 | for { 35 | // Wait for next request from client 36 | msg, err := rep.Recv() 37 | if err != nil { 38 | log.Fatalf("could not recv request: %v", err) 39 | } 40 | 41 | log.Printf("received request: [%s]\n", msg.Frames[0]) 42 | 43 | // Do some 'work' 44 | time.Sleep(time.Second) 45 | 46 | // Send reply back to client 47 | err = rep.Send(zmq4.NewMsgString("World")) 48 | if err != nil { 49 | log.Fatalf("could not send reply: %v", err) 50 | } 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /example/rtdealer.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | //go:build ignore 6 | // +build ignore 7 | 8 | // Router/Dealer example. 9 | package main 10 | 11 | import ( 12 | "bytes" 13 | "context" 14 | "fmt" 15 | "log" 16 | "math/rand" 17 | "sync" 18 | "time" 19 | 20 | "github.com/go-zeromq/zmq4" 21 | ) 22 | 23 | const ( 24 | NWORKERS = 10 25 | endpoint = "tcp://localhost:5671" 26 | ) 27 | 28 | var ( 29 | Fired = []byte("Fired!") 30 | WorkHarder = []byte("Work Harder!") 31 | 32 | ready = zmq4.NewMsgFrom([]byte(""), []byte("ready")) 33 | ) 34 | 35 | func main() { 36 | rand.Seed(1234) 37 | bkg := context.Background() 38 | router := zmq4.NewCRouter(bkg, zmq4.CWithID(zmq4.SocketIdentity("router"))) 39 | 40 | err := router.Listen("tcp://*:5671") 41 | if err != nil { 42 | log.Fatalf("could not listen %q: %v", endpoint, err) 43 | } 44 | defer router.Close() 45 | 46 | var wg sync.WaitGroup 47 | wg.Add(NWORKERS) 48 | for i := 0; i < NWORKERS; i++ { 49 | go worker(i, &wg) 50 | } 51 | 52 | nfired := 0 53 | for { 54 | msg, err := router.Recv() 55 | if err != nil { 56 | log.Fatalf("router failed to recv message: %v", err) 57 | } 58 | 59 | id := msg.Frames[0] 60 | fire := rand.Float64() * 100 61 | switch { 62 | case fire < 30: 63 | msg = zmq4.NewMsgFrom(id, []byte(""), Fired) 64 | nfired++ 65 | default: 66 | msg = zmq4.NewMsgFrom(id, []byte(""), WorkHarder) 67 | } 68 | err = router.Send(msg) 69 | if err != nil { 70 | log.Fatalf("router failed to send message to %q: %v", id, err) 71 | } 72 | if nfired == NWORKERS { 73 | break 74 | } 75 | } 76 | wg.Wait() 77 | log.Printf("fired everybody.") 78 | } 79 | 80 | func worker(i int, wg *sync.WaitGroup) { 81 | id := zmq4.SocketIdentity(fmt.Sprintf("dealer-%d", i)) 82 | dealer := zmq4.NewDealer(context.Background(), zmq4.WithID(id)) 83 | defer dealer.Close() 84 | defer wg.Done() 85 | 86 | err := dealer.Dial(endpoint) 87 | if err != nil { 88 | log.Fatalf("dealer %d failed to dial: %v", i, err) 89 | } 90 | 91 | total := 0 92 | dloop: 93 | for { 94 | // ready to work 95 | err = dealer.Send(ready) 96 | if err != nil { 97 | log.Fatalf("dealer %d failed to send ready message: %v", i, err) 98 | } 99 | 100 | // get workload from broker 101 | msg, err := dealer.Recv() 102 | if err != nil { 103 | log.Fatalf("dealer %d failed to recv message: %v", i, err) 104 | } 105 | work := msg.Frames[1] 106 | if bytes.Equal(work, Fired) { 107 | break dloop 108 | } 109 | 110 | // do some random work 111 | time.Sleep(time.Duration(rand.Intn(500)) * time.Millisecond) 112 | total++ 113 | } 114 | 115 | log.Printf("dealer %d completed %d tasks", i, total) 116 | } 117 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/go-zeromq/zmq4 2 | 3 | go 1.21 4 | 5 | require ( 6 | github.com/go-zeromq/goczmq/v4 v4.2.2 7 | go.uber.org/goleak v1.3.0 8 | golang.org/x/sync v0.7.0 9 | golang.org/x/text v0.15.0 10 | ) 11 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/go-zeromq/goczmq/v4 v4.2.2 h1:HAJN+i+3NW55ijMJJhk7oWxHKXgAuSBkoFfvr8bYj4U= 4 | github.com/go-zeromq/goczmq/v4 v4.2.2/go.mod h1:Sm/lxrfxP/Oxqs0tnHD6WAhwkWrx+S+1MRrKzcxoaYE= 5 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 6 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 7 | github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= 8 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 9 | go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= 10 | go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= 11 | golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= 12 | golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 13 | golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= 14 | golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= 15 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 16 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 17 | -------------------------------------------------------------------------------- /internal/errgroup/errgroup.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package errgroup is bit more advanced than golang.org/x/sync/errgroup. 6 | // Major difference is that when error group is created with WithContext 7 | // the parent context would implicitly cancel all functions called by Go method. 8 | package errgroup 9 | 10 | import ( 11 | "context" 12 | 13 | "golang.org/x/sync/errgroup" 14 | ) 15 | 16 | // The Group is superior errgroup.Group which aborts whole group 17 | // execution when parent context is cancelled 18 | type Group struct { 19 | grp *errgroup.Group 20 | ctx context.Context 21 | } 22 | 23 | // WithContext creates Group and store inside parent context 24 | // so the Go method would respect parent context cancellation 25 | func WithContext(ctx context.Context) (*Group, context.Context) { 26 | grp, child_ctx := errgroup.WithContext(ctx) 27 | return &Group{grp: grp, ctx: ctx}, child_ctx 28 | } 29 | 30 | // Go runs the provided f function in a dedicated goroutine and waits for its 31 | // completion or for the parent context cancellation. 32 | func (g *Group) Go(f func() error) { 33 | g.getErrGroup().Go(g.wrap(f)) 34 | } 35 | 36 | // Wait blocks until all function calls from the Go method have returned, then 37 | // returns the first non-nil error (if any) from them. 38 | // If the error group was created via WithContext then the Wait returns error 39 | // of cancelled parent context prior any functions calls complete. 40 | func (g *Group) Wait() error { 41 | return g.getErrGroup().Wait() 42 | } 43 | 44 | // SetLimit limits the number of active goroutines in this group to at most n. 45 | // A negative value indicates no limit. 46 | // 47 | // Any subsequent call to the Go method will block until it can add an active 48 | // goroutine without exceeding the configured limit. 49 | // 50 | // The limit must not be modified while any goroutines in the group are active. 51 | func (g *Group) SetLimit(n int) { 52 | g.getErrGroup().SetLimit(n) 53 | } 54 | 55 | // TryGo calls the given function in a new goroutine only if the number of 56 | // active goroutines in the group is currently below the configured limit. 57 | // 58 | // The return value reports whether the goroutine was started. 59 | func (g *Group) TryGo(f func() error) bool { 60 | return g.getErrGroup().TryGo(g.wrap(f)) 61 | } 62 | 63 | func (g *Group) wrap(f func() error) func() error { 64 | if g.ctx == nil { 65 | return f 66 | } 67 | 68 | return func() error { 69 | // If parent context is canceled, 70 | // just return its error and do not call func f 71 | select { 72 | case <-g.ctx.Done(): 73 | return g.ctx.Err() 74 | default: 75 | } 76 | 77 | // Create return channel and call func f 78 | // Buffered channel is used as the following select 79 | // may be exiting by context cancellation 80 | // and in such case the write to channel can be block 81 | // and cause the go routine leak 82 | ch := make(chan error, 1) 83 | go func() { 84 | ch <- f() 85 | }() 86 | 87 | // Wait func f complete or 88 | // parent context to be cancelled, 89 | select { 90 | case err := <-ch: 91 | return err 92 | case <-g.ctx.Done(): 93 | return g.ctx.Err() 94 | } 95 | } 96 | } 97 | 98 | // The getErrGroup returns actual x/sync/errgroup.Group. 99 | // If the group is not allocated it would implicitly allocate it. 100 | // Thats allows the internal/errgroup.Group be fully 101 | // compatible to x/sync/errgroup.Group 102 | func (g *Group) getErrGroup() *errgroup.Group { 103 | if g.grp == nil { 104 | g.grp = &errgroup.Group{} 105 | } 106 | return g.grp 107 | } 108 | -------------------------------------------------------------------------------- /internal/errgroup/errgroup_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package errgroup 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "testing" 11 | 12 | "golang.org/x/sync/errgroup" 13 | ) 14 | 15 | // TestRegularErrGroupDoesNotRespectParentContext checks regular errgroup behavior 16 | // where errgroup.WithContext does not respect the parent context 17 | func TestRegularErrGroupDoesNotRespectParentContext(t *testing.T) { 18 | ctx, cancel := context.WithCancel(context.Background()) 19 | eg, _ := errgroup.WithContext(ctx) 20 | 21 | what := fmt.Errorf("func generated error") 22 | ch := make(chan error) 23 | eg.Go(func() error { return <-ch }) 24 | 25 | cancel() // abort parent context 26 | ch <- what // signal the func in regular errgroup to fail 27 | err := eg.Wait() 28 | 29 | // The error shall be one returned by the function 30 | // as regular errgroup.WithContext does not respect parent context 31 | if err != what { 32 | t.Errorf("invalid error. got=%+v, want=%+v", err, what) 33 | } 34 | } 35 | 36 | // TestErrGroupWithContextCanCallFunctions checks the errgroup operations 37 | // are fine working and errgroup called function can return error 38 | func TestErrGroupWithContextCanCallFunctions(t *testing.T) { 39 | ctx, cancel := context.WithCancel(context.Background()) 40 | defer cancel() 41 | eg, _ := WithContext(ctx) 42 | 43 | what := fmt.Errorf("func generated error") 44 | ch := make(chan error) 45 | eg.Go(func() error { return <-ch }) 46 | 47 | ch <- what // signal the func in errgroup to fail 48 | err := eg.Wait() // wait errgroup complete and read error 49 | 50 | // The error shall be one returned by the function 51 | if err != what { 52 | t.Errorf("invalid error. got=%+v, want=%+v", err, what) 53 | } 54 | } 55 | 56 | // TestErrGroupWithContextDoesRespectParentContext checks the errgroup operations 57 | // are cancellable by parent context 58 | func TestErrGroupWithContextDoesRespectParentContext(t *testing.T) { 59 | ctx, cancel := context.WithCancel(context.Background()) 60 | eg, _ := WithContext(ctx) 61 | 62 | s1 := make(chan struct{}) 63 | s2 := make(chan struct{}) 64 | eg.Go(func() error { 65 | s1 <- struct{}{} 66 | <-s2 67 | return fmt.Errorf("func generated error") 68 | }) 69 | 70 | // We have no set limit to errgroup so 71 | // shall be able to start function via TryGo 72 | if ok := eg.TryGo(func() error { return nil }); !ok { 73 | t.Errorf("Expected TryGo to be able start function!!!") 74 | } 75 | 76 | <-s1 // wait for function to start 77 | cancel() // abort parent context 78 | 79 | eg.Go(func() error { 80 | t.Errorf("The parent context was already cancelled and this function shall not be called!!!") 81 | return nil 82 | }) 83 | 84 | s2 <- struct{}{} // signal the func in regular errgroup to fail 85 | err := eg.Wait() // wait errgroup complete and read error 86 | 87 | // The error shall be one returned by the function 88 | // as regular errgroup.WithContext does not respect parent context 89 | if err != context.Canceled { 90 | t.Errorf("expected a context.Canceled error, got=%+v", err) 91 | } 92 | } 93 | 94 | // TestErrGroupFallback tests fallback logic to be compatible with x/sync/errgroup 95 | func TestErrGroupFallback(t *testing.T) { 96 | eg := Group{} 97 | eg.SetLimit(2) 98 | 99 | ch1 := make(chan error) 100 | eg.Go(func() error { return <-ch1 }) 101 | 102 | ch2 := make(chan error) 103 | ok := eg.TryGo(func() error { return <-ch2 }) 104 | if !ok { 105 | t.Errorf("Expected errgroup.TryGo to success!!!") 106 | } 107 | 108 | // The limit set to 2, so 3rd function shall not be possible to call 109 | ok = eg.TryGo(func() error { 110 | t.Errorf("This function is unexpected to be called!!!") 111 | return nil 112 | }) 113 | if ok { 114 | t.Errorf("Expected errgroup.TryGo to fail!!!") 115 | } 116 | 117 | ch1 <- nil 118 | ch2 <- nil 119 | err := eg.Wait() 120 | 121 | if err != nil { 122 | t.Errorf("expected a nil error, got=%+v", err) 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /internal/inproc/conn.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Copyright 2010 The Go Authors. All rights reserved. 6 | // Use of this source code is governed by a BSD-style 7 | // license that can be found in the LICENSE file. 8 | 9 | package inproc 10 | 11 | import ( 12 | "io" 13 | "net" 14 | "sync" 15 | "time" 16 | ) 17 | 18 | type conn struct { 19 | addr Addr 20 | r <-chan []byte 21 | w chan<- []byte 22 | 23 | once sync.Once // Protects closing localDone 24 | localDone chan struct{} 25 | remoteDone <-chan struct{} 26 | 27 | rdeadline pipeDeadline 28 | wdeadline pipeDeadline 29 | } 30 | 31 | func (c *conn) Write(data []byte) (int, error) { 32 | n, err := c.write(data) 33 | if err != nil && err != io.ErrClosedPipe { 34 | err = &net.OpError{Op: "write", Net: "pipe", Err: err} 35 | } 36 | return n, err 37 | } 38 | 39 | func (c *conn) write(data []byte) (int, error) { 40 | switch { 41 | case isClosedChan(c.localDone): 42 | return 0, io.ErrClosedPipe 43 | case isClosedChan(c.remoteDone): 44 | return 0, io.ErrClosedPipe 45 | case isClosedChan(c.wdeadline.wait()): 46 | return 0, timeoutError{} 47 | } 48 | 49 | var n int 50 | select { 51 | case c.w <- data: 52 | n = len(data) 53 | return n, nil 54 | case <-c.localDone: 55 | return n, io.ErrClosedPipe 56 | case <-c.remoteDone: 57 | return n, io.ErrClosedPipe 58 | case <-c.wdeadline.wait(): 59 | return n, timeoutError{} 60 | } 61 | } 62 | 63 | func (c *conn) Read(data []byte) (int, error) { 64 | n, err := c.read(data) 65 | if err != nil && err != io.EOF && err != io.ErrClosedPipe { 66 | err = &net.OpError{Op: "read", Net: "pipe", Err: err} 67 | } 68 | return n, err 69 | } 70 | 71 | func (c *conn) read(data []byte) (int, error) { 72 | switch { 73 | case isClosedChan(c.localDone): 74 | return 0, io.ErrClosedPipe 75 | case isClosedChan(c.remoteDone): 76 | return 0, io.EOF 77 | case isClosedChan(c.rdeadline.wait()): 78 | return 0, timeoutError{} 79 | } 80 | 81 | select { 82 | case bw := <-c.r: 83 | nr := copy(data, bw) 84 | if len(data) < len(bw) { 85 | return nr, io.ErrShortBuffer 86 | } 87 | return nr, nil 88 | case <-c.rdeadline.wait(): 89 | return 0, timeoutError{} 90 | } 91 | } 92 | 93 | func (c *conn) LocalAddr() net.Addr { return c.addr } 94 | func (c *conn) RemoteAddr() net.Addr { return c.addr } 95 | 96 | func (c *conn) SetDeadline(t time.Time) error { 97 | if isClosedChan(c.localDone) || isClosedChan(c.remoteDone) { 98 | return io.ErrClosedPipe 99 | } 100 | c.rdeadline.set(t) 101 | c.wdeadline.set(t) 102 | return nil 103 | } 104 | 105 | func (c *conn) SetReadDeadline(t time.Time) error { 106 | if isClosedChan(c.localDone) || isClosedChan(c.remoteDone) { 107 | return io.ErrClosedPipe 108 | } 109 | c.rdeadline.set(t) 110 | return nil 111 | } 112 | 113 | func (c *conn) SetWriteDeadline(t time.Time) error { 114 | if isClosedChan(c.localDone) || isClosedChan(c.remoteDone) { 115 | return io.ErrClosedPipe 116 | } 117 | c.wdeadline.set(t) 118 | return nil 119 | } 120 | 121 | func (c *conn) Close() error { 122 | c.once.Do(func() { 123 | close(c.localDone) 124 | }) 125 | return nil 126 | } 127 | 128 | // pipeDeadline is an abstraction for handling timeouts. 129 | type pipeDeadline struct { 130 | mu sync.Mutex // Guards timer and cancel 131 | timer *time.Timer 132 | cancel chan struct{} // Must be non-nil 133 | } 134 | 135 | func makePipeDeadline() pipeDeadline { 136 | return pipeDeadline{cancel: make(chan struct{})} 137 | } 138 | 139 | // set sets the point in time when the deadline will time out. 140 | // A timeout event is signaled by closing the channel returned by waiter. 141 | // Once a timeout has occurred, the deadline can be refreshed by specifying a 142 | // t value in the future. 143 | // 144 | // A zero value for t prevents timeout. 145 | func (d *pipeDeadline) set(t time.Time) { 146 | d.mu.Lock() 147 | defer d.mu.Unlock() 148 | 149 | if d.timer != nil && !d.timer.Stop() { 150 | <-d.cancel // Wait for the timer callback to finish and close cancel 151 | } 152 | d.timer = nil 153 | 154 | // Time is zero, then there is no deadline. 155 | closed := isClosedChan(d.cancel) 156 | if t.IsZero() { 157 | if closed { 158 | d.cancel = make(chan struct{}) 159 | } 160 | return 161 | } 162 | 163 | // Time in the future, setup a timer to cancel in the future. 164 | if dur := time.Until(t); dur > 0 { 165 | if closed { 166 | d.cancel = make(chan struct{}) 167 | } 168 | d.timer = time.AfterFunc(dur, func() { 169 | close(d.cancel) 170 | }) 171 | return 172 | } 173 | 174 | // Time in the past, so close immediately. 175 | if !closed { 176 | close(d.cancel) 177 | } 178 | } 179 | 180 | // wait returns a channel that is closed when the deadline is exceeded. 181 | func (d *pipeDeadline) wait() chan struct{} { 182 | d.mu.Lock() 183 | defer d.mu.Unlock() 184 | return d.cancel 185 | } 186 | 187 | func isClosedChan(c <-chan struct{}) bool { 188 | select { 189 | case <-c: 190 | return true 191 | default: 192 | return false 193 | } 194 | } 195 | 196 | type timeoutError struct{} 197 | 198 | func (timeoutError) Error() string { return "deadline exceeded" } 199 | func (timeoutError) Timeout() bool { return true } 200 | func (timeoutError) Temporary() bool { return true } 201 | -------------------------------------------------------------------------------- /internal/inproc/inproc.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Copyright 2010 The Go Authors. All rights reserved. 6 | // Use of this source code is governed by a BSD-style 7 | // license that can be found in the LICENSE file. 8 | 9 | // Package inproc provides tools to implement an in-process asynchronous pipe of net.Conns. 10 | package inproc 11 | 12 | import ( 13 | "errors" 14 | "fmt" 15 | "net" 16 | "strings" 17 | "sync" 18 | ) 19 | 20 | var ( 21 | mgr = contextType{db: make(map[string]*Listener)} 22 | 23 | ErrClosed = errors.New("inproc: connection closed") 24 | ErrConnRefused = errors.New("inproc: connection refused") 25 | ) 26 | 27 | func init() { 28 | mgr.cv.L = &mgr.mu 29 | } 30 | 31 | type contextType struct { 32 | mu sync.Mutex 33 | cv sync.Cond 34 | db map[string]*Listener 35 | } 36 | 37 | // A Listener is an in-process listener for stream-oriented protocols. 38 | // Listener implements net.Listener. 39 | // 40 | // Multiple goroutines may invoke methods on a Listener simultaneously. 41 | type Listener struct { 42 | addr Addr 43 | 44 | pipes []*pipe 45 | closed bool 46 | } 47 | 48 | type pipe struct { 49 | p1 *conn 50 | p2 *conn 51 | } 52 | 53 | func newPipe(addr Addr) *pipe { 54 | const sz = 8 55 | ch1 := make(chan []byte, sz) 56 | ch2 := make(chan []byte, sz) 57 | done1 := make(chan struct{}) 58 | done2 := make(chan struct{}) 59 | 60 | p1 := &conn{ 61 | addr: addr, 62 | r: ch1, 63 | w: ch2, 64 | localDone: done1, 65 | remoteDone: done2, 66 | rdeadline: makePipeDeadline(), 67 | wdeadline: makePipeDeadline(), 68 | } 69 | p2 := &conn{ 70 | addr: addr, 71 | r: ch2, 72 | w: ch1, 73 | localDone: done2, 74 | remoteDone: done1, 75 | rdeadline: makePipeDeadline(), 76 | wdeadline: makePipeDeadline(), 77 | } 78 | return &pipe{p1, p2} 79 | } 80 | 81 | func (p *pipe) Close() error { 82 | e1 := p.p1.Close() 83 | e2 := p.p2.Close() 84 | if e1 != nil { 85 | return e1 86 | } 87 | if e2 != nil { 88 | return e2 89 | } 90 | return nil 91 | } 92 | 93 | // Listen announces on the given address. 94 | func Listen(addr string) (*Listener, error) { 95 | mgr.mu.Lock() 96 | _, dup := mgr.db[addr] 97 | if dup { 98 | mgr.mu.Unlock() 99 | return nil, fmt.Errorf("inproc: address %q already in use", addr) 100 | } 101 | 102 | l := &Listener{ 103 | addr: Addr(addr), 104 | } 105 | mgr.db[addr] = l 106 | mgr.cv.Broadcast() 107 | mgr.mu.Unlock() 108 | 109 | return l, nil 110 | } 111 | 112 | // Addr returns the listener's netword address. 113 | func (l *Listener) Addr() net.Addr { 114 | return l.addr 115 | } 116 | 117 | // Close closes the listener. 118 | // Any blocked Accept operations will be unblocked and return errors. 119 | func (l *Listener) Close() error { 120 | mgr.mu.Lock() 121 | defer mgr.mu.Unlock() 122 | if l.closed { 123 | return nil 124 | } 125 | var err error 126 | for i := range l.pipes { 127 | p := l.pipes[i] 128 | e := p.Close() 129 | if e != nil && err == nil { 130 | err = e 131 | } 132 | } 133 | l.closed = true 134 | delete(mgr.db, string(l.addr)) 135 | return err 136 | } 137 | 138 | // Accept waits for and returns the next connection to the listener. 139 | func (l *Listener) Accept() (net.Conn, error) { 140 | mgr.mu.Lock() 141 | p := newPipe(l.addr) 142 | l.pipes = append(l.pipes, p) 143 | closed := l.closed 144 | mgr.cv.Broadcast() 145 | mgr.mu.Unlock() 146 | 147 | if closed { 148 | return nil, ErrClosed 149 | } 150 | return p.p1, nil 151 | } 152 | 153 | // Dial connects to the given address. 154 | func Dial(addr string) (net.Conn, error) { 155 | mgr.mu.Lock() 156 | 157 | for { 158 | var ( 159 | l *Listener 160 | ok bool 161 | ) 162 | if l, ok = mgr.db[addr]; !ok || l == nil { 163 | mgr.mu.Unlock() 164 | return nil, ErrConnRefused 165 | } 166 | if n := len(l.pipes); n != 0 { 167 | p := l.pipes[n-1] 168 | l.pipes = l.pipes[:n-1] 169 | mgr.mu.Unlock() 170 | return p.p2, nil 171 | } 172 | mgr.cv.Wait() 173 | } 174 | } 175 | 176 | // Addr represents an in-process "network" end-point address. 177 | type Addr string 178 | 179 | // String implements net.Addr.String 180 | func (a Addr) String() string { 181 | return strings.TrimPrefix(string(a), "inproc://") 182 | } 183 | 184 | // Network returns the name of the network. 185 | func (a Addr) Network() string { 186 | return "inproc" 187 | } 188 | 189 | var ( 190 | _ net.Addr = (*Addr)(nil) 191 | _ net.Listener = (*Listener)(nil) 192 | ) 193 | -------------------------------------------------------------------------------- /internal/inproc/inproc_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package inproc 6 | 7 | import ( 8 | "bytes" 9 | "fmt" 10 | "io" 11 | "math/rand" 12 | "reflect" 13 | "testing" 14 | 15 | "golang.org/x/sync/errgroup" 16 | ) 17 | 18 | func TestBasicIO(t *testing.T) { 19 | t.Skip() 20 | 21 | want := make([]byte, 1<<2) 22 | rand.New(rand.NewSource(0)).Read(want) 23 | 24 | pipe := newPipe(Addr("basic-io")) 25 | c1 := pipe.p1 26 | c2 := pipe.p2 27 | 28 | dataCh := make(chan []byte) 29 | go func() { 30 | rd := bytes.NewReader(want) 31 | if err := chunkedCopy(c1, rd); err != nil { 32 | t.Errorf("unexpected c1.Write error: %v", err) 33 | } 34 | }() 35 | 36 | go func() { 37 | wr := new(bytes.Buffer) 38 | if err := chunkedCopy(wr, c2); err != nil { 39 | t.Errorf("unexpected c2.Read error: %v", err) 40 | } 41 | dataCh <- wr.Bytes() 42 | }() 43 | 44 | if got := <-dataCh; !bytes.Equal(got, want) { 45 | // t.Errorf("transmitted data differs") 46 | t.Errorf("transmitted data differs:\ngot= %q\nwnt= %q\n", got, want) 47 | } 48 | 49 | if err := c1.Close(); err != nil { 50 | t.Errorf("unexpected c1.Close error: %v", err) 51 | } 52 | if err := c2.Close(); err != nil { 53 | t.Errorf("unexpected c2.Close error: %v", err) 54 | } 55 | } 56 | 57 | // chunkedCopy copies from r to w in fixed-width chunks to avoid 58 | // causing a Write that exceeds the maximum packet size for packet-based 59 | // connections like "unixpacket". 60 | // We assume that the maximum packet size is at least 1024. 61 | func chunkedCopy(w io.Writer, r io.Reader) error { 62 | // b := make([]byte, 1024) 63 | // _, err := io.CopyBuffer(struct{ io.Writer }{w}, struct{ io.Reader }{r}, b) 64 | _, err := io.Copy(w, r) 65 | return err 66 | } 67 | 68 | func TestRW(t *testing.T) { 69 | const ep = "inproc://rw-srv" 70 | lst, err := Listen(ep) 71 | if err != nil { 72 | t.Fatalf("could not create server: %+v", err) 73 | } 74 | defer lst.Close() 75 | 76 | if addr := lst.Addr(); addr == nil { 77 | t.Fatalf("listener with nil address") 78 | } 79 | if got, want := lst.Addr().String(), ep[len("inproc://"):]; got != want { 80 | t.Fatalf("invalid listener address: got=%q, want=%q", got, want) 81 | } 82 | 83 | var grp errgroup.Group 84 | grp.Go(func() error { 85 | conn, err := lst.Accept() 86 | if err != nil { 87 | return fmt.Errorf("could not accept connection: %w", err) 88 | } 89 | defer conn.Close() 90 | 91 | if addr := conn.LocalAddr(); addr == nil { 92 | t.Fatalf("accept-conn with nil address") 93 | } 94 | if got, want := conn.LocalAddr().String(), ep[len("inproc://"):]; got != want { 95 | t.Fatalf("invalid accept-con address: got=%q, want=%q", got, want) 96 | } 97 | if got, want := conn.RemoteAddr().String(), ep[len("inproc://"):]; got != want { 98 | t.Fatalf("invalid accept-con address: got=%q, want=%q", got, want) 99 | } 100 | 101 | raw := make([]byte, len("HELLO")) 102 | _, err = io.ReadFull(conn, raw) 103 | if err != nil { 104 | return fmt.Errorf("could not read request: %w", err) 105 | } 106 | 107 | if got, want := raw, []byte("HELLO"); !reflect.DeepEqual(got, want) { 108 | return fmt.Errorf("invalid request: got=%v, want=%v", got, want) 109 | } 110 | 111 | _, err = conn.Write([]byte("HELLO")) 112 | if err != nil { 113 | return fmt.Errorf("could not write reply: %w", err) 114 | } 115 | 116 | raw = make([]byte, len("QUIT")) 117 | _, err = io.ReadFull(conn, raw) 118 | if err != nil { 119 | return fmt.Errorf("could not read final request: %w", err) 120 | } 121 | 122 | if got, want := raw, []byte("QUIT"); !reflect.DeepEqual(got, want) { 123 | return fmt.Errorf("invalid request: got=%v, want=%v", got, want) 124 | } 125 | 126 | return nil 127 | }) 128 | 129 | grp.Go(func() error { 130 | conn, err := Dial("inproc://rw-srv") 131 | if err != nil { 132 | return fmt.Errorf("could not dial server: %w", err) 133 | } 134 | defer conn.Close() 135 | 136 | if addr := conn.LocalAddr(); addr == nil { 137 | t.Fatalf("dial-conn with nil address") 138 | } 139 | if got, want := conn.LocalAddr().String(), ep[len("inproc://"):]; got != want { 140 | t.Fatalf("invalid dial-con address: got=%q, want=%q", got, want) 141 | } 142 | if got, want := conn.RemoteAddr().String(), ep[len("inproc://"):]; got != want { 143 | t.Fatalf("invalid dial-con address: got=%q, want=%q", got, want) 144 | } 145 | 146 | _, err = conn.Write([]byte("HELLO")) 147 | if err != nil { 148 | return fmt.Errorf("could not send request: %w", err) 149 | } 150 | 151 | raw := make([]byte, len("HELLO")) 152 | _, err = io.ReadFull(conn, raw) 153 | if err != nil { 154 | return fmt.Errorf("could not read reply: %w", err) 155 | } 156 | 157 | if got, want := raw, []byte("HELLO"); !reflect.DeepEqual(got, want) { 158 | return fmt.Errorf("invalid reply: got=%v, want=%v", got, want) 159 | } 160 | 161 | _, err = conn.Write([]byte("QUIT")) 162 | if err != nil { 163 | return fmt.Errorf("could not write final request: %w", err) 164 | } 165 | 166 | return nil 167 | }) 168 | 169 | err = grp.Wait() 170 | if err != nil { 171 | t.Fatalf("error: %+v", err) 172 | } 173 | } 174 | -------------------------------------------------------------------------------- /internal/inproc/transport.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package inproc 6 | 7 | import ( 8 | "context" 9 | "net" 10 | 11 | "github.com/go-zeromq/zmq4/transport" 12 | ) 13 | 14 | // Transport implements the zmq4 Transport interface for the inproc transport. 15 | type Transport struct{} 16 | 17 | // Dial connects to the address on the named network using the provided 18 | // context. 19 | func (Transport) Dial(ctx context.Context, dialer transport.Dialer, addr string) (net.Conn, error) { 20 | return Dial(addr) 21 | } 22 | 23 | // Listen announces on the provided network address. 24 | func (Transport) Listen(ctx context.Context, addr string) (net.Listener, error) { 25 | return Listen(addr) 26 | } 27 | 28 | // Addr returns the end-point address. 29 | func (Transport) Addr(ep string) (addr string, err error) { 30 | return ep, nil 31 | } 32 | 33 | var ( 34 | _ transport.Transport = (*Transport)(nil) 35 | ) 36 | -------------------------------------------------------------------------------- /internal/leaks_test/reaper_leak_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package leaks_test 6 | 7 | import ( 8 | "context" 9 | "errors" 10 | "sync" 11 | "testing" 12 | "time" 13 | 14 | "github.com/go-zeromq/zmq4" 15 | "go.uber.org/goleak" 16 | ) 17 | 18 | // TestReaper does multiple rapid Dial/Close to check that connection reaper goroutines are not leaking. 19 | // TestReaper is in a dedicated package as goleak detects also goroutines from values created during init(). 20 | func TestReaperLeak1(t *testing.T) { 21 | defer goleak.VerifyNone(t) 22 | 23 | mu := &sync.Mutex{} 24 | errs := []error{} 25 | 26 | ctx, cancel := context.WithCancel(context.Background()) 27 | rep := zmq4.NewRep(ctx) 28 | ep := "ipc://@test.rep.socket" 29 | err := rep.Listen(ep) 30 | if err != nil { 31 | t.Fatal(err) 32 | } 33 | 34 | maxClients := 100 35 | maxMsgs := 100 36 | wgClients := &sync.WaitGroup{} 37 | wgServer := &sync.WaitGroup{} 38 | client := func() { 39 | defer wgClients.Done() 40 | for n := 0; n < maxMsgs; n++ { 41 | func() { 42 | ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) 43 | defer cancel() 44 | req := zmq4.NewReq(ctx) 45 | err := req.Dial(ep) 46 | if err != nil { 47 | mu.Lock() 48 | defer mu.Unlock() 49 | errs = append(errs, err) 50 | return 51 | } 52 | 53 | err = req.Close() 54 | if err != nil { 55 | mu.Lock() 56 | defer mu.Unlock() 57 | errs = append(errs, err) 58 | } 59 | }() 60 | } 61 | } 62 | server := func() { 63 | defer wgServer.Done() 64 | pong := zmq4.NewMsgString("pong") 65 | for { 66 | msg, err := rep.Recv() 67 | if errors.Is(err, context.Canceled) { 68 | break 69 | } 70 | if err != nil { 71 | break 72 | } 73 | if string(msg.Frames[0]) != "ping" { 74 | mu.Lock() 75 | defer mu.Unlock() 76 | errs = append(errs, errors.New("unexpected message")) 77 | return 78 | } 79 | err = rep.Send(pong) 80 | if err != nil { 81 | mu.Lock() 82 | defer mu.Unlock() 83 | errs = append(errs, err) 84 | } 85 | } 86 | } 87 | 88 | wgServer.Add(1) 89 | go server() 90 | wgClients.Add(maxClients) 91 | for n := 0; n < maxClients; n++ { 92 | go client() 93 | } 94 | wgClients.Wait() 95 | cancel() 96 | wgServer.Wait() 97 | rep.Close() 98 | for _, err := range errs { 99 | t.Fatal(err) 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /msg.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4 6 | 7 | import ( 8 | "bytes" 9 | "fmt" 10 | "io" 11 | ) 12 | 13 | type MsgType byte 14 | 15 | const ( 16 | UsrMsg MsgType = 0 17 | CmdMsg MsgType = 1 18 | ) 19 | 20 | // Msg is a ZMTP message, possibly composed of multiple frames. 21 | type Msg struct { 22 | Frames [][]byte 23 | Type MsgType 24 | multipart bool 25 | err error 26 | } 27 | 28 | func NewMsg(frame []byte) Msg { 29 | return Msg{Frames: [][]byte{frame}} 30 | } 31 | 32 | func NewMsgFrom(frames ...[]byte) Msg { 33 | return Msg{Frames: frames} 34 | } 35 | 36 | func NewMsgString(frame string) Msg { 37 | return NewMsg([]byte(frame)) 38 | } 39 | 40 | func NewMsgFromString(frames []string) Msg { 41 | msg := Msg{Frames: make([][]byte, len(frames))} 42 | for i, frame := range frames { 43 | msg.Frames[i] = append(msg.Frames[i], []byte(frame)...) 44 | } 45 | return msg 46 | } 47 | 48 | func (msg Msg) isCmd() bool { 49 | return msg.Type == CmdMsg 50 | } 51 | 52 | func (msg Msg) Err() error { 53 | return msg.err 54 | } 55 | 56 | // Bytes returns the concatenated content of all its frames. 57 | func (msg Msg) Bytes() []byte { 58 | buf := make([]byte, 0, msg.size()) 59 | for _, frame := range msg.Frames { 60 | buf = append(buf, frame...) 61 | } 62 | return buf 63 | } 64 | 65 | func (msg Msg) size() int { 66 | n := 0 67 | for _, frame := range msg.Frames { 68 | n += len(frame) 69 | } 70 | return n 71 | } 72 | 73 | func (msg Msg) String() string { 74 | buf := new(bytes.Buffer) 75 | buf.WriteString("Msg{Frames:{") 76 | for i, frame := range msg.Frames { 77 | if i > 0 { 78 | buf.WriteString(", ") 79 | } 80 | fmt.Fprintf(buf, "%q", frame) 81 | } 82 | buf.WriteString("}}") 83 | return buf.String() 84 | } 85 | 86 | func (msg Msg) Clone() Msg { 87 | o := Msg{Frames: make([][]byte, len(msg.Frames))} 88 | for i, frame := range msg.Frames { 89 | o.Frames[i] = make([]byte, len(frame)) 90 | copy(o.Frames[i], frame) 91 | } 92 | return o 93 | } 94 | 95 | // Cmd is a ZMTP Cmd as per: 96 | // 97 | // https://rfc.zeromq.org/spec:23/ZMTP/#formal-grammar 98 | type Cmd struct { 99 | Name string 100 | Body []byte 101 | } 102 | 103 | func (cmd *Cmd) unmarshalZMTP(data []byte) error { 104 | if len(data) == 0 { 105 | return io.ErrUnexpectedEOF 106 | } 107 | n := int(data[0]) 108 | if n > len(data)-1 { 109 | return ErrBadCmd 110 | } 111 | cmd.Name = string(data[1 : n+1]) 112 | cmd.Body = data[n+1:] 113 | return nil 114 | } 115 | 116 | func (cmd *Cmd) marshalZMTP() ([]byte, error) { 117 | n := len(cmd.Name) 118 | if n > 255 { 119 | return nil, ErrBadCmd 120 | } 121 | 122 | buf := make([]byte, 0, 1+n+len(cmd.Body)) 123 | buf = append(buf, byte(n)) 124 | buf = append(buf, []byte(cmd.Name)...) 125 | buf = append(buf, cmd.Body...) 126 | return buf, nil 127 | } 128 | 129 | // ZMTP commands as per: 130 | // 131 | // https://rfc.zeromq.org/spec:23/ZMTP/#commands 132 | const ( 133 | CmdCancel = "CANCEL" 134 | CmdError = "ERROR" 135 | CmdHello = "HELLO" 136 | CmdInitiate = "INITIATE" 137 | CmdPing = "PING" 138 | CmdPong = "PONG" 139 | CmdReady = "READY" 140 | CmdSubscribe = "SUBSCRIBE" 141 | CmdUnsubscribe = "UNSUBSCRIBE" 142 | CmdWelcome = "WELCOME" 143 | ) 144 | -------------------------------------------------------------------------------- /msgio.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4 6 | 7 | import ( 8 | "context" 9 | "io" 10 | "sync" 11 | 12 | errgrp "github.com/go-zeromq/zmq4/internal/errgroup" 13 | "golang.org/x/sync/errgroup" 14 | ) 15 | 16 | // rpool is the interface that reads ZMQ messages from a pool of connections. 17 | type rpool interface { 18 | io.Closer 19 | 20 | addConn(r *Conn) 21 | rmConn(r *Conn) 22 | read(ctx context.Context, msg *Msg) error 23 | } 24 | 25 | // wpool is the interface that writes ZMQ messages to a pool of connections. 26 | type wpool interface { 27 | io.Closer 28 | 29 | addConn(w *Conn) 30 | rmConn(r *Conn) 31 | write(ctx context.Context, msg Msg) error 32 | } 33 | 34 | // qreader is a queued-message reader. 35 | type qreader struct { 36 | ctx context.Context 37 | mu sync.RWMutex 38 | rs []*Conn 39 | c chan Msg 40 | 41 | sem *semaphore // ready when a connection is live. 42 | } 43 | 44 | func newQReader(ctx context.Context) *qreader { 45 | const qrsize = 10 46 | return &qreader{ 47 | ctx: ctx, 48 | c: make(chan Msg, qrsize), 49 | sem: newSemaphore(), 50 | } 51 | } 52 | 53 | func (q *qreader) Close() error { 54 | q.mu.RLock() 55 | var err error 56 | var grp errgroup.Group 57 | for i := range q.rs { 58 | grp.Go(q.rs[i].Close) 59 | } 60 | err = grp.Wait() 61 | q.rs = nil 62 | q.mu.RUnlock() 63 | return err 64 | } 65 | 66 | func (q *qreader) addConn(r *Conn) { 67 | q.mu.Lock() 68 | q.sem.enable() 69 | q.rs = append(q.rs, r) 70 | q.mu.Unlock() 71 | go q.listen(q.ctx, r) 72 | } 73 | 74 | func (q *qreader) rmConn(r *Conn) { 75 | q.mu.Lock() 76 | defer q.mu.Unlock() 77 | 78 | cur := -1 79 | for i := range q.rs { 80 | if q.rs[i] == r { 81 | cur = i 82 | break 83 | } 84 | } 85 | if cur >= 0 { 86 | q.rs = append(q.rs[:cur], q.rs[cur+1:]...) 87 | } 88 | } 89 | 90 | func (q *qreader) read(ctx context.Context, msg *Msg) error { 91 | q.sem.lock(ctx) 92 | select { 93 | case <-ctx.Done(): 94 | return ctx.Err() 95 | case *msg = <-q.c: 96 | } 97 | return msg.err 98 | } 99 | 100 | func (q *qreader) listen(ctx context.Context, r *Conn) { 101 | defer q.rmConn(r) 102 | defer r.Close() 103 | 104 | for { 105 | msg := r.read() 106 | select { 107 | case <-ctx.Done(): 108 | return 109 | default: 110 | q.c <- msg 111 | if msg.err != nil { 112 | return 113 | } 114 | } 115 | } 116 | } 117 | 118 | type mwriter struct { 119 | ctx context.Context 120 | mu sync.Mutex 121 | ws []*Conn 122 | sem *semaphore 123 | } 124 | 125 | func newMWriter(ctx context.Context) *mwriter { 126 | return &mwriter{ 127 | ctx: ctx, 128 | sem: newSemaphore(), 129 | } 130 | } 131 | 132 | func (w *mwriter) Close() error { 133 | w.mu.Lock() 134 | var err error 135 | for _, ww := range w.ws { 136 | e := ww.Close() 137 | if e != nil && err == nil { 138 | err = e 139 | } 140 | } 141 | w.ws = nil 142 | w.mu.Unlock() 143 | return err 144 | } 145 | 146 | func (mw *mwriter) addConn(w *Conn) { 147 | mw.mu.Lock() 148 | mw.sem.enable() 149 | mw.ws = append(mw.ws, w) 150 | mw.mu.Unlock() 151 | } 152 | 153 | func (mw *mwriter) rmConn(w *Conn) { 154 | mw.mu.Lock() 155 | defer mw.mu.Unlock() 156 | 157 | cur := -1 158 | for i := range mw.ws { 159 | if mw.ws[i] == w { 160 | cur = i 161 | break 162 | } 163 | } 164 | if cur >= 0 { 165 | mw.ws = append(mw.ws[:cur], mw.ws[cur+1:]...) 166 | } 167 | } 168 | 169 | func (w *mwriter) write(ctx context.Context, msg Msg) error { 170 | w.sem.lock(ctx) 171 | grp, _ := errgrp.WithContext(ctx) 172 | w.mu.Lock() 173 | for i := range w.ws { 174 | ww := w.ws[i] 175 | grp.Go(func() error { 176 | return ww.SendMsg(msg) 177 | }) 178 | } 179 | err := grp.Wait() 180 | w.mu.Unlock() 181 | return err 182 | } 183 | 184 | type semaphore struct { 185 | ready chan struct{} 186 | } 187 | 188 | func newSemaphore() *semaphore { 189 | return &semaphore{ready: make(chan struct{})} 190 | } 191 | 192 | func (sem *semaphore) enable() { 193 | select { 194 | case _, ok := <-sem.ready: 195 | if ok { 196 | close(sem.ready) 197 | } 198 | default: 199 | close(sem.ready) 200 | } 201 | } 202 | 203 | func (sem *semaphore) lock(ctx context.Context) { 204 | select { 205 | case <-ctx.Done(): 206 | case <-sem.ready: 207 | } 208 | } 209 | 210 | var ( 211 | _ rpool = (*qreader)(nil) 212 | _ wpool = (*mwriter)(nil) 213 | ) 214 | -------------------------------------------------------------------------------- /null_security.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4 6 | -------------------------------------------------------------------------------- /options.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4 6 | 7 | import ( 8 | "log" 9 | "time" 10 | ) 11 | 12 | // Option configures some aspect of a ZeroMQ socket. 13 | // (e.g. SocketIdentity, Security, ...) 14 | type Option func(s *socket) 15 | 16 | // WithID configures a ZeroMQ socket identity. 17 | func WithID(id SocketIdentity) Option { 18 | return func(s *socket) { 19 | s.id = id 20 | } 21 | } 22 | 23 | // WithSecurity configures a ZeroMQ socket to use the given security mechanism. 24 | // If the security mechanims is nil, the NULL mechanism is used. 25 | func WithSecurity(sec Security) Option { 26 | return func(s *socket) { 27 | s.sec = sec 28 | } 29 | } 30 | 31 | // WithDialerRetry configures the time to wait before two failed attempts 32 | // at dialing an endpoint. 33 | func WithDialerRetry(retry time.Duration) Option { 34 | return func(s *socket) { 35 | s.retry = retry 36 | } 37 | } 38 | 39 | // WithDialerTimeout sets the maximum amount of time a dial will wait 40 | // for a connect to complete. 41 | func WithDialerTimeout(timeout time.Duration) Option { 42 | return func(s *socket) { 43 | s.dialer.Timeout = timeout 44 | } 45 | } 46 | 47 | // WithTimeout sets the timeout value for socket operations 48 | func WithTimeout(timeout time.Duration) Option { 49 | return func(s *socket) { 50 | s.timeout = timeout 51 | } 52 | } 53 | 54 | // WithLogger sets a dedicated log.Logger for the socket. 55 | func WithLogger(msg *log.Logger) Option { 56 | return func(s *socket) { 57 | s.log = msg 58 | } 59 | } 60 | 61 | // WithDialerMaxRetries configures the maximum number of retries 62 | // when dialing an endpoint (-1 means infinite retries). 63 | func WithDialerMaxRetries(maxRetries int) Option { 64 | return func(s *socket) { 65 | s.maxRetries = maxRetries 66 | } 67 | } 68 | 69 | // WithAutomaticReconnect allows to configure a socket to automatically 70 | // reconnect on connection loss. 71 | func WithAutomaticReconnect(automaticReconnect bool) Option { 72 | return func(s *socket) { 73 | s.autoReconnect = automaticReconnect 74 | } 75 | } 76 | 77 | /* 78 | // TODO(sbinet) 79 | 80 | func WithIOThreads(threads int) Option { 81 | return nil 82 | } 83 | 84 | func WithSendBufferSize(size int) Option { 85 | return nil 86 | } 87 | 88 | func WithRecvBufferSize(size int) Option { 89 | return nil 90 | } 91 | */ 92 | 93 | const ( 94 | OptionSubscribe = "SUBSCRIBE" 95 | OptionUnsubscribe = "UNSUBSCRIBE" 96 | OptionHWM = "HWM" 97 | ) 98 | -------------------------------------------------------------------------------- /pair.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4 6 | 7 | import ( 8 | "context" 9 | "net" 10 | ) 11 | 12 | // NewPair returns a new PAIR ZeroMQ socket. 13 | // The returned socket value is initially unbound. 14 | func NewPair(ctx context.Context, opts ...Option) Socket { 15 | pair := &pairSocket{newSocket(ctx, Pair, opts...)} 16 | return pair 17 | } 18 | 19 | // pairSocket is a PAIR ZeroMQ socket. 20 | type pairSocket struct { 21 | sck *socket 22 | } 23 | 24 | // Close closes the open Socket 25 | func (pair *pairSocket) Close() error { 26 | return pair.sck.Close() 27 | } 28 | 29 | // Send puts the message on the outbound send queue. 30 | // Send blocks until the message can be queued or the send deadline expires. 31 | func (pair *pairSocket) Send(msg Msg) error { 32 | return pair.sck.Send(msg) 33 | } 34 | 35 | // SendMulti puts the message on the outbound send queue. 36 | // SendMulti blocks until the message can be queued or the send deadline expires. 37 | // The message will be sent as a multipart message. 38 | func (pair *pairSocket) SendMulti(msg Msg) error { 39 | return pair.sck.SendMulti(msg) 40 | } 41 | 42 | // Recv receives a complete message. 43 | func (pair *pairSocket) Recv() (Msg, error) { 44 | return pair.sck.Recv() 45 | } 46 | 47 | // Listen connects a local endpoint to the Socket. 48 | func (pair *pairSocket) Listen(ep string) error { 49 | return pair.sck.Listen(ep) 50 | } 51 | 52 | // Dial connects a remote endpoint to the Socket. 53 | func (pair *pairSocket) Dial(ep string) error { 54 | return pair.sck.Dial(ep) 55 | } 56 | 57 | // Type returns the type of this Socket (PUB, SUB, ...) 58 | func (pair *pairSocket) Type() SocketType { 59 | return pair.sck.Type() 60 | } 61 | 62 | // Addr returns the listener's address. 63 | // Addr returns nil if the socket isn't a listener. 64 | func (pair *pairSocket) Addr() net.Addr { 65 | return pair.sck.Addr() 66 | } 67 | 68 | // GetOption is used to retrieve an option for a socket. 69 | func (pair *pairSocket) GetOption(name string) (interface{}, error) { 70 | return pair.sck.GetOption(name) 71 | } 72 | 73 | // SetOption is used to set an option for a socket. 74 | func (pair *pairSocket) SetOption(name string, value interface{}) error { 75 | return pair.sck.SetOption(name, value) 76 | } 77 | 78 | var ( 79 | _ Socket = (*pairSocket)(nil) 80 | ) 81 | -------------------------------------------------------------------------------- /protocol.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4 6 | 7 | import ( 8 | "bytes" 9 | "encoding/binary" 10 | "errors" 11 | "fmt" 12 | "io" 13 | "strings" 14 | 15 | "golang.org/x/text/cases" 16 | "golang.org/x/text/language" 17 | ) 18 | 19 | var ( 20 | errGreeting = errors.New("zmq4: invalid greeting received") 21 | errSecMech = errors.New("zmq4: invalid security mechanism") 22 | errBadSec = errors.New("zmq4: invalid or unsupported security mechanism") 23 | ErrBadCmd = errors.New("zmq4: invalid command name") 24 | ErrBadFrame = errors.New("zmq4: invalid frame") 25 | errOverflow = errors.New("zmq4: overflow") 26 | errEmptyAppMDKey = errors.New("zmq4: empty application metadata key") 27 | errDupAppMDKey = errors.New("zmq4: duplicate application metadata key") 28 | errBoolCnv = errors.New("zmq4: invalid byte to bool conversion") 29 | ) 30 | 31 | const ( 32 | sigHeader = 0xFF 33 | sigFooter = 0x7F 34 | 35 | majorVersion uint8 = 3 36 | minorVersion uint8 = 0 37 | 38 | hasMoreBitFlag = 0x1 39 | isLongBitFlag = 0x2 40 | isCommandBitFlag = 0x4 41 | 42 | zmtpMsgLen = 64 43 | ) 44 | 45 | var ( 46 | defaultVersion = [2]uint8{ 47 | majorVersion, 48 | minorVersion, 49 | } 50 | ) 51 | 52 | const ( 53 | maxUint = ^uint(0) 54 | maxInt = int(maxUint >> 1) 55 | maxUint64 = ^uint64(0) 56 | maxInt64 = int64(maxUint64 >> 1) 57 | ) 58 | 59 | func asString(slice []byte) string { 60 | i := bytes.IndexByte(slice, 0) 61 | if i < 0 { 62 | i = len(slice) 63 | } 64 | return string(slice[:i]) 65 | } 66 | 67 | func asBool(b byte) (bool, error) { 68 | switch b { 69 | case 0x00: 70 | return false, nil 71 | case 0x01: 72 | return true, nil 73 | } 74 | 75 | return false, errBoolCnv 76 | } 77 | 78 | type greeting struct { 79 | Sig struct { 80 | Header byte 81 | _ [8]byte 82 | Footer byte 83 | } 84 | Version [2]uint8 85 | Mechanism [20]byte 86 | Server byte 87 | _ [31]byte 88 | } 89 | 90 | func (g *greeting) read(r io.Reader) error { 91 | var data [zmtpMsgLen]byte 92 | _, err := io.ReadFull(r, data[:]) 93 | if err != nil { 94 | return fmt.Errorf("could not read ZMTP greeting: %w", err) 95 | } 96 | 97 | g.unmarshal(data[:]) 98 | 99 | if g.Sig.Header != sigHeader { 100 | return fmt.Errorf("invalid ZMTP signature header: %w", errGreeting) 101 | } 102 | 103 | if g.Sig.Footer != sigFooter { 104 | return fmt.Errorf("invalid ZMTP signature footer: %w", errGreeting) 105 | } 106 | 107 | if !g.validate(defaultVersion) { 108 | return fmt.Errorf( 109 | "invalid ZMTP version (got=%v, want=%v): %w", 110 | g.Version, defaultVersion, errGreeting, 111 | ) 112 | } 113 | 114 | return nil 115 | } 116 | 117 | func (g *greeting) unmarshal(data []byte) { 118 | _ = data[:zmtpMsgLen] 119 | g.Sig.Header = data[0] 120 | g.Sig.Footer = data[9] 121 | g.Version[0] = data[10] 122 | g.Version[1] = data[11] 123 | copy(g.Mechanism[:], data[12:32]) 124 | g.Server = data[32] 125 | } 126 | 127 | func (g *greeting) write(w io.Writer) error { 128 | _, err := w.Write(g.marshal()) 129 | return err 130 | } 131 | 132 | func (g *greeting) marshal() []byte { 133 | var buf [zmtpMsgLen]byte 134 | buf[0] = g.Sig.Header 135 | // padding 1 ignored 136 | buf[9] = g.Sig.Footer 137 | buf[10] = g.Version[0] 138 | buf[11] = g.Version[1] 139 | copy(buf[12:32], g.Mechanism[:]) 140 | buf[32] = g.Server 141 | // padding 2 ignored 142 | return buf[:] 143 | } 144 | 145 | func (g *greeting) validate(ref [2]uint8) bool { 146 | switch { 147 | case g.Version == ref: 148 | return true 149 | case g.Version[0] > ref[0] || 150 | g.Version[0] == ref[0] && g.Version[1] > ref[1]: 151 | // accept higher protocol values 152 | return true 153 | case g.Version[0] < ref[0] || 154 | g.Version[0] == ref[0] && g.Version[1] < ref[1]: 155 | // FIXME(sbinet): handle version negotiations as per 156 | // https://rfc.zeromq.org/spec:23/ZMTP/#version-negotiation 157 | return false 158 | default: 159 | return false 160 | } 161 | } 162 | 163 | const ( 164 | sysSockType = "Socket-Type" 165 | sysSockID = "Identity" 166 | ) 167 | 168 | // Metadata is describing a Conn's metadata information. 169 | type Metadata map[string]string 170 | 171 | // MarshalZMTP marshals MetaData to ZMTP encoded data. 172 | func (md Metadata) MarshalZMTP() ([]byte, error) { 173 | buf := new(bytes.Buffer) 174 | keys := make(map[string]struct{}) 175 | 176 | for k, v := range md { 177 | if len(k) == 0 { 178 | return nil, errEmptyAppMDKey 179 | } 180 | 181 | key := strings.ToLower(k) 182 | if _, dup := keys[key]; dup { 183 | return nil, errDupAppMDKey 184 | } 185 | 186 | keys[key] = struct{}{} 187 | switch k { 188 | case sysSockID, sysSockType: 189 | if _, err := io.Copy(buf, Property{K: k, V: v}); err != nil { 190 | return nil, err 191 | } 192 | default: 193 | if _, err := io.Copy(buf, Property{K: "X-" + key, V: v}); err != nil { 194 | return nil, err 195 | } 196 | } 197 | } 198 | return buf.Bytes(), nil 199 | } 200 | 201 | // UnmarshalZMTP unmarshals MetaData from a ZMTP encoded data. 202 | func (md *Metadata) UnmarshalZMTP(p []byte) error { 203 | i := 0 204 | for i < len(p) { 205 | var kv Property 206 | n, err := kv.Write(p[i:]) 207 | if err != nil { 208 | return err 209 | } 210 | i += n 211 | 212 | name := toTitle(kv.K) 213 | (*md)[name] = kv.V 214 | } 215 | return nil 216 | } 217 | 218 | // Property describes a Conn metadata's entry. 219 | // The on-wire respresentation of Property is specified by: 220 | // 221 | // https://rfc.zeromq.org/spec:23/ZMTP/ 222 | type Property struct { 223 | K string 224 | V string 225 | } 226 | 227 | func (prop Property) Read(data []byte) (n int, err error) { 228 | klen := len(prop.K) 229 | vlen := len(prop.V) 230 | size := 1 + klen + 4 + vlen 231 | _ = data[:size] // help with bound check elision 232 | 233 | data[n] = byte(klen) 234 | n++ 235 | n += copy(data[n:n+klen], toTitle(prop.K)) 236 | binary.BigEndian.PutUint32(data[n:n+4], uint32(vlen)) 237 | n += 4 238 | n += copy(data[n:n+vlen], prop.V) 239 | return n, io.EOF 240 | } 241 | 242 | func (prop *Property) Write(data []byte) (n int, err error) { 243 | klen := int(data[n]) 244 | n++ 245 | if klen > len(data) { 246 | return n, io.ErrUnexpectedEOF 247 | } 248 | 249 | prop.K = toTitle(string(data[n : n+klen])) 250 | n += klen 251 | 252 | v := binary.BigEndian.Uint32(data[n : n+4]) 253 | n += 4 254 | if uint64(v) > uint64(maxInt) { 255 | return n, errOverflow 256 | } 257 | 258 | vlen := int(v) 259 | if n+vlen > len(data) { 260 | return n, io.ErrUnexpectedEOF 261 | } 262 | 263 | prop.V = string(data[n : n+vlen]) 264 | n += vlen 265 | return n, nil 266 | } 267 | 268 | type flag byte 269 | 270 | func (fl flag) hasMore() bool { return fl&hasMoreBitFlag == hasMoreBitFlag } 271 | func (fl flag) isLong() bool { return fl&isLongBitFlag == isLongBitFlag } 272 | func (fl flag) isCommand() bool { return fl&isCommandBitFlag == isCommandBitFlag } 273 | 274 | func toTitle(s string) string { 275 | return cases.Title(language.Und, cases.NoLower).String(s) 276 | } 277 | -------------------------------------------------------------------------------- /protocol_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4 6 | 7 | import ( 8 | "bytes" 9 | "fmt" 10 | "io" 11 | "testing" 12 | ) 13 | 14 | func TestGreeting(t *testing.T) { 15 | for _, tc := range []struct { 16 | name string 17 | data []byte 18 | want error 19 | }{ 20 | { 21 | name: "valid", 22 | data: func() []byte { 23 | w := new(bytes.Buffer) 24 | g := greeting{ 25 | Version: defaultVersion, 26 | } 27 | g.Sig.Header = sigHeader 28 | g.Sig.Footer = sigFooter 29 | err := g.write(w) 30 | if err != nil { 31 | t.Fatalf("could not marshal greeting: %+v", err) 32 | } 33 | return w.Bytes() 34 | }(), 35 | }, 36 | { 37 | name: "empty-buffer", 38 | data: nil, 39 | want: fmt.Errorf("could not read ZMTP greeting: %w", io.EOF), 40 | }, 41 | { 42 | name: "unexpected-EOF", 43 | data: make([]byte, 1), 44 | want: fmt.Errorf("could not read ZMTP greeting: %w", io.ErrUnexpectedEOF), 45 | }, 46 | { 47 | name: "invalid-header", 48 | data: func() []byte { 49 | w := new(bytes.Buffer) 50 | g := greeting{ 51 | Version: defaultVersion, 52 | } 53 | g.Sig.Header = sigFooter // err 54 | g.Sig.Footer = sigFooter 55 | err := g.write(w) 56 | if err != nil { 57 | t.Fatalf("could not marshal greeting: %+v", err) 58 | } 59 | return w.Bytes() 60 | }(), 61 | want: fmt.Errorf("invalid ZMTP signature header: %w", errGreeting), 62 | }, 63 | { 64 | name: "invalid-footer", 65 | data: func() []byte { 66 | w := new(bytes.Buffer) 67 | g := greeting{ 68 | Version: defaultVersion, 69 | } 70 | g.Sig.Header = sigHeader 71 | g.Sig.Footer = sigHeader // err 72 | err := g.write(w) 73 | if err != nil { 74 | t.Fatalf("could not marshal greeting: %+v", err) 75 | } 76 | return w.Bytes() 77 | }(), 78 | want: fmt.Errorf("invalid ZMTP signature footer: %w", errGreeting), 79 | }, 80 | { 81 | name: "higher-major-version", 82 | data: func() []byte { 83 | w := new(bytes.Buffer) 84 | g := greeting{ 85 | Version: [2]uint8{defaultVersion[0] + 1, defaultVersion[1]}, 86 | } 87 | g.Sig.Header = sigHeader 88 | g.Sig.Footer = sigFooter 89 | err := g.write(w) 90 | if err != nil { 91 | t.Fatalf("could not marshal greeting: %+v", err) 92 | } 93 | return w.Bytes() 94 | }(), 95 | }, 96 | { 97 | name: "higher-minor-version", 98 | data: func() []byte { 99 | w := new(bytes.Buffer) 100 | g := greeting{ 101 | Version: [2]uint8{defaultVersion[0], defaultVersion[1] + 1}, 102 | } 103 | g.Sig.Header = sigHeader 104 | g.Sig.Footer = sigFooter 105 | err := g.write(w) 106 | if err != nil { 107 | t.Fatalf("could not marshal greeting: %+v", err) 108 | } 109 | return w.Bytes() 110 | }(), 111 | }, 112 | { 113 | name: "smaller-major-version", // FIXME(sbinet): adapt for when/if we support multiple ZMTP versions 114 | data: func() []byte { 115 | w := new(bytes.Buffer) 116 | g := greeting{ 117 | Version: [2]uint8{defaultVersion[0] - 1, defaultVersion[1]}, 118 | } 119 | g.Sig.Header = sigHeader 120 | g.Sig.Footer = sigFooter 121 | err := g.write(w) 122 | if err != nil { 123 | t.Fatalf("could not marshal greeting: %+v", err) 124 | } 125 | return w.Bytes() 126 | }(), 127 | want: fmt.Errorf("invalid ZMTP version (got=%v, want=%v): %w", 128 | [2]uint8{defaultVersion[0] - 1, defaultVersion[1]}, 129 | defaultVersion, 130 | errGreeting, 131 | ), 132 | }, 133 | } { 134 | t.Run(tc.name, func(t *testing.T) { 135 | var ( 136 | g greeting 137 | r = bytes.NewReader(tc.data) 138 | ) 139 | 140 | err := g.read(r) 141 | switch { 142 | case err == nil && tc.want == nil: 143 | // ok 144 | case err == nil && tc.want != nil: 145 | t.Fatalf("expected an error (%s)", tc.want) 146 | case err != nil && tc.want == nil: 147 | t.Fatalf("could not read ZMTP greeting: %+v", err) 148 | case err != nil && tc.want != nil: 149 | if got, want := err.Error(), tc.want.Error(); got != want { 150 | t.Fatalf("invalid ZMTP greeting error:\ngot= %+v\nwant=%+v\n", 151 | got, want, 152 | ) 153 | } 154 | } 155 | 156 | }) 157 | } 158 | } 159 | -------------------------------------------------------------------------------- /proxy.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "log" 11 | "sync" 12 | 13 | "golang.org/x/sync/errgroup" 14 | ) 15 | 16 | // Proxy connects a frontend socket to a backend socket. 17 | type Proxy struct { 18 | ctx context.Context // life-line of proxy 19 | grp *errgroup.Group 20 | cmds chan proxyCmd 21 | } 22 | 23 | type proxyCmd byte 24 | 25 | const ( 26 | proxyStats proxyCmd = iota 27 | proxyPause 28 | proxyResume 29 | proxyKill 30 | ) 31 | 32 | // NewProxy creates a new Proxy value. 33 | // It proxies messages received on the frontend to the backend (and vice versa) 34 | // If capture is not nil, messages proxied are also sent on that socket. 35 | // 36 | // Conceptually, data flows from frontend to backend. Depending on the 37 | // socket types, replies may flow in the opposite direction. 38 | // The direction is conceptual only; the proxy is fully symmetric and 39 | // there is no technical difference between frontend and backend. 40 | // 41 | // Before creating a Proxy, users must set any socket options, 42 | // and Listen or Dial both frontend and backend sockets. 43 | func NewProxy(ctx context.Context, front, back, capture Socket) *Proxy { 44 | grp, ctx := errgroup.WithContext(ctx) 45 | proxy := Proxy{ 46 | ctx: ctx, 47 | grp: grp, 48 | cmds: make(chan proxyCmd), 49 | } 50 | proxy.init(front, back, capture) 51 | return &proxy 52 | } 53 | 54 | func (p *Proxy) Pause() { p.cmds <- proxyPause } 55 | func (p *Proxy) Stats() { p.cmds <- proxyStats } 56 | func (p *Proxy) Resume() { p.cmds <- proxyResume } 57 | func (p *Proxy) Kill() { p.cmds <- proxyKill } 58 | 59 | // Run runs the proxy loop. 60 | func (p *Proxy) Run() error { 61 | return p.grp.Wait() 62 | } 63 | 64 | func (p *Proxy) init(front, back, capture Socket) { 65 | canRecv := func(sck Socket) bool { 66 | switch sck.Type() { 67 | case Push: 68 | return false 69 | default: 70 | return true 71 | } 72 | } 73 | 74 | canSend := func(sck Socket) bool { 75 | switch sck.Type() { 76 | case Pull: 77 | return false 78 | default: 79 | return true 80 | } 81 | } 82 | 83 | type Pipe struct { 84 | name string 85 | dst Socket 86 | src Socket 87 | } 88 | 89 | var ( 90 | quit = make(chan struct{}) 91 | pipes = []Pipe{ 92 | { 93 | name: "backend", 94 | dst: back, 95 | src: front, 96 | }, 97 | { 98 | name: "frontend", 99 | dst: front, 100 | src: back, 101 | }, 102 | } 103 | ) 104 | 105 | // workers makes sure all goroutines are launched and scheduled. 106 | var workers sync.WaitGroup 107 | workers.Add(len(pipes) + 1) 108 | for i := range pipes { 109 | pipe := pipes[i] 110 | if pipe.src == nil || !canRecv(pipe.src) { 111 | workers.Done() 112 | continue 113 | } 114 | p.grp.Go(func() error { 115 | workers.Done() 116 | canSend := canSend(pipe.dst) 117 | for { 118 | msg, err := pipe.src.Recv() 119 | select { 120 | case <-p.ctx.Done(): 121 | return p.ctx.Err() 122 | case <-quit: 123 | return nil 124 | default: 125 | if canSend { 126 | err = pipe.dst.Send(msg) 127 | if err != nil { 128 | log.Printf("could not forward to %s: %+v", pipe.name, err) 129 | continue 130 | } 131 | } 132 | if err == nil && capture != nil && len(msg.Frames) != 0 { 133 | _ = capture.Send(msg) 134 | } 135 | } 136 | } 137 | }) 138 | } 139 | 140 | p.grp.Go(func() error { 141 | workers.Done() 142 | for { 143 | select { 144 | case <-p.ctx.Done(): 145 | return p.ctx.Err() 146 | case cmd := <-p.cmds: 147 | switch cmd { 148 | case proxyPause, proxyResume, proxyStats: 149 | // TODO 150 | case proxyKill: 151 | close(quit) 152 | return nil 153 | default: 154 | // API error. panic. 155 | panic(fmt.Errorf("invalid control socket command: %v", cmd)) 156 | } 157 | } 158 | } 159 | }) 160 | 161 | // wait for all worker routines to be scheduled. 162 | workers.Wait() 163 | } 164 | -------------------------------------------------------------------------------- /proxy_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4_test 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "sync" 11 | "testing" 12 | "time" 13 | 14 | "github.com/go-zeromq/zmq4" 15 | "golang.org/x/sync/errgroup" 16 | ) 17 | 18 | func TestProxy(t *testing.T) { 19 | bkg := context.Background() 20 | ctx, timeout := context.WithTimeout(bkg, 20*time.Second) 21 | defer timeout() 22 | 23 | var ( 24 | frontIn = zmq4.NewPush(ctx, zmq4.WithLogger(zmq4.Devnull)) 25 | front = zmq4.NewPull(ctx, zmq4.WithLogger(zmq4.Devnull)) 26 | back = zmq4.NewPush(ctx, zmq4.WithLogger(zmq4.Devnull)) 27 | backOut = zmq4.NewPull(ctx, zmq4.WithLogger(zmq4.Devnull)) 28 | capt = zmq4.NewPush(ctx, zmq4.WithLogger(zmq4.Devnull)) 29 | captOut = zmq4.NewPull(ctx, zmq4.WithLogger(zmq4.Devnull)) 30 | 31 | proxy *zmq4.Proxy 32 | 33 | epFront = "ipc://proxy-front" 34 | epBack = "ipc://proxy-back" 35 | epCapt = "ipc://proxy-capt" 36 | 37 | wg1 sync.WaitGroup // all sockets ready 38 | wg2 sync.WaitGroup // proxy setup 39 | wg3 sync.WaitGroup // all messages received 40 | wg4 sync.WaitGroup // all capture messages received 41 | wg5 sync.WaitGroup // terminate sent 42 | wg6 sync.WaitGroup // all sockets done 43 | ) 44 | 45 | wg1.Add(6) // number of sockets 46 | wg2.Add(1) // proxy ready 47 | wg3.Add(1) // messages received at backout 48 | wg4.Add(1) // capture messages received at capt-out 49 | wg5.Add(1) // terminate 50 | wg6.Add(6) // number of sockets 51 | 52 | cleanUp(epFront) 53 | cleanUp(epBack) 54 | cleanUp(epCapt) 55 | 56 | var ( 57 | msgs = []zmq4.Msg{ 58 | zmq4.NewMsgFrom([]byte("msg1")), 59 | zmq4.NewMsgFrom([]byte("msg2")), 60 | zmq4.NewMsgFrom([]byte("msg3")), 61 | zmq4.NewMsgFrom([]byte("msg4")), 62 | } 63 | ) 64 | 65 | grp, ctx := errgroup.WithContext(ctx) 66 | grp.Go(func() error { 67 | defer frontIn.Close() 68 | err := frontIn.Dial(epFront) 69 | if err != nil { 70 | return fmt.Errorf("front-in could not dial %q: %w", epFront, err) 71 | } 72 | 73 | wg1.Done() 74 | t.Logf("front-in ready") 75 | wg1.Wait() // sockets 76 | wg2.Wait() // proxy 77 | 78 | for _, msg := range msgs { 79 | t.Logf("front-in sending %v...", msg) 80 | err = frontIn.Send(msg) 81 | if err != nil { 82 | return fmt.Errorf("could not send front-in %q: %w", msg, err) 83 | } 84 | t.Logf("front-in sending %v... [done]", msg) 85 | } 86 | 87 | wg3.Wait() // all messages received 88 | wg4.Wait() // all capture messages received 89 | t.Logf("front-in waiting for terminate signal") 90 | wg5.Wait() // terminate 91 | 92 | wg6.Done() // all sockets done 93 | wg6.Wait() 94 | return nil 95 | }) 96 | 97 | grp.Go(func() error { 98 | defer front.Close() 99 | err := front.Listen(epFront) 100 | if err != nil { 101 | return fmt.Errorf("front could not listen %q: %w", epFront, err) 102 | } 103 | 104 | wg1.Done() 105 | t.Logf("front ready") 106 | wg1.Wait() // sockets 107 | wg2.Wait() // proxy 108 | wg3.Wait() // all messages received 109 | wg4.Wait() // all capture messages received 110 | t.Logf("front waiting for terminate signal") 111 | wg5.Wait() // terminate 112 | 113 | wg6.Done() // all sockets done 114 | wg6.Wait() 115 | return nil 116 | }) 117 | 118 | grp.Go(func() error { 119 | defer back.Close() 120 | err := back.Listen(epBack) 121 | if err != nil { 122 | return fmt.Errorf("back could not listen %q: %w", epBack, err) 123 | } 124 | 125 | wg1.Done() 126 | t.Logf("back ready") 127 | wg1.Wait() // sockets 128 | wg2.Wait() // proxy 129 | wg3.Wait() // all messages received 130 | wg4.Wait() // all capture messages received 131 | t.Logf("back waiting for terminate signal") 132 | wg5.Wait() // terminate 133 | 134 | wg6.Done() // all sockets done 135 | wg6.Wait() 136 | return nil 137 | }) 138 | 139 | grp.Go(func() error { 140 | defer backOut.Close() 141 | err := backOut.Dial(epBack) 142 | if err != nil { 143 | return fmt.Errorf("back-out could not dial %q: %w", epBack, err) 144 | } 145 | 146 | wg1.Done() 147 | t.Logf("back-out ready") 148 | wg1.Wait() // sockets 149 | wg2.Wait() // proxy 150 | 151 | for _, want := range msgs { 152 | t.Logf("back-out recving %v...", want) 153 | msg, err := backOut.Recv() 154 | if err != nil { 155 | return fmt.Errorf("back-out could not recv: %w", err) 156 | } 157 | if msg.String() != want.String() { 158 | return fmt.Errorf("invalid message: got=%v, want=%v", msg, want) 159 | } 160 | t.Logf("back-out recving %v... [done]", msg) 161 | } 162 | 163 | wg3.Done() // all messages received 164 | wg3.Wait() // all messages received 165 | wg4.Wait() // all capture messages received 166 | t.Logf("back-out waiting for terminate signal") 167 | wg5.Wait() // terminate 168 | 169 | wg6.Done() // all sockets done 170 | wg6.Wait() 171 | return nil 172 | }) 173 | 174 | grp.Go(func() error { 175 | defer captOut.Close() 176 | err := captOut.Listen(epCapt) 177 | if err != nil { 178 | return fmt.Errorf("capt-out could not listen %q: %w", epCapt, err) 179 | } 180 | 181 | wg1.Done() 182 | t.Logf("capt-out ready") 183 | wg1.Wait() // sockets 184 | wg2.Wait() // proxy 185 | wg3.Wait() // all messages received 186 | 187 | for _, want := range msgs { 188 | t.Logf("capt-out recving %v...", want) 189 | msg, err := captOut.Recv() 190 | if err != nil { 191 | return fmt.Errorf("capt-out could not recv msg: %w", err) 192 | } 193 | if msg.String() != want.String() { 194 | return fmt.Errorf("capt-out: invalid message: got=%v, want=%v", msg, want) 195 | } 196 | t.Logf("capt-out recving %v... [done]", msg) 197 | } 198 | 199 | wg4.Done() // all capture messages received 200 | wg4.Wait() // all capture messages received 201 | t.Logf("capt-out waiting for terminate signal") 202 | wg5.Wait() // terminate 203 | 204 | wg6.Done() // all sockets done 205 | wg6.Wait() 206 | return nil 207 | }) 208 | 209 | grp.Go(func() error { 210 | defer capt.Close() 211 | err := capt.Dial(epCapt) 212 | if err != nil { 213 | return fmt.Errorf("capt could not dial %q: %w", epCapt, err) 214 | } 215 | 216 | wg1.Done() 217 | t.Logf("capt ready") 218 | wg1.Wait() // sockets 219 | wg2.Wait() // proxy 220 | wg3.Wait() // all messages received 221 | wg4.Wait() // all capture messages received 222 | t.Logf("capt waiting for terminate signal") 223 | wg5.Wait() // terminate 224 | 225 | wg6.Done() // all sockets done 226 | wg6.Wait() 227 | return nil 228 | }) 229 | 230 | grp.Go(func() error { 231 | t.Logf("ctrl ready") 232 | wg1.Wait() // sockets 233 | wg2.Wait() // proxy 234 | for _, cmd := range []struct { 235 | name string 236 | fct func() 237 | }{ 238 | {"pause", proxy.Pause}, 239 | {"resume", proxy.Resume}, 240 | {"stats", proxy.Stats}, 241 | } { 242 | t.Logf("ctrl sending %v...", cmd.name) 243 | cmd.fct() 244 | t.Logf("ctrl sending %v... [done]", cmd.name) 245 | } 246 | wg3.Wait() // all messages received 247 | wg4.Wait() // all capture messages received 248 | 249 | t.Logf("ctrl sending kill...") 250 | proxy.Kill() 251 | t.Logf("ctrl sending kill... [done]") 252 | 253 | wg5.Done() 254 | t.Logf("ctrl waiting for terminate signal") 255 | wg5.Wait() // terminate 256 | 257 | wg6.Wait() 258 | return nil 259 | }) 260 | 261 | grp.Go(func() error { 262 | wg1.Wait() // sockets ready 263 | proxy = zmq4.NewProxy(ctx, front, back, capt) 264 | t.Logf("proxy ready") 265 | wg2.Done() 266 | err := proxy.Run() 267 | t.Logf("proxy done: err=%+v", err) 268 | return err 269 | }) 270 | 271 | if err := grp.Wait(); err != nil { 272 | t.Fatalf("error: %+v", err) 273 | } 274 | 275 | if err := ctx.Err(); err != nil && err != context.Canceled { 276 | t.Fatalf("error: %+v", err) 277 | } 278 | } 279 | 280 | func TestProxyStop(t *testing.T) { 281 | ctx, cancel := context.WithCancel(context.Background()) 282 | 283 | var ( 284 | epFront = "ipc://proxy-stop-front" 285 | epBack = "ipc://proxy-stop-back" 286 | 287 | frontIn = zmq4.NewPush(ctx, zmq4.WithLogger(zmq4.Devnull)) 288 | front = zmq4.NewPull(ctx, zmq4.WithLogger(zmq4.Devnull)) 289 | back = zmq4.NewPush(ctx, zmq4.WithLogger(zmq4.Devnull)) 290 | backOut = zmq4.NewPull(ctx, zmq4.WithLogger(zmq4.Devnull)) 291 | ) 292 | 293 | cleanUp(epFront) 294 | cleanUp(epBack) 295 | 296 | defer front.Close() 297 | defer back.Close() 298 | 299 | if err := front.Listen(epFront); err != nil { 300 | t.Fatalf("could not listen: %+v", err) 301 | } 302 | 303 | if err := frontIn.Dial(epFront); err != nil { 304 | t.Fatalf("could not dial: %+v", err) 305 | } 306 | 307 | if err := back.Listen(epBack); err != nil { 308 | t.Fatalf("could not listen: %+v", err) 309 | } 310 | 311 | if err := backOut.Dial(epBack); err != nil { 312 | t.Fatalf("could not dial: %+v", err) 313 | } 314 | 315 | var errc = make(chan error) 316 | go func() { 317 | errc <- zmq4.NewProxy(ctx, front, back, nil).Run() 318 | }() 319 | 320 | go func() { 321 | _ = frontIn.Send(zmq4.NewMsgString("msg1")) 322 | }() 323 | go func() { 324 | _, _ = backOut.Recv() 325 | }() 326 | cancel() 327 | 328 | err := <-errc 329 | if err != context.Canceled { 330 | t.Fatalf("error: %+v", err) 331 | } 332 | 333 | if err := ctx.Err(); err != nil && err != context.Canceled { 334 | t.Fatalf("error: %+v", err) 335 | } 336 | } 337 | -------------------------------------------------------------------------------- /pub.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "net" 11 | "sync" 12 | "sync/atomic" 13 | ) 14 | 15 | const ( 16 | DefaultSendHwm = 1000 17 | ) 18 | 19 | // Topics is an interface that wraps the basic Topics method. 20 | type Topics interface { 21 | // Topics returns the sorted list of topics a socket is subscribed to. 22 | Topics() []string 23 | } 24 | 25 | // NewPub returns a new PUB ZeroMQ socket. 26 | // The returned socket value is initially unbound. 27 | func NewPub(ctx context.Context, opts ...Option) Socket { 28 | pub := &pubSocket{sck: newSocket(ctx, Pub, opts...)} 29 | pub.sck.w = newPubMWriter(pub.sck.ctx) 30 | pub.sck.r = newPubQReader(pub.sck.ctx) 31 | return pub 32 | } 33 | 34 | // pubSocket is a PUB ZeroMQ socket. 35 | type pubSocket struct { 36 | sck *socket 37 | } 38 | 39 | // Close closes the open Socket 40 | func (pub *pubSocket) Close() error { 41 | return pub.sck.Close() 42 | } 43 | 44 | // Send puts the message on the outbound send queue. 45 | // Send blocks until the message can be queued or the send deadline expires. 46 | func (pub *pubSocket) Send(msg Msg) error { 47 | ctx, cancel := context.WithTimeout(pub.sck.ctx, pub.sck.Timeout()) 48 | defer cancel() 49 | return pub.sck.w.write(ctx, msg) 50 | } 51 | 52 | // SendMulti puts the message on the outbound send queue. 53 | // SendMulti blocks until the message can be queued or the send deadline expires. 54 | // The message will be sent as a multipart message. 55 | func (pub *pubSocket) SendMulti(msg Msg) error { 56 | msg.multipart = true 57 | ctx, cancel := context.WithTimeout(pub.sck.ctx, pub.sck.Timeout()) 58 | defer cancel() 59 | return pub.sck.w.write(ctx, msg) 60 | } 61 | 62 | // Recv receives a complete message. 63 | func (*pubSocket) Recv() (Msg, error) { 64 | msg := Msg{err: fmt.Errorf("zmq4: PUB sockets can't recv messages")} 65 | return msg, msg.err 66 | } 67 | 68 | // Listen connects a local endpoint to the Socket. 69 | func (pub *pubSocket) Listen(ep string) error { 70 | return pub.sck.Listen(ep) 71 | } 72 | 73 | // Dial connects a remote endpoint to the Socket. 74 | func (pub *pubSocket) Dial(ep string) error { 75 | return pub.sck.Dial(ep) 76 | } 77 | 78 | // Type returns the type of this Socket (PUB, SUB, ...) 79 | func (pub *pubSocket) Type() SocketType { 80 | return pub.sck.Type() 81 | } 82 | 83 | // Addr returns the listener's address. 84 | // Addr returns nil if the socket isn't a listener. 85 | func (pub *pubSocket) Addr() net.Addr { 86 | return pub.sck.Addr() 87 | } 88 | 89 | // GetOption is used to retrieve an option for a socket. 90 | func (pub *pubSocket) GetOption(name string) (interface{}, error) { 91 | return pub.sck.GetOption(name) 92 | } 93 | 94 | // SetOption is used to set an option for a socket. 95 | func (pub *pubSocket) SetOption(name string, value interface{}) error { 96 | err := pub.sck.SetOption(name, value) 97 | if err != nil { 98 | return err 99 | } 100 | 101 | if name != OptionHWM { 102 | return ErrBadProperty 103 | } 104 | 105 | hwm, ok := value.(int) 106 | if !ok { 107 | return ErrBadProperty 108 | } 109 | 110 | w := pub.sck.w.(*pubMWriter) 111 | w.hwm.Store(int64(hwm)) 112 | return nil 113 | } 114 | 115 | // Topics returns the sorted list of topics a socket is subscribed to. 116 | func (pub *pubSocket) Topics() []string { 117 | return pub.sck.topics() 118 | } 119 | 120 | // pubQReader is a queued-message reader. 121 | type pubQReader struct { 122 | ctx context.Context 123 | 124 | mu sync.RWMutex 125 | rs []*Conn 126 | c chan Msg 127 | 128 | sem *semaphore // ready when a connection is live. 129 | } 130 | 131 | func newPubQReader(ctx context.Context) *pubQReader { 132 | const qrsize = 10 133 | return &pubQReader{ 134 | ctx: ctx, 135 | c: make(chan Msg, qrsize), 136 | sem: newSemaphore(), 137 | } 138 | } 139 | 140 | func (q *pubQReader) Close() error { 141 | q.mu.RLock() 142 | var err error 143 | for _, r := range q.rs { 144 | e := r.Close() 145 | if e != nil && err == nil { 146 | err = e 147 | } 148 | } 149 | q.rs = nil 150 | q.mu.RUnlock() 151 | return err 152 | } 153 | 154 | func (q *pubQReader) addConn(r *Conn) { 155 | q.mu.Lock() 156 | q.sem.enable() 157 | q.rs = append(q.rs, r) 158 | q.mu.Unlock() 159 | go q.listen(q.ctx, r) 160 | } 161 | 162 | func (q *pubQReader) rmConn(r *Conn) { 163 | q.mu.Lock() 164 | defer q.mu.Unlock() 165 | 166 | cur := -1 167 | for i := range q.rs { 168 | if q.rs[i] == r { 169 | cur = i 170 | break 171 | } 172 | } 173 | if cur >= 0 { 174 | q.rs = append(q.rs[:cur], q.rs[cur+1:]...) 175 | } 176 | } 177 | 178 | func (q *pubQReader) read(ctx context.Context, msg *Msg) error { 179 | q.sem.lock(ctx) 180 | select { 181 | case <-ctx.Done(): 182 | case *msg = <-q.c: 183 | } 184 | return msg.err 185 | } 186 | 187 | func (q *pubQReader) listen(ctx context.Context, r *Conn) { 188 | defer q.rmConn(r) 189 | defer r.Close() 190 | 191 | for { 192 | msg := r.read() 193 | select { 194 | case <-ctx.Done(): 195 | return 196 | default: 197 | if msg.err != nil { 198 | return 199 | } 200 | switch { 201 | case q.topic(msg): 202 | r.subscribe(msg) 203 | default: 204 | q.c <- msg 205 | } 206 | } 207 | } 208 | } 209 | 210 | func (q *pubQReader) topic(msg Msg) bool { 211 | if len(msg.Frames) != 1 { 212 | return false 213 | } 214 | frame := msg.Frames[0] 215 | if len(frame) == 0 { 216 | return false 217 | } 218 | topic := frame[0] 219 | return topic == 0 || topic == 1 220 | } 221 | 222 | type pubMWriter struct { 223 | ctx context.Context 224 | mu sync.RWMutex 225 | subscribers map[*Conn]chan Msg 226 | 227 | hwm atomic.Int64 228 | } 229 | 230 | func newPubMWriter(ctx context.Context) *pubMWriter { 231 | p := &pubMWriter{ 232 | ctx: ctx, 233 | subscribers: map[*Conn]chan Msg{}, 234 | } 235 | p.hwm.Store(DefaultSendHwm) 236 | return p 237 | } 238 | 239 | func (w *pubMWriter) Close() error { 240 | w.mu.Lock() 241 | defer w.mu.Unlock() 242 | 243 | for conn, channel := range w.subscribers { 244 | _ = conn.Close() 245 | close(channel) 246 | } 247 | w.subscribers = nil 248 | return nil 249 | } 250 | 251 | func (mw *pubMWriter) addConn(w *Conn) { 252 | mw.mu.Lock() 253 | defer mw.mu.Unlock() 254 | 255 | c := make(chan Msg, mw.hwm.Load()) 256 | mw.subscribers[w] = c 257 | go func() { 258 | for { 259 | msg, ok := <-c 260 | if !ok { 261 | break 262 | } 263 | topic := string(msg.Frames[0]) 264 | if w.subscribed(topic) { 265 | _ = w.SendMsg(msg) 266 | } 267 | } 268 | }() 269 | } 270 | 271 | func (mw *pubMWriter) rmConn(w *Conn) { 272 | mw.mu.Lock() 273 | defer mw.mu.Unlock() 274 | 275 | if channel, ok := mw.subscribers[w]; ok { 276 | _ = w.Close() 277 | delete(mw.subscribers, w) 278 | close(channel) 279 | } 280 | } 281 | 282 | func (w *pubMWriter) write(ctx context.Context, msg Msg) error { 283 | w.mu.RLock() 284 | defer w.mu.RUnlock() 285 | 286 | for _, channel := range w.subscribers { 287 | select { 288 | case <-ctx.Done(): 289 | return ctx.Err() 290 | case channel <- msg: // proceeds to default case if the channel is full (msg will be discarded) 291 | default: 292 | } 293 | } 294 | return nil 295 | } 296 | 297 | var ( 298 | _ rpool = (*pubQReader)(nil) 299 | _ wpool = (*pubMWriter)(nil) 300 | _ Socket = (*pubSocket)(nil) 301 | _ Topics = (*pubSocket)(nil) 302 | ) 303 | -------------------------------------------------------------------------------- /pull.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "net" 11 | ) 12 | 13 | // NewPull returns a new PULL ZeroMQ socket. 14 | // The returned socket value is initially unbound. 15 | func NewPull(ctx context.Context, opts ...Option) Socket { 16 | pull := &pullSocket{newSocket(ctx, Pull, opts...)} 17 | pull.sck.w = nil 18 | return pull 19 | } 20 | 21 | // pullSocket is a PULL ZeroMQ socket. 22 | type pullSocket struct { 23 | sck *socket 24 | } 25 | 26 | // Close closes the open Socket 27 | func (pull *pullSocket) Close() error { 28 | return pull.sck.Close() 29 | } 30 | 31 | // Send puts the message on the outbound send queue. 32 | // Send blocks until the message can be queued or the send deadline expires. 33 | func (*pullSocket) Send(msg Msg) error { 34 | return fmt.Errorf("zmq4: PULL sockets can't send messages") 35 | } 36 | 37 | // SendMulti puts the message on the outbound send queue. 38 | // SendMulti blocks until the message can be queued or the send deadline expires. 39 | // The message will be sent as a multipart message. 40 | func (pull *pullSocket) SendMulti(msg Msg) error { 41 | return fmt.Errorf("zmq4: PULL sockets can't send messages") 42 | } 43 | 44 | // Recv receives a complete message. 45 | func (pull *pullSocket) Recv() (Msg, error) { 46 | return pull.sck.Recv() 47 | } 48 | 49 | // Listen connects a local endpoint to the Socket. 50 | func (pull *pullSocket) Listen(ep string) error { 51 | return pull.sck.Listen(ep) 52 | } 53 | 54 | // Dial connects a remote endpoint to the Socket. 55 | func (pull *pullSocket) Dial(ep string) error { 56 | return pull.sck.Dial(ep) 57 | } 58 | 59 | // Type returns the type of this Socket (PUB, SUB, ...) 60 | func (pull *pullSocket) Type() SocketType { 61 | return pull.sck.Type() 62 | } 63 | 64 | // Addr returns the listener's address. 65 | // Addr returns nil if the socket isn't a listener. 66 | func (pull *pullSocket) Addr() net.Addr { 67 | return pull.sck.Addr() 68 | } 69 | 70 | // GetOption is used to retrieve an option for a socket. 71 | func (pull *pullSocket) GetOption(name string) (interface{}, error) { 72 | return pull.sck.GetOption(name) 73 | } 74 | 75 | // SetOption is used to set an option for a socket. 76 | func (pull *pullSocket) SetOption(name string, value interface{}) error { 77 | return pull.sck.SetOption(name, value) 78 | } 79 | 80 | var ( 81 | _ Socket = (*pullSocket)(nil) 82 | ) 83 | -------------------------------------------------------------------------------- /push.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "net" 11 | ) 12 | 13 | // NewPush returns a new PUSH ZeroMQ socket. 14 | // The returned socket value is initially unbound. 15 | func NewPush(ctx context.Context, opts ...Option) Socket { 16 | push := &pushSocket{newSocket(ctx, Push, opts...)} 17 | push.sck.r = nil 18 | return push 19 | } 20 | 21 | // pushSocket is a PUSH ZeroMQ socket. 22 | type pushSocket struct { 23 | sck *socket 24 | } 25 | 26 | // Close closes the open Socket 27 | func (push *pushSocket) Close() error { 28 | return push.sck.Close() 29 | } 30 | 31 | // Send puts the message on the outbound send queue. 32 | // Send blocks until the message can be queued or the send deadline expires. 33 | func (push *pushSocket) Send(msg Msg) error { 34 | return push.sck.Send(msg) 35 | } 36 | 37 | // SendMulti puts the message on the outbound send queue. 38 | // SendMulti blocks until the message can be queued or the send deadline expires. 39 | // The message will be sent as a multipart message. 40 | func (push *pushSocket) SendMulti(msg Msg) error { 41 | return push.sck.SendMulti(msg) 42 | } 43 | 44 | // Recv receives a complete message. 45 | func (*pushSocket) Recv() (Msg, error) { 46 | return Msg{}, fmt.Errorf("zmq4: PUSH sockets can't recv messages") 47 | } 48 | 49 | // Listen connects a local endpoint to the Socket. 50 | func (push *pushSocket) Listen(ep string) error { 51 | return push.sck.Listen(ep) 52 | } 53 | 54 | // Dial connects a remote endpoint to the Socket. 55 | func (push *pushSocket) Dial(ep string) error { 56 | return push.sck.Dial(ep) 57 | } 58 | 59 | // Type returns the type of this Socket (PUB, SUB, ...) 60 | func (push *pushSocket) Type() SocketType { 61 | return push.sck.Type() 62 | } 63 | 64 | // Addr returns the listener's address. 65 | // Addr returns nil if the socket isn't a listener. 66 | func (push *pushSocket) Addr() net.Addr { 67 | return push.sck.Addr() 68 | } 69 | 70 | // GetOption is used to retrieve an option for a socket. 71 | func (push *pushSocket) GetOption(name string) (interface{}, error) { 72 | return push.sck.GetOption(name) 73 | } 74 | 75 | // SetOption is used to set an option for a socket. 76 | func (push *pushSocket) SetOption(name string, value interface{}) error { 77 | return push.sck.SetOption(name, value) 78 | } 79 | 80 | var ( 81 | _ Socket = (*pushSocket)(nil) 82 | ) 83 | -------------------------------------------------------------------------------- /queue.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4 6 | 7 | import ( 8 | "container/list" 9 | ) 10 | 11 | const innerCap = 512 12 | 13 | type Queue struct { 14 | rep *list.List 15 | len int 16 | } 17 | 18 | func NewQueue() *Queue { 19 | q := &Queue{list.New(), 0} 20 | return q 21 | } 22 | 23 | func (q *Queue) Len() int { 24 | return q.len 25 | } 26 | 27 | func (q *Queue) Init() { 28 | q.rep.Init() 29 | q.len = 0 30 | } 31 | 32 | func (q *Queue) Push(val Msg) { 33 | q.len++ 34 | 35 | var i []interface{} 36 | elem := q.rep.Back() 37 | if elem != nil { 38 | i = elem.Value.([]interface{}) 39 | } 40 | if i == nil || len(i) == innerCap { 41 | elem = q.rep.PushBack(make([]interface{}, 0, innerCap)) 42 | i = elem.Value.([]interface{}) 43 | } 44 | 45 | elem.Value = append(i, val) 46 | } 47 | 48 | func (q *Queue) Peek() (Msg, bool) { 49 | i := q.front() 50 | if i == nil { 51 | return Msg{}, false 52 | } 53 | return i[0].(Msg), true 54 | } 55 | 56 | func (q *Queue) Pop() { 57 | elem := q.rep.Front() 58 | if elem == nil { 59 | panic("attempting to Pop on an empty Queue") 60 | } 61 | 62 | q.len-- 63 | i := elem.Value.([]interface{}) 64 | i[0] = nil // remove ref to poped element 65 | i = i[1:] 66 | if len(i) == 0 { 67 | q.rep.Remove(elem) 68 | } else { 69 | elem.Value = i 70 | } 71 | } 72 | 73 | func (q *Queue) front() []interface{} { 74 | elem := q.rep.Front() 75 | if elem == nil { 76 | return nil 77 | } 78 | return elem.Value.([]interface{}) 79 | } 80 | -------------------------------------------------------------------------------- /queue_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4 6 | 7 | import ( 8 | "reflect" 9 | "testing" 10 | ) 11 | 12 | func makeMsg(i int) Msg { 13 | return NewMsgString(string(rune(i))) 14 | } 15 | 16 | func TestQueue(t *testing.T) { 17 | q := NewQueue() 18 | if q.Len() != 0 { 19 | t.Fatal("queue should be empty") 20 | } 21 | if _, exists := q.Peek(); exists { 22 | t.Fatal("Queue should be empty") 23 | } 24 | 25 | q.Push(makeMsg(1)) 26 | if q.Len() != 1 { 27 | t.Fatal("queue should contain 1 element") 28 | } 29 | msg, ok := q.Peek() 30 | if !ok || !reflect.DeepEqual(msg, makeMsg(1)) { 31 | t.Fatal("unexpected value in queue") 32 | } 33 | 34 | q.Push(makeMsg(2)) 35 | if q.Len() != 2 { 36 | t.Fatal("queue should contain 2 elements") 37 | } 38 | msg, ok = q.Peek() 39 | if !ok || !reflect.DeepEqual(msg, makeMsg(1)) { 40 | t.Fatal("unexpected value in queue") 41 | } 42 | 43 | q.Pop() 44 | if q.Len() != 1 { 45 | t.Fatal("queue should contain 1 element") 46 | } 47 | msg, ok = q.Peek() 48 | if !ok || !reflect.DeepEqual(msg, makeMsg(2)) { 49 | t.Fatal("unexpected value in queue") 50 | } 51 | 52 | q.Pop() 53 | if q.Len() != 0 { 54 | t.Fatal("queue should be empty") 55 | } 56 | 57 | q.Push(makeMsg(1)) 58 | q.Push(makeMsg(2)) 59 | q.Init() 60 | if q.Len() != 0 { 61 | t.Fatal("queue should be empty") 62 | } 63 | } 64 | 65 | func TestQueueNewInnerList(t *testing.T) { 66 | q := NewQueue() 67 | 68 | for i := 1; i <= innerCap; i++ { 69 | q.Push(makeMsg(i)) 70 | } 71 | 72 | if q.Len() != innerCap { 73 | t.Fatalf("queue should contain %d elements", innerCap) 74 | } 75 | 76 | // next push will create a new inner slice 77 | q.Push(makeMsg(innerCap + 1)) 78 | if q.Len() != innerCap+1 { 79 | t.Fatalf("queue should contain %d elements", innerCap+1) 80 | } 81 | msg, ok := q.Peek() 82 | if !ok || !reflect.DeepEqual(msg, makeMsg(1)) { 83 | t.Fatal("unexpected value in queue") 84 | } 85 | 86 | q.Pop() 87 | if q.Len() != innerCap { 88 | t.Fatalf("queue should contain %d elements", innerCap) 89 | } 90 | msg, ok = q.Peek() 91 | if !ok || !reflect.DeepEqual(msg, makeMsg(2)) { 92 | t.Fatal("unexpected value in queue") 93 | } 94 | 95 | q.Push(makeMsg(innerCap + 1)) 96 | q.Init() 97 | if q.Len() != 0 { 98 | t.Fatal("queue should be empty") 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /reaper_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4 6 | 7 | import ( 8 | "context" 9 | "io" 10 | "net" 11 | "sync/atomic" 12 | "testing" 13 | "time" 14 | ) 15 | 16 | func TestConnReaperDeadlock2(t *testing.T) { 17 | ep := must(EndPoint("tcp")) 18 | defer cleanUp(ep) 19 | 20 | ctx, cancel := context.WithCancel(context.Background()) 21 | defer cancel() 22 | 23 | // Bind the server. 24 | srv := NewRouter(ctx, WithLogger(Devnull)).(*routerSocket) 25 | if err := srv.Listen(ep); err != nil { 26 | t.Fatalf("could not listen on %q: %+v", ep, err) 27 | } 28 | defer srv.Close() 29 | 30 | // Add modified clients connection to server 31 | // so any send to client will trigger context switch 32 | // and be failing. 33 | // Idea is that while srv.Send is progressing, 34 | // the connection will be closed and assigned 35 | // for connection reaper, and reaper will try to remove those 36 | id := "client-x" 37 | srv.sck.mu.Lock() 38 | rmw := srv.sck.w.(*routerMWriter) 39 | for i := 0; i < 2; i++ { 40 | w := &Conn{} 41 | w.Peer.Meta = make(Metadata) 42 | w.Peer.Meta[sysSockID] = id 43 | w.rw = &sockSendEOF{} 44 | w.onCloseErrorCB = srv.sck.scheduleRmConn 45 | // Do not to call srv.addConn as we dont want to have listener on this fake socket 46 | rmw.addConn(w) 47 | srv.sck.conns = append(srv.sck.conns, w) 48 | } 49 | srv.sck.mu.Unlock() 50 | 51 | // Now try to send a message from the server to all clients. 52 | msg := NewMsgFrom(nil, nil, []byte("payload")) 53 | msg.Frames[0] = []byte(id) 54 | if err := srv.Send(msg); err != nil { 55 | t.Logf("Send to %s failed: %+v\n", id, err) 56 | } 57 | } 58 | 59 | type sockSendEOF struct { 60 | } 61 | 62 | var a atomic.Int32 63 | 64 | func (r *sockSendEOF) Write(b []byte) (n int, err error) { 65 | // Each odd write fails asap. 66 | // Each even write fails after sleep. 67 | // Such a way we ensure the short write failure 68 | // will cause socket be assinged to connection reaper 69 | // while srv.Send is still in progress due to long writes. 70 | if x := a.Add(1); x&1 == 0 { 71 | time.Sleep(1 * time.Second) 72 | } 73 | return 0, io.EOF 74 | } 75 | 76 | func (r *sockSendEOF) Read(b []byte) (int, error) { 77 | return 0, nil 78 | } 79 | 80 | func (r *sockSendEOF) Close() error { 81 | return nil 82 | } 83 | 84 | func (r *sockSendEOF) LocalAddr() net.Addr { 85 | return nil 86 | } 87 | 88 | func (r *sockSendEOF) RemoteAddr() net.Addr { 89 | return nil 90 | } 91 | 92 | func (r *sockSendEOF) SetDeadline(t time.Time) error { 93 | return nil 94 | } 95 | 96 | func (r *sockSendEOF) SetReadDeadline(t time.Time) error { 97 | return nil 98 | } 99 | 100 | func (r *sockSendEOF) SetWriteDeadline(t time.Time) error { 101 | return nil 102 | } 103 | -------------------------------------------------------------------------------- /rep.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "net" 11 | "sync" 12 | ) 13 | 14 | // NewRep returns a new REP ZeroMQ socket. 15 | // The returned socket value is initially unbound. 16 | func NewRep(ctx context.Context, opts ...Option) Socket { 17 | rep := &repSocket{newSocket(ctx, Rep, opts...)} 18 | sharedState := newRepState() 19 | rep.sck.w = newRepWriter(rep.sck.ctx, sharedState) 20 | rep.sck.r = newRepReader(rep.sck.ctx, sharedState) 21 | return rep 22 | } 23 | 24 | // repSocket is a REP ZeroMQ socket. 25 | type repSocket struct { 26 | sck *socket 27 | } 28 | 29 | // Close closes the open Socket 30 | func (rep *repSocket) Close() error { 31 | return rep.sck.Close() 32 | } 33 | 34 | // Send puts the message on the outbound send queue. 35 | // Send blocks until the message can be queued or the send deadline expires. 36 | func (rep *repSocket) Send(msg Msg) error { 37 | ctx, cancel := context.WithTimeout(rep.sck.ctx, rep.sck.Timeout()) 38 | defer cancel() 39 | return rep.sck.w.write(ctx, msg) 40 | } 41 | 42 | // SendMulti puts the message on the outbound send queue. 43 | // SendMulti blocks until the message can be queued or the send deadline expires. 44 | // The message will be sent as a multipart message. 45 | func (rep *repSocket) SendMulti(msg Msg) error { 46 | msg.multipart = true 47 | ctx, cancel := context.WithTimeout(rep.sck.ctx, rep.sck.Timeout()) 48 | defer cancel() 49 | return rep.sck.w.write(ctx, msg) 50 | } 51 | 52 | // Recv receives a complete message. 53 | func (rep *repSocket) Recv() (Msg, error) { 54 | ctx, cancel := context.WithCancel(rep.sck.ctx) 55 | defer cancel() 56 | var msg Msg 57 | err := rep.sck.r.read(ctx, &msg) 58 | return msg, err 59 | } 60 | 61 | // Listen connects a local endpoint to the Socket. 62 | func (rep *repSocket) Listen(ep string) error { 63 | return rep.sck.Listen(ep) 64 | } 65 | 66 | // Dial connects a remote endpoint to the Socket. 67 | func (rep *repSocket) Dial(ep string) error { 68 | return rep.sck.Dial(ep) 69 | } 70 | 71 | // Type returns the type of this Socket (PUB, SUB, ...) 72 | func (rep *repSocket) Type() SocketType { 73 | return rep.sck.Type() 74 | } 75 | 76 | // Addr returns the listener's address. 77 | // Addr returns nil if the socket isn't a listener. 78 | func (rep *repSocket) Addr() net.Addr { 79 | return rep.sck.Addr() 80 | } 81 | 82 | // GetOption is used to retrieve an option for a socket. 83 | func (rep *repSocket) GetOption(name string) (interface{}, error) { 84 | return rep.sck.GetOption(name) 85 | } 86 | 87 | // SetOption is used to set an option for a socket. 88 | func (rep *repSocket) SetOption(name string, value interface{}) error { 89 | return rep.sck.SetOption(name, value) 90 | } 91 | 92 | type repMsg struct { 93 | conn *Conn 94 | msg Msg 95 | } 96 | 97 | type repReader struct { 98 | ctx context.Context 99 | state *repState 100 | 101 | mu sync.Mutex 102 | conns []*Conn 103 | 104 | msgCh chan repMsg 105 | } 106 | 107 | func newRepReader(ctx context.Context, state *repState) *repReader { 108 | const qsize = 10 109 | return &repReader{ 110 | ctx: ctx, 111 | msgCh: make(chan repMsg, qsize), 112 | state: state, 113 | } 114 | } 115 | 116 | func (r *repReader) addConn(c *Conn) { 117 | r.mu.Lock() 118 | r.conns = append(r.conns, c) 119 | r.mu.Unlock() 120 | go r.listen(r.ctx, c) 121 | } 122 | 123 | func (r *repReader) rmConn(conn *Conn) { 124 | r.mu.Lock() 125 | defer r.mu.Unlock() 126 | 127 | cur := -1 128 | for i := range r.conns { 129 | if r.conns[i] == conn { 130 | cur = i 131 | break 132 | } 133 | } 134 | if cur >= 0 { 135 | r.conns = append(r.conns[:cur], r.conns[cur+1:]...) 136 | } 137 | } 138 | 139 | func (r *repReader) read(ctx context.Context, msg *Msg) error { 140 | select { 141 | case <-ctx.Done(): 142 | return ctx.Err() 143 | case repMsg := <-r.msgCh: 144 | if repMsg.msg.err != nil { 145 | return repMsg.msg.err 146 | } 147 | pre, innerMsg := splitReq(repMsg.msg) 148 | if pre == nil { 149 | return fmt.Errorf("zmq4: invalid REP message") 150 | } 151 | *msg = innerMsg 152 | r.state.Set(repMsg.conn, pre) 153 | } 154 | return nil 155 | } 156 | 157 | func (r *repReader) listen(ctx context.Context, conn *Conn) { 158 | defer r.rmConn(conn) 159 | defer conn.Close() 160 | 161 | for { 162 | msg := conn.read() 163 | select { 164 | case <-ctx.Done(): 165 | return 166 | default: 167 | if msg.err != nil { 168 | return 169 | } 170 | r.msgCh <- repMsg{conn, msg} 171 | } 172 | } 173 | } 174 | 175 | func (r *repReader) Close() error { 176 | r.mu.Lock() 177 | defer r.mu.Unlock() 178 | 179 | var err error 180 | for _, conn := range r.conns { 181 | e := conn.Close() 182 | if e != nil && err == nil { 183 | err = e 184 | } 185 | } 186 | r.conns = nil 187 | return err 188 | } 189 | 190 | func splitReq(envelope Msg) (preamble [][]byte, msg Msg) { 191 | for i, frame := range envelope.Frames { 192 | if len(frame) != 0 { 193 | continue 194 | } 195 | preamble = envelope.Frames[:i+1] 196 | if i+1 < len(envelope.Frames) { 197 | msg = NewMsgFrom(envelope.Frames[i+1:]...) 198 | } 199 | } 200 | return 201 | } 202 | 203 | type repSendPayload struct { 204 | conn *Conn 205 | preamble [][]byte 206 | msg Msg 207 | } 208 | 209 | type repWriter struct { 210 | ctx context.Context 211 | state *repState 212 | 213 | mu sync.Mutex 214 | conns []*Conn 215 | 216 | sendCh chan repSendPayload 217 | } 218 | 219 | func (r repSendPayload) buildReplyMsg() Msg { 220 | var frames = make([][]byte, 0, len(r.preamble)+len(r.msg.Frames)) 221 | frames = append(frames, r.preamble...) 222 | frames = append(frames, r.msg.Frames...) 223 | return NewMsgFrom(frames...) 224 | } 225 | 226 | func newRepWriter(ctx context.Context, state *repState) *repWriter { 227 | r := &repWriter{ 228 | ctx: ctx, 229 | state: state, 230 | sendCh: make(chan repSendPayload), 231 | } 232 | go r.run() 233 | return r 234 | } 235 | 236 | func (r *repWriter) addConn(w *Conn) { 237 | r.mu.Lock() 238 | r.conns = append(r.conns, w) 239 | r.mu.Unlock() 240 | } 241 | 242 | func (r *repWriter) rmConn(conn *Conn) { 243 | r.mu.Lock() 244 | defer r.mu.Unlock() 245 | 246 | cur := -1 247 | for i := range r.conns { 248 | if r.conns[i] == conn { 249 | cur = i 250 | break 251 | } 252 | } 253 | if cur >= 0 { 254 | r.conns = append(r.conns[:cur], r.conns[cur+1:]...) 255 | } 256 | } 257 | 258 | func (r *repWriter) write(ctx context.Context, msg Msg) error { 259 | conn, preamble := r.state.Get() 260 | select { 261 | case <-ctx.Done(): 262 | return ctx.Err() 263 | case <-r.ctx.Done(): // repWriter.run() terminates on this, sendCh <- will not complete 264 | return r.ctx.Err() 265 | case r.sendCh <- repSendPayload{conn, preamble, msg}: 266 | return nil 267 | } 268 | } 269 | 270 | func (r *repWriter) run() { 271 | for { 272 | select { 273 | case <-r.ctx.Done(): 274 | return 275 | case payload, ok := <-r.sendCh: 276 | if !ok { 277 | return 278 | } 279 | r.sendPayload(payload) 280 | } 281 | } 282 | } 283 | 284 | func (r *repWriter) sendPayload(payload repSendPayload) { 285 | r.mu.Lock() 286 | defer r.mu.Unlock() 287 | for _, conn := range r.conns { 288 | if conn == payload.conn { 289 | reply := payload.buildReplyMsg() 290 | // not much we can do at this point. Perhaps log the error? 291 | _ = conn.SendMsg(reply) 292 | return 293 | } 294 | } 295 | } 296 | 297 | func (r *repWriter) Close() error { 298 | close(r.sendCh) 299 | r.mu.Lock() 300 | defer r.mu.Unlock() 301 | 302 | var err error 303 | for _, conn := range r.conns { 304 | e := conn.Close() 305 | if e != nil && err == nil { 306 | err = e 307 | } 308 | } 309 | r.conns = nil 310 | return err 311 | } 312 | 313 | type repState struct { 314 | mu sync.Mutex 315 | conn *Conn 316 | preamble [][]byte // includes delimiter 317 | } 318 | 319 | func newRepState() *repState { 320 | return &repState{} 321 | } 322 | 323 | func (r *repState) Get() (conn *Conn, preamble [][]byte) { 324 | r.mu.Lock() 325 | conn = r.conn 326 | preamble = r.preamble 327 | r.mu.Unlock() 328 | return 329 | } 330 | 331 | func (r *repState) Set(conn *Conn, pre [][]byte) { 332 | r.mu.Lock() 333 | r.conn = conn 334 | r.preamble = pre 335 | r.mu.Unlock() 336 | } 337 | 338 | var ( 339 | _ Socket = (*repSocket)(nil) 340 | ) 341 | -------------------------------------------------------------------------------- /rep_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4_test 6 | 7 | import ( 8 | "context" 9 | "errors" 10 | "sync" 11 | "testing" 12 | 13 | "github.com/go-zeromq/zmq4" 14 | ) 15 | 16 | func TestIssue99(t *testing.T) { 17 | var ( 18 | wg sync.WaitGroup 19 | outMsg zmq4.Msg 20 | inMsg zmq4.Msg 21 | ok = make(chan int) 22 | ) 23 | 24 | ep, err := EndPoint("tcp") 25 | if err != nil { 26 | t.Fatalf("could not find endpoint: %+v", err) 27 | } 28 | 29 | requester := func() { 30 | defer wg.Done() 31 | defer close(ok) 32 | 33 | req := zmq4.NewReq(context.Background()) 34 | defer req.Close() 35 | 36 | err := req.Dial(ep) 37 | if err != nil { 38 | t.Errorf("could not dial: %+v", err) 39 | return 40 | } 41 | 42 | // Test message w/ 3 frames 43 | outMsg = zmq4.NewMsgFromString([]string{"ZERO", "Hello!", "World!"}) 44 | err = req.Send(outMsg) 45 | if err != nil { 46 | t.Errorf("failed to send: %+v", err) 47 | return 48 | } 49 | 50 | inMsg, err = req.Recv() 51 | if err != nil { 52 | t.Errorf("failed to recv: %+v", err) 53 | return 54 | } 55 | } 56 | 57 | responder := func() { 58 | defer wg.Done() 59 | 60 | rep := zmq4.NewRep(context.Background()) 61 | defer rep.Close() 62 | 63 | err := rep.Listen(ep) 64 | if err != nil { 65 | t.Errorf("could not dial: %+v", err) 66 | return 67 | } 68 | 69 | // Wait for next request from client 70 | msg, err := rep.Recv() 71 | if err != nil { 72 | t.Errorf("could not recv request: %+v", err) 73 | return 74 | } 75 | 76 | // Send reply back to client 77 | err = rep.Send(msg) 78 | if err != nil { 79 | t.Errorf("could not send reply: %+v", err) 80 | return 81 | } 82 | <-ok 83 | } 84 | 85 | wg.Add(2) 86 | 87 | go requester() 88 | go responder() 89 | 90 | wg.Wait() 91 | 92 | if want, got := len(outMsg.Frames), len(inMsg.Frames); want != got { 93 | t.Fatalf("message length mismatch: got=%d, want=%d", got, want) 94 | } 95 | } 96 | 97 | func TestCancellation(t *testing.T) { 98 | // if the context is cancelled during a rep.Send both the requester and the responder should get an error 99 | var wg sync.WaitGroup 100 | 101 | ep, err := EndPoint("tcp") 102 | if err != nil { 103 | t.Fatalf("could not find endpoint: %+v", err) 104 | } 105 | 106 | responderStarted := make(chan bool) 107 | 108 | requester := func() { 109 | defer wg.Done() 110 | <-responderStarted 111 | 112 | req := zmq4.NewReq(context.Background()) 113 | defer req.Close() 114 | 115 | err := req.Dial(ep) 116 | if err != nil { 117 | t.Errorf("could not dial: %+v", err) 118 | return 119 | } 120 | 121 | err = req.Send(zmq4.NewMsgString("ping")) 122 | if err != nil { 123 | t.Errorf("could not send: %+v", err) 124 | return 125 | } 126 | 127 | msg, err := req.Recv() 128 | if err == nil { 129 | t.Errorf("requester should have gotten an error, but got: %+v", msg) 130 | } 131 | } 132 | 133 | responder := func() { 134 | 135 | defer wg.Done() 136 | repCtx, cancel := context.WithCancel(context.Background()) 137 | defer cancel() 138 | rep := zmq4.NewRep(repCtx) 139 | defer rep.Close() 140 | 141 | err := rep.Listen(ep) 142 | if err != nil { 143 | t.Errorf("could not dial: %+v", err) 144 | return 145 | } 146 | 147 | responderStarted <- true 148 | 149 | _, err = rep.Recv() 150 | if err != nil { 151 | t.Errorf("could not recv: %+v", err) 152 | return 153 | } 154 | 155 | // cancel the context right before sending the response 156 | cancel() 157 | err = rep.Send(zmq4.NewMsgString("pong")) 158 | 159 | if !errors.Is(err, context.Canceled) { 160 | t.Errorf("context should be cancelled: %+v", err) 161 | } 162 | } 163 | 164 | wg.Add(2) 165 | 166 | go requester() 167 | go responder() 168 | 169 | wg.Wait() 170 | } 171 | -------------------------------------------------------------------------------- /req.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "net" 11 | "sync" 12 | ) 13 | 14 | // NewReq returns a new REQ ZeroMQ socket. 15 | // The returned socket value is initially unbound. 16 | func NewReq(ctx context.Context, opts ...Option) Socket { 17 | state := &reqState{} 18 | req := &reqSocket{newSocket(ctx, Req, opts...), state} 19 | req.sck.r = newReqReader(req.sck.ctx, state) 20 | req.sck.w = newReqWriter(req.sck.ctx, state) 21 | return req 22 | } 23 | 24 | // reqSocket is a REQ ZeroMQ socket. 25 | type reqSocket struct { 26 | sck *socket 27 | state *reqState 28 | } 29 | 30 | // Close closes the open Socket 31 | func (req *reqSocket) Close() error { 32 | return req.sck.Close() 33 | } 34 | 35 | // Send puts the message on the outbound send queue. 36 | // Send blocks until the message can be queued or the send deadline expires. 37 | func (req *reqSocket) Send(msg Msg) error { 38 | ctx, cancel := context.WithTimeout(req.sck.ctx, req.sck.Timeout()) 39 | defer cancel() 40 | return req.sck.w.write(ctx, msg) 41 | } 42 | 43 | // SendMulti puts the message on the outbound send queue. 44 | // SendMulti blocks until the message can be queued or the send deadline expires. 45 | // The message will be sent as a multipart message. 46 | func (req *reqSocket) SendMulti(msg Msg) error { 47 | msg.multipart = true 48 | ctx, cancel := context.WithTimeout(req.sck.ctx, req.sck.Timeout()) 49 | defer cancel() 50 | return req.sck.w.write(ctx, msg) 51 | } 52 | 53 | // Recv receives a complete message. 54 | func (req *reqSocket) Recv() (Msg, error) { 55 | ctx, cancel := context.WithCancel(req.sck.ctx) 56 | defer cancel() 57 | var msg Msg 58 | err := req.sck.r.read(ctx, &msg) 59 | return msg, err 60 | } 61 | 62 | // Listen connects a local endpoint to the Socket. 63 | func (req *reqSocket) Listen(ep string) error { 64 | return req.sck.Listen(ep) 65 | } 66 | 67 | // Dial connects a remote endpoint to the Socket. 68 | func (req *reqSocket) Dial(ep string) error { 69 | return req.sck.Dial(ep) 70 | } 71 | 72 | // Type returns the type of this Socket (PUB, SUB, ...) 73 | func (req *reqSocket) Type() SocketType { 74 | return req.sck.Type() 75 | } 76 | 77 | // Addr returns the listener's address. 78 | // Addr returns nil if the socket isn't a listener. 79 | func (req *reqSocket) Addr() net.Addr { 80 | return req.sck.Addr() 81 | } 82 | 83 | // GetOption is used to retrieve an option for a socket. 84 | func (req *reqSocket) GetOption(name string) (interface{}, error) { 85 | return req.sck.GetOption(name) 86 | } 87 | 88 | // SetOption is used to set an option for a socket. 89 | func (req *reqSocket) SetOption(name string, value interface{}) error { 90 | return req.sck.SetOption(name, value) 91 | } 92 | 93 | type reqWriter struct { 94 | mu sync.Mutex 95 | conns []*Conn 96 | nextConn int 97 | state *reqState 98 | } 99 | 100 | func newReqWriter(ctx context.Context, state *reqState) *reqWriter { 101 | return &reqWriter{ 102 | state: state, 103 | } 104 | } 105 | 106 | func (r *reqWriter) write(ctx context.Context, msg Msg) error { 107 | msg.Frames = append([][]byte{nil}, msg.Frames...) 108 | 109 | r.mu.Lock() 110 | defer r.mu.Unlock() 111 | var err error 112 | for i := 0; i < len(r.conns); i++ { 113 | cur := i + r.nextConn%len(r.conns) 114 | conn := r.conns[cur] 115 | err = conn.SendMsg(msg) 116 | if err == nil { 117 | r.nextConn = cur + 1%len(r.conns) 118 | r.state.Set(conn) 119 | return nil 120 | } 121 | } 122 | return fmt.Errorf("zmq4: no connections available: %w", err) 123 | } 124 | 125 | func (r *reqWriter) addConn(c *Conn) { 126 | r.mu.Lock() 127 | r.conns = append(r.conns, c) 128 | r.mu.Unlock() 129 | } 130 | 131 | func (r *reqWriter) rmConn(conn *Conn) { 132 | r.mu.Lock() 133 | defer r.mu.Unlock() 134 | 135 | cur := -1 136 | for i := range r.conns { 137 | if r.conns[i] == conn { 138 | cur = i 139 | break 140 | } 141 | } 142 | if cur >= 0 { 143 | r.conns = append(r.conns[:cur], r.conns[cur+1:]...) 144 | } 145 | 146 | r.state.Reset(conn) 147 | } 148 | 149 | func (r *reqWriter) Close() error { 150 | r.mu.Lock() 151 | defer r.mu.Unlock() 152 | 153 | var err error 154 | for _, conn := range r.conns { 155 | e := conn.Close() 156 | if e != nil && err == nil { 157 | err = e 158 | } 159 | } 160 | r.conns = nil 161 | return err 162 | } 163 | 164 | type reqReader struct { 165 | state *reqState 166 | } 167 | 168 | func newReqReader(ctx context.Context, state *reqState) *reqReader { 169 | return &reqReader{ 170 | state: state, 171 | } 172 | } 173 | 174 | func (r *reqReader) addConn(c *Conn) {} 175 | func (r *reqReader) rmConn(c *Conn) {} 176 | 177 | func (r *reqReader) Close() error { 178 | return nil 179 | } 180 | 181 | func (r *reqReader) read(ctx context.Context, msg *Msg) error { 182 | curConn := r.state.Get() 183 | if curConn == nil { 184 | return fmt.Errorf("zmq4: no connections available") 185 | } 186 | *msg = curConn.read() 187 | if msg.err != nil { 188 | return msg.err 189 | } 190 | if len(msg.Frames) > 1 { 191 | msg.Frames = msg.Frames[1:] 192 | } 193 | return nil 194 | } 195 | 196 | type reqState struct { 197 | mu sync.Mutex 198 | lastConn *Conn 199 | } 200 | 201 | func (r *reqState) Set(conn *Conn) { 202 | r.mu.Lock() 203 | defer r.mu.Unlock() 204 | r.lastConn = conn 205 | } 206 | 207 | // Reset resets the state iff c matches the resident connection 208 | func (r *reqState) Reset(c *Conn) { 209 | r.mu.Lock() 210 | defer r.mu.Unlock() 211 | if r.lastConn == c { 212 | r.lastConn = nil 213 | } 214 | } 215 | 216 | func (r *reqState) Get() *Conn { 217 | r.mu.Lock() 218 | defer r.mu.Unlock() 219 | return r.lastConn 220 | } 221 | 222 | var ( 223 | _ Socket = (*reqSocket)(nil) 224 | ) 225 | -------------------------------------------------------------------------------- /router.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4 6 | 7 | import ( 8 | "bytes" 9 | "context" 10 | "net" 11 | "sync" 12 | 13 | "golang.org/x/sync/errgroup" 14 | ) 15 | 16 | // NewRouter returns a new ROUTER ZeroMQ socket. 17 | // The returned socket value is initially unbound. 18 | func NewRouter(ctx context.Context, opts ...Option) Socket { 19 | router := &routerSocket{newSocket(ctx, Router, opts...)} 20 | router.sck.r = newRouterQReader(router.sck.ctx) 21 | router.sck.w = newRouterMWriter(router.sck.ctx) 22 | return router 23 | } 24 | 25 | // routerSocket is a ROUTER ZeroMQ socket. 26 | type routerSocket struct { 27 | sck *socket 28 | } 29 | 30 | // Close closes the open Socket 31 | func (router *routerSocket) Close() error { 32 | return router.sck.Close() 33 | } 34 | 35 | // Send puts the message on the outbound send queue. 36 | // Send blocks until the message can be queued or the send deadline expires. 37 | func (router *routerSocket) Send(msg Msg) error { 38 | ctx, cancel := context.WithTimeout(router.sck.ctx, router.sck.Timeout()) 39 | defer cancel() 40 | return router.sck.w.write(ctx, msg) 41 | } 42 | 43 | // SendMulti puts the message on the outbound send queue. 44 | // SendMulti blocks until the message can be queued or the send deadline expires. 45 | // The message will be sent as a multipart message. 46 | func (router *routerSocket) SendMulti(msg Msg) error { 47 | msg.multipart = true 48 | return router.Send(msg) 49 | } 50 | 51 | // Recv receives a complete message. 52 | func (router *routerSocket) Recv() (Msg, error) { 53 | return router.sck.Recv() 54 | } 55 | 56 | // Listen connects a local endpoint to the Socket. 57 | func (router *routerSocket) Listen(ep string) error { 58 | return router.sck.Listen(ep) 59 | } 60 | 61 | // Dial connects a remote endpoint to the Socket. 62 | func (router *routerSocket) Dial(ep string) error { 63 | return router.sck.Dial(ep) 64 | } 65 | 66 | // Type returns the type of this Socket (PUB, SUB, ...) 67 | func (router *routerSocket) Type() SocketType { 68 | return router.sck.Type() 69 | } 70 | 71 | // Addr returns the listener's address. 72 | // Addr returns nil if the socket isn't a listener. 73 | func (router *routerSocket) Addr() net.Addr { 74 | return router.sck.Addr() 75 | } 76 | 77 | // GetOption is used to retrieve an option for a socket. 78 | func (router *routerSocket) GetOption(name string) (interface{}, error) { 79 | return router.sck.GetOption(name) 80 | } 81 | 82 | // SetOption is used to set an option for a socket. 83 | func (router *routerSocket) SetOption(name string, value interface{}) error { 84 | return router.sck.SetOption(name, value) 85 | } 86 | 87 | // routerQReader is a queued-message reader. 88 | type routerQReader struct { 89 | ctx context.Context 90 | 91 | mu sync.RWMutex 92 | rs []*Conn 93 | c chan Msg 94 | 95 | sem *semaphore // ready when a connection is live. 96 | } 97 | 98 | func newRouterQReader(ctx context.Context) *routerQReader { 99 | const qrsize = 10 100 | return &routerQReader{ 101 | ctx: ctx, 102 | c: make(chan Msg, qrsize), 103 | sem: newSemaphore(), 104 | } 105 | } 106 | 107 | func (q *routerQReader) Close() error { 108 | q.mu.RLock() 109 | var err error 110 | for _, r := range q.rs { 111 | e := r.Close() 112 | if e != nil && err == nil { 113 | err = e 114 | } 115 | } 116 | q.rs = nil 117 | q.mu.RUnlock() 118 | return err 119 | } 120 | 121 | func (q *routerQReader) addConn(r *Conn) { 122 | q.mu.Lock() 123 | q.sem.enable() 124 | q.rs = append(q.rs, r) 125 | q.mu.Unlock() 126 | go q.listen(q.ctx, r) 127 | } 128 | 129 | func (q *routerQReader) rmConn(r *Conn) { 130 | q.mu.Lock() 131 | defer q.mu.Unlock() 132 | 133 | cur := -1 134 | for i := range q.rs { 135 | if q.rs[i] == r { 136 | cur = i 137 | break 138 | } 139 | } 140 | if cur >= 0 { 141 | q.rs = append(q.rs[:cur], q.rs[cur+1:]...) 142 | } 143 | } 144 | 145 | func (q *routerQReader) read(ctx context.Context, msg *Msg) error { 146 | q.sem.lock(ctx) 147 | select { 148 | case <-ctx.Done(): 149 | return ctx.Err() 150 | case *msg = <-q.c: 151 | } 152 | return msg.err 153 | } 154 | 155 | func (q *routerQReader) listen(ctx context.Context, r *Conn) { 156 | defer q.rmConn(r) 157 | defer r.Close() 158 | 159 | id := []byte(r.Peer.Meta[sysSockID]) 160 | for { 161 | msg := r.read() 162 | select { 163 | case <-ctx.Done(): 164 | return 165 | default: 166 | if msg.err != nil { 167 | return 168 | } 169 | msg.Frames = append([][]byte{id}, msg.Frames...) 170 | q.c <- msg 171 | } 172 | } 173 | } 174 | 175 | type routerMWriter struct { 176 | ctx context.Context 177 | mu sync.Mutex 178 | ws []*Conn 179 | sem *semaphore 180 | } 181 | 182 | func newRouterMWriter(ctx context.Context) *routerMWriter { 183 | return &routerMWriter{ 184 | ctx: ctx, 185 | sem: newSemaphore(), 186 | } 187 | } 188 | 189 | func (w *routerMWriter) Close() error { 190 | w.mu.Lock() 191 | var err error 192 | for _, ww := range w.ws { 193 | e := ww.Close() 194 | if e != nil && err == nil { 195 | err = e 196 | } 197 | } 198 | w.ws = nil 199 | w.mu.Unlock() 200 | return err 201 | } 202 | 203 | func (mw *routerMWriter) addConn(w *Conn) { 204 | mw.mu.Lock() 205 | mw.sem.enable() 206 | mw.ws = append(mw.ws, w) 207 | mw.mu.Unlock() 208 | } 209 | 210 | func (mw *routerMWriter) rmConn(w *Conn) { 211 | mw.mu.Lock() 212 | defer mw.mu.Unlock() 213 | 214 | cur := -1 215 | for i := range mw.ws { 216 | if mw.ws[i] == w { 217 | cur = i 218 | break 219 | } 220 | } 221 | if cur >= 0 { 222 | mw.ws = append(mw.ws[:cur], mw.ws[cur+1:]...) 223 | } 224 | } 225 | 226 | func (w *routerMWriter) write(ctx context.Context, msg Msg) error { 227 | w.sem.lock(ctx) 228 | grp, _ := errgroup.WithContext(ctx) 229 | w.mu.Lock() 230 | id := msg.Frames[0] 231 | dmsg := NewMsgFrom(msg.Frames[1:]...) 232 | for i := range w.ws { 233 | ww := w.ws[i] 234 | pid := []byte(ww.Peer.Meta[sysSockID]) 235 | if !bytes.Equal(pid, id) { 236 | continue 237 | } 238 | grp.Go(func() error { 239 | return ww.SendMsg(dmsg) 240 | }) 241 | } 242 | err := grp.Wait() 243 | w.mu.Unlock() 244 | return err 245 | } 246 | 247 | var ( 248 | _ rpool = (*routerQReader)(nil) 249 | _ wpool = (*routerMWriter)(nil) 250 | _ Socket = (*routerSocket)(nil) 251 | ) 252 | -------------------------------------------------------------------------------- /security.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4 6 | 7 | import ( 8 | "fmt" 9 | "io" 10 | ) 11 | 12 | // Security is an interface for ZMTP security mechanisms 13 | type Security interface { 14 | // Type returns the security mechanism type. 15 | Type() SecurityType 16 | 17 | // Handshake implements the ZMTP security handshake according to 18 | // this security mechanism. 19 | // see: 20 | // https://rfc.zeromq.org/spec:23/ZMTP/ 21 | // https://rfc.zeromq.org/spec:24/ZMTP-PLAIN/ 22 | // https://rfc.zeromq.org/spec:25/ZMTP-CURVE/ 23 | Handshake(conn *Conn, server bool) error 24 | 25 | // Encrypt writes the encrypted form of data to w. 26 | Encrypt(w io.Writer, data []byte) (int, error) 27 | 28 | // Decrypt writes the decrypted form of data to w. 29 | Decrypt(w io.Writer, data []byte) (int, error) 30 | } 31 | 32 | // SecurityType denotes types of ZMTP security mechanisms 33 | type SecurityType string 34 | 35 | const ( 36 | // NullSecurityType is an empty security mechanism 37 | // that does no authentication nor encryption. 38 | NullSecurity SecurityType = "NULL" 39 | 40 | // PlainSecurity is a security mechanism that uses 41 | // plaintext passwords. It is a reference implementation and 42 | // should not be used to anything important. 43 | PlainSecurity SecurityType = "PLAIN" 44 | 45 | // CurveSecurity uses ZMQ_CURVE for authentication 46 | // and encryption. 47 | CurveSecurity SecurityType = "CURVE" 48 | ) 49 | 50 | // security implements the NULL security mechanism. 51 | type nullSecurity struct{} 52 | 53 | // Type returns the security mechanism type. 54 | func (nullSecurity) Type() SecurityType { 55 | return NullSecurity 56 | } 57 | 58 | // Handshake implements the ZMTP security handshake according to 59 | // this security mechanism. 60 | // see: 61 | // 62 | // https://rfc.zeromq.org/spec:23/ZMTP/ 63 | // https://rfc.zeromq.org/spec:24/ZMTP-PLAIN/ 64 | // https://rfc.zeromq.org/spec:25/ZMTP-CURVE/ 65 | func (nullSecurity) Handshake(conn *Conn, server bool) error { 66 | raw, err := conn.Meta.MarshalZMTP() 67 | if err != nil { 68 | return fmt.Errorf("zmq4: could not marshal metadata: %w", err) 69 | } 70 | 71 | err = conn.SendCmd(CmdReady, raw) 72 | if err != nil { 73 | return fmt.Errorf("zmq4: could not send metadata to peer: %w", err) 74 | } 75 | 76 | cmd, err := conn.RecvCmd() 77 | if err != nil { 78 | return fmt.Errorf("zmq4: could not recv metadata from peer: %w", err) 79 | } 80 | 81 | if cmd.Name != CmdReady { 82 | return ErrBadCmd 83 | } 84 | 85 | err = conn.Peer.Meta.UnmarshalZMTP(cmd.Body) 86 | if err != nil { 87 | return fmt.Errorf("zmq4: could not unmarshal peer metadata: %w", err) 88 | } 89 | 90 | return nil 91 | } 92 | 93 | // Encrypt writes the encrypted form of data to w. 94 | func (nullSecurity) Encrypt(w io.Writer, data []byte) (int, error) { 95 | return w.Write(data) 96 | } 97 | 98 | // Decrypt writes the decrypted form of data to w. 99 | func (nullSecurity) Decrypt(w io.Writer, data []byte) (int, error) { 100 | return w.Write(data) 101 | } 102 | 103 | var ( 104 | _ Security = (*nullSecurity)(nil) 105 | ) 106 | -------------------------------------------------------------------------------- /security/null/null.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package null provides the ZeroMQ NULL security mechanism 6 | package null 7 | 8 | import ( 9 | "fmt" 10 | "io" 11 | 12 | "github.com/go-zeromq/zmq4" 13 | ) 14 | 15 | // security implements the NULL security mechanism. 16 | type security struct{} 17 | 18 | // Security returns a value that implements the NULL security mechanism 19 | func Security() zmq4.Security { 20 | return security{} 21 | } 22 | 23 | // Type returns the security mechanism type. 24 | func (security) Type() zmq4.SecurityType { 25 | return zmq4.NullSecurity 26 | } 27 | 28 | // Handshake implements the ZMTP security handshake according to 29 | // this security mechanism. 30 | // see: 31 | // 32 | // https://rfc.zeromq.org/spec:23/ZMTP/ 33 | // https://rfc.zeromq.org/spec:24/ZMTP-PLAIN/ 34 | // https://rfc.zeromq.org/spec:25/ZMTP-CURVE/ 35 | func (security) Handshake(conn *zmq4.Conn, server bool) error { 36 | raw, err := conn.Meta.MarshalZMTP() 37 | if err != nil { 38 | return fmt.Errorf("security/null: could not marshal metadata: %w", err) 39 | } 40 | 41 | err = conn.SendCmd(zmq4.CmdReady, raw) 42 | if err != nil { 43 | return fmt.Errorf("security/null: could not send metadata to peer: %w", err) 44 | } 45 | 46 | cmd, err := conn.RecvCmd() 47 | if err != nil { 48 | return fmt.Errorf("security/null: could not recv metadata from peer: %w", err) 49 | } 50 | 51 | if cmd.Name != zmq4.CmdReady { 52 | return zmq4.ErrBadCmd 53 | } 54 | 55 | err = conn.Peer.Meta.UnmarshalZMTP(cmd.Body) 56 | if err != nil { 57 | return fmt.Errorf("security/null: could not unmarshal peer metadata: %w", err) 58 | } 59 | 60 | return nil 61 | } 62 | 63 | // Encrypt writes the encrypted form of data to w. 64 | func (security) Encrypt(w io.Writer, data []byte) (int, error) { 65 | return w.Write(data) 66 | } 67 | 68 | // Decrypt writes the decrypted form of data to w. 69 | func (security) Decrypt(w io.Writer, data []byte) (int, error) { 70 | return w.Write(data) 71 | } 72 | 73 | var ( 74 | _ zmq4.Security = (*security)(nil) 75 | ) 76 | -------------------------------------------------------------------------------- /security/null/null_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package null_test 6 | 7 | import ( 8 | "bytes" 9 | "context" 10 | "fmt" 11 | "os" 12 | "reflect" 13 | "strings" 14 | "testing" 15 | "time" 16 | 17 | "github.com/go-zeromq/zmq4" 18 | "github.com/go-zeromq/zmq4/security/null" 19 | "golang.org/x/sync/errgroup" 20 | ) 21 | 22 | func TestSecurity(t *testing.T) { 23 | sec := null.Security() 24 | if got, want := sec.Type(), zmq4.NullSecurity; got != want { 25 | t.Fatalf("got=%v, want=%v", got, want) 26 | } 27 | 28 | data := []byte("hello world") 29 | wenc := new(bytes.Buffer) 30 | if _, err := sec.Encrypt(wenc, data); err != nil { 31 | t.Fatalf("error encrypting data: %+v", err) 32 | } 33 | 34 | if !bytes.Equal(wenc.Bytes(), data) { 35 | t.Fatalf("error encrypted data.\ngot = %q\nwant= %q\n", wenc.Bytes(), data) 36 | } 37 | 38 | wdec := new(bytes.Buffer) 39 | if _, err := sec.Decrypt(wdec, wenc.Bytes()); err != nil { 40 | t.Fatalf("error decrypting data: %+v", err) 41 | } 42 | 43 | if !bytes.Equal(wdec.Bytes(), data) { 44 | t.Fatalf("error decrypted data.\ngot = %q\nwant= %q\n", wdec.Bytes(), data) 45 | } 46 | } 47 | 48 | func TestHandshakeReqRep(t *testing.T) { 49 | var ( 50 | reqQuit = zmq4.NewMsgString("QUIT") 51 | repQuit = zmq4.NewMsgString("bye") 52 | ) 53 | 54 | sec := null.Security() 55 | ctx, timeout := context.WithTimeout(context.Background(), 10*time.Second) 56 | defer timeout() 57 | 58 | ep := "ipc://ipc-req-rep-null-sec" 59 | cleanUp(ep) 60 | 61 | req := zmq4.NewReq(ctx, zmq4.WithSecurity(sec)) 62 | defer req.Close() 63 | 64 | rep := zmq4.NewRep(ctx, zmq4.WithSecurity(sec)) 65 | defer rep.Close() 66 | 67 | grp, _ := errgroup.WithContext(ctx) 68 | grp.Go(func() error { 69 | err := rep.Listen(ep) 70 | if err != nil { 71 | return fmt.Errorf("could not listen: %w", err) 72 | } 73 | 74 | msg, err := rep.Recv() 75 | if err != nil { 76 | return fmt.Errorf("could not recv REQ message: %w", err) 77 | } 78 | 79 | if !reflect.DeepEqual(msg, reqQuit) { 80 | return fmt.Errorf("got = %v, want = %v", msg, repQuit) 81 | } 82 | 83 | err = rep.Send(repQuit) 84 | if err != nil { 85 | return fmt.Errorf("could not send REP message: %w", err) 86 | } 87 | 88 | return nil 89 | }) 90 | 91 | grp.Go(func() error { 92 | err := req.Dial(ep) 93 | if err != nil { 94 | return fmt.Errorf("could not dial: %w", err) 95 | } 96 | 97 | err = req.Send(reqQuit) 98 | if err != nil { 99 | return fmt.Errorf("could not send REQ message: %w", err) 100 | } 101 | return nil 102 | }) 103 | 104 | if err := grp.Wait(); err != nil { 105 | t.Fatalf("error: %+v", err) 106 | } 107 | } 108 | 109 | func cleanUp(ep string) { 110 | if strings.HasPrefix(ep, "ipc://") { 111 | os.Remove(ep[len("ipc://"):]) 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /security/plain/plain.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package plain provides the ZeroMQ PLAIN security mechanism as specified by: 6 | // https://rfc.zeromq.org/spec:24/ZMTP-PLAIN/ 7 | package plain 8 | 9 | import ( 10 | "fmt" 11 | "io" 12 | 13 | "github.com/go-zeromq/zmq4" 14 | ) 15 | 16 | // security implements the PLAIN security mechanism. 17 | type security struct { 18 | user []byte 19 | pass []byte 20 | } 21 | 22 | // Security returns a value that implements the PLAIN security mechanism 23 | func Security(user, pass string) zmq4.Security { 24 | return &security{[]byte(user), []byte(pass)} 25 | } 26 | 27 | // Type returns the security mechanism type. 28 | func (security) Type() zmq4.SecurityType { 29 | return zmq4.PlainSecurity 30 | } 31 | 32 | // Handshake implements the ZMTP security handshake according to 33 | // this security mechanism. 34 | // see: 35 | // 36 | // https://rfc.zeromq.org/spec:23/ZMTP/ 37 | // https://rfc.zeromq.org/spec:24/ZMTP-PLAIN/ 38 | // https://rfc.zeromq.org/spec:25/ZMTP-CURVE/ 39 | func (sec *security) Handshake(conn *zmq4.Conn, server bool) error { 40 | switch { 41 | case server: 42 | cmd, err := conn.RecvCmd() 43 | if err != nil { 44 | return fmt.Errorf("security/plain: could not receive HELLO from client: %w", err) 45 | } 46 | 47 | if cmd.Name != zmq4.CmdHello { 48 | return fmt.Errorf("security/plain: expected HELLO command") 49 | } 50 | 51 | // FIXME(sbinet): perform a real authentication 52 | err = validateHello(cmd.Body) 53 | if err != nil { 54 | _ = conn.SendCmd(zmq4.CmdError, []byte("invalid")) // FIXME(sbinet) correct ERROR reason 55 | return fmt.Errorf("security/plain: could not authenticate client: %w", err) 56 | } 57 | 58 | err = conn.SendCmd(zmq4.CmdWelcome, nil) 59 | if err != nil { 60 | return fmt.Errorf("security/plain: could not send WELCOME to client: %w", err) 61 | } 62 | 63 | cmd, err = conn.RecvCmd() 64 | if err != nil { 65 | return fmt.Errorf("security/plain: could not receive INITIATE from client: %w", err) 66 | } 67 | 68 | err = conn.Peer.Meta.UnmarshalZMTP(cmd.Body) 69 | if err != nil { 70 | return fmt.Errorf("security/plain: could not unmarshal peer metadata: %w", err) 71 | } 72 | 73 | raw, err := conn.Meta.MarshalZMTP() 74 | if err != nil { 75 | _ = conn.SendCmd(zmq4.CmdError, []byte("invalid")) // FIXME(sbinet) correct ERROR reason 76 | return fmt.Errorf("security/plain: could not serialize metadata: %w", err) 77 | } 78 | 79 | err = conn.SendCmd(zmq4.CmdReady, raw) 80 | if err != nil { 81 | return fmt.Errorf("security/plain: could not send READY to client: %w", err) 82 | } 83 | 84 | case !server: 85 | hello := make([]byte, 0, len(sec.user)+len(sec.pass)+2) 86 | hello = append(hello, byte(len(sec.user))) 87 | hello = append(hello, sec.user...) 88 | hello = append(hello, byte(len(sec.pass))) 89 | hello = append(hello, sec.pass...) 90 | 91 | err := conn.SendCmd(zmq4.CmdHello, hello) 92 | if err != nil { 93 | return fmt.Errorf("security/plain: could not send HELLO to server: %w", err) 94 | } 95 | 96 | cmd, err := conn.RecvCmd() 97 | if err != nil { 98 | return fmt.Errorf("security/plain: could not receive WELCOME from server: %w", err) 99 | } 100 | if cmd.Name != zmq4.CmdWelcome { 101 | _ = conn.SendCmd(zmq4.CmdError, []byte("invalid command")) // FIXME(sbinet) correct ERROR reason 102 | return fmt.Errorf("security/plain: expected a WELCOME command from server: %w", err) 103 | } 104 | 105 | raw, err := conn.Meta.MarshalZMTP() 106 | if err != nil { 107 | _ = conn.SendCmd(zmq4.CmdError, []byte("internal error")) // FIXME(sbinet) correct ERROR reason 108 | return fmt.Errorf("security/plain: could not serialize metadata: %w", err) 109 | } 110 | 111 | err = conn.SendCmd(zmq4.CmdInitiate, raw) 112 | if err != nil { 113 | return fmt.Errorf("security/plain: could not send INITIATE to server: %w", err) 114 | } 115 | 116 | cmd, err = conn.RecvCmd() 117 | if err != nil { 118 | return fmt.Errorf("security/plain: could not receive READY from server: %w", err) 119 | } 120 | if cmd.Name != zmq4.CmdReady { 121 | _ = conn.SendCmd(zmq4.CmdError, []byte("invalid command")) // FIXME(sbinet) correct ERROR reason 122 | return fmt.Errorf("security/plain: expected a READY command from server: %w", err) 123 | } 124 | 125 | err = conn.Peer.Meta.UnmarshalZMTP(cmd.Body) 126 | if err != nil { 127 | return fmt.Errorf("security/plain: could not unmarshal peer metadata: %w", err) 128 | } 129 | 130 | sec.user = nil 131 | sec.pass = nil 132 | } 133 | return nil 134 | } 135 | 136 | // Encrypt writes the encrypted form of data to w. 137 | func (security) Encrypt(w io.Writer, data []byte) (int, error) { 138 | return w.Write(data) 139 | } 140 | 141 | // Decrypt writes the decrypted form of data to w. 142 | func (security) Decrypt(w io.Writer, data []byte) (int, error) { 143 | return w.Write(data) 144 | } 145 | 146 | // validateHello validates the user/passwd credentials. 147 | func validateHello(body []byte) error { 148 | // n := int(body[0]) 149 | // user := body[1 : 1+n] 150 | // body = body[1+n:] 151 | // n = int(body[0]) 152 | // pass := body[1 : 1+n] 153 | // body = body[1+n:] 154 | // log.Printf("user=%q, pass=%q, body=%q", user, pass, body) 155 | return nil 156 | } 157 | 158 | var ( 159 | _ zmq4.Security = (*security)(nil) 160 | ) 161 | -------------------------------------------------------------------------------- /security/plain/plain_cxx_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | //go:build czmq4 6 | // +build czmq4 7 | 8 | package plain_test 9 | 10 | import ( 11 | "context" 12 | "fmt" 13 | "os" 14 | "reflect" 15 | "testing" 16 | "time" 17 | 18 | czmq4 "github.com/go-zeromq/goczmq/v4" 19 | "github.com/go-zeromq/zmq4" 20 | "github.com/go-zeromq/zmq4/security/plain" 21 | "golang.org/x/sync/errgroup" 22 | ) 23 | 24 | var ( 25 | repQuit = zmq4.NewMsgString("bye") 26 | ) 27 | 28 | func TestMain(m *testing.M) { 29 | auth := czmq4.NewAuth() 30 | 31 | err := auth.Allow("127.0.0.1") 32 | if err != nil { 33 | auth.Destroy() 34 | panic(err) 35 | } 36 | 37 | err = auth.Plain("./testdata/password.txt") 38 | if err != nil { 39 | auth.Destroy() 40 | panic(err) 41 | } 42 | 43 | // call flag.Parse() here if TestMain uses flags 44 | 45 | exit := m.Run() 46 | 47 | auth.Destroy() 48 | os.Exit(exit) 49 | } 50 | 51 | func TestHandshakeReqCRep(t *testing.T) { 52 | t.Skipf("REQ-CREP") 53 | 54 | sec := plain.Security("user", "secret") 55 | if got, want := sec.Type(), zmq4.PlainSecurity; got != want { 56 | t.Fatalf("got=%v, want=%v", got, want) 57 | } 58 | 59 | ctx, timeout := context.WithTimeout(context.Background(), 10*time.Second) 60 | defer timeout() 61 | 62 | ep := must(EndPoint("tcp")) 63 | 64 | req := zmq4.NewReq(ctx, zmq4.WithSecurity(sec)) 65 | defer req.Close() 66 | 67 | rep := zmq4.NewCRep(ctx, czmq4.SockSetZapDomain("global"), czmq4.SockSetPlainServer(1)) 68 | defer rep.Close() 69 | 70 | grp, ctx := errgroup.WithContext(ctx) 71 | grp.Go(func() error { 72 | err := rep.Listen(ep) 73 | if err != nil { 74 | return fmt.Errorf("could not listen: %w", err) 75 | } 76 | 77 | msg, err := rep.Recv() 78 | if err != nil { 79 | return fmt.Errorf("could not recv REQ message: %w", err) 80 | } 81 | 82 | if !reflect.DeepEqual(msg, reqQuit) { 83 | return fmt.Errorf("got = %v, want = %v", msg, repQuit) 84 | } 85 | 86 | err = rep.Send(repQuit) 87 | if err != nil { 88 | return fmt.Errorf("could not send REP message: %w", err) 89 | } 90 | 91 | return nil 92 | }) 93 | 94 | grp.Go(func() error { 95 | err := req.Dial(ep) 96 | if err != nil { 97 | return fmt.Errorf("could not dial: %w", err) 98 | } 99 | 100 | err = req.Send(reqQuit) 101 | if err != nil { 102 | return fmt.Errorf("could not send REQ message: %w", err) 103 | } 104 | msg, err := req.Recv() 105 | if err != nil { 106 | return fmt.Errorf("could not recv REQ message: %w", err) 107 | } 108 | 109 | if !reflect.DeepEqual(msg, repQuit) { 110 | return fmt.Errorf("got = %v, want = %v", msg, repQuit) 111 | } 112 | return nil 113 | }) 114 | 115 | if err := grp.Wait(); err != nil { 116 | t.Fatalf("error: %+v", err) 117 | } 118 | } 119 | 120 | func TestHandshakeCReqRep(t *testing.T) { 121 | t.Skipf("CREQ-REP") 122 | 123 | sec := plain.Security("user", "secret") 124 | if got, want := sec.Type(), zmq4.PlainSecurity; got != want { 125 | t.Fatalf("got=%v, want=%v", got, want) 126 | } 127 | 128 | ctx, timeout := context.WithTimeout(context.Background(), 10*time.Second) 129 | defer timeout() 130 | 131 | ep := must(EndPoint("tcp")) 132 | 133 | req := zmq4.NewCReq(ctx, czmq4.SockSetPlainUsername("user"), czmq4.SockSetPlainPassword("secret")) 134 | defer req.Close() 135 | 136 | rep := zmq4.NewRep(ctx, zmq4.WithSecurity(sec)) 137 | defer rep.Close() 138 | 139 | grp, ctx := errgroup.WithContext(ctx) 140 | grp.Go(func() error { 141 | err := rep.Listen(ep) 142 | if err != nil { 143 | return fmt.Errorf("could not listen: %w", err) 144 | } 145 | 146 | msg, err := rep.Recv() 147 | if err != nil { 148 | return fmt.Errorf("could not recv REQ message: %w", err) 149 | } 150 | 151 | if !reflect.DeepEqual(msg, reqQuit) { 152 | return fmt.Errorf("got = %v, want = %v", msg, repQuit) 153 | } 154 | 155 | err = rep.Send(repQuit) 156 | if err != nil { 157 | return fmt.Errorf("could not send REP message: %w", err) 158 | } 159 | 160 | return nil 161 | }) 162 | 163 | grp.Go(func() error { 164 | err := req.Dial(ep) 165 | if err != nil { 166 | return fmt.Errorf("could not dial: %w", err) 167 | } 168 | 169 | err = req.Send(reqQuit) 170 | if err != nil { 171 | return fmt.Errorf("could not send REQ message: %w", err) 172 | } 173 | msg, err := req.Recv() 174 | if err != nil { 175 | return fmt.Errorf("could not recv REQ message: %w", err) 176 | } 177 | 178 | if !reflect.DeepEqual(msg, repQuit) { 179 | return fmt.Errorf("got = %v, want = %v", msg, repQuit) 180 | } 181 | return nil 182 | }) 183 | 184 | if err := grp.Wait(); err != nil { 185 | t.Fatalf("error: %+v", err) 186 | } 187 | } 188 | 189 | func TestHandshakeCReqCRep(t *testing.T) { 190 | t.Skipf("CREQ-CREP") 191 | 192 | sec := plain.Security("user", "secret") 193 | if got, want := sec.Type(), zmq4.PlainSecurity; got != want { 194 | t.Fatalf("got=%v, want=%v", got, want) 195 | } 196 | 197 | ctx, timeout := context.WithTimeout(context.Background(), 10*time.Second) 198 | defer timeout() 199 | 200 | ep := must(EndPoint("tcp")) 201 | 202 | req := zmq4.NewCReq(ctx, czmq4.SockSetPlainUsername("user"), czmq4.SockSetPlainPassword("secret")) 203 | defer req.Close() 204 | 205 | rep := zmq4.NewCRep(ctx, czmq4.SockSetZapDomain("global"), czmq4.SockSetPlainServer(1)) 206 | defer rep.Close() 207 | 208 | grp, ctx := errgroup.WithContext(ctx) 209 | grp.Go(func() error { 210 | err := rep.Listen(ep) 211 | if err != nil { 212 | return fmt.Errorf("could not listen: %w", err) 213 | } 214 | 215 | msg, err := rep.Recv() 216 | if err != nil { 217 | return fmt.Errorf("could not recv REQ message: %w", err) 218 | } 219 | 220 | if !reflect.DeepEqual(msg, reqQuit) { 221 | return fmt.Errorf("got = %v, want = %v", msg, repQuit) 222 | } 223 | 224 | err = rep.Send(repQuit) 225 | if err != nil { 226 | return fmt.Errorf("could not send REP message: %w", err) 227 | } 228 | 229 | return nil 230 | }) 231 | 232 | grp.Go(func() error { 233 | err := req.Dial(ep) 234 | if err != nil { 235 | return fmt.Errorf("could not dial: %w", err) 236 | } 237 | 238 | err = req.Send(reqQuit) 239 | if err != nil { 240 | return fmt.Errorf("could not send REQ message: %w", err) 241 | } 242 | msg, err := req.Recv() 243 | if err != nil { 244 | return fmt.Errorf("could not recv REQ message: %w", err) 245 | } 246 | 247 | if !reflect.DeepEqual(msg, repQuit) { 248 | return fmt.Errorf("got = %v, want = %v", msg, repQuit) 249 | } 250 | return nil 251 | }) 252 | 253 | if err := grp.Wait(); err != nil { 254 | t.Fatalf("error: %+v", err) 255 | } 256 | } 257 | -------------------------------------------------------------------------------- /security/plain/plain_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package plain_test 6 | 7 | import ( 8 | "bytes" 9 | "context" 10 | "crypto/rand" 11 | "fmt" 12 | "io" 13 | "log" 14 | "net" 15 | "testing" 16 | "time" 17 | 18 | "github.com/go-zeromq/zmq4" 19 | "github.com/go-zeromq/zmq4/security/plain" 20 | "golang.org/x/sync/errgroup" 21 | ) 22 | 23 | var ( 24 | reqQuit = zmq4.NewMsgString("QUIT") 25 | ) 26 | 27 | func TestSecurity(t *testing.T) { 28 | sec := plain.Security("user", "secret") 29 | if got, want := sec.Type(), zmq4.PlainSecurity; got != want { 30 | t.Fatalf("got=%v, want=%v", got, want) 31 | } 32 | 33 | data := []byte("hello world") 34 | wenc := new(bytes.Buffer) 35 | if _, err := sec.Encrypt(wenc, data); err != nil { 36 | t.Fatalf("error encrypting data: %+v", err) 37 | } 38 | 39 | if !bytes.Equal(wenc.Bytes(), data) { 40 | t.Fatalf("error encrypted data.\ngot = %q\nwant= %q\n", wenc.Bytes(), data) 41 | } 42 | 43 | wdec := new(bytes.Buffer) 44 | if _, err := sec.Decrypt(wdec, wenc.Bytes()); err != nil { 45 | t.Fatalf("error decrypting data: %+v", err) 46 | } 47 | 48 | if !bytes.Equal(wdec.Bytes(), data) { 49 | t.Fatalf("error decrypted data.\ngot = %q\nwant= %q\n", wdec.Bytes(), data) 50 | } 51 | } 52 | 53 | func TestHandshakeReqRep(t *testing.T) { 54 | sec := plain.Security("user", "secret") 55 | if got, want := sec.Type(), zmq4.PlainSecurity; got != want { 56 | t.Fatalf("got=%v, want=%v", got, want) 57 | } 58 | 59 | ctx, timeout := context.WithTimeout(context.Background(), 10*time.Second) 60 | defer timeout() 61 | 62 | ep := must(EndPoint("tcp")) 63 | 64 | req := zmq4.NewReq(ctx, zmq4.WithSecurity(sec)) 65 | defer req.Close() 66 | 67 | rep := zmq4.NewRep(ctx, zmq4.WithSecurity(sec)) 68 | defer rep.Close() 69 | 70 | grp, _ := errgroup.WithContext(ctx) 71 | grp.Go(func() error { 72 | err := rep.Listen(ep) 73 | if err != nil { 74 | return fmt.Errorf("could not listen: %w", err) 75 | } 76 | 77 | msg, err := rep.Recv() 78 | if err != nil { 79 | return fmt.Errorf("could not recv REQ message: %w", err) 80 | } 81 | if string(msg.Frames[0]) != "QUIT" { 82 | return fmt.Errorf("received wrong REQ message: %#v", msg) 83 | } 84 | return nil 85 | }) 86 | 87 | grp.Go(func() error { 88 | err := req.Dial(ep) 89 | if err != nil { 90 | return fmt.Errorf("could not dial: %w", err) 91 | } 92 | 93 | err = req.Send(reqQuit) 94 | if err != nil { 95 | return fmt.Errorf("could not send REQ message: %w", err) 96 | } 97 | return nil 98 | }) 99 | 100 | if err := grp.Wait(); err != nil { 101 | t.Fatalf("error: %+v", err) 102 | } 103 | } 104 | 105 | func must(str string, err error) string { 106 | if err != nil { 107 | panic(err) 108 | } 109 | return str 110 | } 111 | 112 | func EndPoint(transport string) (string, error) { 113 | switch transport { 114 | case "tcp": 115 | addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0") 116 | if err != nil { 117 | return "", fmt.Errorf("could not resolve TCP address: %w", err) 118 | } 119 | l, err := net.ListenTCP("tcp", addr) 120 | if err != nil { 121 | return "", fmt.Errorf("could not listen to TCP addr=%q: %w", addr, err) 122 | } 123 | defer l.Close() 124 | return fmt.Sprintf("tcp://%s", l.Addr()), nil 125 | case "ipc": 126 | return "ipc://tmp-" + newUUID(), nil 127 | case "inproc": 128 | return "inproc://tmp-" + newUUID(), nil 129 | default: 130 | panic("invalid transport: [" + transport + "]") 131 | } 132 | } 133 | 134 | func newUUID() string { 135 | var uuid [16]byte 136 | if _, err := io.ReadFull(rand.Reader, uuid[:]); err != nil { 137 | log.Fatalf("cannot generate random data for UUID: %v", err) 138 | } 139 | uuid[8] = uuid[8]&^0xc0 | 0x80 140 | uuid[6] = uuid[6]&^0xf0 | 0x40 141 | return fmt.Sprintf("%x-%x-%x-%x-%x", uuid[:4], uuid[4:6], uuid[6:8], uuid[8:10], uuid[10:]) 142 | } 143 | -------------------------------------------------------------------------------- /security/plain/testdata/password.txt: -------------------------------------------------------------------------------- 1 | user=secret 2 | 3 | -------------------------------------------------------------------------------- /security_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4 6 | 7 | import ( 8 | "bytes" 9 | "context" 10 | "fmt" 11 | "os" 12 | "reflect" 13 | "strings" 14 | "testing" 15 | "time" 16 | 17 | "golang.org/x/sync/errgroup" 18 | ) 19 | 20 | func TestNullSecurity(t *testing.T) { 21 | sec := nullSecurity{} 22 | if got, want := sec.Type(), NullSecurity; got != want { 23 | t.Fatalf("got=%v, want=%v", got, want) 24 | } 25 | 26 | data := []byte("hello world") 27 | wenc := new(bytes.Buffer) 28 | if _, err := sec.Encrypt(wenc, data); err != nil { 29 | t.Fatalf("error encrypting data: %+v", err) 30 | } 31 | 32 | if !bytes.Equal(wenc.Bytes(), data) { 33 | t.Fatalf("error encrypted data.\ngot = %q\nwant= %q\n", wenc.Bytes(), data) 34 | } 35 | 36 | wdec := new(bytes.Buffer) 37 | if _, err := sec.Decrypt(wdec, wenc.Bytes()); err != nil { 38 | t.Fatalf("error decrypting data: %+v", err) 39 | } 40 | 41 | if !bytes.Equal(wdec.Bytes(), data) { 42 | t.Fatalf("error decrypted data.\ngot = %q\nwant= %q\n", wdec.Bytes(), data) 43 | } 44 | } 45 | 46 | func TestNullHandshakeReqRep(t *testing.T) { 47 | var ( 48 | reqQuit = NewMsgString("QUIT") 49 | repQuit = NewMsgString("bye") 50 | ) 51 | 52 | sec := nullSecurity{} 53 | ctx, timeout := context.WithTimeout(context.Background(), 10*time.Second) 54 | defer timeout() 55 | 56 | ep := "ipc://ipc-req-rep-null-sec" 57 | cleanUp(ep) 58 | 59 | req := NewReq(ctx, WithSecurity(sec), WithLogger(Devnull)) 60 | defer req.Close() 61 | 62 | rep := NewRep(ctx, WithSecurity(sec), WithLogger(Devnull)) 63 | defer rep.Close() 64 | 65 | grp, _ := errgroup.WithContext(ctx) 66 | grp.Go(func() error { 67 | err := rep.Listen(ep) 68 | if err != nil { 69 | return fmt.Errorf("could not listen: %w", err) 70 | } 71 | 72 | msg, err := rep.Recv() 73 | if err != nil { 74 | return fmt.Errorf("could not recv REQ message: %w", err) 75 | } 76 | 77 | if !reflect.DeepEqual(msg, reqQuit) { 78 | return fmt.Errorf("got = %v, want = %v", msg, repQuit) 79 | } 80 | 81 | err = rep.Send(repQuit) 82 | if err != nil { 83 | return fmt.Errorf("could not send REP message: %w", err) 84 | } 85 | 86 | return nil 87 | }) 88 | 89 | grp.Go(func() error { 90 | err := req.Dial(ep) 91 | if err != nil { 92 | return fmt.Errorf("could not dial: %w", err) 93 | } 94 | 95 | err = req.Send(reqQuit) 96 | if err != nil { 97 | return fmt.Errorf("could not send REQ message: %w", err) 98 | } 99 | return nil 100 | }) 101 | 102 | if err := grp.Wait(); err != nil { 103 | t.Fatalf("error: %+v", err) 104 | } 105 | } 106 | 107 | func cleanUp(ep string) { 108 | if strings.HasPrefix(ep, "ipc://") { 109 | os.Remove(ep[len("ipc://"):]) 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /socket_types.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4 6 | 7 | // SocketType is a ZeroMQ socket type. 8 | type SocketType string 9 | 10 | const ( 11 | Pair SocketType = "PAIR" // a ZMQ_PAIR socket 12 | Pub SocketType = "PUB" // a ZMQ_PUB socket 13 | Sub SocketType = "SUB" // a ZMQ_SUB socket 14 | Req SocketType = "REQ" // a ZMQ_REQ socket 15 | Rep SocketType = "REP" // a ZMQ_REP socket 16 | Dealer SocketType = "DEALER" // a ZMQ_DEALER socket 17 | Router SocketType = "ROUTER" // a ZMQ_ROUTER socket 18 | Pull SocketType = "PULL" // a ZMQ_PULL socket 19 | Push SocketType = "PUSH" // a ZMQ_PUSH socket 20 | XPub SocketType = "XPUB" // a ZMQ_XPUB socket 21 | XSub SocketType = "XSUB" // a ZMQ_XSUB socket 22 | ) 23 | 24 | // IsCompatible checks whether two sockets are compatible and thus 25 | // can be connected together. 26 | // See https://rfc.zeromq.org/spec:23/ZMTP/ for more informations. 27 | func (sck SocketType) IsCompatible(peer SocketType) bool { 28 | switch sck { 29 | case Pair: 30 | if peer == Pair { 31 | return true 32 | } 33 | case Pub: 34 | switch peer { 35 | case Sub, XSub: 36 | return true 37 | } 38 | case Sub: 39 | switch peer { 40 | case Pub, XPub: 41 | return true 42 | } 43 | case Req: 44 | switch peer { 45 | case Rep, Router: 46 | return true 47 | } 48 | case Rep: 49 | switch peer { 50 | case Req, Dealer: 51 | return true 52 | } 53 | case Dealer: 54 | switch peer { 55 | case Rep, Dealer, Router: 56 | return true 57 | } 58 | case Router: 59 | switch peer { 60 | case Req, Dealer, Router: 61 | return true 62 | } 63 | case Pull: 64 | switch peer { 65 | case Push: 66 | return true 67 | } 68 | case Push: 69 | switch peer { 70 | case Pull: 71 | return true 72 | } 73 | case XPub: 74 | switch peer { 75 | case Sub, XSub: 76 | return true 77 | } 78 | case XSub: 79 | switch peer { 80 | case Pub, XPub: 81 | return true 82 | } 83 | default: 84 | panic("unknown socket-type: \"" + string(sck) + "\"") 85 | } 86 | 87 | return false 88 | } 89 | 90 | // SocketIdentity is the ZMTP metadata socket identity. 91 | // See: 92 | // 93 | // https://rfc.zeromq.org/spec:23/ZMTP/. 94 | type SocketIdentity []byte 95 | 96 | func (id SocketIdentity) String() string { 97 | n := len(id) 98 | if n > 255 { // ZMTP identities are: 0*255OCTET 99 | n = 255 100 | } 101 | return string(id[:n]) 102 | } 103 | -------------------------------------------------------------------------------- /sub.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4 6 | 7 | import ( 8 | "context" 9 | "net" 10 | "sort" 11 | "sync" 12 | ) 13 | 14 | // NewSub returns a new SUB ZeroMQ socket. 15 | // The returned socket value is initially unbound. 16 | func NewSub(ctx context.Context, opts ...Option) Socket { 17 | sub := &subSocket{sck: newSocket(ctx, Sub, opts...)} 18 | sub.sck.r = newQReader(sub.sck.ctx) 19 | sub.sck.subTopics = sub.Topics 20 | sub.topics = make(map[string]struct{}) 21 | return sub 22 | } 23 | 24 | // subSocket is a SUB ZeroMQ socket. 25 | type subSocket struct { 26 | sck *socket 27 | 28 | mu sync.RWMutex 29 | topics map[string]struct{} 30 | } 31 | 32 | // Close closes the open Socket 33 | func (sub *subSocket) Close() error { 34 | return sub.sck.Close() 35 | } 36 | 37 | // Send puts the message on the outbound send queue. 38 | // Send blocks until the message can be queued or the send deadline expires. 39 | func (sub *subSocket) Send(msg Msg) error { 40 | return sub.sck.Send(msg) 41 | } 42 | 43 | // SendMulti puts the message on the outbound send queue. 44 | // SendMulti blocks until the message can be queued or the send deadline expires. 45 | // The message will be sent as a multipart message. 46 | func (sub *subSocket) SendMulti(msg Msg) error { 47 | return sub.sck.SendMulti(msg) 48 | } 49 | 50 | // Recv receives a complete message. 51 | func (sub *subSocket) Recv() (Msg, error) { 52 | return sub.sck.Recv() 53 | } 54 | 55 | // Listen connects a local endpoint to the Socket. 56 | func (sub *subSocket) Listen(ep string) error { 57 | return sub.sck.Listen(ep) 58 | } 59 | 60 | // Dial connects a remote endpoint to the Socket. 61 | func (sub *subSocket) Dial(ep string) error { 62 | err := sub.sck.Dial(ep) 63 | if err != nil { 64 | return err 65 | } 66 | return nil 67 | } 68 | 69 | // Type returns the type of this Socket (PUB, SUB, ...) 70 | func (sub *subSocket) Type() SocketType { 71 | return sub.sck.Type() 72 | } 73 | 74 | // Addr returns the listener's address. 75 | // Addr returns nil if the socket isn't a listener. 76 | func (sub *subSocket) Addr() net.Addr { 77 | return sub.sck.Addr() 78 | } 79 | 80 | // GetOption is used to retrieve an option for a socket. 81 | func (sub *subSocket) GetOption(name string) (interface{}, error) { 82 | return sub.sck.GetOption(name) 83 | } 84 | 85 | // SetOption is used to set an option for a socket. 86 | func (sub *subSocket) SetOption(name string, value interface{}) error { 87 | err := sub.sck.SetOption(name, value) 88 | if err != nil { 89 | return err 90 | } 91 | 92 | var ( 93 | topic []byte 94 | ) 95 | 96 | switch name { 97 | case OptionSubscribe: 98 | k := value.(string) 99 | sub.subscribe(k, 1) 100 | topic = append([]byte{1}, k...) 101 | 102 | case OptionUnsubscribe: 103 | k := value.(string) 104 | topic = append([]byte{0}, k...) 105 | sub.subscribe(k, 0) 106 | 107 | default: 108 | return ErrBadProperty 109 | } 110 | 111 | sub.sck.mu.RLock() 112 | if len(sub.sck.conns) > 0 { 113 | err = sub.Send(NewMsg(topic)) 114 | } 115 | sub.sck.mu.RUnlock() 116 | return err 117 | } 118 | 119 | // Topics returns the sorted list of topics a socket is subscribed to. 120 | func (sub *subSocket) Topics() []string { 121 | sub.mu.RLock() 122 | var topics = make([]string, 0, len(sub.topics)) 123 | for topic := range sub.topics { 124 | topics = append(topics, topic) 125 | } 126 | sub.mu.RUnlock() 127 | sort.Strings(topics) 128 | return topics 129 | } 130 | 131 | func (sub *subSocket) subscribe(topic string, v int) { 132 | sub.mu.Lock() 133 | switch v { 134 | case 0: 135 | delete(sub.topics, topic) 136 | case 1: 137 | sub.topics[topic] = struct{}{} 138 | } 139 | sub.mu.Unlock() 140 | } 141 | 142 | var ( 143 | _ Socket = (*subSocket)(nil) 144 | _ Topics = (*subSocket)(nil) 145 | ) 146 | -------------------------------------------------------------------------------- /transport.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4 6 | 7 | import ( 8 | "fmt" 9 | "sort" 10 | "sync" 11 | 12 | "github.com/go-zeromq/zmq4/internal/inproc" 13 | "github.com/go-zeromq/zmq4/transport" 14 | ) 15 | 16 | // UnknownTransportError records an error when trying to 17 | // use an unknown transport. 18 | type UnknownTransportError struct { 19 | Name string 20 | } 21 | 22 | func (ute UnknownTransportError) Error() string { 23 | return fmt.Sprintf("zmq4: unknown transport %q", ute.Name) 24 | } 25 | 26 | var _ error = (*UnknownTransportError)(nil) 27 | 28 | // Transports returns the sorted list of currently registered transports. 29 | func Transports() []string { 30 | return drivers.names() 31 | } 32 | 33 | // RegisterTransport registers a new transport with the zmq4 package. 34 | func RegisterTransport(name string, trans transport.Transport) error { 35 | return drivers.add(name, trans) 36 | } 37 | 38 | type transports struct { 39 | sync.RWMutex 40 | db map[string]transport.Transport 41 | } 42 | 43 | func (ts *transports) get(name string) (transport.Transport, bool) { 44 | ts.RLock() 45 | defer ts.RUnlock() 46 | 47 | v, ok := ts.db[name] 48 | return v, ok 49 | } 50 | 51 | func (ts *transports) add(name string, trans transport.Transport) error { 52 | ts.Lock() 53 | defer ts.Unlock() 54 | 55 | if old, dup := ts.db[name]; dup { 56 | return fmt.Errorf("zmq4: duplicate transport %q (%T)", name, old) 57 | } 58 | 59 | ts.db[name] = trans 60 | return nil 61 | } 62 | 63 | func (ts *transports) names() []string { 64 | ts.RLock() 65 | defer ts.RUnlock() 66 | 67 | o := make([]string, 0, len(ts.db)) 68 | for k := range ts.db { 69 | o = append(o, k) 70 | } 71 | sort.Strings(o) 72 | return o 73 | } 74 | 75 | var drivers = transports{ 76 | db: make(map[string]transport.Transport), 77 | } 78 | 79 | func init() { 80 | must := func(err error) { 81 | if err != nil { 82 | panic(fmt.Errorf("%+v", err)) 83 | } 84 | } 85 | 86 | must(RegisterTransport("ipc", transport.New("unix"))) 87 | must(RegisterTransport("tcp", transport.New("tcp"))) 88 | must(RegisterTransport("udp", transport.New("udp"))) 89 | must(RegisterTransport("inproc", inproc.Transport{})) 90 | } 91 | -------------------------------------------------------------------------------- /transport/transport.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package transport defines the Transport interface and provides a net-based 6 | // implementation that can be used by zmq4 sockets to exchange messages. 7 | package transport // import "github.com/go-zeromq/zmq4/transport" 8 | 9 | import ( 10 | "context" 11 | "fmt" 12 | "net" 13 | ) 14 | 15 | // Dialer is the interface that wraps the DialContext method. 16 | type Dialer interface { 17 | DialContext(ctx context.Context, network, address string) (net.Conn, error) 18 | } 19 | 20 | // Transport is the zmq4 transport interface that wraps 21 | // the Dial and Listen methods. 22 | type Transport interface { 23 | // Dial connects to the address on the named network using the provided 24 | // context. 25 | Dial(ctx context.Context, dialer Dialer, addr string) (net.Conn, error) 26 | 27 | // Listen announces on the provided network address. 28 | Listen(ctx context.Context, addr string) (net.Listener, error) 29 | 30 | // Addr returns the end-point address. 31 | Addr(ep string) (addr string, err error) 32 | } 33 | 34 | // netTransport implements the Transport interface using the net package. 35 | type netTransport struct { 36 | prot string 37 | } 38 | 39 | // New returns a new net-based transport with the given network (e.g "tcp"). 40 | func New(network string) Transport { 41 | return netTransport{prot: network} 42 | } 43 | 44 | // Dial connects to the address on the named network using the provided 45 | // context. 46 | func (trans netTransport) Dial(ctx context.Context, dialer Dialer, addr string) (net.Conn, error) { 47 | return dialer.DialContext(ctx, trans.prot, addr) 48 | } 49 | 50 | // Listen announces on the provided network address. 51 | func (trans netTransport) Listen(ctx context.Context, addr string) (net.Listener, error) { 52 | return net.Listen(trans.prot, addr) 53 | } 54 | 55 | // Addr returns the end-point address. 56 | func (trans netTransport) Addr(ep string) (addr string, err error) { 57 | switch trans.prot { 58 | case "tcp", "udp": 59 | host, port, err := net.SplitHostPort(ep) 60 | if err != nil { 61 | return addr, err 62 | } 63 | switch port { 64 | case "0", "*", "": 65 | port = "0" 66 | } 67 | switch host { 68 | case "", "*": 69 | host = "0.0.0.0" 70 | } 71 | addr = net.JoinHostPort(host, port) 72 | return addr, err 73 | 74 | case "unix": 75 | return ep, nil 76 | 77 | default: 78 | err = fmt.Errorf("zmq4: unknown protocol %q", trans.prot) 79 | } 80 | 81 | return addr, err 82 | } 83 | 84 | var ( 85 | _ Transport = (*netTransport)(nil) 86 | ) 87 | -------------------------------------------------------------------------------- /transport_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4 6 | 7 | import ( 8 | "reflect" 9 | "testing" 10 | 11 | "github.com/go-zeromq/zmq4/internal/inproc" 12 | ) 13 | 14 | func TestTransport(t *testing.T) { 15 | if got, want := Transports(), []string{"inproc", "ipc", "tcp", "udp"}; !reflect.DeepEqual(got, want) { 16 | t.Fatalf("invalid list of transports.\ngot= %q\nwant=%q", got, want) 17 | } 18 | 19 | err := RegisterTransport("tcp", inproc.Transport{}) 20 | if err == nil { 21 | t.Fatalf("expected a duplicate-registration error") 22 | } 23 | if got, want := err.Error(), "zmq4: duplicate transport \"tcp\" (transport.netTransport)"; got != want { 24 | t.Fatalf("invalid duplicate registration error:\ngot= %s\nwant=%s", got, want) 25 | } 26 | 27 | err = RegisterTransport("inproc2", inproc.Transport{}) 28 | if err != nil { 29 | t.Fatalf("could not register 'inproc2': %+v", err) 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /utils.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4 6 | 7 | import ( 8 | "crypto/rand" 9 | "fmt" 10 | "io" 11 | "log" 12 | "strings" 13 | ) 14 | 15 | // splitAddr returns the triplet (network, addr, error) 16 | func splitAddr(v string) (network, addr string, err error) { 17 | ep := strings.Split(v, "://") 18 | if len(ep) != 2 { 19 | err = errInvalidAddress 20 | return network, addr, err 21 | } 22 | network = ep[0] 23 | 24 | trans, ok := drivers.get(network) 25 | if !ok { 26 | err = fmt.Errorf("zmq4: unknown transport %q", network) 27 | return network, addr, err 28 | } 29 | 30 | addr, err = trans.Addr(ep[1]) 31 | return network, addr, err 32 | } 33 | 34 | func newUUID() string { 35 | var uuid [16]byte 36 | if _, err := io.ReadFull(rand.Reader, uuid[:]); err != nil { 37 | log.Fatalf("cannot generate random data for UUID: %v", err) 38 | } 39 | uuid[8] = uuid[8]&^0xc0 | 0x80 40 | uuid[6] = uuid[6]&^0xf0 | 0x40 41 | return fmt.Sprintf("%x-%x-%x-%x-%x", uuid[:4], uuid[4:6], uuid[6:8], uuid[8:10], uuid[10:]) 42 | } 43 | -------------------------------------------------------------------------------- /utils_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4 6 | 7 | import ( 8 | "fmt" 9 | "testing" 10 | ) 11 | 12 | func TestSplitAddr(t *testing.T) { 13 | testCases := []struct { 14 | desc string 15 | v string 16 | network string 17 | addr string 18 | err error 19 | }{ 20 | { 21 | desc: "tcp wild", 22 | v: "tcp://*:5000", 23 | network: "tcp", 24 | addr: "0.0.0.0:5000", 25 | err: nil, 26 | }, 27 | { 28 | desc: "tcp ipv4", 29 | v: "tcp://127.0.0.1:6000", 30 | network: "tcp", 31 | addr: "127.0.0.1:6000", 32 | err: nil, 33 | }, 34 | { 35 | desc: "tcp ipv6", 36 | v: "tcp://[::1]:7000", 37 | network: "tcp", 38 | addr: "[::1]:7000", 39 | err: nil, 40 | }, 41 | { 42 | desc: "ipc", 43 | v: "ipc://some-ep", 44 | network: "ipc", 45 | addr: "some-ep", 46 | err: nil, 47 | }, 48 | { 49 | desc: "inproc", 50 | v: "inproc://some-ep", 51 | network: "inproc", 52 | addr: "some-ep", 53 | err: nil, 54 | }, 55 | } 56 | for _, tc := range testCases { 57 | t.Run(tc.desc, func(t *testing.T) { 58 | network, addr, err := splitAddr(tc.v) 59 | if network != tc.network { 60 | t.Fatalf("unexpected network: got=%v, want=%v", network, tc.network) 61 | } 62 | if addr != tc.addr { 63 | t.Fatalf("unexpected address: got=%q, want=%q", addr, tc.addr) 64 | } 65 | if fmt.Sprintf("%+v", err) != fmt.Sprintf("%+v", tc.err) { // nil-safe comparison errors by value 66 | t.Fatalf("unexpected error: got=%+v, want=%+v", err, tc.err) 67 | } 68 | }) 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /xpub.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4 6 | 7 | import ( 8 | "context" 9 | "net" 10 | ) 11 | 12 | // NewXPub returns a new XPUB ZeroMQ socket. 13 | // The returned socket value is initially unbound. 14 | func NewXPub(ctx context.Context, opts ...Option) Socket { 15 | xpub := &xpubSocket{newSocket(ctx, XPub, opts...)} 16 | xpub.sck.w = newPubMWriter(xpub.sck.ctx) 17 | xpub.sck.r = newPubQReader(xpub.sck.ctx) 18 | return xpub 19 | } 20 | 21 | // xpubSocket is a XPUB ZeroMQ socket. 22 | type xpubSocket struct { 23 | sck *socket 24 | } 25 | 26 | // Close closes the open Socket 27 | func (xpub *xpubSocket) Close() error { 28 | return xpub.sck.Close() 29 | } 30 | 31 | // Send puts the message on the outbound send queue. 32 | // Send blocks until the message can be queued or the send deadline expires. 33 | func (xpub *xpubSocket) Send(msg Msg) error { 34 | return xpub.sck.Send(msg) 35 | } 36 | 37 | // SendMulti puts the message on the outbound send queue. 38 | // SendMulti blocks until the message can be queued or the send deadline expires. 39 | // The message will be sent as a multipart message. 40 | func (xpub *xpubSocket) SendMulti(msg Msg) error { 41 | return xpub.sck.SendMulti(msg) 42 | } 43 | 44 | // Recv receives a complete message. 45 | func (xpub *xpubSocket) Recv() (Msg, error) { 46 | return xpub.sck.Recv() 47 | } 48 | 49 | // Listen connects a local endpoint to the Socket. 50 | func (xpub *xpubSocket) Listen(ep string) error { 51 | return xpub.sck.Listen(ep) 52 | } 53 | 54 | // Dial connects a remote endpoint to the Socket. 55 | func (xpub *xpubSocket) Dial(ep string) error { 56 | return xpub.sck.Dial(ep) 57 | } 58 | 59 | // Type returns the type of this Socket (PUB, SUB, ...) 60 | func (xpub *xpubSocket) Type() SocketType { 61 | return xpub.sck.Type() 62 | } 63 | 64 | // Addr returns the listener's address. 65 | // Addr returns nil if the socket isn't a listener. 66 | func (xpub *xpubSocket) Addr() net.Addr { 67 | return xpub.sck.Addr() 68 | } 69 | 70 | // GetOption is used to retrieve an option for a socket. 71 | func (xpub *xpubSocket) GetOption(name string) (interface{}, error) { 72 | return xpub.sck.GetOption(name) 73 | } 74 | 75 | // SetOption is used to set an option for a socket. 76 | func (xpub *xpubSocket) SetOption(name string, value interface{}) error { 77 | return xpub.sck.SetOption(name, value) 78 | } 79 | 80 | func (xpub *xpubSocket) Topics() []string { 81 | return xpub.sck.topics() 82 | } 83 | 84 | var ( 85 | _ Socket = (*xpubSocket)(nil) 86 | ) 87 | -------------------------------------------------------------------------------- /xsub.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4 6 | 7 | import ( 8 | "context" 9 | "net" 10 | ) 11 | 12 | // NewXSub returns a new XSUB ZeroMQ socket. 13 | // The returned socket value is initially unbound. 14 | func NewXSub(ctx context.Context, opts ...Option) Socket { 15 | xsub := &xsubSocket{newSocket(ctx, XSub, opts...)} 16 | return xsub 17 | } 18 | 19 | // xsubSocket is a XSUB ZeroMQ socket. 20 | type xsubSocket struct { 21 | sck *socket 22 | } 23 | 24 | // Close closes the open Socket 25 | func (xsub *xsubSocket) Close() error { 26 | return xsub.sck.Close() 27 | } 28 | 29 | // Send puts the message on the outbound send queue. 30 | // Send blocks until the message can be queued or the send deadline expires. 31 | func (xsub *xsubSocket) Send(msg Msg) error { 32 | return xsub.sck.Send(msg) 33 | } 34 | 35 | // SendMulti puts the message on the outbound send queue. 36 | // SendMulti blocks until the message can be queued or the send deadline expires. 37 | // The message will be sent as a multipart message. 38 | func (xsub *xsubSocket) SendMulti(msg Msg) error { 39 | return xsub.sck.SendMulti(msg) 40 | } 41 | 42 | // Recv receives a complete message. 43 | func (xsub *xsubSocket) Recv() (Msg, error) { 44 | return xsub.sck.Recv() 45 | } 46 | 47 | // Listen connects a local endpoint to the Socket. 48 | func (xsub *xsubSocket) Listen(ep string) error { 49 | return xsub.sck.Listen(ep) 50 | } 51 | 52 | // Dial connects a remote endpoint to the Socket. 53 | func (xsub *xsubSocket) Dial(ep string) error { 54 | return xsub.sck.Dial(ep) 55 | } 56 | 57 | // Type returns the type of this Socket (PUB, SUB, ...) 58 | func (xsub *xsubSocket) Type() SocketType { 59 | return xsub.sck.Type() 60 | } 61 | 62 | // Addr returns the listener's address. 63 | // Addr returns nil if the socket isn't a listener. 64 | func (xsub *xsubSocket) Addr() net.Addr { 65 | return xsub.sck.Addr() 66 | } 67 | 68 | // GetOption is used to retrieve an option for a socket. 69 | func (xsub *xsubSocket) GetOption(name string) (interface{}, error) { 70 | return xsub.sck.GetOption(name) 71 | } 72 | 73 | // SetOption is used to set an option for a socket. 74 | func (xsub *xsubSocket) SetOption(name string, value interface{}) error { 75 | return xsub.sck.SetOption(name, value) 76 | } 77 | 78 | var ( 79 | _ Socket = (*xsubSocket)(nil) 80 | ) 81 | -------------------------------------------------------------------------------- /zall_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4 6 | 7 | import ( 8 | "fmt" 9 | "io" 10 | "log" 11 | "net" 12 | ) 13 | 14 | var ( 15 | Devnull = log.New(io.Discard, "zmq4: ", 0) 16 | ) 17 | 18 | func must(str string, err error) string { 19 | if err != nil { 20 | panic(err) 21 | } 22 | return str 23 | } 24 | 25 | func EndPoint(transport string) (string, error) { 26 | switch transport { 27 | case "tcp": 28 | addr, err := net.ResolveTCPAddr("tcp", "localhost:0") 29 | if err != nil { 30 | return "", err 31 | } 32 | l, err := net.ListenTCP("tcp", addr) 33 | if err != nil { 34 | return "", err 35 | } 36 | defer l.Close() 37 | return fmt.Sprintf("tcp://%s", l.Addr()), nil 38 | case "ipc": 39 | return "ipc://tmp-" + newUUID(), nil 40 | case "inproc": 41 | return "inproc://tmp-" + newUUID(), nil 42 | default: 43 | panic("invalid transport: [" + transport + "]") 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /zmq4.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package zmq4 implements the ØMQ sockets and protocol for ZeroMQ-4. 6 | // 7 | // For more informations, see http://zeromq.org. 8 | package zmq4 9 | 10 | import "net" 11 | 12 | // Socket represents a ZeroMQ socket. 13 | type Socket interface { 14 | // Close closes the open Socket. 15 | Close() error 16 | 17 | // Send puts the message on the outbound send queue. 18 | // 19 | // Send blocks until the message can be queued or the send deadline expires. 20 | Send(msg Msg) error 21 | 22 | // SendMulti puts the message on the outbound send queue. 23 | // 24 | // SendMulti blocks until the message can be queued or the send deadline 25 | // expires. The message will be sent as a multipart message. 26 | SendMulti(msg Msg) error 27 | 28 | // Recv receives a complete message. 29 | Recv() (Msg, error) 30 | 31 | // Listen connects a local endpoint to the Socket. 32 | // 33 | // In ZeroMQ's terminology, it binds. 34 | Listen(ep string) error 35 | 36 | // Dial connects a remote endpoint to the Socket. 37 | // 38 | // In ZeroMQ's terminology, it connects. 39 | Dial(ep string) error 40 | 41 | // Type returns the type of this Socket (for example PUB, SUB, etc.) 42 | Type() SocketType 43 | 44 | // Addr returns the listener's address. It returns nil if the socket isn't a 45 | // listener. 46 | Addr() net.Addr 47 | 48 | // GetOption retrieves an option for a socket. 49 | GetOption(name string) (interface{}, error) 50 | 51 | // SetOption sets an option for a socket. 52 | SetOption(name string, value interface{}) error 53 | } 54 | -------------------------------------------------------------------------------- /zmq4_pair_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4_test 6 | 7 | import ( 8 | "bytes" 9 | "context" 10 | "fmt" 11 | "sync" 12 | "testing" 13 | "time" 14 | 15 | "github.com/go-zeromq/zmq4" 16 | "golang.org/x/sync/errgroup" 17 | ) 18 | 19 | var ( 20 | pairs = []testCasePair{ 21 | { 22 | name: "tcp-pair-pair", 23 | endpoint: must(EndPoint("tcp")), 24 | srv: zmq4.NewPair(bkg), 25 | cli: zmq4.NewPair(bkg), 26 | }, 27 | { 28 | name: "ipc-pair-pair", 29 | endpoint: "ipc://ipc-pair-pair", 30 | srv: zmq4.NewPair(bkg), 31 | cli: zmq4.NewPair(bkg), 32 | }, 33 | { 34 | name: "inproc-pair-pair", 35 | endpoint: "inproc://inproc-pair-pair", 36 | srv: zmq4.NewPair(bkg), 37 | cli: zmq4.NewPair(bkg), 38 | }, 39 | } 40 | ) 41 | 42 | type testCasePair struct { 43 | name string 44 | skip bool 45 | endpoint string 46 | srv zmq4.Socket 47 | cli zmq4.Socket 48 | } 49 | 50 | func TestPair(t *testing.T) { 51 | var ( 52 | msg0 = zmq4.NewMsgString("") 53 | msg1 = zmq4.NewMsgString("MSG 1") 54 | msg2 = zmq4.NewMsgString("msg 2") 55 | msgs = []zmq4.Msg{ 56 | msg0, 57 | msg1, 58 | msg2, 59 | } 60 | ) 61 | 62 | for i := range pairs { 63 | tc := pairs[i] 64 | t.Run(tc.name, func(t *testing.T) { 65 | defer tc.srv.Close() 66 | defer tc.cli.Close() 67 | 68 | ep := tc.endpoint 69 | cleanUp(ep) 70 | 71 | if tc.skip { 72 | t.Skipf(tc.name) 73 | } 74 | // t.Parallel() 75 | 76 | var ( 77 | wg1 sync.WaitGroup 78 | wg2 sync.WaitGroup 79 | ) 80 | 81 | wg1.Add(1) 82 | wg2.Add(1) 83 | 84 | ctx, timeout := context.WithTimeout(context.Background(), 20*time.Second) 85 | defer timeout() 86 | 87 | grp, ctx := errgroup.WithContext(ctx) 88 | grp.Go(func() error { 89 | 90 | err := tc.srv.Listen(ep) 91 | if err != nil { 92 | return fmt.Errorf("could not listen: %w", err) 93 | } 94 | 95 | if addr := tc.srv.Addr(); addr == nil { 96 | return fmt.Errorf("listener with nil Addr") 97 | } 98 | 99 | wg1.Wait() 100 | wg2.Done() 101 | 102 | for _, msg := range msgs { 103 | err = tc.srv.Send(msg) 104 | if err != nil { 105 | return fmt.Errorf("could not send message %v: %w", msg, err) 106 | } 107 | reply, err := tc.srv.Recv() 108 | if err != nil { 109 | return fmt.Errorf("could not recv reply to %v: %w", msg, err) 110 | } 111 | 112 | if got, want := reply, zmq4.NewMsgString("reply: "+string(msg.Bytes())); !bytes.Equal(got.Bytes(), want.Bytes()) { 113 | return fmt.Errorf("invalid cli reply for msg #%d: got=%v, want=%v", i, got, want) 114 | } 115 | } 116 | 117 | quit, err := tc.srv.Recv() 118 | if err != nil { 119 | return fmt.Errorf("could not recv QUIT message: %w", err) 120 | } 121 | 122 | if got, want := quit, zmq4.NewMsgString("QUIT"); !bytes.Equal(got.Bytes(), want.Bytes()) { 123 | return fmt.Errorf("invalid QUIT message from cli: got=%v, want=%v", got, want) 124 | } 125 | 126 | return err 127 | }) 128 | 129 | grp.Go(func() error { 130 | 131 | err := tc.cli.Dial(ep) 132 | if err != nil { 133 | return fmt.Errorf("could not dial: %w", err) 134 | } 135 | 136 | wg1.Done() 137 | wg2.Wait() 138 | 139 | for i := range msgs { 140 | msg, err := tc.cli.Recv() 141 | if err != nil { 142 | return fmt.Errorf("could not recv #%d msg from srv: %w", i, err) 143 | } 144 | if !bytes.Equal(msg.Bytes(), msgs[i].Bytes()) { 145 | return fmt.Errorf("invalid #%d msg from srv: got=%v, want=%v", 146 | i, msg, msgs[i], 147 | ) 148 | } 149 | 150 | err = tc.cli.Send(zmq4.NewMsgString("reply: " + string(msg.Bytes()))) 151 | if err != nil { 152 | return fmt.Errorf("could not send message %v: %w", msg, err) 153 | } 154 | } 155 | 156 | err = tc.cli.Send(zmq4.NewMsgString("QUIT")) 157 | if err != nil { 158 | return fmt.Errorf("could not send QUIT message: %w", err) 159 | } 160 | 161 | return err 162 | }) 163 | 164 | if err := grp.Wait(); err != nil { 165 | t.Fatalf("error: %+v", err) 166 | } 167 | 168 | if err := ctx.Err(); err != nil && err != context.Canceled { 169 | t.Fatalf("error: %+v", err) 170 | } 171 | }) 172 | } 173 | } 174 | -------------------------------------------------------------------------------- /zmq4_pushpull_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4_test 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "reflect" 11 | "testing" 12 | "time" 13 | 14 | "github.com/go-zeromq/zmq4" 15 | "golang.org/x/sync/errgroup" 16 | ) 17 | 18 | var ( 19 | pushpulls = []testCasePushPull{ 20 | { 21 | name: "tcp-push-pull", 22 | endpoint: must(EndPoint("tcp")), 23 | push: zmq4.NewPush(bkg), 24 | pull: zmq4.NewPull(bkg), 25 | }, 26 | { 27 | name: "ipc-push-pull", 28 | endpoint: "ipc://ipc-push-pull", 29 | push: zmq4.NewPush(bkg), 30 | pull: zmq4.NewPull(bkg), 31 | }, 32 | { 33 | name: "inproc-push-pull", 34 | endpoint: "inproc://push-pull", 35 | push: zmq4.NewPush(bkg), 36 | pull: zmq4.NewPull(bkg), 37 | }, 38 | } 39 | ) 40 | 41 | type testCasePushPull struct { 42 | name string 43 | skip bool 44 | endpoint string 45 | push zmq4.Socket 46 | pull zmq4.Socket 47 | } 48 | 49 | func TestPushPull(t *testing.T) { 50 | var ( 51 | hello = zmq4.NewMsg([]byte("HELLO WORLD")) 52 | bye = zmq4.NewMsgFrom([]byte("GOOD"), []byte("BYE")) 53 | ) 54 | 55 | for i := range pushpulls { 56 | tc := pushpulls[i] 57 | t.Run(tc.name, func(t *testing.T) { 58 | defer tc.pull.Close() 59 | defer tc.push.Close() 60 | 61 | ep := tc.endpoint 62 | cleanUp(ep) 63 | 64 | if tc.skip { 65 | t.Skipf(tc.name) 66 | } 67 | // t.Parallel() 68 | 69 | ctx, timeout := context.WithTimeout(context.Background(), 20*time.Second) 70 | defer timeout() 71 | 72 | grp, _ := errgroup.WithContext(ctx) 73 | grp.Go(func() error { 74 | 75 | err := tc.push.Listen(ep) 76 | if err != nil { 77 | return fmt.Errorf("could not listen: %w", err) 78 | } 79 | 80 | if addr := tc.push.Addr(); addr == nil { 81 | return fmt.Errorf("listener with nil Addr") 82 | } 83 | 84 | err = tc.push.Send(hello) 85 | if err != nil { 86 | return fmt.Errorf("could not send %v: %w", hello, err) 87 | } 88 | 89 | err = tc.push.Send(bye) 90 | if err != nil { 91 | return fmt.Errorf("could not send %v: %w", bye, err) 92 | } 93 | return err 94 | }) 95 | grp.Go(func() error { 96 | 97 | err := tc.pull.Dial(ep) 98 | if err != nil { 99 | return fmt.Errorf("could not dial: %w", err) 100 | } 101 | 102 | if addr := tc.pull.Addr(); addr != nil { 103 | return fmt.Errorf("dialer with non-nil Addr") 104 | } 105 | 106 | msg, err := tc.pull.Recv() 107 | if err != nil { 108 | return fmt.Errorf("could not recv %v: %w", hello, err) 109 | } 110 | 111 | if got, want := msg, hello; !reflect.DeepEqual(got, want) { 112 | return fmt.Errorf("recv1: got = %v, want= %v", got, want) 113 | } 114 | 115 | msg, err = tc.pull.Recv() 116 | if err != nil { 117 | return fmt.Errorf("could not recv %v: %w", bye, err) 118 | } 119 | 120 | if got, want := msg, bye; !reflect.DeepEqual(got, want) { 121 | return fmt.Errorf("recv2: got = %v, want= %v", got, want) 122 | } 123 | 124 | return err 125 | }) 126 | if err := grp.Wait(); err != nil { 127 | t.Fatalf("error: %+v", err) 128 | } 129 | }) 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /zmq4_reqrep_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4_test 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "reflect" 11 | "testing" 12 | "time" 13 | 14 | "github.com/go-zeromq/zmq4" 15 | "golang.org/x/sync/errgroup" 16 | ) 17 | 18 | var ( 19 | reqreps = []testCaseReqRep{ 20 | { 21 | name: "tcp-req-rep", 22 | endpoint: must(EndPoint("tcp")), 23 | req1: zmq4.NewReq(bkg), 24 | rep: zmq4.NewRep(bkg), 25 | }, 26 | { 27 | name: "ipc-req-rep", 28 | endpoint: "ipc://ipc-req-rep", 29 | req1: zmq4.NewReq(bkg), 30 | rep: zmq4.NewRep(bkg), 31 | }, 32 | { 33 | name: "inproc-req-rep", 34 | endpoint: "inproc://inproc-req-rep", 35 | req1: zmq4.NewReq(bkg), 36 | rep: zmq4.NewRep(bkg), 37 | }, 38 | } 39 | ) 40 | 41 | type testCaseReqRep struct { 42 | name string 43 | skip bool 44 | endpoint string 45 | req1 zmq4.Socket 46 | req2 zmq4.Socket 47 | rep zmq4.Socket 48 | } 49 | 50 | func TestReqRep(t *testing.T) { 51 | var ( 52 | reqName = zmq4.NewMsgString("NAME") 53 | reqLang = zmq4.NewMsgString("LANG") 54 | reqQuit = zmq4.NewMsgString("QUIT") 55 | repName = zmq4.NewMsgString("zmq4") 56 | repLang = zmq4.NewMsgString("Go") 57 | repQuit = zmq4.NewMsgString("bye") 58 | ) 59 | 60 | for i := range reqreps { 61 | tc := reqreps[i] 62 | t.Run(tc.name, func(t *testing.T) { 63 | defer tc.req1.Close() 64 | defer tc.rep.Close() 65 | 66 | ep := tc.endpoint 67 | cleanUp(ep) 68 | 69 | if tc.skip { 70 | t.Skipf(tc.name) 71 | } 72 | // t.Parallel() 73 | 74 | ctx, timeout := context.WithTimeout(context.Background(), 20*time.Second) 75 | defer timeout() 76 | 77 | grp, _ := errgroup.WithContext(ctx) 78 | grp.Go(func() error { 79 | 80 | err := tc.rep.Listen(ep) 81 | if err != nil { 82 | return fmt.Errorf("could not listen: %w", err) 83 | } 84 | 85 | if addr := tc.rep.Addr(); addr == nil { 86 | return fmt.Errorf("listener with nil Addr") 87 | } 88 | 89 | loop := true 90 | for loop { 91 | msg, err := tc.rep.Recv() 92 | if err != nil { 93 | return fmt.Errorf("could not recv REQ message: %w", err) 94 | } 95 | var rep zmq4.Msg 96 | switch string(msg.Frames[0]) { 97 | case "NAME": 98 | rep = repName 99 | case "LANG": 100 | rep = repLang 101 | case "QUIT": 102 | rep = repQuit 103 | loop = false 104 | } 105 | 106 | err = tc.rep.Send(rep) 107 | if err != nil { 108 | return fmt.Errorf("could not send REP message to %v: %w", msg, err) 109 | } 110 | } 111 | 112 | return err 113 | }) 114 | grp.Go(func() error { 115 | 116 | err := tc.req1.Dial(ep) 117 | if err != nil { 118 | return fmt.Errorf("could not dial: %w", err) 119 | } 120 | 121 | if addr := tc.req1.Addr(); addr != nil { 122 | return fmt.Errorf("dialer with non-nil Addr") 123 | } 124 | 125 | for _, msg := range []struct { 126 | req zmq4.Msg 127 | rep zmq4.Msg 128 | }{ 129 | {reqName, repName}, 130 | {reqLang, repLang}, 131 | {reqQuit, repQuit}, 132 | } { 133 | err = tc.req1.Send(msg.req) 134 | if err != nil { 135 | return fmt.Errorf("could not send REQ message %v: %w", msg.req, err) 136 | } 137 | rep, err := tc.req1.Recv() 138 | if err != nil { 139 | return fmt.Errorf("could not recv REP message %v: %w", msg.req, err) 140 | } 141 | 142 | if got, want := rep, msg.rep; !reflect.DeepEqual(got, want) { 143 | return fmt.Errorf("got = %v, want= %v", got, want) 144 | } 145 | } 146 | 147 | return err 148 | }) 149 | if err := grp.Wait(); err != nil { 150 | t.Fatalf("error: %+v", err) 151 | } 152 | }) 153 | } 154 | } 155 | 156 | func TestMultiReqRepIssue70(t *testing.T) { 157 | var ( 158 | reqName1 = zmq4.NewMsgString("NAME") 159 | reqLang1 = zmq4.NewMsgString("LANG") 160 | reqQuit1 = zmq4.NewMsgString("QUIT") 161 | reqName2 = zmq4.NewMsgString("NAME2") 162 | reqLang2 = zmq4.NewMsgString("LANG2") 163 | reqQuit2 = zmq4.NewMsgString("QUIT2") 164 | repName1 = zmq4.NewMsgString("zmq4") 165 | repLang1 = zmq4.NewMsgString("Go") 166 | repQuit1 = zmq4.NewMsgString("bye") 167 | repName2 = zmq4.NewMsgString("zmq42") 168 | repLang2 = zmq4.NewMsgString("Go2") 169 | repQuit2 = zmq4.NewMsgString("bye2") 170 | ) 171 | 172 | reqreps := []testCaseReqRep{ 173 | { 174 | name: "tcp-req-rep", 175 | endpoint: must(EndPoint("tcp")), 176 | req1: zmq4.NewReq(bkg), 177 | req2: zmq4.NewReq(bkg), 178 | rep: zmq4.NewRep(bkg), 179 | }, 180 | { 181 | name: "ipc-req-rep", 182 | endpoint: "ipc://ipc-req-rep", 183 | req1: zmq4.NewReq(bkg), 184 | req2: zmq4.NewReq(bkg), 185 | rep: zmq4.NewRep(bkg), 186 | }, 187 | { 188 | name: "inproc-req-rep", 189 | endpoint: "inproc://inproc-req-rep", 190 | req1: zmq4.NewReq(bkg), 191 | req2: zmq4.NewReq(bkg), 192 | rep: zmq4.NewRep(bkg), 193 | }, 194 | } 195 | 196 | for i := range reqreps { 197 | tc := reqreps[i] 198 | t.Run(tc.name, func(t *testing.T) { 199 | defer tc.req1.Close() 200 | defer tc.req2.Close() 201 | defer tc.rep.Close() 202 | 203 | if tc.skip { 204 | t.Skipf(tc.name) 205 | } 206 | // t.Parallel() 207 | 208 | ep := tc.endpoint 209 | cleanUp(ep) 210 | 211 | ctx, timeout := context.WithTimeout(context.Background(), 20*time.Second) 212 | defer timeout() 213 | 214 | grp, _ := errgroup.WithContext(ctx) 215 | grp.Go(func() error { 216 | err := tc.rep.Listen(ep) 217 | if err != nil { 218 | return fmt.Errorf("could not listen: %w", err) 219 | } 220 | 221 | if addr := tc.rep.Addr(); addr == nil { 222 | return fmt.Errorf("listener with nil Addr") 223 | } 224 | 225 | loop1, loop2 := true, true 226 | for loop1 || loop2 { 227 | msg, err := tc.rep.Recv() 228 | if err != nil { 229 | return fmt.Errorf("could not recv REQ message: %w", err) 230 | } 231 | var rep zmq4.Msg 232 | switch string(msg.Frames[0]) { 233 | case "NAME": 234 | rep = repName1 235 | case "LANG": 236 | rep = repLang1 237 | case "QUIT": 238 | rep = repQuit1 239 | loop1 = false 240 | case "NAME2": 241 | rep = repName2 242 | case "LANG2": 243 | rep = repLang2 244 | case "QUIT2": 245 | rep = repQuit2 246 | loop2 = false 247 | } 248 | 249 | err = tc.rep.Send(rep) 250 | if err != nil { 251 | return fmt.Errorf("could not send REP message to %v: %w", msg, err) 252 | } 253 | } 254 | return err 255 | }) 256 | grp.Go(func() error { 257 | 258 | err := tc.req2.Dial(ep) 259 | if err != nil { 260 | return fmt.Errorf("could not dial: %w", err) 261 | } 262 | 263 | if addr := tc.req2.Addr(); addr != nil { 264 | return fmt.Errorf("dialer with non-nil Addr") 265 | } 266 | 267 | for _, msg := range []struct { 268 | req zmq4.Msg 269 | rep zmq4.Msg 270 | }{ 271 | {reqName2, repName2}, 272 | {reqLang2, repLang2}, 273 | {reqQuit2, repQuit2}, 274 | } { 275 | err = tc.req2.Send(msg.req) 276 | if err != nil { 277 | return fmt.Errorf("could not send REQ message %v: %w", msg.req, err) 278 | } 279 | rep, err := tc.req2.Recv() 280 | if err != nil { 281 | return fmt.Errorf("could not recv REP message %v: %w", msg.req, err) 282 | } 283 | 284 | if got, want := rep, msg.rep; !reflect.DeepEqual(got, want) { 285 | return fmt.Errorf("got = %v, want= %v", got, want) 286 | } 287 | } 288 | return err 289 | }) 290 | grp.Go(func() error { 291 | 292 | err := tc.req1.Dial(ep) 293 | if err != nil { 294 | return fmt.Errorf("could not dial: %w", err) 295 | } 296 | 297 | if addr := tc.req1.Addr(); addr != nil { 298 | return fmt.Errorf("dialer with non-nil Addr") 299 | } 300 | 301 | for _, msg := range []struct { 302 | req zmq4.Msg 303 | rep zmq4.Msg 304 | }{ 305 | {reqName1, repName1}, 306 | {reqLang1, repLang1}, 307 | {reqQuit1, repQuit1}, 308 | } { 309 | err = tc.req1.Send(msg.req) 310 | if err != nil { 311 | return fmt.Errorf("could not send REQ message %v: %w", msg.req, err) 312 | } 313 | rep, err := tc.req1.Recv() 314 | if err != nil { 315 | return fmt.Errorf("could not recv REP message %v: %w", msg.req, err) 316 | } 317 | 318 | if got, want := rep, msg.rep; !reflect.DeepEqual(got, want) { 319 | return fmt.Errorf("got = %v, want= %v", got, want) 320 | } 321 | } 322 | return err 323 | }) 324 | if err := grp.Wait(); err != nil { 325 | t.Fatalf("error: %+v", err) 326 | } 327 | }) 328 | } 329 | } 330 | -------------------------------------------------------------------------------- /zmq4_routerdealer_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4_test 6 | 7 | import ( 8 | "bytes" 9 | "context" 10 | "fmt" 11 | "net" 12 | "reflect" 13 | "sync" 14 | "testing" 15 | "time" 16 | 17 | "github.com/go-zeromq/zmq4" 18 | "golang.org/x/sync/errgroup" 19 | ) 20 | 21 | var ( 22 | routerdealers = []testCaseRouterDealer{ 23 | { 24 | name: "tcp-router-dealer", 25 | endpoint: func() string { return must(EndPoint("tcp")) }, 26 | router: func(ctx context.Context) zmq4.Socket { 27 | return zmq4.NewRouter(ctx, zmq4.WithID(zmq4.SocketIdentity("router"))) 28 | }, 29 | dealer0: func(ctx context.Context) zmq4.Socket { 30 | return zmq4.NewDealer(ctx, zmq4.WithID(zmq4.SocketIdentity("dealer-0"))) 31 | }, 32 | dealer1: func(ctx context.Context) zmq4.Socket { 33 | return zmq4.NewDealer(ctx, zmq4.WithID(zmq4.SocketIdentity("dealer-1"))) 34 | }, 35 | dealer2: func(ctx context.Context) zmq4.Socket { 36 | return zmq4.NewDealer(ctx, zmq4.WithID(zmq4.SocketIdentity("dealer-2"))) 37 | }, 38 | }, 39 | { 40 | name: "ipc-router-dealer", 41 | skip: true, 42 | endpoint: func() string { return must(EndPoint("ipc")) }, 43 | router: func(ctx context.Context) zmq4.Socket { 44 | return zmq4.NewRouter(ctx, zmq4.WithID(zmq4.SocketIdentity("router"))) 45 | }, 46 | dealer0: func(ctx context.Context) zmq4.Socket { 47 | return zmq4.NewDealer(ctx, zmq4.WithID(zmq4.SocketIdentity("dealer-0"))) 48 | }, 49 | dealer1: func(ctx context.Context) zmq4.Socket { 50 | return zmq4.NewDealer(ctx, zmq4.WithID(zmq4.SocketIdentity("dealer-1"))) 51 | }, 52 | dealer2: func(ctx context.Context) zmq4.Socket { 53 | return zmq4.NewDealer(ctx, zmq4.WithID(zmq4.SocketIdentity("dealer-2"))) 54 | }, 55 | }, 56 | { 57 | name: "inproc-router-dealer", 58 | skip: true, 59 | endpoint: func() string { return must(EndPoint("inproc")) }, 60 | router: func(ctx context.Context) zmq4.Socket { 61 | return zmq4.NewRouter(ctx, zmq4.WithID(zmq4.SocketIdentity("router"))) 62 | }, 63 | dealer0: func(ctx context.Context) zmq4.Socket { 64 | return zmq4.NewDealer(ctx, zmq4.WithID(zmq4.SocketIdentity("dealer-0"))) 65 | }, 66 | dealer1: func(ctx context.Context) zmq4.Socket { 67 | return zmq4.NewDealer(ctx, zmq4.WithID(zmq4.SocketIdentity("dealer-1"))) 68 | }, 69 | dealer2: func(ctx context.Context) zmq4.Socket { 70 | return zmq4.NewDealer(ctx, zmq4.WithID(zmq4.SocketIdentity("dealer-2"))) 71 | }, 72 | }, 73 | } 74 | ) 75 | 76 | type testCaseRouterDealer struct { 77 | name string 78 | skip bool 79 | endpoint func() string 80 | router func(context.Context) zmq4.Socket 81 | dealer0 func(context.Context) zmq4.Socket 82 | dealer1 func(context.Context) zmq4.Socket 83 | dealer2 func(context.Context) zmq4.Socket 84 | } 85 | 86 | func TestRouterDealer(t *testing.T) { 87 | var ( 88 | Fired = []byte("Fired!") 89 | WorkHarder = []byte("Work Harder!") 90 | 91 | ready = zmq4.NewMsgFrom([]byte(""), []byte("ready")) 92 | ) 93 | 94 | for i := range routerdealers { 95 | tc := routerdealers[i] 96 | t.Run(tc.name, func(t *testing.T) { 97 | // t.Parallel() 98 | ep := tc.endpoint() 99 | cleanUp(ep) 100 | 101 | if tc.skip { 102 | t.Skipf(tc.name) 103 | } 104 | 105 | ctx, timeout := context.WithTimeout(context.Background(), 10*time.Second) 106 | defer timeout() 107 | 108 | router := tc.router(ctx) 109 | defer router.Close() 110 | 111 | dealer0 := tc.dealer0(ctx) 112 | defer dealer0.Close() 113 | dealer1 := tc.dealer1(ctx) 114 | defer dealer1.Close() 115 | dealer2 := tc.dealer2(ctx) 116 | defer dealer2.Close() 117 | 118 | dealers := []zmq4.Socket{dealer0, dealer1, dealer2} 119 | fired := make([]int, len(dealers)) 120 | 121 | var wgd sync.WaitGroup 122 | wgd.Add(len(dealers)) 123 | var wgr sync.WaitGroup 124 | wgr.Add(1) 125 | 126 | var seenMu sync.RWMutex 127 | seen := make(map[string]int) 128 | grp, _ := errgroup.WithContext(ctx) 129 | grp.Go(func() error { 130 | 131 | err := router.Listen(ep) 132 | if err != nil { 133 | return fmt.Errorf("could not listen: %w", err) 134 | } 135 | 136 | if addr := router.Addr(); addr == nil { 137 | return fmt.Errorf("listener with nil Addr") 138 | } 139 | 140 | wgd.Wait() 141 | wgr.Done() 142 | 143 | fired := 0 144 | const N = 3 145 | for i := 0; i < len(dealers)*N+1 && fired < N; i++ { 146 | msg, err := router.Recv() 147 | if err != nil { 148 | return fmt.Errorf("could not recv message: %w", err) 149 | } 150 | 151 | if len(msg.Frames) == 0 { 152 | seenMu.RLock() 153 | str := fmt.Sprintf("%v", seen) 154 | seenMu.RUnlock() 155 | return fmt.Errorf("router received empty message (test=%q, iter=%d, seen=%v)", tc.name, i, str) 156 | } 157 | id := string(msg.Frames[0]) 158 | seenMu.Lock() 159 | seen[id]++ 160 | n := seen[id] 161 | seenMu.Unlock() 162 | switch { 163 | case n >= N: 164 | msg = zmq4.NewMsgFrom([]byte(id), []byte(""), Fired) 165 | fired++ 166 | default: 167 | msg = zmq4.NewMsgFrom([]byte(id), []byte(""), WorkHarder) 168 | } 169 | err = router.Send(msg) 170 | if err != nil { 171 | return fmt.Errorf("could not send %v: %w", msg, err) 172 | } 173 | } 174 | if fired != N { 175 | return fmt.Errorf("did not fire everybody (fired=%d, want=%d)", fired, N) 176 | } 177 | return nil 178 | }) 179 | for idealer := range dealers { 180 | func(idealer int, dealer zmq4.Socket) { 181 | grp.Go(func() error { 182 | 183 | err := dealer.Dial(ep) 184 | if err != nil { 185 | return fmt.Errorf("could not dial: %w", err) 186 | } 187 | 188 | if addr := dealer.Addr(); addr != nil { 189 | return fmt.Errorf("dialer with non-nil Addr") 190 | } 191 | 192 | wgd.Done() 193 | wgd.Wait() 194 | wgr.Wait() 195 | 196 | n := 0 197 | loop: 198 | for { 199 | // tell the broker we are ready for work 200 | err = dealer.Send(ready) 201 | if err != nil { 202 | return fmt.Errorf("could not send %v: %w", ready, err) 203 | } 204 | 205 | // get workload from broker 206 | msg, err := dealer.Recv() 207 | if err != nil { 208 | return fmt.Errorf("could not recv msg: %w", err) 209 | } 210 | if len(msg.Frames) < 2 { 211 | seenMu.RLock() 212 | str := fmt.Sprintf("%v", seen) 213 | seenMu.RUnlock() 214 | return fmt.Errorf("dealer-%d received invalid msg %v (test=%q, iter=%d, seen=%v)", idealer, msg, tc.name, n, str) 215 | } 216 | work := msg.Frames[1] 217 | fired[idealer]++ 218 | if bytes.Equal(work, Fired) { 219 | break loop 220 | } 221 | 222 | // do some random work 223 | time.Sleep(50 * time.Millisecond) 224 | n++ 225 | } 226 | 227 | return err 228 | }) 229 | }(idealer, dealers[idealer]) 230 | } 231 | 232 | if err := grp.Wait(); err != nil { 233 | t.Errorf("workers: %v", fired) 234 | t.Fatalf("error: %+v", err) 235 | } 236 | 237 | if !reflect.DeepEqual(fired, []int{3, 3, 3}) { 238 | t.Fatalf("some workers did not get fired: %v", fired) 239 | } 240 | }) 241 | } 242 | } 243 | 244 | func TestRouterWithNoDealer(t *testing.T) { 245 | router := zmq4.NewRouter(context.Background()) 246 | err := router.Listen("tcp://*:*") 247 | if err != nil { 248 | t.Fatalf("could not listen: %+v", err) 249 | } 250 | 251 | err = router.Close() 252 | if err != nil { 253 | t.Fatalf("could not close router: %+v", err) 254 | } 255 | } 256 | 257 | func TestRouterDealerClose(t *testing.T) { 258 | tests := []struct { 259 | name string 260 | }{ 261 | {name: "router"}, 262 | {name: "dealer"}, 263 | } 264 | for _, tt := range tests { 265 | t.Run(tt.name, func(t *testing.T) { 266 | ctx, cancel := context.WithCancel(context.Background()) 267 | defer cancel() 268 | socks := map[string]zmq4.Socket{ 269 | "router": zmq4.NewRouter(ctx), 270 | "dealer": zmq4.NewDealer(ctx), 271 | } 272 | router := socks["router"] 273 | dealer := socks["dealer"] 274 | 275 | err := router.Listen("tcp://*:*") 276 | if err != nil { 277 | t.Fatalf("router could not listen: %+v", err) 278 | } 279 | _, port, _ := net.SplitHostPort(router.Addr().String()) 280 | err = dealer.Dial("tcp://*:" + port) 281 | if err != nil { 282 | t.Fatalf("dealer could not dial: %+v", err) 283 | } 284 | start := make(chan bool) 285 | var wg sync.WaitGroup 286 | wg.Add(1) 287 | go func(sock zmq4.Socket, start <-chan bool) { 288 | defer wg.Done() 289 | <-start 290 | _, err := sock.Recv() 291 | if err == nil { 292 | t.Error("expected error: context canceled") 293 | } 294 | }(socks[tt.name], start) 295 | 296 | err = socks[tt.name].Close() 297 | if err != nil { 298 | t.Fatalf("could not close %s: %+v", tt.name, err) 299 | } 300 | start <- true 301 | wg.Wait() 302 | }) 303 | } 304 | } 305 | -------------------------------------------------------------------------------- /zmq4_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2018 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4_test 6 | 7 | import ( 8 | "context" 9 | "crypto/rand" 10 | "fmt" 11 | "io" 12 | "log" 13 | "net" 14 | "os" 15 | "strings" 16 | ) 17 | 18 | var ( 19 | bkg = context.Background() 20 | ) 21 | 22 | func must(str string, err error) string { 23 | if err != nil { 24 | panic(err) 25 | } 26 | return str 27 | } 28 | 29 | func EndPoint(transport string) (string, error) { 30 | switch transport { 31 | case "tcp": 32 | addr, err := net.ResolveTCPAddr("tcp", "localhost:0") 33 | if err != nil { 34 | return "", err 35 | } 36 | l, err := net.ListenTCP("tcp", addr) 37 | if err != nil { 38 | return "", err 39 | } 40 | defer l.Close() 41 | return fmt.Sprintf("tcp://%s", l.Addr()), nil 42 | case "ipc": 43 | return "ipc://tmp-" + newUUID(), nil 44 | case "inproc": 45 | return "inproc://tmp-" + newUUID(), nil 46 | default: 47 | panic("invalid transport: [" + transport + "]") 48 | } 49 | } 50 | 51 | func cleanUp(ep string) { 52 | switch { 53 | case strings.HasPrefix(ep, "ipc://"): 54 | os.Remove(ep[len("ipc://"):]) 55 | case strings.HasPrefix(ep, "inproc://"): 56 | os.Remove(ep[len("inproc://"):]) 57 | } 58 | } 59 | 60 | func newUUID() string { 61 | var uuid [16]byte 62 | if _, err := io.ReadFull(rand.Reader, uuid[:]); err != nil { 63 | log.Fatalf("cannot generate random data for UUID: %v", err) 64 | } 65 | uuid[8] = uuid[8]&^0xc0 | 0x80 66 | uuid[6] = uuid[6]&^0xf0 | 0x40 67 | return fmt.Sprintf("%x-%x-%x-%x-%x", uuid[:4], uuid[4:6], uuid[6:8], uuid[8:10], uuid[10:]) 68 | } 69 | -------------------------------------------------------------------------------- /zmq4_timeout_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4 6 | 7 | import ( 8 | "context" 9 | "testing" 10 | "time" 11 | ) 12 | 13 | func TestPushTimeout(t *testing.T) { 14 | ep := "ipc://@push_timeout_test" 15 | push := NewPush(context.Background(), WithTimeout(1*time.Second)) 16 | defer push.Close() 17 | if err := push.Listen(ep); err != nil { 18 | t.FailNow() 19 | } 20 | 21 | pull := NewPull(context.Background()) 22 | defer pull.Close() 23 | if err := pull.Dial(ep); err != nil { 24 | t.FailNow() 25 | } 26 | 27 | ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 28 | defer cancel() 29 | for { 30 | select { 31 | case <-ctx.Done(): 32 | // The ctx limits overall time of execution 33 | // If it gets canceled, that meain tests failed 34 | // as write to socket did not genereate timeout error 35 | t.Fatalf("test failed before being able to generate timeout error: %+v", ctx.Err()) 36 | default: 37 | } 38 | 39 | err := push.Send(NewMsgString("test string")) 40 | if err == nil { 41 | continue 42 | } 43 | if err != context.DeadlineExceeded { 44 | t.Fatalf("expected a context.DeadlineExceeded error, got=%+v", err) 45 | } 46 | break 47 | } 48 | 49 | } 50 | -------------------------------------------------------------------------------- /zmq4_xpubsub_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The go-zeromq Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package zmq4_test 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "reflect" 11 | "sync" 12 | "testing" 13 | "time" 14 | 15 | "github.com/go-zeromq/zmq4" 16 | "golang.org/x/sync/errgroup" 17 | ) 18 | 19 | var ( 20 | xpubsubs = []testCaseXPubSub{ 21 | { 22 | name: "tcp-xpub-sub", 23 | endpoint: must(EndPoint("tcp")), 24 | xpub: zmq4.NewXPub(bkg), 25 | sub0: zmq4.NewSub(bkg, zmq4.WithID(zmq4.SocketIdentity("sub0"))), 26 | sub1: zmq4.NewSub(bkg, zmq4.WithID(zmq4.SocketIdentity("sub1"))), 27 | sub2: zmq4.NewSub(bkg, zmq4.WithID(zmq4.SocketIdentity("sub2"))), 28 | }, 29 | { 30 | name: "ipc-xpub-sub", 31 | endpoint: "ipc://ipc-xpub-sub", 32 | xpub: zmq4.NewXPub(bkg), 33 | sub0: zmq4.NewSub(bkg, zmq4.WithID(zmq4.SocketIdentity("sub0"))), 34 | sub1: zmq4.NewSub(bkg, zmq4.WithID(zmq4.SocketIdentity("sub1"))), 35 | sub2: zmq4.NewSub(bkg, zmq4.WithID(zmq4.SocketIdentity("sub2"))), 36 | }, 37 | { 38 | name: "inproc-xpub-sub", 39 | endpoint: "inproc://inproc-xpub-sub", 40 | xpub: zmq4.NewXPub(bkg), 41 | sub0: zmq4.NewSub(bkg, zmq4.WithID(zmq4.SocketIdentity("sub0"))), 42 | sub1: zmq4.NewSub(bkg, zmq4.WithID(zmq4.SocketIdentity("sub1"))), 43 | sub2: zmq4.NewSub(bkg, zmq4.WithID(zmq4.SocketIdentity("sub2"))), 44 | }, 45 | } 46 | ) 47 | 48 | type testCaseXPubSub struct { 49 | name string 50 | skip bool 51 | endpoint string 52 | xpub zmq4.Socket 53 | sub0 zmq4.Socket 54 | sub1 zmq4.Socket 55 | sub2 zmq4.Socket 56 | } 57 | 58 | func TestXPubSub(t *testing.T) { 59 | var ( 60 | topics = []string{"", "MSG", "msg"} 61 | wantNumMsgs = []int{3, 1, 1} 62 | msg0 = zmq4.NewMsgString("anything") 63 | msg1 = zmq4.NewMsgString("MSG 1") 64 | msg2 = zmq4.NewMsgString("msg 2") 65 | msgs = [][]zmq4.Msg{ 66 | 0: {msg0, msg1, msg2}, 67 | 1: {msg1}, 68 | 2: {msg2}, 69 | } 70 | ) 71 | 72 | for i := range xpubsubs { 73 | tc := xpubsubs[i] 74 | t.Run(tc.name, func(t *testing.T) { 75 | defer tc.xpub.Close() 76 | defer tc.sub0.Close() 77 | defer tc.sub1.Close() 78 | defer tc.sub2.Close() 79 | 80 | ep := tc.endpoint 81 | cleanUp(ep) 82 | 83 | if tc.skip { 84 | t.Skipf(tc.name) 85 | } 86 | // t.Parallel() 87 | 88 | ctx, timeout := context.WithTimeout(context.Background(), 20*time.Second) 89 | defer timeout() 90 | 91 | nmsgs := []int{0, 0, 0} 92 | subs := []zmq4.Socket{tc.sub0, tc.sub1, tc.sub2} 93 | 94 | var wg1 sync.WaitGroup 95 | var wg2 sync.WaitGroup 96 | wg1.Add(len(subs)) 97 | wg2.Add(len(subs)) 98 | 99 | grp, ctx := errgroup.WithContext(ctx) 100 | grp.Go(func() error { 101 | 102 | err := tc.xpub.Listen(ep) 103 | if err != nil { 104 | return fmt.Errorf("could not listen: %w", err) 105 | } 106 | 107 | if addr := tc.xpub.Addr(); addr == nil { 108 | return fmt.Errorf("listener with nil Addr") 109 | } 110 | 111 | wg1.Wait() 112 | wg2.Wait() 113 | 114 | time.Sleep(1 * time.Second) 115 | 116 | if sck, ok := tc.xpub.(zmq4.Topics); ok { 117 | got := sck.Topics() 118 | if !reflect.DeepEqual(got, topics) { 119 | t.Fatalf("invalid topics.\ngot= %q\nwant=%q", got, topics) 120 | } 121 | } 122 | 123 | for _, msg := range msgs[0] { 124 | err = tc.xpub.Send(msg) 125 | if err != nil { 126 | return fmt.Errorf("could not send message %v: %w", msg, err) 127 | } 128 | } 129 | 130 | return err 131 | }) 132 | 133 | for isub := range subs { 134 | func(isub int, sub zmq4.Socket) { 135 | grp.Go(func() error { 136 | var err error 137 | err = sub.Dial(ep) 138 | if err != nil { 139 | return fmt.Errorf("could not dial: %w", err) 140 | } 141 | 142 | if addr := sub.Addr(); addr != nil { 143 | return fmt.Errorf("dialer with non-nil Addr") 144 | } 145 | 146 | wg1.Done() 147 | wg1.Wait() 148 | 149 | err = sub.SetOption(zmq4.OptionSubscribe, topics[isub]) 150 | if err != nil { 151 | return fmt.Errorf("could not subscribe to topic %q: %w", topics[isub], err) 152 | } 153 | 154 | wg2.Done() 155 | wg2.Wait() 156 | 157 | msgs := msgs[isub] 158 | for imsg, want := range msgs { 159 | msg, err := sub.Recv() 160 | if err != nil { 161 | return fmt.Errorf("could not recv message %v: %w", want, err) 162 | } 163 | if !reflect.DeepEqual(msg, want) { 164 | return fmt.Errorf("sub[%d][msg=%d]: got = %v, want= %v", isub, imsg, msg, want) 165 | } 166 | nmsgs[isub]++ 167 | } 168 | 169 | return err 170 | }) 171 | }(isub, subs[isub]) 172 | } 173 | 174 | if err := grp.Wait(); err != nil { 175 | t.Fatalf("error: %+v", err) 176 | } 177 | 178 | if err := ctx.Err(); err != nil && err != context.Canceled { 179 | t.Fatalf("error: %+v", err) 180 | } 181 | 182 | for i, want := range wantNumMsgs { 183 | if want != nmsgs[i] { 184 | t.Errorf("xsub[%d]: got %d messages, want %d msgs=%v", i, nmsgs[i], want, nmsgs) 185 | } 186 | } 187 | }) 188 | } 189 | } 190 | --------------------------------------------------------------------------------