├── .gitignore ├── .pyup.yml ├── .travis.yml ├── CHANGES.rst ├── LICENSE ├── MANIFEST.in ├── README.rst ├── aioelasticsearch ├── __init__.py ├── connection.py ├── exceptions.py ├── helpers.py ├── pool.py └── transport.py ├── pytest.ini ├── requirements.txt ├── setup.cfg ├── setup.py ├── tests ├── conftest.py ├── test_aioelasticsearch.py ├── test_connection.py ├── test_pool.py ├── test_scan.py └── test_transport.py └── tox.ini /.gitignore: -------------------------------------------------------------------------------- 1 | # python specific 2 | env* 3 | .cache/ 4 | *.pyc 5 | *.so 6 | *.pyd 7 | build/* 8 | dist/* 9 | MANIFEST 10 | __pycache__/ 11 | *.egg-info/ 12 | .coverage 13 | .python-version 14 | htmlcov 15 | 16 | # generic files to ignore 17 | *~ 18 | *.lock 19 | *.DS_Store 20 | *.swp 21 | *.out 22 | 23 | .tox/ 24 | deps/ 25 | docs/_build/ 26 | .idea/ 27 | .pytest_cache -------------------------------------------------------------------------------- /.pyup.yml: -------------------------------------------------------------------------------- 1 | # autogenerated pyup.io config file 2 | # see https://pyup.io/docs/configuration/ for all available options 3 | 4 | schedule: every week 5 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | conditions: v1 2 | 3 | dist: xenial 4 | language: python 5 | python: 6 | - "3.5" 7 | - "3.6" 8 | - "3.7" 9 | install: 10 | - pip install -U tox 11 | script: 12 | - tox -- --es_tag=${ES_TAG} 13 | after_success: 14 | - tox -e coverage 15 | cache: pip 16 | env: 17 | TOXENV: py 18 | ES_TAG: 7.3.1 19 | 20 | jobs: 21 | fast_finish: true 22 | include: 23 | - stage: Deploy to PYPI 24 | if: tag IS present 25 | python: "3.7" 26 | install: skip 27 | script: skip 28 | after_success: [] 29 | deploy: 30 | provider: pypi 31 | user: aio-libs-bot 32 | password: 33 | secure: "cKvNQXCPwKAZHB02Me462NOyd4za1r5TJYUgCvrAkJTSf6qHdlLnOvH1bZtDyfs2GEpvYlx8nNZJHkbVHSiz3UEkaiuLj6xfaJdVUpJ9qK5w4YQvrMeb4cD6amIvMttrgMoppGq4GSjdcggBIwNcmgk2g5PrpozheK2GWkA1B8rlsnud4JRPhenzchH5yvv1VBXokVFvv6SyS+EIL8DEKRCZQ/Bug3N5QVPXWQn52JINY0c6v0UgLJzc0F82VCJKYdKkVBXLFGakwHlufjcn5TB9myia5hBeObdyXlJGk+NiGrZGchluAfE+QWfdtTn9AAyUHdNV7jjQWQtfv5gj025Zo2IXNRvHkeIG6CiHyZ3YiXorA20kZgaHd95PG+WkJwbRUBJdCgrZ1r+saxrA8D7tT3aljEwc10jz/YcQdazQmFXqV2NP0PDAiqNS0zgIPPZQn2lo+KUBpiUA+h/Wp/Jv3XlmhyyvAeCfGVMDdBRhIq5SDHdaqVJ+UbGR79qKcA9/bKYimyZ83MRUBUsZPMh/VElLKImN3LqjCjsoEeFCeh5pZvEmrydNRv4aeE124nnYX0DY6Md6+NN6wjnBagg+Ws4cR0UzsQrO35L/JECEEklHQv+nGn4sTmMoHwf9IR0XeNwol3dnY/3HVuipgNcFN84Ir7F4gcUtshJu0Vo=" 34 | distributions: sdist bdist_wheel 35 | on: 36 | tags: true 37 | all_branches: true 38 | -------------------------------------------------------------------------------- /CHANGES.rst: -------------------------------------------------------------------------------- 1 | Changes 2 | ======= 3 | 4 | 0.7.0 (2019-11-07) 5 | ------------------ 6 | 7 | - Support ``ession_factory_class`` by ``AIOHttpConnection`` (#211) 8 | 9 | 0.6.0 (2019-11-07) 10 | ------------------ 11 | 12 | - Support elasticsearch 7.x (#194) 13 | 14 | 0.5.1 (2018-03-12) 15 | ------------------ 16 | 17 | - Don't use deprecated ``verify_ssl`` parameter 18 | 19 | 0.5.0 (2018-02-14) 20 | ------------------ 21 | 22 | - Fix compatibility with aiohttp 3.0 23 | 24 | 25 | 0.4.0 (2017-12-05) 26 | ------------------ 27 | 28 | - First public release 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2018 aio-libs team https://github.com/aio-libs/ 4 | Copyright (c) 2017-2018 Ocean S.A. https://ocean.io/ 5 | Copyright (c) 2017 WikiBusiness Corporation. http://wikibusiness.org/ 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in 15 | all copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 23 | THE SOFTWARE. 24 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.rst 2 | include LICENSE 3 | recursive-exclude * __pycache__ 4 | recursive-exclude * *.py[co] 5 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | aioelasticsearch 2 | ================ 3 | 4 | :info: elasticsearch-py wrapper for asyncio 5 | 6 | .. image:: https://img.shields.io/travis/aio-libs/aioelasticsearch.svg 7 | :target: https://travis-ci.org/aio-libs/aioelasticsearch 8 | 9 | .. image:: https://img.shields.io/pypi/v/aioelasticsearch.svg 10 | :target: https://pypi.python.org/pypi/aioelasticsearch 11 | 12 | .. image:: https://codecov.io/gh/aio-libs/aioelasticsearch/branch/master/graph/badge.svg 13 | :target: https://codecov.io/gh/aio-libs/aioelasticsearch 14 | 15 | Installation 16 | ------------ 17 | 18 | .. code-block:: shell 19 | 20 | pip install aioelasticsearch 21 | 22 | Usage 23 | ----- 24 | 25 | .. code-block:: python 26 | 27 | import asyncio 28 | 29 | from aioelasticsearch import Elasticsearch 30 | 31 | async def go(): 32 | es = Elasticsearch() 33 | 34 | print(await es.search()) 35 | 36 | await es.close() 37 | 38 | loop = asyncio.get_event_loop() 39 | loop.run_until_complete(go()) 40 | loop.close() 41 | 42 | Features 43 | -------- 44 | 45 | Asynchronous `scroll `_ 46 | 47 | .. code-block:: python 48 | 49 | import asyncio 50 | 51 | from aioelasticsearch import Elasticsearch 52 | from aioelasticsearch.helpers import Scan 53 | 54 | async def go(): 55 | async with Elasticsearch() as es: 56 | async with Scan( 57 | es, 58 | index='index', 59 | doc_type='doc_type', 60 | query={}, 61 | ) as scan: 62 | print(scan.total) 63 | 64 | async for doc in scan: 65 | print(doc['_source']) 66 | 67 | loop = asyncio.get_event_loop() 68 | loop.run_until_complete(go()) 69 | loop.close() 70 | 71 | Thanks 72 | ------ 73 | 74 | The library was donated by `Ocean S.A. `_ 75 | 76 | Thanks to the company for contribution. 77 | -------------------------------------------------------------------------------- /aioelasticsearch/__init__.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from elasticsearch import Elasticsearch as _Elasticsearch # noqa # isort:skip 4 | from elasticsearch.connection_pool import (ConnectionSelector, # noqa # isort:skip 5 | RoundRobinSelector) 6 | from elasticsearch.serializer import JSONSerializer # noqa # isort:skip 7 | 8 | from .exceptions import * # noqa # isort:skip 9 | from .pool import AIOHttpConnectionPool # noqa # isort:skip 10 | from .transport import AIOHttpTransport # noqa # isort:skip 11 | 12 | 13 | __version__ = '0.7.0' 14 | 15 | 16 | class Elasticsearch(_Elasticsearch): 17 | 18 | def __init__( 19 | self, 20 | hosts=None, 21 | transport_class=AIOHttpTransport, 22 | *, 23 | loop=None, 24 | **kwargs 25 | ): 26 | if loop is None: 27 | loop = asyncio.get_event_loop() 28 | 29 | self.loop = loop 30 | 31 | kwargs['loop'] = self.loop 32 | 33 | super().__init__(hosts, transport_class=transport_class, **kwargs) 34 | 35 | async def close(self): 36 | await self.transport.close() 37 | 38 | async def __aenter__(self): # noqa 39 | return self 40 | 41 | async def __aexit__(self, *exc_info): # noqa 42 | await self.close() 43 | -------------------------------------------------------------------------------- /aioelasticsearch/connection.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import aiohttp 4 | 5 | from .exceptions import ConnectionError, ConnectionTimeout, SSLError # noqa # isort:skip 6 | 7 | from elasticsearch.connection import Connection # noqa # isort:skip 8 | from yarl import URL # noqa # isort:skip 9 | 10 | 11 | def session_factory(**kwargs): 12 | connector = aiohttp.TCPConnector( 13 | loop=kwargs.get('loop'), 14 | limit=kwargs.get('limit', 10), 15 | use_dns_cache=kwargs.get('use_dns_cache', False), 16 | ssl=kwargs.get('ssl', False), 17 | ) 18 | 19 | return aiohttp.ClientSession( 20 | auth=kwargs.get('auth'), 21 | connector=connector, 22 | ) 23 | 24 | 25 | class AIOHttpConnection(Connection): 26 | 27 | def __init__( 28 | self, 29 | host='localhost', 30 | port=9200, 31 | http_auth=None, 32 | use_ssl=False, 33 | ssl_context=None, 34 | verify_certs=False, 35 | maxsize=10, 36 | headers=None, 37 | *, 38 | loop, 39 | **kwargs 40 | ): 41 | assert not( 42 | kwargs.get('session') and 43 | kwargs.get('session_factory') 44 | ), 'Provide `session` or `session_factory`, not both.' 45 | 46 | super().__init__(host=host, port=port, use_ssl=use_ssl, **kwargs) 47 | 48 | if headers is None: 49 | headers = {} 50 | self.headers = headers 51 | self.headers.setdefault('Content-Type', 'application/json') 52 | 53 | self.loop = loop 54 | 55 | if http_auth is not None: 56 | if isinstance(http_auth, aiohttp.BasicAuth): 57 | pass 58 | elif isinstance(http_auth, str): 59 | http_auth = aiohttp.BasicAuth(*http_auth.split(':', 1)) 60 | elif isinstance(http_auth, (tuple, list)): 61 | http_auth = aiohttp.BasicAuth(*http_auth) 62 | else: 63 | raise TypeError("Expected str, list, tuple or " 64 | "aiohttp.BasicAuth as http_auth parameter," 65 | "got {!r}".format(http_auth)) 66 | 67 | self.http_auth = http_auth 68 | 69 | self.verify_certs = verify_certs 70 | 71 | self.base_url = URL.build(scheme='https' if self.use_ssl else 'http', 72 | host=host, 73 | port=port, 74 | path=self.url_prefix) 75 | 76 | self.session = kwargs.get('session') 77 | self.close_session = False 78 | 79 | if self.session is None: 80 | 81 | self._session_factory = kwargs.get( 82 | 'session_factory', 83 | session_factory, 84 | ) 85 | 86 | self.session = self._session_factory( 87 | auth=self.http_auth, 88 | loop=self.loop, 89 | ssl=ssl_context if self.verify_certs else False, 90 | limit=maxsize, 91 | use_dns_cache=kwargs.get('use_dns_cache', False), 92 | ) 93 | 94 | self.close_session = True 95 | 96 | async def close(self): 97 | if self.close_session: 98 | await self.session.close() 99 | 100 | async def perform_request( 101 | self, 102 | method, 103 | url, 104 | params=None, 105 | body=None, 106 | headers=None, 107 | timeout=None, 108 | ignore=() 109 | ): 110 | url_path = url 111 | 112 | url = (self.base_url / url.lstrip('/')).with_query(params) 113 | 114 | start = self.loop.time() 115 | try: 116 | async with self.session.request( 117 | method, 118 | url, 119 | data=body, 120 | headers=self._build_headers(headers), 121 | timeout=timeout or self.timeout) as response: 122 | raw_data = await response.text() 123 | 124 | duration = self.loop.time() - start 125 | 126 | except aiohttp.ClientSSLError as exc: 127 | self.log_request_fail( 128 | method, 129 | url, 130 | url_path, 131 | body, 132 | self.loop.time() - start, 133 | exception=exc, 134 | ) 135 | raise SSLError('N/A', str(exc), exc) 136 | 137 | except asyncio.TimeoutError as exc: 138 | self.log_request_fail( 139 | method, 140 | url, 141 | url_path, 142 | body, 143 | self.loop.time() - start, 144 | exception=exc, 145 | ) 146 | raise ConnectionTimeout('TIMEOUT', str(exc), exc) 147 | 148 | except aiohttp.ClientError as exc: 149 | self.log_request_fail( 150 | method, 151 | url, 152 | url_path, 153 | body, 154 | self.loop.time() - start, 155 | exception=exc, 156 | ) 157 | 158 | raise ConnectionError('N/A', str(exc), exc) 159 | 160 | # raise errors based on http status codes 161 | # let the client handle those if needed 162 | if ( 163 | not (200 <= response.status < 300) and 164 | response.status not in ignore 165 | ): 166 | self.log_request_fail( 167 | method, 168 | url, 169 | url_path, 170 | body, 171 | duration, 172 | response.status, 173 | raw_data, 174 | ) 175 | self._raise_error(response.status, raw_data) 176 | 177 | self.log_request_success( 178 | method, 179 | url, 180 | url_path, 181 | body, 182 | response.status, 183 | raw_data, 184 | duration, 185 | ) 186 | 187 | return response.status, response.headers, raw_data 188 | 189 | def _build_headers(self, headers): 190 | if headers: 191 | final_headers = self.headers.copy() 192 | final_headers.update(headers) 193 | else: 194 | final_headers = self.headers 195 | return final_headers 196 | -------------------------------------------------------------------------------- /aioelasticsearch/exceptions.py: -------------------------------------------------------------------------------- 1 | from elasticsearch.exceptions import * # noqa # isort:skip 2 | from elasticsearch.exceptions import (AuthenticationException, # noqa # isort:skip 3 | AuthorizationException) 4 | -------------------------------------------------------------------------------- /aioelasticsearch/helpers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from elasticsearch.helpers import ScanError 4 | 5 | from aioelasticsearch import NotFoundError 6 | 7 | __all__ = ('Scan', 'ScanError') 8 | 9 | 10 | logger = logging.getLogger('elasticsearch') 11 | 12 | 13 | class Scan: 14 | 15 | def __init__( 16 | self, 17 | es, 18 | query=None, 19 | scroll='5m', 20 | raise_on_error=True, 21 | preserve_order=False, 22 | size=1000, 23 | clear_scroll=True, 24 | scroll_kwargs=None, 25 | **kwargs 26 | ): 27 | self._es = es 28 | 29 | if not preserve_order: 30 | query = query.copy() if query else {} 31 | query['sort'] = '_doc' 32 | self._query = query 33 | self._scroll = scroll 34 | self._raise_on_error = raise_on_error 35 | self._size = size 36 | self._clear_scroll = clear_scroll 37 | self._kwargs = kwargs 38 | self._scroll_kwargs = scroll_kwargs or {} 39 | 40 | self._scroll_id = None 41 | 42 | self._total = 0 43 | 44 | self._initial = True 45 | self._done = False 46 | self._hits = [] 47 | self._hits_idx = 0 48 | self._successful_shards = 0 49 | self._total_shards = 0 50 | 51 | async def __aenter__(self): # noqa 52 | await self._do_search() 53 | return self 54 | 55 | async def __aexit__(self, *exc_info): # noqa 56 | await self._do_clear_scroll() 57 | 58 | def __aiter__(self): 59 | if self._initial: 60 | raise RuntimeError("Scan operations should be done " 61 | "inside async context manager") 62 | return self 63 | 64 | async def __anext__(self): # noqa 65 | if self._done: 66 | raise StopAsyncIteration 67 | 68 | if self._hits_idx >= len(self._hits): 69 | if self._successful_shards < self._total_shards: 70 | logger.warning( 71 | 'Scroll request has only succeeded on %d shards out of %d.', # noqa 72 | self._successful_shards, self._total_shards 73 | ) 74 | if self._raise_on_error: 75 | raise ScanError( 76 | self._scroll_id, 77 | 'Scroll request has only succeeded on {} shards out of {}.' # noqa 78 | .format(self._successful_shards, self._total_shards) 79 | ) 80 | 81 | await self._do_scroll() 82 | ret = self._hits[self._hits_idx] 83 | self._hits_idx += 1 84 | return ret 85 | 86 | @property 87 | def scroll_id(self): 88 | if self._initial: 89 | raise RuntimeError("Scan operations should be done " 90 | "inside async context manager") 91 | 92 | return self._scroll_id 93 | 94 | @property 95 | def total(self): 96 | if self._initial: 97 | raise RuntimeError("Scan operations should be done " 98 | "inside async context manager") 99 | return self._total 100 | 101 | async def _do_search(self): 102 | self._initial = False 103 | 104 | try: 105 | resp = await self._es.search( 106 | body=self._query, 107 | scroll=self._scroll, 108 | size=self._size, 109 | **self._kwargs 110 | ) 111 | except NotFoundError: 112 | self._done = True 113 | return 114 | else: 115 | self._total = resp['hits']['total'] 116 | self._update_state(resp) 117 | 118 | async def _do_scroll(self): 119 | resp = await self._es.scroll( 120 | scroll_id=self._scroll_id, 121 | scroll=self._scroll, 122 | **self._scroll_kwargs, 123 | ) 124 | self._update_state(resp) 125 | 126 | if self._done: 127 | raise StopAsyncIteration 128 | 129 | async def _do_clear_scroll(self): 130 | if self._scroll_id is not None and self._clear_scroll: 131 | await self._es.clear_scroll( 132 | body={'scroll_id': [self._scroll_id]}, 133 | ignore=404, 134 | ) 135 | 136 | def _update_state(self, resp): 137 | self._hits = resp['hits']['hits'] 138 | self._hits_idx = 0 139 | self._scroll_id = resp.get('_scroll_id') 140 | self._successful_shards = resp['_shards']['successful'] 141 | self._total_shards = resp['_shards']['total'] 142 | self._done = not self._hits or self._scroll_id is None 143 | -------------------------------------------------------------------------------- /aioelasticsearch/pool.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import collections 3 | import logging 4 | import random 5 | 6 | from elasticsearch.connection_pool import RoundRobinSelector 7 | 8 | from .exceptions import ImproperlyConfigured 9 | 10 | logger = logging.getLogger('elasticsearch') 11 | 12 | 13 | class AIOHttpConnectionPool: 14 | 15 | def __init__( 16 | self, 17 | connections, 18 | dead_timeout=60, 19 | timeout_cutoff=5, 20 | selector_class=RoundRobinSelector, 21 | randomize_hosts=True, 22 | *, 23 | loop, 24 | **kwargs 25 | ): 26 | self._dead_timeout = dead_timeout 27 | self.timeout_cutoff = timeout_cutoff 28 | self.connection_opts = connections 29 | self.connections = [c for (c, _) in connections] 30 | self.orig_connections = set(self.connections) 31 | self.dead = asyncio.PriorityQueue(len(self.connections), loop=loop) 32 | self.dead_count = collections.Counter() 33 | 34 | self.loop = loop 35 | 36 | if randomize_hosts: 37 | random.shuffle(self.connections) 38 | 39 | self.selector = selector_class(dict(connections)) 40 | 41 | def dead_timeout(self, dead_count): 42 | exponent = min(dead_count - 1, self.timeout_cutoff) 43 | return self._dead_timeout * 2 ** exponent 44 | 45 | def mark_dead(self, connection): 46 | now = self.loop.time() 47 | 48 | try: 49 | self.connections.remove(connection) 50 | except ValueError: 51 | # connection not alive or marked already, ignore 52 | return 53 | else: 54 | self.dead_count[connection] += 1 55 | dead_count = self.dead_count[connection] 56 | 57 | timeout = self.dead_timeout(dead_count) 58 | 59 | # it is impossible to raise QueueFull here 60 | self.dead.put_nowait((now + timeout, connection)) 61 | 62 | logger.warning( 63 | 'Connection %r has failed for %i times in a row, ' 64 | 'putting on %i second timeout.', 65 | connection, dead_count, timeout, 66 | ) 67 | 68 | def mark_live(self, connection): 69 | del self.dead_count[connection] 70 | 71 | def resurrect(self, force=False): 72 | if self.dead.empty(): 73 | if force: 74 | # list here is ok, it's a very rare case 75 | return random.choice(list(self.orig_connections)) 76 | return 77 | 78 | timestamp, connection = self.dead.get_nowait() 79 | 80 | if not force and timestamp > self.loop.time(): 81 | # return it back if not eligible and not forced 82 | self.dead.put_nowait((timestamp, connection)) 83 | return 84 | 85 | # either we were forced or the connection is elligible to be retried 86 | self.connections.append(connection) 87 | 88 | logger.info( 89 | 'Resurrecting connection %r (force=%s).', 90 | connection, force, 91 | ) 92 | 93 | return connection 94 | 95 | def get_connection(self): 96 | self.resurrect() 97 | 98 | if not self.connections: 99 | conn = self.resurrect(force=True) 100 | assert conn is not None 101 | return conn 102 | 103 | if len(self.connections) > 1: 104 | return self.selector.select(self.connections) 105 | 106 | return self.connections[0] 107 | 108 | async def close(self, *, skip=frozenset()): 109 | coros = [ 110 | connection.close() for connection in 111 | self.orig_connections - skip 112 | ] 113 | 114 | await asyncio.gather(*coros, loop=self.loop) 115 | 116 | 117 | class DummyConnectionPool(AIOHttpConnectionPool): 118 | 119 | def __init__(self, connections, *, loop, **kwargs): 120 | if len(connections) != 1: 121 | raise ImproperlyConfigured( 122 | 'DummyConnectionPool needs exactly one connection defined.', 123 | ) 124 | 125 | self.loop = loop 126 | 127 | self.connection_opts = connections 128 | self.connection = connections[0][0] 129 | self.connections = [self.connection] 130 | self.orig_connections = set(self.connections) 131 | 132 | def get_connection(self): 133 | return self.connection 134 | 135 | async def close(self, *, skip=frozenset()): 136 | if self.connection in skip: 137 | return 138 | await self.connection.close() 139 | 140 | def mark_live(self, connection): 141 | pass 142 | 143 | def mark_dead(self, connection): 144 | pass 145 | 146 | def resurrect(self, force=False): 147 | pass 148 | -------------------------------------------------------------------------------- /aioelasticsearch/transport.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | from itertools import chain, count 4 | 5 | from elasticsearch.serializer import (DEFAULT_SERIALIZERS, Deserializer, 6 | JSONSerializer) 7 | from elasticsearch.transport import Transport, get_host_info 8 | 9 | from .connection import AIOHttpConnection 10 | from .exceptions import (ConnectionError, ConnectionTimeout, 11 | SerializationError, TransportError) 12 | from .pool import AIOHttpConnectionPool, DummyConnectionPool 13 | 14 | logger = logging.getLogger('elasticsearch') 15 | 16 | 17 | class AIOHttpTransport(Transport): 18 | 19 | def __init__( 20 | self, 21 | hosts, 22 | connection_class=AIOHttpConnection, 23 | connection_pool_class=AIOHttpConnectionPool, 24 | host_info_callback=get_host_info, 25 | serializer=JSONSerializer(), 26 | serializers=None, 27 | sniff_on_start=False, 28 | sniffer_timeout=None, 29 | sniff_timeout=.1, 30 | sniff_on_connection_fail=False, 31 | default_mimetype='application/json', 32 | max_retries=3, 33 | retry_on_status=(502, 503, 504, ), 34 | retry_on_timeout=False, 35 | send_get_body_as='GET', 36 | *, 37 | loop, 38 | **kwargs 39 | ): 40 | self.loop = loop 41 | self._closed = False 42 | 43 | _serializers = DEFAULT_SERIALIZERS.copy() 44 | # if a serializer has been specified, 45 | # use it for deserialization as well 46 | _serializers[serializer.mimetype] = serializer 47 | # if custom serializers map has been supplied, 48 | # override the defaults with it 49 | if serializers: 50 | _serializers.update(serializers) 51 | # create a deserializer with our config 52 | self.deserializer = Deserializer(_serializers, default_mimetype) 53 | 54 | self.max_retries = max_retries 55 | self.retry_on_timeout = retry_on_timeout 56 | self.retry_on_status = retry_on_status 57 | self.send_get_body_as = send_get_body_as 58 | 59 | # data serializer 60 | self.serializer = serializer 61 | 62 | # sniffing data 63 | self.sniffer_timeout = sniffer_timeout 64 | self.sniff_on_connection_fail = sniff_on_connection_fail 65 | self.last_sniff = self.loop.time() 66 | self.sniff_timeout = sniff_timeout 67 | 68 | # callback to construct host dict from data in /_cluster/nodes 69 | self.host_info_callback = host_info_callback 70 | 71 | # store all strategies... 72 | self.connection_pool_class = connection_pool_class 73 | self.connection_class = connection_class 74 | self._connection_pool_lock = asyncio.Lock(loop=self.loop) 75 | 76 | # ...save kwargs to be passed to the connections 77 | self.kwargs = kwargs 78 | self.hosts = hosts 79 | 80 | # ...and instantiate them 81 | self.set_connections(hosts) 82 | # retain the original connection instances for sniffing 83 | self.seed_connections = set(self.connection_pool.connections) 84 | 85 | self.seed_connection_opts = self.connection_pool.connection_opts 86 | 87 | self.initial_sniff_task = None 88 | 89 | if sniff_on_start: 90 | def _initial_sniff_reset(fut): 91 | self.initial_sniff_task = None 92 | 93 | task = self.sniff_hosts(initial=True) 94 | 95 | self.initial_sniff_task = asyncio.ensure_future(task, 96 | loop=self.loop) 97 | self.initial_sniff_task.add_done_callback(_initial_sniff_reset) 98 | 99 | def set_connections(self, hosts): 100 | if self._closed: 101 | raise RuntimeError("Transport is closed") 102 | 103 | def _create_connection(host): 104 | # if this is not the initial setup look at the existing connection 105 | # options and identify connections that haven't changed and can be 106 | # kept around. 107 | if hasattr(self, 'connection_pool'): 108 | existing_connections = (self.connection_pool.connection_opts + 109 | self.seed_connection_opts) 110 | 111 | for (connection, old_host) in existing_connections: 112 | if old_host == host: 113 | return connection 114 | 115 | kwargs = self.kwargs.copy() 116 | kwargs.update(host) 117 | kwargs['loop'] = self.loop 118 | 119 | return self.connection_class(**kwargs) 120 | 121 | connections = map(_create_connection, hosts) 122 | 123 | connections = list(zip(connections, hosts)) 124 | 125 | if len(connections) == 1: 126 | self.connection_pool = DummyConnectionPool( 127 | connections, 128 | loop=self.loop, 129 | **self.kwargs 130 | ) 131 | else: 132 | self.connection_pool = self.connection_pool_class( 133 | connections, 134 | loop=self.loop, 135 | **self.kwargs 136 | ) 137 | 138 | async def _get_sniff_data(self, initial=False): 139 | previous_sniff = self.last_sniff 140 | 141 | tried = set() 142 | 143 | try: 144 | # reset last_sniff timestamp 145 | self.last_sniff = self.loop.time() 146 | for connection in chain( 147 | self.connection_pool.connections, 148 | self.seed_connections, 149 | ): 150 | if connection in tried: 151 | continue 152 | 153 | tried.add(connection) 154 | 155 | try: 156 | # use small timeout for the sniffing request, 157 | # should be a fast api call 158 | _, headers, node_info = await connection.perform_request( 159 | 'GET', 160 | '/_nodes/_all/http', 161 | timeout=self.sniff_timeout if not initial else None, 162 | ) 163 | 164 | node_info = self.deserializer.loads( 165 | node_info, headers.get('content-type'), 166 | ) 167 | break 168 | except (ConnectionError, SerializationError): 169 | pass 170 | else: 171 | raise TransportError('N/A', 'Unable to sniff hosts.') 172 | except: # noqa 173 | # keep the previous value on error 174 | self.last_sniff = previous_sniff 175 | raise 176 | 177 | return list(node_info['nodes'].values()) 178 | 179 | async def sniff_hosts(self, initial=False): 180 | if self._closed: 181 | raise RuntimeError("Transport is closed") 182 | async with self._connection_pool_lock: 183 | node_info = await self._get_sniff_data(initial) 184 | hosts = (self._get_host_info(n) for n in node_info) 185 | hosts = [host for host in hosts if host is not None] 186 | # we weren't able to get any nodes, maybe using an incompatible 187 | # transport_schema or host_info_callback blocked all - raise error. 188 | if not hosts: 189 | raise TransportError( 190 | 'N/A', 'Unable to sniff hosts - no viable hosts found.', 191 | ) 192 | 193 | old_connection_pool = self.connection_pool 194 | 195 | self.set_connections(hosts) 196 | 197 | skip = (self.seed_connections | 198 | self.connection_pool.orig_connections) 199 | 200 | await old_connection_pool.close(skip=skip) 201 | 202 | async def close(self): 203 | if self._closed: 204 | return 205 | seeds = self.seed_connections - self.connection_pool.orig_connections 206 | 207 | coros = [connection.close() for connection in seeds] 208 | 209 | if self.initial_sniff_task is not None: 210 | self.initial_sniff_task.cancel() 211 | 212 | async def _initial_sniff_wrapper(): 213 | try: 214 | await self.initial_sniff_task 215 | except asyncio.CancelledError: 216 | return 217 | 218 | coros.append(_initial_sniff_wrapper()) 219 | 220 | coros.append(self.connection_pool.close()) 221 | 222 | await asyncio.gather(*coros, loop=self.loop) 223 | self._closed = True 224 | 225 | async def get_connection(self): 226 | if self._closed: 227 | raise RuntimeError("Transport is closed") 228 | if self.initial_sniff_task is not None: 229 | await self.initial_sniff_task 230 | 231 | if self.sniffer_timeout: 232 | if self.loop.time() >= self.last_sniff + self.sniffer_timeout: 233 | await self.sniff_hosts() 234 | 235 | async with self._connection_pool_lock: 236 | return self.connection_pool.get_connection() 237 | 238 | async def mark_dead(self, connection): 239 | if self._closed: 240 | raise RuntimeError("Transport is closed") 241 | self.connection_pool.mark_dead(connection) 242 | 243 | if self.sniff_on_connection_fail: 244 | await self.sniff_hosts() 245 | 246 | async def _perform_request( 247 | self, 248 | method, url, params, body, 249 | ignore=(), timeout=None, headers=None, 250 | ): 251 | for attempt in count(1): # pragma: no branch 252 | connection = await self.get_connection() 253 | 254 | try: 255 | 256 | status, headers, data = await connection.perform_request( 257 | method, url, params, body, 258 | ignore=ignore, timeout=timeout, headers=headers, 259 | ) 260 | except TransportError as e: 261 | if method == 'HEAD' and e.status_code == 404: 262 | return False 263 | 264 | retry = False 265 | if isinstance(e, ConnectionTimeout): 266 | retry = self.retry_on_timeout 267 | elif isinstance(e, ConnectionError): 268 | retry = True 269 | elif e.status_code in self.retry_on_status: 270 | retry = True 271 | 272 | if retry: 273 | await self.mark_dead(connection) 274 | 275 | if attempt == self.max_retries: 276 | raise 277 | else: 278 | raise 279 | 280 | else: 281 | self.connection_pool.mark_live(connection) 282 | 283 | if method == 'HEAD': 284 | return 200 <= status < 300 285 | 286 | if data: 287 | data = self.deserializer.loads( 288 | data, headers.get('content-type'), 289 | ) 290 | 291 | return data 292 | 293 | async def perform_request(self, method, url, headers=None, params=None, body=None): # noqa 294 | if self._closed: 295 | raise RuntimeError("Transport is closed") 296 | # yarl fix for https://github.com/elastic/elasticsearch-py/blob/d4efb81b0695f3d9f64784a35891b732823a9c32/elasticsearch/client/utils.py#L29 # noqa 297 | if params is not None: 298 | to_replace = {} 299 | for k, v in params.items(): 300 | if isinstance(v, bytes): 301 | to_replace[k] = v.decode('utf-8') 302 | for k, v in to_replace.items(): 303 | params[k] = v 304 | 305 | if body is not None: 306 | body = self.serializer.dumps(body) 307 | 308 | # some clients or environments don't support sending GET with body 309 | if method in ('HEAD', 'GET') and self.send_get_body_as != 'GET': 310 | # send it as post instead 311 | if self.send_get_body_as == 'POST': 312 | method = 'POST' 313 | 314 | # or as source parameter 315 | elif self.send_get_body_as == 'source': 316 | if params is None: 317 | params = {} 318 | params['source'] = body 319 | params['source_content_type'] = self.serializer.mimetype 320 | body = None 321 | 322 | if body is not None: 323 | try: 324 | body = body.encode('utf-8', 'surrogatepass') 325 | except (UnicodeDecodeError, AttributeError): 326 | # bytes/str - no need to re-encode 327 | pass 328 | 329 | ignore = () 330 | timeout = None 331 | if params: 332 | timeout = params.pop('request_timeout', None) 333 | ignore = params.pop('ignore', ()) 334 | if isinstance(ignore, int): 335 | ignore = (ignore, ) 336 | 337 | return await self._perform_request( 338 | method, url, params, body, 339 | ignore=ignore, timeout=timeout, headers=headers, 340 | ) 341 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | addopts= --keep-duplicates --cache-clear --no-cov-on-fail --cov=aioelasticsearch --cov-report=term --cov-report=html 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.7.4 2 | attrs==19.2.0 3 | flake8==3.7.8 4 | ipdb==0.12.2 5 | ipython==7.16.3 6 | pytest==5.2.1 7 | pytest-cov==2.8.1 8 | pytest-mock==1.11.1 9 | tox==3.14.0 10 | docker==4.1.0 11 | isort==4.3.21 12 | -e . 13 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [coverage:run] 2 | branch = True 3 | omit = site-packages 4 | source = aioelasticsearch, tests 5 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import re 4 | 5 | from setuptools import setup 6 | 7 | 8 | def get_version(): 9 | regex = r"__version__\s=\s\'(?P[\d\.]+?)\'" 10 | 11 | path = ('aioelasticsearch', '__init__.py') 12 | 13 | return re.search(regex, read(*path)).group('version') 14 | 15 | 16 | def read(*parts): 17 | filename = os.path.join(os.path.abspath(os.path.dirname(__file__)), *parts) 18 | 19 | with io.open(filename, encoding='utf-8', mode='rt') as fp: 20 | return fp.read() 21 | 22 | 23 | setup( 24 | name='aioelasticsearch', 25 | version=get_version(), 26 | author='wikibusiness', 27 | author_email='osf@wikibusiness.org', 28 | url='https://github.com/aio-libs/aioelasticsearch', 29 | description='elasticsearch-py wrapper for asyncio', 30 | long_description=read('README.rst'), 31 | install_requires=[ 32 | 'elasticsearch>=7.0.0', 33 | 'aiohttp>=3.5.0,<4.0.0', 34 | ], 35 | python_requires='>=3.5.3', 36 | packages=['aioelasticsearch'], 37 | include_package_data=True, 38 | zip_safe=False, 39 | classifiers=[ 40 | 'Development Status :: 5 - Production/Stable', 41 | 'Intended Audience :: Developers', 42 | 'License :: OSI Approved :: MIT License', 43 | 'Operating System :: POSIX', 44 | 'Operating System :: MacOS :: MacOS X', 45 | 'Operating System :: Microsoft :: Windows', 46 | 'Programming Language :: Python', 47 | 'Programming Language :: Python :: 3', 48 | 'Programming Language :: Python :: 3.5', 49 | 'Programming Language :: Python :: 3.6', 50 | 'Programming Language :: Python :: 3.7', 51 | ], 52 | keywords=['elasticsearch', 'asyncio', 'aiohttp'], 53 | ) 54 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import gc 3 | import time 4 | import uuid 5 | 6 | import elasticsearch 7 | import pytest 8 | from aiohttp.test_utils import unused_port 9 | from docker import from_env as docker_from_env 10 | 11 | import aioelasticsearch 12 | 13 | 14 | @pytest.fixture 15 | def loop(request): 16 | asyncio.set_event_loop(None) 17 | 18 | loop = asyncio.new_event_loop() 19 | 20 | yield loop 21 | 22 | if not loop._closed: 23 | loop.call_soon(loop.stop) 24 | loop.run_forever() 25 | loop.close() 26 | 27 | gc.collect() 28 | asyncio.set_event_loop(None) 29 | 30 | 31 | @pytest.fixture(scope='session') 32 | def session_id(): 33 | '''Unique session identifier, random string.''' 34 | return str(uuid.uuid4()) 35 | 36 | 37 | @pytest.fixture(scope='session') 38 | def docker(): 39 | client = docker_from_env(version='auto') 40 | return client 41 | 42 | 43 | def pytest_addoption(parser): 44 | parser.addoption("--es_tag", action="append", default=[], 45 | help=("Elasticsearch server versions. " 46 | "May be used several times. " 47 | "6.0.0 by default")) 48 | parser.addoption("--no-pull", action="store_true", default=False, 49 | help="Don't perform docker images pulling") 50 | parser.addoption("--local-docker", action="store_true", default=False, 51 | help="Use 0.0.0.0 as docker host, useful for MacOs X") 52 | 53 | 54 | def pytest_generate_tests(metafunc): 55 | if 'es_tag' in metafunc.fixturenames: 56 | tags = set(metafunc.config.option.es_tag) 57 | if not tags: 58 | tags = ['6.0.0'] 59 | else: 60 | tags = list(tags) 61 | metafunc.parametrize("es_tag", tags, scope='session') 62 | 63 | 64 | @pytest.fixture(scope='session') 65 | def es_container(docker, session_id, es_tag, request): 66 | image = 'docker.elastic.co/elasticsearch/elasticsearch:{}'.format(es_tag) 67 | 68 | if not request.config.option.no_pull: 69 | docker.images.pull(image) 70 | 71 | es_auth = ('elastic', 'changeme') 72 | 73 | if request.config.option.local_docker: 74 | es_port_9200 = es_access_port = unused_port() 75 | es_port_9300 = unused_port() 76 | else: 77 | es_port_9200 = es_port_9300 = None 78 | es_access_port = 9200 79 | 80 | container = docker.containers.run( 81 | image=image, 82 | detach=True, 83 | name='aioelasticsearch-' + session_id, 84 | ports={ 85 | '9200/tcp': es_port_9200, 86 | '9300/tcp': es_port_9300, 87 | }, 88 | environment={ 89 | 'http.host': '0.0.0.0', 90 | 'transport.host': '127.0.0.1', 91 | }, 92 | ) 93 | 94 | if request.config.option.local_docker: 95 | docker_host = '0.0.0.0' 96 | else: 97 | inspection = docker.api.inspect_container(container.id) 98 | docker_host = inspection['NetworkSettings']['IPAddress'] 99 | 100 | delay = 0.1 101 | for i in range(20): 102 | es = elasticsearch.Elasticsearch( 103 | [{ 104 | 'host': docker_host, 105 | 'port': es_access_port, 106 | }], 107 | http_auth=es_auth, 108 | ) 109 | 110 | try: 111 | es.transport.perform_request('GET', '/_nodes/_all/http') 112 | except elasticsearch.TransportError: 113 | time.sleep(delay) 114 | delay *= 2 115 | else: 116 | break 117 | finally: 118 | es.transport.close() 119 | else: 120 | pytest.fail("Cannot start elastic server") 121 | 122 | yield { 123 | 'host': docker_host, 124 | 'port': es_access_port, 125 | 'auth': es_auth, 126 | } 127 | 128 | container.kill(signal=9) 129 | container.remove(force=True) 130 | 131 | 132 | @pytest.fixture 133 | def es_clean(es_container): 134 | def do(): 135 | es = elasticsearch.Elasticsearch( 136 | hosts=[{ 137 | 'host': es_container['host'], 138 | 'port': es_container['port'], 139 | }], 140 | http_auth=es_container['auth'], 141 | ) 142 | 143 | try: 144 | es.transport.perform_request('DELETE', '/_template/*') 145 | es.transport.perform_request('DELETE', '/_all') 146 | finally: 147 | es.transport.close() 148 | 149 | return do 150 | 151 | 152 | @pytest.fixture 153 | def es_server(es_clean, es_container): 154 | es_clean() 155 | 156 | return es_container 157 | 158 | 159 | @pytest.fixture 160 | def es(es_server, auto_close, loop): 161 | es = aioelasticsearch.Elasticsearch( 162 | hosts=[{ 163 | 'host': es_server['host'], 164 | 'port': es_server['port'], 165 | }], 166 | http_auth=es_server['auth'], 167 | loop=loop, 168 | ) 169 | 170 | return auto_close(es) 171 | 172 | 173 | @pytest.fixture 174 | def auto_close(loop): 175 | close_list = [] 176 | 177 | def f(arg): 178 | close_list.append(arg) 179 | return arg 180 | 181 | yield f 182 | 183 | for arg in close_list: 184 | loop.run_until_complete(arg.close()) 185 | 186 | 187 | @pytest.mark.tryfirst 188 | def pytest_pycollect_makeitem(collector, name, obj): 189 | if collector.funcnamefilter(name): 190 | item = pytest.Function(name, parent=collector) 191 | 192 | if 'run_loop' in item.keywords: 193 | return list(collector._genfunctions(name, obj)) 194 | 195 | 196 | @pytest.mark.tryfirst 197 | def pytest_pyfunc_call(pyfuncitem): 198 | if 'run_loop' in pyfuncitem.keywords: 199 | funcargs = pyfuncitem.funcargs 200 | 201 | loop = funcargs['loop'] 202 | 203 | testargs = { 204 | arg: funcargs[arg] 205 | for arg in pyfuncitem._fixtureinfo.argnames 206 | } 207 | 208 | assert asyncio.iscoroutinefunction(pyfuncitem.obj) 209 | 210 | loop.run_until_complete(pyfuncitem.obj(**testargs)) 211 | 212 | return True 213 | 214 | 215 | @pytest.fixture 216 | def populate(es, loop): 217 | async def do(index, n, body): 218 | coros = [] 219 | 220 | await es.indices.create(index) 221 | 222 | for i in range(n): 223 | coros.append( 224 | es.index( 225 | index=index, 226 | id=str(i), 227 | body=body, 228 | ), 229 | ) 230 | 231 | await asyncio.gather(*coros, loop=loop) 232 | await es.indices.refresh() 233 | return do 234 | -------------------------------------------------------------------------------- /tests/test_aioelasticsearch.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import pytest 4 | 5 | from aioelasticsearch import Elasticsearch 6 | 7 | 8 | @pytest.mark.run_loop 9 | async def test_ping(es): 10 | ping = await es.ping() 11 | 12 | assert ping 13 | 14 | 15 | @pytest.mark.run_loop 16 | @asyncio.coroutine 17 | def test_ping2(es): 18 | ping = yield from es.ping() 19 | 20 | assert ping 21 | 22 | 23 | def test_elastic_default_loop(auto_close, loop): 24 | asyncio.set_event_loop(loop) 25 | 26 | es = Elasticsearch() 27 | 28 | auto_close(es) 29 | 30 | assert es.loop is loop 31 | -------------------------------------------------------------------------------- /tests/test_connection.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import ssl 3 | from unittest import mock 4 | 5 | import aiohttp 6 | import pytest 7 | from elasticsearch import ConnectionTimeout 8 | 9 | from aioelasticsearch.connection import (AIOHttpConnection, ConnectionError, 10 | SSLError) 11 | 12 | 13 | @pytest.mark.run_loop 14 | async def test_default_headers(auto_close, loop): 15 | conn = auto_close(AIOHttpConnection(loop=loop)) 16 | assert conn.headers == {'Content-Type': 'application/json'} 17 | 18 | 19 | @pytest.mark.run_loop 20 | async def test_custom_headers(auto_close, loop): 21 | conn = auto_close(AIOHttpConnection(headers={'X-Custom': 'value'}, 22 | loop=loop)) 23 | assert conn.headers == {'Content-Type': 'application/json', 24 | 'X-Custom': 'value'} 25 | 26 | 27 | @pytest.mark.run_loop 28 | async def test_auth_no_auth(auto_close, loop): 29 | conn = auto_close(AIOHttpConnection(loop=loop)) 30 | assert conn.http_auth is None 31 | 32 | 33 | @pytest.mark.run_loop 34 | async def test_ssl_context(auto_close, loop): 35 | context = ssl.create_default_context() 36 | conn = auto_close( 37 | AIOHttpConnection(loop=loop, verify_certs=True, ssl_context=context) 38 | ) 39 | assert conn.session.connector._ssl is context 40 | 41 | 42 | @pytest.mark.run_loop 43 | async def test_auth_str(auto_close, loop): 44 | auth = aiohttp.BasicAuth('user', 'pass') 45 | conn = auto_close(AIOHttpConnection(http_auth='user:pass', loop=loop)) 46 | assert conn.http_auth == auth 47 | 48 | 49 | @pytest.mark.run_loop 50 | async def test_auth_tuple(auto_close, loop): 51 | auth = aiohttp.BasicAuth('user', 'pass') 52 | conn = auto_close(AIOHttpConnection(http_auth=('user', 'pass'), loop=loop)) 53 | assert conn.http_auth == auth 54 | 55 | 56 | @pytest.mark.run_loop 57 | async def test_auth_basicauth(auto_close, loop): 58 | auth = aiohttp.BasicAuth('user', 'pass') 59 | conn = auto_close(AIOHttpConnection(http_auth=auth, loop=loop)) 60 | assert conn.http_auth == auth 61 | 62 | 63 | @pytest.mark.run_loop 64 | async def test_auth_invalid(loop): 65 | with pytest.raises(TypeError): 66 | AIOHttpConnection(http_auth=object(), loop=loop) 67 | 68 | 69 | @pytest.mark.run_loop 70 | async def test_explicit_session(auto_close, loop): 71 | session = aiohttp.ClientSession(loop=loop) 72 | conn = auto_close(AIOHttpConnection(session=session, loop=loop)) 73 | assert conn.session is session 74 | 75 | 76 | @pytest.mark.run_loop 77 | async def test_explicit_session_not_closed(loop): 78 | session = aiohttp.ClientSession(loop=loop) 79 | conn = AIOHttpConnection(session=session, loop=loop) 80 | await conn.close() 81 | assert not conn.session.closed and not session.closed 82 | 83 | 84 | @pytest.mark.run_loop 85 | async def test_default_session(auto_close, loop): 86 | conn = auto_close(AIOHttpConnection(loop=loop)) 87 | assert isinstance(conn.session, aiohttp.ClientSession) 88 | 89 | 90 | @pytest.mark.run_loop 91 | async def test_session_closed(loop): 92 | conn = AIOHttpConnection(loop=loop) 93 | await conn.close() 94 | assert conn.session.closed 95 | 96 | 97 | @pytest.mark.run_loop 98 | async def test_perform_request_ssl_error(auto_close, loop): 99 | for exc, expected in [ 100 | (aiohttp.ClientConnectorCertificateError(mock.Mock(), mock.Mock()), SSLError), # noqa 101 | (aiohttp.ClientConnectorSSLError(mock.Mock(), mock.Mock()), SSLError), 102 | (aiohttp.ClientSSLError(mock.Mock(), mock.Mock()), SSLError), 103 | (aiohttp.ClientError('Other'), ConnectionError), 104 | (asyncio.TimeoutError, ConnectionTimeout), 105 | ]: 106 | session = aiohttp.ClientSession(loop=loop) 107 | 108 | async def coro(*args, **Kwargs): 109 | raise exc 110 | 111 | session._request = coro 112 | 113 | conn = auto_close(AIOHttpConnection(session=session, loop=loop, 114 | use_ssl=True)) 115 | with pytest.raises(expected): 116 | await conn.perform_request('HEAD', '/') 117 | -------------------------------------------------------------------------------- /tests/test_pool.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from aioelasticsearch import (AIOHttpConnectionPool, Elasticsearch, 4 | ImproperlyConfigured) 5 | from aioelasticsearch.pool import DummyConnectionPool 6 | 7 | 8 | @pytest.mark.run_loop 9 | async def test_mark_dead_removed_connection(auto_close, es_server, loop): 10 | es = auto_close(Elasticsearch(hosts=[{'host': es_server['host'], 11 | 'port': es_server['port']}, 12 | {'host': 'unknown_host', 13 | 'port': 9200}], 14 | http_auth=es_server['auth'], 15 | loop=loop)) 16 | conn = await es.transport.get_connection() 17 | pool = es.transport.connection_pool 18 | pool.mark_dead(conn) 19 | assert conn in pool.dead_count 20 | # second call should succeed 21 | pool.mark_dead(conn) 22 | assert conn in pool.dead_count 23 | 24 | 25 | @pytest.mark.run_loop 26 | async def test_mark_live(auto_close, es_server, loop): 27 | es = auto_close(Elasticsearch(hosts=[{'host': es_server['host'], 28 | 'port': es_server['port']}, 29 | {'host': 'unknown_host', 30 | 'port': 9200}], 31 | http_auth=es_server['auth'], 32 | loop=loop)) 33 | conn = await es.transport.get_connection() 34 | pool = es.transport.connection_pool 35 | pool.mark_dead(conn) 36 | assert conn in pool.dead_count 37 | 38 | pool.mark_live(conn) 39 | assert conn not in pool.dead_count 40 | 41 | 42 | @pytest.mark.run_loop 43 | async def test_mark_live_not_dead(auto_close, es_server, loop): 44 | es = auto_close(Elasticsearch(hosts=[{'host': es_server['host'], 45 | 'port': es_server['port']}, 46 | {'host': 'unknown_host', 47 | 'port': 9200}], 48 | http_auth=es_server['auth'], 49 | loop=loop)) 50 | conn = await es.transport.get_connection() 51 | pool = es.transport.connection_pool 52 | pool.mark_live(conn) 53 | assert conn not in pool.dead_count 54 | 55 | 56 | @pytest.mark.run_loop 57 | async def test_resurrect_empty(loop): 58 | conn1 = object() 59 | conn2 = object() 60 | conns = [(conn1, object()), (conn2, object())] 61 | pool = AIOHttpConnectionPool(connections=conns, 62 | randomize_hosts=False, loop=loop) 63 | pool.resurrect() 64 | assert pool.connections == [conn1, conn2] 65 | 66 | 67 | @pytest.mark.run_loop 68 | async def test_resurrect_empty_force(loop): 69 | conn1 = object() 70 | conn2 = object() 71 | conns = [(conn1, object()), (conn2, object())] 72 | pool = AIOHttpConnectionPool(connections=conns, 73 | randomize_hosts=False, loop=loop) 74 | assert pool.resurrect(force=True) in (conn1, conn2) 75 | 76 | 77 | @pytest.mark.run_loop 78 | async def test_resurrect_from_dead_not_ready_connection(loop): 79 | conn1 = object() 80 | conn2 = object() 81 | conns = [(conn1, object()), (conn2, object())] 82 | pool = AIOHttpConnectionPool(connections=conns, 83 | randomize_hosts=False, loop=loop) 84 | pool.mark_dead(conn1) 85 | pool.resurrect() 86 | assert pool.connections == [conn2] 87 | 88 | 89 | @pytest.mark.run_loop 90 | async def test_resurrect_from_dead_ready_connection(loop): 91 | conn1 = object() 92 | conn2 = object() 93 | conns = [(conn1, object()), (conn2, object())] 94 | pool = AIOHttpConnectionPool(connections=conns, 95 | randomize_hosts=False, loop=loop) 96 | pool.dead_timeout = lambda t: 0 97 | pool.mark_dead(conn1) 98 | pool.resurrect() 99 | assert pool.connections == [conn2, conn1] 100 | 101 | 102 | @pytest.mark.run_loop 103 | async def test_get_connections_only_one_conn(loop): 104 | conn1 = object() 105 | conn2 = object() 106 | conns = [(conn1, object()), (conn2, object())] 107 | pool = AIOHttpConnectionPool(connections=conns, 108 | randomize_hosts=False, loop=loop) 109 | pool.mark_dead(conn1) 110 | conn = pool.get_connection() 111 | assert conn is conn2 112 | 113 | 114 | @pytest.mark.run_loop 115 | async def test_get_connections_no_conns(loop): 116 | conn1 = object() 117 | conn2 = object() 118 | conns = [(conn1, object()), (conn2, object())] 119 | pool = AIOHttpConnectionPool(connections=conns, 120 | randomize_hosts=False, loop=loop) 121 | pool.mark_dead(conn1) 122 | pool.mark_dead(conn2) 123 | conn = pool.get_connection() 124 | assert conn in (conn1, conn2) 125 | 126 | 127 | @pytest.mark.run_loop 128 | async def test_dummy_improperly_configured(loop): 129 | conn1 = object() 130 | conn2 = object() 131 | conns = [(conn1, object()), (conn2, object())] 132 | with pytest.raises(ImproperlyConfigured): 133 | DummyConnectionPool(connections=conns, loop=loop) 134 | 135 | 136 | @pytest.mark.run_loop 137 | async def test_dummy_mark_dead_and_live(loop): 138 | conn1 = object() 139 | conns = [(conn1, object())] 140 | 141 | pool = DummyConnectionPool(connections=conns, loop=loop) 142 | pool.mark_dead(conn1) 143 | assert pool.connections == [conn1] 144 | 145 | pool.mark_live(conn1) 146 | assert pool.connections == [conn1] 147 | 148 | pool.resurrect() 149 | assert pool.connections == [conn1] 150 | -------------------------------------------------------------------------------- /tests/test_scan.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from unittest import mock 3 | 4 | import pytest 5 | 6 | from aioelasticsearch import NotFoundError 7 | from aioelasticsearch.helpers import Scan, ScanError 8 | 9 | logger = logging.getLogger('elasticsearch') 10 | 11 | 12 | def test_scan_total_without_context_manager(es): 13 | scan = Scan(es) 14 | 15 | with pytest.raises(RuntimeError): 16 | scan.total 17 | 18 | 19 | @pytest.mark.run_loop 20 | async def test_scan_async_for_without_context_manager(es): 21 | scan = Scan(es) 22 | 23 | with pytest.raises(RuntimeError): 24 | async for doc in scan: 25 | doc 26 | 27 | 28 | def test_scan_scroll_id_without_context_manager(es): 29 | scan = Scan(es) 30 | 31 | with pytest.raises(RuntimeError): 32 | scan.scroll_id 33 | 34 | 35 | @pytest.mark.run_loop 36 | async def test_scan_simple(es, populate): 37 | index = 'test_aioes' 38 | scroll_size = 3 39 | n = 10 40 | 41 | body = {'foo': 1} 42 | await populate(index, n, body) 43 | ids = set() 44 | 45 | async with Scan( 46 | es, 47 | index=index, 48 | size=scroll_size, 49 | ) as scan: 50 | assert isinstance(scan.scroll_id, str) 51 | assert scan.total['value'] == 10 52 | async for doc in scan: 53 | ids.add(doc['_id']) 54 | assert doc == {'_id': mock.ANY, 55 | '_index': 'test_aioes', 56 | '_score': None, 57 | '_source': {'foo': 1}, 58 | '_type': '_doc', 59 | 'sort': mock.ANY} 60 | 61 | assert ids == {str(i) for i in range(10)} 62 | 63 | 64 | @pytest.mark.run_loop 65 | async def test_scan_equal_chunks_for_loop(es, es_clean, populate): 66 | for n, scroll_size in [ 67 | (0, 1), # no results 68 | (6, 6), # 1 scroll 69 | (6, 8), # 1 scroll 70 | (6, 3), # 2 scrolls 71 | (6, 4), # 2 scrolls 72 | (6, 2), # 3 scrolls 73 | (6, 1), # 6 scrolls 74 | ]: 75 | es_clean() 76 | 77 | index = 'test_aioes' 78 | body = {'foo': 1} 79 | 80 | await populate(index, n, body) 81 | 82 | ids = set() 83 | 84 | async with Scan( 85 | es, 86 | index=index, 87 | size=scroll_size, 88 | ) as scan: 89 | 90 | async for doc in scan: 91 | ids.add(doc['_id']) 92 | 93 | # check number of unique doc ids 94 | assert len(ids) == n == scan.total['value'] 95 | 96 | 97 | @pytest.mark.run_loop 98 | async def test_scan_no_mask_index(es): 99 | index = 'undefined-*' 100 | scroll_size = 3 101 | 102 | async with Scan( 103 | es, 104 | index=index, 105 | size=scroll_size, 106 | ) as scan: 107 | assert scan.scroll_id is None 108 | assert scan.total['value'] == 0 109 | cnt = 0 110 | async for doc in scan: # noqa 111 | cnt += 1 112 | assert cnt == 0 113 | 114 | 115 | @pytest.mark.run_loop 116 | async def test_scan_no_scroll(es, loop, populate): 117 | index = 'test_aioes' 118 | n = 10 119 | scroll_size = 1 120 | body = {'foo': 1} 121 | 122 | await populate(index, n, body) 123 | 124 | async with Scan( 125 | es, 126 | size=scroll_size, 127 | ) as scan: 128 | # same comes after search context expiration 129 | await scan._do_clear_scroll() 130 | 131 | with pytest.raises(NotFoundError): 132 | async for doc in scan: 133 | doc 134 | 135 | 136 | @pytest.mark.run_loop 137 | async def test_scan_no_index(es): 138 | index = 'undefined' 139 | scroll_size = 3 140 | 141 | async with Scan( 142 | es, 143 | index=index, 144 | size=scroll_size, 145 | ) as scan: 146 | assert scan.scroll_id is None 147 | assert scan.total == 0 148 | cnt = 0 149 | async for doc in scan: # noqa 150 | cnt += 1 151 | assert cnt == 0 152 | 153 | 154 | @pytest.mark.run_loop 155 | async def test_scan_warning_on_failed_shards(es, populate, mocker): 156 | index = 'test_aioes' 157 | scroll_size = 3 158 | n = 10 159 | 160 | body = {'foo': 1} 161 | await populate(index, n, body) 162 | 163 | mocker.spy(logger, 'warning') 164 | 165 | async with Scan( 166 | es, 167 | index=index, 168 | size=scroll_size, 169 | raise_on_error=False, 170 | ) as scan: 171 | i = 0 172 | async for doc in scan: # noqa 173 | if i == 3: 174 | # once after first scroll 175 | scan._successful_shards = 4 176 | scan._total_shards = 5 177 | i += 1 178 | 179 | logger.warning.assert_called_once_with( 180 | 'Scroll request has only succeeded on %d shards out of %d.', 4, 5) 181 | 182 | 183 | @pytest.mark.run_loop 184 | async def test_scan_exception_on_failed_shards(es, populate, mocker): 185 | index = 'test_aioes' 186 | scroll_size = 3 187 | n = 10 188 | 189 | body = {'foo': 1} 190 | await populate(index, n, body) 191 | 192 | mocker.spy(logger, 'warning') 193 | 194 | i = 0 195 | async with Scan( 196 | es, 197 | index=index, 198 | size=scroll_size, 199 | ) as scan: 200 | with pytest.raises(ScanError) as cm: 201 | async for doc in scan: # noqa 202 | if i == 3: 203 | # once after first scroll 204 | scan._successful_shards = 4 205 | scan._total_shards = 5 206 | i += 1 207 | 208 | assert (str(cm.value) == 209 | 'Scroll request has only succeeded on 4 shards out of 5.') 210 | 211 | assert i == 6 212 | logger.warning.assert_called_once_with( 213 | 'Scroll request has only succeeded on %d shards out of %d.', 4, 5) 214 | -------------------------------------------------------------------------------- /tests/test_transport.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from aioelasticsearch import (AIOHttpTransport, ConnectionError, 4 | ConnectionTimeout, Elasticsearch, TransportError) 5 | from aioelasticsearch.connection import AIOHttpConnection 6 | 7 | 8 | class DummyConnection(AIOHttpConnection): 9 | def __init__(self, **kwargs): 10 | self.exception = kwargs.pop('exception', None) 11 | self.status = kwargs.pop('status', 200) 12 | self.data = kwargs.pop('data', '{}') 13 | self.headers = kwargs.pop('headers', {}) 14 | self.calls = [] 15 | super().__init__(**kwargs) 16 | 17 | async def perform_request(self, *args, **kwargs): 18 | self.calls.append((args, kwargs)) 19 | if self.exception: 20 | raise self.exception 21 | return self.status, self.headers, self.data 22 | 23 | 24 | @pytest.mark.run_loop 25 | async def test_custom_serializers(auto_close, loop): 26 | serializer = object() 27 | t = auto_close(AIOHttpTransport([{}], 28 | serializers={'test': serializer}, 29 | loop=loop)) 30 | assert 'test' in t.deserializer.serializers 31 | assert t.deserializer.serializers['test'] is serializer 32 | 33 | 34 | @pytest.mark.run_loop 35 | async def test_no_sniff_on_start(auto_close, loop): 36 | t = auto_close(AIOHttpTransport([{}], sniff_on_start=False, loop=loop)) 37 | assert t.initial_sniff_task is None 38 | 39 | 40 | @pytest.mark.run_loop 41 | async def test_sniff_on_start(auto_close, loop, es_server): 42 | t = auto_close(AIOHttpTransport([{'host': 'unknown_host', 43 | 'port': 9200}, 44 | {'host': es_server['host'], 45 | 'port': es_server['port']}], 46 | http_auth=es_server['auth'], 47 | sniff_on_start=True, loop=loop)) 48 | assert t.initial_sniff_task is not None 49 | await t.initial_sniff_task 50 | assert t.initial_sniff_task is None 51 | assert len(t.connection_pool.connections) == 1 52 | 53 | 54 | @pytest.mark.run_loop 55 | async def test_close_with_sniff_on_start(loop, es_server): 56 | t = AIOHttpTransport([{'host': es_server['host'], 57 | 'port': es_server['port']}], 58 | http_auth=es_server['auth'], 59 | sniff_on_start=True, loop=loop) 60 | assert t.initial_sniff_task is not None 61 | await t.close() 62 | assert t.initial_sniff_task is None 63 | assert t._closed 64 | 65 | 66 | @pytest.mark.run_loop 67 | async def test_get_connection_with_sniff_on_start(auto_close, loop, es_server): 68 | t = auto_close(AIOHttpTransport([{'host': es_server['host'], 69 | 'port': es_server['port']}], 70 | http_auth=es_server['auth'], 71 | sniff_on_start=True, loop=loop)) 72 | conn = await t.get_connection() 73 | assert conn is not None 74 | assert t.initial_sniff_task is None 75 | 76 | 77 | @pytest.mark.run_loop 78 | async def test_get_connection_with_sniffer_timeout(auto_close, 79 | loop, es_server): 80 | t = auto_close(AIOHttpTransport([{'host': 'unknown_host', 81 | 'port': 9200}, 82 | {'host': es_server['host'], 83 | 'port': es_server['port']}], 84 | http_auth=es_server['auth'], 85 | sniffer_timeout=10, loop=loop)) 86 | assert t.initial_sniff_task is None 87 | t.last_sniff -= 15 88 | conn = await t.get_connection() 89 | assert conn is not None 90 | assert t.initial_sniff_task is None 91 | assert len(t.connection_pool.connections) == 1 92 | 93 | 94 | @pytest.mark.run_loop 95 | async def test_get_connection_without_sniffer_timeout(auto_close, 96 | loop, es_server): 97 | t = auto_close(AIOHttpTransport([{'host': 'unknown_host', 98 | 'port': 9200}, 99 | {'host': es_server['host'], 100 | 'port': es_server['port']}], 101 | http_auth=es_server['auth'], 102 | sniffer_timeout=1e12, loop=loop)) 103 | conn = await t.get_connection() 104 | assert conn is not None 105 | assert t.initial_sniff_task is None 106 | assert len(t.connection_pool.connections) == 2 107 | 108 | 109 | @pytest.mark.run_loop 110 | async def test_sniff_hosts_error(auto_close, loop, es_server): 111 | t = auto_close(AIOHttpTransport([{'host': 'unknown_host', 112 | 'port': 9200}], 113 | loop=loop)) 114 | with pytest.raises(TransportError): 115 | await t.sniff_hosts() 116 | 117 | 118 | @pytest.mark.run_loop 119 | async def test_sniff_hosts_no_hosts(auto_close, loop, es_server): 120 | t = auto_close(AIOHttpTransport([{'host': es_server['host'], 121 | 'port': es_server['port']}], 122 | http_auth=es_server['auth'], 123 | loop=loop)) 124 | t.host_info_callback = lambda host_info, host: None 125 | with pytest.raises(TransportError): 126 | await t.sniff_hosts() 127 | 128 | 129 | @pytest.mark.run_loop 130 | async def test_mark_dead(auto_close, loop, es_server): 131 | t = auto_close(AIOHttpTransport([{'host': 'unknown_host', 132 | 'port': 9200}, 133 | {'host': es_server['host'], 134 | 'port': es_server['port']}], 135 | http_auth=es_server['auth'], 136 | randomize_hosts=False, 137 | loop=loop)) 138 | conn = t.connection_pool.connections[0] 139 | assert conn is not None 140 | assert conn.host == 'http://unknown_host:9200' 141 | await t.mark_dead(conn) 142 | assert len(t.connection_pool.connections) == 1 143 | 144 | 145 | @pytest.mark.run_loop 146 | async def test_mark_dead_with_sniff(auto_close, loop, es_server): 147 | t = auto_close(AIOHttpTransport([{'host': 'unknown_host', 148 | 'port': 9200}, 149 | {'host': 'unknown_host2', 150 | 'port': 9200}, 151 | {'host': es_server['host'], 152 | 'port': es_server['port']}], 153 | http_auth=es_server['auth'], 154 | sniff_on_connection_fail=True, 155 | randomize_hosts=False, 156 | loop=loop)) 157 | conn = t.connection_pool.connections[0] 158 | assert conn is not None 159 | assert conn.host == 'http://unknown_host:9200' 160 | await t.mark_dead(conn) 161 | assert len(t.connection_pool.connections) == 1 162 | 163 | 164 | @pytest.mark.run_loop 165 | async def test_send_get_body_as_post(es_server, auto_close, loop): 166 | cl = auto_close(Elasticsearch([{'host': es_server['host'], 167 | 'port': es_server['port']}], 168 | send_get_body_as='POST', 169 | http_auth=es_server['auth'], 170 | loop=loop)) 171 | await cl.create('test', '1', {'val': '1'}) 172 | await cl.create('test', '2', {'val': '2'}) 173 | ret = await cl.mget( 174 | {"docs": [ 175 | {"_id": "1"}, 176 | {"_id": "2"} 177 | ]}, 178 | index='test', 179 | ) 180 | assert ret == {'docs': [{'_id': '1', 181 | '_index': 'test', 182 | '_source': {'val': '1'}, 183 | '_type': '_doc', 184 | '_version': 1, 185 | '_primary_term': 1, 186 | '_seq_no': 0, 187 | 'found': True}, 188 | {'_id': '2', 189 | '_index': 'test', 190 | '_source': {'val': '2'}, 191 | '_type': '_doc', 192 | '_version': 1, 193 | '_primary_term': 1, 194 | '_seq_no': 1, 195 | 'found': True}]} 196 | 197 | 198 | @pytest.mark.run_loop 199 | async def test_send_get_body_as_source(es_server, auto_close, loop): 200 | cl = auto_close(Elasticsearch([{'host': es_server['host'], 201 | 'port': es_server['port']}], 202 | send_get_body_as='source', 203 | http_auth=es_server['auth'], 204 | loop=loop)) 205 | await cl.create('test', '1', {'val': '1'}) 206 | await cl.create('test', '2', {'val': '2'}) 207 | ret = await cl.mget( 208 | {"docs": [ 209 | {"_id": "1"}, 210 | {"_id": "2"} 211 | ]}, 212 | index='test', 213 | ) 214 | assert ret == {'docs': [{'_id': '1', 215 | '_index': 'test', 216 | '_source': {'val': '1'}, 217 | '_type': '_doc', 218 | '_version': 1, 219 | '_primary_term': 1, 220 | '_seq_no': 0, 221 | 'found': True}, 222 | {'_id': '2', 223 | '_index': 'test', 224 | '_source': {'val': '2'}, 225 | '_type': '_doc', 226 | '_version': 1, 227 | '_primary_term': 1, 228 | '_seq_no': 1, 229 | 'found': True}]} 230 | 231 | 232 | @pytest.mark.run_loop 233 | async def test_send_get_body_as_get(es_server, auto_close, loop): 234 | cl = auto_close(Elasticsearch([{'host': es_server['host'], 235 | 'port': es_server['port']}], 236 | http_auth=es_server['auth'], 237 | loop=loop)) 238 | await cl.create('test', '1', {'val': '1'}) 239 | await cl.create('test', '2', {'val': '2'}) 240 | ret = await cl.mget( 241 | {"docs": [ 242 | {"_id": "1"}, 243 | {"_id": "2"} 244 | ]}, 245 | index='test', 246 | ) 247 | assert ret == {'docs': [{'_id': '1', 248 | '_index': 'test', 249 | '_source': {'val': '1'}, 250 | '_type': '_doc', 251 | '_version': 1, 252 | '_primary_term': 1, 253 | '_seq_no': 0, 254 | 'found': True}, 255 | {'_id': '2', 256 | '_index': 'test', 257 | '_source': {'val': '2'}, 258 | '_type': '_doc', 259 | '_version': 1, 260 | '_primary_term': 1, 261 | '_seq_no': 1, 262 | 'found': True}]} 263 | 264 | 265 | @pytest.mark.run_loop 266 | async def test_send_get_body_as_source_none_params(es_server, 267 | auto_close, loop): 268 | cl = auto_close(Elasticsearch([{'host': es_server['host'], 269 | 'port': es_server['port']}], 270 | send_get_body_as='source', 271 | http_auth=es_server['auth'], 272 | loop=loop)) 273 | await cl.create('test', '1', {'val': '1'}) 274 | await cl.create('test', '2', {'val': '2'}) 275 | ret = await cl.transport.perform_request( 276 | 'GET', 'test/_mget', 277 | body={"docs": [ 278 | {"_id": "1"}, 279 | {"_id": "2"} 280 | ]}) 281 | assert ret == {'docs': [{'_id': '1', 282 | '_index': 'test', 283 | '_source': {'val': '1'}, 284 | '_type': '_doc', 285 | '_version': 1, 286 | '_primary_term': 1, 287 | '_seq_no': 0, 288 | 'found': True}, 289 | {'_id': '2', 290 | '_index': 'test', 291 | '_source': {'val': '2'}, 292 | '_type': '_doc', 293 | '_version': 1, 294 | '_primary_term': 1, 295 | '_seq_no': 1, 296 | 'found': True}]} 297 | 298 | 299 | @pytest.mark.run_loop 300 | async def test_set_connections_closed(es): 301 | await es.close() 302 | with pytest.raises(RuntimeError): 303 | es.transport.set_connections(['host1', 'host2']) 304 | 305 | 306 | @pytest.mark.run_loop 307 | async def test_sniff_hosts_closed(es): 308 | await es.close() 309 | with pytest.raises(RuntimeError): 310 | await es.transport.sniff_hosts() 311 | 312 | 313 | @pytest.mark.run_loop 314 | async def test_close_closed(es): 315 | await es.close() 316 | await es.close() 317 | 318 | 319 | @pytest.mark.run_loop 320 | async def test_get_connection_closed(es): 321 | await es.close() 322 | with pytest.raises(RuntimeError): 323 | await es.transport.get_connection() 324 | 325 | 326 | @pytest.mark.run_loop 327 | async def test_mark_dead_closed(es): 328 | await es.close() 329 | conn = object() 330 | with pytest.raises(RuntimeError): 331 | await es.transport.mark_dead(conn) 332 | 333 | 334 | @pytest.mark.run_loop 335 | async def test_perform_request_closed(es): 336 | await es.close() 337 | with pytest.raises(RuntimeError): 338 | await es.transport.perform_request('GET', '/') 339 | 340 | 341 | @pytest.mark.run_loop 342 | async def test_request_error_404_on_head(loop, auto_close): 343 | exc = TransportError(404) 344 | t = AIOHttpTransport([{}], connection_class=DummyConnection, loop=loop, 345 | exception=exc) 346 | auto_close(t) 347 | 348 | ret = await t.perform_request('HEAD', '/') 349 | assert not ret 350 | 351 | 352 | @pytest.mark.run_loop 353 | async def test_request_connection_error(loop, auto_close): 354 | exc = ConnectionError() 355 | t = AIOHttpTransport([{}], connection_class=DummyConnection, loop=loop, 356 | exception=exc) 357 | auto_close(t) 358 | 359 | with pytest.raises(ConnectionError): 360 | await t.perform_request('GET', '/') 361 | 362 | conn = await t.get_connection() 363 | assert len(conn.calls) == 3 364 | 365 | 366 | @pytest.mark.run_loop 367 | async def test_request_connection_timeout(loop, auto_close): 368 | exc = ConnectionTimeout() 369 | t = AIOHttpTransport([{}], connection_class=DummyConnection, loop=loop, 370 | exception=exc) 371 | auto_close(t) 372 | 373 | with pytest.raises(ConnectionTimeout): 374 | await t.perform_request('GET', '/') 375 | 376 | conn = await t.get_connection() 377 | assert len(conn.calls) == 1 378 | 379 | 380 | @pytest.mark.run_loop 381 | async def test_request_connection_timeout_with_retry(loop, auto_close): 382 | exc = ConnectionTimeout() 383 | t = AIOHttpTransport([{}], connection_class=DummyConnection, loop=loop, 384 | exception=exc, retry_on_timeout=True) 385 | auto_close(t) 386 | 387 | with pytest.raises(ConnectionTimeout): 388 | await t.perform_request('GET', '/') 389 | 390 | conn = await t.get_connection() 391 | assert len(conn.calls) == 3 392 | 393 | 394 | @pytest.mark.run_loop 395 | async def test_request_retry_on_status(loop, auto_close): 396 | exc = TransportError(500) 397 | t = AIOHttpTransport([{}], connection_class=DummyConnection, loop=loop, 398 | exception=exc, retry_on_status=(500,)) 399 | auto_close(t) 400 | 401 | with pytest.raises(TransportError): 402 | await t.perform_request('GET', '/') 403 | 404 | conn = await t.get_connection() 405 | assert len(conn.calls) == 3 406 | 407 | 408 | @pytest.mark.run_loop 409 | async def test_request_without_data(loop, auto_close): 410 | t = AIOHttpTransport([{}], connection_class=DummyConnection, loop=loop, 411 | data='') 412 | auto_close(t) 413 | 414 | ret = await t.perform_request('GET', '/') 415 | assert ret == '' 416 | 417 | 418 | @pytest.mark.run_loop 419 | async def test_request_headers(loop, auto_close, es_server, mocker): 420 | t = auto_close(AIOHttpTransport( 421 | [{'host': es_server['host'], 422 | 'port': es_server['port']}], 423 | http_auth=es_server['auth'], 424 | loop=loop, 425 | headers={'H1': 'V1', 'H2': 'V2'}, 426 | )) 427 | 428 | for conn in t.connection_pool.connections: 429 | mocker.spy(conn.session, 'request') 430 | 431 | await t.perform_request('GET', '/', headers={'H1': 'VV1', 'H3': 'V3'}) 432 | 433 | session = (await t.get_connection()).session 434 | _, kwargs = session.request.call_args 435 | assert kwargs['headers'] == { 436 | 'H1': 'VV1', 437 | 'H2': 'V2', 438 | 'H3': 'V3', 439 | 'Content-Type': 'application/json', 440 | } 441 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py3{5,6,7}-{debug,release},isort,flake8 3 | skip_missing_interpreters = True 4 | 5 | [testenv] 6 | deps = 7 | -r{toxinidir}/requirements.txt 8 | commands = pytest tests {posargs} 9 | 10 | [testenv:coverage] 11 | passenv = CI TRAVIS TRAVIS_* 12 | deps = codecov 13 | commands = codecov 14 | 15 | [testenv:flake8] 16 | skipsdist = True 17 | skip_install = True 18 | deps = flake8 19 | commands = flake8 --show-source aioelasticsearch tests setup.py 20 | 21 | [testenv:isort] 22 | skipsdist = True 23 | skip_install = True 24 | deps = 25 | -r{toxinidir}/requirements.txt 26 | commands = 27 | isort --check-only -rc aioelasticsearch --diff 28 | isort --check-only setup.py --diff 29 | isort --check-only -rc tests --diff 30 | 31 | setenv = 32 | debug: PYTHONASYNCIODEBUG=x 33 | release: PYTHONASYNCIODEBUG= 34 | 35 | 36 | 37 | --------------------------------------------------------------------------------