├── tests ├── __init__.py ├── ext │ ├── __init__.py │ ├── test_resources.py │ └── test_views.py ├── conftest.py ├── test_examples.py ├── test_traversal.py └── test_router.py ├── aiohttp_traversal ├── ext │ ├── __init__.py │ ├── views.py │ ├── static.py │ └── resources.py ├── __init__.py ├── abc.py ├── traversal.py └── router.py ├── doc └── img │ ├── request_lifetime.png │ └── traversal_algorithm.png ├── .bumpversion.cfg ├── .gitignore ├── .travis.yml ├── setup.py ├── README.rst └── examples ├── 1-hello.py └── 2-middleware.py /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/ext/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aiohttp_traversal/ext/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aiohttp_traversal/__init__.py: -------------------------------------------------------------------------------- 1 | from .router import TraversalRouter 2 | from .traversal import lineage, find_root 3 | -------------------------------------------------------------------------------- /doc/img/request_lifetime.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzzsochi/aiohttp_traversal/HEAD/doc/img/request_lifetime.png -------------------------------------------------------------------------------- /doc/img/traversal_algorithm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzzsochi/aiohttp_traversal/HEAD/doc/img/traversal_algorithm.png -------------------------------------------------------------------------------- /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | commit = True 3 | current_version = 0.11.0 4 | tag = True 5 | tag_name = {new_version} 6 | 7 | [bumpversion:file:setup.py] 8 | parse = version='(?P\d+)\.(?P\d+)\.(?P\d+)' 9 | 10 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/* 2 | /dist/ 3 | /.cache/ 4 | 5 | *.pyc 6 | *.pyo 7 | *.egg 8 | *.egg-info 9 | 10 | .python-version 11 | .coverage 12 | 13 | *.sublime-settings 14 | *.sublime-project 15 | *.sublime-workspace 16 | 17 | *.swo 18 | *.swp 19 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | os: linux 2 | dist: xenial 3 | 4 | language: python 5 | python: 6 | - "3.5" 7 | - "3.6" 8 | - "3.7" 9 | 10 | install: 11 | - travis_retry python3 setup.py install 12 | - travis_retry pip install pytest pytest-aiohttp pytest-cov coveralls 13 | 14 | script: 15 | - py.test tests -v --cov aiohttp_traversal --cov-report term-missing 16 | 17 | after_success: 18 | - coveralls 19 | -------------------------------------------------------------------------------- /aiohttp_traversal/abc.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod, abstractproperty 2 | 3 | 4 | class AbstractResource(metaclass=ABCMeta): 5 | @abstractproperty 6 | def __parent__(self): 7 | """ Parent resource or None for root """ 8 | 9 | @abstractmethod 10 | def __getitem__(self, name): 11 | """ Return traversal.Traverser instance 12 | 13 | In simple: 14 | 15 | return traversal.Traverser(self, [name]) 16 | """ 17 | 18 | @abstractmethod 19 | async def __getchild__(self, name): 20 | """ Return child resource or None, if not exists """ 21 | 22 | 23 | class AbstractView(metaclass=ABCMeta): 24 | @abstractmethod 25 | def __init__(self, resource, request): 26 | """ Receive current traversed resource """ 27 | 28 | @abstractmethod 29 | async def __call__(self): 30 | """ Return aiohttp.web.Response """ 31 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | 4 | setup( 5 | name='aiohttp_traversal', 6 | version='0.11.0', 7 | description='Traversal based router for aiohttp.web', 8 | classifiers=[ 9 | "License :: OSI Approved :: BSD License", 10 | "Operating System :: POSIX", 11 | "Programming Language :: Python :: 3.5", 12 | "Programming Language :: Python :: 3.6", 13 | "Programming Language :: Python :: 3.7", 14 | "Topic :: Internet :: WWW/HTTP", 15 | ], 16 | author='Alexander Zelenyak', 17 | author_email='zzz.sochi@gmail.com', 18 | license='BSD', 19 | url='https://github.com/zzzsochi/aiohttp_traversal', 20 | keywords=['asyncio', 'aiohttp', 'traversal', 'pyramid'], 21 | packages=['aiohttp_traversal', 'aiohttp_traversal.ext'], 22 | install_requires=[ 23 | 'aiohttp >=2.0', 24 | 'resolver_deco', 25 | 'zope.dottedname', 26 | ], 27 | tests_require=['pytest'] 28 | ), 29 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ====================================== 2 | Traversal based router for aiohttp.web 3 | ====================================== 4 | 5 | .. image:: https://api.travis-ci.org/zzzsochi/aiohttp_traversal.svg 6 | :target: https://secure.travis-ci.org/zzzsochi/aiohttp_traversal 7 | :align: center 8 | 9 | .. image:: https://coveralls.io/repos/zzzsochi/aiohttp_traversal/badge.svg 10 | :target: https://coveralls.io/r/zzzsochi/aiohttp_traversal 11 | :align: center 12 | 13 | 14 | ------- 15 | Schemes 16 | ------- 17 | 18 | **Request lifetime** 19 | 20 | .. image:: https://raw.githubusercontent.com/zzzsochi/aiohttp_traversal/master/doc/img/request_lifetime.png 21 | :alt: Request lifetime 22 | :align: center 23 | 24 | 25 | **Traversal algorithm** 26 | 27 | .. image:: https://raw.githubusercontent.com/zzzsochi/aiohttp_traversal/master/doc/img/traversal_algorithm.png 28 | :alt: Traversal algorithm 29 | :align: center 30 | 31 | ----- 32 | Tests 33 | ----- 34 | 35 | .. code:: shell 36 | 37 | $ pip install pytest 38 | $ py.test tests -v 39 | -------------------------------------------------------------------------------- /examples/1-hello.py: -------------------------------------------------------------------------------- 1 | """ 2 | Hello World application. 3 | 4 | Start this: 5 | 6 | $ python3 1-hello.py 7 | 8 | Or use aiohttp_devtools: 9 | 10 | $ adev runserver 1-hello.py --app-factory create_app 11 | 12 | After start, check urls: 13 | 14 | * GET localhost:8000/ 15 | * GET localhost:8000/json 16 | """ 17 | import asyncio 18 | 19 | from aiohttp.web import Application, Response, run_app 20 | 21 | from aiohttp_traversal.router import TraversalRouter 22 | from aiohttp_traversal.ext.views import View, RESTView 23 | from aiohttp_traversal.ext.resources import Root 24 | 25 | 26 | class HelloView(View): 27 | @asyncio.coroutine 28 | def __call__(self): 29 | return Response(text="Hello World!") 30 | 31 | 32 | class HelloJSON(RESTView): 33 | methods = {'get'} 34 | 35 | @asyncio.coroutine 36 | def get(self): 37 | return dict(text="Hello World!") 38 | 39 | 40 | def create_app(): 41 | app = Application(router=TraversalRouter()) # create main application instance 42 | app.router.set_root_factory(lambda request, app=app: Root(app)) # set root factory 43 | app.router.bind_view(Root, HelloView) # add view for '/' 44 | app.router.bind_view(Root, HelloJSON, 'json') # add view for '/json' 45 | 46 | return app 47 | 48 | 49 | if __name__ == '__main__': 50 | app = create_app() 51 | run_app(app, port=8000) 52 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import pytest 4 | from aiohttp.web import Application 5 | 6 | from aiohttp_traversal.abc import AbstractResource 7 | from aiohttp_traversal.router import TraversalRouter 8 | from aiohttp_traversal.traversal import Traverser 9 | 10 | 11 | @pytest.yield_fixture 12 | def loop(): 13 | asyncio.set_event_loop(None) 14 | loop = asyncio.new_event_loop() 15 | yield loop 16 | loop.close() 17 | 18 | 19 | @pytest.fixture 20 | def root_factory(): 21 | return lambda app: Resource(parent=None, name='ROOT') 22 | 23 | 24 | class Resource(AbstractResource): 25 | __parent__ = None 26 | 27 | def __init__(self, parent, name): 28 | self.__parent__ = parent 29 | self.name = name 30 | 31 | def __getitem__(self, name): 32 | return Traverser(self, (name,)) 33 | 34 | async def __getchild__(self, name): 35 | if name == 'not': 36 | return None 37 | else: 38 | return Resource(self, name) 39 | 40 | def __repr__(self): 41 | return ''.format(self.name) 42 | 43 | 44 | @pytest.fixture 45 | def router(root_factory): 46 | return TraversalRouter(root_factory=root_factory) 47 | 48 | 49 | @pytest.fixture 50 | def app(router): 51 | return Application(router=router) 52 | 53 | 54 | @pytest.fixture 55 | def root(app, root_factory): 56 | return root_factory(app) 57 | -------------------------------------------------------------------------------- /tests/test_examples.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | from importlib.machinery import SourceFileLoader 4 | 5 | 6 | def load_example(name): 7 | here = pathlib.Path(os.path.dirname(__file__)) 8 | path = here / '..' / 'examples' / (name + '.py') 9 | mod = SourceFileLoader(name, str(path)).load_module() 10 | return mod 11 | 12 | 13 | async def test_hello(aiohttp_client): 14 | mod = load_example('1-hello') 15 | app = mod.create_app() 16 | client = await aiohttp_client(app) 17 | 18 | # HelloView 19 | resp = await client.get('/') 20 | assert resp.status == 200 21 | 22 | body = await resp.text() 23 | assert body == 'Hello World!' 24 | 25 | # HelloJSON 26 | resp = await client.get('/json') 27 | assert resp.status == 200 28 | 29 | data = await resp.json() 30 | assert data == {'text': 'Hello World!'} 31 | 32 | 33 | async def test_middleware(aiohttp_client): 34 | mod = load_example('2-middleware') 35 | app = mod.create_app() 36 | client = await aiohttp_client(app) 37 | 38 | # Hello 39 | resp = await client.get('/') 40 | assert resp.status == 200 41 | 42 | data = await resp.json() 43 | assert data == {'counter': 1} 44 | 45 | # 404 46 | resp = await client.get('/thing') 47 | assert resp.status == 404 48 | 49 | data = await resp.json() 50 | assert data == {'error': 'not_found', 'reason': 'Not Found'} 51 | -------------------------------------------------------------------------------- /tests/test_traversal.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from aiohttp_traversal.traversal import traverse, lineage 4 | 5 | 6 | @pytest.fixture 7 | def res_c(loop, root): 8 | async def coro(): 9 | return await root['a']['b']['c'] 10 | 11 | return loop.run_until_complete(coro()) 12 | 13 | 14 | def test_traverse_await(loop, root): 15 | async def coro(): 16 | return await traverse(root, ('a', 'b', 'c')) 17 | 18 | res, tail = loop.run_until_complete(coro()) 19 | assert res.name == 'c' 20 | assert not tail 21 | assert len(list(lineage(res))) == 4 22 | 23 | 24 | def test_traverse_await_empty(loop, root): 25 | async def coro(): 26 | return await traverse(root, []) 27 | 28 | res, tail = loop.run_until_complete(coro()) 29 | assert res is root 30 | assert not tail 31 | 32 | 33 | def test_traverse_await_with_tail(loop, root): 34 | async def coro(): 35 | return await traverse(root, ('a', 'b', 'not', 'c')) 36 | 37 | res, tail = loop.run_until_complete(coro()) 38 | assert res.name == 'b' 39 | assert tail == ('not', 'c') 40 | assert len(list(lineage(res))) == 3 41 | 42 | 43 | def test_traverser_await_with_tail(loop, root): 44 | # import asyncio 45 | # @asyncio.coroutine 46 | async def coro(): 47 | with pytest.raises(KeyError): 48 | await root['a']['b']['not'] 49 | 50 | loop.run_until_complete(coro()) 51 | -------------------------------------------------------------------------------- /aiohttp_traversal/ext/views.py: -------------------------------------------------------------------------------- 1 | import json 2 | import asyncio 3 | 4 | from aiohttp.web import Response, StreamResponse 5 | from aiohttp.web import HTTPMethodNotAllowed 6 | 7 | from aiohttp_traversal.abc import AbstractView 8 | 9 | 10 | class View(AbstractView): 11 | def __init__(self, request, resource, tail): 12 | self.request = request 13 | self.resource = resource 14 | self.tail = tail 15 | 16 | async def __call__(self): 17 | raise NotImplementedError() 18 | 19 | 20 | class MethodsView(View): 21 | methods = frozenset() # {'get', 'post', 'put', 'patch', 'delete', 'option'} 22 | 23 | async def __call__(self): 24 | method = self.request.method.lower() 25 | 26 | if method in self.methods: 27 | return await getattr(self, method)() 28 | else: 29 | raise HTTPMethodNotAllowed(method, self.methods) 30 | 31 | async def get(self): 32 | raise NotImplementedError 33 | 34 | async def post(self): 35 | raise NotImplementedError 36 | 37 | async def put(self): 38 | raise NotImplementedError 39 | 40 | async def patch(self): 41 | raise NotImplementedError 42 | 43 | async def delete(self): 44 | raise NotImplementedError 45 | 46 | async def option(self): 47 | raise NotImplementedError 48 | 49 | 50 | class RESTView(MethodsView): 51 | def serialize(self, data): 52 | """ Serialize data to JSON. 53 | 54 | You can owerride this method if you data cant be serialized 55 | standart json.dumps routine. 56 | """ 57 | return json.dumps(data).encode('utf8') 58 | 59 | async def __call__(self): 60 | data = await super().__call__() 61 | 62 | if isinstance(data, StreamResponse): 63 | return data 64 | else: 65 | return Response( 66 | body=self.serialize(data), 67 | headers={'Content-Type': 'application/json; charset=utf-8'}, 68 | ) 69 | -------------------------------------------------------------------------------- /aiohttp_traversal/ext/static.py: -------------------------------------------------------------------------------- 1 | """ DO NOT USE IN PRODUCTION!!! 2 | """ 3 | 4 | import os 5 | import asyncio 6 | import mimetypes 7 | import warnings 8 | import logging 9 | from collections import namedtuple 10 | 11 | from aiohttp.web import Response, HTTPNotFound 12 | from resolver_deco import resolver 13 | 14 | from .resources import Resource, add_child 15 | from .views import View 16 | 17 | log = logging.getLogger(__name__) 18 | 19 | 20 | StaticInfo = namedtuple('StaticInfo', ('path', 'content_type', 'data')) 21 | 22 | 23 | class StaticResource(Resource): 24 | path = None 25 | 26 | def __init__(self, parent, name): 27 | super().__init__(parent, name) 28 | self.path = os.path.abspath(self.path) 29 | 30 | async def __getchild__(self, name): 31 | return None 32 | 33 | def get(self, path): 34 | path = os.path.join(self.path, path) 35 | 36 | if not os.path.isfile(path): 37 | raise HTTPNotFound() 38 | 39 | ext = os.path.splitext(path)[1] 40 | ct = mimetypes.types_map.get(ext, 'application/octet-stream') 41 | 42 | with open(path, 'rb') as f: 43 | return StaticInfo(path, ct, f.read()) 44 | 45 | 46 | class StaticView(View): 47 | 48 | async def __call__(self): 49 | if self.request.tail: 50 | path = os.path.join(*self.request.tail) 51 | else: 52 | path = '' 53 | 54 | info = self.resource.get(path) 55 | 56 | return Response( 57 | body=info.data, 58 | headers={'Content-Type': info.content_type}, 59 | ) 60 | 61 | 62 | def prepare_static_view(app): 63 | warnings.warn("Do not use this module in production!") 64 | app.router.bind_view(StaticResource, StaticView, tail='*') 65 | 66 | 67 | @resolver('parent', 'resource_class') 68 | def add_static(app, parent, name, path, resource_class=StaticResource): 69 | """ Add resource for serve static 70 | """ 71 | warnings.warn("Do not use this module in production!") 72 | SRes = type(resource_class.__name__, (resource_class,), {'path': path}) 73 | add_child(app, parent, name, SRes) 74 | -------------------------------------------------------------------------------- /aiohttp_traversal/traversal.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | 4 | log = logging.getLogger(__name__) 5 | 6 | 7 | async def traverse(root, path): 8 | """ Find resource for path. 9 | 10 | root: instance of Resource 11 | path: list or path parts 12 | return: tuple `(resource, tail)` 13 | """ 14 | if not path: 15 | return root, tuple(path) 16 | 17 | path = list(path) 18 | traverser = root[path.pop(0)] 19 | 20 | while path: 21 | traverser = traverser[path.pop(0)] 22 | 23 | return await traverser.traverse() 24 | 25 | 26 | class Traverser: 27 | 28 | def __init__(self, resource, path): 29 | self.resource = resource 30 | self.path = path 31 | 32 | def __getitem__(self, item): 33 | return Traverser(self.resource, self.path + (item,)) 34 | 35 | def __await__(self): 36 | return self.__anext__().__await__() 37 | 38 | async def __anext__(self): 39 | """ This object is coroutine. 40 | 41 | For this: 42 | 43 | await app.router.get_root()['a']['b']['c'] 44 | """ 45 | resource, tail = await self.traverse() 46 | 47 | if tail: 48 | raise KeyError(tail[0]) 49 | else: 50 | return resource 51 | 52 | async def traverse(self): 53 | """ Main traversal algorithm. 54 | 55 | Return tuple `(resource, tail)`. 56 | """ 57 | last, current = None, self.resource 58 | path = list(self.path) 59 | 60 | while path: 61 | item = path[0] 62 | last, current = current, await current.__getchild__(item) 63 | 64 | if current is None: 65 | return last, tuple(path) 66 | 67 | del path[0] 68 | 69 | return current, tuple(path) 70 | 71 | 72 | def lineage(resource): 73 | """ Return a generator representing the lineage 74 | of the resource object implied by the resource argument 75 | """ 76 | while resource is not None: 77 | yield resource 78 | resource = resource.__parent__ 79 | 80 | 81 | def find_root(resource): 82 | """ Find root resource 83 | """ 84 | return list(lineage(resource))[-1] 85 | -------------------------------------------------------------------------------- /examples/2-middleware.py: -------------------------------------------------------------------------------- 1 | """ 2 | REST-like requests counter application with JSON error responces. 3 | 4 | Start this: 5 | 6 | $ python3 2-middleware.py 7 | 8 | Or use aiohttp_devtools: 9 | 10 | $ adev runserver 2-middleware.py --app-factory create_app 11 | 12 | After start, check urls: 13 | 14 | * GET localhost:8000/ 15 | * GET localhost:8000/thing 16 | * POST localhost:8000/ 17 | """ 18 | import asyncio 19 | 20 | import aiohttp 21 | from aiohttp.web import Application, run_app 22 | 23 | from aiohttp_traversal.router import TraversalRouter 24 | from aiohttp_traversal.ext.views import RESTView 25 | from aiohttp_traversal.ext.resources import Root 26 | 27 | 28 | class Counter(RESTView): 29 | methods = {'get'} 30 | 31 | @asyncio.coroutine 32 | def get(self): 33 | self.request.app['counter'] += 1 34 | return dict(counter=self.request.app['counter']) 35 | 36 | 37 | async def json_error_middleware(app, handler): 38 | async def middleware_handler(request): 39 | try: 40 | resp = await handler(request) 41 | if isinstance(resp, aiohttp.web.HTTPException): 42 | raise resp 43 | except aiohttp.web.HTTPNoContent: 44 | raise 45 | except aiohttp.web.HTTPNotFound as exc: 46 | return error_response(404, 'not_found', exc.reason, exc.headers) 47 | except aiohttp.web.HTTPMethodNotAllowed as exc: 48 | return error_response(405, 'not_allowed', exc.reason, exc.headers) 49 | else: 50 | return resp 51 | return middleware_handler 52 | 53 | 54 | def error_response(status, error, reason, headers) -> aiohttp.web.Response: 55 | if headers is not None: 56 | headers.pop('Content-Type', None) 57 | return aiohttp.web.json_response( 58 | data={'error': error, 'reason': reason}, 59 | headers=headers, 60 | status=status, 61 | ) 62 | 63 | 64 | def create_app(): 65 | app = Application(router=TraversalRouter()) 66 | 67 | app.middlewares.append(json_error_middleware) 68 | 69 | app.router.set_root_factory(lambda request, app=app: Root(app)) 70 | app.router.bind_view(Root, Counter) 71 | 72 | app['counter'] = 0 73 | 74 | return app 75 | 76 | 77 | if __name__ == '__main__': 78 | app = create_app() 79 | run_app(app, port=8000) 80 | -------------------------------------------------------------------------------- /aiohttp_traversal/ext/resources.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | 4 | from resolver_deco import resolver 5 | 6 | from aiohttp_traversal.abc import AbstractResource 7 | from aiohttp_traversal.traversal import Traverser 8 | 9 | log = logging.getLogger(__name__) 10 | 11 | 12 | class Resource(AbstractResource): 13 | """ Simple base resource. 14 | """ 15 | _parent = None 16 | name = None 17 | app = None 18 | setup = None 19 | 20 | def __init__(self, parent, name): 21 | self._parent = parent 22 | self.name = str(name) 23 | 24 | if parent is not None: 25 | self.app = parent.app 26 | self.setup = self.app.router.resources.get(self.__class__) 27 | 28 | @property 29 | def __parent__(self): 30 | return self._parent 31 | 32 | def __getitem__(self, name): 33 | return Traverser(self, (name,)) 34 | 35 | async def __getchild__(self, name): 36 | return None 37 | 38 | 39 | class InitCoroMixin: 40 | """ Mixin for create initialization coroutine. 41 | """ 42 | def __new__(cls, *args, **kwargs): # noqa 43 | """ This is magic! 44 | """ 45 | instance = super().__new__(cls) 46 | 47 | async def coro(): 48 | instance.__init__(*args, **kwargs) 49 | await instance.__ainit__() 50 | return instance 51 | 52 | return coro() 53 | 54 | async def __ainit__(self): 55 | raise NotImplementedError 56 | 57 | 58 | class DispatchMixin: 59 | 60 | async def __getchild__(self, name): 61 | if (self.setup is not None and 62 | 'children' in self.setup and 63 | name in self.setup['children']): 64 | 65 | res = self.setup['children'][name](self, name) 66 | 67 | if asyncio.iscoroutine(res): 68 | return await res 69 | else: 70 | return res 71 | else: 72 | return None 73 | 74 | 75 | class DispatchResource(DispatchMixin, Resource): 76 | pass 77 | 78 | 79 | class Root(DispatchResource): 80 | """ This root accept application instance, not request. 81 | 82 | Usage: 83 | 84 | app.router.set_root_factory(lambda request, app=app: Root(app)) 85 | """ 86 | def __init__(self, app, *args, **kwargs): 87 | super().__init__(parent=None, name=None) 88 | self.app = app 89 | self.args = args 90 | self.kwargs = kwargs 91 | self.setup = self.app.router.resources.get(self.__class__) 92 | 93 | 94 | @resolver('parent', 'child') 95 | def add_child(app, parent, name, child): 96 | """ Add child resource for dispatch-resources. 97 | """ 98 | if not issubclass(parent, DispatchMixin): 99 | raise ValueError("{!r} is not a DispatchMixin subclass" 100 | "".format(parent)) 101 | 102 | parent_setup = app.router.resources.setdefault(parent, {}) 103 | parent_setup.setdefault('children', {})[name] = child 104 | -------------------------------------------------------------------------------- /tests/ext/test_resources.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock, MagicMock 2 | import asyncio 3 | 4 | import pytest 5 | 6 | from aiohttp_traversal.traversal import Traverser 7 | from aiohttp_traversal.ext.resources import ( 8 | Resource, 9 | InitCoroMixin, 10 | Root, 11 | add_child, 12 | ) 13 | 14 | 15 | def test_Resource_init(): # noqa 16 | parent = Mock(name='parent', __parent__=None) 17 | app = MagicMock(name='app') 18 | app.router.resources = {} 19 | parent.request.app = parent.app = app 20 | name = 'name' 21 | 22 | res = Resource(parent, name) 23 | 24 | assert res.__parent__ is parent 25 | assert res.name == name 26 | assert res.app is parent.request.app 27 | assert res.setup is res.app.router.resources.get(Resource) 28 | 29 | 30 | def test_Resource_init__root(): # noqa 31 | name = 'root' 32 | 33 | res = Resource(None, name) 34 | 35 | assert res.__parent__ is None 36 | assert res.name == name 37 | assert res.app is None 38 | assert res.setup is None 39 | 40 | 41 | @pytest.fixture 42 | def res_simple(): 43 | parent = Mock(name='parent') 44 | parent.request.app = MagicMock(name='app') 45 | parent.app = parent.request.app 46 | name = 'name' 47 | 48 | return Resource(parent, name) 49 | 50 | 51 | def test_Resource_getitem(loop, res_simple): # noqa 52 | traverser = res_simple['a'] 53 | assert isinstance(traverser, Traverser) 54 | assert traverser.resource is res_simple 55 | assert traverser.path == ('a',) 56 | 57 | 58 | def test_Resource_getchild(loop, res_simple): # noqa 59 | assert loop.run_until_complete(res_simple.__getchild__('a')) is None 60 | 61 | 62 | def test_InitCoroMixin(loop): # noqa 63 | class Res(InitCoroMixin, Resource): 64 | calls_init = 0 65 | calls_ainit = 0 66 | 67 | def __init__(self, parent, name): 68 | super().__init__(parent, name) 69 | self.calls_init += 1 70 | 71 | @asyncio.coroutine 72 | def __ainit__(self): 73 | self.calls_ainit += 1 74 | 75 | coro = Res(None, 'name') 76 | 77 | assert asyncio.iscoroutine(coro) 78 | 79 | res = loop.run_until_complete(coro) 80 | 81 | assert isinstance(res, Res) 82 | assert res.calls_init == 1 83 | assert res.calls_ainit == 1 84 | 85 | 86 | def test_DispatchResource(loop, app): # noqa 87 | class Res(Resource): 88 | pass 89 | 90 | class CoroRes(InitCoroMixin, Resource): 91 | calls_ainit = 0 92 | 93 | async def __ainit__(self): 94 | self.calls_ainit += 1 95 | 96 | add_child(app, 'aiohttp_traversal.ext.resources.Root', 'simple', Res) 97 | add_child(app, Root, 'coro', CoroRes) 98 | 99 | request = MagicMock(name='request') 100 | request.app = app 101 | 102 | root = Root(app) 103 | 104 | res_simple = loop.run_until_complete(root['simple']) 105 | assert isinstance(res_simple, Res) 106 | assert res_simple.name == 'simple' 107 | assert res_simple.__parent__ is root 108 | 109 | res_coro = loop.run_until_complete(root['coro']) 110 | assert isinstance(res_coro, CoroRes) 111 | assert res_coro.name == 'coro' 112 | assert res_coro.__parent__ is root 113 | assert res_coro.calls_ainit == 1 114 | 115 | with pytest.raises(KeyError): 116 | loop.run_until_complete(root['not_exist']) 117 | -------------------------------------------------------------------------------- /tests/ext/test_views.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock 2 | import asyncio 3 | 4 | import pytest 5 | 6 | from aiohttp.web import Response 7 | from aiohttp.web import HTTPMethodNotAllowed 8 | 9 | from aiohttp_traversal.ext.views import ( 10 | View, 11 | MethodsView, 12 | RESTView, 13 | ) 14 | 15 | 16 | def test_View_init(): # noqa 17 | request = Mock(name='request') 18 | resource = Mock(name='resource') 19 | tail = ('ta', 'il') 20 | view = View(request, resource, tail) 21 | assert view.request is request 22 | assert view.resource is resource 23 | assert view.tail is tail 24 | 25 | 26 | @pytest.fixture 27 | def MVw(request): # noqa 28 | class MVw(MethodsView): 29 | methods = {'get', 'post'} 30 | 31 | @asyncio.coroutine 32 | def get(self): 33 | return 'data' 34 | 35 | return MVw 36 | 37 | 38 | def test_MethodsView_call(loop, MVw): # noqa 39 | request = Mock(name='request') 40 | request.method = 'GET' 41 | resource = Mock(name='resource') 42 | tail = ('ta', 'il') 43 | 44 | resp = loop.run_until_complete(MVw(request, resource, tail)()) 45 | assert resp == 'data' 46 | 47 | 48 | def test_MethodsView_call__not_implemented(loop, MVw): # noqa 49 | request = Mock(name='request') 50 | request.method = 'POST' 51 | resource = Mock(name='resource') 52 | tail = ('ta', 'il') 53 | 54 | with pytest.raises(NotImplementedError): 55 | loop.run_until_complete(MVw(request, resource, tail)()) 56 | 57 | 58 | def test_MethodsView_call__not_allowed(loop, MVw): # noqa 59 | request = Mock(name='request') 60 | request.method = 'DELETE' 61 | resource = Mock(name='resource') 62 | tail = ('ta', 'il') 63 | 64 | with pytest.raises(HTTPMethodNotAllowed): 65 | loop.run_until_complete(MVw(request, resource, tail)()) 66 | 67 | 68 | @pytest.fixture 69 | def RVw(): # noqa 70 | class RVw(RESTView): 71 | methods = {'get', 'post'} 72 | 73 | @asyncio.coroutine 74 | def get(self): 75 | return {'key': 'value'} 76 | 77 | @asyncio.coroutine 78 | def post(self): 79 | return Response() 80 | 81 | return RVw 82 | 83 | 84 | def test_RESTView__dict(loop, RVw): # noqa 85 | request = Mock(name='request') 86 | request.method = 'GET' 87 | resource = Mock(name='resource') 88 | tail = ('ta', 'il') 89 | 90 | resp = loop.run_until_complete(RVw(request, resource, tail)()) 91 | assert isinstance(resp, Response) 92 | 93 | 94 | def test_RESTView__response(loop, RVw): # noqa 95 | request = Mock(name='request') 96 | request.method = 'POST' 97 | resource = Mock(name='resource') 98 | tail = ('ta', 'il') 99 | 100 | resp = loop.run_until_complete(RVw(request, resource, tail)()) 101 | assert isinstance(resp, Response) 102 | 103 | 104 | @pytest.fixture 105 | def RVobj(): # noqa 106 | class RVobj(RESTView): 107 | methods = {'get', 'post'} 108 | 109 | def serialize(self, data): 110 | return data.upper().encode('utf8') 111 | 112 | @asyncio.coroutine 113 | def get(self): 114 | return "test" 115 | 116 | return RVobj 117 | 118 | 119 | def test_RESTView__object(loop, RVobj): # noqa 120 | request = Mock(name='request') 121 | request.method = 'GET' 122 | resource = Mock(name='resource') 123 | tail = ('ta', 'il') 124 | 125 | resp = loop.run_until_complete(RVobj(request, resource, tail)()) 126 | assert isinstance(resp, Response) 127 | assert resp.body == b'TEST' 128 | -------------------------------------------------------------------------------- /aiohttp_traversal/router.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import types 4 | from contextlib import contextmanager 5 | 6 | from aiohttp.abc import AbstractRouter, AbstractMatchInfo 7 | from aiohttp.web_exceptions import HTTPNotFound 8 | 9 | from resolver_deco import resolver 10 | from .traversal import traverse 11 | 12 | log = logging.getLogger(__name__) 13 | 14 | SIMPLE_VIEWS_TYPES = (types.FunctionType, types.CoroutineType) 15 | 16 | 17 | class ViewNotResolved(Exception): 18 | """ Raised from Application.resolve_view. 19 | """ 20 | def __init__(self, request, resource, tail): 21 | super().__init__(request, resource, tail) 22 | self.request = request 23 | self.resource = resource 24 | self.tail = tail 25 | 26 | 27 | class BaseMatchInfo(AbstractMatchInfo): 28 | route = None 29 | _current_app = None 30 | 31 | def __init__(self): 32 | self._apps = [] 33 | 34 | async def expect_handler(self, request): 35 | return None 36 | 37 | @property 38 | def http_exception(self): 39 | return None 40 | 41 | @property 42 | def apps(self): 43 | return self._apps 44 | 45 | def add_app(self, app): 46 | self._apps.append(app) 47 | 48 | def freeze(self): 49 | pass 50 | 51 | @property 52 | def current_app(self): 53 | return self._current_app 54 | 55 | @contextmanager 56 | def set_current_app(self, app): 57 | assert app in self._apps, ( 58 | "Expected one of the following apps {!r}, got {!r}" 59 | .format(self._apps, app)) 60 | prev = self._current_app 61 | self._current_app = app 62 | try: 63 | yield 64 | finally: 65 | self._current_app = prev 66 | 67 | 68 | class MatchInfo(BaseMatchInfo): 69 | def __init__(self, request, resource, tail, view): 70 | super().__init__() 71 | self.request = request 72 | self.resource = resource 73 | self.tail = tail 74 | self.view = view 75 | 76 | def handler(self, request): 77 | if isinstance(self.view, SIMPLE_VIEWS_TYPES): 78 | return self.view(self.request, self.resource, self.tail) 79 | else: 80 | return self.view() 81 | 82 | def get_info(self): 83 | return { 84 | 'request': self.request, 85 | 'resource': self.resource, 86 | 'tail': self.tail, 87 | 'view': self.view, 88 | } 89 | 90 | 91 | class TraversalExceptionMatchInfo(BaseMatchInfo): 92 | def __init__(self, request, exc): 93 | super().__init__() 94 | self.request = request 95 | self.exc = exc 96 | 97 | def handler(self, request): 98 | raise self.exc 99 | 100 | def get_info(self): 101 | return { 102 | 'request': self.request, 103 | 'exc': self.exc, 104 | } 105 | 106 | 107 | class TraversalRouter(AbstractRouter): 108 | _root_factory = None 109 | 110 | @resolver('root_factory') 111 | def __init__(self, root_factory=None): 112 | self.set_root_factory(root_factory) 113 | self.resources = {} 114 | self.exceptions = {} 115 | 116 | async def resolve(self, request): 117 | try: 118 | resource, tail = await self.traverse(request) 119 | exc = None 120 | except Exception as _exc: 121 | resource = None 122 | tail = None 123 | exc = _exc 124 | 125 | request.resource = resource 126 | request.tail = tail 127 | request.exc = exc 128 | 129 | if resource is not None: 130 | try: 131 | view = self.resolve_view(request, resource, tail) 132 | except ViewNotResolved: 133 | return TraversalExceptionMatchInfo(request, HTTPNotFound()) 134 | 135 | return MatchInfo(request, resource, tail, view) 136 | else: 137 | return TraversalExceptionMatchInfo(request, exc) 138 | 139 | async def traverse(self, request, *args, **kwargs): 140 | path = tuple(p for p in request.path.split('/') if p) 141 | root = self.get_root(request, *args, **kwargs) 142 | if path: 143 | return await traverse(root, path) 144 | else: 145 | return root, path 146 | 147 | @resolver('root_factory') 148 | def set_root_factory(self, root_factory): 149 | """ Set root resource class. 150 | 151 | Analogue of the "set_root_factory" method from pyramid framework. 152 | """ 153 | self._root_factory = root_factory 154 | 155 | def get_root(self, request, *args, **kwargs): 156 | """ Create new root resource instance. 157 | """ 158 | return self._root_factory(request, *args, **kwargs) 159 | 160 | @resolver('resource') 161 | def resolve_view(self, request, resource, tail=()): 162 | """ Resolve view for resource and tail. 163 | """ 164 | if isinstance(resource, type): 165 | resource_class = resource 166 | else: 167 | resource_class = resource.__class__ 168 | 169 | for rc in resource_class.__mro__[:-1]: 170 | if rc in self.resources: 171 | if 'views' not in self.resources[rc]: 172 | continue 173 | 174 | views = self.resources[rc]['views'] 175 | 176 | if tail in views: 177 | view = views[tail] 178 | break 179 | 180 | elif '*' in views: 181 | view = views['*'] 182 | break 183 | 184 | else: 185 | raise ViewNotResolved(request, resource, tail) 186 | 187 | if isinstance(view, SIMPLE_VIEWS_TYPES): 188 | return view 189 | else: 190 | return view(request, resource, tail) 191 | 192 | @resolver('resource', 'view') 193 | def bind_view(self, resource, view, tail=()): 194 | """ Bind view for resource. 195 | """ 196 | if isinstance(tail, str) and tail != '*': 197 | tail = tuple(i for i in tail.split('/') if i) 198 | 199 | setup = self.resources.setdefault(resource, {'views': {}}) 200 | setup.setdefault('views', {})[tail] = view 201 | 202 | def __repr__(self): 203 | return "<{}>".format(self.__class__.__name__) 204 | -------------------------------------------------------------------------------- /tests/test_router.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock 2 | 3 | import pytest 4 | from aiohttp.web_exceptions import HTTPNotFound 5 | 6 | from aiohttp_traversal.traversal import find_root, lineage 7 | from aiohttp_traversal.router import ( 8 | MatchInfo, 9 | ViewNotResolved, 10 | TraversalExceptionMatchInfo, 11 | ) 12 | 13 | 14 | def test_repr(router): 15 | repr(router) 16 | 17 | 18 | @pytest.fixture 19 | def request(app): 20 | request = Mock(name='request') 21 | request.app = app 22 | request.path = '/' 23 | return request 24 | 25 | 26 | def test_resolve(loop, router, request, root): 27 | request.path = '/a/b/c' 28 | 29 | async def traverse(request): 30 | return ('res', 'tail') 31 | 32 | def resolve_view(req, res, tail): 33 | async def view_call(): 34 | return 'view_result' 35 | 36 | mock = Mock(name='view') 37 | mock.name = 'view' 38 | mock.return_value = view_call() 39 | return mock 40 | 41 | router.traverse = traverse 42 | router.resolve_view = resolve_view 43 | mi = loop.run_until_complete(router.resolve(request)) 44 | 45 | assert isinstance(mi, MatchInfo) 46 | assert mi.view.name == 'view' 47 | assert mi.route is None 48 | 49 | result = loop.run_until_complete(mi.handler(request)) 50 | assert result == 'view_result' 51 | 52 | 53 | def test_resolve__not_found(loop, router, request, root): 54 | request.path = '/a/b/c' 55 | 56 | async def traverse(request): 57 | return ('res', 'tail') 58 | 59 | def resolve_view(req, res, tail): 60 | raise ViewNotResolved(req, res, tail) 61 | 62 | router.traverse = traverse 63 | router.resolve_view = resolve_view 64 | 65 | mi = loop.run_until_complete(router.resolve(request)) 66 | 67 | with pytest.raises(HTTPNotFound): 68 | mi.handler(request) 69 | 70 | 71 | def test_resolve__exception(loop, router, request, root): 72 | request.path = '/a/b/c' 73 | 74 | async def traverse(request): 75 | raise ValueError() 76 | 77 | router.traverse = traverse 78 | mi = loop.run_until_complete(router.resolve(request)) 79 | 80 | assert isinstance(mi, TraversalExceptionMatchInfo) 81 | assert isinstance(mi.exc, ValueError) 82 | assert mi.route is None 83 | 84 | with pytest.raises(ValueError): 85 | loop.run_until_complete(mi.handler(request)) 86 | 87 | 88 | def test_traverse(loop, router, request): 89 | request.path = '/a/b/c' 90 | 91 | res, tail = loop.run_until_complete(router.traverse(request)) 92 | 93 | assert res.name == 'c' 94 | assert not tail 95 | assert len(list(lineage(res))) == 4 96 | assert find_root(res).name == 'ROOT' 97 | 98 | 99 | def test_traverse_with_tail(loop, router, request, ): 100 | request.path = '/a/b/not/c' 101 | 102 | res, tail = loop.run_until_complete(router.traverse(request)) 103 | 104 | assert res.name == 'b' 105 | assert tail == ('not', 'c') 106 | assert len(list(lineage(res))) == 3 107 | assert find_root(res).name == 'ROOT' 108 | 109 | 110 | def test_traverse_root(loop, router, request): 111 | request.path = '/' 112 | 113 | res, tail = loop.run_until_complete(router.traverse(request)) 114 | 115 | assert tail == () 116 | assert len(list(lineage(res))) == 1 117 | assert find_root(res) is res 118 | assert res.name == 'ROOT' 119 | 120 | 121 | def test_traverse_root_with_tail(loop, router, request): 122 | request.path = '/not/c' 123 | 124 | res, tail = loop.run_until_complete(router.traverse(request)) 125 | 126 | assert tail == ('not', 'c') 127 | assert len(list(lineage(res))) == 1 128 | assert find_root(res).name == 'ROOT' 129 | 130 | 131 | def test_set_root_factory(router): 132 | assert router._root_factory 133 | new_root_class = Mock(name='root') 134 | router.set_root_factory(new_root_class) 135 | assert router._root_factory is new_root_class 136 | 137 | 138 | def test_get_root(router, app): 139 | assert router.get_root(app).name == 'ROOT' 140 | 141 | 142 | @pytest.fixture 143 | def Res(): # noqa 144 | return type('res', (), {}) 145 | 146 | 147 | @pytest.fixture 148 | def View(): # noqa 149 | class View: 150 | def __init__(self, request, resource, tail): 151 | self.request = request 152 | self.resource = resource 153 | self.tail = tail 154 | 155 | def __call__(self): 156 | return 'response' 157 | 158 | return View 159 | 160 | 161 | def test_resolve_view(router, Res, View): # noqa 162 | res = Res() 163 | tail = ('a', 'b') 164 | router.resources[Res] = {'views': {tail: View}} 165 | 166 | view = router.resolve_view(None, res, tail) 167 | 168 | assert isinstance(view, View) 169 | assert view.resource is res 170 | 171 | 172 | def test_resolve_view_simple(router, Res): # noqa 173 | 174 | async def view(): 175 | pass 176 | 177 | res = Res() 178 | tail = ('a', 'b') 179 | router.resources[Res] = {'views': {tail: view}} 180 | 181 | result_view = router.resolve_view(None, res, tail) 182 | 183 | assert result_view is view 184 | 185 | 186 | def test_resolve_view__asterisk(router, Res, View): # noqa 187 | res = Res() 188 | router.resources[Res] = {'views': {'*': View}} 189 | 190 | view = router.resolve_view(None, res, ('a', 'b')) 191 | 192 | assert isinstance(view, View) 193 | assert view.resource is res 194 | 195 | 196 | def test_resolve_view__mro(router, Res, View): # noqa 197 | class SubRes(Res): 198 | pass 199 | 200 | res = SubRes() 201 | router.resources[Res] = {'views': {'*': View}} 202 | 203 | view = router.resolve_view(None, res, '*') 204 | 205 | assert isinstance(view, View) 206 | assert view.resource is res 207 | 208 | 209 | def test_resolve_view__mro_invert(router, Res, View): # noqa 210 | class SubRes(Res): 211 | pass 212 | 213 | res = Res() 214 | router.resources[SubRes] = {'views': {'*': View}} 215 | 216 | with pytest.raises(ViewNotResolved): 217 | router.resolve_view(None, res, '*') 218 | 219 | 220 | def test_resolve_view__not_resolved(router): 221 | with pytest.raises(ViewNotResolved): 222 | router.resolve_view(None, str, ()) 223 | 224 | 225 | def test_bind_view(router, Res, View): # noqa 226 | router.bind_view(Res, View) 227 | assert router.resources[Res]['views'][()] is View 228 | 229 | 230 | def test_bind_view__tail_str(router, Res, View): # noqa 231 | router.bind_view(Res, View, '/a/b') 232 | assert router.resources[Res]['views'][('a', 'b')] is View 233 | 234 | 235 | def test_bind_view__tail_str_asterisk(router, Res, View): # noqa 236 | router.bind_view(Res, View, '*') 237 | assert router.resources[Res]['views']['*'] is View 238 | 239 | 240 | def test_match_info(router, View): # noqa 241 | view = View('request', 'resource', 'tail') 242 | mi = MatchInfo('request', 'resource', 'tail', view) 243 | response = mi.handler('request') 244 | assert response == 'response' 245 | 246 | 247 | def test_match_info__simple_view(router): 248 | def view(request, resource, tail): 249 | assert request == 'request' 250 | assert resource == 'resource' 251 | assert tail == 'tail' 252 | return 'response' 253 | 254 | mi = MatchInfo('request', 'resource', 'tail', view) 255 | response = mi.handler('request') 256 | assert response == 'response' 257 | --------------------------------------------------------------------------------