├── .gitignore ├── .python-version ├── .travis.yml ├── LICENSE ├── Makefile ├── README.md ├── Vagrantfile ├── bin └── nsqlookupd-statsd ├── nsq ├── __init__.py ├── backoff.py ├── checker.py ├── client.py ├── connection.py ├── constants.py ├── exceptions.py ├── gevent.py ├── http │ ├── __init__.py │ ├── nsqd.py │ └── nsqlookupd.py ├── reader.py ├── response.py ├── sockets │ ├── __init__.py │ ├── base.py │ ├── deflate.py │ ├── snappy.py │ └── tls.py ├── stats.py └── util.py ├── requirements.txt ├── scripts ├── travis │ └── before-install.sh └── vagrant │ ├── files │ └── etc │ │ └── init │ │ ├── nsqadmin.conf │ │ ├── nsqd.conf │ │ └── nsqlookupd.conf │ └── provision.sh ├── setup.py ├── shovel └── profile.py └── test ├── common ├── __init__.py ├── clienttest.py ├── httpclientintegrationtest.py ├── integrationtest.py ├── mockedconnectiontest.py └── mockedsockettest.py ├── fixtures ├── certificates │ ├── cert.pem │ └── key.pem └── test_stats │ └── TestStats │ └── stats ├── test_backoff.py ├── test_checker.py ├── test_client.py ├── test_clients ├── test_clients.py ├── test_nsqd.py └── test_nsqlookupd.py ├── test_connection.py ├── test_reader.py ├── test_response.py ├── test_sockets ├── test_base.py └── test_tls.py ├── test_stats.py └── test_util.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .coverage 3 | build/ 4 | dist/ 5 | *.egg-info/ 6 | venv 7 | .vagrant 8 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 2.7.11 2 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "2.7" 4 | - "3.6" 5 | env: 6 | - NSQ_VERSION=0.2.26 7 | # Get some pre-requisites 8 | before_install: "bash scripts/travis/before-install.sh" 9 | # command to install dependencies 10 | install: "pip install -r requirements.txt" 11 | # command to run tests 12 | script: make test 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 Dan Lecocq 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | clean: 2 | # Remove the build 3 | sudo rm -rf build dist 4 | # And all of our pyc files 5 | find . -name '*.pyc' | xargs -n 100 rm 6 | # And lastly, .coverage files 7 | find . -name .coverage | xargs rm 8 | 9 | .PHONY: test 10 | test: 11 | rm -f .coverage 12 | nosetests --exe --cover-package=nsq --with-coverage --cover-branches -v --logging-clear-handlers 13 | 14 | requirements: 15 | pip freeze | grep -v -e nsq-py > requirements.txt 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | NSQ for Python 2 | ============== 3 | [![Build Status](https://travis-ci.org/dlecocq/nsq-py.png)](https://travis-ci.org/dlecocq/nsq-py) 4 | 5 | Clients 6 | ======= 7 | 8 | TCP Clients 9 | ----------- 10 | This library will provide bindings for the TCP interface of `nsqd`, compatible 11 | with three frameworks: 12 | 13 | 1. `threading` / `select` which should be sufficient for most cases, except for 14 | those using a large number of `nsqd` instances 15 | 2. `gevent`, which is actually merely a wrapping of the above with 16 | monkey-patched `threading` and `select` and 17 | 3. `tornado` for those used to the original official python client. 18 | 19 | It also provides the building blocks for exending this client to work with other 20 | frameworks as well. 21 | 22 | HTTP Clients 23 | ------------ 24 | This also provides bindings for the HTTP interfaces of `nsqlookupd` and `nsqd` 25 | for convenience in `nsq.http`. 26 | 27 | Primitives 28 | ========== 29 | There are a few primitives you should use when building event-mechanism-specific 30 | bindings: 31 | 32 | - `connection.Connection` simply wraps a `socket` and knows how to send commands 33 | and read as many responses as are available on the wire 34 | - `response` has the `Response`, `Error` and `Message` classes which all know 35 | how to unpack and pack themselves. 36 | - `util` holds some utility methods for packing data and other miscellany 37 | 38 | Usage 39 | ===== 40 | Both the `threading` and `gevent` clients keep the same interface. It's just the 41 | internals that differ. In these cases, the `Reader` might be used like so: 42 | 43 | ```python 44 | # For the threaded version: 45 | from nsq.reader import Reader 46 | # For the gevent version: 47 | from nsq.gevent import Reader 48 | 49 | reader = Reader('topic', 'channel', ...) 50 | 51 | for message in reader: 52 | print message 53 | message.fin() 54 | ``` 55 | 56 | If you're using `gevent`, you might want to have a pool of `coroutines` running 57 | code to consume messages. That would look something like this: 58 | 59 | ```python 60 | from gevent.pool import Pool 61 | pool = Pool(50) 62 | 63 | def consume_message(message): 64 | print message 65 | message.fin() 66 | 67 | pool.map(consume_message, reader) 68 | ``` 69 | 70 | Closing 71 | ------- 72 | You really ought to close your reader when you're done with it. Fortunately, 73 | this is quite-easily done with `contextlib`: 74 | 75 | ```python 76 | from contextlib import closing 77 | 78 | with closing(Reader('topic', 'channel', ...)) as reader: 79 | for message in reader: 80 | .... 81 | ``` 82 | 83 | Benchmarks 84 | ========== 85 | There is a `shovel` task included in `shovel/profile.py` that runs a basic 86 | consumer benchmark against a local `nsqd` isntance. The most recent benchmark on 87 | a 2011 MacBook Pro shows the `select`-based `Reader` consuming about 105k 88 | messages / second. With `gevent` enabled, it does not appear to be statistically 89 | significantly different. 90 | 91 | Running Tests 92 | ============= 93 | You'll need to install a few dependencies before invoking the tests: 94 | 95 | ```python 96 | pip install -r requirements.txt 97 | make test 98 | ``` 99 | 100 | This should run the tests and provide coverage information. 101 | 102 | Contributing 103 | ============ 104 | Help is always appreciated. If you add functionality, please: 105 | 106 | - include a failing test in one commit 107 | - a fix for the failing test in a subsequent commit 108 | - don't decrease the code coverage 109 | -------------------------------------------------------------------------------- /Vagrantfile: -------------------------------------------------------------------------------- 1 | # Encoding: utf-8 2 | # -*- mode: ruby -*- 3 | # vi: set ft=ruby : 4 | 5 | ENV['VAGRANT_DEFAULT_PROVIDER'] = 'virtualbox' 6 | 7 | # http://docs.vagrantup.com/v2/ 8 | Vagrant.configure('2') do |config| 9 | config.vm.box = 'ubuntu/trusty64' 10 | config.vm.hostname = 'nsq-py' 11 | config.ssh.forward_agent = true 12 | 13 | config.vm.network 'forwarded_port', guest: 4150, host: 4150 14 | config.vm.network 'forwarded_port', guest: 4151, host: 4151 15 | config.vm.network 'forwarded_port', guest: 4161, host: 4161 16 | config.vm.network 'forwarded_port', guest: 4171, host: 4171 17 | 18 | config.vm.provider :virtualbox do |vb| 19 | vb.customize ["modifyvm", :id, "--memory", "1024"] 20 | end 21 | 22 | config.vm.provision :shell, path: 'scripts/vagrant/provision.sh', privileged: false 23 | end 24 | -------------------------------------------------------------------------------- /bin/nsqlookupd-statsd: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | '''Send nsqlookupd metrics to statsd.''' 3 | 4 | import argparse 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--nsqlookupd', dest='nsqlookupd', type=str, 7 | default='http://localhost:4161/', help='Full URL to nsqlookupd HTTP interface') 8 | parser.add_argument('--statsd-port', dest='statsd_port', type=int, 9 | default=8125, help='The statsd port') 10 | parser.add_argument('--statsd-host', dest='statsd_host', type=str, 11 | default='localhost', help='The statsd host') 12 | parser.add_argument('--statsd-prefix', dest='statsd_prefix', type=str, 13 | default=None, help='Prefix for statsd metrics') 14 | parser.add_argument('--interval', dest='interval', type=int, 15 | default=60, help='Interval (in seconds) between stats collections') 16 | parser.add_argument('--verbose', dest='verbose', action='store_true', 17 | help='Print all the metrics emittied') 18 | 19 | args = parser.parse_args() 20 | 21 | from nsq.stats import Nsqlookupd 22 | from nsq.checker import PeriodicThread 23 | from statsd import StatsClient 24 | 25 | statsd = StatsClient( 26 | host=args.statsd_host, 27 | port=args.statsd_port, 28 | prefix=args.statsd_prefix) 29 | 30 | nsqlookupd = Nsqlookupd(args.nsqlookupd) 31 | 32 | def report(): 33 | '''Report metrics to statsd.''' 34 | for name, value in nsqlookupd.stats: 35 | if args.verbose: 36 | print '%s => %s' % (name, value) 37 | statsd.gauge(name, value) 38 | 39 | thread = PeriodicThread(args.interval, report) 40 | thread.start() 41 | try: 42 | while True: 43 | thread.join(1) 44 | except KeyboardInterrupt: 45 | thread.stop() 46 | -------------------------------------------------------------------------------- /nsq/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | # Logging, obviously 4 | logger = logging.getLogger('nsq') 5 | if not logger.handlers: 6 | handler = logging.StreamHandler() 7 | handler.setLevel(logging.DEBUG) 8 | formatter = logging.Formatter( 9 | '%(asctime)s [%(levelname)s] %(filename)s@%(lineno)d: %(message)s') 10 | handler.setFormatter(formatter) 11 | logger.addHandler(handler) 12 | logger.setLevel(logging.INFO) 13 | 14 | # Our underlying json implmentation 15 | try: 16 | import simplejson as json 17 | except ImportError: # pragma: no cover 18 | import json 19 | 20 | # The current version 21 | __version__ = '0.1.12' 22 | -------------------------------------------------------------------------------- /nsq/backoff.py: -------------------------------------------------------------------------------- 1 | '''Classes that know about backoffs''' 2 | 3 | import sys 4 | import time 5 | 6 | 7 | class Backoff(object): 8 | '''Backoff base class''' 9 | def sleep(self, attempt): 10 | '''Sleep for the duration of this backoff''' 11 | time.sleep(self.backoff(attempt)) 12 | 13 | def backoff(self, attempt): 14 | '''The amount of time this attempt requires''' 15 | raise NotImplementedError() 16 | 17 | 18 | class Linear(Backoff): 19 | '''Linear backoff''' 20 | def __init__(self, a, b): 21 | Backoff.__init__(self) 22 | self._a = a 23 | self._b = b 24 | 25 | def backoff(self, attempt): 26 | return self._a * attempt + self._b 27 | 28 | 29 | class Constant(Linear): 30 | '''Always the same backoff''' 31 | def __init__(self, constant): 32 | Linear.__init__(self, 0, constant) 33 | 34 | 35 | class Exponential(Backoff): 36 | '''Exponential backoff of the form a * base ^ attempt + c''' 37 | def __init__(self, base, a=1, c=0): 38 | Backoff.__init__(self) 39 | self._base = base 40 | self._a = a 41 | self._c = c 42 | 43 | def backoff(self, attempt): 44 | return self._a * (self._base ** attempt) + self._c 45 | 46 | 47 | class Clamped(Backoff): 48 | '''Backoff clamped to min / max bounds''' 49 | def __init__(self, backoff, minimum=0, maximum=sys.maxsize): 50 | Backoff.__init__(self) 51 | self._min = minimum 52 | self._max = maximum 53 | self._backoff = backoff 54 | 55 | def backoff(self, attempt): 56 | return max(self._min, min(self._max, self._backoff.backoff(attempt))) 57 | 58 | 59 | class AttemptCounter(object): 60 | '''Count the number of attempts we've used''' 61 | def __init__(self, backoff): 62 | self.attempts = 0 63 | self._backoff = backoff 64 | self._last_failed = None 65 | 66 | def sleep(self): 67 | '''Sleep for the duration of this backoff''' 68 | time.sleep(self.backoff()) 69 | 70 | def backoff(self): 71 | '''Get the current backoff''' 72 | return self._backoff.backoff(self.attempts) 73 | 74 | def success(self): 75 | '''Update the attempts count correspondingly''' 76 | self._last_failed = None 77 | 78 | def failed(self): 79 | '''Update the attempts count correspondingly''' 80 | self._last_failed = time.time() 81 | self.attempts += 1 82 | 83 | def ready(self): 84 | '''Whether or not enough time has passed since the last failure''' 85 | if self._last_failed: 86 | delta = time.time() - self._last_failed 87 | return delta >= self.backoff() 88 | return True 89 | 90 | 91 | class ResettingAttemptCounter(AttemptCounter): 92 | '''A counter that resets on success''' 93 | def success(self): 94 | AttemptCounter.success(self) 95 | self.attempts = 0 96 | 97 | 98 | class DecrementingAttemptCounter(AttemptCounter): 99 | '''A counter that decrements attempts on success''' 100 | def success(self): 101 | AttemptCounter.success(self) 102 | self.attempts = max(0, self.attempts - 1) 103 | -------------------------------------------------------------------------------- /nsq/checker.py: -------------------------------------------------------------------------------- 1 | '''A class that checks connections''' 2 | 3 | import time 4 | import threading 5 | from . import logger 6 | 7 | 8 | class StoppableThread(threading.Thread): 9 | '''A thread that may be stopped''' 10 | def __init__(self): 11 | threading.Thread.__init__(self) 12 | self._event = threading.Event() 13 | 14 | def wait(self, timeout): 15 | '''Wait for the provided time to elapse''' 16 | logger.debug('Waiting for %fs', timeout) 17 | return self._event.wait(timeout) 18 | 19 | def stop(self): 20 | '''Set the stop condition''' 21 | self._event.set() 22 | 23 | 24 | class PeriodicThread(StoppableThread): 25 | '''A thread that periodically invokes a callback every interval seconds''' 26 | def __init__(self, interval, callback, *args, **kwargs): 27 | StoppableThread.__init__(self) 28 | self._interval = interval 29 | self._callback = callback 30 | self._args = args 31 | self._kwargs = kwargs 32 | self._last_checked = None 33 | 34 | def delay(self): 35 | '''How long to wait before the next check''' 36 | if self._last_checked: 37 | return self._interval - (time.time() - self._last_checked) 38 | return self._interval 39 | 40 | def callback(self): 41 | '''Run the callback''' 42 | self._callback(*self._args, **self._kwargs) 43 | self._last_checked = time.time() 44 | 45 | def run(self): 46 | '''Run the callback periodically''' 47 | while not self.wait(self.delay()): 48 | try: 49 | logger.info('Invoking callback %s', self.callback) 50 | self.callback() 51 | except Exception: 52 | logger.exception('Callback failed') 53 | 54 | 55 | class ConnectionChecker(PeriodicThread): 56 | '''A thread that checks the connections on an object''' 57 | def __init__(self, client, interval=60): 58 | PeriodicThread.__init__(self, interval, client.check_connections) 59 | -------------------------------------------------------------------------------- /nsq/client.py: -------------------------------------------------------------------------------- 1 | '''A client for talking to NSQ''' 2 | 3 | from . import connection 4 | from . import logger 5 | from . import exceptions 6 | from .constants import HEARTBEAT 7 | from .response import Response, Error 8 | from .http import nsqlookupd, ClientException 9 | from .checker import ConnectionChecker 10 | 11 | from contextlib import contextmanager 12 | import random 13 | import select 14 | import socket 15 | import time 16 | import threading 17 | import math 18 | 19 | 20 | class Client(object): 21 | '''A client for talking to NSQ over a connection''' 22 | def __init__(self, 23 | lookupd_http_addresses=None, nsqd_tcp_addresses=None, topic=None, 24 | timeout=0.1, reconnection_backoff=None, auth_secret=None, connect_timeout=None, **identify): 25 | # If lookupd_http_addresses are provided, so must a topic be. 26 | if lookupd_http_addresses: 27 | assert topic 28 | 29 | # Create clients for each of lookupd instances 30 | lookupd_http_addresses = lookupd_http_addresses or [] 31 | params = {} 32 | if auth_secret: 33 | params['access_token'] = auth_secret 34 | self._lookupd = [ 35 | nsqlookupd.Client(host, **params) for host in lookupd_http_addresses] 36 | self._topic = topic 37 | 38 | # The select timeout 39 | self._timeout = timeout 40 | # Our reconnection backoff policy 41 | self._reconnection_backoff = reconnection_backoff 42 | # The connection timeout to pass to the `Connection` class 43 | self._connect_timeout = connect_timeout 44 | 45 | # The options to send along with identify when establishing connections 46 | self._identify_options = identify 47 | self._auth_secret = auth_secret 48 | # A mapping of (host, port) to our nsqd connection objects 49 | self._connections = {} 50 | 51 | self._nsqd_tcp_addresses = nsqd_tcp_addresses or [] 52 | self.heartbeat_interval = 30 * 1000 53 | self.last_recv_timestamp = time.time() 54 | # A lock for manipulating our connections 55 | self._lock = threading.RLock() 56 | # And lastly, instantiate our connections 57 | self.check_connections() 58 | 59 | def discover(self, topic): 60 | '''Run the discovery mechanism''' 61 | logger.info('Discovering on topic %s', topic) 62 | producers = [] 63 | for lookupd in self._lookupd: 64 | logger.info('Discovering on %s', lookupd) 65 | try: 66 | # Find all the current producers on this instance 67 | for producer in lookupd.lookup(topic)['producers']: 68 | logger.info('Found producer %s on %s', producer, lookupd) 69 | producers.append( 70 | (producer['broadcast_address'], producer['tcp_port'])) 71 | except ClientException: 72 | logger.exception('Failed to query %s', lookupd) 73 | 74 | new = [] 75 | for host, port in producers: 76 | conn = self._connections.get((host, port)) 77 | if not conn: 78 | logger.info('Discovered %s:%s', host, port) 79 | new.append(self.connect(host, port)) 80 | elif not conn.alive(): 81 | logger.info('Reconnecting to %s:%s', host, port) 82 | if conn.connect(): 83 | conn.setblocking(0) 84 | self.reconnected(conn) 85 | else: 86 | logger.debug('Connection to %s:%s still alive', host, port) 87 | 88 | # And return all the new connections 89 | return [conn for conn in new if conn] 90 | 91 | def check_connections(self): 92 | '''Connect to all the appropriate instances''' 93 | logger.info('Checking connections') 94 | if self._lookupd: 95 | self.discover(self._topic) 96 | 97 | # Make sure we're connected to all the prescribed hosts 98 | for hostspec in self._nsqd_tcp_addresses: 99 | logger.debug('Checking nsqd instance %s', hostspec) 100 | host, port = hostspec.split(':') 101 | port = int(port) 102 | conn = self._connections.get((host, port), None) 103 | # If there is no connection to it, we have to try to connect 104 | if not conn: 105 | logger.info('Connecting to %s:%s', host, port) 106 | self.connect(host, port) 107 | elif not conn.alive(): 108 | # If we've connected to it before, but it's no longer alive, 109 | # we'll have to make a decision about when to try to reconnect 110 | # to it, if we need to reconnect to it at all 111 | if conn.ready_to_reconnect(): 112 | logger.info('Reconnecting to %s:%s', host, port) 113 | if conn.connect(): 114 | conn.setblocking(0) 115 | self.reconnected(conn) 116 | else: 117 | logger.debug('Checking freshness') 118 | now = time.time() 119 | time_check = math.ceil(now - self.last_recv_timestamp) 120 | if time_check >= ((self.heartbeat_interval * 2) / 1000.0): 121 | if conn.ready_to_reconnect(): 122 | logger.info('Reconnecting to %s:%s', host, port) 123 | if conn.connect(): 124 | conn.setblocking(0) 125 | self.reconnected(conn) 126 | 127 | @contextmanager 128 | def connection_checker(self): 129 | '''Run periodic reconnection checks''' 130 | thread = ConnectionChecker(self) 131 | logger.info('Starting connection-checker thread') 132 | thread.start() 133 | try: 134 | yield thread 135 | finally: 136 | logger.info('Stopping connection-checker') 137 | thread.stop() 138 | logger.info('Joining connection-checker') 139 | thread.join() 140 | 141 | def connect(self, host, port): 142 | '''Connect to the provided host, port''' 143 | conn = connection.Connection(host, port, 144 | reconnection_backoff=self._reconnection_backoff, 145 | auth_secret=self._auth_secret, 146 | timeout=self._connect_timeout, 147 | **self._identify_options) 148 | if conn.alive(): 149 | conn.setblocking(0) 150 | self.add(conn) 151 | return conn 152 | 153 | def reconnected(self, conn): 154 | '''Hook into when a connection has been reestablished''' 155 | 156 | def connections(self): 157 | '''Safely return a list of all our connections''' 158 | with self._lock: 159 | return list(self._connections.values()) 160 | 161 | def added(self, conn): 162 | '''Hook into when a connection has been added''' 163 | 164 | def add(self, connection): 165 | '''Add a connection''' 166 | key = (connection.host, connection.port) 167 | with self._lock: 168 | if key not in self._connections: 169 | self._connections[key] = connection 170 | self.added(connection) 171 | return connection 172 | else: 173 | return None 174 | 175 | def remove(self, connection): 176 | '''Remove a connection''' 177 | key = (connection.host, connection.port) 178 | with self._lock: 179 | found = self._connections.pop(key, None) 180 | try: 181 | self.close_connection(found) 182 | except Exception as exc: 183 | logger.warning('Failed to close %s: %s', connection, exc) 184 | return found 185 | 186 | def close_connection(self, connection): 187 | '''A hook for subclasses when connections are closed''' 188 | connection.close() 189 | 190 | def close(self): 191 | '''Close this client down''' 192 | map(self.remove, self.connections()) 193 | 194 | def read(self): 195 | '''Read from any of the connections that need it''' 196 | # We'll check all living connections 197 | connections = [c for c in self.connections() if c.alive()] 198 | 199 | if not connections: 200 | # If there are no connections, obviously we return no messages, but 201 | # we should wait the duration of the timeout 202 | time.sleep(self._timeout) 203 | return [] 204 | 205 | # Not all connections need to be written to, so we'll only concern 206 | # ourselves with those that require writes 207 | writes = [c for c in connections if c.pending()] 208 | try: 209 | readable, writable, exceptable = select.select( 210 | connections, writes, connections, self._timeout) 211 | except exceptions.ConnectionClosedException: 212 | logger.exception('Tried selecting on closed client') 213 | return [] 214 | except select.error: 215 | logger.exception('Error running select') 216 | return [] 217 | 218 | # If we returned because the timeout interval passed, log it and return 219 | if not (readable or writable or exceptable): 220 | logger.debug('Timed out...') 221 | return [] 222 | 223 | responses = [] 224 | # For each readable socket, we'll try to read some responses 225 | for conn in readable: 226 | try: 227 | for res in conn.read(): 228 | # We'll capture heartbeats and respond to them automatically 229 | if (isinstance(res, Response) and res.data == HEARTBEAT): 230 | logger.info('Sending heartbeat to %s', conn) 231 | conn.nop() 232 | logger.debug('Setting last_recv_timestamp') 233 | self.last_recv_timestamp = time.time() 234 | continue 235 | elif isinstance(res, Error): 236 | nonfatal = ( 237 | exceptions.FinFailedException, 238 | exceptions.ReqFailedException, 239 | exceptions.TouchFailedException 240 | ) 241 | if not isinstance(res.exception(), nonfatal): 242 | # If it's not any of the non-fatal exceptions, then 243 | # we have to close this connection 244 | logger.error( 245 | 'Closing %s: %s', conn, res.exception()) 246 | self.close_connection(conn) 247 | responses.append(res) 248 | logger.debug('Setting last_recv_timestamp') 249 | self.last_recv_timestamp = time.time() 250 | except exceptions.NSQException: 251 | logger.exception('Failed to read from %s', conn) 252 | self.close_connection(conn) 253 | except socket.error: 254 | logger.exception('Failed to read from %s', conn) 255 | self.close_connection(conn) 256 | 257 | # For each writable socket, flush some data out 258 | for conn in writable: 259 | try: 260 | conn.flush() 261 | except socket.error: 262 | logger.exception('Failed to flush %s', conn) 263 | self.close_connection(conn) 264 | 265 | # For each connection with an exception, try to close it and remove it 266 | # from our connections 267 | for conn in exceptable: 268 | self.close_connection(conn) 269 | 270 | return responses 271 | 272 | @contextmanager 273 | def random_connection(self): 274 | '''Pick a random living connection''' 275 | # While at the moment there's no need for this to be a context manager 276 | # per se, I would like to use that interface since I anticipate 277 | # adding some wrapping around it at some point. 278 | yield random.choice( 279 | [conn for conn in self.connections() if conn.alive()]) 280 | 281 | def wait_response(self): 282 | '''Wait for a response''' 283 | responses = self.read() 284 | while not responses: 285 | responses = self.read() 286 | return responses 287 | 288 | def wait_write(self, client): 289 | '''Wait until the specific client has written the message''' 290 | while client.pending(): 291 | self.read() 292 | 293 | def pub(self, topic, message): 294 | '''Publish the provided message to the provided topic''' 295 | with self.random_connection() as client: 296 | client.pub(topic, message) 297 | return self.wait_response() 298 | 299 | def mpub(self, topic, *messages): 300 | '''Publish messages to a topic''' 301 | with self.random_connection() as client: 302 | client.mpub(topic, *messages) 303 | return self.wait_response() 304 | -------------------------------------------------------------------------------- /nsq/connection.py: -------------------------------------------------------------------------------- 1 | from . import backoff 2 | from . import constants 3 | from . import logger 4 | from . import util 5 | from . import json 6 | from . import __version__ 7 | from .exceptions import ( 8 | UnsupportedException, ConnectionClosedException, ConnectionTimeoutException) 9 | from .sockets import TLSSocket, SnappySocket, DeflateSocket 10 | from .response import Response, Message 11 | 12 | import errno 13 | import socket 14 | import ssl 15 | import struct 16 | import sys 17 | import time 18 | import threading 19 | from collections import deque 20 | import six 21 | 22 | 23 | class Connection(object): 24 | '''A socket-based connection to a NSQ server''' 25 | # Default user agent 26 | USER_AGENT = 'nsq-py/%s' % __version__ 27 | # Errors that would block 28 | WOULD_BLOCK_ERRS = ( 29 | errno.EAGAIN, ssl.SSL_ERROR_WANT_WRITE, ssl.SSL_ERROR_WANT_READ) 30 | 31 | def __init__(self, host, port, timeout=None, reconnection_backoff=None, 32 | auth_secret=None, **identify): 33 | assert isinstance(host, six.string_types), host 34 | assert isinstance(port, int), port 35 | 36 | self._reset() 37 | 38 | # Our host and port 39 | self.host = host 40 | self.port = port 41 | # Whether or not our socket is set to block 42 | self._blocking = 1 43 | self._timeout = timeout if timeout is not None else 1.0 44 | # The options to use when identifying 45 | self._identify_options = dict(identify) 46 | self._identify_options.setdefault('hostname', socket.gethostname()) 47 | self._identify_options.setdefault('client_id', socket.getfqdn().split('.')[0]) 48 | self._identify_options.setdefault('feature_negotiation', True) 49 | self._identify_options.setdefault('user_agent', self.USER_AGENT) 50 | 51 | # In support of auth 52 | self._auth_secret = auth_secret 53 | 54 | # Some settings that may be determined by an identify response 55 | self.max_rdy_count = sys.maxsize 56 | 57 | # Check for any options we don't support 58 | disallowed = [] 59 | if not SnappySocket: # pragma: no branch 60 | disallowed.append('snappy') 61 | if not DeflateSocket: # pragma: no branch 62 | disallowed.extend(['deflate', 'deflate_level']) 63 | if not TLSSocket: # pragma: no branch 64 | disallowed.append('tls_v1') 65 | for key in disallowed: 66 | if self._identify_options.get(key, False): 67 | raise UnsupportedException('Option %s is not supported' % key) 68 | 69 | # Our backoff policy for reconnection. The default is to use an 70 | # exponential backoff 8 * (2 ** attempt) clamped to [0, 60] 71 | self._reconnection_backoff = ( 72 | reconnection_backoff or 73 | backoff.Clamped(backoff.Exponential(2, 8), maximum=60)) 74 | self._reconnnection_counter = backoff.ResettingAttemptCounter( 75 | self._reconnection_backoff) 76 | 77 | # A lock around our socket 78 | self._socket_lock = threading.RLock() 79 | 80 | # Establish our connection 81 | self.connect() 82 | 83 | def __str__(self): 84 | state = 'alive' if self.alive() else 'dead' 85 | return '' % ( 86 | self.host, self.port, state, self.fileno()) 87 | 88 | def ready_to_reconnect(self): 89 | '''Returns True if enough time has passed to attempt a reconnection''' 90 | return self._reconnnection_counter.ready() 91 | 92 | def _reset(self): 93 | '''Reset all of our stateful variables''' 94 | self._socket = None 95 | # The pending messages we have to send, and the current buffer we're 96 | # sending 97 | self._pending = deque() 98 | self._out_buffer = b'' 99 | # Our read buffer 100 | self._buffer = b'' 101 | # The identify response we last received from the server 102 | self._identify_response = {} 103 | # Our ready state 104 | self.last_ready_sent = 0 105 | self.ready = 0 106 | 107 | def connect(self, force=False): 108 | '''Establish a connection''' 109 | # Don't re-establish existing connections 110 | if not force and self.alive(): 111 | return True 112 | 113 | self._reset() 114 | 115 | # Otherwise, try to connect 116 | with self._socket_lock: 117 | try: 118 | logger.info('Creating socket...') 119 | self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 120 | self._socket.settimeout(self._timeout) 121 | logger.info('Connecting to %s, %s', self.host, self.port) 122 | self._socket.connect((self.host, self.port)) 123 | # Set our socket's blocking state to whatever ours is 124 | self._socket.setblocking(self._blocking) 125 | # Safely write our magic 126 | self._pending.append(constants.MAGIC_V2) 127 | while self.pending(): 128 | self.flush() 129 | # And send our identify command 130 | self.identify(self._identify_options) 131 | while self.pending(): 132 | self.flush() 133 | self._reconnnection_counter.success() 134 | # Wait until we've gotten a response to IDENTIFY, try to read 135 | # one. Also, only spend up to the provided timeout waiting to 136 | # establish the connection. 137 | limit = time.time() + self._timeout 138 | responses = self._read(1) 139 | while (not responses) and (time.time() < limit): 140 | responses = self._read(1) 141 | if not responses: 142 | raise ConnectionTimeoutException( 143 | 'Read identify response timed out (%ss)' % self._timeout) 144 | self.identified(responses[0]) 145 | return True 146 | except: 147 | logger.exception('Failed to connect') 148 | if self._socket: 149 | self._socket.close() 150 | self._reconnnection_counter.failed() 151 | self._reset() 152 | return False 153 | 154 | def close(self): 155 | '''Close our connection''' 156 | # Flush any unsent message 157 | try: 158 | while self.pending(): 159 | self.flush() 160 | except socket.error: 161 | pass 162 | with self._socket_lock: 163 | try: 164 | if self._socket: 165 | self._socket.close() 166 | finally: 167 | self._reset() 168 | 169 | def socket(self, blocking=True): 170 | '''Blockingly yield the socket''' 171 | # If the socket is available, then yield it. Otherwise, yield nothing 172 | if self._socket_lock.acquire(blocking): 173 | try: 174 | yield self._socket 175 | finally: 176 | self._socket_lock.release() 177 | 178 | def identified(self, res): 179 | '''Handle a response to our 'identify' command. Returns response''' 180 | # If they support it, they should give us a JSON blob which we should 181 | # inspect. 182 | try: 183 | res.data = json.loads(res.data) 184 | self._identify_response = res.data 185 | logger.info('Got identify response: %s', res.data) 186 | except: 187 | logger.warning('Server does not support feature negotiation') 188 | self._identify_response = {} 189 | 190 | # Save our max ready count unless it's not provided 191 | self.max_rdy_count = self._identify_response.get( 192 | 'max_rdy_count', self.max_rdy_count) 193 | if self._identify_options.get('tls_v1', False): 194 | if not self._identify_response.get('tls_v1', False): 195 | raise UnsupportedException( 196 | 'NSQd instance does not support TLS') 197 | else: 198 | self._socket = TLSSocket.wrap_socket(self._socket) 199 | 200 | # Now is the appropriate time to send auth 201 | if self._identify_response.get('auth_required', False): 202 | if not self._auth_secret: 203 | raise UnsupportedException( 204 | 'Auth required but not provided') 205 | else: 206 | self.auth(self._auth_secret) 207 | # If we're not talking over TLS, warn the user 208 | if not self._identify_response.get('tls_v1', False): 209 | logger.warning('Using AUTH without TLS') 210 | elif self._auth_secret: 211 | logger.warning('Authentication secret provided but not required') 212 | return res 213 | 214 | def alive(self): 215 | '''Returns True if this connection is alive''' 216 | return bool(self._socket) 217 | 218 | def setblocking(self, blocking): 219 | '''Set whether or not this message is blocking''' 220 | for sock in self.socket(): 221 | sock.setblocking(blocking) 222 | self._blocking = blocking 223 | 224 | def fileno(self): 225 | '''Returns the socket's fileno. This allows us to select on this''' 226 | for sock in self.socket(): 227 | if sock: 228 | return sock.fileno() 229 | raise ConnectionClosedException() 230 | 231 | def pending(self): 232 | '''All of the messages waiting to be sent''' 233 | return self._pending 234 | 235 | def flush(self): 236 | '''Flush some of the waiting messages, returns count written''' 237 | # When profiling, we found that while there was some efficiency to be 238 | # gained elsewhere, the big performance hit is sending lots of small 239 | # messages at a time. In particular, consumers send many 'FIN' messages 240 | # which are very small indeed and the cost of dispatching so many system 241 | # calls is very high. Instead, we prefer to glom together many messages 242 | # into a single string to send at once. 243 | total = 0 244 | for sock in self.socket(blocking=False): 245 | # If there's nothing left in the out buffer, take whatever's in the 246 | # pending queue. 247 | # 248 | # When using SSL, if the socket throws 'SSL_WANT_WRITE', then the 249 | # subsequent send requests have to send the same buffer. 250 | pending = self._pending 251 | data = self._out_buffer or b''.join( 252 | pending.popleft() for _ in range(len(pending))) 253 | try: 254 | # Try to send as much of the first message as possible 255 | total = sock.send(data) 256 | except socket.error as exc: 257 | # Catch (errno, message)-type socket.errors 258 | if exc.args[0] not in self.WOULD_BLOCK_ERRS: 259 | raise 260 | self._out_buffer = data 261 | else: 262 | self._out_buffer = None 263 | finally: 264 | if total < len(data): 265 | # Save the rest of the message that could not be sent 266 | self._pending.appendleft(data[total:]) 267 | return total 268 | 269 | def send(self, command, message=None): 270 | '''Send a command over the socket with length endcoded''' 271 | if message: 272 | joined = command + constants.NL + util.pack(message) 273 | else: 274 | joined = command + constants.NL 275 | if self._blocking: 276 | for sock in self.socket(): 277 | sock.sendall(joined) 278 | else: 279 | self._pending.append(joined) 280 | 281 | def identify(self, data): 282 | '''Send an identification message''' 283 | return self.send(constants.IDENTIFY, json.dumps(data).encode('UTF-8')) 284 | 285 | def auth(self, secret): 286 | '''Send an auth secret''' 287 | return self.send(constants.AUTH, secret) 288 | 289 | def sub(self, topic, channel): 290 | '''Subscribe to a topic/channel''' 291 | return self.send(b' '.join((constants.SUB, topic, channel))) 292 | 293 | def pub(self, topic, message): 294 | '''Publish to a topic''' 295 | return self.send(b' '.join((constants.PUB, topic)), message) 296 | 297 | def mpub(self, topic, *messages): 298 | '''Publish multiple messages to a topic''' 299 | return self.send(constants.MPUB + b' ' + topic, messages) 300 | 301 | def rdy(self, count): 302 | '''Indicate that you're ready to receive''' 303 | self.ready = count 304 | self.last_ready_sent = count 305 | return self.send(constants.RDY + b' ' + six.text_type(count).encode()) 306 | 307 | def fin(self, message_id): 308 | '''Indicate that you've finished a message ID''' 309 | return self.send(constants.FIN + b' ' + message_id) 310 | 311 | def req(self, message_id, timeout): 312 | '''Re-queue a message''' 313 | return self.send(constants.REQ + b' ' + message_id + b' ' + six.text_type(timeout).encode()) 314 | 315 | def touch(self, message_id): 316 | '''Reset the timeout for an in-flight message''' 317 | return self.send(constants.TOUCH + b' ' + message_id) 318 | 319 | def cls(self): 320 | '''Close the connection cleanly''' 321 | return self.send(constants.CLS) 322 | 323 | def nop(self): 324 | '''Send a no-op''' 325 | return self.send(constants.NOP) 326 | 327 | # These are the various incarnations of our read method. In some instances, 328 | # we want to return responses in the typical way. But while establishing 329 | # connections or negotiating a TLS connection, we need to do different 330 | # things 331 | def _read(self, limit=1000): 332 | '''Return all the responses read''' 333 | # It's important to know that it may return no responses or multiple 334 | # responses. It depends on how the buffering works out. First, read from 335 | # the socket 336 | for sock in self.socket(): 337 | if sock is None: 338 | # Race condition. Connection has been closed. 339 | return [] 340 | try: 341 | packet = sock.recv(4096) 342 | except socket.timeout: 343 | # If the socket times out, return nothing 344 | return [] 345 | except socket.error as exc: 346 | # Catch (errno, message)-type socket.errors 347 | if exc.args[0] in self.WOULD_BLOCK_ERRS: 348 | return [] 349 | else: 350 | raise 351 | 352 | # Append our newly-read data to our buffer 353 | self._buffer += packet 354 | 355 | responses = [] 356 | total = 0 357 | buf = self._buffer 358 | remaining = len(buf) 359 | while limit and (remaining >= 4): 360 | size = struct.unpack('>l', buf[total:(total + 4)])[0] 361 | # Now check to see if there's enough left in the buffer to read 362 | # the message. 363 | if (remaining - 4) >= size: 364 | responses.append(Response.from_raw( 365 | self, buf[(total + 4):(total + size + 4)])) 366 | total += (size + 4) 367 | remaining -= (size + 4) 368 | limit -= 1 369 | else: 370 | break 371 | self._buffer = self._buffer[total:] 372 | return responses 373 | 374 | def read(self): 375 | '''Responses from an established socket''' 376 | responses = self._read() 377 | # Determine the number of messages in here and decrement our ready 378 | # count appropriately 379 | self.ready -= sum( 380 | map(int, (r.frame_type == Message.FRAME_TYPE for r in responses))) 381 | return responses 382 | -------------------------------------------------------------------------------- /nsq/constants.py: -------------------------------------------------------------------------------- 1 | '''Contstants for NSQ''' 2 | 3 | # NSQ Magic 4 | MAGIC_V2 = b' V2' 5 | 6 | # The newline character 7 | NL = b'\n' 8 | 9 | # Response 10 | FRAME_TYPE_RESPONSE = 0 11 | FRAME_TYPE_ERROR = 1 12 | FRAME_TYPE_MESSAGE = 2 13 | 14 | # Command names 15 | IDENTIFY = b'IDENTIFY' 16 | AUTH = b'AUTH' 17 | SUB = b'SUB' 18 | PUB = b'PUB' 19 | MPUB = b'MPUB' 20 | RDY = b'RDY' 21 | FIN = b'FIN' 22 | REQ = b'REQ' 23 | TOUCH = b'TOUCH' 24 | CLS = b'CLS' 25 | NOP = b'NOP' 26 | 27 | # Heartbeat text 28 | HEARTBEAT = b'_heartbeat_' 29 | -------------------------------------------------------------------------------- /nsq/exceptions.py: -------------------------------------------------------------------------------- 1 | '''Exception classes''' 2 | 3 | 4 | class NSQException(Exception): 5 | '''Base class for all exceptions in this library''' 6 | 7 | 8 | class ConnectionTimeoutException(NSQException): 9 | '''Connection instantiation timed out''' 10 | 11 | 12 | class ConnectionClosedException(NSQException): 13 | '''Trying to use a closed connection as if it's alive''' 14 | 15 | 16 | class UnsupportedException(NSQException): 17 | '''When a requested feature cannot be used''' 18 | 19 | 20 | class TimeoutException(NSQException): 21 | '''Exception for failing a timeout''' 22 | 23 | 24 | class InvalidException(NSQException): 25 | '''Exception for E_INVALID''' 26 | name = b'E_INVALID' 27 | 28 | 29 | class BadBodyException(NSQException): 30 | '''Exception for E_BAD_BODY''' 31 | name = b'E_BAD_BODY' 32 | 33 | 34 | class BadTopicException(NSQException): 35 | '''Exception for E_BAD_TOPIC''' 36 | name = b'E_BAD_TOPIC' 37 | 38 | 39 | class BadChannelException(NSQException): 40 | '''Exception for E_BAD_CHANNEL''' 41 | name = b'E_BAD_CHANNEL' 42 | 43 | 44 | class BadMessageException(NSQException): 45 | '''Exception for E_BAD_MESSAGE''' 46 | name = b'E_BAD_MESSAGE' 47 | 48 | 49 | class PubFailedException(NSQException): 50 | '''Exception for E_PUB_FAILED''' 51 | name = b'E_PUB_FAILED' 52 | 53 | 54 | class MpubFailedException(NSQException): 55 | '''Exception for E_MPUB_FAILED''' 56 | name = b'E_MPUB_FAILED' 57 | 58 | 59 | class FinFailedException(NSQException): 60 | '''Exception for E_FIN_FAILED''' 61 | name = b'E_FIN_FAILED' 62 | 63 | 64 | class ReqFailedException(NSQException): 65 | '''Exception for E_REQ_FAILED''' 66 | name = b'E_REQ_FAILED' 67 | 68 | 69 | class TouchFailedException(NSQException): 70 | '''Exception for E_TOUCH_FAILED''' 71 | name = b'E_TOUCH_FAILED' 72 | -------------------------------------------------------------------------------- /nsq/gevent.py: -------------------------------------------------------------------------------- 1 | '''A Gevent-based client''' 2 | 3 | from __future__ import absolute_import 4 | import gevent.monkey 5 | gevent.monkey.patch_all() 6 | 7 | # The gevent client is actually the same as the synchronous client, just with 8 | # socket, thread, ssl, select and time patched. By importing this here, it makes 9 | # a gevent-compatible client available from nsq.gevent.Client 10 | from .client import Client 11 | from .reader import Reader 12 | -------------------------------------------------------------------------------- /nsq/http/__init__.py: -------------------------------------------------------------------------------- 1 | '''Our clients for interacting with various clients''' 2 | 3 | from decorator import decorator 4 | import requests 5 | import six 6 | from six.moves.urllib_parse import urlsplit, urlunsplit, urljoin 7 | 8 | from .. import json, logger 9 | from ..exceptions import NSQException 10 | 11 | 12 | @decorator 13 | def wrap(function, *args, **kwargs): 14 | '''Wrap a function that returns a request with some exception handling''' 15 | try: 16 | req = function(*args, **kwargs) 17 | logger.debug('Got %s: %s', req.status_code, req.content) 18 | if req.status_code == 200: 19 | return req 20 | else: 21 | raise ClientException(req.reason, req.content) 22 | except ClientException: 23 | raise 24 | except Exception as exc: 25 | raise ClientException(exc) 26 | 27 | 28 | @decorator 29 | def json_wrap(function, *args, **kwargs): 30 | '''Return the json content of a function that returns a request''' 31 | try: 32 | # Some responses have data = None, but they generally signal a 33 | # successful API call as well. 34 | response = json.loads(function(*args, **kwargs).content) 35 | if 'data' in response: 36 | return response['data'] or True 37 | else: 38 | return response 39 | except Exception as exc: 40 | raise ClientException(exc) 41 | 42 | 43 | @decorator 44 | def ok_check(function, *args, **kwargs): 45 | '''Ensure that the response body is OK''' 46 | req = function(*args, **kwargs) 47 | if req.content.lower() != b'ok': 48 | raise ClientException(req.content) 49 | return req.content 50 | 51 | 52 | class ClientException(NSQException): 53 | '''An exception class for all client errors''' 54 | 55 | 56 | def _relative(split_result, path): 57 | new_split = split_result._replace(path=urljoin(split_result.path, path)) 58 | return urlunsplit(new_split) 59 | 60 | 61 | class BaseClient(object): 62 | '''Base client class''' 63 | def __init__(self, target, **params): 64 | if isinstance(target, six.string_types): 65 | self._host = urlsplit(target) 66 | elif isinstance(target, (tuple, list)): 67 | self._host = urlsplit('http://%s:%s/' % target) 68 | else: 69 | raise TypeError('Host must be a string or tuple') 70 | self._params = params 71 | 72 | @wrap 73 | def get(self, path, *args, **kwargs): 74 | '''GET the provided endpoint''' 75 | target = _relative(self._host, path) 76 | params = kwargs.get('params', {}) 77 | params.update(self._params) 78 | kwargs['params'] = params 79 | logger.debug('GET %s with %s, %s', target, args, kwargs) 80 | return requests.get(target, *args, **kwargs) 81 | 82 | @wrap 83 | def post(self, path, *args, **kwargs): 84 | '''POST to the provided endpoint''' 85 | target = _relative(self._host, path) 86 | params = kwargs.get('params', {}) 87 | params.update(self._params) 88 | kwargs['params'] = params 89 | logger.debug('POST %s with %s, %s', target, args, kwargs) 90 | return requests.post(target, *args, **kwargs) 91 | -------------------------------------------------------------------------------- /nsq/http/nsqd.py: -------------------------------------------------------------------------------- 1 | '''A class for interacting with a nsqd instance over http''' 2 | 3 | from . import BaseClient, json_wrap, ok_check, ClientException 4 | from ..util import pack 5 | 6 | 7 | class Client(BaseClient): 8 | @ok_check 9 | def ping(self): 10 | '''Ping the client''' 11 | return self.get('ping') 12 | 13 | @json_wrap 14 | def info(self): 15 | '''Get information about the client''' 16 | return self.get('info') 17 | 18 | @ok_check 19 | def pub(self, topic, message): 20 | '''Publish a message to a topic''' 21 | return self.post('pub', params={'topic': topic}, data=message) 22 | 23 | @ok_check 24 | def mpub(self, topic, messages, binary=True): 25 | '''Send multiple messages to a topic. Optionally pack the messages''' 26 | if binary: 27 | # Pack and ship the data 28 | return self.post('mpub', data=pack(messages)[4:], 29 | params={'topic': topic, 'binary': True}) 30 | elif any(b'\n' in m for m in messages): 31 | # If any of the messages has a newline, then you must use the binary 32 | # calling format 33 | raise ClientException( 34 | 'Use `binary` flag in mpub for messages with newlines') 35 | else: 36 | return self.post( 37 | '/mpub', params={'topic': topic}, data=b'\n'.join(messages)) 38 | 39 | @json_wrap 40 | def create_topic(self, topic): 41 | '''Create the provided topic''' 42 | return self.get('create_topic', params={'topic': topic}) 43 | 44 | @json_wrap 45 | def empty_topic(self, topic): 46 | '''Empty the provided topic''' 47 | return self.get('empty_topic', params={'topic': topic}) 48 | 49 | @json_wrap 50 | def delete_topic(self, topic): 51 | '''Delete the provided topic''' 52 | return self.get('delete_topic', params={'topic': topic}) 53 | 54 | @json_wrap 55 | def pause_topic(self, topic): 56 | '''Pause the provided topic''' 57 | return self.get('pause_topic', params={'topic': topic}) 58 | 59 | @json_wrap 60 | def unpause_topic(self, topic): 61 | '''Unpause the provided topic''' 62 | return self.get('unpause_topic', params={'topic': topic}) 63 | 64 | @json_wrap 65 | def create_channel(self, topic, channel): 66 | '''Create the channel in the provided topic''' 67 | return self.get( 68 | '/create_channel', params={'topic': topic, 'channel': channel}) 69 | 70 | @json_wrap 71 | def empty_channel(self, topic, channel): 72 | '''Empty the channel in the provided topic''' 73 | return self.get( 74 | '/empty_channel', params={'topic': topic, 'channel': channel}) 75 | 76 | @json_wrap 77 | def delete_channel(self, topic, channel): 78 | '''Delete the channel in the provided topic''' 79 | return self.get( 80 | '/delete_channel', params={'topic': topic, 'channel': channel}) 81 | 82 | @json_wrap 83 | def pause_channel(self, topic, channel): 84 | '''Pause the channel in the provided topic''' 85 | return self.get( 86 | '/pause_channel', params={'topic': topic, 'channel': channel}) 87 | 88 | @json_wrap 89 | def unpause_channel(self, topic, channel): 90 | '''Unpause the channel in the provided topic''' 91 | return self.get( 92 | '/unpause_channel', params={'topic': topic, 'channel': channel}) 93 | 94 | @json_wrap 95 | def stats(self): 96 | '''Get stats about the server''' 97 | return self.get('stats', params={'format': 'json'}) 98 | 99 | def clean_stats(self): 100 | '''Stats with topics and channels keyed on topic and channel names''' 101 | stats = self.stats() 102 | if 'topics' in stats: # pragma: no branch 103 | topics = stats['topics'] 104 | topics = dict((t.pop('topic_name'), t) for t in topics) 105 | for topic, data in topics.items(): 106 | if 'channels' in data: # pragma: no branch 107 | channels = data['channels'] 108 | channels = dict( 109 | (c.pop('channel_name'), c) for c in channels) 110 | data['channels'] = channels 111 | stats['topics'] = topics 112 | return stats 113 | -------------------------------------------------------------------------------- /nsq/http/nsqlookupd.py: -------------------------------------------------------------------------------- 1 | '''A class for interacting with a nsqlookupd instance over http''' 2 | 3 | from . import BaseClient, json_wrap, ok_check 4 | 5 | 6 | class Client(BaseClient): 7 | '''A client for talking to nsqlookupd over http''' 8 | @ok_check 9 | def ping(self): 10 | '''Ping the client''' 11 | return self.get('ping') 12 | 13 | @json_wrap 14 | def info(self): 15 | '''Get info about this instance''' 16 | return self.get('info') 17 | 18 | @json_wrap 19 | def lookup(self, topic): 20 | '''Look up which hosts serve a particular topic''' 21 | return self.get('lookup', params={'topic': topic}) 22 | 23 | @json_wrap 24 | def topics(self): 25 | '''Get a list of topics''' 26 | return self.get('topics') 27 | 28 | @json_wrap 29 | def channels(self, topic): 30 | '''Get a list of channels for a given topic''' 31 | return self.get('channels', params={'topic': topic}) 32 | 33 | @json_wrap 34 | def nodes(self): 35 | '''Get information about nodes''' 36 | return self.get('nodes') 37 | 38 | @json_wrap 39 | def delete_topic(self, topic): 40 | '''Delete a topic''' 41 | return self.get('delete_topic', params={'topic': topic}) 42 | 43 | @json_wrap 44 | def delete_channel(self, topic, channel): 45 | '''Delete a channel in the provided topic''' 46 | return self.get('delete_channel', 47 | params={'topic': topic, 'channel': channel}) 48 | 49 | @json_wrap 50 | def tombstone_topic_producer(self, topic, node): 51 | '''It's not clear what this endpoint does''' 52 | return self.get('tombstone_topic_producer', 53 | params={'topic': topic, 'node': node}) 54 | 55 | @json_wrap 56 | def create_topic(self, topic): 57 | '''Create a topic''' 58 | return self.get('create_topic', params={'topic': topic}) 59 | 60 | @json_wrap 61 | def create_channel(self, topic, channel): 62 | '''Create a channel in the provided topic''' 63 | return self.get('create_channel', 64 | params={'topic': topic, 'channel': channel}) 65 | 66 | @json_wrap 67 | def debug(self): 68 | '''Get debugging information''' 69 | return self.get('debug') 70 | -------------------------------------------------------------------------------- /nsq/reader.py: -------------------------------------------------------------------------------- 1 | from .client import Client 2 | from .response import Message 3 | from .util import distribute 4 | from . import logger 5 | 6 | 7 | class Reader(Client): 8 | '''A client meant exclusively for reading''' 9 | def __init__(self, topic, channel, lookupd_http_addresses=None, 10 | nsqd_tcp_addresses=None, max_in_flight=200, **identify): 11 | self._channel = channel 12 | self._max_in_flight = max_in_flight 13 | Client.__init__( 14 | self, lookupd_http_addresses, nsqd_tcp_addresses, topic, **identify) 15 | 16 | def reconnected(self, conn): 17 | '''Subscribe connection and manipulate its RDY state''' 18 | conn.sub(self._topic, self._channel) 19 | conn.rdy(1) 20 | 21 | def added(self, conn): 22 | '''Subscribe connection and manipulate its RDY state''' 23 | if conn.alive(): 24 | self.reconnected(conn) 25 | 26 | def distribute_ready(self): 27 | '''Distribute the ready state across all of the connections''' 28 | connections = [c for c in self.connections() if c.alive()] 29 | if len(connections) > self._max_in_flight: 30 | raise NotImplementedError( 31 | 'Max in flight must be greater than number of connections') 32 | else: 33 | # Distribute the ready count evenly among the connections 34 | for count, conn in distribute(self._max_in_flight, connections): 35 | # We cannot exceed the maximum RDY count for a connection 36 | if count > conn.max_rdy_count: 37 | logger.info( 38 | 'Using max_rdy_count (%i) instead of %i for %s RDY', 39 | conn.max_rdy_count, count, conn) 40 | count = conn.max_rdy_count 41 | logger.info('Sending RDY %i to %s', count, conn) 42 | conn.rdy(count) 43 | 44 | def needs_distribute_ready(self): 45 | '''Determine whether or not we need to redistribute the ready state''' 46 | # Try to pre-empty starvation by comparing current RDY against 47 | # the last value sent. 48 | alive = [c for c in self.connections() if c.alive()] 49 | if any(c.ready <= (c.last_ready_sent * 0.25) for c in alive): 50 | return True 51 | 52 | def close_connection(self, connection): 53 | '''A hook into when connections are closed''' 54 | Client.close_connection(self, connection) 55 | self.distribute_ready() 56 | 57 | def read(self): 58 | '''Read some number of messages''' 59 | found = Client.read(self) 60 | 61 | # Redistribute our ready state if necessary 62 | if self.needs_distribute_ready(): 63 | self.distribute_ready() 64 | 65 | # Finally, return all the results we've read 66 | return found 67 | 68 | def __iter__(self): 69 | with self.connection_checker(): 70 | while True: 71 | for message in self.read(): 72 | # A reader's only interested in actual messages 73 | if isinstance(message, Message): 74 | # We'll probably add a hook in here to track the RDY 75 | # states of our connections 76 | yield message 77 | -------------------------------------------------------------------------------- /nsq/response.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import struct 3 | 4 | import six 5 | from .constants import FRAME_TYPE_RESPONSE, FRAME_TYPE_MESSAGE, FRAME_TYPE_ERROR 6 | from . import exceptions 7 | 8 | from contextlib import contextmanager 9 | import socket 10 | import sys 11 | 12 | 13 | class Response(object): 14 | '''A response from NSQ''' 15 | FRAME_TYPE = FRAME_TYPE_RESPONSE 16 | 17 | __slots__ = ('connection', 'frame_type', 'data') 18 | 19 | @staticmethod 20 | def from_raw(conn, raw): 21 | '''Return a new response from a raw buffer''' 22 | frame_type = struct.unpack('>l', raw[0:4])[0] 23 | message = raw[4:] 24 | if frame_type == FRAME_TYPE_MESSAGE: 25 | return Message(conn, frame_type, message) 26 | elif frame_type == FRAME_TYPE_RESPONSE: 27 | return Response(conn, frame_type, message) 28 | elif frame_type == FRAME_TYPE_ERROR: 29 | return Error(conn, frame_type, message) 30 | else: 31 | raise TypeError('Unknown frame type: %s' % frame_type) 32 | 33 | @classmethod 34 | def pack(cls, data): 35 | '''Pack the provided data into a Response''' 36 | return struct.pack('>ll', len(data) + 4, cls.FRAME_TYPE) + data 37 | 38 | def __init__(self, conn, frame_type, data): 39 | self.connection = conn 40 | self.data = data 41 | self.frame_type = frame_type 42 | 43 | def __str__(self): 44 | return '%s - %s' % (self.__class__.__name__, self.data) 45 | 46 | def __eq__(self, other): 47 | return ( 48 | (self.frame_type == other.frame_type) and 49 | (self.connection == other.connection) and 50 | (self.data == other.data)) 51 | 52 | 53 | class Message(Response): 54 | '''A message''' 55 | FRAME_TYPE = FRAME_TYPE_MESSAGE 56 | 57 | format = '>qH16s' 58 | size = struct.calcsize(format) 59 | 60 | __slots__ = ('timestamp', 'attempts', 'id', 'body', 'processed') 61 | 62 | @classmethod 63 | def pack(cls, timestamp, attempts, _id, data): 64 | return struct.pack( 65 | '>llqH16s', 66 | len(data) + cls.size + 4, 67 | cls.FRAME_TYPE, 68 | timestamp, 69 | attempts, 70 | _id) + data 71 | 72 | def __init__(self, conn, frame_type, data): 73 | Response.__init__(self, conn, frame_type, data) 74 | self.timestamp, self.attempts, self.id = struct.unpack( 75 | self.format, data[:self.size]) 76 | self.body = data[self.size:] 77 | self.processed = False 78 | 79 | def __str__(self): 80 | return '%s - %i %i %s %s' % ( 81 | self.__class__.__name__, 82 | self.timestamp, 83 | self.attempts, 84 | self.id, 85 | self.body) 86 | 87 | def fin(self): 88 | '''Indicate that this message is finished processing''' 89 | self.connection.fin(self.id) 90 | self.processed = True 91 | 92 | def req(self, timeout): 93 | '''Re-queue a message''' 94 | self.connection.req(self.id, timeout) 95 | self.processed = True 96 | 97 | def touch(self): 98 | '''Reset the timeout for an in-flight message''' 99 | self.connection.touch(self.id) 100 | 101 | def delay(self): 102 | '''How long to delay its requeueing''' 103 | return 60 104 | 105 | @contextmanager 106 | def handle(self): 107 | '''Make sure this message gets either 'fin' or 'req'd''' 108 | try: 109 | yield self 110 | except: 111 | # Requeue the message and raise the original exception 112 | typ, value, trace = sys.exc_info() 113 | if not self.processed: 114 | try: 115 | self.req(self.delay()) 116 | except socket.error: 117 | self.connection.close() 118 | six.reraise(typ, value, trace) 119 | else: 120 | if not self.processed: 121 | try: 122 | self.fin() 123 | except socket.error: 124 | self.connection.close() 125 | 126 | 127 | class Error(Response): 128 | '''An error''' 129 | FRAME_TYPE = FRAME_TYPE_ERROR 130 | 131 | # A mapping of the response string to the appropriate exception 132 | mapping = {} 133 | 134 | @classmethod 135 | def find(cls, name): 136 | '''Find the exception class by name''' 137 | if not cls.mapping: # pragma: no branch 138 | for _, obj in inspect.getmembers(exceptions): 139 | if inspect.isclass(obj): 140 | if issubclass(obj, exceptions.NSQException): # pragma: no branch 141 | if hasattr(obj, 'name'): 142 | cls.mapping[obj.name] = obj 143 | klass = cls.mapping.get(name) 144 | if klass == None: 145 | raise TypeError('No matching exception for %s' % name) 146 | return klass 147 | 148 | def exception(self): 149 | '''Return an instance of the corresponding exception''' 150 | code, _, message = self.data.partition(b' ') 151 | return self.find(code)(message) 152 | -------------------------------------------------------------------------------- /nsq/sockets/__init__.py: -------------------------------------------------------------------------------- 1 | '''Sockets that wrap different connection types''' 2 | 3 | # Not all platforms support all types of sockets provided here. For those that 4 | # are not available, the corresponding socket wrapper is imported as None. 5 | 6 | from .. import logger 7 | 8 | # Snappy support 9 | try: 10 | from .snappy import SnappySocket 11 | except ImportError: # pragma: no cover 12 | logger.debug('Snappy compression not supported') 13 | SnappySocket = None 14 | 15 | 16 | # Deflate support 17 | try: 18 | from .deflate import DeflateSocket 19 | except ImportError: # pragma: no cover 20 | logger.debug('Deflate compression not supported') 21 | DeflateSocket = None 22 | 23 | 24 | # The TLS socket 25 | try: 26 | from .tls import TLSSocket 27 | except ImportError: # pragma: no cover 28 | logger.warning('TLS not supported') 29 | TLSSocket = None 30 | -------------------------------------------------------------------------------- /nsq/sockets/base.py: -------------------------------------------------------------------------------- 1 | '''Base socket wrapper''' 2 | 3 | 4 | class SocketWrapper(object): 5 | '''Wraps a socket in another layer''' 6 | # Methods for which we the default should be to simply pass through to the 7 | # underlying socket 8 | METHODS = ( 9 | 'accept', 'bind', 'close', 'connect', 'fileno', 'getpeername', 10 | 'getsockname', 'getsockopt', 'setsockopt', 'gettimeout', 'settimeout', 11 | 'setblocking', 'listen', 'makefile', 'shutdown' 12 | ) 13 | 14 | @classmethod 15 | def wrap_socket(cls, socket, **options): 16 | '''Returns a socket-like object that transparently does compression''' 17 | return cls(socket, **options) 18 | 19 | def __init__(self, socket): 20 | self._socket = socket 21 | for method in self.METHODS: 22 | # Check to see if this class overrides this method, and if not, then 23 | # we should have it simply map through to the underlying socket 24 | if not hasattr(self, method): 25 | setattr(self, method, getattr(self._socket, method)) 26 | 27 | def send(self, data, flags=0): 28 | '''Same as socket.send''' 29 | raise NotImplementedError() 30 | 31 | def sendall(self, data, flags=0): 32 | '''Same as socket.sendall''' 33 | count = len(data) 34 | while count: 35 | sent = self.send(data, flags) 36 | # This could probably be a buffer object 37 | data = data[sent:] 38 | count -= sent 39 | 40 | def recv(self, nbytes, flags=0): 41 | '''Same as socket.recv''' 42 | raise NotImplementedError() 43 | 44 | def recv_into(self, buff, nbytes, flags=0): 45 | '''Same as socket.recv_into''' 46 | raise NotImplementedError('Wrapped sockets do not implement recv_into') 47 | -------------------------------------------------------------------------------- /nsq/sockets/deflate.py: -------------------------------------------------------------------------------- 1 | '''Wraps a socket in Deflate compression''' 2 | 3 | raise ImportError('Deflate not supported') 4 | -------------------------------------------------------------------------------- /nsq/sockets/snappy.py: -------------------------------------------------------------------------------- 1 | '''A socket wrapping snappy compression''' 2 | 3 | raise ImportError('Snappy not supported') 4 | -------------------------------------------------------------------------------- /nsq/sockets/tls.py: -------------------------------------------------------------------------------- 1 | '''Wraps a socket in TLS''' 2 | 3 | import ssl 4 | 5 | from .. import logger 6 | 7 | 8 | class TLSSocket(object): 9 | '''Provide a way to return a TLS socket''' 10 | @classmethod 11 | def wrap_socket(cls, socket): 12 | sock = ssl.wrap_socket(socket, ssl_version=ssl.PROTOCOL_TLSv1) 13 | while True: 14 | try: 15 | logger.info('Performing TLS handshade...') 16 | sock.do_handshake() 17 | break 18 | except ssl.SSLError as err: 19 | errs = ( 20 | ssl.SSL_ERROR_WANT_READ, 21 | ssl.SSL_ERROR_WANT_WRITE) 22 | if err.args[0] not in (errs): 23 | raise 24 | else: 25 | logger.info('Continuing TLS handshake...') 26 | logger.info('Socket wrapped') 27 | return sock 28 | -------------------------------------------------------------------------------- /nsq/stats.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | 3 | from .http import nsqlookupd, nsqd 4 | 5 | class Nsqlookupd(object): 6 | '''A class for grabbing stats about all hosts and topics reported to nsqlookupd.''' 7 | def __init__(self, *args, **kwargs): 8 | self.client = nsqlookupd.Client(*args, **kwargs) 9 | 10 | @property 11 | def merged(self): 12 | '''The clean stats from all the hosts reporting to this host.''' 13 | stats = {} 14 | for topic in self.client.topics()['topics']: 15 | for producer in self.client.lookup(topic)['producers']: 16 | hostname = producer['broadcast_address'] 17 | port = producer['http_port'] 18 | host = '%s_%s' % (hostname, port) 19 | stats[host] = nsqd.Client( 20 | 'http://%s:%s/' % (hostname, port)).clean_stats() 21 | return stats 22 | 23 | @property 24 | def raw(self): 25 | '''All the raw, unaggregated stats (with duplicates).''' 26 | topic_keys = ( 27 | 'message_count', 28 | 'depth', 29 | 'backend_depth', 30 | 'paused' 31 | ) 32 | 33 | channel_keys = ( 34 | 'in_flight_count', 35 | 'timeout_count', 36 | 'paused', 37 | 'deferred_count', 38 | 'message_count', 39 | 'depth', 40 | 'backend_depth', 41 | 'requeue_count' 42 | ) 43 | 44 | for host, stats in self.merged.items(): 45 | for topic, stats in stats.get('topics', {}).items(): 46 | for key in topic_keys: 47 | value = int(stats.get(key, -1)) 48 | yield ( 49 | 'host.%s.topic.%s.%s' % (host, topic, key), 50 | value, 51 | False 52 | ) 53 | yield ( 54 | 'topic.%s.%s' % (topic, key), 55 | value, 56 | True 57 | ) 58 | yield ( 59 | 'topics.%s' % key, 60 | value, 61 | True 62 | ) 63 | 64 | for chan, stats in stats.get('channels', {}).items(): 65 | data = { 66 | key: int(stats.get(key, -1)) for key in channel_keys 67 | } 68 | data['clients'] = len(stats.get('clients', [])) 69 | 70 | for key, value in data.items(): 71 | yield ( 72 | 'host.%s.topic.%s.channel.%s.%s' % (host, topic, chan, key), 73 | value, 74 | False 75 | ) 76 | yield ( 77 | 'host.%s.topic.%s.channels.%s' % (host, topic, key), 78 | value, 79 | True 80 | ) 81 | yield ( 82 | 'topic.%s.channels.%s' % (topic, key), 83 | value, 84 | True 85 | ) 86 | yield ( 87 | 'channels.%s' % key, 88 | value, 89 | True 90 | ) 91 | 92 | @property 93 | def stats(self): 94 | '''Stats that have been aggregated appropriately.''' 95 | data = Counter() 96 | for name, value, aggregated in self.raw: 97 | if aggregated: 98 | data['%s.max' % name] = max(data['%s.max' % name], value) 99 | data['%s.total' % name] += value 100 | else: 101 | data[name] = value 102 | 103 | return sorted(data.items()) 104 | -------------------------------------------------------------------------------- /nsq/util.py: -------------------------------------------------------------------------------- 1 | '''Some utilities used around town''' 2 | 3 | import struct 4 | 5 | 6 | def pack_string(message): 7 | '''Pack a single message in the TCP protocol format''' 8 | # [ 4-byte message size ][ N-byte binary data ] 9 | return struct.pack('>l', len(message)) + message 10 | 11 | 12 | def pack_iterable(messages): 13 | '''Pack an iterable of messages in the TCP protocol format''' 14 | # [ 4-byte body size ] 15 | # [ 4-byte num messages ] 16 | # [ 4-byte message #1 size ][ N-byte binary data ] 17 | # ... (repeated times) 18 | return pack_string( 19 | struct.pack('>l', len(messages)) + 20 | b''.join(map(pack_string, messages))) 21 | 22 | 23 | def pack(message): 24 | '''Pack the provided message''' 25 | if isinstance(message, bytes): 26 | return pack_string(message) 27 | else: 28 | return pack_iterable(message) 29 | 30 | 31 | def hexify(message): 32 | '''Print out printable characters, but others in hex''' 33 | import string 34 | hexified = [] 35 | for char in message: 36 | if (char in '\n\r \t') or (char not in string.printable): 37 | hexified.append('\\x%02x' % ord(char)) 38 | else: 39 | hexified.append(char) 40 | return ''.join(hexified) 41 | 42 | 43 | def distribute(total, objects): 44 | '''Generator for (count, object) tuples that distributes count evenly among 45 | the provided objects''' 46 | for index, obj in enumerate(objects): 47 | start = (index * total) // len(objects) 48 | stop = ((index + 1) * total) // len(objects) 49 | yield (stop - start, obj) 50 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2019.3.9 2 | chardet==3.0.4 3 | contextlib2==0.5.5 4 | decorator==4.4.0 5 | funcsigs==1.0.2 6 | idna==2.8 7 | mock==3.0.5 8 | nose==1.3.7 9 | requests==2.22.0 10 | simplejson==3.16.0 11 | six==1.12.0 12 | statsd==3.3.0 13 | urllib3==1.25.3 14 | -------------------------------------------------------------------------------- /scripts/travis/before-install.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | 3 | # Download and install the NSQ binaries 4 | export NSQ_DIST="nsq-${NSQ_VERSION}.linux-$(go env GOARCH).go1.2" 5 | wget "https://s3.amazonaws.com/bitly-downloads/nsq/${NSQ_DIST}.tar.gz" 6 | tar xf "${NSQ_DIST}.tar.gz" 7 | ( 8 | cd "${NSQ_DIST}" 9 | sudo cp bin/* /usr/local/bin/ 10 | ) 11 | 12 | # Start nsqlookupd and disown it 13 | nsqlookupd > /dev/null 2> /dev/null & 14 | disown 15 | 16 | # And an instance of nsqd 17 | nsqd --lookupd-tcp-address=127.0.0.1:4160 > /dev/null 2> /dev/null & 18 | disown 19 | -------------------------------------------------------------------------------- /scripts/vagrant/files/etc/init/nsqadmin.conf: -------------------------------------------------------------------------------- 1 | # nsqadmin 2 | 3 | description "nsqadmin" 4 | 5 | start on (local-filesystems 6 | and net-device-up IFACE!=lo) 7 | stop on runlevel[!2345] 8 | 9 | respawn 10 | console log 11 | setuid nsq 12 | 13 | exec nsqadmin --lookupd-http-address=localhost:4161 14 | -------------------------------------------------------------------------------- /scripts/vagrant/files/etc/init/nsqd.conf: -------------------------------------------------------------------------------- 1 | # nsqadmin 2 | 3 | description "nsqadmin" 4 | 5 | start on (local-filesystems 6 | and net-device-up IFACE!=lo) 7 | stop on runlevel[!2345] 8 | 9 | respawn 10 | console log 11 | setuid nsq 12 | 13 | exec nsqd --lookupd-tcp-address=localhost:4160 --data-path=/var/lib/nsqd 14 | -------------------------------------------------------------------------------- /scripts/vagrant/files/etc/init/nsqlookupd.conf: -------------------------------------------------------------------------------- 1 | # nsqlookupd 2 | 3 | description "nsqlookupd" 4 | 5 | start on (local-filesystems 6 | and net-device-up IFACE!=lo) 7 | stop on runlevel[!2345] 8 | 9 | respawn 10 | console log 11 | setuid nsq 12 | 13 | exec nsqlookupd 14 | -------------------------------------------------------------------------------- /scripts/vagrant/provision.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | 3 | set -e 4 | 5 | sudo apt-get update 6 | sudo apt-get install -y tar curl git 7 | 8 | # Libraries required to build a complete python with pyenv: 9 | # https://github.com/yyuu/pyenv/wiki 10 | sudo apt-get install -y make build-essential libssl-dev zlib1g-dev libbz2-dev \ 11 | libreadline-dev libsqlite3-dev wget curl llvm libncurses5-dev 12 | 13 | # Download and install the NSQ binaries 14 | export NSQ_VERSION=0.2.26 15 | export NSQ_DIST="nsq-${NSQ_VERSION}.linux-amd64.go1.2" 16 | pushd /tmp 17 | wget "https://s3.amazonaws.com/bitly-downloads/nsq/${NSQ_DIST}.tar.gz" 18 | tar xf "${NSQ_DIST}.tar.gz" 19 | pushd ${NSQ_DIST} 20 | sudo cp bin/* /usr/local/bin/ 21 | popd 22 | popd 23 | 24 | # Create an NSQ user 25 | sudo useradd -U -s /bin/false -d /dev/null nsq 26 | sudo mkdir /var/lib/nsqd 27 | sudo chown nsq:nsq /var/lib/nsqd 28 | 29 | # Copy all relevant files 30 | sudo rsync -r /vagrant/scripts/vagrant/files/ / 31 | 32 | # And start the relevant services 33 | sudo service nsqlookupd start 34 | sudo service nsqadmin start 35 | sudo service nsqd start 36 | 37 | # Install pyenv 38 | git clone https://github.com/yyuu/pyenv.git ~/.pyenv 39 | echo ' 40 | # Pyenv 41 | export PYENV_ROOT="$HOME/.pyenv" 42 | export PATH="$PYENV_ROOT/bin:$PATH" 43 | eval "$(pyenv init -)" 44 | ' >> ~/.bash_profile 45 | source ~/.bash_profile 46 | hash 47 | 48 | pushd /vagrant 49 | # Install our python version 50 | pyenv install 51 | pyenv rehash 52 | 53 | # Install a virtualenv 54 | pip install virtualenv 55 | virtualenv venv 56 | source venv/bin/activate 57 | 58 | # Install our dependencies 59 | pip install -r requirements.txt 60 | popd 61 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | from nsq import __version__ 4 | 5 | extra = {} 6 | 7 | from setuptools import setup 8 | 9 | 10 | setup(name = 'nsq-py', 11 | version = __version__, 12 | description = 'NSQ for Python With Pure Sockets', 13 | url = 'http://github.com/dlecocq/nsq-py', 14 | author = 'Dan Lecocq', 15 | author_email = 'dan@moz.com', 16 | license = "MIT License", 17 | keywords = 'nsq, queue', 18 | packages = ['nsq', 'nsq.http', 'nsq.sockets'], 19 | package_dir = {'nsq': 'nsq', 'nsq.http': 'nsq/http', 'nsq.sockets': 'nsq/sockets'}, 20 | classifiers = [ 21 | 'License :: OSI Approved :: MIT License', 22 | 'Programming Language :: Python', 23 | 'Intended Audience :: Developers', 24 | 'Operating System :: OS Independent' 25 | ], 26 | install_requires=[ 27 | 'requests', 28 | 'decorator', 29 | 'six', 30 | 'statsd' 31 | ], 32 | tests_requires=[ 33 | 'nose', 34 | 'coverage' 35 | ], 36 | scripts = [ 37 | 'bin/nsqlookupd-statsd' 38 | ], 39 | **extra 40 | ) 41 | -------------------------------------------------------------------------------- /shovel/profile.py: -------------------------------------------------------------------------------- 1 | from shovel import task 2 | 3 | import time 4 | from itertools import islice, cycle, permutations, izip_longest, chain 5 | from contextlib import contextmanager, closing 6 | 7 | 8 | @contextmanager 9 | def profiler(): 10 | '''Profile the block''' 11 | import cProfile 12 | import pstats 13 | pr = cProfile.Profile() 14 | pr.enable() 15 | yield 16 | pr.disable() 17 | ps = pstats.Stats(pr).sort_stats('tottime') 18 | ps.print_stats() 19 | 20 | 21 | def messages(count, size): 22 | '''Generator for count messages of the provided size''' 23 | import string 24 | # Make sure we have at least 'size' letters 25 | letters = islice(cycle(chain(string.lowercase, string.uppercase)), size) 26 | return islice(cycle(''.join(l) for l in permutations(letters, size)), count) 27 | 28 | 29 | def grouper(iterable, n): 30 | '''Collect data into fixed-length chunks or blocks''' 31 | args = [iter(iterable)] * n 32 | for group in izip_longest(fillvalue=None, *args): 33 | group = [g for g in group if g != None] 34 | yield group 35 | 36 | 37 | @task 38 | def basic(topic='topic', channel='channel', count=1e6, size=10, gevent=False, 39 | max_in_flight=2500, profile=False): 40 | '''Basic benchmark''' 41 | if gevent: 42 | from gevent import monkey 43 | monkey.patch_all() 44 | 45 | # Check the types of the arguments 46 | count = int(count) 47 | size = int(size) 48 | max_in_flight = int(max_in_flight) 49 | 50 | from nsq.http import nsqd 51 | from nsq.reader import Reader 52 | 53 | print('Publishing messages...') 54 | for batch in grouper(messages(count, size), 1000): 55 | nsqd.Client('http://localhost:4151').mpub(topic, batch) 56 | 57 | print('Consuming messages') 58 | client = Reader(topic, channel, nsqd_tcp_addresses=['localhost:4150'], 59 | max_in_flight=max_in_flight) 60 | with closing(client): 61 | start = -time.time() 62 | if profile: 63 | with profiler(): 64 | for message in islice(client, count): 65 | message.fin() 66 | else: 67 | for message in islice(client, count): 68 | message.fin() 69 | start += time.time() 70 | print('Finished %i messages in %fs (%5.2f messages / second)' % ( 71 | count, start, count / start)) 72 | 73 | 74 | @task 75 | def stats(): 76 | '''Read a stream of floats and give summary statistics''' 77 | import re 78 | import sys 79 | import math 80 | values = [] 81 | for line in sys.stdin: 82 | values.extend(map(float, re.findall(r'\d+\.?\d+', line))) 83 | 84 | mean = sum(values) / len(values) 85 | variance = sum((val - mean) ** 2 for val in values) / len(values) 86 | print('%3i items; mean: %10.5f; std-dev: %10.5f' % ( 87 | len(values), mean, math.sqrt(variance))) 88 | -------------------------------------------------------------------------------- /test/common/__init__.py: -------------------------------------------------------------------------------- 1 | # Import some of the tests here to make them available 2 | 3 | from .clienttest import ClientTest 4 | from .httpclientintegrationtest import HttpClientIntegrationTest 5 | from .integrationtest import IntegrationTest 6 | from .mockedconnectiontest import MockedConnectionTest 7 | from .mockedsockettest import MockedSocketTest 8 | -------------------------------------------------------------------------------- /test/common/clienttest.py: -------------------------------------------------------------------------------- 1 | import mock 2 | import unittest 3 | 4 | from contextlib import contextmanager 5 | 6 | 7 | class ClientTest(unittest.TestCase): 8 | @contextmanager 9 | def patched_get(self): 10 | with mock.patch.object(self.client, 'get') as get: 11 | yield get 12 | 13 | @contextmanager 14 | def patched_post(self): 15 | with mock.patch.object(self.client, 'post') as post: 16 | yield post 17 | -------------------------------------------------------------------------------- /test/common/httpclientintegrationtest.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | 3 | from nsq.http import ClientException, nsqd, nsqlookupd 4 | from .integrationtest import IntegrationTest 5 | 6 | 7 | class HttpClientIntegrationTest(IntegrationTest): 8 | '''For integration tests, includes a topic, a channel and clients''' 9 | nsqd_ports = (14150,) 10 | nsqlookup_port = 14160 11 | 12 | def setUp(self): 13 | self.topic = b'test-topic' 14 | self.channel = b'test-channel' 15 | self.nsqd = nsqd.Client('http://localhost:14151') 16 | self.nsqlookupd = nsqlookupd.Client('http://localhost:14161') 17 | 18 | # Create this topic 19 | self.nsqlookupd.create_topic(self.topic) 20 | self.nsqlookupd.create_channel(self.topic, self.channel) 21 | self.nsqd.create_topic(self.topic) 22 | self.nsqd.create_channel(self.topic, self.channel) 23 | 24 | def tearDown(self): 25 | with self.delete_topic(self.topic): 26 | pass 27 | 28 | @contextmanager 29 | def delete_topic(self, topic): 30 | '''Delete a topic after running''' 31 | try: 32 | yield 33 | finally: 34 | # Delete the topic from our nsqd instance 35 | try: 36 | self.nsqd.delete_topic(topic) 37 | except ClientException: 38 | pass 39 | # Delete the topic from our nsqlookupd instance 40 | try: 41 | self.nsqlookupd.delete_topic(topic) 42 | except ClientException: 43 | pass 44 | -------------------------------------------------------------------------------- /test/common/integrationtest.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from contextlib import contextmanager 4 | import os 5 | import subprocess 6 | import time 7 | 8 | import contextlib2 9 | from nsq import logger 10 | from nsq.http import nsqd, nsqlookupd, ClientException 11 | 12 | 13 | class ProcessWrapper(object): 14 | '''Wraps a subprocess''' 15 | def __init__(self, path, *args): 16 | self._path = path 17 | self._args = [path] + list(args) 18 | self._process = None 19 | 20 | def start(self): 21 | '''Start the process''' 22 | logger.info('Spawning %s' % ' '.join(self._args)) 23 | self._process = subprocess.Popen( 24 | self._args, 25 | bufsize=0, 26 | executable=self._path, 27 | stdin=None, 28 | stdout=open(os.devnull), 29 | stderr=open(os.devnull)) 30 | # Wait until the process is 'live' 31 | while not self.ready(): 32 | time.sleep(0.1) 33 | 34 | def stop(self): 35 | '''Stop the process''' 36 | logger.info('Stopping %s' % ' '.join(self._args)) 37 | if not self._process: 38 | return 39 | 40 | self._process.terminate() 41 | self._process.wait() 42 | self._process = None 43 | 44 | def ready(self): 45 | '''By default, run 'ping' on the client''' 46 | try: 47 | self._client.ping() 48 | return True 49 | except ClientException: 50 | return False 51 | 52 | @contextmanager 53 | def run(self): 54 | '''Start and yield this process, and stop it afterwards''' 55 | try: 56 | self.start() 57 | yield self 58 | finally: 59 | self.stop() 60 | 61 | 62 | class Nsqd(ProcessWrapper): 63 | '''Wraps an instance of nsqd''' 64 | def __init__(self, port, nsqlookupd): 65 | self._client = nsqd.Client('http://localhost:%s' % (port + 1)) 66 | options = { 67 | 'data-path': 'test/tmp', 68 | 'deflate': 'true', 69 | 'snappy': 'true', 70 | 'tls-cert': 'test/fixtures/certificates/cert.pem', 71 | 'tls-key': 'test/fixtures/certificates/key.pem', 72 | 'broadcast-address': 'localhost', 73 | 'tcp-address': '0.0.0.0:%s' % (port), 74 | 'http-address': '0.0.0.0:%s' % (port + 1), 75 | 'lookupd-tcp-address': '127.0.0.1:%s' % nsqlookupd 76 | } 77 | args = ['--%s=%s' % (k, v) for k, v in options.items()] 78 | ProcessWrapper.__init__(self, 'nsqd', *args) 79 | 80 | 81 | class Nsqlookupd(ProcessWrapper): 82 | '''Wraps an instance of nsqlookupd''' 83 | def __init__(self, port): 84 | self._client = nsqlookupd.Client('http://localhost:%s' % (port + 1)) 85 | options = { 86 | 'tcp-address': 'localhost:%s' % port, 87 | 'http-address': 'localhost:%s' % (port + 1) 88 | } 89 | args = ['--%s=%s' % (k, v) for k, v in options.items()] 90 | ProcessWrapper.__init__(self, 'nsqlookupd', *args) 91 | 92 | 93 | class IntegrationTest(unittest.TestCase): 94 | '''Spawn a temporary real server with all the bells and whistles''' 95 | host = 'localhost' 96 | nsqd_ports = (14150,) 97 | nsqlookupd_port = 14160 98 | 99 | @classmethod 100 | def setUpClass(cls): 101 | if not os.path.exists('test/tmp'): 102 | os.mkdir('test/tmp') 103 | # TODO(dan): Ensure that test/tmp exists and is empty 104 | instances = ( 105 | [Nsqlookupd(cls.nsqlookupd_port)] + 106 | [Nsqd(p, cls.nsqlookupd_port) for p in cls.nsqd_ports]) 107 | cls._context = contextlib2.ExitStack() 108 | cls._context.__enter__() 109 | for i in instances: 110 | cls._context.enter_context(i.run()) 111 | 112 | @classmethod 113 | def tearDownClass(cls): 114 | cls._context.__exit__(None, None, None) 115 | # Also remove the tmp directory 116 | for path in os.listdir('test/tmp'): 117 | os.remove(os.path.join('test/tmp', path)) 118 | -------------------------------------------------------------------------------- /test/common/mockedconnectiontest.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import mock 3 | 4 | from nsq.client import Client 5 | from nsq import response 6 | 7 | 8 | class MockConnection(mock.Mock): 9 | def __init__(self, host=None, port=None, *args, **kwargs): 10 | mock.Mock.__init__(self) 11 | self.host = host 12 | self.port = port 13 | self._responses = [] 14 | self._alive = True 15 | 16 | def read(self): 17 | '''Return all of our responses''' 18 | found = list(self._responses) 19 | self._responses = [] 20 | return found 21 | 22 | def response(self, message): 23 | self._responses.append( 24 | response.Response(self, response.Response.FRAME_TYPE, message)) 25 | 26 | def error(self, exception): 27 | '''Send an error''' 28 | self._responses.append( 29 | response.Error(self, response.Error.FRAME_TYPE, exception.name)) 30 | 31 | def alive(self): 32 | return self._alive 33 | 34 | def close(self): 35 | self._alive = False 36 | 37 | 38 | class MockedConnectionTest(unittest.TestCase): 39 | '''Create a client with mocked connection objects''' 40 | nsqd_ports = (12345, 12346) 41 | 42 | def setUp(self): 43 | with mock.patch('nsq.client.connection.Connection', MockConnection): 44 | hosts = ['localhost:%s' % port for port in self.nsqd_ports] 45 | self.client = Client(nsqd_tcp_addresses=hosts) 46 | self.connections = self.client.connections() 47 | -------------------------------------------------------------------------------- /test/common/mockedsockettest.py: -------------------------------------------------------------------------------- 1 | import mock 2 | import unittest 3 | 4 | from nsq import connection 5 | from nsq import json 6 | from nsq import response 7 | 8 | 9 | class MockSocket(mock.Mock): 10 | '''The server-side socket. Read/write are from the server's perspective''' 11 | def __init__(self, *_, **__): 12 | mock.Mock.__init__(self) 13 | self._to_client_buffer = b'' 14 | self._to_server_buffer = b'' 15 | 16 | # From the server's perspective 17 | def write(self, message): 18 | self._to_client_buffer += message 19 | 20 | def read(self): 21 | data, self._to_server_buffer = self._to_server_buffer, b'' 22 | return data 23 | 24 | # From the client's perspective 25 | def send(self, message): 26 | self._to_server_buffer += message 27 | return len(message) 28 | 29 | def sendall(self, message): 30 | return self.send(message) 31 | 32 | def recv(self, limit): 33 | data, self._to_client_buffer = ( 34 | self._to_client_buffer[:limit], self._to_client_buffer[limit:]) 35 | return data 36 | 37 | def response(self, message): 38 | '''Send the provided message as a response''' 39 | self.write(response.Response.pack(message)) 40 | 41 | def error(self, exception): 42 | '''Send an error''' 43 | self.write(response.Error.pack(exception.name)) 44 | 45 | def identify(self, spec=None): 46 | '''Write out the identify response''' 47 | if spec: 48 | self.response(json.dumps(spec).encode('UTF-8')) 49 | else: 50 | self.response(b'OK') 51 | 52 | 53 | class MockedSocketTest(unittest.TestCase): 54 | '''A test where socket is patched''' 55 | def connect(self, identify_response=None): 56 | sock = MockSocket() 57 | sock.identify(identify_response) 58 | with mock.patch('nsq.connection.socket.socket', return_value=sock): 59 | return connection.Connection('localhost', 1234, 0.01) 60 | 61 | def setUp(self): 62 | self.connection = self.connect() 63 | self.connection.setblocking(0) 64 | self.socket = self.connection._socket 65 | -------------------------------------------------------------------------------- /test/fixtures/certificates/cert.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIDtTCCAp2gAwIBAgIJAOma51eu0ih2MA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV 3 | BAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBX 4 | aWRnaXRzIFB0eSBMdGQwHhcNMTQwNzEwMTYyNTU1WhcNMTUwNzEwMTYyNTU1WjBF 5 | MQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50 6 | ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB 7 | CgKCAQEAt/5QMJCODTXdrijm5WeDyFZXnPTzFHNRgCyxFgrYpGF5AAWKurO5Bnr0 8 | 8QLxq7jq7WsgCEaKX6nT5D4Rmgbkk4yA99QoPJFJMHNVqXXcBZWB+OSEOVijaLNr 9 | G5q37cDFeEERvkZrxVMVlsXnzOaVU8TurIDwJNZJjBckTmcXdNr4N2iCLWPI0E/W 10 | AreDPUSfda3jCjY+MOFAIwEyQYR96Rqj2RLCqS7FLwGtnSOqfUTf0qJBlGZy0cLf 11 | cHfIdAL2+MC4bSQ6MYsA48JbRM9G+q6Lr/1q/EPdxJD8iIe6ytsZfY1c6KgpTqZB 12 | GjYyJScGC47xEaxFQ0Rd2sUbciu+VwIDAQABo4GnMIGkMB0GA1UdDgQWBBSI/4uT 13 | 32pgcu7rHiJP56odrPxazDB1BgNVHSMEbjBsgBSI/4uT32pgcu7rHiJP56odrPxa 14 | zKFJpEcwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgTClNvbWUtU3RhdGUxITAfBgNV 15 | BAoTGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZIIJAOma51eu0ih2MAwGA1UdEwQF 16 | MAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAF4YZXtj4P2htvmPGaGuwBm/su28FWmW 17 | CtXFhIUJeQkz4UKLYTpNPOP5xaKg2Zb9lMfHomZB2naTcGzSiwR3ftcKKwn0pXjA 18 | HlGyDiIYzj2Lz1asJBMZyEGppZvln1nZI0Ik/M9ncMAzQmsxrGORSlJtEa1yK/hB 19 | w52hoeU3vCXdVtbkAGQWQk3KhLfLsQWQFGts0uQt8MJA9zPGWtEaFDSKoOUnvTCa 20 | NcS8DkU2VAdz1tVobhXNwV17jCroEvwLDMytrrewRYka5oRUzhVhAPuW28eKsRGN 21 | +qR7e0PwumB+Bvffs8DjHDwa7Ogh+qruvbioXBeTqRfS/W0q9+VTnJA= 22 | -----END CERTIFICATE----- 23 | -------------------------------------------------------------------------------- /test/fixtures/certificates/key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIIEogIBAAKCAQEAt/5QMJCODTXdrijm5WeDyFZXnPTzFHNRgCyxFgrYpGF5AAWK 3 | urO5Bnr08QLxq7jq7WsgCEaKX6nT5D4Rmgbkk4yA99QoPJFJMHNVqXXcBZWB+OSE 4 | OVijaLNrG5q37cDFeEERvkZrxVMVlsXnzOaVU8TurIDwJNZJjBckTmcXdNr4N2iC 5 | LWPI0E/WAreDPUSfda3jCjY+MOFAIwEyQYR96Rqj2RLCqS7FLwGtnSOqfUTf0qJB 6 | lGZy0cLfcHfIdAL2+MC4bSQ6MYsA48JbRM9G+q6Lr/1q/EPdxJD8iIe6ytsZfY1c 7 | 6KgpTqZBGjYyJScGC47xEaxFQ0Rd2sUbciu+VwIDAQABAoIBABh8ngt4kY8shg4x 8 | n1kUh7NX2l0nNFqaZlRank7CrsZhuorIMgha9trn7kVNEQC7oXhrc13mlW/Z2Dte 9 | D1WiaTVB08An2hsFcuohz1q4Nsn/dca8EuTW6Rh8GFsaIjRgHWe9sTDTinA+eHcS 10 | a6EXZvQ5F1KZ7lvYsP0V710H11VTvPrKEG2I2ipHGrOm/VHmkfpznjpMb9Z5b+Pk 11 | 54+SOVKSDPhYhbMxXlRKinKEPF9mdroJkfRow4ac5dKe01sook9UYq1MiLtOjUjH 12 | j6jVgQjHdfTAam8VxzeJOihL2ixddaRMGIkHUZTF4Lb7l5dqi92Ov/e3DhTKvxdZ 13 | d3L59+ECgYEA782YHeQx+l8VSJL1HFzHO0pLg13OhyT5xhbxh4vaGjd3gFau8ce9 14 | KqfZcsWJDU0t4lzuoi5MzuJT3dmme8YtcYgY/OlsqDGeRqJJt/e9SCZJqCXwtUVD 15 | lgSl4WW7+m56elux6PeJj17v5axKA1aa3tJ+AdNEu5xGpOhHzVEqDkkCgYEAxGu4 16 | dtj6KmU3WpQL3k9Dkr+8hS/C0RVgZHbgMVu9gAYRvzwFXOTCKC+qWc5PGNNB1sKm 17 | xGG8MEBuHOZXBg+SyuLAciVn6R5LdJiPdcuZEqmk+FBGnTEnbPypfGtZ+Q+MD5zY 18 | jVOvSL+mst7MO2XvrKZZ+VOu7OwPKRDDbkjh558CgYBORF4Xs3kUbKA3ta9GeImW 19 | MmN/FsjnlwvmuWpPgTfIQr5AJwqmYzi8iVgRe6OFseD99rL0QARVqc0RpY4O69m9 20 | Klxtf4o1QyyThThmUPd4avazaN6ta1PpzM6PSHMYA6L5+J+Sl+hP4P6PibIGcOfP 21 | PghedCQEz7bG8AEvZARD8QKBgCW0F8CYfczNiQaWDIEr7eipbWKTfG3uEIa4Wuie 22 | l42PnLB8sPrX3n0gSS7b70rwol67Fo/zws/wTjK19FZxftf7Fr3SeFPDQPCsqD0Q 23 | S93NOqF/p05dNRgyl8YORUMNvPDyRo86VRc90p3bLpDoTE1z0SmO6rEHzxEu6pSs 24 | 4NA5AoGAWSNPSUSNwYmDoiHn2K8vVFxbNrN1iCO1+EAMk7F1i4vUGmhc7cObO579 25 | l/PNzzWgvp4xIXlPEEZmrf3Lad5dkLmtEQrVfOLLmb5K7e2JD7kF5Sdz6e5BMCzi 26 | noBZrXmjcOJN9fg0/7koXMEjAjrRVmNXchsx1XW/BoUTsp6nDfQ= 27 | -----END RSA PRIVATE KEY----- 28 | -------------------------------------------------------------------------------- /test/fixtures/test_stats/TestStats/stats: -------------------------------------------------------------------------------- 1 | [ 2 | [ 3 | "channels.backend_depth.max", 4 | 0 5 | ], 6 | [ 7 | "channels.backend_depth.total", 8 | 0 9 | ], 10 | [ 11 | "channels.clients.max", 12 | 0 13 | ], 14 | [ 15 | "channels.clients.total", 16 | 0 17 | ], 18 | [ 19 | "channels.deferred_count.max", 20 | 0 21 | ], 22 | [ 23 | "channels.deferred_count.total", 24 | 0 25 | ], 26 | [ 27 | "channels.depth.max", 28 | 19 29 | ], 30 | [ 31 | "channels.depth.total", 32 | 19 33 | ], 34 | [ 35 | "channels.in_flight_count.max", 36 | 0 37 | ], 38 | [ 39 | "channels.in_flight_count.total", 40 | 0 41 | ], 42 | [ 43 | "channels.message_count.max", 44 | 19 45 | ], 46 | [ 47 | "channels.message_count.total", 48 | 19 49 | ], 50 | [ 51 | "channels.paused.max", 52 | 0 53 | ], 54 | [ 55 | "channels.paused.total", 56 | 0 57 | ], 58 | [ 59 | "channels.requeue_count.max", 60 | 0 61 | ], 62 | [ 63 | "channels.requeue_count.total", 64 | 0 65 | ], 66 | [ 67 | "channels.timeout_count.max", 68 | 0 69 | ], 70 | [ 71 | "channels.timeout_count.total", 72 | 0 73 | ], 74 | [ 75 | "host.localhost_14151.topic.topic-on-both-instances.backend_depth", 76 | 0 77 | ], 78 | [ 79 | "host.localhost_14151.topic.topic-on-both-instances.depth", 80 | 23 81 | ], 82 | [ 83 | "host.localhost_14151.topic.topic-on-both-instances.message_count", 84 | 23 85 | ], 86 | [ 87 | "host.localhost_14151.topic.topic-on-both-instances.paused", 88 | 0 89 | ], 90 | [ 91 | "host.localhost_14151.topic.topic-with-channels.backend_depth", 92 | 0 93 | ], 94 | [ 95 | "host.localhost_14151.topic.topic-with-channels.channel.channel.backend_depth", 96 | 0 97 | ], 98 | [ 99 | "host.localhost_14151.topic.topic-with-channels.channel.channel.clients", 100 | 0 101 | ], 102 | [ 103 | "host.localhost_14151.topic.topic-with-channels.channel.channel.deferred_count", 104 | 0 105 | ], 106 | [ 107 | "host.localhost_14151.topic.topic-with-channels.channel.channel.depth", 108 | 19 109 | ], 110 | [ 111 | "host.localhost_14151.topic.topic-with-channels.channel.channel.in_flight_count", 112 | 0 113 | ], 114 | [ 115 | "host.localhost_14151.topic.topic-with-channels.channel.channel.message_count", 116 | 19 117 | ], 118 | [ 119 | "host.localhost_14151.topic.topic-with-channels.channel.channel.paused", 120 | 0 121 | ], 122 | [ 123 | "host.localhost_14151.topic.topic-with-channels.channel.channel.requeue_count", 124 | 0 125 | ], 126 | [ 127 | "host.localhost_14151.topic.topic-with-channels.channel.channel.timeout_count", 128 | 0 129 | ], 130 | [ 131 | "host.localhost_14151.topic.topic-with-channels.channels.backend_depth.max", 132 | 0 133 | ], 134 | [ 135 | "host.localhost_14151.topic.topic-with-channels.channels.backend_depth.total", 136 | 0 137 | ], 138 | [ 139 | "host.localhost_14151.topic.topic-with-channels.channels.clients.max", 140 | 0 141 | ], 142 | [ 143 | "host.localhost_14151.topic.topic-with-channels.channels.clients.total", 144 | 0 145 | ], 146 | [ 147 | "host.localhost_14151.topic.topic-with-channels.channels.deferred_count.max", 148 | 0 149 | ], 150 | [ 151 | "host.localhost_14151.topic.topic-with-channels.channels.deferred_count.total", 152 | 0 153 | ], 154 | [ 155 | "host.localhost_14151.topic.topic-with-channels.channels.depth.max", 156 | 19 157 | ], 158 | [ 159 | "host.localhost_14151.topic.topic-with-channels.channels.depth.total", 160 | 19 161 | ], 162 | [ 163 | "host.localhost_14151.topic.topic-with-channels.channels.in_flight_count.max", 164 | 0 165 | ], 166 | [ 167 | "host.localhost_14151.topic.topic-with-channels.channels.in_flight_count.total", 168 | 0 169 | ], 170 | [ 171 | "host.localhost_14151.topic.topic-with-channels.channels.message_count.max", 172 | 19 173 | ], 174 | [ 175 | "host.localhost_14151.topic.topic-with-channels.channels.message_count.total", 176 | 19 177 | ], 178 | [ 179 | "host.localhost_14151.topic.topic-with-channels.channels.paused.max", 180 | 0 181 | ], 182 | [ 183 | "host.localhost_14151.topic.topic-with-channels.channels.paused.total", 184 | 0 185 | ], 186 | [ 187 | "host.localhost_14151.topic.topic-with-channels.channels.requeue_count.max", 188 | 0 189 | ], 190 | [ 191 | "host.localhost_14151.topic.topic-with-channels.channels.requeue_count.total", 192 | 0 193 | ], 194 | [ 195 | "host.localhost_14151.topic.topic-with-channels.channels.timeout_count.max", 196 | 0 197 | ], 198 | [ 199 | "host.localhost_14151.topic.topic-with-channels.channels.timeout_count.total", 200 | 0 201 | ], 202 | [ 203 | "host.localhost_14151.topic.topic-with-channels.depth", 204 | 0 205 | ], 206 | [ 207 | "host.localhost_14151.topic.topic-with-channels.message_count", 208 | 19 209 | ], 210 | [ 211 | "host.localhost_14151.topic.topic-with-channels.paused", 212 | 0 213 | ], 214 | [ 215 | "host.localhost_14151.topic.topic-without-channels.backend_depth", 216 | 0 217 | ], 218 | [ 219 | "host.localhost_14151.topic.topic-without-channels.depth", 220 | 22 221 | ], 222 | [ 223 | "host.localhost_14151.topic.topic-without-channels.message_count", 224 | 22 225 | ], 226 | [ 227 | "host.localhost_14151.topic.topic-without-channels.paused", 228 | 0 229 | ], 230 | [ 231 | "host.localhost_14153.topic.topic-on-both-instances.backend_depth", 232 | 0 233 | ], 234 | [ 235 | "host.localhost_14153.topic.topic-on-both-instances.depth", 236 | 0 237 | ], 238 | [ 239 | "host.localhost_14153.topic.topic-on-both-instances.message_count", 240 | 0 241 | ], 242 | [ 243 | "host.localhost_14153.topic.topic-on-both-instances.paused", 244 | 0 245 | ], 246 | [ 247 | "host.localhost_14153.topic.topic-without-channels.backend_depth", 248 | 0 249 | ], 250 | [ 251 | "host.localhost_14153.topic.topic-without-channels.depth", 252 | 10 253 | ], 254 | [ 255 | "host.localhost_14153.topic.topic-without-channels.message_count", 256 | 10 257 | ], 258 | [ 259 | "host.localhost_14153.topic.topic-without-channels.paused", 260 | 0 261 | ], 262 | [ 263 | "topic.topic-on-both-instances.backend_depth.max", 264 | 0 265 | ], 266 | [ 267 | "topic.topic-on-both-instances.backend_depth.total", 268 | 0 269 | ], 270 | [ 271 | "topic.topic-on-both-instances.depth.max", 272 | 23 273 | ], 274 | [ 275 | "topic.topic-on-both-instances.depth.total", 276 | 23 277 | ], 278 | [ 279 | "topic.topic-on-both-instances.message_count.max", 280 | 23 281 | ], 282 | [ 283 | "topic.topic-on-both-instances.message_count.total", 284 | 23 285 | ], 286 | [ 287 | "topic.topic-on-both-instances.paused.max", 288 | 0 289 | ], 290 | [ 291 | "topic.topic-on-both-instances.paused.total", 292 | 0 293 | ], 294 | [ 295 | "topic.topic-with-channels.backend_depth.max", 296 | 0 297 | ], 298 | [ 299 | "topic.topic-with-channels.backend_depth.total", 300 | 0 301 | ], 302 | [ 303 | "topic.topic-with-channels.channels.backend_depth.max", 304 | 0 305 | ], 306 | [ 307 | "topic.topic-with-channels.channels.backend_depth.total", 308 | 0 309 | ], 310 | [ 311 | "topic.topic-with-channels.channels.clients.max", 312 | 0 313 | ], 314 | [ 315 | "topic.topic-with-channels.channels.clients.total", 316 | 0 317 | ], 318 | [ 319 | "topic.topic-with-channels.channels.deferred_count.max", 320 | 0 321 | ], 322 | [ 323 | "topic.topic-with-channels.channels.deferred_count.total", 324 | 0 325 | ], 326 | [ 327 | "topic.topic-with-channels.channels.depth.max", 328 | 19 329 | ], 330 | [ 331 | "topic.topic-with-channels.channels.depth.total", 332 | 19 333 | ], 334 | [ 335 | "topic.topic-with-channels.channels.in_flight_count.max", 336 | 0 337 | ], 338 | [ 339 | "topic.topic-with-channels.channels.in_flight_count.total", 340 | 0 341 | ], 342 | [ 343 | "topic.topic-with-channels.channels.message_count.max", 344 | 19 345 | ], 346 | [ 347 | "topic.topic-with-channels.channels.message_count.total", 348 | 19 349 | ], 350 | [ 351 | "topic.topic-with-channels.channels.paused.max", 352 | 0 353 | ], 354 | [ 355 | "topic.topic-with-channels.channels.paused.total", 356 | 0 357 | ], 358 | [ 359 | "topic.topic-with-channels.channels.requeue_count.max", 360 | 0 361 | ], 362 | [ 363 | "topic.topic-with-channels.channels.requeue_count.total", 364 | 0 365 | ], 366 | [ 367 | "topic.topic-with-channels.channels.timeout_count.max", 368 | 0 369 | ], 370 | [ 371 | "topic.topic-with-channels.channels.timeout_count.total", 372 | 0 373 | ], 374 | [ 375 | "topic.topic-with-channels.depth.max", 376 | 0 377 | ], 378 | [ 379 | "topic.topic-with-channels.depth.total", 380 | 0 381 | ], 382 | [ 383 | "topic.topic-with-channels.message_count.max", 384 | 19 385 | ], 386 | [ 387 | "topic.topic-with-channels.message_count.total", 388 | 19 389 | ], 390 | [ 391 | "topic.topic-with-channels.paused.max", 392 | 0 393 | ], 394 | [ 395 | "topic.topic-with-channels.paused.total", 396 | 0 397 | ], 398 | [ 399 | "topic.topic-without-channels.backend_depth.max", 400 | 0 401 | ], 402 | [ 403 | "topic.topic-without-channels.backend_depth.total", 404 | 0 405 | ], 406 | [ 407 | "topic.topic-without-channels.depth.max", 408 | 22 409 | ], 410 | [ 411 | "topic.topic-without-channels.depth.total", 412 | 32 413 | ], 414 | [ 415 | "topic.topic-without-channels.message_count.max", 416 | 22 417 | ], 418 | [ 419 | "topic.topic-without-channels.message_count.total", 420 | 32 421 | ], 422 | [ 423 | "topic.topic-without-channels.paused.max", 424 | 0 425 | ], 426 | [ 427 | "topic.topic-without-channels.paused.total", 428 | 0 429 | ], 430 | [ 431 | "topics.backend_depth.max", 432 | 0 433 | ], 434 | [ 435 | "topics.backend_depth.total", 436 | 0 437 | ], 438 | [ 439 | "topics.depth.max", 440 | 23 441 | ], 442 | [ 443 | "topics.depth.total", 444 | 55 445 | ], 446 | [ 447 | "topics.message_count.max", 448 | 23 449 | ], 450 | [ 451 | "topics.message_count.total", 452 | 74 453 | ], 454 | [ 455 | "topics.paused.max", 456 | 0 457 | ], 458 | [ 459 | "topics.paused.total", 460 | 0 461 | ] 462 | ] 463 | -------------------------------------------------------------------------------- /test/test_backoff.py: -------------------------------------------------------------------------------- 1 | import mock 2 | import unittest 3 | 4 | from nsq import backoff 5 | 6 | 7 | class TestBackoff(unittest.TestCase): 8 | '''Test our backoff class''' 9 | def setUp(self): 10 | self.backoff = backoff.Backoff() 11 | 12 | def test_sleep(self): 13 | '''Just calls time.sleep with whatever backoff returns''' 14 | with mock.patch.object(self.backoff, 'backoff', return_value=5): 15 | with mock.patch('nsq.backoff.time') as MockTime: 16 | self.backoff.sleep(10) 17 | self.backoff.backoff.assert_called_with(10) 18 | MockTime.sleep.assert_called_with(5) 19 | 20 | def test_backoff(self): 21 | '''Not implemented on the base class''' 22 | self.assertRaises(NotImplementedError, self.backoff.backoff, 1) 23 | 24 | 25 | class TestLinear(unittest.TestCase): 26 | '''Test our linear backoff class''' 27 | def setUp(self): 28 | self.backoff = backoff.Linear(1, 2) 29 | 30 | def test_constant(self): 31 | '''The constant is added to each time''' 32 | self.assertEqual(self.backoff.backoff(0), 2) 33 | 34 | def test_affine(self): 35 | '''The affine factor works as advertised''' 36 | self.assertEqual(self.backoff.backoff(0) + 1, self.backoff.backoff(1)) 37 | 38 | 39 | class TestConstant(unittest.TestCase): 40 | '''Test our constant backoff class''' 41 | def setUp(self): 42 | self.backoff = backoff.Constant(10) 43 | 44 | def test_constant(self): 45 | '''Always gives the same result''' 46 | for i in range(100): 47 | self.assertEqual(self.backoff.backoff(0), self.backoff.backoff(i)) 48 | 49 | 50 | class TestExponential(unittest.TestCase): 51 | '''Test our exponential backoff class''' 52 | def test_factor(self): 53 | '''We make use of the constant factor''' 54 | base = 5 55 | one = backoff.Exponential(base, 1) 56 | two = backoff.Exponential(base, 2) 57 | for i in range(10): 58 | self.assertEqual(one.backoff(i) * 2, two.backoff(i)) 59 | 60 | def test_constant(self): 61 | '''We add the constant value''' 62 | base = 5 63 | four = backoff.Exponential(base, c=4) 64 | zero = backoff.Exponential(base) 65 | for i in range(10): 66 | self.assertEqual(zero.backoff(i) + 4, four.backoff(i)) 67 | 68 | def test_base(self): 69 | '''We honor the base''' 70 | one = backoff.Exponential(1) 71 | two = backoff.Exponential(2) 72 | self.assertEqual(one.backoff(1) * 2, two.backoff(1)) 73 | self.assertEqual(one.backoff(2) * 4, two.backoff(2)) 74 | 75 | 76 | class TestClamped(unittest.TestCase): 77 | '''Does in fact keep our backoff clamped''' 78 | def setUp(self): 79 | self.linear = backoff.Linear(1, 2) 80 | self.backoff = backoff.Clamped(self.linear, minimum=5, maximum=10) 81 | 82 | def test_min(self): 83 | '''Asserts a minimum''' 84 | self.assertLess(self.linear.backoff(0), 5) 85 | self.assertEqual(self.backoff.backoff(0), 5) 86 | 87 | def test_max(self): 88 | '''Asserts a maximum''' 89 | self.assertGreater(self.linear.backoff(100), 10) 90 | self.assertEqual(self.backoff.backoff(100), 10) 91 | 92 | 93 | class TestAttemptCounter(unittest.TestCase): 94 | '''Test the attempt counter''' 95 | def setUp(self): 96 | self.backoff = mock.Mock() 97 | self.backoff.backoff.return_value = 9001 98 | self.counter = backoff.AttemptCounter(self.backoff) 99 | 100 | def test_sleep(self): 101 | '''Just calls time.sleep with whatever backoff returns''' 102 | with mock.patch.object(self.counter, 'backoff', return_value=5): 103 | with mock.patch('nsq.backoff.time') as MockTime: 104 | self.counter.sleep() 105 | self.counter.backoff.assert_called_with() 106 | MockTime.sleep.assert_called_with(5) 107 | 108 | def test_backoff(self): 109 | '''Not implemented on the base class''' 110 | with mock.patch.object(self.counter, 'attempts', 5): 111 | self.assertEqual( 112 | self.counter.backoff(), self.backoff.backoff.return_value) 113 | self.backoff.backoff.assert_called_with(5) 114 | 115 | def test_failed(self): 116 | '''Failed increments the number of attempts''' 117 | for attempts in range(10): 118 | self.assertEqual(self.counter.attempts, attempts) 119 | self.counter.failed() 120 | 121 | def test_ready_false(self): 122 | '''Ready returns false if not enough time has elapsed''' 123 | with mock.patch('nsq.backoff.time') as mock_time: 124 | mock_time.time = mock.Mock(return_value=10) 125 | with mock.patch.object(self.counter, '_last_failed', 10): 126 | self.assertFalse(self.counter.ready()) 127 | 128 | def test_ready_true(self): 129 | '''Ready returns true if enough time has elapsed''' 130 | with mock.patch('nsq.backoff.time') as mock_time: 131 | mock_time.time = mock.Mock(return_value=10) 132 | with mock.patch.object(self.counter, '_last_failed', 1): 133 | with mock.patch.object(self.counter, 'backoff', return_value=5): 134 | self.assertTrue(self.counter.ready()) 135 | 136 | def test_ready_never_failed(self): 137 | '''If it has never failed, then it returns True''' 138 | with mock.patch.object(self.counter, '_last_failed', None): 139 | self.assertTrue(self.counter.ready()) 140 | 141 | 142 | class TestResettingAttemptCounter(unittest.TestCase): 143 | '''Test the ResettingAttemptCounter''' 144 | def setUp(self): 145 | self.counter = backoff.ResettingAttemptCounter(None) 146 | 147 | def test_success(self): 148 | '''Success resets the attempts counter''' 149 | for _ in range(10): 150 | self.counter.failed() 151 | self.counter.success() 152 | self.assertEqual(self.counter.attempts, 0) 153 | 154 | 155 | class TestDecrementingAttemptCounter(unittest.TestCase): 156 | def setUp(self): 157 | self.counter = backoff.DecrementingAttemptCounter(None) 158 | 159 | def test_success(self): 160 | '''Success only decrements the attempts counter''' 161 | for _ in range(10): 162 | self.counter.failed() 163 | self.counter.success() 164 | self.assertEqual(self.counter.attempts, 9) 165 | 166 | def test_negative_attempts(self): 167 | '''Success never lets the attempts count drop below 0''' 168 | self.counter.success() 169 | self.assertEqual(self.counter.attempts, 0) 170 | -------------------------------------------------------------------------------- /test/test_checker.py: -------------------------------------------------------------------------------- 1 | '''Tests about our connection-checking thread''' 2 | 3 | import mock 4 | import unittest 5 | import threading 6 | 7 | from nsq.checker import PeriodicThread, ConnectionChecker 8 | 9 | 10 | class TestPeriodicThread(unittest.TestCase): 11 | '''Can stop a PeriodicThread''' 12 | def test_stop(self): 13 | '''Stop is effective for stopping a stoppable thread''' 14 | def callback(event): 15 | '''Trigger an event''' 16 | event.set() 17 | 18 | event = threading.Event() 19 | thread = PeriodicThread(0.01, callback, event) 20 | thread.start() 21 | event.wait() 22 | thread.stop() 23 | thread.join() 24 | self.assertFalse(thread.is_alive()) 25 | 26 | def test_repeats(self): 27 | '''Repeats the same callback several times''' 28 | def callback(event, counter): 29 | '''Trigger an event after accumulating enough''' 30 | counter['count'] += 1 31 | if counter['count'] > 10: 32 | event.set() 33 | 34 | event = threading.Event() 35 | counter = {'count': 0} 36 | thread = PeriodicThread(0.01, callback, event, counter) 37 | thread.start() 38 | event.wait() 39 | thread.stop() 40 | thread.join() 41 | self.assertGreaterEqual(counter['count'], 10) 42 | 43 | def test_survives_standard_error(self): 44 | '''The thread survives exceptions''' 45 | def callback(): 46 | '''Raise an exception''' 47 | raise Exception('foo') 48 | 49 | with mock.patch('nsq.checker.logger') as mock_logger: 50 | thread = PeriodicThread(0.01, callback) 51 | thread.start() 52 | thread.join(0.1) 53 | self.assertTrue(thread.is_alive()) 54 | thread.stop() 55 | thread.join() 56 | mock_logger.exception.assert_called_with('Callback failed') 57 | self.assertGreater(mock_logger.exception.call_count, 1) 58 | 59 | 60 | class TestConnectionChecker(unittest.TestCase): 61 | '''ConnectionChecker tests''' 62 | def test_callback(self): 63 | '''Provides the client's connection checking method''' 64 | with mock.patch.object(PeriodicThread, '__init__') as mock_init: 65 | mock_client = mock.Mock() 66 | checker = ConnectionChecker(mock_client, 10) 67 | mock_init.assert_called_with( 68 | checker, 10, mock_client.check_connections) 69 | -------------------------------------------------------------------------------- /test/test_client.py: -------------------------------------------------------------------------------- 1 | import mock 2 | 3 | from nsq import client 4 | from nsq import response 5 | from nsq import constants 6 | from nsq import exceptions 7 | from nsq.http import ClientException 8 | 9 | from common import HttpClientIntegrationTest, MockedConnectionTest 10 | from contextlib import contextmanager 11 | import errno 12 | import select 13 | import socket 14 | 15 | 16 | class TestClientNsqd(HttpClientIntegrationTest): 17 | '''Test our client class''' 18 | nsqd_ports = (14150,) 19 | 20 | def setUp(self): 21 | '''Return a new client''' 22 | HttpClientIntegrationTest.setUp(self) 23 | hosts = ['localhost:%s' % port for port in self.nsqd_ports] 24 | self.client = client.Client(nsqd_tcp_addresses=hosts) 25 | 26 | def test_connect_nsqd(self): 27 | '''Can successfully establish connections''' 28 | connections = self.client.connections() 29 | self.assertEqual(len(connections), 1) 30 | for connection in connections: 31 | self.assertTrue(connection.alive()) 32 | 33 | def test_add_added(self): 34 | '''Connect invokes self.connected''' 35 | connection = mock.Mock() 36 | with mock.patch.object(self.client, 'added'): 37 | self.client.add(connection) 38 | self.client.added.assert_called_with(connection) 39 | 40 | def test_add_existing(self): 41 | '''Adding an existing connection returns None''' 42 | connection = self.client.connections()[0] 43 | self.assertEqual(self.client.add(connection), None) 44 | 45 | def test_remove_exception(self): 46 | '''If closing a connection raises an exception, remove still works''' 47 | connection = self.client.connections()[0] 48 | with mock.patch.object(connection, 'close', side_effect=Exception): 49 | self.assertEqual(self.client.remove(connection), connection) 50 | 51 | def test_honors_identify_options(self): 52 | '''Sends along identify options to each connection as it's created''' 53 | with mock.patch('nsq.client.connection.Connection') as MockConnection: 54 | with mock.patch.object( 55 | self.client, '_identify_options', {'foo': 'bar'}): 56 | self.client.connect('foo', 'bar') 57 | MockConnection.assert_called_with('foo', 'bar', 58 | reconnection_backoff=None, auth_secret=None, foo='bar', timeout=None) 59 | 60 | def test_conection_checker(self): 61 | '''Spawns and starts a connection checker''' 62 | with self.client.connection_checker() as checker: 63 | self.assertTrue(checker.is_alive()) 64 | self.assertFalse(checker.is_alive()) 65 | 66 | def test_read_closed(self): 67 | '''Recovers from reading on a closed connection''' 68 | conn = self.client.connections()[0] 69 | with mock.patch.object(conn, 'alive', return_value=True): 70 | with mock.patch.object(conn, '_socket', None): 71 | # This test passes if no exception in raised 72 | self.client.read() 73 | 74 | def test_read_select_err(self): 75 | '''Recovers from select errors''' 76 | with mock.patch('nsq.client.select.select') as mock_select: 77 | mock_select.side_effect = select.error(errno.EBADF) 78 | # This test passes if no exception is raised 79 | self.client.read() 80 | 81 | 82 | class TestClientLookupd(HttpClientIntegrationTest): 83 | '''Test our client class''' 84 | def setUp(self): 85 | '''Return a new client''' 86 | HttpClientIntegrationTest.setUp(self) 87 | self.client = client.Client(topic=self.topic, lookupd_http_addresses=['http://localhost:14161']) 88 | 89 | def test_connected(self): 90 | '''Can successfully establish connections''' 91 | connections = self.client.connections() 92 | self.assertEqual(len(connections), 1) 93 | for connection in connections: 94 | self.assertTrue(connection.alive()) 95 | 96 | def test_asserts_topic(self): 97 | '''If nsqlookupd servers are provided, asserts a topic''' 98 | self.assertRaises( 99 | AssertionError, client.Client, lookupd_http_addresses=['foo']) 100 | 101 | def test_client_exception(self): 102 | '''Is OK when discovery fails''' 103 | with mock.patch('nsq.client.nsqlookupd.Client') as MockClass: 104 | instance = MockClass.return_value 105 | instance.lookup.side_effect = ClientException 106 | self.client = client.Client( 107 | lookupd_http_addresses=['http://localhost:1234'], 108 | topic='foo') 109 | 110 | def test_discover_connected(self): 111 | '''Doesn't freak out when rediscovering established connections''' 112 | before = self.client.connections() 113 | self.client.discover(self.topic) 114 | self.assertEqual(self.client.connections(), before) 115 | 116 | def test_discover_closed(self): 117 | '''Reconnects to discovered servers that have closed connections''' 118 | for conn in self.client.connections(): 119 | conn.close() 120 | state = [conn.alive() for conn in self.client.connections()] 121 | self.assertEqual(state, [False]) 122 | self.client.discover(self.topic) 123 | state = [conn.alive() for conn in self.client.connections()] 124 | self.assertEqual(state, [True]) 125 | 126 | def test_auth_secret(self): 127 | '''If an auth secret is provided, it passes it to nsqlookupd''' 128 | with mock.patch('nsq.client.nsqlookupd.Client') as MockClient: 129 | client.Client(topic=self.topic, lookupd_http_addresses=['foo'], auth_secret='hello') 130 | MockClient.assert_called_with('foo', access_token='hello') 131 | 132 | 133 | class TestClientMultiple(MockedConnectionTest): 134 | '''Tests for our client class''' 135 | @contextmanager 136 | def readable(self, connections): 137 | '''With all the connections readable''' 138 | value = (connections, [], []) 139 | with mock.patch('nsq.client.select.select', return_value=value): 140 | yield 141 | 142 | @contextmanager 143 | def writable(self, connections): 144 | '''With all the connections writable''' 145 | value = ([], connections, []) 146 | with mock.patch('nsq.client.select.select', return_value=value): 147 | yield 148 | 149 | @contextmanager 150 | def exceptable(self, connections): 151 | '''With all the connections exceptable''' 152 | value = ([], [], connections) 153 | with mock.patch('nsq.client.select.select', return_value=value): 154 | yield 155 | 156 | def test_multi_read(self): 157 | '''Can read from multiple sockets''' 158 | # With all the connections read-ready 159 | for conn in self.connections: 160 | conn.response('hello') 161 | with self.readable(self.connections): 162 | found = self.client.read() 163 | self.assertEqual(len(found), 2) 164 | for res in found: 165 | self.assertIsInstance(res, response.Response) 166 | self.assertEqual(res.data, 'hello') 167 | 168 | def test_heartbeat(self): 169 | '''Sends a nop on connections that have received a heartbeat''' 170 | for conn in self.connections: 171 | conn.response(constants.HEARTBEAT) 172 | with self.readable(self.connections): 173 | self.assertEqual(self.client.read(), []) 174 | for conn in self.connections: 175 | conn.nop.assert_called_with() 176 | 177 | def test_closes_on_fatal(self): 178 | '''All but a few errors are considered fatal''' 179 | self.connections[0].error(exceptions.InvalidException) 180 | with self.readable(self.connections): 181 | self.client.read() 182 | self.assertFalse(self.connections[0].alive()) 183 | 184 | def test_nonfatal(self): 185 | '''Nonfatal errors keep the connection open''' 186 | self.connections[0].error(exceptions.FinFailedException) 187 | with self.readable(self.connections): 188 | self.client.read() 189 | self.assertTrue(self.connections[0].alive()) 190 | 191 | def test_passes_errors(self): 192 | '''The client's read method should now swallow Error responses''' 193 | self.connections[0].error(exceptions.InvalidException) 194 | with self.readable(self.connections): 195 | res = self.client.read() 196 | self.assertEqual(len(res), 1) 197 | self.assertIsInstance(res[0], response.Error) 198 | self.assertEqual(res[0].data, exceptions.InvalidException.name) 199 | 200 | def test_closes_on_exception(self): 201 | '''If a connection gets an exception, it closes it''' 202 | # Pick a connection to have throw an exception 203 | conn = self.connections[0] 204 | with mock.patch.object( 205 | conn, 'read', side_effect=exceptions.NSQException): 206 | with self.readable(self.connections): 207 | self.client.read() 208 | self.assertFalse(conn.alive()) 209 | 210 | def test_closes_on_read_socket_error(self): 211 | '''If a connection gets a socket error, it closes it''' 212 | # Pick a connection to have throw an exception 213 | conn = self.connections[0] 214 | with mock.patch.object( 215 | conn, 'read', side_effect=socket.error): 216 | with self.readable(self.connections): 217 | self.client.read() 218 | self.assertFalse(conn.alive()) 219 | 220 | def test_closes_on_flush_socket_error(self): 221 | '''If a connection fails to flush, it gets closed''' 222 | # Pick a connection to have throw an exception 223 | conn = self.connections[0] 224 | with mock.patch.object( 225 | conn, 'flush', side_effect=socket.error): 226 | with self.writable(self.connections): 227 | self.client.read() 228 | self.assertFalse(conn.alive()) 229 | 230 | def test_read_writable(self): 231 | '''Read flushes any writable connections''' 232 | with self.writable(self.connections): 233 | self.client.read() 234 | for conn in self.connections: 235 | conn.flush.assert_called_with() 236 | 237 | def test_read_exceptions(self): 238 | '''Closes connections with socket errors''' 239 | with self.exceptable(self.connections): 240 | self.client.read() 241 | for conn in self.connections: 242 | self.assertFalse(conn.alive()) 243 | 244 | def test_read_timeout(self): 245 | '''Logs a message when our read loop finds nothing because of timeout''' 246 | with self.readable([]): 247 | with mock.patch('nsq.client.logger') as MockLogger: 248 | self.client.read() 249 | MockLogger.debug.assert_called_with('Timed out...') 250 | 251 | def test_read_with_no_connections(self): 252 | '''Attempting to read with no connections''' 253 | with mock.patch.object(self.client, 'connections', return_value=[]): 254 | self.assertEqual(self.client.read(), []) 255 | 256 | def test_read_sleep_no_connections(self): 257 | '''Sleeps for timeout if no connections''' 258 | with mock.patch.object(self.client, '_timeout', 5): 259 | with mock.patch.object(self.client, 'connections', return_value=[]): 260 | with mock.patch('nsq.client.time.sleep') as mock_sleep: 261 | self.client.read() 262 | mock_sleep.assert_called_with(self.client._timeout) 263 | 264 | def test_random_connection(self): 265 | '''Yields a random client''' 266 | found = [] 267 | for _ in range(20): 268 | with self.client.random_connection() as conn: 269 | found.append(conn) 270 | self.assertEqual(set(found), set(self.client.connections())) 271 | 272 | def test_wait_response(self): 273 | '''Waits until a response is available''' 274 | with mock.patch.object( 275 | self.client, 'read', side_effect=[[], ['hello']]): 276 | self.assertEqual(self.client.wait_response(), ['hello']) 277 | 278 | def test_wait_write(self): 279 | '''Waits until a command has been sent''' 280 | connection = mock.Mock() 281 | with mock.patch.object(self.client, 'read'): 282 | connection.pending = mock.Mock(side_effect=[True, False]) 283 | self.client.wait_write(connection) 284 | self.assertTrue(connection.pending.called) 285 | 286 | def test_pub(self): 287 | '''Pub called on a random connection and waits for a response''' 288 | connection = mock.Mock() 289 | with mock.patch.object( 290 | self.client, 'connections', return_value=[connection]): 291 | with mock.patch.object( 292 | self.client, 'wait_response', return_value=['response']): 293 | self.assertEqual(self.client.pub('foo', 'bar'), ['response']) 294 | connection.pub.assert_called_with('foo', 'bar') 295 | 296 | def test_mpub(self): 297 | '''Mpub called on a random connection and waits for a response''' 298 | connection = mock.Mock() 299 | messages = ['hello', 'how', 'are', 'you'] 300 | with mock.patch.object( 301 | self.client, 'connections', return_value=[connection]): 302 | with mock.patch.object( 303 | self.client, 'wait_response', return_value=['response']): 304 | self.assertEqual( 305 | self.client.mpub('foo', *messages), ['response']) 306 | connection.mpub.assert_called_with('foo', *messages) 307 | 308 | def test_not_ready_to_reconnect(self): 309 | '''Does not try to reconnect connections that are not ready''' 310 | conn = self.connections[0] 311 | conn.close() 312 | conn.ready_to_reconnect.return_value = False 313 | self.client.check_connections() 314 | self.assertFalse(conn.connect.called) 315 | 316 | def test_ready_to_reconnect(self): 317 | '''Tries to reconnect when ready''' 318 | conn = self.connections[0] 319 | conn.close() 320 | conn.ready_to_reconnect.return_value = True 321 | self.client.check_connections() 322 | self.assertTrue(conn.connect.called) 323 | 324 | def test_set_blocking(self): 325 | '''Sets blocking to 0 when reconnecting''' 326 | conn = self.connections[0] 327 | conn.close() 328 | conn.ready_to_reconnect.return_value = True 329 | conn.connect.return_value = True 330 | self.client.check_connections() 331 | conn.setblocking.assert_called_with(0) 332 | 333 | def test_calls_reconnected(self): 334 | '''Sets blocking to 0 when reconnecting''' 335 | conn = self.connections[0] 336 | conn.close() 337 | conn.ready_to_reconnect.return_value = True 338 | conn.connect.return_value = True 339 | with mock.patch.object(self.client, 'reconnected'): 340 | self.client.check_connections() 341 | self.client.reconnected.assert_called_with(conn) 342 | 343 | 344 | class TestClientNsqdWithConnectTimeout(HttpClientIntegrationTest): 345 | '''Test our client class when a connection timeout is set''' 346 | nsqd_ports = (14150,) 347 | 348 | def test_connect_timeout(self): 349 | HttpClientIntegrationTest.setUp(self) 350 | hosts = ['localhost:%s' % port for port in self.nsqd_ports] 351 | connect_timeout = 2.0 352 | self.client = client.Client(nsqd_tcp_addresses=hosts, connect_timeout=connect_timeout) 353 | self.assertEqual(self.client._connect_timeout, connect_timeout) 354 | -------------------------------------------------------------------------------- /test/test_clients/test_clients.py: -------------------------------------------------------------------------------- 1 | import mock 2 | import unittest 3 | 4 | from nsq import http 5 | 6 | 7 | class TestClients(unittest.TestCase): 8 | def setUp(self): 9 | self.result = mock.Mock() 10 | self.result.content = '{"foo": "bar"}' 11 | self.result.status_code = 200 12 | self.result.reason = 'OK' 13 | self.func = mock.Mock(return_value=self.result) 14 | self.client = http.BaseClient('http://foo:1') 15 | 16 | def function(*args, **kwargs): 17 | return self.func(*args, **kwargs) 18 | 19 | self.function = function 20 | 21 | def test_string(self): 22 | '''Can make a client with a string''' 23 | # This test passes if no exception is thrown 24 | http.BaseClient('http://foo.com:4161') 25 | 26 | def test_tuple(self): 27 | '''Can create a client with a tuple''' 28 | http.BaseClient(('foo.com', 4161)) 29 | 30 | def test_non_tuple_string(self): 31 | '''Raises an exception if it's neither a tuple or a string''' 32 | self.assertRaises(TypeError, http.BaseClient, {}) 33 | 34 | def test_wrap_basic(self): 35 | '''Invokes a function with the same args and kwargs''' 36 | args = [1, 2, 3] 37 | kwargs = {'whiz': 'bang'} 38 | http.wrap(self.function)(*args, **kwargs) 39 | self.func.assert_called_with(*args, **kwargs) 40 | 41 | def test_wrap_non_200(self): 42 | '''Raises a client exception''' 43 | self.result.status_code = 500 44 | self.result.reason = 'Internal Server Error' 45 | self.assertRaisesRegexp(http.ClientException, 46 | 'Internal Server Error', http.wrap(self.function)) 47 | 48 | def test_wrap_exception(self): 49 | '''Wraps exceptions as ClientExceptions''' 50 | self.func.side_effect = TypeError 51 | self.assertRaises(http.ClientException, http.wrap(self.function)) 52 | 53 | def test_json_wrap_basic(self): 54 | '''Returns JSON-parsed content''' 55 | self.result.content = '{"data":"bar"}' 56 | self.assertEqual(http.json_wrap(self.function)(), 'bar') 57 | 58 | def test_json_wrap_exception(self): 59 | '''Raises a generalized exception for failed 200s''' 60 | # This is not JSON 61 | self.result.content = '{"' 62 | self.assertRaises(http.ClientException, 63 | http.json_wrap(self.function)) 64 | 65 | def test_ok_check(self): 66 | '''Passes through the OK response''' 67 | self.result.content = b'OK' 68 | self.assertEqual(b'OK', http.ok_check(self.function)()) 69 | 70 | def test_ok_check_raises_exception(self): 71 | '''Raises an exception if the respons is not OK''' 72 | self.result.content = 'NOT OK' 73 | self.assertRaisesRegexp( 74 | http.ClientException, 'NOT OK', http.ok_check(self.function)) 75 | 76 | def test_get(self): 77 | '''Gets from the appropriate host with all the provided params''' 78 | with mock.patch('nsq.http.requests') as MockClass: 79 | args = [1, 2, 3] 80 | kwargs = {'whiz': 'bang'} 81 | MockClass.get.return_value = mock.Mock( 82 | status_code=200, content='{"foo": "bar"}') 83 | self.client.get('/path', *args, **kwargs) 84 | MockClass.get.assert_called_with( 85 | 'http://foo:1/path', params={}, *args, **kwargs) 86 | 87 | def test_post(self): 88 | '''Posts to the appropriate host with all the provided params''' 89 | with mock.patch('nsq.http.requests') as MockClass: 90 | args = [1, 2, 3] 91 | kwargs = {'whiz': 'bang'} 92 | MockClass.post.return_value = mock.Mock( 93 | status_code=200, content='{"foo": "bar"}') 94 | self.client.post('/path', *args, **kwargs) 95 | MockClass.post.assert_called_with( 96 | 'http://foo:1/path', params={}, *args, **kwargs) 97 | 98 | def test_prefix_get(self): 99 | '''Gets from the appropriately-relativized path''' 100 | client = http.BaseClient('http://foo.com:1/prefix/') 101 | with mock.patch('nsq.http.requests') as MockClass: 102 | MockClass.get.return_value = mock.Mock( 103 | status_code=200, content='{"foo": "bar"}') 104 | client.get('path') 105 | MockClass.get.assert_called_with( 106 | 'http://foo.com:1/prefix/path', params={}) 107 | 108 | def test_prefix_post(self): 109 | '''Posts to the appropriately-relativized path''' 110 | client = http.BaseClient('http://foo.com:1/prefix/') 111 | with mock.patch('nsq.http.requests') as MockClass: 112 | MockClass.post.return_value = mock.Mock( 113 | status_code=200, content='{"foo": "bar"}') 114 | client.post('path') 115 | MockClass.post.assert_called_with( 116 | 'http://foo.com:1/prefix/path', params={}) 117 | 118 | def test_params_get(self): 119 | '''Provides default parameters''' 120 | client = http.BaseClient('http://foo.com:1/', a='b') 121 | with mock.patch('nsq.http.requests') as MockClass: 122 | MockClass.get.return_value = mock.Mock( 123 | status_code=200, content='{"foo": "bar"}') 124 | client.get('path') 125 | MockClass.get.assert_called_with( 126 | 'http://foo.com:1/path', params={'a': 'b'}) 127 | 128 | def test_params_post(self): 129 | '''Provides default parameters''' 130 | client = http.BaseClient('http://foo.com:1/', a='b') 131 | with mock.patch('nsq.http.requests') as MockClass: 132 | MockClass.post.return_value = mock.Mock( 133 | status_code=200, content='{"foo": "bar"}') 134 | client.post('path') 135 | MockClass.post.assert_called_with( 136 | 'http://foo.com:1/path', params={'a': 'b'}) 137 | -------------------------------------------------------------------------------- /test/test_clients/test_nsqd.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | import six 4 | from nsq.http import nsqd, ClientException 5 | from nsq.util import pack 6 | from common import HttpClientIntegrationTest, ClientTest 7 | 8 | 9 | class TestNsqdClient(ClientTest): 10 | '''Testing the nsqd client in isolation''' 11 | def setUp(self): 12 | self.client = nsqd.Client('http://foo:1') 13 | 14 | def test_mpub_ascii(self): 15 | '''Publishes ascii messages fine''' 16 | with self.patched_post() as post: 17 | post.return_value.content = b'OK' 18 | messages = [six.text_type(n).encode() for n in range(10)] 19 | self.client.mpub('topic', messages, binary=False) 20 | post.assert_called_with( 21 | '/mpub', params={'topic': 'topic'}, data=b'\n'.join(messages)) 22 | 23 | def test_mpub_binary(self): 24 | '''Publishes messages with binary fine''' 25 | with self.patched_post() as post: 26 | post.return_value.content = b'OK' 27 | messages = [six.text_type(n).encode() for n in range(10)] 28 | self.client.mpub('topic', messages) 29 | post.assert_called_with( 30 | 'mpub', 31 | params={'topic': 'topic', 'binary': True}, 32 | data=pack(messages)[4:]) 33 | 34 | def test_mpub_ascii_exception(self): 35 | '''Raises an exception when ascii-mpub-ing messages with newline''' 36 | messages = [b'hello\n', b'how\n', b'are\n', b'you\n'] 37 | self.assertRaises( 38 | ClientException, self.client.mpub, 'topic', messages, binary=False) 39 | 40 | 41 | class TestNsqdClientIntegration(HttpClientIntegrationTest): 42 | '''An integration test of the nsqd client''' 43 | def test_ping_ok(self): 44 | '''Make sure ping works in a basic way''' 45 | self.assertEqual(self.nsqd.ping(), b'OK') 46 | 47 | def test_info(self): 48 | '''Info works in a basic way''' 49 | self.assertIn('version', self.nsqd.info()) 50 | 51 | def test_pub(self): 52 | '''Publishing a message works as expected''' 53 | self.assertEqual(self.nsqd.pub(self.topic, 'message'), b'OK') 54 | topic = self.nsqd.clean_stats()['topics'][self.topic.decode()] 55 | self.assertEqual(topic['channels'][self.channel.decode()]['depth'], 1) 56 | 57 | def test_mpub_ascii(self): 58 | '''Publishing messages in ascii works as expected''' 59 | messages = [six.text_type(i).encode() for i in range(100)] 60 | self.assertTrue(self.nsqd.mpub(self.topic, messages, binary=False)) 61 | 62 | def test_mpub_binary(self): 63 | '''Publishing messages in binary works as expected''' 64 | messages = [six.text_type(i).encode() for i in range(100)] 65 | self.assertTrue(self.nsqd.mpub(self.topic, messages)) 66 | 67 | def test_create_topic(self): 68 | '''Topic creation should work''' 69 | topic = uuid.uuid4().hex 70 | with self.delete_topic(topic): 71 | # Ensure the topic doesn't exist beforehand 72 | self.assertNotIn(topic, self.nsqd.clean_stats()['topics']) 73 | self.assertTrue(self.nsqd.create_topic(topic)) 74 | # And now it exists afterwards 75 | self.assertIn(topic, self.nsqd.clean_stats()['topics']) 76 | 77 | def test_empty_topic(self): 78 | '''We can drain a topic''' 79 | topic = uuid.uuid4().hex 80 | with self.delete_topic(topic): 81 | self.nsqd.pub(topic, 'foo') 82 | self.nsqd.empty_topic(topic) 83 | depth = self.nsqd.clean_stats()['topics'][topic]['depth'] 84 | self.assertEqual(depth, 0) 85 | 86 | def test_delete_topic(self): 87 | '''We can delete a topic''' 88 | topic = uuid.uuid4().hex 89 | with self.delete_topic(topic): 90 | self.nsqd.create_topic(topic) 91 | self.assertTrue(self.nsqd.delete_topic(topic)) 92 | # Ensure the topic doesn't exist afterwards 93 | self.assertNotIn(topic, self.nsqd.clean_stats()['topics']) 94 | 95 | def test_pause_topic(self): 96 | '''We can pause a topic''' 97 | self.assertTrue(self.nsqd.pause_topic(self.topic)) 98 | 99 | def test_unpause_topic(self): 100 | '''We can unpause a topic''' 101 | self.nsqd.pause_topic(self.topic) 102 | self.assertTrue(self.nsqd.unpause_topic(self.topic)) 103 | 104 | def test_create_channel(self): 105 | '''We can create a channel''' 106 | topic = uuid.uuid4().hex 107 | channel = uuid.uuid4().hex 108 | with self.delete_topic(topic): 109 | self.nsqd.create_topic(topic) 110 | self.nsqd.create_channel(topic, channel) 111 | topic = self.nsqd.clean_stats()['topics'][topic] 112 | self.assertIn(channel, topic['channels']) 113 | 114 | def test_empty_channel(self): 115 | '''Can clear the messages out in a channel''' 116 | self.nsqd.pub(self.topic, self.channel) 117 | self.nsqd.empty_channel(self.topic, self.channel) 118 | topic = self.nsqd.clean_stats()['topics'][self.topic.decode()] 119 | channel = topic['channels'][self.channel.decode()] 120 | self.assertEqual(channel['depth'], 0) 121 | 122 | def test_delete_channel(self): 123 | '''Can delete a channel in a topic''' 124 | self.nsqd.delete_channel(self.topic, self.channel) 125 | topic = self.nsqd.clean_stats()['topics'][self.topic.decode()] 126 | self.assertNotIn(self.channel, topic['channels']) 127 | 128 | def test_pause_channel(self): 129 | '''Can pause a channel''' 130 | self.nsqd.pause_channel(self.topic, self.channel) 131 | topic = self.nsqd.clean_stats()['topics'][self.topic.decode()] 132 | channel = topic['channels'][self.channel.decode()] 133 | self.assertTrue(channel['paused']) 134 | 135 | def test_unpause_channel(self): 136 | '''Can unpause a channel''' 137 | self.nsqd.pause_channel(self.topic, self.channel) 138 | self.nsqd.unpause_channel(self.topic, self.channel) 139 | topic = self.nsqd.clean_stats()['topics'][self.topic.decode()] 140 | channel = topic['channels'][self.channel.decode()] 141 | self.assertFalse(channel['paused']) 142 | 143 | def test_clean_stats(self): 144 | '''Clean stats turns 'topics' and 'channels' into dictionaries''' 145 | stats = self.nsqd.clean_stats() 146 | self.assertIsInstance(stats['topics'], dict) 147 | self.assertIsInstance( 148 | stats['topics'][self.topic.decode()]['channels'], dict) 149 | -------------------------------------------------------------------------------- /test/test_clients/test_nsqlookupd.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | from common import HttpClientIntegrationTest 4 | 5 | 6 | class TestNsqlookupdClient(HttpClientIntegrationTest): 7 | def test_ping(self): 8 | '''Ping the client''' 9 | self.assertTrue(self.nsqlookupd.ping()) 10 | 11 | def test_info(self): 12 | '''Get info about the client''' 13 | self.assertIn('version', self.nsqlookupd.info()) 14 | 15 | def test_lookup(self): 16 | '''Can look up nsqd instances for a topic''' 17 | self.assertIn('producers', self.nsqlookupd.lookup(self.topic)) 18 | 19 | def test_topics(self): 20 | '''Can get all the topics this instance knows about''' 21 | self.assertIn(self.topic.decode(), self.nsqlookupd.topics()['topics']) 22 | 23 | def test_channels(self): 24 | '''Can get all the channels in the provided topic''' 25 | self.assertIn(self.channel.decode(), 26 | self.nsqlookupd.channels(self.topic)['channels']) 27 | 28 | def test_nodes(self): 29 | '''Can get information about all the nodes''' 30 | self.assertIn('producers', self.nsqlookupd.nodes()) 31 | 32 | def test_delete_topic(self): 33 | '''Can delete topics''' 34 | self.nsqlookupd.delete_topic(self.topic) 35 | self.assertNotIn(self.topic, self.nsqlookupd.topics()['topics']) 36 | 37 | def test_delete_channel(self): 38 | '''Can delete a channel within a topic''' 39 | self.nsqlookupd.delete_channel(self.topic, self.channel) 40 | self.assertNotIn(self.channel, 41 | self.nsqlookupd.channels(self.topic)['channels']) 42 | 43 | def test_create_topics(self): 44 | '''Can create a topic''' 45 | topic = uuid.uuid4().hex 46 | with self.delete_topic(topic): 47 | self.nsqlookupd.create_topic(topic) 48 | self.assertIn(topic, self.nsqlookupd.topics()['topics']) 49 | 50 | def test_create_channel(self): 51 | '''Can create a channel within a topic''' 52 | channel = uuid.uuid4().hex 53 | self.nsqlookupd.create_channel(self.topic, channel) 54 | self.assertIn(channel, self.nsqlookupd.channels(self.topic)['channels']) 55 | 56 | def test_debug(self): 57 | '''Can access debug information''' 58 | key = 'channel:%s:%s' % (self.topic.decode(), self.channel.decode()) 59 | self.assertIn(key, self.nsqlookupd.debug()) 60 | -------------------------------------------------------------------------------- /test/test_connection.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | import mock 4 | 5 | import errno 6 | import socket 7 | import ssl 8 | from collections import deque 9 | 10 | from nsq import connection 11 | from nsq import constants 12 | from nsq import exceptions 13 | from nsq import response 14 | from nsq import util 15 | from nsq import json 16 | from common import MockedSocketTest, HttpClientIntegrationTest 17 | 18 | 19 | class TestConnection(MockedSocketTest): 20 | '''Tests about our connection''' 21 | def test_alive(self): 22 | self.assertTrue(self.connection.alive()) 23 | 24 | def test_close(self): 25 | '''Should mark the connection as closed''' 26 | self.connection.close() 27 | self.assertFalse(self.connection.alive()) 28 | 29 | def test_blocking(self): 30 | '''Sets blocking on the socket''' 31 | self.connection.setblocking(0) 32 | self.socket.setblocking.assert_called_with(0) 33 | 34 | def test_pending(self): 35 | '''Appends to pending''' 36 | self.connection.nop() 37 | self.assertEqual( 38 | list(self.connection.pending()), [constants.NOP + constants.NL]) 39 | 40 | def test_flush_partial(self): 41 | '''Keeps its place when flushing out partial messages''' 42 | # We'll tell the connection it has only sent one byte when flushing 43 | with mock.patch.object(self.socket, 'send'): 44 | self.socket.send.return_value = 1 45 | self.connection.nop() 46 | self.connection.flush() 47 | # We expect all but the first byte to remain 48 | message = constants.NOP + constants.NL 49 | self.assertEqual(list(self.connection.pending()), [message[1:]]) 50 | 51 | def test_flush_full(self): 52 | '''Pops off messages it has flushed completely''' 53 | # We'll tell the connection it has only sent one byte when flushing 54 | self.connection.nop() 55 | self.connection.flush() 56 | # The nop message was sent, so we expect it to be popped 57 | self.assertEqual(list(self.connection.pending()), []) 58 | 59 | def test_flush_count(self): 60 | '''Returns how many bytes were sent''' 61 | message = constants.NOP + constants.NL 62 | # Ensure this doesn't invoke our normal flush 63 | self.connection.nop() 64 | self.assertEqual(self.connection.flush(), len(message)) 65 | 66 | def test_flush_empty(self): 67 | '''Returns 0 if there are no pending messages''' 68 | self.assertEqual(self.connection.flush(), 0) 69 | 70 | def test_flush_multiple(self): 71 | '''Flushes as many messages as possible''' 72 | pending = deque([b'hello'] * 5) 73 | with mock.patch.object(self.connection, '_pending', pending): 74 | self.connection.flush() 75 | self.assertEqual(len(self.connection.pending()), 0) 76 | 77 | def test_flush_would_block(self): 78 | '''Honors EAGAIN / EWOULDBLOCK''' 79 | pending = deque([b'1', b'2', b'3']) 80 | with mock.patch.object(self.connection, '_socket') as mock_socket: 81 | with mock.patch.object(self.connection, '_pending', pending): 82 | mock_socket.send.side_effect = socket.error(errno.EAGAIN) 83 | self.assertEqual(self.connection.flush(), 0) 84 | 85 | def test_flush_would_block_ssl_write(self): 86 | '''Honors ssl.SSL_ERROR_WANT_WRITE''' 87 | pending = deque([b'1', b'2', b'3']) 88 | with mock.patch.object(self.connection, '_socket') as mock_socket: 89 | with mock.patch.object(self.connection, '_pending', pending): 90 | mock_socket.send.side_effect = ssl.SSLError( 91 | ssl.SSL_ERROR_WANT_WRITE) 92 | self.assertEqual(self.connection.flush(), 0) 93 | 94 | def test_flush_would_block_ssl_read(self): 95 | '''Honors ssl.SSL_ERROR_WANT_READ''' 96 | pending = deque([b'1', b'2', b'3']) 97 | with mock.patch.object(self.connection, '_socket') as mock_socket: 98 | with mock.patch.object(self.connection, '_pending', pending): 99 | mock_socket.send.side_effect = ssl.SSLError( 100 | ssl.SSL_ERROR_WANT_READ) 101 | self.assertEqual(self.connection.flush(), 0) 102 | 103 | def test_flush_would_block_ssl_write_buffer(self): 104 | '''ssl.SSL_ERROR_WANT_WRITE usesthe same buffer on next send''' 105 | pending = deque([b'1', b'2', b'3']) 106 | with mock.patch.object(self.connection, '_pending', pending): 107 | with mock.patch.object(self.connection, '_socket') as mock_socket: 108 | mock_socket.send.side_effect = ssl.SSLError( 109 | ssl.SSL_ERROR_WANT_WRITE) 110 | self.assertFalse(self.connection._out_buffer) 111 | self.connection.flush() 112 | self.assertEqual(self.connection._out_buffer, b'123') 113 | 114 | # With some more pending items, make sure we still only get '123' sent 115 | pending = deque([b'4', b'5', b'6']) 116 | with mock.patch.object(self.connection, '_pending', pending): 117 | with mock.patch.object(self.connection, '_socket') as mock_socket: 118 | mock_socket.send.return_value = 3 119 | # The first flush should see the existing buffer 120 | self.connection.flush() 121 | mock_socket.send.assert_called_with(b'123') 122 | # The second flush should see the pending requests 123 | self.connection.flush() 124 | mock_socket.send.assert_called_with(b'456') 125 | 126 | def test_flush_socket_error(self): 127 | '''Re-raises socket non-EAGAIN errors''' 128 | pending = deque([b'1', b'2', b'3']) 129 | with mock.patch.object(self.connection, '_socket') as mock_socket: 130 | with mock.patch.object(self.connection, '_pending', pending): 131 | mock_socket.send.side_effect = socket.error('foo') 132 | self.assertRaises(socket.error, self.connection.flush) 133 | 134 | def test_eager_flush(self): 135 | '''Sending on a non-blocking connection does not eagerly flushes''' 136 | with mock.patch.object(self.connection, 'flush') as mock_flush: 137 | self.connection.send(b'foo') 138 | mock_flush.assert_not_called() 139 | 140 | def test_close_flush(self): 141 | '''Closing the connection flushes all remaining messages''' 142 | def fake_flush(): 143 | self.connection._pending = False 144 | self.connection._fake_flush_called = True 145 | 146 | with mock.patch.object(self.connection, 'flush', fake_flush): 147 | self.connection.send(b'foo') 148 | self.connection.close() 149 | self.assertTrue(self.connection._fake_flush_called) 150 | 151 | def test_magic(self): 152 | '''Sends the NSQ magic bytes''' 153 | self.assertTrue(self.socket.read().startswith(constants.MAGIC_V2)) 154 | 155 | def test_identify(self): 156 | '''The connection sends the identify commands''' 157 | expected = b''.join([ 158 | constants.MAGIC_V2, 159 | constants.IDENTIFY, 160 | constants.NL, 161 | util.pack(json.dumps(self.connection._identify_options).encode())]) 162 | self.assertEqual(self.socket.read(), expected) 163 | 164 | def test_read_timeout(self): 165 | '''Returns no results after a socket timeout''' 166 | with mock.patch.object(self.connection, '_socket') as mock_socket: 167 | mock_socket.recv.side_effect = socket.timeout 168 | self.assertEqual(self.connection.read(), []) 169 | 170 | def test_read_socket_error(self): 171 | '''Re-raises socket non-errno socket errors''' 172 | with mock.patch.object(self.connection, '_socket') as mock_socket: 173 | mock_socket.recv.side_effect = socket.error('foo') 174 | self.assertRaises(socket.error, self.connection.read) 175 | 176 | def test_read_would_block(self): 177 | '''Returns no results if it would block''' 178 | with mock.patch.object(self.connection, '_socket') as mock_socket: 179 | mock_socket.recv.side_effect = socket.error(errno.EAGAIN) 180 | self.assertEqual(self.connection.read(), []) 181 | 182 | def test_read_would_block_ssl_write(self): 183 | '''Returns no results if it would block on a SSL socket''' 184 | with mock.patch.object(self.connection, '_socket') as mock_socket: 185 | mock_socket.recv.side_effect = ssl.SSLError(ssl.SSL_ERROR_WANT_WRITE) 186 | self.assertEqual(self.connection.read(), []) 187 | 188 | def test_read_would_block_ssl_read(self): 189 | '''Returns no results if it would block on a SSL socket''' 190 | with mock.patch.object(self.connection, '_socket') as mock_socket: 191 | mock_socket.recv.side_effect = ssl.SSLError(ssl.SSL_ERROR_WANT_READ) 192 | self.assertEqual(self.connection.read(), []) 193 | 194 | def test_read_partial(self): 195 | '''Returns nothing if it has only read partial results''' 196 | self.socket.write(b'f') 197 | self.assertEqual(self.connection.read(), []) 198 | 199 | def test_read_size_partial(self): 200 | '''Returns one response size is complete, but content is partial''' 201 | self.socket.write(response.Response.pack(b'hello')[:-1]) 202 | self.assertEqual(self.connection.read(), []) 203 | 204 | def test_read_whole(self): 205 | '''Returns a single message if it has read a complete one''' 206 | self.socket.write(response.Response.pack(b'hello')) 207 | expected = response.Response( 208 | self.connection, constants.FRAME_TYPE_RESPONSE, b'hello') 209 | self.assertEqual(self.connection.read(), [expected]) 210 | 211 | def test_read_multiple(self): 212 | '''Returns multiple responses if available''' 213 | self.socket.write(response.Response.pack(b'hello') * 10) 214 | expected = response.Response( 215 | self.connection, constants.FRAME_TYPE_RESPONSE, b'hello') 216 | self.assertEqual(self.connection.read(), [expected] * 10) 217 | 218 | def test_fileno(self): 219 | '''Returns the connection's file descriptor appropriately''' 220 | self.assertEqual( 221 | self.connection.fileno(), self.socket.fileno()) 222 | 223 | def test_fileno_closed(self): 224 | '''Raises an exception if the connection's closed''' 225 | with mock.patch.object(self.connection, '_socket', None): 226 | self.assertRaises(exceptions.ConnectionClosedException, 227 | self.connection.fileno) 228 | 229 | def test_str_alive(self): 230 | '''Sane str representation for an alive connection''' 231 | with mock.patch.object(self.connection, 'alive', return_value=True): 232 | with mock.patch.object( 233 | self.connection, 'fileno', return_value=7): 234 | with mock.patch.object(self.connection, 'host', 'host'): 235 | with mock.patch.object(self.connection, 'port', 'port'): 236 | self.assertEqual(str(self.connection), 237 | '') 238 | 239 | def test_str_dead(self): 240 | '''Sane str representation for an alive connection''' 241 | with mock.patch.object(self.connection, 'alive', return_value=False): 242 | with mock.patch.object( 243 | self.connection, 'fileno', return_value=7): 244 | with mock.patch.object(self.connection, 'host', 'host'): 245 | with mock.patch.object(self.connection, 'port', 'port'): 246 | self.assertEqual(str(self.connection), 247 | '') 248 | 249 | def test_send_no_message(self): 250 | '''Appropriately sends packed data without message''' 251 | self.socket.read() 252 | self.connection.nop() 253 | self.connection.flush() 254 | expected = constants.NOP + constants.NL 255 | self.assertEqual(self.socket.read(), expected) 256 | 257 | def test_send_message(self): 258 | '''Appropriately sends packed data with message''' 259 | self.socket.read() 260 | self.connection.identify({}) 261 | self.connection.flush() 262 | expected = b''.join( 263 | (constants.IDENTIFY, constants.NL, util.pack(b'{}'))) 264 | self.assertEqual(self.socket.read(), expected) 265 | 266 | def assertSent(self, expected, function, *args, **kwargs): 267 | '''Assert that the connection sends the expected payload''' 268 | self.socket.read() 269 | function(*args, **kwargs) 270 | self.connection.flush() 271 | self.assertEqual(self.socket.read(), expected) 272 | 273 | def test_auth(self): 274 | '''Appropriately send auth''' 275 | expected = b''.join((constants.AUTH, constants.NL, util.pack(b'hello'))) 276 | self.assertSent(expected, self.connection.auth, b'hello') 277 | 278 | def test_sub(self): 279 | '''Appropriately sends sub''' 280 | expected = b''.join((constants.SUB, b' foo bar', constants.NL)) 281 | self.assertSent(expected, self.connection.sub, b'foo', b'bar') 282 | 283 | def test_pub(self): 284 | '''Appropriately sends pub''' 285 | expected = b''.join( 286 | (constants.PUB, b' foo', constants.NL, util.pack(b'hello'))) 287 | self.assertSent(expected, self.connection.pub, b'foo', b'hello') 288 | 289 | def test_mpub(self): 290 | '''Appropriately sends mpub''' 291 | expected = b''.join(( 292 | constants.MPUB, b' foo', constants.NL, 293 | util.pack([b'hello', b'howdy']))) 294 | self.assertSent(expected, self.connection.mpub, b'foo', b'hello', b'howdy') 295 | 296 | def test_ready(self): 297 | '''Appropriately sends ready''' 298 | expected = b''.join((constants.RDY, b' 5', constants.NL)) 299 | self.assertSent(expected, self.connection.rdy, 5) 300 | 301 | def test_fin(self): 302 | '''Appropriately sends fin''' 303 | expected = b''.join((constants.FIN, b' message_id', constants.NL)) 304 | self.assertSent(expected, self.connection.fin, b'message_id') 305 | 306 | def test_req(self): 307 | '''Appropriately sends req''' 308 | expected = b''.join((constants.REQ, b' message_id 10', constants.NL)) 309 | self.assertSent(expected, self.connection.req, b'message_id', 10) 310 | 311 | def test_touch(self): 312 | '''Appropriately sends touch''' 313 | expected = b''.join((constants.TOUCH, b' message_id', constants.NL)) 314 | self.assertSent(expected, self.connection.touch, b'message_id') 315 | 316 | def test_cls(self): 317 | '''Appropriately sends cls''' 318 | expected = b''.join((constants.CLS, constants.NL)) 319 | self.assertSent(expected, self.connection.cls) 320 | 321 | def test_nop(self): 322 | '''Appropriately sends nop''' 323 | expected = b''.join((constants.NOP, constants.NL)) 324 | self.assertSent(expected, self.connection.nop) 325 | 326 | # Some tests very closely aimed at identification 327 | def test_calls_identified(self): 328 | '''Upon getting an identification response, we call 'identified''' 329 | with mock.patch.object( 330 | connection.Connection, 'identified') as mock_identified: 331 | self.connect({'foo': 'bar'}) 332 | self.assertTrue(mock_identified.called) 333 | 334 | def test_identified_tolerates_ok(self): 335 | '''The identified handler tolerates OK responses''' 336 | res = mock.Mock(data='OK') 337 | self.assertEqual(self.connection.identified(res).data, 'OK') 338 | 339 | def test_identify_defaults(self): 340 | '''Identify provides default options''' 341 | self.assertEqual(self.connection._identify_options, { 342 | 'feature_negotiation': True, 343 | 'client_id': socket.getfqdn().split('.')[0], 344 | 'hostname': socket.gethostname(), 345 | 'user_agent': self.connection.USER_AGENT 346 | }) 347 | 348 | def test_identify_override_defaults(self): 349 | '''Identify allows us to override defaults''' 350 | with mock.patch('nsq.connection.Connection.connect'): 351 | conn = connection.Connection('host', 0, long_id='not-your-fqdn') 352 | self.assertEqual(conn._identify_options['long_id'], 'not-your-fqdn') 353 | 354 | def test_identify_tls_unsupported(self): 355 | '''Raises an exception about the lack of TLS support''' 356 | with mock.patch('nsq.connection.TLSSocket', None): 357 | self.assertRaises(exceptions.UnsupportedException, 358 | connection.Connection, 'host', 0, tls_v1=True) 359 | 360 | def test_identify_snappy_unsupported(self): 361 | '''Raises an exception about the lack of snappy support''' 362 | with mock.patch('nsq.connection.SnappySocket', None): 363 | self.assertRaises(exceptions.UnsupportedException, 364 | connection.Connection, 'host', 0, snappy=True) 365 | 366 | def test_identify_deflate_unsupported(self): 367 | '''Raises an exception about the lack of deflate support''' 368 | with mock.patch('nsq.connection.DeflateSocket', None): 369 | self.assertRaises(exceptions.UnsupportedException, 370 | connection.Connection, 'host', 0, deflate=True) 371 | 372 | def test_identify_no_deflate_level(self): 373 | '''Raises an exception about the lack of deflate_level support''' 374 | with mock.patch('nsq.connection.DeflateSocket', None): 375 | self.assertRaises(exceptions.UnsupportedException, 376 | connection.Connection, 'host', 0, deflate_level=True) 377 | 378 | def test_identify_no_snappy_and_deflate(self): 379 | '''We should yell early about incompatible snappy and deflate options''' 380 | self.assertRaises(exceptions.UnsupportedException, 381 | connection.Connection, 'host', 0, snappy=True, deflate=True) 382 | 383 | def test_identify_saves_identify_response(self): 384 | '''Saves the identify response from the server''' 385 | expected = {'foo': 'bar'} 386 | conn = self.connect(expected) 387 | self.assertEqual(conn._identify_response, expected) 388 | 389 | def test_identify_saves_max_rdy_count(self): 390 | '''Saves the max ready count if it's provided''' 391 | conn = self.connect({'max_rdy_count': 100}) 392 | self.assertEqual(conn.max_rdy_count, 100) 393 | 394 | def test_ready_to_reconnect(self): 395 | '''Alias for the reconnection attempt's ready method''' 396 | with mock.patch.object( 397 | self.connection, '_reconnnection_counter') as ctr: 398 | self.connection.ready_to_reconnect() 399 | ctr.ready.assert_called_with() 400 | 401 | def test_reconnect_living_socket(self): 402 | '''Don't reconnect a living connection''' 403 | before = self.connection._socket 404 | self.connection.connect() 405 | self.assertEqual(self.connection._socket, before) 406 | 407 | def test_connect_socket_error_return_value(self): 408 | '''Socket errors has connect return False''' 409 | self.connection.close() 410 | with mock.patch('nsq.connection.socket') as mock_socket: 411 | mock_socket.socket = mock.Mock(side_effect=socket.error) 412 | self.assertFalse(self.connection.connect()) 413 | 414 | def test_connect_socket_error_reset(self): 415 | '''Invokes reset if the socket raises an error''' 416 | self.connection.close() 417 | with mock.patch('nsq.connection.socket') as mock_socket: 418 | with mock.patch.object(self.connection, '_reset') as mock_reset: 419 | mock_socket.socket = mock.Mock(side_effect=socket.error) 420 | self.connection.connect() 421 | mock_reset.assert_called_with() 422 | 423 | def test_connect_timeout(self): 424 | '''Times out when connection instantiation is too slow''' 425 | socket = self.connection._socket 426 | self.connection.close() 427 | with mock.patch.object(self.connection, '_read', return_value=[]): 428 | with mock.patch.object(self.connection, '_timeout', 0.05): 429 | with mock.patch( 430 | 'nsq.connection.socket.socket', return_value=socket): 431 | self.assertFalse(self.connection.connect()) 432 | 433 | def test_connect_resets_state(self): 434 | '''Upon connection, makes a call to reset its state''' 435 | socket = self.connection._socket 436 | self.connection.close() 437 | with mock.patch.object(self.connection, '_read', return_value=[]): 438 | with mock.patch.object(self.connection, '_reset') as mock_reset: 439 | with mock.patch.object(self.connection, '_timeout', 0.05): 440 | with mock.patch( 441 | 'nsq.connection.socket.socket', return_value=socket): 442 | self.connection.connect() 443 | mock_reset.assert_called_with() 444 | 445 | def test_close_resets_state(self): 446 | '''On closing a connection, reset its state''' 447 | with mock.patch.object(self.connection, '_reset') as mock_reset: 448 | self.connection.close() 449 | mock_reset.assert_called_with() 450 | 451 | def test_reset_socket(self): 452 | '''Resets socket''' 453 | self.connection._socket = True 454 | self.connection._reset() 455 | self.assertEqual(self.connection._socket, None) 456 | 457 | def test_reset_pending(self): 458 | '''Resets pending''' 459 | self.connection._pending = True 460 | self.connection._reset() 461 | self.assertEqual(self.connection._pending, deque()) 462 | 463 | def test_reset_out_buffer(self): 464 | '''Resets the outbound buffer''' 465 | self.connection._out_buffer = True 466 | self.connection._reset() 467 | self.assertEqual(self.connection._out_buffer, b'') 468 | 469 | def test_reset_buffer(self): 470 | '''Resets buffer''' 471 | self.connection._buffer = True 472 | self.connection._reset() 473 | self.assertEqual(self.connection._buffer, b'') 474 | 475 | def test_reset_identify_response(self): 476 | '''Resets identify_response''' 477 | self.connection._identify_response = True 478 | self.connection._reset() 479 | self.assertEqual(self.connection._identify_response, {}) 480 | 481 | def test_reset_last_ready_sent(self): 482 | '''Resets last_ready_sent''' 483 | self.connection.last_ready_sent = True 484 | self.connection._reset() 485 | self.assertEqual(self.connection.last_ready_sent, 0) 486 | 487 | def test_reset_ready(self): 488 | '''Resets ready''' 489 | self.connection.ready = True 490 | self.connection._reset() 491 | self.assertEqual(self.connection.ready, 0) 492 | 493 | def test_ok_response(self): 494 | '''Sets our _identify_response to {} if 'OK' is provided''' 495 | res = response.Response( 496 | self.connection, response.Response.FRAME_TYPE, 'OK') 497 | self.connection.identified(res) 498 | self.assertEqual(self.connection._identify_response, {}) 499 | 500 | def test_tls_unsupported(self): 501 | '''Raises an exception if the server does not support TLS''' 502 | res = response.Response(self.connection, 503 | response.Response.FRAME_TYPE, json.dumps({'tls_v1': False})) 504 | options = {'tls_v1': True} 505 | with mock.patch.object(self.connection, '_identify_options', options): 506 | self.assertRaises(exceptions.UnsupportedException, 507 | self.connection.identified, res) 508 | 509 | def test_auth_required_not_provided(self): 510 | '''Raises an exception if auth is required but not provided''' 511 | res = response.Response(self.connection, response.Response.FRAME_TYPE, 512 | json.dumps({'auth_required': True})) 513 | self.assertRaises(exceptions.UnsupportedException, 514 | self.connection.identified, res) 515 | 516 | def test_auth_required_provided(self): 517 | '''Sends the auth message if required and provided''' 518 | res = response.Response(self.connection, response.Response.FRAME_TYPE, 519 | json.dumps({'auth_required': True})) 520 | with mock.patch.object(self.connection, 'auth') as mock_auth: 521 | with mock.patch.object(self.connection, '_auth_secret', 'hello'): 522 | self.connection.identified(res) 523 | mock_auth.assert_called_with('hello') 524 | 525 | def test_auth_provided_not_required(self): 526 | '''Logs a warning if you provide auth when none is required''' 527 | res = response.Response(self.connection, response.Response.FRAME_TYPE, 528 | json.dumps({'auth_required': False})) 529 | with mock.patch('nsq.connection.logger') as mock_logger: 530 | with mock.patch.object(self.connection, '_auth_secret', 'hello'): 531 | self.connection.identified(res) 532 | mock_logger.warning.assert_called_with( 533 | 'Authentication secret provided but not required') 534 | 535 | 536 | class TestTLSConnectionIntegration(HttpClientIntegrationTest): 537 | '''We can establish a connection with TLS''' 538 | def setUp(self): 539 | HttpClientIntegrationTest.setUp(self) 540 | self.connection = connection.Connection('localhost', 14150, tls_v1=True) 541 | self.connection.setblocking(0) 542 | 543 | def test_alive(self): 544 | '''The connection is alive''' 545 | self.assertTrue(self.connection.alive()) 546 | 547 | def test_basic(self): 548 | '''Can send and receive things''' 549 | self.connection.pub(b'foo', b'bar') 550 | self.connection.flush() 551 | responses = [] 552 | while not responses: 553 | responses = self.connection.read() 554 | self.assertEqual(len(responses), 1) 555 | self.assertEqual(responses[0].data, b'OK') 556 | -------------------------------------------------------------------------------- /test/test_reader.py: -------------------------------------------------------------------------------- 1 | import mock 2 | 3 | import uuid 4 | 5 | from nsq import reader 6 | from nsq import response 7 | 8 | from common import HttpClientIntegrationTest 9 | 10 | 11 | class TestReader(HttpClientIntegrationTest): 12 | '''Tests for our reader class''' 13 | nsqd_ports = (14150, 14152) 14 | 15 | def setUp(self): 16 | '''Return a connection''' 17 | HttpClientIntegrationTest.setUp(self) 18 | nsqd_tcp_addresses = ['localhost:%s' % port for port in self.nsqd_ports] 19 | self.client = reader.Reader( 20 | self.topic, self.channel, nsqd_tcp_addresses=nsqd_tcp_addresses) 21 | 22 | def test_it_subscribes(self): 23 | '''It subscribes for newly-established connections''' 24 | connection = mock.Mock() 25 | self.client.added(connection) 26 | connection.sub.assert_called_with(self.topic, self.channel) 27 | 28 | def test_new_connections_rdy(self): 29 | '''Calls rdy(1) when connections are added''' 30 | connection = mock.Mock() 31 | self.client.added(connection) 32 | connection.rdy.assert_called_with(1) 33 | 34 | def test_reconnected_rdy(self): 35 | '''Calls rdy(1) when connections are reestablished''' 36 | connection = mock.Mock() 37 | self.client.reconnected(connection) 38 | connection.rdy.assert_called_with(1) 39 | 40 | def test_added_dead(self): 41 | '''Does not call reconnected when adding dead connections''' 42 | conn = mock.Mock() 43 | conn.alive.return_value = False 44 | with mock.patch.object(self.client, 'reconnected') as mock_reconnected: 45 | self.client.added(conn) 46 | self.assertFalse(mock_reconnected.called) 47 | 48 | def test_it_checks_max_in_flight(self): 49 | '''Raises an exception if more connections than in-flight limit''' 50 | with mock.patch.object(self.client, '_max_in_flight', 0): 51 | self.assertRaises(NotImplementedError, self.client.distribute_ready) 52 | 53 | def test_it_distributes_ready(self): 54 | '''It distributes RDY with util.distribute''' 55 | with mock.patch('nsq.reader.distribute') as mock_distribute: 56 | counts = range(10) 57 | connections = [mock.Mock(max_rdy_count=100) for _ in counts] 58 | mock_distribute.return_value = zip(counts, connections) 59 | self.client.distribute_ready() 60 | for count, connection in zip(counts, connections): 61 | connection.rdy.assert_called_with(count) 62 | 63 | def test_it_ignores_dead_connections(self): 64 | '''It does not distribute RDY state to dead connections''' 65 | dead = mock.Mock(max_rdy_count=100) 66 | dead.alive.return_value = False 67 | alive = mock.Mock(max_rdy_count=100) 68 | alive.alive.return_value = True 69 | with mock.patch.object( 70 | self.client, 'connections', return_value=[alive, dead]): 71 | self.client.distribute_ready() 72 | self.assertTrue(alive.rdy.called) 73 | self.assertFalse(dead.rdy.called) 74 | 75 | def test_zero_ready(self): 76 | '''When a connection has ready=0, distribute_ready is invoked''' 77 | connection = self.client.connections()[0] 78 | with mock.patch.object(connection, 'ready', 0): 79 | self.assertTrue(self.client.needs_distribute_ready()) 80 | 81 | def test_not_ready(self): 82 | '''When no connection has ready=0, distribute_ready is not invoked''' 83 | connection = self.client.connections()[0] 84 | with mock.patch.object(connection, 'ready', 10): 85 | self.assertFalse(self.client.needs_distribute_ready()) 86 | 87 | def test_negative_ready(self): 88 | '''If clients have negative RDY values, distribute_ready is invoked''' 89 | connection = self.client.connections()[0] 90 | with mock.patch.object(connection, 'ready', -1): 91 | self.assertTrue(self.client.needs_distribute_ready()) 92 | 93 | def test_low_ready(self): 94 | '''If clients have negative RDY values, distribute_ready is invoked''' 95 | connection = self.client.connections()[0] 96 | with mock.patch.object(connection, 'ready', 2): 97 | with mock.patch.object(connection, 'last_ready_sent', 10): 98 | self.assertTrue(self.client.needs_distribute_ready()) 99 | 100 | def test_none_alive(self): 101 | '''We don't need to redistribute RDY if there are none alive''' 102 | with mock.patch.object(self.client, 'connections', return_value=[]): 103 | self.assertFalse(self.client.needs_distribute_ready()) 104 | 105 | def test_read_distribute_ready(self): 106 | '''Read checks if we need to distribute ready''' 107 | with mock.patch('nsq.reader.Client'): 108 | with mock.patch.object( 109 | self.client, 'needs_distribute_ready', return_value=True): 110 | with mock.patch.object( 111 | self.client, 'distribute_ready') as mock_ready: 112 | self.client.read() 113 | mock_ready.assert_called_with() 114 | 115 | def test_read_not_distribute_ready(self): 116 | '''Does not redistribute ready if not needed''' 117 | with mock.patch('nsq.reader.Client'): 118 | with mock.patch.object( 119 | self.client, 'needs_distribute_ready', return_value=False): 120 | with mock.patch.object( 121 | self.client, 'distribute_ready') as mock_ready: 122 | self.client.read() 123 | self.assertFalse(mock_ready.called) 124 | 125 | def test_iter(self): 126 | '''The client can be used as an iterator''' 127 | iterator = iter(self.client) 128 | message_id = uuid.uuid4().hex[0:16].encode() 129 | packed = response.Message.pack(0, 0, message_id, b'hello') 130 | messages = [response.Message(None, None, packed) for _ in range(10)] 131 | with mock.patch.object(self.client, 'read', return_value=messages): 132 | found = [next(iterator) for _ in range(10)] 133 | self.assertEqual(messages, found) 134 | 135 | def test_iter_repeated_read(self): 136 | '''Repeatedly calls read in iterator mode''' 137 | iterator = iter(self.client) 138 | message_id = uuid.uuid4().hex[0:16].encode() 139 | packed = response.Message.pack(0, 0, message_id, b'hello') 140 | messages = [response.Message(None, None, packed) for _ in range(10)] 141 | for message in messages: 142 | with mock.patch.object(self.client, 'read', return_value=[message]): 143 | self.assertEqual(next(iterator), message) 144 | 145 | def test_skip_non_messages(self): 146 | '''Skips all non-messages''' 147 | iterator = iter(self.client) 148 | message_id = uuid.uuid4().hex[0:16].encode() 149 | packed = response.Message.pack(0, 0, message_id, b'hello') 150 | messages = [response.Message(None, None, packed) for _ in range(10)] 151 | packed = response.Response.pack(b'hello') 152 | responses = [ 153 | response.Response(None, None, packed) for _ in range(10)] + messages 154 | with mock.patch.object(self.client, 'read', return_value=responses): 155 | found = [next(iterator) for _ in range(10)] 156 | self.assertEqual(messages, found) 157 | 158 | def test_honors_max_rdy_count(self): 159 | '''Honors the max RDY count provided in an identify response''' 160 | for conn in self.client.connections(): 161 | conn.max_rdy_count = 10 162 | self.client.distribute_ready() 163 | self.assertEqual(self.client.connections()[0].ready, 10) 164 | 165 | def test_read(self): 166 | '''Can receive a message in a basic way''' 167 | self.nsqd.pub(self.topic, b'hello') 168 | message = next(iter(self.client)) 169 | self.assertEqual(message.body, b'hello') 170 | 171 | def test_close_redistribute(self): 172 | '''Redistributes rdy count when a connection is closed''' 173 | with mock.patch.object(self.client, 'distribute_ready') as mock_ready: 174 | self.client.close_connection(self.client.connections()[0]) 175 | mock_ready.assert_called_with() 176 | -------------------------------------------------------------------------------- /test/test_response.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import mock 4 | import uuid 5 | import six 6 | import socket 7 | import struct 8 | from nsq import response 9 | from nsq import constants 10 | from nsq import exceptions 11 | 12 | 13 | class TestResponse(unittest.TestCase): 14 | '''Test our response class''' 15 | def test_from_raw_response(self): 16 | '''Make sure we can construct a raw response''' 17 | raw = struct.pack('>l5s', constants.FRAME_TYPE_RESPONSE, b'hello') 18 | res = response.Response.from_raw(None, raw) 19 | self.assertEqual(res.__class__, response.Response) 20 | self.assertEqual(res.data, b'hello') 21 | 22 | def test_from_raw_unknown_frame(self): 23 | '''Raises an exception for unknown frame types''' 24 | raw = struct.pack('>l5s', 9042, b'hello') 25 | self.assertRaises(TypeError, response.Response.from_raw, None, raw) 26 | 27 | def test_str(self): 28 | '''Has a reasonable string value''' 29 | raw = struct.pack('>l5s', constants.FRAME_TYPE_RESPONSE, b'hello') 30 | res = response.Response.from_raw(None, raw) 31 | if six.PY2: 32 | self.assertEqual(str(res), 'Response - hello') 33 | else: 34 | self.assertEqual(str(res), "Response - b'hello'") 35 | 36 | def test_pack(self): 37 | '''Can pack itself up''' 38 | packed = response.Response.pack(b'hello')[4:] 39 | unpacked = response.Response.from_raw(None, packed) 40 | self.assertIsInstance(unpacked, response.Response) 41 | self.assertEqual(unpacked.data, b'hello') 42 | 43 | 44 | class TestError(unittest.TestCase): 45 | '''Test our error response class''' 46 | def test_from_raw_error(self): 47 | '''Can identify an error type''' 48 | raw = struct.pack('>l5s', constants.FRAME_TYPE_ERROR, b'hello') 49 | res = response.Response.from_raw(None, raw) 50 | self.assertEqual(res.__class__, response.Error) 51 | self.assertEqual(res.data, b'hello') 52 | 53 | def test_str(self): 54 | '''Has a reasonable string value''' 55 | raw = struct.pack('>l5s', constants.FRAME_TYPE_ERROR, b'hello') 56 | res = response.Response.from_raw(None, raw) 57 | if six.PY2: 58 | self.assertEqual(str(res), 'Error - hello') 59 | else: 60 | self.assertEqual(str(res), "Error - b'hello'") 61 | 62 | def test_find(self): 63 | '''Can correctly identify the appropriate exception''' 64 | expected = { 65 | b'E_INVALID': exceptions.InvalidException, 66 | b'E_BAD_BODY': exceptions.BadBodyException, 67 | b'E_BAD_TOPIC': exceptions.BadTopicException, 68 | b'E_BAD_CHANNEL': exceptions.BadChannelException, 69 | b'E_BAD_MESSAGE': exceptions.BadMessageException, 70 | b'E_PUB_FAILED': exceptions.PubFailedException, 71 | b'E_MPUB_FAILED': exceptions.MpubFailedException, 72 | b'E_FIN_FAILED': exceptions.FinFailedException, 73 | b'E_REQ_FAILED': exceptions.ReqFailedException, 74 | b'E_TOUCH_FAILED': exceptions.TouchFailedException 75 | } 76 | for key, klass in expected.items(): 77 | self.assertEqual(response.Error.find(key), klass) 78 | 79 | def test_find_missing(self): 80 | '''Raises an exception if it can't find the appropriate exception''' 81 | self.assertRaises(TypeError, response.Error.find, 'woruowijklf') 82 | 83 | def test_exception(self): 84 | '''Can correctly raise the appropriate exception''' 85 | raw = struct.pack('>l13s', constants.FRAME_TYPE_ERROR, b'E_INVALID foo') 86 | res = response.Response.from_raw(None, raw) 87 | exc = res.exception() 88 | self.assertIsInstance(exc, exceptions.InvalidException) 89 | if six.PY2: 90 | self.assertEqual(str(exc), 'foo') 91 | else: 92 | self.assertEqual(str(exc), "b'foo'") 93 | 94 | def test_pack(self): 95 | '''Can pack itself up''' 96 | packed = response.Error.pack(b'hello')[4:] 97 | unpacked = response.Response.from_raw(None, packed) 98 | self.assertIsInstance(unpacked, response.Error) 99 | self.assertEqual(unpacked.data, b'hello') 100 | 101 | 102 | class TestMessage(unittest.TestCase): 103 | '''Test our message case''' 104 | def setUp(self): 105 | self.id = uuid.uuid4().hex[:16].encode() 106 | self.timestamp = 0 107 | self.attempt = 1 108 | self.body = b'hello' 109 | self.packed = struct.pack('>qH16s5s', 0, 1, self.id, self.body) 110 | self.response = response.Response.from_raw(mock.Mock(), 111 | struct.pack('>l31s', constants.FRAME_TYPE_MESSAGE, self.packed)) 112 | 113 | def test_str(self): 114 | '''Has a reasonable string value''' 115 | if six.PY2: 116 | self.assertEqual(str(self.response), 'Message - 0 1 {} hello'.format(self.id)) 117 | else: 118 | self.assertEqual(str(self.response), "Message - 0 1 {!r} b'hello'".format(self.id)) 119 | 120 | def test_from_raw_message(self): 121 | '''Can identify a message type''' 122 | self.assertEqual(self.response.__class__, response.Message) 123 | 124 | def test_timestamp(self): 125 | '''Can identify the timestamp''' 126 | self.assertEqual(self.response.timestamp, self.timestamp) 127 | 128 | def test_attempts(self): 129 | '''Can identify the number of attempts''' 130 | self.assertEqual(self.response.attempts, self.attempt) 131 | 132 | def test_id(self): 133 | '''Can identify the ID of the message''' 134 | self.assertEqual(self.response.id, self.id) 135 | 136 | def test_message(self): 137 | '''Can properly detect the message''' 138 | self.assertEqual(self.response.body, self.body) 139 | 140 | def test_fin(self): 141 | '''Invokes the fin method''' 142 | self.response.fin() 143 | self.response.connection.fin.assert_called_with(self.id) 144 | 145 | def test_req(self): 146 | '''Invokes the req method''' 147 | self.response.req(1) 148 | self.response.connection.req.assert_called_with(self.id, 1) 149 | 150 | def test_touch(self): 151 | '''Invokes the touch method''' 152 | self.response.touch() 153 | self.response.connection.touch.assert_called_with(self.id) 154 | 155 | def test_pack(self): 156 | '''Can pack itself up''' 157 | packed = response.Message.pack( 158 | self.timestamp, self.attempt, self.id, self.body)[4:] 159 | unpacked = response.Response.from_raw(None, packed) 160 | self.assertIsInstance(unpacked, response.Message) 161 | self.assertEqual(unpacked.timestamp, self.timestamp) 162 | self.assertEqual(unpacked.attempts, self.attempt) 163 | self.assertEqual(unpacked.id, self.id) 164 | self.assertEqual(unpacked.body, self.body) 165 | 166 | def test_handle_yields(self): 167 | '''The handle method should yield the message''' 168 | with self.response.handle() as msg: 169 | self.assertEqual(msg, self.response) 170 | 171 | def test_handle_exception(self): 172 | '''Handles exceptions by requeueing''' 173 | try: 174 | with self.response.handle(): 175 | raise ValueError('foo') 176 | except ValueError: 177 | self.response.connection.req.assert_called_with( 178 | self.response.id, self.response.delay()) 179 | else: 180 | self.assertTrue(False, 'No exception was raised') 181 | 182 | def test_handle_success(self): 183 | '''Handles success by calling fin''' 184 | with self.response.handle(): 185 | pass 186 | self.response.connection.fin.assert_called_with(self.response.id) 187 | 188 | def test_handle_already_requeued(self): 189 | '''If we've already requeued a message, doesn't requeue it again''' 190 | try: 191 | with self.response.handle(): 192 | self.response.req(10) 193 | raise ValueError('foo') 194 | except ValueError: 195 | self.assertEqual(self.response.connection.req.call_count, 1) 196 | 197 | def test_handle_already_finish(self): 198 | '''If we've already finished a messages, doesn't finish it again''' 199 | with self.response.handle(): 200 | self.response.fin() 201 | self.assertEqual(self.response.connection.fin.call_count, 1) 202 | 203 | def test_handle_exception_socket_error(self): 204 | '''Handles socket errors when catching exceptions''' 205 | try: 206 | self.response.connection.req = mock.Mock(side_effect=socket.error) 207 | with self.response.handle(): 208 | raise ValueError('foo') 209 | except ValueError: 210 | # The connection should have been closed 211 | self.response.connection.close.assert_called_with() 212 | 213 | def test_handle_success_socket_error(self): 214 | '''Handles socket errors when trying to complete the message''' 215 | try: 216 | self.response.connection.fin = mock.Mock(side_effect=socket.error) 217 | with self.response.handle(): 218 | pass 219 | except ValueError: 220 | # The connection should have been closed 221 | self.response.connection.close.assert_called_with() 222 | -------------------------------------------------------------------------------- /test/test_sockets/test_base.py: -------------------------------------------------------------------------------- 1 | import mock 2 | import unittest 3 | 4 | from nsq.sockets.base import SocketWrapper 5 | 6 | 7 | class TestSocketWrapper(unittest.TestCase): 8 | '''Test the SocketWrapper class''' 9 | def setUp(self): 10 | self.socket = mock.Mock() 11 | self.wrapped = SocketWrapper.wrap_socket(self.socket) 12 | 13 | def test_wrap_socket(self): 14 | '''Passes through objects to the constructor''' 15 | with mock.patch.object(SocketWrapper, '__init__') as mock_init: 16 | mock_init.return_value = None 17 | SocketWrapper.wrap_socket(5, hello='foo') 18 | mock_init.assert_called_with(5, hello='foo') 19 | 20 | def test_method_pass_through(self): 21 | '''Passes through most methods directly to the underlying socket''' 22 | self.assertEqual(self.wrapped.accept, self.socket.accept) 23 | 24 | def test_send(self): 25 | '''SocketWrapper.send saises NotImplementedError''' 26 | self.assertRaises(NotImplementedError, self.wrapped.send, 'foo') 27 | 28 | def test_sendall(self): 29 | '''Repeatedly calls send until everything has been sent''' 30 | with mock.patch.object(self.wrapped, 'send') as mock_send: 31 | # Only sends one byte at a time 32 | mock_send.return_value = 1 33 | self.wrapped.sendall('hello') 34 | self.assertEqual(mock_send.call_count, 5) 35 | 36 | def test_recv(self): 37 | '''SocketWrapper.recv saises NotImplementedError''' 38 | self.assertRaises(NotImplementedError, self.wrapped.recv, 5) 39 | 40 | def test_recv_into(self): 41 | '''SocketWrapper.recv_into saises NotImplementedError''' 42 | self.assertRaises(NotImplementedError, self.wrapped.recv_into, 'foo', 5) 43 | 44 | def test_inheritance_overrides(self): 45 | '''Classes that inherit can override things like accept''' 46 | class Foo(SocketWrapper): 47 | def close(self): 48 | pass 49 | 50 | wrapped = Foo.wrap_socket(self.socket) 51 | self.assertNotEqual(wrapped.close, self.socket.close) 52 | -------------------------------------------------------------------------------- /test/test_sockets/test_tls.py: -------------------------------------------------------------------------------- 1 | import mock 2 | import unittest 3 | 4 | from contextlib import contextmanager 5 | import ssl 6 | 7 | from nsq.sockets import tls 8 | 9 | 10 | class TestTLSSocket(unittest.TestCase): 11 | '''Test the SocketWrapper class''' 12 | @contextmanager 13 | def wrapped_ssl_socket(self): 14 | sock = mock.Mock() 15 | with mock.patch.object(tls.ssl, 'wrap_socket', return_value=sock): 16 | yield sock 17 | 18 | def test_needs_read(self): 19 | '''If the handshake needs reading, calls do_handshake again''' 20 | with self.wrapped_ssl_socket() as sock: 21 | effects = [ssl.SSLError(ssl.SSL_ERROR_WANT_READ), None] 22 | with mock.patch.object(sock, 'do_handshake', side_effect=effects): 23 | tls.TLSSocket.wrap_socket(sock) 24 | 25 | def test_needs_write(self): 26 | '''If the handshake needs writing, calls do_handshake again''' 27 | with self.wrapped_ssl_socket() as sock: 28 | effects = [ssl.SSLError(ssl.SSL_ERROR_WANT_WRITE), None] 29 | with mock.patch.object(sock, 'do_handshake', side_effect=effects): 30 | tls.TLSSocket.wrap_socket(sock) 31 | 32 | def test_raises_exceptions(self): 33 | '''Bubbles up non-EAGAIN-like exceptions''' 34 | with self.wrapped_ssl_socket() as sock: 35 | effects = [ssl.SSLError(ssl.SSL_ERROR_SSL), None] 36 | with mock.patch.object(sock, 'do_handshake', side_effect=effects): 37 | self.assertRaises(ssl.SSLError, tls.TLSSocket.wrap_socket, sock) 38 | -------------------------------------------------------------------------------- /test/test_stats.py: -------------------------------------------------------------------------------- 1 | '''Test our stats utility''' 2 | 3 | import os 4 | 5 | import simplejson as json 6 | import six 7 | 8 | from nsq.stats import Nsqlookupd 9 | from nsq.http import nsqd 10 | from common import IntegrationTest 11 | 12 | 13 | class TestStats(IntegrationTest): 14 | '''Test our stats utility.''' 15 | 16 | nsqd_ports = (14150, 14152) 17 | 18 | def assertFixture(self, path, actual): 19 | '''Assert actual is equivalent to the JSON fixture provided.''' 20 | if os.environ.get('RECORD', 'false').lower() == 'true': 21 | with open(path, 'w') as fout: 22 | json.dump(actual, 23 | fout, sort_keys=True, indent=4, separators=(',', ': ' )) 24 | # Add a trailing newline 25 | fout.write('\n') 26 | else: 27 | with open(path) as fin: 28 | self.assertEqual(json.load(fin), actual) 29 | 30 | def test_stats(self): 31 | '''Can effectively grab all the stats.''' 32 | # Create some topics, channels, and messages on the first nsqd instance 33 | client = nsqd.Client('http://localhost:14151') 34 | topics = [ 35 | 'topic-on-both-instances', 36 | 'topic-with-channels', 37 | 'topic-without-channels' 38 | ] 39 | for topic in topics: 40 | client.create_topic(topic) 41 | client.mpub(topic, [six.text_type(i).encode() for i in range(len(topic))]) 42 | client.create_channel('topic-with-channels', 'channel') 43 | 44 | # Create a topic and messages on the second nsqd isntance 45 | client = nsqd.Client('http://localhost:14153') 46 | client.create_topic('topic-on-both-instances') 47 | client.mpub(topic, [six.text_type(i).encode() for i in range(10)]) 48 | 49 | # Check the stats 50 | self.assertFixture( 51 | 'test/fixtures/test_stats/TestStats/stats', 52 | [list(x) for x in Nsqlookupd('http://localhost:14161').stats]) 53 | -------------------------------------------------------------------------------- /test/test_util.py: -------------------------------------------------------------------------------- 1 | '''Test all of our utility functions''' 2 | 3 | import unittest 4 | 5 | import struct 6 | from nsq import util 7 | 8 | 9 | class TestPack(unittest.TestCase): 10 | '''Test our packing utility''' 11 | def test_string(self): 12 | '''Give it a low-ball test''' 13 | message = b'hello' 14 | self.assertEqual(util.pack(message), struct.pack('>l5s', 5, message)) 15 | 16 | def test_iterable(self): 17 | '''Make sure it handles iterables''' 18 | messages = [b'hello'] * 10 19 | packed = struct.pack('>l5s', 5, b'hello') 20 | expected = struct.pack('>ll90s', 94, 10, packed * 10) 21 | self.assertEqual(util.pack(messages), expected) 22 | 23 | def test_iterable_of_iterables(self): 24 | '''Should complain in the event of nested iterables''' 25 | messages = [[b'hello'] * 5] * 10 26 | self.assertRaises(TypeError, util.pack, messages) 27 | 28 | 29 | class TestHexify(unittest.TestCase): 30 | '''Test our hexification utility''' 31 | def setUp(self): 32 | self.message = '\x00hello\n\tFOO2' 33 | 34 | def test_identical(self): 35 | '''Does not transform the value of the text''' 36 | import ast 37 | hexified = util.hexify(self.message) 38 | self.assertEqual(self.message, ast.literal_eval("'%s'" % hexified)) 39 | 40 | def test_meaningful(self): 41 | '''The output it gives is meaningful''' 42 | hexified = util.hexify(self.message) 43 | self.assertEqual(hexified, '\\x00hello\\x0a\\x09FOO2') 44 | 45 | 46 | class TestDistribute(unittest.TestCase): 47 | '''Test the distribute''' 48 | def counts(self, total, objects): 49 | '''Return a list of the counts returned by distribute''' 50 | return tuple(zip(*util.distribute(total, objects)))[0] 51 | 52 | def count(self, total, objects): 53 | '''Return the sum of the counts''' 54 | return sum(self.counts(total, objects)) 55 | 56 | def test_sum_evenly_divisible(self): 57 | '''We get the expected total when total is evenly divisible''' 58 | self.assertEqual(self.count(10, range(5)), 10) 59 | 60 | def test_sum_not_evenly_divisible(self): 61 | '''We get the expected total when total not evenly divisible''' 62 | self.assertEqual(self.count(10, range(3)), 10) 63 | 64 | def test_min_max(self): 65 | '''The minimum and maximum should be within 1''' 66 | for num in range(1, 50): 67 | objects = range(num) 68 | for total in range(1, 50): 69 | counts = self.counts(total, objects) 70 | self.assertLessEqual(max(counts) - min(counts), 1) 71 | 72 | def test_distribute_types(self): 73 | '''Distribute should always return integers''' 74 | parts = tuple(util.distribute(1000, (1, 2, 3))) 75 | self.assertEqual(parts, ((333, 1), (333, 2), (334, 3))) 76 | --------------------------------------------------------------------------------