├── tests ├── __init__.py ├── integration │ ├── __init__.py │ ├── test_service.py │ └── test_registry.py ├── factories.py ├── conftest.py └── test_registry.py ├── requirements ├── __init__.py ├── test.txt └── base.txt ├── trellio ├── utils │ ├── __init__.py │ ├── async_utils.py │ ├── jsonencoder.py │ ├── ordered_class_member.py │ ├── helpers.py │ ├── log_handlers.py │ ├── decorators.py │ ├── stats.py │ └── log.py ├── conf_manager │ ├── __init__.py │ └── conf_client.py ├── management │ ├── commands │ │ ├── __init__.py │ │ ├── base.py │ │ └── start.py │ ├── exceptions.py │ ├── __init__.py │ └── core.py ├── protocol_factory.py ├── manage.py ├── wrappers.py ├── sendqueue.py ├── signals.py ├── exceptions.py ├── __init__.py ├── views.py ├── jsonprotocol.py ├── pinger.py ├── pubsub.py ├── packet.py ├── registry_client.py ├── host.py ├── bus.py ├── registry.py └── services.py ├── requirements.txt ├── README.md ├── Dockerfile ├── LICENSE ├── .gitignore ├── .travis.yml └── setup.py /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /trellio/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/integration/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /trellio/conf_manager/__init__.py: -------------------------------------------------------------------------------- 1 | from .conf_client import * 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -r requirements/base.txt 2 | -r requirements/test.txt -------------------------------------------------------------------------------- /trellio/management/commands/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .start import * 3 | -------------------------------------------------------------------------------- /trellio/management/exceptions.py: -------------------------------------------------------------------------------- 1 | class InvalidCMDArguments(Exception): 2 | pass 3 | -------------------------------------------------------------------------------- /trellio/management/__init__.py: -------------------------------------------------------------------------------- 1 | from .commands import * 2 | from .core import * 3 | from .exceptions import * 4 | -------------------------------------------------------------------------------- /trellio/utils/async_utils.py: -------------------------------------------------------------------------------- 1 | def run_safe_thread(): 2 | pass 3 | 4 | 5 | def run_in_unsafe_thread(): 6 | pass 7 | -------------------------------------------------------------------------------- /trellio/protocol_factory.py: -------------------------------------------------------------------------------- 1 | from .jsonprotocol import TrellioProtocol 2 | 3 | 4 | def get_trellio_protocol(handler): 5 | return TrellioProtocol(handler) 6 | -------------------------------------------------------------------------------- /trellio/manage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | import sys 3 | 4 | from .management.core import execute_from_command_line 5 | 6 | if __name__ == '__main__': 7 | execute_from_command_line(sys.argv) 8 | -------------------------------------------------------------------------------- /requirements/test.txt: -------------------------------------------------------------------------------- 1 | factory-boy==2.5.2 2 | fake-factory==0.5.2 3 | flake8==3.3.0 4 | py==1.4.33 5 | pycodestyle==2.3.1 6 | pyflakes==1.5.0 7 | pytest==3.0.7 8 | pytest-asyncio==0.5.0 9 | requests==2.13.0 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/artificilabs/trellio.svg?branch=master)](https://travis-ci.org/artificilabs/trellio) 2 | 3 | # trellio 4 | 5 | Python 3 asyncio based micro-framework for micro-service architecture 6 | -------------------------------------------------------------------------------- /trellio/wrappers.py: -------------------------------------------------------------------------------- 1 | from aiohttp.web import Request as Req, Response as Res 2 | 3 | 4 | class Request(Req): 5 | """ 6 | Wraps the aiohttp request object to hide it from user 7 | """ 8 | pass 9 | 10 | 11 | class Response(Res): 12 | """ 13 | Wraps the aiohttp response object to hide it from user 14 | """ 15 | pass 16 | -------------------------------------------------------------------------------- /requirements/base.txt: -------------------------------------------------------------------------------- 1 | again==1.2.21 2 | aiohttp==2.0.2 3 | appdirs==1.4.3 4 | async-retrial==0.7 5 | async-timeout==1.2.0 6 | asyncio-redis==0.14.2 7 | cchardet==1.1.3 8 | jsonstreamer==1.3.6 9 | mccabe==0.6.1 10 | multidict==2.1.4 11 | packaging==16.8 12 | pep8==1.7.0 13 | pyparsing==2.2.0 14 | python-json-logger==0.1.7 15 | PyYAML==3.12 16 | setproctitle==1.1.10 17 | six==1.10.0 18 | uvloop==0.8.0 19 | yarl==0.10.0 20 | -------------------------------------------------------------------------------- /tests/integration/test_service.py: -------------------------------------------------------------------------------- 1 | from trellio import TCPService, api, Host 2 | 3 | 4 | class ServiceTest(TCPService): 5 | def __init__(self): 6 | super().__init__('test_service', '1') 7 | 8 | @api 9 | async def echo(self): 10 | return "echo" 11 | 12 | 13 | if __name__ == "__main__": 14 | test_service = ServiceTest() 15 | Host.configure(tcp_port=8001) 16 | 17 | Host.attach_tcp_service(test_service) 18 | Host.run() 19 | -------------------------------------------------------------------------------- /trellio/management/core.py: -------------------------------------------------------------------------------- 1 | class ManagementCommandNotFound(Exception): 2 | pass 3 | 4 | 5 | def execute_from_command_line(args): # sys args 6 | command_name = args[1] 7 | command_args = args[2:] 8 | from trellio.management.commands.base import ManagementRegistry 9 | if not ManagementRegistry.get(command_name): 10 | raise ManagementCommandNotFound 11 | command_class = ManagementRegistry.get(command_name) 12 | command_class(command_args).run() 13 | -------------------------------------------------------------------------------- /trellio/utils/jsonencoder.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | from time import mktime 4 | 5 | 6 | class TrellioEncoder(json.JSONEncoder): 7 | """ 8 | json dump encoder class 9 | """ 10 | 11 | def default(self, obj): 12 | """ 13 | convert datetime instance to str datetime 14 | """ 15 | if isinstance(obj, datetime.datetime): 16 | return int(mktime(obj.timetuple())) 17 | return json.JSONEncoder.default(self, obj) 18 | -------------------------------------------------------------------------------- /trellio/management/commands/base.py: -------------------------------------------------------------------------------- 1 | class ManagementCommand: 2 | name = '' 3 | 4 | def run(self): 5 | raise NotImplementedError 6 | 7 | 8 | class ManagementRegistry: 9 | _management_reg = {} 10 | 11 | @classmethod 12 | def register(cls, command_cls): 13 | if issubclass(command_cls, ManagementCommand): 14 | cls._management_reg[command_cls.name] = command_cls 15 | 16 | @classmethod 17 | def get(cls, name): 18 | return cls._management_reg.get(name) 19 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.6 2 | 3 | RUN apt-get -y update && apt-get install -y redis-server cmake git 4 | 5 | RUN export LD_LIBRARY_PATH=$HOME/.local/lib/:$LD_LIBRARY_PATH 6 | RUN git clone --depth=1 https://github.com/lloyd/yajl.git 7 | WORKDIR /yajl/ 8 | RUN ./configure --prefix=$HOME/.local/ 9 | RUN cmake -DCMAKE_INSTALL_PREFIX=$HOME/.local/ && make && make install 10 | RUN ln -s /yajl/yajl-2.1.1/lib/libyajl.so.2.1.1 /usr/lib/x86_64-linux-gnu/libyajl.so 11 | 12 | EXPOSE 4500 13 | 14 | RUN pip install trellio 15 | CMD ["python","-m","trellio.registry"] -------------------------------------------------------------------------------- /trellio/sendqueue.py: -------------------------------------------------------------------------------- 1 | class SendQueue: 2 | """ 3 | Queues packets to send when transport can send 4 | """ 5 | 6 | def __init__(self, transport, can_send_func=lambda: True, pre_process_func=lambda x: x): 7 | self._q = [] 8 | self._transport = transport 9 | self._can_send = can_send_func 10 | self._pre_process = pre_process_func 11 | 12 | def send(self, packet=None): 13 | if packet: 14 | self._q.append(packet) 15 | if self._can_send(): 16 | for each in self._q: 17 | each = self._pre_process(each) 18 | self._transport.write(each) 19 | self._q.clear() 20 | -------------------------------------------------------------------------------- /tests/factories.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import factory 4 | 5 | logging.getLogger("factory").setLevel(logging.WARN) 6 | 7 | 8 | class ServiceFactory(factory.DictFactory): 9 | name = factory.Sequence(lambda n: "service_%d" % n) 10 | version = "1.0.0" 11 | dependencies = factory.List([]) 12 | events = factory.List([]) 13 | host = factory.Sequence(lambda n: "192.168.0.%d" % n) 14 | port = factory.Sequence(lambda n: 4000 + n) 15 | node_id = factory.Sequence(lambda n: "node_%d" % n) 16 | type = 'tcp' 17 | 18 | 19 | class EndpointFactory(factory.DictFactory): 20 | endpoint = factory.Sequence(lambda n: "endpoint_%d" % n) 21 | strategy = factory.Iterator(['DESIGNATION', 'RANDOM']) 22 | -------------------------------------------------------------------------------- /trellio/utils/ordered_class_member.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | 4 | class OrderedClassMembers(type): 5 | @classmethod 6 | def __prepare__(self, name, bases): 7 | return collections.OrderedDict() 8 | 9 | def __new__(self, name, bases, classdict): 10 | classdict['__ordered__'] = [key for key in classdict.keys() if # __classcell__ passed to __new__ 11 | key not in ('__module__', '__qualname__', '__classcell__')] 12 | 13 | for each in bases: # add keys from base classes also 14 | odict = getattr(each, '__ordered__', each.__dict__) 15 | 16 | classdict['__ordered__'] = [key for key in odict if key not in ('__module__', '__qualname__')] + \ 17 | classdict['__ordered__'] 18 | 19 | return type.__new__(self, name, bases, classdict) 20 | -------------------------------------------------------------------------------- /trellio/signals.py: -------------------------------------------------------------------------------- 1 | class InvalidSignalType(Exception): 2 | pass 3 | 4 | 5 | class BaseSignal: 6 | _registry_list = [] 7 | 8 | @classmethod 9 | def register(cls, to_register, soft=False): 10 | cls._registry_list.append([to_register, soft]) 11 | 12 | @classmethod 13 | async def _run(cls, *args, **kwargs): 14 | print('service ready run called') 15 | for i in cls._registry_list: 16 | try: 17 | await i[0](*args, **kwargs) 18 | except Exception as e: 19 | if not i[1]: 20 | raise e 21 | 22 | 23 | class ServiceReady(BaseSignal): 24 | _registry_list = [] 25 | 26 | 27 | def register(signal_type, soft=False): 28 | def decorator(receiver): 29 | if not issubclass(signal_type, BaseSignal): 30 | raise InvalidSignalType 31 | signal_type.register(receiver, soft) 32 | return receiver 33 | 34 | return decorator 35 | -------------------------------------------------------------------------------- /trellio/exceptions.py: -------------------------------------------------------------------------------- 1 | class TrellioServiceException(Exception): 2 | """ 3 | To be subclassed by service level exceptions and indicate exceptions that 4 | are to be handled at the service level itself. 5 | These exceptions shall not be counted as errors at the macroscopic level. 6 | eg: record not found, invalid parameter etc. 7 | """ 8 | 9 | 10 | class TrellioServiceError(Exception): 11 | """ 12 | Unlike TrellioServiceExceptions these will be counted as errors and must only 13 | be used when a service encounters an error it couldn't handle at its level. 14 | eg: client not found, database disconnected. 15 | """ 16 | 17 | 18 | class TrellioException(Exception): 19 | pass 20 | 21 | 22 | class RequestException(Exception): 23 | pass 24 | 25 | 26 | class ClientException(Exception): 27 | pass 28 | 29 | 30 | class ClientNotFoundError(ClientException): 31 | pass 32 | 33 | 34 | class ClientDisconnected(ClientException): 35 | pass 36 | 37 | 38 | class AlreadyRegistered(Exception): 39 | pass 40 | -------------------------------------------------------------------------------- /trellio/utils/helpers.py: -------------------------------------------------------------------------------- 1 | from ..wrappers import Response 2 | 3 | 4 | class Borg(object): 5 | __shared_state = dict() 6 | 7 | def __init__(self): 8 | self.__dict__ = self.__shared_state 9 | 10 | 11 | class Singleton(object): 12 | _instance = None 13 | _init_ran = False 14 | 15 | def has_inited(self): 16 | return self._init_ran 17 | 18 | def init_done(self): 19 | self._init_ran = True 20 | 21 | def __new__(cls, *args, **kwargs): 22 | if cls._instance is None: 23 | cls._instance = super(Singleton, cls).__new__(cls, *args, **kwargs) 24 | return cls._instance 25 | 26 | 27 | def default_preflight_response(request): 28 | headers = {'Access-Control-Allow-Origin': '*', 29 | 'Access-Control-Allow-Methods': 'GET,POST,PUT,DELETE', 30 | 'Access-Control-Allow-Headers': 'Access-Control-Allow-Headers, Origin,Accept, X-Requested-With, Content-Type, Authorization, Access-Control-Request-Method, Access-Control-Request-Headers', 31 | 'Access-Control-Allow-Credentials': 'true'} 32 | return Response(status=204, headers=headers) 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 technomaniac 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # PyInstaller 3 | # Usually these files are written by a python script from a template 4 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 5 | *.manifest 6 | *.spec 7 | 8 | # Installer logs 9 | pip-log.txt 10 | pip-delete-this-directory.txt 11 | 12 | # Unit test / coverage reports 13 | htmlcov/ 14 | .tox/ 15 | .coverage 16 | .coverage.* 17 | .cache 18 | nosetests.xml 19 | coverage.xml 20 | *,cover 21 | 22 | ======= 23 | .hypothesis/ 24 | 25 | # Translations 26 | *.mo 27 | *.pot 28 | 29 | # Django stuff: 30 | 31 | local_settings.py 32 | 33 | # Flask stuff: 34 | instance/ 35 | .webassets-cache 36 | 37 | # Scrapy stuff: 38 | .scrapy 39 | 40 | # Sphinx documentation 41 | docs/_build/ 42 | 43 | # PyBuilder 44 | target/ 45 | 46 | .idea/ 47 | # IPython Notebook 48 | .ipynb_checkpoints 49 | 50 | # pyenv 51 | .python-version 52 | 53 | # celery beat schedule file 54 | celerybeat-schedule 55 | 56 | # dotenv 57 | .env 58 | 59 | # virtualenv 60 | venv/ 61 | ENV/ 62 | 63 | *.orig 64 | # Spyder project settings 65 | .spyderproject 66 | 67 | # Rope project settings 68 | .ropeproject 69 | 70 | dist/ 71 | trellio.egg-info/ 72 | __pycache__/ 73 | trellio/__pycache__/ 74 | trellio/utils/__pycache__/ 75 | *.pyc 76 | 77 | build/ -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | sudo: false 3 | 4 | python: 5 | - '3.5' 6 | - '3.6' 7 | 8 | services: 9 | - redis-server 10 | 11 | before_install: 12 | - export LD_LIBRARY_PATH=$HOME/.local/lib/:$LD_LIBRARY_PATH 13 | - git clone --depth=1 https://github.com/lloyd/yajl.git 14 | - cd yajl && ./configure --prefix=$HOME/.local/ 15 | - cmake -DCMAKE_INSTALL_PREFIX=$HOME/.local/ && make && make install 16 | - cd .. 17 | 18 | install: pip install -r requirements.txt --upgrade 19 | script: 20 | - py.test --ignore tests/integration --assert=plain 21 | - py.test tests/integration --assert=plain 22 | 23 | deploy: 24 | provider: pypi 25 | user: technomaniac 26 | password: 27 | secure: I8I7k7Idzl2z7rbtmO8ITtG3AX5nmdTTe2EqRixEmZQBxLXKLG6kF9EynwojycaFDfHfmrhoUubuAhFz0Q4P1SRqtqaZWRJdqjMLAL5iqpjTnrKIfshNQYzWkEJ7u4V7prQ5lJbG/sPgGiYCBeZjAT6J7XShbsZ9tK3IiJUHH5B7/Lsu83sbuSOVdlhyqumMWkHflselisEfYA9ngR8IkKffr8uPIgg+O90t8Fa3uz4Z9CXSViZxe2t7lZXVntIjHRGJbwTUzCJqbWWljetl6HJQcgCwIN7ORpSxZ9K6m0VOKkYafD2u2XCo+p3BDnQc4Ovs3BAyadFKRG+WKi4fsHK9lldD2E+e3jaR8/EbjHJb5qcFWtnvLC0J72pCRXuKRn5BlWjo6zMQ4UPDu/kh93k2Y21YKqhZ6ch3fVAwEs7j0RLYVhS7TBPKoRcX+lfsjojCO/NKG63UC/xzKhUSc2sR4DwGlBp1Kzo3qTUvalmy+ry2xNSC1fFs2bYgVcjSP4n1aR/vJ40q/cCOaDQDpuHL+JEaSxiHar0IvHVF5OuUt9hsxhhGU/LlmW8cZaplNCrsRlYQns8qsBEEyJ3I+xEdrVRmOWrNolxyGjkD1MFr9BoXr4MK9pQjrVBHRZEm43YFAkgkODtdC/hmExyTrJg4brG9JKxDKwTmW5bHaak= 28 | on: 29 | tags: true 30 | branch: master 31 | -------------------------------------------------------------------------------- /trellio/utils/log_handlers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.handlers 3 | import smtplib 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | class BufferingSMTPHandler(logging.handlers.BufferingHandler): 9 | def __init__(self, mailhost, mailport, fromaddr, toaddrs, subject, capacity, password): 10 | logging.handlers.BufferingHandler.__init__(self, capacity) 11 | self.mailhost = mailhost 12 | self.mailport = mailport 13 | self.fromaddr = fromaddr 14 | self.subject = subject 15 | self.toaddrs = toaddrs 16 | self.setFormatter(logging.Formatter("%(asctime)s %(levelname)-5s %(message)s")) 17 | self.smtp = smtplib.SMTP(self.mailhost, self.mailport) 18 | self.smtp.ehlo() 19 | self.smtp.starttls() 20 | self.smtp.login(fromaddr, password) 21 | 22 | def handleError(self, record): 23 | print(record) # overriding original and not doing anything 24 | 25 | def flush(self): 26 | if len(self.buffer) > 0: 27 | try: 28 | msg = "From: %s\r\nTo: %s\r\nSubject: %s\r\n\r\n" % ( 29 | self.fromaddr, ",".join(self.toaddrs), self.subject) 30 | for record in self.buffer: 31 | s = self.format(record) 32 | msg = msg + s + "\r\n" 33 | self.smtp.sendmail(self.fromaddr, self.toaddrs, msg) 34 | except Exception as e: 35 | self.handleError(e) # no particular record 36 | self.buffer = [] 37 | -------------------------------------------------------------------------------- /trellio/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['Host', 'TCPServiceClient', 'TCPService', 'HTTPService', 'HTTPServiceClient', 'api', 'request', 'subscribe', 2 | 'publish', 'xsubscribe', 'get', 'post', 'head', 'put', 'patch', 'delete', 'options', 'trace', 3 | 'RequestException', 'Response', 'Request', 'log', 'setup_logging', 'apideprecated', 4 | 'TrellioServiceException', 'TrellioServiceError', 'ConfigHandler', 'ManagementCommand', 'HTTPView', 5 | 'TCPView', 'ManagementRegistry', 'InvalidCMDArguments', 'execute_from_command_line', 'Publisher', 6 | 'Subscriber'] 7 | 8 | import logging 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | try: 13 | import asyncio 14 | import uvloop 15 | 16 | asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) 17 | except ImportError: 18 | logger.warning('uvloop is not install, event loop will be set to default asyncio loop') 19 | 20 | from .conf_manager import * 21 | from .exceptions import RequestException, TrellioServiceError, TrellioServiceException # noqa 22 | from .host import Host # noqa 23 | from .management import * 24 | from .pubsub import Publisher, Subscriber 25 | from .services import (TCPService, HTTPService, HTTPServiceClient, TCPServiceClient) # noqa 26 | from .services import (api, request, subscribe, publish, xsubscribe, apideprecated) # noqa 27 | from .services import (get, post, head, put, patch, delete, options, trace) # noqa 28 | from .utils import log # noqa 29 | from .utils.log import setup_logging # noqa 30 | from .views import HTTPView, TCPView 31 | from .wrappers import Response, Request # noqa 32 | 33 | __version__ = '1.1.34' 34 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import ast 3 | import re 4 | from os import getcwd, path 5 | 6 | from pip.download import PipSession 7 | from pip.req import parse_requirements 8 | from setuptools import setup, find_packages 9 | 10 | if not path.dirname(__file__): # setup.py without /path/to/ 11 | _dirname = getcwd() # /path/to/ 12 | else: 13 | _dirname = path.dirname(path.dirname(__file__)) 14 | 15 | 16 | def read(name, default=None, debug=True): 17 | try: 18 | filename = path.join(_dirname, name) 19 | with open(filename) as f: 20 | return f.read() 21 | except Exception as e: 22 | err = "%s: %s" % (type(e), str(e)) 23 | if debug: 24 | print(err) 25 | return default 26 | 27 | 28 | def lines(name): 29 | txt = read(name) 30 | return map( 31 | lambda l: l.lstrip().rstrip(), 32 | filter(lambda t: not t.startswith('#'), txt.splitlines() if txt else []) 33 | ) 34 | 35 | 36 | install_reqs = parse_requirements("./requirements/base.txt", session=PipSession()) 37 | install_requires = [str(ir.req).split('==')[0] for ir in install_reqs] 38 | 39 | with open('trellio/__init__.py', 'rb') as i: 40 | version = str(ast.literal_eval(re.compile(r'__version__\s+=\s+(.*)').search(i.read().decode('utf-8')).group(1))) 41 | 42 | setup( 43 | name='trellio', 44 | packages=find_packages(exclude=['examples', 'tests']), 45 | version=version, 46 | description='Python3 asyncio based micro-framework for micro-service architecture', 47 | author='Abhishek Verma, Nirmal Singh', 48 | author_email='ashuverma1989@gmail.com, nirmal.singh.cer08@itbhu.ac.in', 49 | url='https://github.com/artificilabs/trellio.git', 50 | keywords=['asyncio', 'microservice', 'microframework', 'aiohttp'], 51 | package_data={'requirements': ['*.txt']}, 52 | install_requires=install_requires 53 | ) 54 | -------------------------------------------------------------------------------- /trellio/management/commands/start.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from ...conf_manager.conf_client import ConfigHandler 4 | from ...management.commands.base import ManagementCommand, ManagementRegistry 5 | from ...management.exceptions import InvalidCMDArguments 6 | 7 | 8 | class TrellioHostCommand(ManagementCommand): 9 | ''' 10 | usage:python trellio.py start_service <(optional)service_file_path> 11 | ''' 12 | 13 | name = 'runserver' 14 | 15 | def parse_args(self, args): 16 | new_args = {} 17 | for ind, arg in enumerate(args): 18 | if '=' in arg: 19 | broken = arg.split('=') 20 | new_args[broken[0]] = broken[1] 21 | elif not new_args.get('config') and ind == 0: 22 | new_args['config'] = arg 23 | if not new_args.get('config'): 24 | new_args['config'] = os.path.abspath('./config.json') 25 | return new_args 26 | 27 | def __init__(self, args): 28 | from trellio.host import Host 29 | self.args = self.parse_args(args) 30 | self.host_class = Host 31 | self.config_manager = ConfigHandler(self.host_class) 32 | self.setup_config() 33 | 34 | def setup_config(self): 35 | try: 36 | self.config_manager.set_config(config_path=self.args['config']) 37 | except KeyError: 38 | raise InvalidCMDArguments('Please give config.json path!') 39 | 40 | def setup(self): 41 | self.setup_config() 42 | self.setup_environment_variables() 43 | self.config_manager.setup_host() 44 | 45 | def setup_environment_variables(self): 46 | pass # todo not needed right now 47 | 48 | def run(self): 49 | self.setup() 50 | self.host_class._smpt_handler = self.config_manager.get_smtp_logging_handler() 51 | self.host_class.run() 52 | 53 | 54 | ManagementRegistry.register(TrellioHostCommand) 55 | -------------------------------------------------------------------------------- /trellio/utils/decorators.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from asyncio.coroutines import iscoroutine, coroutine 3 | from asyncio.tasks import sleep 4 | from functools import wraps 5 | 6 | 7 | def deprecated(func): 8 | """ 9 | Generates a deprecation warning 10 | """ 11 | 12 | @wraps(func) 13 | def wrapper(*args, **kwargs): 14 | msg = "'{}' is deprecated".format(func.__name__) 15 | warnings.warn(msg, category=DeprecationWarning, stacklevel=2) 16 | return func(*args, **kwargs) 17 | 18 | return wrapper 19 | 20 | 21 | def retry(exceptions, tries=5, delay=1, backoff=2, logger=None): 22 | """ 23 | Retry calling the decorated function using an exponential backoff. 24 | 25 | Args: 26 | exceptions: The exception to check. may be a tuple of 27 | exceptions to check. 28 | tries: Number of times to try (not retry) before giving up. 29 | delay: Initial delay between retries in seconds. 30 | backoff: Backoff multiplier (e.g. value of 2 will double the delay 31 | each retry). 32 | logger: Logger to use. If None, print. 33 | """ 34 | 35 | def deco_retry(func): 36 | @wraps(func) 37 | async def f_retry(self, *args, **kwargs): 38 | if not iscoroutine(func): 39 | f = coroutine(func) 40 | else: 41 | f = func 42 | 43 | mtries, mdelay = tries, delay 44 | while mtries > 1: 45 | try: 46 | return await f(self, *args, **kwargs) 47 | except exceptions: 48 | if logger: 49 | logger.info('Retrying %s after %s seconds', f.__name__, mdelay) 50 | sleep(mdelay) 51 | mtries -= 1 52 | mdelay *= backoff 53 | return await f(self, *args, **kwargs) 54 | 55 | return f_retry 56 | 57 | return deco_retry 58 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from trellio.registry import Registry, Repository 4 | from .factories import ServiceFactory, EndpointFactory 5 | 6 | 7 | @pytest.fixture 8 | def registry(): 9 | r = Registry(ip='192.168.1.1', port=4001, repository=Repository()) 10 | return r 11 | 12 | 13 | @pytest.fixture 14 | def service(*args, **kwargs): 15 | return ServiceFactory(*args, **kwargs) 16 | 17 | 18 | @pytest.fixture 19 | def dependencies_from_services(*services): 20 | return [{'name': service['name'], 'version': service['version']} for service in services] 21 | 22 | 23 | def endpoints_for_service(service, n): 24 | endpoints = [] 25 | for _ in range(n): 26 | endpoint = EndpointFactory() 27 | endpoint['name'] = service['name'] 28 | endpoint['version'] = service['version'] 29 | endpoints.append(endpoint) 30 | return endpoints 31 | 32 | 33 | @pytest.fixture 34 | def service_a1(): 35 | return ServiceFactory() 36 | 37 | 38 | @pytest.fixture 39 | def service_a2(service_a1): 40 | return ServiceFactory(service=service_a1['name'], version='1.0.1') 41 | 42 | 43 | @pytest.fixture 44 | def service_b1(service_a1): 45 | return ServiceFactory(dependencies=dependencies_from_services(service_a1)) 46 | 47 | 48 | @pytest.fixture 49 | def service_b2(service_b1): 50 | return ServiceFactory( 51 | service=service_b1['name'], version='1.0.1', dependencies=dependencies_from_services(service_b1)) 52 | 53 | 54 | @pytest.fixture 55 | def service_c1(service_a1, service_b1): 56 | return ServiceFactory(dependencies=dependencies_from_services(service_a1, service_b1)) 57 | 58 | 59 | @pytest.fixture 60 | def service_c2(service_a2, service_b2, service_c1): 61 | return ServiceFactory(service=service_c1['name'], version='1.0.1', 62 | dependencies=dependencies_from_services(service_a2, service_b2)) 63 | 64 | 65 | @pytest.fixture 66 | def service_c3(service_a2, service_b2, service_c2): 67 | return ServiceFactory(service=service_c2['name'], version='1.0.1', 68 | dependencies=dependencies_from_services(service_a2, service_b2)) 69 | 70 | 71 | @pytest.fixture 72 | def service_d1(service_a1): 73 | return ServiceFactory(events=endpoints_for_service(service_a1, 1)) 74 | -------------------------------------------------------------------------------- /trellio/views.py: -------------------------------------------------------------------------------- 1 | __all__ = ['HTTPView', 'TCPView'] 2 | 3 | from again.utils import unique_hex 4 | 5 | from .utils.helpers import default_preflight_response 6 | from .utils.ordered_class_member import OrderedClassMembers 7 | 8 | 9 | class BaseView: 10 | '''base class for views''' 11 | _host = None 12 | 13 | @property 14 | def host(self): 15 | return self._host 16 | 17 | @host.setter 18 | def host(self, host): 19 | self._host = host 20 | 21 | 22 | class BaseHTTPView(BaseView, metaclass=OrderedClassMembers): 23 | '''base class for HTTP views''' 24 | middlewares = [] 25 | 26 | def __init__(self): 27 | super(BaseHTTPView, self).__init__() 28 | 29 | 30 | class HTTPView(BaseHTTPView): 31 | def __init__(self, allow_cross_domain=True, 32 | preflight_response=default_preflight_response): 33 | super(HTTPView, self).__init__() 34 | self._allow_cross_domain = allow_cross_domain 35 | self._preflight_response = preflight_response 36 | 37 | @property 38 | def cross_domain_allowed(self): 39 | return self._allow_cross_domain 40 | 41 | @property 42 | def preflight_response(self): 43 | return self._preflight_response 44 | 45 | 46 | class BaseTCPView(BaseView): 47 | '''base class for TCP views''' 48 | 49 | def __init__(self): 50 | super(BaseTCPView, self).__init__() 51 | 52 | 53 | class TCPView(BaseTCPView): 54 | def __init__(self): 55 | super(TCPView, self).__init__() 56 | 57 | @staticmethod 58 | def _make_response_packet(request_id: str, from_id: str, entity: str, result: object, error: object, 59 | failed: bool, old_api=None, replacement_api=None): 60 | from .services import _Service 61 | if failed: 62 | payload = {'request_id': request_id, 'error': error, 'failed': failed} 63 | else: 64 | payload = {'request_id': request_id, 'result': result} 65 | if old_api: 66 | payload['old_api'] = old_api 67 | if replacement_api: 68 | payload['replacement_api'] = replacement_api 69 | packet = {'pid': unique_hex(), 70 | 'to': from_id, 71 | 'entity': entity, 72 | 'type': _Service._RES_PKT_STR, 73 | 'payload': payload} 74 | return packet 75 | -------------------------------------------------------------------------------- /tests/integration/test_registry.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | 3 | import requests 4 | from aiohttp.web_response import json_response 5 | 6 | from trellio import * 7 | from trellio.registry import Repository, Registry 8 | 9 | processes = [] 10 | 11 | 12 | class ServiceA(TCPService): 13 | def __init__(self): 14 | super().__init__("service_a", "1", host_port=8001, host_ip='0.0.0.0') 15 | 16 | @api 17 | async def echo(self, data): 18 | return data 19 | 20 | 21 | class ServiceClientA(TCPServiceClient): 22 | def __init__(self): 23 | super(ServiceClientA, self).__init__("raven", "1") 24 | 25 | @request 26 | def echo(self, data, test=True): 27 | return locals() 28 | 29 | 30 | class ServiceB(HTTPService): 31 | def __init__(self): 32 | super().__init__("service_b", "1", host_port=8000, host_ip='0.0.0.0') 33 | self._client_a = ServiceClientA() 34 | 35 | @get(path="/{data}/") 36 | async def get_echo(self, request): 37 | data = request.match_info.get('data') 38 | d = await self._client_a.echo(data) 39 | return json_response(text=d) 40 | 41 | 42 | class ServiceBTCP(TCPService): 43 | def __init__(self): 44 | super().__init__("service_b", "1", host_port=8003, host_ip='0.0.0.0') 45 | self._client_a = ServiceClientA() 46 | 47 | @api 48 | async def get_echo(self, data): 49 | data = request.match_info.get('data') 50 | d = await self._client_a.echo(data) 51 | return d 52 | 53 | 54 | def start_registry(): 55 | repository = Repository() 56 | registry = Registry('0.0.0.0', 4500, repository) 57 | registry.start() 58 | 59 | 60 | def start_service_a(): 61 | Host.configure(http_port=8002, tcp_port=8001, tcp_host='0.0.0.0', service_name="service_a", service_version="1") 62 | Host.attach_tcp_service(ServiceA()) 63 | Host.run() 64 | 65 | 66 | def start_service_b(): 67 | client_a = ServiceClientA() 68 | service_b = ServiceB() 69 | tcp_service = ServiceBTCP() 70 | 71 | Host.configure(http_port=8000, tcp_port=8003, http_host='0.0.0.0', service_name="service_b", service_version="1", 72 | registry_host='0.0.0.0', registry_port=4500) 73 | 74 | service_b.clients = [client_a] 75 | tcp_service.clients = [client_a] 76 | Host.attach_http_service(service_b) 77 | Host.attach_tcp_service(tcp_service) 78 | Host.run() 79 | 80 | 81 | def setup_module(): 82 | global processes 83 | for target in [start_registry, start_service_a, start_service_b]: 84 | p = multiprocessing.Process(target=target) 85 | p.start() 86 | processes.append(p) 87 | 88 | # allow the subsystems to start up. 89 | # sleep for awhile 90 | import time 91 | time.sleep(1) 92 | 93 | 94 | def restart_service_a(): 95 | processes[1].terminate() 96 | p = multiprocessing.Process(target=start_service_a) 97 | p.start() 98 | processes[1] = p 99 | import time 100 | time.sleep(1) 101 | 102 | 103 | def teardown_module(): 104 | for p in processes: 105 | p.terminate() 106 | 107 | 108 | def test_service_b(): 109 | url = 'http://0.0.0.0:8000/blah/' 110 | r = requests.get(url) 111 | assert r.text == 'blah' 112 | assert r.status_code == 200 113 | 114 | 115 | if __name__ == "__main__": 116 | setup_module() 117 | # restart_service_a() 118 | test_service_b() 119 | -------------------------------------------------------------------------------- /tests/test_registry.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from unittest import mock 3 | 4 | 5 | def service_registered_successfully(registry, *services): 6 | for service in services: 7 | service_entry = ( 8 | service['host'], service['port'], service['node_id'], service['type']) 9 | try: 10 | entry = registry._repository._registered_services[ 11 | service['name']][service['version']] 12 | assert service_entry in entry 13 | except KeyError: 14 | raise 15 | return True 16 | 17 | 18 | def no_pending_services(registry): 19 | return len(registry._repository.get_pending_services()) == 0 20 | 21 | 22 | def instance_returned_successfully(response, service): 23 | instance = ( 24 | service['host'], service['port'], service['node_id'], service['type']) 25 | for returned_instance in response['params']['instances']: 26 | t = ( 27 | returned_instance['host'], returned_instance['port'], returned_instance['node'], returned_instance['type']) 28 | if instance == t: 29 | return True 30 | 31 | return False 32 | 33 | 34 | def subscriber_returned_successfully(response, service): 35 | service_t = (service['host'], service['port'], service['node_id'], service['name'], service['version']) 36 | for s in response['params']['subscribers']: 37 | subscriber_t = (s['host'], s['port'], s['node_id'], s['name'], s['version']) 38 | 39 | if service_t == subscriber_t: 40 | return True 41 | return False 42 | 43 | 44 | def test_register_independent_service(registry, service_a1): 45 | registry.register_service( 46 | packet={'params': service_a1}, registry_protocol=mock.Mock()) 47 | 48 | assert service_registered_successfully(registry, service_a1) 49 | assert no_pending_services(registry) 50 | 51 | 52 | def test_register_dependent_service(registry, service_a1, service_b1): 53 | registry.register_service( 54 | packet={'params': service_b1}, registry_protocol=mock.Mock()) 55 | assert not no_pending_services(registry) 56 | 57 | registry.register_service( 58 | packet={'params': service_a1}, registry_protocol=mock.Mock()) 59 | assert no_pending_services(registry) 60 | 61 | assert service_registered_successfully(registry, service_a1, service_b1) 62 | 63 | 64 | def test_deregister_dependent_service(service_a1, service_b1, registry): 65 | registry.register_service( 66 | packet={'params': service_b1}, registry_protocol=mock.Mock()) 67 | registry.register_service( 68 | packet={'params': service_a1}, registry_protocol=mock.Mock()) 69 | 70 | assert no_pending_services(registry) 71 | 72 | registry.deregister_service(service_a1['host'], service_a1['port'], service_a1['node_id']) 73 | assert not no_pending_services(registry) 74 | 75 | 76 | def test_get_instances(service_a1, registry): 77 | registry.register_service( 78 | packet={'params': service_a1}, registry_protocol=mock.Mock()) 79 | 80 | protocol = mock.Mock() 81 | registry.get_service_instances( 82 | packet={'params': service_a1, 'request_id': str(uuid.uuid4())}, registry_protocol=protocol) 83 | 84 | assert instance_returned_successfully( 85 | protocol.send.call_args_list[0][0][0], service_a1) 86 | 87 | 88 | def test_xsubscribe(service_a1, service_d1, registry): 89 | # assert service_d1 == {} 90 | registry.register_service( 91 | packet={'params': service_a1}, registry_protocol=mock.Mock()) 92 | registry.register_service( 93 | packet={'params': service_d1}, registry_protocol=mock.Mock()) 94 | registry._xsubscribe(packet={'params': service_d1}) 95 | 96 | protocol = mock.Mock() 97 | params = { 98 | 'name': service_a1['name'], 99 | 'version': service_a1['version'], 100 | 'endpoint': service_d1['events'][0]['endpoint'] 101 | } 102 | registry.get_subscribers(packet={'params': params, 'request_id': str(uuid.uuid4())}, protocol=protocol) 103 | assert subscriber_returned_successfully(protocol.send.call_args_list[0][0][0], service_d1) 104 | -------------------------------------------------------------------------------- /trellio/jsonprotocol.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import logging 4 | 5 | from jsonstreamer import ObjectStreamer 6 | 7 | from .sendqueue import SendQueue 8 | from .utils.jsonencoder import TrellioEncoder 9 | 10 | 11 | class JSONProtocol(asyncio.Protocol): 12 | logger = logging.getLogger(__name__) 13 | 14 | def __init__(self): 15 | self._send_q = None 16 | self._connected = False 17 | self._transport = None 18 | self._obj_streamer = None 19 | self._pending_data = [] 20 | 21 | @staticmethod 22 | def _make_frame(packet): 23 | string = json.dumps(packet, cls=TrellioEncoder) + ',' 24 | return string.encode() 25 | 26 | def is_connected(self): 27 | return self._connected 28 | 29 | def _write_pending_data(self): 30 | for packet in self._pending_data: 31 | frame = self._make_frame(packet) 32 | self._transport.write(frame) 33 | self._pending_data.clear() 34 | 35 | def connection_made(self, transport): 36 | self._connected = True 37 | self._transport = transport 38 | try: 39 | self._transport.send = self._transport.write 40 | except: 41 | pass 42 | self._send_q = SendQueue(transport, self.is_connected) 43 | self.set_streamer() 44 | self._send_q.send() 45 | 46 | def set_streamer(self): 47 | self._obj_streamer = ObjectStreamer() 48 | self._obj_streamer.auto_listen(self, prefix='on_') 49 | self._obj_streamer.consume('[') 50 | 51 | def connection_lost(self, exc): 52 | self._connected = False 53 | self.logger.info('Peer closed %s', self._transport.get_extra_info('peername')) 54 | 55 | def send(self, packet): 56 | frame = self._make_frame(packet) 57 | self._send_q.send(frame) 58 | self.logger.debug('Data sent: %s', frame.decode()) 59 | 60 | def close(self): 61 | self._transport.write(']'.encode()) # end the json array 62 | self._transport.close() 63 | 64 | def data_received(self, byte_data): 65 | string_data = byte_data.decode() 66 | self.logger.debug('Data received: %s', string_data) 67 | try: 68 | self._obj_streamer.consume(string_data) 69 | except: 70 | # recover from invalid data 71 | self.logger.exception('Invalid data received') 72 | self.set_streamer() 73 | 74 | def on_object_stream_start(self): 75 | raise RuntimeError('Incorrect JSON Streaming Format: expect a JSON Array to start at root, got object') 76 | 77 | def on_object_stream_end(self): 78 | del self._obj_streamer 79 | raise RuntimeError('Incorrect JSON Streaming Format: expect a JSON Array to end at root, got object') 80 | 81 | def on_array_stream_start(self): 82 | self.logger.debug('Array Stream started') 83 | 84 | def on_array_stream_end(self): 85 | del self._obj_streamer 86 | self.logger.debug('Array Stream ended') 87 | 88 | def on_pair(self, pair): 89 | self.logger.debug('Pair {}'.format(pair)) 90 | raise RuntimeError('Received a key-value pair object - expected elements only') 91 | 92 | 93 | class TrellioProtocol(JSONProtocol): 94 | def __init__(self, handler): 95 | super(TrellioProtocol, self).__init__() 96 | self._handler = handler 97 | 98 | def connection_made(self, transport): 99 | peer_name = transport.get_extra_info('peername') 100 | self.logger.info('Connection from %s', peer_name) 101 | super(TrellioProtocol, self).connection_made(transport) 102 | 103 | def connection_lost(self, exc): 104 | super(TrellioProtocol, self).connection_lost(exc) 105 | try: 106 | self._handler._handle_connection_lost() 107 | except: 108 | pass 109 | # self.logger.exception(str(e)) 110 | 111 | def on_element(self, element): 112 | try: 113 | self._handler.receive(packet=element, protocol=self, transport=self._transport) 114 | except: 115 | # ignore any unhandled errors raised by handler 116 | self.logger.exception('api request exception') 117 | pass 118 | -------------------------------------------------------------------------------- /trellio/utils/stats.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import socket 4 | from collections import defaultdict, deque 5 | 6 | import setproctitle 7 | 8 | 9 | class Stats: 10 | name = None 11 | hostname = socket.gethostbyname(socket.gethostname()) 12 | service_name = '_'.join(setproctitle.getproctitle().split('_')[1:-1]) 13 | http_stats = {'total_requests': 0, 'total_responses': 0, 'timedout': 0, 'total_errors': 0} 14 | tcp_stats = {'total_requests': 0, 'total_responses': 0, 'timedout': 0, 'total_errors': 0} 15 | 16 | @classmethod 17 | def periodic_stats_logger(cls): 18 | logd = defaultdict(lambda: 0) 19 | logd['hostname'] = cls.hostname 20 | logd['service_name'] = cls.name 21 | 22 | for key, value in cls.http_stats.items(): 23 | logd[key] += value 24 | logd['http_' + key] = value 25 | 26 | for key, value in cls.tcp_stats.items(): 27 | logd[key] += value 28 | logd['tcp_' + key] = value 29 | 30 | _logger = logging.getLogger('stats') 31 | _logger.info(dict(logd)) 32 | 33 | asyncio.get_event_loop().call_later(120, cls.periodic_stats_logger) 34 | 35 | 36 | class StatUnit: 37 | MAXSIZE = 10 38 | 39 | def __init__(self, key=None): 40 | self.key = key 41 | self.average = 0 42 | self.values = deque() 43 | self.count = 0 44 | self.success_count = 0 45 | self.sub = dict() 46 | 47 | def update(self, val, success): 48 | self.values.append(val) 49 | if len(self.values) > self.MAXSIZE: 50 | self.values.popleft() 51 | 52 | self.average = sum(self.values) / len(self.values) 53 | self.count += 1 54 | if success: 55 | self.success_count += 1 56 | 57 | def to_dict(self): 58 | 59 | d = dict({'count': self.count, 'average': self.average, 'success_count': self.success_count, 'sub': dict()}) 60 | for k, v in self.sub.items(): 61 | d['sub'][k] = v.to_dict() 62 | return d 63 | 64 | def __str__(self): 65 | return "{} {} {} {}".format(self.key, self.sum, self.average, self.count) 66 | 67 | 68 | class Aggregator: 69 | _stats = StatUnit(key='total') 70 | _service_name = None 71 | 72 | @classmethod 73 | def recursive_update(cls, d, new_val, keys, success): 74 | if len(keys) == 0: 75 | return 76 | 77 | try: 78 | key = keys.pop() 79 | value = d[key] 80 | 81 | except KeyError: 82 | value = StatUnit(key=key) 83 | d[key] = value 84 | 85 | finally: 86 | value.update(new_val, success) 87 | cls.recursive_update(value.sub, new_val, keys, success) 88 | 89 | @classmethod 90 | def update_stats(cls, endpoint, status, time_taken, server_type, success=True): 91 | 92 | cls._stats.update(val=time_taken, success=success) 93 | cls.recursive_update(cls._stats.sub, time_taken, keys=[status, endpoint, server_type], success=success) 94 | 95 | @classmethod 96 | def dump_stats(cls): 97 | return cls._stats.to_dict() 98 | 99 | @classmethod 100 | def periodic_aggregated_stats_logger(cls): 101 | hostname = socket.gethostbyname(socket.gethostname()) 102 | 103 | logd = cls._stats.to_dict() 104 | logs = [] 105 | for server_type in ['http', 'tcp']: 106 | try: 107 | server_type_d = logd['sub'][server_type]['sub'] 108 | except KeyError: 109 | continue 110 | for k, v in server_type_d.items(): 111 | d = dict({ 112 | 'method': k, 113 | 'server_type': server_type, 114 | 'hostname': hostname, 115 | 'service_name': cls._service_name, 116 | 'average_response_time': v['average'], 117 | 'total_request_count': v['count'], 118 | 'success_count': v['success_count'] 119 | }) 120 | for k2, v2 in v['sub'].items(): 121 | d['CODE_{}'.format(k2)] = v2['count'] 122 | logs.append(d) 123 | 124 | _logger = logging.getLogger('stats') 125 | for logd in logs: 126 | _logger.info(dict(logd)) 127 | 128 | asyncio.get_event_loop().call_later(300, cls.periodic_aggregated_stats_logger) 129 | -------------------------------------------------------------------------------- /trellio/pinger.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import functools 3 | import logging 4 | 5 | from aiohttp import request 6 | 7 | from .packet import ControlPacket 8 | 9 | PING_TIMEOUT = 10 10 | PING_INTERVAL = 5 11 | 12 | 13 | class Pinger: 14 | """ 15 | Pinger to send ping packets to an endpoint and inform if the timeout has occurred 16 | """ 17 | 18 | def __init__(self, handler, interval, timeout, loop=None, max_failures=5): 19 | """ 20 | Aysncio based pinger 21 | :param handler: Pinger uses it to send a ping and inform when timeout occurs. 22 | Must implement send_ping() and on_timeout() methods 23 | :param int interval: time interval between ping after a pong 24 | :param loop: Optional event loop 25 | """ 26 | 27 | self._handler = handler 28 | self._interval = interval 29 | self._timeout = timeout 30 | self._loop = loop or asyncio.get_event_loop() 31 | self._timer = None 32 | self._failures = 0 33 | 34 | self._max_failures = max_failures 35 | self.logger = logging.getLogger() 36 | 37 | @asyncio.coroutine 38 | def send_ping(self, payload=None): 39 | """ 40 | Sends the ping after the interval specified when initializing 41 | """ 42 | yield from asyncio.sleep(self._interval) 43 | self._handler.send_ping(payload=payload) 44 | self._start_timer(payload=payload) 45 | 46 | def pong_received(self, payload=None): 47 | """ 48 | Called when a pong is received. So the timer is cancelled 49 | """ 50 | if self._timer is not None: 51 | self._timer.cancel() 52 | self._failures = 0 53 | asyncio.async(self.send_ping(payload=payload)) 54 | 55 | def _start_timer(self, payload=None): 56 | self._timer = self._loop.call_later(self._timeout, functools.partial(self._on_timeout, payload=payload)) 57 | 58 | def stop(self): 59 | if self._timer is not None: 60 | self._timer.cancel() 61 | 62 | def _on_timeout(self, payload=None): 63 | if self._failures < self._max_failures: 64 | self._failures += 1 65 | asyncio.ensure_future(self.send_ping(payload=payload)) 66 | else: 67 | self._handler.on_timeout() 68 | 69 | 70 | class TCPPinger: 71 | logger = logging.getLogger(__name__) 72 | 73 | def __init__(self, host, port, node_id, protocol, handler): 74 | self._host = host 75 | self._port = port 76 | self._pinger = Pinger(self, PING_INTERVAL, PING_TIMEOUT) 77 | self._node_id = node_id 78 | self._protocol = protocol 79 | self._handler = handler 80 | 81 | def ping(self, payload=None): 82 | asyncio.ensure_future(self._pinger.send_ping(payload=payload)) 83 | 84 | def send_ping(self, payload=None): 85 | self._protocol.send(ControlPacket.ping(self._node_id, payload=payload)) 86 | 87 | def on_timeout(self): 88 | self.logger.debug('%s timed out', self._node_id) 89 | # Dummy packet to cleanly close transport 90 | self._protocol._transport.write( 91 | '{"closed":"true", "type":"closed", "service":"none", "version":"none"}'.encode()) 92 | self._protocol.close() 93 | self._handler.on_timeout(self._host, self._port, self._node_id) 94 | 95 | def stop(self): 96 | self._pinger.stop() 97 | 98 | def pong_received(self, payload=None): 99 | self._pinger.pong_received(payload=payload) 100 | 101 | 102 | class HTTPPinger: 103 | def __init__(self, host, port, node_id, handler): 104 | self._host = host 105 | self._port = port 106 | self._pinger = Pinger(self, PING_INTERVAL, PING_TIMEOUT) 107 | self._node_id = node_id 108 | self._handler = handler 109 | self._url = 'http://{}:{}/ping'.format(host, port) 110 | self.logger = logging.getLogger(__name__) 111 | 112 | def ping(self, payload=None): 113 | asyncio.ensure_future(self._pinger.send_ping(payload=payload)) 114 | 115 | def send_ping(self, payload=None): 116 | asyncio.ensure_future(self.ping_coroutine(payload=payload)) 117 | 118 | def ping_coroutine(self, payload=None): 119 | try: 120 | res = yield from request('get', self._url) 121 | if res.status == 200: 122 | self.pong_received(payload=payload) 123 | res.close() 124 | except Exception: 125 | self.logger.exception('Error while ping') 126 | 127 | def stop(self): 128 | self._pinger.stop() 129 | 130 | def on_timeout(self): 131 | self.logger.warn('%s timed out', self._node_id) 132 | self._handler.on_timeout(self._host, self._port, self._node_id) 133 | 134 | def pong_received(self, payload=None): 135 | self._pinger.pong_received(payload=payload) 136 | -------------------------------------------------------------------------------- /trellio/pubsub.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import logging 4 | 5 | import asyncio_redis as redis 6 | 7 | from .utils.helpers import Borg 8 | from .utils.jsonencoder import TrellioEncoder 9 | 10 | 11 | class PubSub: 12 | """ 13 | Pub sub handler which uses redis. 14 | Can be used to publish an event or subscribe to a list of endpoints 15 | """ 16 | 17 | def __init__(self, redis_host, redis_port): 18 | """ 19 | Create in instance of Pub Sub handler 20 | :param str redis_host: Redis Host address 21 | :param redis_port: Redis port number 22 | """ 23 | self._redis_host = redis_host 24 | self._redis_port = redis_port 25 | self._conn = None 26 | self._logger = logging.getLogger(__name__) 27 | 28 | async def connect(self): 29 | """ 30 | Connect to the redis server and return the connection 31 | :return: 32 | """ 33 | self._conn = await self._get_conn() 34 | return self._conn 35 | 36 | async def publish(self, endpoint: str, payload: str): 37 | """ 38 | Publish to an endpoint. 39 | :param str endpoint: Key by which the endpoint is recognised. 40 | Subscribers will use this key to listen to events 41 | :param str payload: Payload to publish with the event 42 | :return: A boolean indicating if the publish was successful 43 | """ 44 | if self._conn is not None: 45 | try: 46 | await self._conn.publish(endpoint, payload) 47 | return True 48 | except redis.Error as e: 49 | self._logger.error('Publish failed with error %s', repr(e)) 50 | return False 51 | 52 | async def subscribe(self, endpoints: list, handler): 53 | """ 54 | Subscribe to a list of endpoints 55 | :param endpoints: List of endpoints the subscribers is interested to subscribe to 56 | :type endpoints: list 57 | :param handler: The callback to call when a particular event is published. 58 | Must take two arguments, a channel to which the event was published 59 | and the payload. 60 | :return: 61 | """ 62 | connection = await self._get_conn() 63 | subscriber = await connection.start_subscribe() 64 | await subscriber.subscribe(endpoints) 65 | while True: 66 | payload = await subscriber.next_published() 67 | handler(payload.channel, payload.value) 68 | 69 | async def _get_conn(self): 70 | return await redis.Connection.create(self._redis_host, self._redis_port, auto_reconnect=True) 71 | 72 | 73 | class Publisher(Borg): 74 | def __init__(self, service_name, service_version, pubsub_host, pubsub_port): 75 | super(Publisher, self).__init__() 76 | self._service_name = service_name 77 | self._service_version = service_version 78 | self._host = pubsub_host 79 | self._port = pubsub_port 80 | self._pubsub_handler = None 81 | 82 | @property 83 | def service_name(self): 84 | return self._service_name 85 | 86 | @property 87 | def service_version(self): 88 | return self._service_version 89 | 90 | async def create_pubsub_handler(self): 91 | self._pubsub_handler = PubSub(self._host, self._port) 92 | await self._pubsub_handler.connect() 93 | 94 | def _publish(self, endpoint, payload): 95 | channel = self._get_pubsub_channel(endpoint) 96 | asyncio.async(self._pubsub_handler.publish(channel, json.dumps(payload, cls=TrellioEncoder))) 97 | 98 | def _get_pubsub_channel(self, endpoint): 99 | return '/'.join((self.service_name, str(self.service_version), endpoint)) 100 | 101 | 102 | class Subscriber: 103 | def __init__(self, service_name, service_version, pubsub_host=None, pubsub_port=None): 104 | self._service_name = service_name 105 | self._service_version = service_version 106 | self._pubsub_host = pubsub_host 107 | self._pubsub_port = pubsub_port 108 | self._pubsub_handler = None 109 | 110 | @property 111 | def service_name(self): 112 | return self._service_name 113 | 114 | @property 115 | def service_version(self): 116 | return self._service_version 117 | 118 | @property 119 | def pubsub_host(self): 120 | return self._pubsub_host 121 | 122 | @pubsub_host.setter 123 | def pubsub_host(self, pubsub_host): 124 | self._pubsub_host = pubsub_host 125 | 126 | @property 127 | def pubsub_port(self): 128 | return self._pubsub_port 129 | 130 | @pubsub_port.setter 131 | def pubsub_port(self, pubsub_port): 132 | self._pubsub_port = pubsub_port 133 | 134 | async def create_pubsub_handler(self): 135 | self._pubsub_handler = PubSub(self.pubsub_host, self.pubsub_port) 136 | await self._pubsub_handler.connect() 137 | 138 | def _get_pubsub_channel(self, endpoint): 139 | return '/'.join((self.service_name, str(self.service_version), endpoint)) 140 | 141 | async def register_for_subscription(self): 142 | subscription_list = [] 143 | for each in dir(self.__class__): 144 | fn = getattr(self.__class__, each) 145 | if callable(fn) and getattr(fn, 'is_subscribe', False): 146 | subscription_list.append(self._get_pubsub_channel(fn.__name__)) 147 | await self._pubsub_handler.subscribe(subscription_list, handler=self.subscription_handler) 148 | 149 | def subscription_handler(self, channel, payload): 150 | service, version, endpoint = channel.split('/') 151 | func = getattr(self, endpoint) 152 | asyncio.async(func(**json.loads(payload))) 153 | -------------------------------------------------------------------------------- /trellio/packet.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from uuid import uuid4 3 | 4 | 5 | class _Packet: 6 | _pid = 0 7 | 8 | @classmethod 9 | def _next_pid(cls): 10 | from uuid import uuid4 11 | 12 | return str(uuid4()) 13 | 14 | @classmethod 15 | def ack(cls, request_id): 16 | return {'pid': cls._next_pid(), 'type': 'ack', 'request_id': request_id} 17 | 18 | @classmethod 19 | def pong(cls, node_id, payload=None): 20 | if payload: 21 | return cls._get_ping_pong(node_id, 'ping', payload=payload) 22 | return cls._get_ping_pong(node_id, 'pong') 23 | 24 | @classmethod 25 | def ping(cls, node_id, payload=None): 26 | if payload: 27 | return cls._get_ping_pong(node_id, 'ping', payload=payload) 28 | return cls._get_ping_pong(node_id, 'ping') 29 | 30 | @classmethod 31 | def _get_ping_pong(cls, node_id, packet_type, payload=None): 32 | if payload: 33 | return {'pid': cls._next_pid(), 'type': packet_type, 'node_id': node_id, 'payload': payload} 34 | return {'pid': cls._next_pid(), 'type': packet_type, 'node_id': node_id} 35 | 36 | 37 | class ControlPacket(_Packet): 38 | @classmethod 39 | def registration(cls, ip: str, port: int, node_id, name: str, version: str, dependencies, service_type: str): 40 | v = [{'name': dependency.name, 'version': dependency.version} for dependency in dependencies] 41 | 42 | params = {'name': name, 43 | 'version': version, 44 | 'host': ip, 45 | 'port': port, 46 | 'node_id': node_id, 47 | 'dependencies': v, 48 | 'type': service_type} 49 | 50 | packet = {'pid': cls._next_pid(), 'type': 'register', 'params': params} 51 | return packet 52 | 53 | @classmethod 54 | def get_instances(cls, name, version): 55 | params = {'name': name, 'version': version} 56 | packet = {'pid': cls._next_pid(), 57 | 'type': 'get_instances', 58 | 'name': name, 59 | 'version': version, 60 | 'params': params, 61 | 'request_id': str(uuid4())} 62 | 63 | return packet 64 | 65 | @classmethod 66 | def get_subscribers(cls, name, version, endpoint): 67 | params = {'name': name, 'version': version, 'endpoint': endpoint} 68 | packet = {'pid': cls._next_pid(), 69 | 'type': 'get_subscribers', 70 | 'params': params, 71 | 'request_id': str(uuid4())} 72 | return packet 73 | 74 | @classmethod 75 | def send_instances(cls, name, version, request_id, instances): 76 | instance_packet = [{'host': host, 'port': port, 'node': node, 'type': service_type} for 77 | host, port, node, service_type in instances] 78 | instance_packet_params = {'name': name, 'version': version, 'instances': instance_packet} 79 | return {'pid': cls._next_pid(), 'type': 'instances', 'params': instance_packet_params, 'request_id': request_id} 80 | 81 | @classmethod 82 | # TODO : fix parsing on client side 83 | def deregister(cls, name, version, node_id): 84 | params = {'node_id': node_id, 'name': name, 'version': version} 85 | packet = {'pid': cls._next_pid(), 'type': 'deregister', 'params': params} 86 | return packet 87 | 88 | @classmethod 89 | def activated(cls, instances): 90 | dependencies = [] 91 | for k, v in instances.items(): 92 | dependency = defaultdict(list) 93 | dependency['name'] = k[0] 94 | dependency['version'] = k[1] 95 | for host, port, node, service_type in v: 96 | dependency_node_packet = { 97 | 'host': host, 98 | 'port': port, 99 | 'node_id': node, 100 | 'type': service_type 101 | } 102 | dependency['addresses'].append(dependency_node_packet) 103 | dependencies.append(dependency) 104 | params = { 105 | 'dependencies': dependencies 106 | } 107 | packet = {'pid': cls._next_pid(), 108 | 'type': 'registered', 109 | 'params': params} 110 | return packet 111 | 112 | @classmethod 113 | def xsubscribe(cls, name, version, host, port, node_id, endpoints): 114 | params = {'name': name, 'version': version, 'host': host, 'port': port, 'node_id': node_id} 115 | events = [{'name': _name, 'version': _version, 'endpoint': endpoint, 'strategy': strategy} for 116 | _name, _version, endpoint, strategy in endpoints] 117 | params['events'] = events 118 | packet = {'pid': cls._next_pid(), 119 | 'type': 'xsubscribe', 120 | 'params': params} 121 | return packet 122 | 123 | @classmethod 124 | def subscribers(cls, name, version, endpoint, request_id, subscribers): 125 | params = {'name': name, 'version': version, 'endpoint': endpoint} 126 | subscribers = [{'name': _name, 'version': _version, 'host': host, 'port': port, 'node_id': node_id, 127 | 'strategy': strategy} for _name, _version, host, port, node_id, strategy in subscribers] 128 | params['subscribers'] = subscribers 129 | packet = {'pid': cls._next_pid(), 130 | 'request_id': request_id, 131 | 'type': 'subscribers', 132 | 'params': params} 133 | return packet 134 | 135 | @classmethod 136 | def uptime(cls, uptimes): 137 | packet = {'pid': cls._next_pid(), 138 | 'type': 'uptime_report', 139 | 'params': dict(uptimes)} 140 | return packet 141 | 142 | @classmethod 143 | def new_instance(cls, name, version, host, port, node_id, service_type): 144 | params = {'name': name, 'version': version, 'host': host, 'port': port, 'node_id': node_id, 145 | 'service_type': service_type} 146 | return {'pid': cls._next_pid(), 147 | 'type': 'new_instance', 148 | 'params': params} 149 | 150 | 151 | class MessagePacket(_Packet): 152 | @classmethod 153 | def request(cls, name, version, app_name, packet_type, endpoint, params, entity): 154 | return {'pid': cls._next_pid(), 155 | 'app': app_name, 156 | 'name': name, 157 | 'version': version, 158 | 'entity': entity, 159 | 'endpoint': endpoint, 160 | 'type': packet_type, 161 | 'payload': params} 162 | 163 | @classmethod 164 | def publish(cls, publish_id, name, version, endpoint, payload): 165 | return {'pid': cls._next_pid(), 166 | 'type': 'publish', 167 | 'name': name, 168 | 'version': version, 169 | 'endpoint': endpoint, 170 | 'payload': payload, 171 | 'publish_id': publish_id} 172 | -------------------------------------------------------------------------------- /trellio/utils/log.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import datetime 3 | import logging 4 | import logging.config 5 | import sys 6 | from functools import partial, wraps 7 | from logging import Handler 8 | from queue import Queue 9 | from threading import Thread 10 | 11 | import yaml 12 | from pythonjsonlogger import jsonlogger 13 | 14 | RED = '\033[91m' 15 | BLUE = '\033[94m' 16 | BOLD = '\033[1m' 17 | END = '\033[0m' 18 | 19 | 20 | class CustomTimeLoggingFormatter(logging.Formatter): 21 | def formatTime(self, record, datefmt=None): # noqa 22 | """ 23 | Overrides formatTime method to use datetime module instead of time module 24 | to display time in microseconds. Time module by default does not resolve 25 | time to microseconds. 26 | """ 27 | if datefmt: 28 | s = datetime.datetime.now().strftime(datefmt) 29 | else: 30 | t = datetime.datetime.now().strftime(self.default_time_format) 31 | s = self.default_msec_format % (t, record.msecs) 32 | return s 33 | 34 | 35 | class CustomJsonFormatter(jsonlogger.JsonFormatter): 36 | def __init__(self, *args, **kwargs): 37 | self.extrad = kwargs.pop('extrad', {}) 38 | super().__init__(*args, **kwargs) 39 | 40 | def add_fields(self, log_record, record, message_dict): 41 | message_dict.update(self.extrad) 42 | super().add_fields(log_record, record, message_dict) 43 | 44 | 45 | def patch_async_emit(handler: Handler): 46 | base_emit = handler.emit 47 | queue = Queue() 48 | 49 | def loop(): 50 | while True: 51 | record = queue.get() 52 | try: 53 | base_emit(record) 54 | except: 55 | print(sys.exc_info()) 56 | 57 | def async_emit(record): 58 | queue.put(record) 59 | 60 | thread = Thread(target=loop) 61 | thread.daemon = True 62 | thread.start() 63 | handler.emit = async_emit 64 | return handler 65 | 66 | 67 | def patch_add_handler(logger): 68 | base_add_handler = logger.addHandler 69 | 70 | def async_add_handler(handler): 71 | async_handler = patch_async_emit(handler) 72 | base_add_handler(async_handler) 73 | 74 | return async_add_handler 75 | 76 | 77 | DEFAULT_CONFIG_YAML = """ 78 | # logging config 79 | 80 | version: 1 81 | disable_existing_loggers: False 82 | handlers: 83 | stream: 84 | class: logging.StreamHandler 85 | level: INFO 86 | formatter: ctf 87 | stream: ext://sys.stdout 88 | 89 | stats: 90 | class: logging.StreamHandler 91 | level: INFO 92 | formatter: cjf 93 | stream: ext://sys.stdout 94 | 95 | service: 96 | class: logging.StreamHandler 97 | level: INFO 98 | formatter: ctf 99 | stream: ext://sys.stdout 100 | 101 | formatters: 102 | ctf: 103 | (): trellio.utils.log.CustomTimeLoggingFormatter 104 | format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s' 105 | datefmt: '%Y-%m-%d %H:%M:%S,%f' 106 | 107 | cjf: 108 | (): trellio.utils.log.CustomJsonFormatter 109 | format: '{ "timestamp":"%(asctime)s", "message":"%(message)s"}' 110 | datefmt: '%Y-%m-%d %H:%M:%S,%f' 111 | 112 | root: 113 | handlers: [stream] 114 | level: INFO 115 | 116 | loggers: 117 | registry: 118 | handlers: [stream] 119 | level: INFO 120 | 121 | stats: 122 | handlers: [stats] 123 | level: INFO 124 | 125 | """ 126 | 127 | 128 | def setup_logging(_): 129 | try: 130 | with open('config_log.json', 'r') as f: 131 | config_dict = yaml.load(f.read()) 132 | except: 133 | config_dict = yaml.load(DEFAULT_CONFIG_YAML) 134 | 135 | logging.getLogger('asyncio').setLevel(logging.DEBUG) 136 | logger = logging.getLogger() 137 | logger.propagate = False 138 | logger.handlers = [] 139 | logger.addHandler = patch_add_handler(logger) 140 | logging.config.dictConfig(config_dict) 141 | 142 | 143 | def log(fn=None, logger=logging.getLogger(), debug_level=logging.DEBUG): 144 | """ 145 | logs parameters and result - takes no arguments 146 | """ 147 | if fn is None: 148 | return partial(log, logger=logger, debug_level=debug_level) 149 | 150 | @wraps(fn) 151 | def func(*args, **kwargs): 152 | arg_string = "" 153 | for i in range(0, len(args)): 154 | var_name = fn.__code__.co_varnames[i] 155 | if var_name not in ['self', 'cls']: 156 | arg_string += var_name + ":" + str(args[i]) + "," 157 | arg_string = arg_string[0:len(arg_string) - 1] 158 | string = (RED + BOLD + '>> ' + END + 'Calling {0}({1})'.format(fn.__name__, arg_string)) 159 | if len(kwargs): 160 | string = ( 161 | RED + BOLD + '>> ' + END + 'Calling {0} with args {1} and kwargs {2}'.format(fn.__name__, arg_string, 162 | kwargs)) 163 | logger.log(debug_level, string) 164 | wrapped_fn = fn 165 | if not asyncio.iscoroutine(fn): 166 | wrapped_fn = asyncio.coroutine(fn) 167 | try: 168 | result = yield from wrapped_fn(*args, **kwargs) 169 | string = BLUE + BOLD + '<< ' + END + 'Return {0} with result :{1}'.format(fn.__name__, result) 170 | logger.log(debug_level, string) 171 | return result 172 | except Exception as e: 173 | string = (RED + BOLD + '>> ' + END + '{0} raised exception :{1}'.format(fn.__name__, str(e))) 174 | logger.log(debug_level, string) 175 | raise e 176 | 177 | return func 178 | 179 | 180 | def logx(supress_args=[], supress_all_args=False, supress_result=False, logger=logging.getLogger(), 181 | debug_level=logging.DEBUG): 182 | """ 183 | logs parameters and result 184 | takes arguments 185 | supress_args - list of parameter names to supress 186 | supress_all_args - boolean to supress all arguments 187 | supress_result - boolean to supress result 188 | receiver - custom logging function which takes a string as input; defaults to logging on stdout 189 | """ 190 | 191 | def decorator(fn): 192 | def func(*args, **kwargs): 193 | if not supress_all_args: 194 | arg_string = "" 195 | for i in range(0, len(args)): 196 | var_name = fn.__code__.co_varnames[i] 197 | if var_name != "self" and var_name not in supress_args: 198 | arg_string += var_name + ":" + str(args[i]) + "," 199 | arg_string = arg_string[0:len(arg_string) - 1] 200 | string = (RED + BOLD + '>> ' + END + 'Calling {0}({1})'.format(fn.__name__, arg_string)) 201 | if len(kwargs): 202 | string = ( 203 | RED + BOLD + '>> ' + END + 'Calling {0} with args {1} and kwargs {2}'.format( 204 | fn.__name__, 205 | arg_string, kwargs)) 206 | logger.log(debug_level, string) 207 | 208 | wrapped_fn = fn 209 | if not asyncio.iscoroutine(fn): 210 | wrapped_fn = asyncio.coroutine(fn) 211 | result = yield from wrapped_fn(*args, **kwargs) 212 | 213 | if not supress_result: 214 | string = BLUE + BOLD + '<< ' + END + 'Return {0} with result : {1}'.format(fn.__name__, result) 215 | logger.log(debug_level, string) 216 | return result 217 | 218 | return func 219 | 220 | return decorator 221 | -------------------------------------------------------------------------------- /trellio/registry_client.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import random 4 | from collections import defaultdict 5 | from functools import partial 6 | 7 | from retrial.retrial import retry 8 | 9 | from .packet import ControlPacket 10 | from .pinger import TCPPinger 11 | from .protocol_factory import get_trellio_protocol 12 | 13 | try: 14 | from uvloop.loop import TCPTransport as Transport 15 | except ImportError: 16 | from asyncio.transports import Transport 17 | 18 | 19 | def _retry_for_result(result): 20 | if isinstance(result, tuple): 21 | return not isinstance(result[0], Transport) or not isinstance(result[1], asyncio.Protocol) 22 | return True 23 | 24 | 25 | def _retry_for_exception(_): 26 | return True 27 | 28 | 29 | class RegistryClient: 30 | def __init__(self, loop, host, port, ssl_context=None): 31 | self._loop = loop 32 | self._port = port 33 | self._host = host 34 | self.bus = None 35 | self._transport = None 36 | self._protocol = None 37 | self._service = None 38 | self._version = None 39 | self._node_ids = {} 40 | self._pinger = None 41 | self._conn_handler = None 42 | self._pending_requests = {} 43 | self._available_services = defaultdict(list) 44 | self._assigned_services = defaultdict(lambda: defaultdict(list)) 45 | self._ssl_context = ssl_context 46 | self.logger = logging.getLogger(__name__) 47 | 48 | @property 49 | def conn_handler(self): 50 | return self._conn_handler 51 | 52 | @conn_handler.setter 53 | def conn_handler(self, handler): 54 | self._conn_handler = handler 55 | 56 | def register(self, ip, port, service, version, node_id, vendors, service_type): # here vendors are tcp/http_clients 57 | self._service = service 58 | self._version = version 59 | self._node_ids[service_type] = node_id 60 | packet = ControlPacket.registration(ip, port, node_id, service, version, vendors, service_type) 61 | self._protocol.send(packet) 62 | 63 | def get_instances(self, name, version): 64 | packet = ControlPacket.get_instances(name, version) 65 | future = asyncio.Future() 66 | self._protocol.send(packet) 67 | self._pending_requests[packet['request_id']] = future 68 | return future 69 | 70 | def get_subscribers(self, service, version, endpoint): 71 | packet = ControlPacket.get_subscribers(service, version, endpoint) 72 | # TODO : remove duplication in get_instances and get_subscribers 73 | future = asyncio.Future() 74 | self._protocol.send(packet) 75 | self._pending_requests[packet['request_id']] = future 76 | return future 77 | 78 | def x_subscribe(self, host, port, node_id, endpoints): 79 | packet = ControlPacket.xsubscribe(self._service, self._version, host, port, node_id, 80 | endpoints) 81 | self._protocol.send(packet) 82 | 83 | @retry(should_retry_for_result=_retry_for_result, should_retry_for_exception=_retry_for_exception, 84 | strategy=[0, 2, 4, 8, 16, 32]) 85 | def connect(self): 86 | self._transport, self._protocol = yield from self._loop.create_connection(partial(get_trellio_protocol, self), 87 | self._host, self._port, 88 | ssl=self._ssl_context) 89 | self.conn_handler.handle_connected() 90 | self._pinger = TCPPinger(self._host, self._port, 'registry', self._protocol, self) 91 | self._pinger.ping(payload=self._node_ids) 92 | return self._transport, self._protocol 93 | 94 | def on_timeout(self, host, port, node_id): 95 | asyncio.ensure_future(self.connect()) 96 | 97 | def receive(self, packet: dict, protocol, transport): 98 | if packet['type'] == 'registered': 99 | self.cache_vendors(packet['params']['dependencies']) 100 | self.bus.registration_complete() 101 | elif packet['type'] == 'new_instance': 102 | # TODO : once method for both vendors and new instance 103 | self.cache_instance(**packet['params']) 104 | self._handle_new_instance(**packet['params']) 105 | elif packet['type'] == 'deregister': 106 | self._handle_deregistration(packet) 107 | elif packet['type'] == 'subscribers': 108 | self._handle_subscriber_packet(packet) 109 | elif packet['type'] == 'pong': 110 | self._pinger.pong_received(payload=self._node_ids) 111 | elif packet['type'] == 'instances': 112 | self._handle_get_instances(packet) 113 | 114 | def get_all_addresses(self, name, version): 115 | return self._available_services.get( 116 | self._get_full_service_name(name, version)) 117 | 118 | def get_for_node(self, node_id): 119 | for services in self._available_services.values(): 120 | for host, port, node, service_type in services: 121 | if node == node_id: 122 | return host, port, node, service_type 123 | return None 124 | 125 | def get_random_service(self, service_name, service_type): 126 | services = self._available_services[service_name] 127 | services = [service for service in services if service[3] == service_type] 128 | if len(services): 129 | return random.choice(services) 130 | else: 131 | return None 132 | 133 | def resolve(self, service: str, version: str, entity: str, service_type: str): 134 | service_name = self._get_full_service_name(service, version) 135 | if entity is not None: 136 | entity_map = self._assigned_services.get(service_name) 137 | if entity_map is None: 138 | self._assigned_services[service_name] = {} 139 | entity_map = self._assigned_services.get(service_name) 140 | if entity in entity_map: 141 | return entity_map[entity] 142 | else: 143 | host, port, node_id, service_type = self.get_random_service(service_name, service_type) 144 | if node_id is not None: 145 | entity_map[entity] = host, port, node_id, service_type 146 | return host, port, node_id, service_type 147 | else: 148 | return self.get_random_service(service_name, service_type) 149 | 150 | @staticmethod 151 | def _get_full_service_name(service, version): 152 | return "{}/{}".format(service, version) 153 | 154 | def cache_vendors(self, dependencies): 155 | for dependency in dependencies: 156 | vendor_name = self._get_full_service_name(dependency['name'], dependency['version']) 157 | for address in dependency['addresses']: 158 | self._available_services[vendor_name].append( 159 | (address['host'], address['port'], address['node_id'], address['type'])) 160 | self.logger.debug('Connection cache after registration is %s', self._available_services) 161 | 162 | def cache_instance(self, name, version, host, port, node_id, service_type): 163 | vendor = self._get_full_service_name(name, version) 164 | self._available_services[vendor].append((host, port, node_id, service_type)) 165 | self.logger.debug('Connection cache on getting new instance is %s', self._available_services) 166 | 167 | def _handle_deregistration(self, packet): 168 | params = packet['params'] 169 | vendor = self._get_full_service_name(params['name'], params['version']) 170 | node = params['node_id'] 171 | for each in self._available_services[vendor]: 172 | if each[2] == node: 173 | self._available_services[vendor].remove(each) 174 | entity_map = self._assigned_services.get(vendor) 175 | if entity_map is not None: 176 | stale_entities = [] 177 | for entity, node_id in entity_map.items(): 178 | if node == node_id: 179 | stale_entities.append(entity) 180 | for entity in stale_entities: 181 | entity_map.pop(entity) 182 | self.logger.debug('Connection cache after deregister is %s', self._available_services) 183 | 184 | def _handle_subscriber_packet(self, packet): 185 | request_id = packet['request_id'] 186 | future = self._pending_requests.pop(request_id, None) 187 | future.set_result(packet['params']['subscribers']) 188 | 189 | def _handle_get_instances(self, packet): 190 | future = self._pending_requests.pop(packet['request_id'], None) 191 | future.set_result(packet['params']['instances']) 192 | 193 | def _handle_new_instance(self, name, version, host, port, node_id, service_type): 194 | self.bus.new_instance(name, version, host, port, node_id, service_type) 195 | -------------------------------------------------------------------------------- /trellio/conf_manager/conf_client.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import importlib 3 | import json 4 | import logging 5 | import os 6 | 7 | from trellio.services import TCPService, HTTPService 8 | from ..utils.log_handlers import BufferingSMTPHandler 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | GLOBAL_CONFIG = { 13 | "RONIN": False, 14 | "HOST_NAME": "", 15 | "ADMIN_EMAILS": [], 16 | "SERVICE_NAME": "", 17 | "SERVICE_VERSION": "", 18 | "REGISTRY_HOST": "", 19 | "REGISTRY_PORT": "", 20 | "REDIS_HOST": "", 21 | "REDIS_PORT": "", 22 | "HTTP_HOST": "", 23 | "TCP_HOST": "", 24 | "HTTP_PORT": "", 25 | "TCP_PORT": "", 26 | "SIGNALS": {}, 27 | "MIDDLEWARES": [], 28 | "APPS": [], 29 | "DATABASE_SETTINGS": { 30 | "database": "", 31 | "user": "", 32 | "password": "", 33 | "host": "", 34 | "port": "" 35 | }, 36 | "SMTP_SETTINGS": {} 37 | } 38 | 39 | 40 | class InvalidConfigurationError(Exception): 41 | pass 42 | 43 | 44 | class ConfigHandler: 45 | smtp_host = 'SMTP_HOST' 46 | smtp_user = 'SMTP_USER' 47 | smtp_port = 'SMTP_PORT' 48 | smtp_password = 'SMTP_PASSWORD' 49 | admin_emails = 'ADMIN_EMAILS' 50 | middleware_key = 'MIDDLEWARES' 51 | signal_key = 'SIGNALS' 52 | service_name_key = 'SERVICE_NAME' 53 | host_name_key = 'HOST_NAME' 54 | service_version_key = 'SERVICE_VERSION' 55 | reg_host_key = "REGISTRY_HOST" 56 | reg_port_key = "REGISTRY_PORT" 57 | redis_host_key = "REDIS_HOST" 58 | redis_port_key = "REDIS_PORT" 59 | http_host_key = "HTTP_HOST" 60 | tcp_host_key = "TCP_HOST" 61 | http_port_key = "HTTP_PORT" 62 | tcp_port_key = "TCP_PORT" 63 | database_key = 'DATABASE_SETTINGS' 64 | ronin_key = "RONIN" 65 | smtp_key = 'SMTP_SETTINGS' 66 | apps_key = 'APPS' 67 | 68 | # service_path_key = "SERVICE_PATH" 69 | 70 | def __init__(self, host_class): 71 | self.settings = None 72 | self.host = host_class 73 | 74 | @property 75 | def service_name(self): 76 | return self.settings[self.service_name_key] 77 | 78 | def get_tcp_clients(self): 79 | from trellio.services import TCPServiceClient 80 | tcp_clients = self.inheritors(TCPServiceClient) 81 | return tcp_clients 82 | 83 | def get_http_clients(self): 84 | from trellio.services import HTTPServiceClient 85 | http_clients = self.inheritors(HTTPServiceClient) 86 | return http_clients 87 | 88 | def get_subscribers(self): 89 | from trellio.pubsub import Subscriber 90 | subscriber_classes = self.inheritors(Subscriber) 91 | subscribers = [] 92 | for subs in subscriber_classes: 93 | s = subs() 94 | s.pubsub_host = self.settings[self.redis_host_key] 95 | s.pubsub_port = self.settings[self.redis_port_key] 96 | subscribers.append(s) 97 | return subscribers 98 | 99 | def configure_host(self, host): 100 | host.configure( 101 | host_name=self.settings[self.host_name_key], 102 | service_name=self.settings[self.service_name_key], 103 | service_version=self.settings[self.service_version_key], 104 | http_host=self.settings[self.http_host_key], 105 | http_port=self.settings[self.http_port_key], 106 | tcp_host=self.settings[self.tcp_host_key], 107 | tcp_port=self.settings[self.tcp_port_key], 108 | registry_host=self.settings[self.reg_host_key], 109 | registry_port=self.settings[self.reg_port_key], 110 | pubsub_host=self.settings[self.redis_host_key], 111 | pubsub_port=self.settings[self.reg_port_key], 112 | ronin=self.settings[self.ronin_key] 113 | ) 114 | 115 | def setup_host(self): 116 | host = self.host 117 | self.configure_host(host) 118 | publisher = self.get_publisher() 119 | subscribers = self.get_subscribers() 120 | if publisher: 121 | host.attach_publisher(publisher) 122 | if subscribers: 123 | host.attach_subscribers(subscribers) 124 | 125 | http_service = self.get_http_service() 126 | tcp_service = self.get_tcp_service() 127 | tcp_clients = self.get_tcp_clients() 128 | http_clients = self.get_http_clients() 129 | http_views = self.get_http_views() 130 | tcp_views = self.get_tcp_views() 131 | 132 | if not http_service: 133 | http_service = HTTPService(host.service_name, host.service_version, host.http_host, host.http_port) 134 | 135 | if not tcp_service: 136 | tcp_service = TCPService(host.service_name, host.service_version, host.tcp_host, host.tcp_port) 137 | 138 | self.enable_signals() 139 | self.enable_middlewares(http_service=http_service, http_views=http_views) 140 | 141 | if http_service: 142 | # self.register_http_views(http_service) 143 | host.attach_service(http_service) 144 | http_service.clients = [i() for i in http_clients + tcp_clients] 145 | # self.register_tcp_views(tcp_service) 146 | 147 | host.attach_service(tcp_service) 148 | 149 | if http_service: 150 | tcp_service.clients = http_service.clients 151 | 152 | if http_views: 153 | host.attach_http_views(http_views) 154 | for view_inst in host.get_tcp_views(): 155 | pass 156 | 157 | if tcp_views: 158 | host.attach_tcp_views(tcp_views) 159 | _tcp_service = host.get_tcp_service() 160 | _tcp_service.tcp_views = host._tcp_views 161 | 162 | host._smtp_handler = self.get_smtp_logging_handler() 163 | 164 | def get_database_settings(self): 165 | return self.settings[self.database_key] 166 | 167 | def set_config(self, config_path): 168 | settings = None 169 | with open(config_path) as f: 170 | settings = json.load(f) 171 | new_settings = copy.deepcopy(GLOBAL_CONFIG) 172 | new_settings.update(settings) 173 | self.settings = new_settings 174 | parent_dir = os.getcwd().split('/')[-1] 175 | client_path = parent_dir + '.clients' 176 | service_path = parent_dir + '.service' 177 | 178 | try: 179 | importlib.import_module(client_path) 180 | except: 181 | logger.warning('No clients found') 182 | 183 | service_imported = True 184 | service_exception = None 185 | try: 186 | importlib.import_module(service_path) 187 | except Exception as e: 188 | service_imported = False 189 | service_exception = e.__traceback__ 190 | 191 | if self.settings.get(self.apps_key): 192 | apps = self.settings[self.apps_key] 193 | for app in apps: 194 | views_path = parent_dir + '.{}.views'.format(app) 195 | try: 196 | importlib.import_module(views_path) 197 | except Exception as e: 198 | print(e.__traceback__.__str__()) 199 | else: 200 | if not service_imported: 201 | print(service_exception.__str__()) 202 | 203 | def get_smtp_logging_handler(self): 204 | if self.settings.get(self.smtp_key): 205 | keys = ["smtp_host", "smtp_port", "smtp_user", "smtp_password"] 206 | setting_keys = self.settings[self.smtp_key].keys() 207 | missing_keys = list(filter(lambda x: x not in setting_keys, keys)) 208 | if not missing_keys: 209 | handler = BufferingSMTPHandler(mailhost=self.settings[self.smtp_key]['smtp_host'], 210 | mailport=self.settings[self.smtp_key]['smtp_port'], 211 | fromaddr=self.settings[self.smtp_key]['smtp_user'], 212 | toaddrs=self.settings[self.admin_emails], 213 | subject='Error {} {}:{}'.format(self.settings[self.host_name_key], 214 | self.settings[ 215 | self.service_name_key].upper(), 216 | self.settings[self.service_version_key]), 217 | capacity=1, 218 | password=self.settings[self.smtp_key]['smtp_password']) 219 | handler.setLevel(logging.ERROR) 220 | if not self.settings[self.ronin_key]: 221 | return handler 222 | 223 | def get_http_service(self): 224 | from trellio.services import HTTPService 225 | http_service = None 226 | if HTTPService.__subclasses__(): 227 | service_sub_class = HTTPService.__subclasses__()[0] 228 | 229 | http_service = service_sub_class(self.settings[self.service_name_key], 230 | self.settings[self.service_version_key], 231 | self.settings[self.http_host_key], 232 | self.settings[self.http_port_key]) 233 | return http_service 234 | 235 | def get_tcp_service(self): 236 | from trellio.services import TCPService 237 | tcp_service = None 238 | if TCPService.__subclasses__(): 239 | service_sub_class = TCPService.__subclasses__()[0] 240 | tcp_service = service_sub_class(self.settings[self.service_name_key], 241 | self.settings[self.service_version_key], 242 | self.settings[self.tcp_host_key], 243 | self.settings[self.tcp_port_key]) 244 | return tcp_service 245 | 246 | def get_publisher(self): 247 | from trellio.pubsub import Publisher 248 | publisher = None 249 | if Publisher.__subclasses__(): 250 | publisher_sub_class = Publisher.__subclasses__()[0] 251 | publisher = publisher_sub_class(self.settings[self.service_name_key], 252 | self.settings[self.service_version_key], 253 | self.settings[self.redis_host_key], 254 | self.settings[self.redis_port_key]) 255 | return publisher 256 | 257 | def get_http_views(self): 258 | from trellio.views import HTTPView 259 | return self.inheritors(HTTPView) 260 | 261 | def get_tcp_views(self): 262 | from trellio.views import TCPView 263 | return self.inheritors(TCPView) 264 | 265 | def import_class_from_path(self, path): 266 | broken = path.split('.') 267 | class_name = broken[-1] 268 | module_name = '.'.join(broken[:-1]) 269 | module = importlib.import_module(module_name) 270 | class_value = getattr(module, class_name) 271 | return module, class_value 272 | 273 | def enable_middlewares(self, http_service=None, http_views=()): 274 | middlewares = self.settings[self.middleware_key] or [] 275 | middle_cls = [] 276 | for i in middlewares: 277 | module, class_value = self.import_class_from_path(i) 278 | if not class_value: 279 | raise InvalidConfigurationError 280 | else: 281 | middle_cls.append(class_value()) 282 | 283 | if http_service: 284 | http_service.middlewares = middle_cls 285 | for view in http_views: 286 | view.middlewares = middle_cls 287 | 288 | def enable_signals(self): 289 | ''' 290 | e.g signal_dict = {signal_path:signal_receiver_path_list, ....} 291 | :return: 292 | ''' 293 | signal_dict = self.settings[self.signal_key] or {} 294 | for i in signal_dict.keys(): 295 | sig_module, signal_class = self.import_class_from_path(i) 296 | for j in signal_dict[i]: 297 | recv_module, recv_coro = self.import_class_from_path(j) 298 | signal_class.register(recv_coro) # registering reciever 299 | 300 | @staticmethod 301 | def inheritors(klass): 302 | subclasses = set() 303 | work = [klass] 304 | while work: 305 | parent = work.pop() 306 | for child in parent.__subclasses__(): 307 | if child not in subclasses: 308 | subclasses.add(child) 309 | work.append(child) 310 | return list(subclasses) 311 | -------------------------------------------------------------------------------- /trellio/host.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import os 4 | import signal 5 | import warnings 6 | from functools import partial 7 | 8 | from aiohttp.web import Application 9 | 10 | from .bus import TCPBus 11 | from .protocol_factory import get_trellio_protocol 12 | from .pubsub import Publisher, Subscriber 13 | from .registry_client import RegistryClient 14 | from .services import HTTPService, TCPService 15 | from .signals import ServiceReady 16 | from .utils.decorators import deprecated 17 | from .utils.log import setup_logging 18 | from .utils.stats import Stats, Aggregator 19 | 20 | 21 | class Host: 22 | """Serves as a static entry point and provides the boilerplate required to host and run a trellio Service. 23 | 24 | Example:: 25 | 26 | Host.configure('SampleService') 27 | Host.attachService(SampleHTTPService()) 28 | Host.run() 29 | 30 | """ 31 | registry_host = None 32 | registry_port = None 33 | pubsub_host = None 34 | pubsub_port = None 35 | host_name = None 36 | service_name = None 37 | http_host = None 38 | http_port = None 39 | tcp_host = None 40 | tcp_port = None 41 | ssl_context = None 42 | ronin = False # If true, the trellio service runs solo without a registry 43 | 44 | _host_id = None 45 | _tcp_service = None 46 | _http_service = None 47 | _publisher = None 48 | _subscribers = [] 49 | _tcp_views = [] 50 | _http_views = [] 51 | _logger = logging.getLogger(__name__) 52 | _smtp_handler = None 53 | 54 | @classmethod 55 | def configure(cls, host_name: str = '', service_name: str = '', service_version='', 56 | http_host: str = '127.0.0.1', http_port: int = 8000, 57 | tcp_host: str = '127.0.0.1', tcp_port: int = 8001, ssl_context=None, 58 | registry_host: str = "0.0.0.0", registry_port: int = 4500, 59 | pubsub_host: str = "0.0.0.0", pubsub_port: int = 6379, ronin: bool = False): 60 | """ A convenience method for providing registry and pubsub(redis) endpoints 61 | 62 | :param host_name: Used for process name 63 | :param registry_host: IP Address for trellio-registry; default = 0.0.0.0 64 | :param registry_port: Port for trellio-registry; default = 4500 65 | :param pubsub_host: IP Address for pubsub component, usually redis; default = 0.0.0.0 66 | :param pubsub_port: Port for pubsub component; default= 6379 67 | :return: None 68 | """ 69 | Host.host_name = host_name 70 | Host.service_name = service_name 71 | Host.service_version = str(service_version) 72 | Host.http_host = http_host 73 | Host.http_port = http_port 74 | Host.tcp_host = tcp_host 75 | Host.tcp_port = tcp_port 76 | Host.registry_host = registry_host 77 | Host.registry_port = registry_port 78 | Host.pubsub_host = pubsub_host 79 | Host.pubsub_port = pubsub_port 80 | Host.ssl_context = ssl_context 81 | Host.ronin = ronin 82 | 83 | @classmethod 84 | def get_http_service(cls): 85 | return cls._http_service 86 | 87 | @classmethod 88 | def get_tcp_service(cls): 89 | return cls._tcp_service 90 | 91 | @classmethod 92 | def get_tcp_clients(cls): 93 | tcp_service = cls.get_tcp_service() 94 | if tcp_service: 95 | return tcp_service.clients 96 | 97 | @classmethod 98 | def get_publisher(cls): 99 | return cls._publisher 100 | 101 | @classmethod 102 | def get_subscribers(cls): 103 | return cls._subscribers 104 | 105 | @classmethod 106 | def get_tcp_views(cls): 107 | return cls._tcp_views 108 | 109 | @classmethod 110 | def get_http_views(cls): 111 | return cls._http_views 112 | 113 | @classmethod 114 | @deprecated 115 | def attach_service(cls, service): 116 | """ Allows you to attach one TCP and one HTTP service 117 | 118 | deprecated:: 2.1.73 use http and tcp specific methods 119 | :param service: A trellio TCP or HTTP service that needs to be hosted 120 | """ 121 | if isinstance(service, HTTPService): 122 | cls._http_service = service 123 | elif isinstance(service, TCPService): 124 | cls._tcp_service = service 125 | else: 126 | cls._logger.error('Invalid argument attached as service') 127 | cls._set_bus(service) 128 | 129 | @classmethod 130 | def attach_http_service(cls, http_service: HTTPService): 131 | """ Attaches a service for hosting 132 | :param http_service: A HTTPService instance 133 | """ 134 | if cls._http_service is None: 135 | cls._http_service = http_service 136 | cls._set_bus(http_service) 137 | else: 138 | warnings.warn('HTTP service is already attached') 139 | 140 | @classmethod 141 | def attach_tcp_service(cls, tcp_service: TCPService): 142 | """ Attaches a service for hosting 143 | :param tcp_service: A TCPService instance 144 | """ 145 | if cls._tcp_service is None: 146 | cls._tcp_service = tcp_service 147 | cls._set_bus(tcp_service) 148 | else: 149 | warnings.warn('TCP service is already attached') 150 | 151 | @classmethod 152 | def attach_http_views(cls, http_views: list): 153 | views_instances = [] 154 | for view_class in http_views: 155 | instance = view_class() 156 | instance.host = Host 157 | views_instances.append(instance) 158 | cls._http_views.extend(views_instances) 159 | 160 | @classmethod 161 | def attach_tcp_views(cls, tcp_views: list): 162 | views_instances = [] 163 | for view_class in tcp_views: 164 | instance = view_class() 165 | instance.host = Host 166 | views_instances.append(instance) 167 | cls._tcp_views.extend(views_instances) 168 | 169 | @classmethod 170 | def attach_publisher(cls, publisher: Publisher): 171 | if cls._publisher is None: 172 | cls._publisher = publisher 173 | else: 174 | warnings.warn('Publisher is already attached') 175 | 176 | @classmethod 177 | def attach_subscribers(cls, subscribers: list): 178 | if all([isinstance(subscriber, Subscriber) for subscriber in subscribers]): 179 | if not cls._subscribers: 180 | cls._subscribers = subscribers 181 | else: 182 | warnings.warn('Subscribers are already attached') 183 | 184 | @classmethod 185 | def run(cls): 186 | """ Fires up the event loop and starts serving attached services 187 | """ 188 | if cls._tcp_service or cls._http_service or cls._http_views or cls._tcp_views: 189 | cls._set_host_id() 190 | cls._setup_logging() 191 | cls._set_process_name() 192 | cls._set_signal_handlers() 193 | cls._start_pubsub() 194 | cls._start_server() 195 | else: 196 | cls._logger.error('No services to host') 197 | 198 | @classmethod 199 | def _set_process_name(cls): 200 | from setproctitle import setproctitle 201 | setproctitle('trellio_{}_{}'.format(cls.host_name, cls._host_id)) 202 | 203 | @classmethod 204 | def _stop(cls, signame: str): 205 | cls._logger.info('\ngot signal {} - exiting'.format(signame)) 206 | asyncio.get_event_loop().stop() 207 | 208 | @classmethod 209 | def _set_signal_handlers(cls): 210 | asyncio.get_event_loop().add_signal_handler(getattr(signal, 'SIGINT'), partial(cls._stop, 'SIGINT')) 211 | asyncio.get_event_loop().add_signal_handler(getattr(signal, 'SIGTERM'), partial(cls._stop, 'SIGTERM')) 212 | 213 | @classmethod 214 | def _create_tcp_server(cls): 215 | if cls._tcp_service: 216 | ssl_context = cls._tcp_service.ssl_context 217 | host_ip, host_port = cls._tcp_service.socket_address 218 | task = asyncio.get_event_loop().create_server(partial(get_trellio_protocol, cls._tcp_service.tcp_bus), 219 | host_ip, host_port, ssl=ssl_context) 220 | result = asyncio.get_event_loop().run_until_complete(task) 221 | return result 222 | 223 | @classmethod 224 | def _create_http_server(cls): 225 | if cls._http_service or cls._http_views: 226 | host_ip, host_port = cls.http_host, cls.http_port 227 | ssl_context = cls.ssl_context 228 | handler = cls._make_aiohttp_handler() 229 | task = asyncio.get_event_loop().create_server(handler, host_ip, host_port, ssl=ssl_context) 230 | return asyncio.get_event_loop().run_until_complete(task) 231 | 232 | @classmethod 233 | def _make_aiohttp_handler(cls): 234 | app = Application(loop=asyncio.get_event_loop()) 235 | 236 | if cls._http_service: 237 | for each in cls._http_service.__ordered__: 238 | # iterate all attributes in the service looking for http endpoints and add them 239 | fn = getattr(cls._http_service, each) 240 | if callable(fn) and getattr(fn, 'is_http_method', False): 241 | for path in fn.paths: 242 | app.router.add_route(fn.method, path, fn) 243 | if cls._http_service.cross_domain_allowed: 244 | # add an 'options' for this specific path to make it CORS friendly 245 | app.router.add_route('options', path, cls._http_service.preflight_response) 246 | 247 | for view in cls._http_views: 248 | for each in view.__ordered__: 249 | fn = getattr(view, each) 250 | if callable(fn) and getattr(fn, 'is_http_method', False): 251 | for path in fn.paths: 252 | app.router.add_route(fn.method, path, fn) 253 | if view.cross_domain_allowed: 254 | # add an 'options' for this specific path to make it CORS friendly 255 | app.router.add_route('options', path, view.preflight_response) 256 | 257 | handler = app.make_handler(access_log=cls._logger) 258 | return handler 259 | 260 | @classmethod 261 | def _set_host_id(cls): 262 | from uuid import uuid4 263 | cls._host_id = uuid4() 264 | 265 | @classmethod 266 | def _start_server(cls): 267 | tcp_server = cls._create_tcp_server() 268 | http_server = cls._create_http_server() 269 | if not cls.ronin: 270 | if cls._tcp_service: 271 | asyncio.get_event_loop().run_until_complete(cls._tcp_service.tcp_bus.connect()) 272 | # if cls._http_service: 273 | # asyncio.get_event_loop().run_until_complete(cls._http_service.tcp_bus.connect()) 274 | if tcp_server: 275 | cls._logger.info('Serving TCP on {}'.format(tcp_server.sockets[0].getsockname())) 276 | if http_server: 277 | cls._logger.info('Serving HTTP on {}'.format(http_server.sockets[0].getsockname())) 278 | cls._logger.info("Event loop running forever, press CTRL+C to interrupt.") 279 | cls._logger.info("pid %s: send SIGINT or SIGTERM to exit." % os.getpid()) 280 | cls._logger.info("Triggering ServiceReady signal") 281 | asyncio.get_event_loop().run_until_complete(ServiceReady._run()) 282 | try: 283 | asyncio.get_event_loop().run_forever() 284 | except Exception as e: 285 | print(e) 286 | finally: 287 | if tcp_server: 288 | tcp_server.close() 289 | asyncio.get_event_loop().run_until_complete(tcp_server.wait_closed()) 290 | 291 | if http_server: 292 | http_server.close() 293 | asyncio.get_event_loop().run_until_complete(http_server.wait_closed()) 294 | 295 | asyncio.get_event_loop().close() 296 | 297 | @classmethod 298 | def _start_pubsub(cls): 299 | if not cls.ronin: 300 | if cls._publisher: 301 | asyncio.get_event_loop().run_until_complete(cls._publisher.create_pubsub_handler()) 302 | 303 | for subscriber in cls._subscribers: 304 | asyncio.get_event_loop().run_until_complete(subscriber.create_pubsub_handler()) 305 | asyncio.async(subscriber.register_for_subscription()) 306 | 307 | @classmethod 308 | def _set_bus(cls, service): 309 | registry_client = RegistryClient(asyncio.get_event_loop(), cls.registry_host, cls.registry_port) 310 | tcp_bus = TCPBus(registry_client) 311 | registry_client.conn_handler = tcp_bus 312 | # pubsub_bus = PubSubBus(cls.pubsub_host, cls.pubsub_port, registry_client) # , cls._tcp_service._ssl_context) 313 | registry_client.bus = tcp_bus 314 | if isinstance(service, TCPService): 315 | tcp_bus.tcp_host = service 316 | if isinstance(service, HTTPService): 317 | tcp_bus.http_host = service 318 | service.tcp_bus = tcp_bus 319 | # service.pubsub_bus = pubsub_bus 320 | 321 | @classmethod 322 | def _setup_logging(cls): 323 | identifier = '{}'.format(cls.service_name) 324 | setup_logging(identifier) 325 | if cls._smtp_handler: 326 | logger = logging.getLogger() 327 | logger.addHandler(cls._smtp_handler) 328 | Stats.service_name = cls.service_name 329 | Aggregator._service_name = cls.service_name 330 | Aggregator.periodic_aggregated_stats_logger() 331 | -------------------------------------------------------------------------------- /trellio/bus.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | from asyncio.coroutines import iscoroutine, coroutine 4 | from functools import partial 5 | 6 | import aiohttp 7 | from again.utils import unique_hex 8 | from retrial.retrial.retry import retry 9 | 10 | from .exceptions import ClientNotFoundError, ClientDisconnected 11 | from .packet import ControlPacket, MessagePacket 12 | from .protocol_factory import get_trellio_protocol 13 | from .services import TCPServiceClient, HTTPServiceClient 14 | 15 | HTTP = 'http' 16 | TCP = 'tcp' 17 | 18 | 19 | def _retry_for_pub(result): 20 | return not result 21 | 22 | 23 | def _retry_for_exception(_): 24 | return True 25 | 26 | 27 | class HTTPBus: 28 | def __init__(self, registry_client): 29 | self._registry_client = registry_client 30 | 31 | def send_http_request(self, app: str, service: str, version: str, method: str, entity: str, params: dict): 32 | """ 33 | A convenience method that allows you to send a well formatted http request to another service 34 | """ 35 | host, port, node_id, service_type = self._registry_client.resolve(service, version, entity, HTTP) 36 | 37 | url = 'http://{}:{}{}'.format(host, port, params.pop('path')) 38 | 39 | http_keys = ['data', 'headers', 'cookies', 'auth', 'allow_redirects', 'compress', 'chunked'] 40 | kwargs = {k: params[k] for k in http_keys if k in params} 41 | 42 | query_params = params.pop('params', {}) 43 | 44 | if app is not None: 45 | query_params['app'] = app 46 | 47 | query_params['version'] = version 48 | query_params['service'] = service 49 | 50 | response = yield from aiohttp.request(method, url, params=query_params, **kwargs) 51 | return response 52 | 53 | 54 | class TCPBus: 55 | def __init__(self, registry_client): 56 | registry_client.conn_handler = self 57 | self._registry_client = registry_client 58 | self._client_protocols = {} 59 | self._pingers = {} 60 | self._node_clients = {} 61 | self._service_clients = [] 62 | self.tcp_host = None 63 | self.http_host = None 64 | self._host_id = unique_hex() 65 | self._ronin = False 66 | self._registered = False 67 | self._logger = logging.getLogger(__name__) 68 | 69 | def _create_service_clients(self): 70 | futures = [] 71 | for sc in self._service_clients: 72 | for host, port, node_id, service_type in self._registry_client.get_all_addresses(*sc.properties): 73 | if service_type == 'tcp': 74 | self._node_clients[node_id] = sc 75 | future = self._connect_to_client(host, node_id, port, service_type, sc) 76 | futures.append(future) 77 | return asyncio.gather(*futures, return_exceptions=False) 78 | 79 | def connect(self): 80 | clients = self.tcp_host.clients if self.tcp_host else self.http_host.clients 81 | for client in clients: 82 | if isinstance(client, (TCPServiceClient, HTTPServiceClient)): 83 | client.tcp_bus = self 84 | self._service_clients = clients 85 | yield from self._registry_client.connect() 86 | 87 | def register(self): 88 | if self.tcp_host: 89 | self._registry_client.register(self.tcp_host.host, self.tcp_host.port, self.tcp_host.name, 90 | self.tcp_host.version, self.tcp_host.node_id, self.tcp_host.clients, 'tcp') 91 | if self.http_host: 92 | self._registry_client.register(self.http_host.host, self.http_host.port, self.http_host.name, 93 | self.http_host.version, self.http_host.node_id, self.http_host.clients, 94 | 'http') 95 | 96 | def registration_complete(self): 97 | if not self._registered: 98 | self._create_service_clients() 99 | self._registered = True 100 | 101 | def new_instance(self, service, version, host, port, node_id, type): 102 | sc = next(sc for sc in self._service_clients if sc.name == service and sc.version == version) 103 | if type == 'tcp': 104 | self._node_clients[node_id] = sc 105 | asyncio.ensure_future(self._connect_to_client(host, node_id, port, type, sc)) 106 | 107 | def send(self, packet: dict): 108 | packet['from'] = self._host_id 109 | func = getattr(self, '_' + packet['type'] + '_sender') 110 | wrapper_func = func 111 | if not iscoroutine(func): 112 | wrapper_func = coroutine(func) 113 | asyncio.ensure_future(wrapper_func(packet)) 114 | 115 | # @retry((ClientDisconnected, ClientNotFoundError)) 116 | @retry(should_retry_for_result=lambda x: not x, should_retry_for_exception=lambda x: True, timeout=None, 117 | max_attempts=5, multiplier=2) 118 | def _request_sender(self, packet: dict): 119 | """ 120 | Sends a request to a server from a ServiceClient 121 | auto dispatch method called from self.send() 122 | """ 123 | node_id = self._get_node_id_for_packet(packet) 124 | client_protocol = self._client_protocols.get(node_id) 125 | if node_id and client_protocol: 126 | if client_protocol.is_connected(): 127 | packet['to'] = node_id 128 | client_protocol.send(packet) 129 | return True 130 | else: 131 | self._logger.error('Client protocol is not connected for packet %s', packet) 132 | raise ClientDisconnected() 133 | else: 134 | # No node found to send request 135 | self._logger.error('Out of %s, Client Not found for packet %s, restarting server...', 136 | self._client_protocols.keys(), packet) 137 | raise ClientNotFoundError() 138 | 139 | def _connect_to_client(self, host, node_id, port, service_type, service_client): 140 | future = asyncio.ensure_future( 141 | asyncio.get_event_loop().create_connection(partial(get_trellio_protocol, service_client), host, port, 142 | ssl=service_client._ssl_context)) 143 | future.add_done_callback( 144 | partial(self._service_client_connection_callback, self._node_clients[node_id], node_id, service_type)) 145 | return future 146 | 147 | def _service_client_connection_callback(self, sc, node_id, service_type, future): 148 | _, protocol = future.result() 149 | # TODO : handle pinging 150 | # if service_type == TCP: 151 | # pinger = Pinger(self, asyncio.get_event_loop()) 152 | # self._pingers[node_id] = pinger 153 | # pinger.register_tcp_service(protocol, node_id) 154 | # asyncio.ensure_future(pinger.start_ping()) 155 | self._client_protocols[node_id] = protocol # stores connection(sockets) 156 | 157 | @staticmethod 158 | def _create_json_service_name(app, service, version): 159 | return {'app': app, 'name': service, 'version': version} 160 | 161 | @staticmethod 162 | def _handle_ping(packet, protocol): 163 | protocol.send(ControlPacket.pong(packet['node_id'])) 164 | 165 | def _handle_pong(self, node_id, count): 166 | pinger = self._pingers[node_id] 167 | asyncio.ensure_future(pinger.pong_received(count)) 168 | 169 | def _get_node_id_for_packet(self, packet): 170 | service, version, entity = packet['name'], packet['version'], packet['entity'] 171 | node = self._registry_client.resolve(service, version, entity, TCP) 172 | return node[2] if node else None 173 | 174 | def handle_ping_timeout(self, node_id): 175 | self._logger.info("Service client connection timed out {}".format(node_id)) 176 | self._pingers.pop(node_id, None) 177 | service_props = self._registry_client.get_for_node(node_id) 178 | self._logger.info('service client props {}'.format(service_props)) 179 | if service_props is not None: 180 | host, port, _node_id, _type = service_props 181 | asyncio.ensure_future(self._connect_to_client(host, _node_id, port, _type)) 182 | 183 | def receive(self, packet: dict, protocol, transport): 184 | if packet['type'] == 'ping': 185 | self._handle_ping(packet, protocol) 186 | elif packet['type'] == 'pong': 187 | self._handle_pong(packet['node_id'], packet['count']) 188 | elif packet['type'] == 'publish': 189 | self._handle_publish(packet, protocol) 190 | else: 191 | if self.tcp_host.is_for_me(packet['name'], packet['version']): 192 | func = getattr(self, '_' + packet['type'] + '_receiver') 193 | func(packet, protocol) 194 | else: 195 | self._logger.warning('wrongly routed packet: ', packet) 196 | 197 | def _request_receiver(self, packet, protocol): 198 | api_fn = None 199 | try: 200 | api_fn = getattr(self.tcp_host, packet['endpoint']) 201 | except AttributeError: 202 | pass 203 | if not api_fn: 204 | for view in self.tcp_host.tcp_views: 205 | _api_fn = None 206 | try: 207 | _api_fn = getattr(view, packet['endpoint']) 208 | except AttributeError: 209 | pass 210 | if _api_fn: 211 | api_fn = _api_fn 212 | break 213 | if api_fn.is_api: 214 | from_node_id = packet['from'] 215 | entity = packet['entity'] 216 | future = asyncio.ensure_future(api_fn(from_id=from_node_id, entity=entity, **packet['payload'])) 217 | 218 | def send_result(f): 219 | result_packet = f.result() 220 | protocol.send(result_packet) 221 | 222 | future.add_done_callback(send_result) 223 | else: 224 | print('no api found for packet: ', packet) 225 | 226 | def _handle_publish(self, packet, protocol): 227 | service, version, endpoint, payload, publish_id = (packet['name'], packet['version'], packet['endpoint'], 228 | packet['payload'], packet['publish_id']) 229 | for client in self._service_clients: 230 | if client.name == service and client.version == version: 231 | fun = getattr(client, endpoint) 232 | asyncio.ensure_future(fun(payload)) 233 | protocol.send(MessagePacket.ack(publish_id)) 234 | 235 | def handle_connected(self): 236 | if self.tcp_host: 237 | self.tcp_host.initiate() 238 | if self.http_host: 239 | self.http_host.initiate() 240 | 241 | # class PubSubBus: 242 | # PUBSUB_DELAY = 5 243 | # 244 | # def __init__(self, pubsub_host, pubsub_port, registry_client, ssl_context=None): 245 | # self._host = pubsub_host 246 | # self._port = pubsub_port 247 | # self._pubsub_handler = None 248 | # self._registry_client = registry_client 249 | # self._clients = None 250 | # self._pending_publishes = {} 251 | # self._ssl_context = ssl_context 252 | # 253 | # def create_pubsub_handler(self): 254 | # self._pubsub_handler = PubSub(self._host, self._port) 255 | # yield from self._pubsub_handler.connect() 256 | # 257 | # def register_for_subscription(self, host, port, node_id, clients): 258 | # self._clients = clients 259 | # subscription_list = [] 260 | # xsubscription_list = [] 261 | # for client in clients: 262 | # if isinstance(client, TCPServiceClient): 263 | # for each in dir(client): 264 | # fn = getattr(client, each) 265 | # if callable(fn) and getattr(fn, 'is_subscribe', False): 266 | # subscription_list.append(self._get_pubsub_key(client.name, client.version, fn.__name__)) 267 | # elif callable(fn) and getattr(fn, 'is_xsubscribe', False): 268 | # xsubscription_list.append((client.name, client.version, fn.__name__, getattr(fn, 'strategy'))) 269 | # self._registry_client.x_subscribe(host, port, node_id, xsubscription_list) 270 | # yield from self._pubsub_handler.subscribe(subscription_list, handler=self.subscription_handler) 271 | # 272 | # def publish(self, service, version, endpoint, payload): 273 | # endpoint_key = self._get_pubsub_key(service, version, endpoint) 274 | # asyncio.ensure_future(self._pubsub_handler.publish(endpoint_key, json.dumps(payload, cls=TrellioEncoder))) 275 | # asyncio.ensure_future(self.xpublish(service, version, endpoint, payload)) 276 | # 277 | # def xpublish(self, service, version, endpoint, payload): 278 | # subscribers = yield from self._registry_client.get_subscribers(service, version, endpoint) 279 | # strategies = defaultdict(list) 280 | # for subscriber in subscribers: 281 | # strategies[(subscriber['name'], subscriber['version'])].append( 282 | # (subscriber['host'], subscriber['port'], subscriber['node_id'], subscriber['strategy'])) 283 | # for key, value in strategies.items(): 284 | # publish_id = str(uuid.uuid4()) 285 | # future = asyncio.ensure_future( 286 | # self._connect_and_publish(publish_id, service, version, endpoint, value, payload)) 287 | # self._pending_publishes[publish_id] = future 288 | # 289 | # def receive(self, packet, transport, protocol): 290 | # if packet['type'] == 'ack': 291 | # future = self._pending_publishes.pop(packet['request_id'], None) 292 | # if future: 293 | # future.cancel() 294 | # transport.close() 295 | # 296 | # def subscription_handler(self, endpoint, payload): 297 | # service, version, endpoint = endpoint.split('/') 298 | # client = [sc for sc in self._clients if (sc.name == service and sc.version == version)][0] 299 | # func = getattr(client, endpoint) 300 | # asyncio.ensure_future(func(**json.loads(payload))) 301 | # 302 | # @staticmethod 303 | # def _get_pubsub_key(service, version, endpoint): 304 | # return '/'.join((service, str(version), endpoint)) 305 | # 306 | # def _connect_and_publish(self, publish_id, service, version, endpoint, subscribers, payload): 307 | # if subscribers[0][3] == 'LEADER': 308 | # host, port = subscribers[0][0], subscribers[0][1] 309 | # else: 310 | # random_metadata = random.choice(subscribers) 311 | # host, port = random_metadata[0], random_metadata[1] 312 | # transport, protocol = yield from asyncio.get_event_loop().create_connection( 313 | # partial(get_trellio_protocol, self), host, port) 314 | # packet = MessagePacket.publish(publish_id, service, version, endpoint, payload) 315 | # protocol.send(packet) 316 | # yield from asyncio.sleep(self.PUBSUB_DELAY) 317 | -------------------------------------------------------------------------------- /trellio/registry.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import logging 4 | import signal 5 | import ssl 6 | import time 7 | from collections import defaultdict, namedtuple 8 | from functools import partial 9 | 10 | from again.utils import natural_sort 11 | from aiohttp import web 12 | 13 | from .packet import ControlPacket 14 | from .pinger import TCPPinger 15 | from .protocol_factory import get_trellio_protocol 16 | from .utils.log import setup_logging 17 | import os 18 | 19 | Service = namedtuple('Service', ['name', 'version', 'dependencies', 'host', 'port', 'node_id', 'type']) 20 | 21 | 22 | def tree(): 23 | return defaultdict(tree) 24 | 25 | 26 | def json_file_to_dict(_file: str) -> dict: 27 | config = None 28 | with open(_file) as config_file: 29 | config = json.load(config_file) 30 | 31 | return config 32 | 33 | 34 | class Repository: 35 | def __init__(self): 36 | self._registered_services = defaultdict(lambda: defaultdict(list)) 37 | self._pending_services = defaultdict(list) 38 | self._service_dependencies = {} 39 | self._subscribe_list = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) 40 | self._uptimes = tree() 41 | self.logger = logging.getLogger(__name__) 42 | 43 | def register_service(self, service: Service): 44 | service_name = self._get_full_service_name(service.name, service.version) 45 | service_entry = (service.host, service.port, service.node_id, service.type) 46 | self._registered_services[service.name][service.version].append(service_entry) 47 | # in future there can be multiple nodes for same service, for load balancing purposes 48 | self._pending_services[service_name].append(service.node_id) 49 | self._uptimes[service_name][service.host] = { 50 | 'uptime': int(time.time()), 51 | 'node_id': service.node_id 52 | } 53 | 54 | if len(service.dependencies): 55 | if not self._service_dependencies.get(service.name): 56 | self._service_dependencies[service_name] = service.dependencies 57 | 58 | def is_pending(self, name, version): 59 | return self._get_full_service_name(name, version) in self._pending_services 60 | 61 | def add_pending_service(self, name, version, node_id): 62 | self._pending_services[self._get_full_service_name(name, version)].append(node_id) 63 | 64 | def get_pending_services(self): 65 | return [self._split_key(k) for k in self._pending_services.keys()] 66 | 67 | def get_pending_instances(self, name, version): 68 | return self._pending_services.get(self._get_full_service_name(name, version), []) 69 | 70 | def remove_pending_instance(self, name, version, node_id): 71 | self.get_pending_instances(name, version).remove(node_id) 72 | if not len(self.get_pending_instances(name, version)): 73 | self._pending_services.pop(self._get_full_service_name(name, version)) 74 | 75 | def get_instances(self, name, version): 76 | return self._registered_services[name][version] 77 | 78 | def get_versioned_instances(self, name, version): 79 | version = self._get_non_breaking_version(version, list(self._registered_services[name].keys())) 80 | return self._registered_services[name][version] 81 | 82 | def get_consumers(self, name, service_version): 83 | consumers = set() 84 | for _name, dependencies in self._service_dependencies.items(): 85 | for dependency in dependencies: 86 | if dependency['name'] == name and dependency['version'] == service_version: 87 | consumers.add(self._split_key(_name)) 88 | return consumers 89 | 90 | def get_dependencies(self, name, version): 91 | return self._service_dependencies.get(self._get_full_service_name(name, version), []) 92 | 93 | def get_node(self, node_id): 94 | for name, versions in self._registered_services.items(): 95 | for version, instances in versions.items(): 96 | for host, port, node, service_type in instances: 97 | if node_id == node: 98 | return Service(name, version, [], host, port, node, service_type) 99 | return None 100 | 101 | def remove_node(self, node_id): 102 | thehost = None 103 | for name, versions in self._registered_services.items(): 104 | for version, instances in versions.items(): 105 | for instance in instances: 106 | host, port, node, service_type = instance 107 | if node_id == node: 108 | thehost = host 109 | instances.remove(instance) 110 | break 111 | for name, nodes in self._uptimes.items(): 112 | for host, uptimes in nodes.items(): 113 | if host == thehost and uptimes['node_id'] == node_id: 114 | uptimes['downtime'] = int(time.time()) 115 | self.log_uptimes() 116 | return None 117 | 118 | def get_uptimes(self): 119 | return self._uptimes 120 | 121 | def log_uptimes(self): 122 | for name, nodes in self._uptimes.items(): 123 | for host, d in nodes.items(): 124 | now = int(time.time()) 125 | live = d.get('downtime', 0) < d['uptime'] 126 | uptime = now - d['uptime'] if live else 0 127 | logd = {'service_name': name.split('/')[0], 'hostname': host, 'status': live, 128 | 'uptime': int(uptime)} 129 | logging.getLogger('stats').info(logd) 130 | 131 | def xsubscribe(self, name, version, host, port, node_id, endpoints): 132 | entry = (name, version, host, port, node_id) 133 | for endpoint in endpoints: 134 | self._subscribe_list[endpoint['name']][endpoint['version']][endpoint['endpoint']].append( 135 | entry + (endpoint['strategy'],)) 136 | 137 | def get_subscribers(self, name, version, endpoint): 138 | return self._subscribe_list[name][version][endpoint] 139 | 140 | def _get_non_breaking_version(self, version, versions): 141 | if version in versions: 142 | return version 143 | versions.sort(key=natural_sort, reverse=True) 144 | for v in versions: 145 | if self._is_non_breaking(v, version): 146 | return v 147 | return version 148 | 149 | @staticmethod 150 | def _is_non_breaking(v, version): 151 | return version.split('.')[0] == v.split('.')[0] 152 | 153 | @staticmethod 154 | def _get_full_service_name(name: str, version): 155 | return '{}/{}'.format(name, version) 156 | 157 | @staticmethod 158 | def _split_key(key: str): 159 | return tuple(key.split('/')) 160 | 161 | 162 | class Registry: 163 | def __init__(self, ip, port, repository: Repository): 164 | self._ip = ip 165 | self._port = port 166 | self._loop = asyncio.get_event_loop() 167 | self._client_protocols = {} 168 | self._service_protocols = {} 169 | self._repository = repository 170 | self._tcp_pingers = {} 171 | self._http_pingers = {} 172 | self.logger = logging.getLogger() 173 | try: 174 | config = json_file_to_dict('./config.json') 175 | self._ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) 176 | self._ssl_context.load_cert_chain(config['SSL_CERTIFICATE'], config['SSL_KEY']) 177 | except: 178 | self._ssl_context = None 179 | 180 | def _create_http_app(self): 181 | app = web.Application() 182 | registry_dump_handle.registry = self 183 | app.router.add_get('/registry/', registry_dump_handle) 184 | handler = app.make_handler(access_log=self.logger) 185 | task = asyncio.get_event_loop().create_server(handler, self._ip, os.environ.get('TRELLIO_HTTP_PORT', 4501)) 186 | http_server = asyncio.get_event_loop().run_until_complete(task) 187 | return http_server 188 | 189 | def start(self): 190 | setup_logging("registry") 191 | self._loop.add_signal_handler(getattr(signal, 'SIGINT'), partial(self._stop, 'SIGINT')) 192 | self._loop.add_signal_handler(getattr(signal, 'SIGTERM'), partial(self._stop, 'SIGTERM')) 193 | registry_coroutine = self._loop.create_server( 194 | partial(get_trellio_protocol, self), self._ip, self._port, ssl=self._ssl_context) 195 | server = self._loop.run_until_complete(registry_coroutine) 196 | http_server = self._create_http_app() 197 | try: 198 | self._loop.run_forever() 199 | except Exception as e: 200 | print(e) 201 | finally: 202 | server.close() 203 | http_server.close() 204 | self._loop.run_until_complete(server.wait_closed()) 205 | self._loop.close() 206 | 207 | def _stop(self, signame: str): 208 | print('\ngot signal {} - exiting'.format(signame)) 209 | self._loop.stop() 210 | 211 | def receive(self, packet: dict, protocol, transport): 212 | request_type = packet['type'] 213 | if request_type in ['register', 'get_instances', 'xsubscribe', 'get_subscribers']: 214 | for_log = {} 215 | params = packet['params'] 216 | for_log["caller_name"] = params['name'] + '/' + params['version'] 217 | for_log["caller_address"] = transport.get_extra_info("peername")[0] 218 | for_log["request_type"] = request_type 219 | self.logger.debug(for_log) 220 | if request_type == 'register': 221 | packet['params']['host'] = transport.get_extra_info("peername")[0] 222 | self.register_service(packet, protocol) 223 | elif request_type == 'get_instances': 224 | self.get_service_instances(packet, protocol) 225 | elif request_type == 'xsubscribe': 226 | self._xsubscribe(packet) 227 | elif request_type == 'get_subscribers': 228 | self.get_subscribers(packet, protocol) 229 | elif request_type == 'pong': 230 | self._ping(packet) 231 | elif request_type == 'ping': 232 | self._handle_ping(packet, protocol) 233 | elif request_type == 'uptime_report': 234 | self._get_uptime_report(packet, protocol) 235 | 236 | def deregister_service(self, host, port, node_id): 237 | service = self._repository.get_node(node_id) 238 | self._tcp_pingers.pop(node_id, None) 239 | self._http_pingers.pop((host, port), None) 240 | if service: 241 | for_log = {"caller_name": service.name + '/' + service.version, "caller_address": service.host, 242 | "request_type": 'deregister'} 243 | self.logger.debug(for_log) 244 | self._repository.remove_node(node_id) 245 | if service is not None: 246 | self._service_protocols.pop(node_id, None) 247 | self._client_protocols.pop(node_id, None) 248 | self._notify_consumers(service.name, service.version, node_id) 249 | if not len(self._repository.get_instances(service.name, service.version)): 250 | consumers = self._repository.get_consumers(service.name, service.version) 251 | for consumer_name, consumer_version in consumers: 252 | for _, _, node_id, _ in self._repository.get_instances(consumer_name, consumer_version): 253 | self._repository.add_pending_service(consumer_name, consumer_version, node_id) 254 | 255 | def register_service(self, packet: dict, registry_protocol): 256 | params = packet['params'] 257 | service = Service(params['name'], params['version'], params['dependencies'], params['host'], params['port'], 258 | params['node_id'], params['type']) 259 | self._repository.register_service(service) 260 | self._client_protocols[params['node_id']] = registry_protocol 261 | if params['node_id'] not in self._service_protocols.keys(): 262 | self._connect_to_service(params['host'], params['port'], params['node_id'], params['type']) 263 | self._handle_pending_registrations() 264 | self._inform_consumers(service) 265 | 266 | def _inform_consumers(self, service: Service): 267 | consumers = self._repository.get_consumers(service.name, service.version) 268 | for service_name, service_version in consumers: 269 | if not self._repository.is_pending(service_name, service_version): 270 | instances = self._repository.get_instances(service_name, service_version) 271 | for host, port, node, type in instances: 272 | protocol = self._client_protocols[node] 273 | protocol.send(ControlPacket.new_instance( 274 | service.name, service.version, service.host, service.port, service.node_id, service.type)) 275 | 276 | def _send_activated_packet(self, name, version, node): 277 | protocol = self._client_protocols.get(node, None) 278 | if protocol: 279 | packet = self._make_activated_packet(name, version) 280 | protocol.send(packet) 281 | 282 | def _handle_pending_registrations(self): 283 | for name, version in self._repository.get_pending_services(): 284 | dependencies = self._repository.get_dependencies(name, version) # list 285 | should_activate = True 286 | for dependency in dependencies: 287 | instances = self._repository.get_versioned_instances(dependency['name'], dependency['version']) # list 288 | tcp_instances = [instance for instance in instances if instance[3] == 'tcp'] 289 | if not len( 290 | tcp_instances): # means the dependency doesn't have an activated tcp service, so registration 291 | # pending 292 | should_activate = False 293 | break 294 | for node in self._repository.get_pending_instances(name, version): # node is node id 295 | if should_activate: 296 | self._send_activated_packet(name, version, node) 297 | self._repository.remove_pending_instance(name, version, node) 298 | self.logger.info('%s activated', (name, version)) 299 | else: 300 | self.logger.info('%s can\'t register because it depends on %s', (name, version), dependency) 301 | 302 | def _make_activated_packet(self, name, version): 303 | dependencies = self._repository.get_dependencies(name, version) 304 | instances = { 305 | (dependency['name'], dependency['version']): self._repository.get_versioned_instances(dependency['name'], 306 | dependency['version']) 307 | for dependency in dependencies} 308 | return ControlPacket.activated(instances) 309 | 310 | def _connect_to_service(self, host, port, node_id, service_type): 311 | if service_type == 'tcp': 312 | if node_id not in self._service_protocols: 313 | coroutine = self._loop.create_connection(partial(get_trellio_protocol, self), host, port) 314 | future = asyncio.ensure_future(coroutine) 315 | future.add_done_callback(partial(self._handle_service_connection, node_id, host, port)) 316 | elif service_type == 'http': 317 | pass 318 | # if not (host, port) in self._http_pingers: 319 | # pinger = HTTPPinger(host, port, node_id, self) 320 | # self._http_pingers[(host, port)] = pinger 321 | # pinger.ping() 322 | 323 | def _handle_service_connection(self, node_id, host, port, future): 324 | transport, protocol = future.result() 325 | self._service_protocols[node_id] = protocol 326 | pinger = TCPPinger(host, port, node_id, protocol, self) 327 | self._tcp_pingers[node_id] = pinger 328 | pinger.ping() 329 | 330 | def _notify_consumers(self, name, version, node_id): 331 | packet = ControlPacket.deregister(name, version, node_id) 332 | for consumer_name, consumer_version in self._repository.get_consumers(name, version): 333 | for host, port, node, service_type in self._repository.get_instances(consumer_name, consumer_version): 334 | protocol = self._client_protocols[node] 335 | protocol.send(packet) 336 | 337 | def get_service_instances(self, packet, registry_protocol): 338 | params = packet['params'] 339 | name, version = params['name'].lower(), params['version'] 340 | instances = self._repository.get_instances(name, version) 341 | instance_packet = ControlPacket.send_instances(name, version, packet['request_id'], instances) 342 | registry_protocol.send(instance_packet) 343 | 344 | def get_subscribers(self, packet, protocol): 345 | params = packet['params'] 346 | request_id = packet['request_id'] 347 | name, version, endpoint = params['name'].lower(), params['version'], params['endpoint'] 348 | subscribers = self._repository.get_subscribers(name, version, endpoint) 349 | packet = ControlPacket.subscribers(name, version, endpoint, request_id, subscribers) 350 | protocol.send(packet) 351 | 352 | def on_timeout(self, host, port, node_id): 353 | service = self._repository.get_node(node_id) 354 | self.logger.debug('%s timed out', service) 355 | self.deregister_service(host, port, node_id) 356 | 357 | def _ping(self, packet): 358 | pinger = self._tcp_pingers[packet['node_id']] 359 | pinger.pong_received() 360 | 361 | def _pong(self, packet, protocol): 362 | protocol.send(ControlPacket.pong(packet['node_id'])) 363 | 364 | def _xsubscribe(self, packet): 365 | params = packet['params'] 366 | name, version, host, port, node_id = ( 367 | params['name'], params['version'], params['host'], params['port'], params['node_id']) 368 | endpoints = params['events'] 369 | self._repository.xsubscribe(name, version, host, port, node_id, endpoints) 370 | 371 | def _get_uptime_report(self, packet, protocol): 372 | uptimes = self._repository.get_uptimes() 373 | protocol.send(ControlPacket.uptime(uptimes)) 374 | 375 | def periodic_uptime_logger(self): 376 | self._repository.log_uptimes() 377 | asyncio.get_event_loop().call_later(300, self.periodic_uptime_logger) 378 | 379 | def _handle_ping(self, packet, protocol): 380 | """ Responds to pings from registry_client only if the node_ids present in the ping payload are registered 381 | 382 | :param packet: The 'ping' packet received 383 | :param protocol: The protocol on which the pong should be sent 384 | """ 385 | if 'payload' in packet: 386 | is_valid_node = True 387 | node_ids = list(packet['payload'].values()) 388 | for node_id in node_ids: 389 | if self._repository.get_node(node_id) is None: 390 | is_valid_node = False 391 | break 392 | if is_valid_node: 393 | self._pong(packet, protocol) 394 | else: 395 | self._pong(packet, protocol) 396 | 397 | 398 | async def registry_dump_handle(request): 399 | ''' 400 | only read 401 | :param request: 402 | :return: 403 | ''' 404 | registry = registry_dump_handle.registry 405 | response_dict = {} 406 | repo = registry._repository 407 | response_dict['registered_services'] = repo._registered_services 408 | response_dict['uptimes'] = repo._uptimes 409 | response_dict['service_dependencies'] = repo._service_dependencies 410 | return web.Response(status=400, content_type='application/json', body=json.dumps(response_dict).encode()) 411 | 412 | 413 | if __name__ == '__main__': 414 | from setproctitle import setproctitle 415 | 416 | setproctitle("trellio-registry") 417 | REGISTRY_HOST = None 418 | REGISTRY_PORT = 4500 419 | registry = Registry(REGISTRY_HOST, REGISTRY_PORT, Repository()) 420 | registry.periodic_uptime_logger() 421 | registry.start() 422 | -------------------------------------------------------------------------------- /trellio/services.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import logging 4 | import socket 5 | import time 6 | from asyncio import iscoroutine, coroutine, wait_for, TimeoutError, Future, get_event_loop, async 7 | from functools import wraps, partial 8 | 9 | import setproctitle 10 | from again.utils import unique_hex 11 | from aiohttp.web import Response 12 | from retrial.retrial.retry import retry 13 | 14 | from trellio.packet import ControlPacket 15 | from .exceptions import RequestException, ClientException, TrellioServiceException 16 | from .packet import MessagePacket 17 | from .utils.helpers import Singleton # we need non singleton subclasses 18 | from .utils.helpers import default_preflight_response 19 | from .utils.ordered_class_member import OrderedClassMembers 20 | from .utils.stats import Aggregator, Stats 21 | from .views import HTTPView 22 | 23 | API_TIMEOUT = 60 * 10 24 | 25 | 26 | def publish(func): 27 | """ 28 | publish the return value of this function as a message from this endpoint 29 | """ 30 | 31 | @wraps(func) 32 | def wrapper(self, *args, **kwargs): # outgoing 33 | payload = func(self, *args, **kwargs) 34 | payload.pop('self', None) 35 | self._publish(func.__name__, payload) 36 | return None 37 | 38 | wrapper.is_publish = True 39 | 40 | return wrapper 41 | 42 | 43 | def subscribe(func): 44 | """ 45 | use to listen for publications from a specific endpoint of a service, 46 | this method receives a publication from a remote service 47 | """ 48 | wrapper = _get_subscribe_decorator(func) 49 | wrapper.is_subscribe = True 50 | return wrapper 51 | 52 | 53 | def xsubscribe(func=None, strategy='DESIGNATION'): 54 | """ 55 | Used to listen for publications from a specific endpoint of a service. If multiple instances 56 | subscribe to an endpoint, only one of them receives the event. And the publish event is retried till 57 | an acknowledgment is received from the other end. 58 | :param func: the function to decorate with. The name of the function is the event subscribers will subscribe to. 59 | :param strategy: The strategy of delivery. Can be 'RANDOM' or 'LEADER'. If 'RANDOM', then the event will be randomly 60 | passed to any one of the interested parties. If 'LEADER' then it is passed to the first instance alive 61 | which registered for that endpoint. 62 | """ 63 | if func is None: 64 | return partial(xsubscribe, strategy=strategy) 65 | else: 66 | wrapper = _get_subscribe_decorator(func) 67 | wrapper.is_xsubscribe = True 68 | wrapper.strategy = strategy 69 | return wrapper 70 | 71 | 72 | def _get_subscribe_decorator(func): 73 | @wraps(func) 74 | def wrapper(*args, **kwargs): 75 | coroutine_func = func 76 | if not iscoroutine(func): 77 | coroutine_func = coroutine(func) 78 | return (async(coroutine_func(*args, **kwargs))) 79 | 80 | return wrapper 81 | 82 | 83 | def request(func=None, timeout=600): 84 | """ 85 | use to request an api call from a specific endpoint 86 | """ 87 | if func is None: 88 | return partial(request, timeout=timeout) 89 | 90 | @wraps(func) 91 | def wrapper(self, *args, **kwargs): 92 | params = func(self, *args, **kwargs) 93 | self = params.pop('self', None) 94 | entity = params.pop('entity', None) 95 | app_name = params.pop('app_name', None) 96 | request_id = unique_hex() 97 | params['request_id'] = request_id 98 | future = self._send_request(app_name, endpoint=func.__name__, entity=entity, params=params, timeout=timeout) 99 | return future 100 | 101 | wrapper.is_request = True 102 | return wrapper 103 | 104 | 105 | def api(func=None, timeout=API_TIMEOUT): # incoming 106 | """ 107 | provide a request/response api 108 | receives any requests here and return value is the response 109 | all functions must have the following signature 110 | - request_id 111 | - entity (partition/routing key) 112 | followed by kwargs 113 | """ 114 | if func is None: 115 | return partial(api, timeout=timeout) 116 | else: 117 | wrapper = _get_api_decorator(func=func, timeout=timeout) 118 | return wrapper 119 | 120 | 121 | def apideprecated(func=None, replacement_api=None): 122 | if func is None: 123 | return partial(apideprecated, replacement_api=replacement_api) 124 | else: 125 | wrapper = _get_api_decorator(func=func, old_api=func.__name__, replacement_api=replacement_api) 126 | return wrapper 127 | 128 | 129 | def _get_api_decorator(func=None, old_api=None, replacement_api=None, timeout=API_TIMEOUT): 130 | @coroutine 131 | @wraps(func) 132 | def wrapper(*args, **kwargs): 133 | _logger = logging.getLogger(__name__) 134 | start_time = int(time.time() * 1000) 135 | self = args[0] 136 | rid = kwargs.pop('request_id') 137 | entity = kwargs.pop('entity') 138 | from_id = kwargs.pop('from_id') 139 | wrapped_func = func 140 | result = None 141 | error = None 142 | failed = False 143 | 144 | status = 'successful' 145 | success = True 146 | if not iscoroutine(func): 147 | wrapped_func = coroutine(func) 148 | 149 | Stats.tcp_stats['total_requests'] += 1 150 | 151 | try: 152 | result = yield from wait_for(wrapped_func(self, **kwargs), timeout) 153 | 154 | except TimeoutError as e: 155 | Stats.tcp_stats['timedout'] += 1 156 | error = str(e) 157 | status = 'timeout' 158 | success = False 159 | failed = True 160 | _logger.exception("TCP request had a timeout for method %s", func.__name__) 161 | 162 | except TrellioServiceException as e: 163 | Stats.tcp_stats['total_responses'] += 1 164 | error = str(e) 165 | status = 'handled_error' 166 | _logger.exception('Handled exception %s for method %s ', e.__class__.__name__, func.__name__) 167 | 168 | except Exception as e: 169 | Stats.tcp_stats['total_errors'] += 1 170 | error = str(e) 171 | status = 'unhandled_error' 172 | success = False 173 | failed = True 174 | _logger.exception('Unhandled exception %s for method %s ', e.__class__.__name__, func.__name__) 175 | else: 176 | Stats.tcp_stats['total_responses'] += 1 177 | end_time = int(time.time() * 1000) 178 | 179 | hostname = socket.gethostname() 180 | service_name = '_'.join(setproctitle.getproctitle().split('_')[1:-1]) 181 | 182 | logd = { 183 | 'endpoint': func.__name__, 184 | 'time_taken': end_time - start_time, 185 | 'hostname': hostname, 'service_name': service_name 186 | } 187 | _logger.debug('Time taken for %s is %d milliseconds', func.__name__, end_time - start_time) 188 | 189 | # call to update aggregator, designed to replace the stats module. 190 | Aggregator.update_stats(endpoint=func.__name__, status=status, success=success, 191 | server_type='tcp', time_taken=end_time - start_time) 192 | 193 | if not old_api: 194 | return self._make_response_packet(request_id=rid, from_id=from_id, entity=entity, result=result, 195 | error=error, failed=failed) 196 | else: 197 | return self._make_response_packet(request_id=rid, from_id=from_id, entity=entity, result=result, 198 | error=error, failed=failed, old_api=old_api, 199 | replacement_api=replacement_api) 200 | 201 | wrapper.is_api = True 202 | return wrapper 203 | 204 | 205 | def make_request(func, self, args, kwargs, method): 206 | params = func(self, *args, **kwargs) 207 | entity = params.pop('entity', None) 208 | app_name = params.pop('app_name', None) 209 | self = params.pop('self') 210 | response = yield from self._send_http_request(app_name, method, entity, params) 211 | return response 212 | 213 | 214 | def _enable_http_middleware(func): # pre and post http, processing 215 | @wraps(func) 216 | async def f(self, *args, **kwargs): 217 | if hasattr(self, 'middlewares'): 218 | for i in self.middlewares: 219 | if hasattr(i, 'pre_request'): 220 | pre_request = getattr(i, 'pre_request') 221 | if callable(pre_request): 222 | try: 223 | res = await pre_request(self, *args, **kwargs) # passing service as first argument 224 | if res: 225 | return res 226 | except Exception as e: 227 | return Response(status=400, content_type='application/json', 228 | body=json.dumps( 229 | {'error': str(e), 'sector': getattr(i, 'middleware_info')}).encode()) 230 | _func = coroutine(func) # func is a generator object 231 | result = await _func(self, *args, **kwargs) 232 | if hasattr(self, 'middlewares'): 233 | for i in self.middlewares: 234 | if hasattr(i, 'post_request'): 235 | post_request = getattr(i, 'post_request') 236 | if callable(post_request): 237 | try: 238 | res = await post_request(self, result, *args, **kwargs) 239 | if res: 240 | return res 241 | except Exception as e: 242 | return Response(status=400, content_type='application/json', 243 | body=json.dumps( 244 | {'error': str(e), 'sector': getattr(i, 'middleware_info')}).encode()) 245 | 246 | return result 247 | 248 | return f 249 | 250 | 251 | def get_decorated_fun(method, path, required_params, timeout): 252 | def decorator(func): 253 | @wraps(func) 254 | @_enable_http_middleware 255 | def f(self, *args, **kwargs): 256 | if isinstance(self, HTTPServiceClient): 257 | return (yield from make_request(func, self, args, kwargs, method)) 258 | elif isinstance(self, HTTPService) or isinstance(self, HTTPView): 259 | Stats.http_stats['total_requests'] += 1 260 | if required_params is not None: 261 | req = args[0] 262 | query_params = req.GET 263 | params = required_params 264 | if not isinstance(required_params, list): 265 | params = [required_params] 266 | missing_params = list(filter(lambda x: x not in query_params, params)) 267 | if len(missing_params) > 0: 268 | res_d = {'error': 'Required params {} not found'.format(','.join(missing_params))} 269 | Stats.http_stats['total_responses'] += 1 270 | Aggregator.update_stats(endpoint=func.__name__, status=400, success=False, 271 | server_type='http', time_taken=0) 272 | return Response(status=400, content_type='application/json', body=json.dumps(res_d).encode()) 273 | 274 | t1 = time.time() 275 | wrapped_func = func 276 | success = True 277 | _logger = logging.getLogger() 278 | 279 | if not iscoroutine(func): 280 | wrapped_func = coroutine(func) 281 | try: 282 | result = yield from wait_for(wrapped_func(self, *args, **kwargs), timeout) 283 | 284 | except TimeoutError as e: 285 | Stats.http_stats['timedout'] += 1 286 | status = 'timeout' 287 | success = False 288 | _logger.exception("HTTP request had a timeout for method %s", func.__name__) 289 | return Response(status=408, body='Request Timeout'.encode()) 290 | 291 | except TrellioServiceException as e: 292 | Stats.http_stats['total_responses'] += 1 293 | status = 'handled_exception' 294 | _logger.exception('Handled exception %s for method %s ', e.__class__.__name__, func.__name__) 295 | raise e 296 | 297 | except Exception as e: 298 | Stats.http_stats['total_errors'] += 1 299 | status = 'unhandled_exception' 300 | success = False 301 | _logger.exception('Unhandled exception %s for method %s ', e.__class__.__name__, func.__name__) 302 | raise e 303 | 304 | else: 305 | t2 = time.time() 306 | hostname = socket.gethostname() 307 | service_name = '_'.join(setproctitle.getproctitle().split('_')[1:-1]) 308 | status = result.status 309 | 310 | logd = { 311 | 'status': result.status, 312 | 'time_taken': int((t2 - t1) * 1000), 313 | 'type': 'http', 314 | 'hostname': hostname, 'service_name': service_name 315 | } 316 | logging.getLogger('stats').debug(logd) 317 | Stats.http_stats['total_responses'] += 1 318 | return result 319 | 320 | finally: 321 | t2 = time.time() 322 | Aggregator.update_stats(endpoint=func.__name__, status=status, success=success, 323 | server_type='http', time_taken=int((t2 - t1) * 1000)) 324 | 325 | f.is_http_method = True 326 | f.method = method 327 | f.paths = path 328 | if not isinstance(path, list): 329 | f.paths = [path] 330 | 331 | return f 332 | 333 | return decorator 334 | 335 | 336 | def get(path=None, required_params=None, timeout=API_TIMEOUT): 337 | return get_decorated_fun('get', path, required_params, timeout) 338 | 339 | 340 | def head(path=None, required_params=None, timeout=API_TIMEOUT): 341 | return get_decorated_fun('head', path, required_params, timeout) 342 | 343 | 344 | def options(path=None, required_params=None, timeout=API_TIMEOUT): 345 | return get_decorated_fun('options', path, required_params, timeout) 346 | 347 | 348 | def patch(path=None, required_params=None, timeout=API_TIMEOUT): 349 | return get_decorated_fun('patch', path, required_params, timeout) 350 | 351 | 352 | def post(path=None, required_params=None, timeout=API_TIMEOUT): 353 | return get_decorated_fun('post', path, required_params, timeout) 354 | 355 | 356 | def put(path=None, required_params=None, timeout=API_TIMEOUT): 357 | return get_decorated_fun('put', path, required_params, timeout) 358 | 359 | 360 | def trace(path=None, required_params=None, timeout=API_TIMEOUT): 361 | return get_decorated_fun('put', path, required_params, timeout) 362 | 363 | 364 | def delete(path=None, required_params=None, timeout=API_TIMEOUT): 365 | return get_decorated_fun('delete', path, required_params, timeout) 366 | 367 | 368 | class _Service: 369 | _PUB_PKT_STR = 'publish' 370 | _REQ_PKT_STR = 'request' 371 | _RES_PKT_STR = 'response' 372 | 373 | def __init__(self, service_name, service_version): 374 | self._service_name = service_name.lower() 375 | self._service_version = str(service_version) 376 | self._tcp_bus = None 377 | self._pubsub_bus = None 378 | self._http_bus = None 379 | 380 | @property 381 | def name(self): 382 | return self._service_name 383 | 384 | @property 385 | def version(self): 386 | return self._service_version 387 | 388 | @property 389 | def properties(self): 390 | return self.name, self.version 391 | 392 | @staticmethod 393 | def time_future(future: Future, timeout: int): 394 | def timer_callback(f): 395 | if not f.done() and not f.cancelled(): 396 | f.set_exception(TimeoutError()) 397 | 398 | get_event_loop().call_later(timeout, timer_callback, future) 399 | 400 | 401 | class TCPServiceClient(Singleton, _Service): 402 | def __init__(self, service_name, service_version, ssl_context=None): 403 | if not self.has_inited(): # to maintain singleton behaviour 404 | super(TCPServiceClient, self).__init__(service_name, service_version) 405 | self._pending_requests = {} 406 | self.tcp_bus = None 407 | self._ssl_context = ssl_context 408 | self.init_done() 409 | 410 | @property 411 | def ssl_context(self): 412 | return self._ssl_context 413 | 414 | def _send_request(self, app_name, endpoint, entity, params, timeout): 415 | packet = MessagePacket.request(self.name, self.version, app_name, _Service._REQ_PKT_STR, endpoint, params, 416 | entity) 417 | future = Future() 418 | request_id = params['request_id'] 419 | self._pending_requests[request_id] = future 420 | try: 421 | self.tcp_bus.send(packet) 422 | except ClientException: 423 | if not future.done() and not future.cancelled(): 424 | error = 'Client not found' 425 | exception = ClientException(error) 426 | exception.error = error 427 | future.set_exception(exception) 428 | _Service.time_future(future, timeout) 429 | return future 430 | 431 | def receive(self, packet: dict, protocol, transport): 432 | if packet['type'] == 'ping': 433 | pass 434 | else: 435 | self._process_response(packet) 436 | 437 | def process_packet(self, packet): 438 | if packet['type'] == _Service._RES_PKT_STR: 439 | self._process_response(packet) 440 | elif packet['type'] == _Service._PUB_PKT_STR: 441 | self._process_publication(packet) 442 | else: 443 | print('Invalid packet', packet) 444 | 445 | def _process_response(self, packet): 446 | payload = packet['payload'] 447 | request_id = payload['request_id'] 448 | has_result = 'result' in payload 449 | has_error = 'error' in payload 450 | if 'old_api' in payload: 451 | warning = 'Deprecated API: ' + payload['old_api'] 452 | if 'replacement_api' in payload: 453 | warning += ', New API: ' + payload['replacement_api'] 454 | logging.getLogger().warning(warning) 455 | future = self._pending_requests.pop(request_id) 456 | if has_result: 457 | if not future.done() and not future.cancelled(): 458 | future.set_result(payload['result']) 459 | elif has_error: 460 | if payload.get('failed', False): 461 | if not future.done() and not future.cancelled(): 462 | future.set_exception(Exception(payload['error'])) 463 | else: 464 | exception = RequestException() 465 | exception.error = payload['error'] 466 | if not future.done() and not future.cancelled(): 467 | future.set_exception(exception) 468 | else: 469 | print('Invalid response to request:', packet) 470 | 471 | def _process_publication(self, packet): 472 | endpoint = packet['endpoint'] 473 | func = getattr(self, endpoint) 474 | func(**packet['payload']) 475 | 476 | def _handle_connection_lost(self): 477 | vendor = self.tcp_bus._registry_client._get_full_service_name(self.name, self.version) 478 | for host, port, node_id, service_type in self.tcp_bus._registry_client._available_services[vendor]: 479 | packet = ControlPacket.deregister(self.name, self.version, node_id) 480 | self.tcp_bus._registry_client._handle_deregistration(packet) 481 | 482 | 483 | class _ServiceHost(_Service): 484 | def __init__(self, service_name, service_version, host_ip, host_port): 485 | super(_ServiceHost, self).__init__(service_name, service_version) 486 | self._node_id = unique_hex() 487 | self._ip = host_ip 488 | self._port = host_port 489 | self._clients = [] 490 | 491 | def is_for_me(self, service, version): 492 | return service == self.name and version == self.version 493 | 494 | @property 495 | def node_id(self): 496 | return self._node_id 497 | 498 | @property 499 | def tcp_bus(self): 500 | return self._tcp_bus 501 | 502 | @tcp_bus.setter 503 | def tcp_bus(self, bus): 504 | for client in self._clients: 505 | if isinstance(client, TCPServiceClient): 506 | client.tcp_bus = bus 507 | self._tcp_bus = bus 508 | 509 | @property 510 | def http_bus(self): 511 | return self._http_bus 512 | 513 | @http_bus.setter 514 | def http_bus(self, bus): 515 | for client in self._clients: 516 | if isinstance(client, HTTPServiceClient): 517 | client._http_bus = self._http_bus 518 | self._http_bus = bus 519 | 520 | # @property 521 | # def pubsub_bus(self): 522 | # return self._pubsub_bus 523 | # 524 | # @pubsub_bus.setter 525 | # def pubsub_bus(self, bus): 526 | # self._pubsub_bus = bus 527 | 528 | @property 529 | def clients(self): 530 | return self._clients 531 | 532 | @clients.setter 533 | def clients(self, clients): 534 | self._clients = clients 535 | 536 | @property 537 | def socket_address(self): 538 | return self._ip, self._port 539 | 540 | @property 541 | def host(self): 542 | return self._ip 543 | 544 | @property 545 | def port(self): 546 | return self._port 547 | 548 | def initiate(self): 549 | self.tcp_bus.register() 550 | # yield from self.pubsub_bus.create_pubsub_handler() 551 | # async(self.pubsub_bus.register_for_subscription(self.host, self.port, self.node_id, self.clients)) 552 | 553 | 554 | class TCPService(_ServiceHost): 555 | def __init__(self, service_name, service_version, host_ip=None, host_port=None, ssl_context=None): 556 | super(TCPService, self).__init__(service_name, service_version, host_ip, host_port) 557 | self._ssl_context = ssl_context 558 | 559 | @property 560 | def ssl_context(self): 561 | return self._ssl_context 562 | 563 | # def _publish(self, endpoint, payload): 564 | # self._pubsub_bus.publish(self.name, self.version, endpoint, payload) 565 | # 566 | # def _xpublish(self, endpoint, payload, strategy): 567 | # self._pubsub_bus.xpublish(self.name, self.version, endpoint, payload, strategy) 568 | 569 | @staticmethod 570 | def _make_response_packet(request_id: str, from_id: str, entity: str, result: object, error: object, 571 | failed: bool, old_api=None, replacement_api=None): 572 | if failed: 573 | payload = {'request_id': request_id, 'error': error, 'failed': failed} 574 | else: 575 | payload = {'request_id': request_id, 'result': result} 576 | if old_api: 577 | payload['old_api'] = old_api 578 | if replacement_api: 579 | payload['replacement_api'] = replacement_api 580 | packet = {'pid': unique_hex(), 581 | 'to': from_id, 582 | 'entity': entity, 583 | 'type': _Service._RES_PKT_STR, 584 | 'payload': payload} 585 | return packet 586 | 587 | 588 | class HTTPService(_ServiceHost, metaclass=OrderedClassMembers): 589 | def __init__(self, service_name, service_version, host_ip=None, host_port=None, ssl_context=None, 590 | allow_cross_domain=True, 591 | preflight_response=default_preflight_response): 592 | super(HTTPService, self).__init__(service_name, service_version, host_ip, host_port) 593 | self._ssl_context = ssl_context 594 | self._allow_cross_domain = allow_cross_domain 595 | self._preflight_response = preflight_response 596 | 597 | @property 598 | def ssl_context(self): 599 | return self._ssl_context 600 | 601 | @property 602 | def cross_domain_allowed(self): 603 | return self._allow_cross_domain 604 | 605 | @property 606 | def preflight_response(self): 607 | return self._preflight_response 608 | 609 | @get('/ping') 610 | def pong(self, _): 611 | return Response() 612 | 613 | @get('/_stats') 614 | def stats(self, _): 615 | res_d = Aggregator.dump_stats() 616 | return Response(status=200, content_type='application/json', body=json.dumps(res_d).encode()) 617 | 618 | 619 | class HTTPServiceClient(Singleton, _Service): 620 | def __init__(self, service_name, service_version): 621 | if not self.has_inited(): 622 | super(HTTPServiceClient, self).__init__(service_name, service_version) 623 | self.init_done() 624 | 625 | def _send_http_request(self, app_name, method, entity, params): 626 | response = yield from self._http_bus.send_http_request(app_name, self.name, self.version, method, entity, 627 | params) 628 | return response 629 | --------------------------------------------------------------------------------