├── .github └── workflows │ └── python-package.yml ├── .gitignore ├── DEVELOPMENT.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── codecov.yml ├── examples ├── __init__.py ├── properties.py ├── resubscription.py ├── shared_subscriptions.py └── will_message.py ├── gmqtt ├── __init__.py ├── client.py ├── mqtt │ ├── __init__.py │ ├── connection.py │ ├── constants.py │ ├── handler.py │ ├── package.py │ ├── property.py │ ├── protocol.py │ └── utils.py └── storage.py ├── pytest.ini ├── requirements_test.txt ├── setup.py ├── static └── logo.png └── tests ├── __init__.py ├── test_mqtt5.py └── utils.py /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ "master" ] 9 | pull_request: 10 | branches: [ "master" ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"] 20 | 21 | steps: 22 | - uses: actions/checkout@v3 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v3 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | python -m pip install flake8 pytest 31 | if [ -f requirements_test.txt ]; then pip install -r requirements_test.txt; fi 32 | - name: Test with pytest 33 | env: 34 | TOKEN: ${{ secrets.TOKEN }} 35 | run: | 36 | pytest --cov=gmqtt 37 | - name: Code coverage 38 | run: | 39 | codecov 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.pyo 3 | .idea 4 | gmqtt.egg-info 5 | dist/ 6 | build/ 7 | # virtualenvs 8 | env/ 9 | venv/ 10 | pyenv/ 11 | 12 | # pytest 13 | .coverage 14 | .pytest_cache/ 15 | htmlcov/ 16 | -------------------------------------------------------------------------------- /DEVELOPMENT.md: -------------------------------------------------------------------------------- 1 | # Development 2 | 3 | If you want to contribute to `gmqtt` or work on the code it's recommended 4 | that you use a Python virtual environment (`venv`). 5 | 6 | After you have create a fork of the original repository, clone it to your 7 | local system. 8 | 9 | ```bash 10 | git clone git@github.com:[YOUR_GITHUB_USERNAME]/gmqtt.git 11 | cd gmqtt 12 | python3 -m venv . 13 | source bin/activate 14 | python3 setup.py develop 15 | ``` 16 | 17 | Also, add the upstream repository and rebase from time to time to stay 18 | up-to-date with the on-going development. 19 | 20 | ```bash 21 | git remote add upstream git@github.com:wialon/gmqtt.git 22 | git pull --rebase upstream master 23 | ``` 24 | 25 | For a new feature or change, create a new branch locally and then you are 26 | finish open a Pull Request. 27 | 28 | ## Run tests 29 | 30 | First install the additional dependencies which are required to run the tests. 31 | 32 | ```bash 33 | pip3 install .[test] 34 | ``` 35 | 36 | The unit tests require that you have a [flespi.io account](https://flespi.io/). 37 | You will need the token from [https://flespi.io/#/panel/list/tokens](https://flespi.io/#/panel/list/tokens) 38 | which then is made available as an environment variable. 39 | 40 | ```bash 41 | export USERNAME=YOUR_FLESPI_IO_TOKEN 42 | ``` 43 | 44 | Now, you can run the tests locally. 45 | 46 | ```bash 47 | pytest-3 tests 48 | ``` 49 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2018 Gurtam 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md LICENSE -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![PyPI version](https://badge.fury.io/py/gmqtt.svg)](https://badge.fury.io/py/gmqtt) 2 | [![build status](https://github.com/wialon/gmqtt/actions/workflows/python-package.yml/badge.svg)](https://github.com/github/wialon/gmqtt/workflows/python-package.yml/badge.svg) 3 | [![Python versions](https://img.shields.io/badge/python-3.7%20%7C%203.8%20%7C%203.9%20%7C%203.10%20%7C%203.11%20%7C%203.12-brightgreen)](https://img.shields.io/badge/python-3.7%20%7C%203.8%20%7C%203.9%20%7C%203.10%20%7C%203.11%20%7C%203.12-brightgreen) 4 | [![codecov](https://codecov.io/gh/wialon/gmqtt/branch/master/graph/badge.svg)](https://codecov.io/gh/wialon/gmqtt) 5 | 6 | 7 | ### gmqtt: Python async MQTT client implementation. 8 | 9 | ![](./static/logo.png) 10 | 11 | ### Installation 12 | 13 | The latest stable version is available in the Python Package Index (PyPi) and can be installed using 14 | ```bash 15 | pip3 install gmqtt 16 | ``` 17 | 18 | 19 | ### Usage 20 | #### Getting Started 21 | 22 | Here is a very simple example that subscribes to the broker TOPIC topic and prints out the resulting messages: 23 | 24 | ```python 25 | import asyncio 26 | import os 27 | import signal 28 | import time 29 | 30 | from gmqtt import Client as MQTTClient 31 | 32 | # gmqtt also compatibility with uvloop 33 | import uvloop 34 | asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) 35 | 36 | 37 | STOP = asyncio.Event() 38 | 39 | 40 | def on_connect(client, flags, rc, properties): 41 | print('Connected') 42 | client.subscribe('TEST/#', qos=0) 43 | 44 | 45 | def on_message(client, topic, payload, qos, properties): 46 | print('RECV MSG:', payload) 47 | 48 | 49 | def on_disconnect(client, packet, exc=None): 50 | print('Disconnected') 51 | 52 | def on_subscribe(client, mid, qos, properties): 53 | print('SUBSCRIBED') 54 | 55 | def ask_exit(*args): 56 | STOP.set() 57 | 58 | async def main(broker_host, token): 59 | client = MQTTClient("client-id") 60 | 61 | client.on_connect = on_connect 62 | client.on_message = on_message 63 | client.on_disconnect = on_disconnect 64 | client.on_subscribe = on_subscribe 65 | 66 | client.set_auth_credentials(token, None) 67 | await client.connect(broker_host) 68 | 69 | client.publish('TEST/TIME', str(time.time()), qos=1) 70 | 71 | await STOP.wait() 72 | await client.disconnect() 73 | 74 | 75 | if __name__ == '__main__': 76 | loop = asyncio.get_event_loop() 77 | 78 | host = 'mqtt.flespi.io' 79 | token = os.environ.get('FLESPI_TOKEN') 80 | 81 | loop.add_signal_handler(signal.SIGINT, ask_exit) 82 | loop.add_signal_handler(signal.SIGTERM, ask_exit) 83 | 84 | loop.run_until_complete(main(host, token)) 85 | ``` 86 | 87 | ### MQTT Version 5.0 88 | gmqtt supports MQTT version 5.0 protocol 89 | 90 | #### Version setup 91 | Version 5.0 is used by default. If your broker does not support 5.0 protocol version and responds with proper CONNACK reason code, client will downgrade to 3.1 and reconnect automatically. Note, that some brokers just fail to parse the 5.0 format CONNECT packet, so first check manually if your broker handles this properly. 92 | You can also force version in connect method: 93 | ```python 94 | from gmqtt.mqtt.constants import MQTTv311 95 | client = MQTTClient('clientid') 96 | client.set_auth_credentials(token, None) 97 | await client.connect(broker_host, 1883, keepalive=60, version=MQTTv311) 98 | ``` 99 | 100 | #### Properties 101 | MQTT 5.0 protocol allows to include custom properties into packages, here is example of passing response topic property in published message: 102 | ```python 103 | 104 | TOPIC = 'testtopic/TOPIC' 105 | 106 | def on_connect(client, flags, rc, properties): 107 | client.subscribe(TOPIC, qos=1) 108 | print('Connected') 109 | 110 | def on_message(client, topic, payload, qos, properties): 111 | print('RECV MSG:', topic, payload.decode(), properties) 112 | 113 | async def main(broker_host, token): 114 | client = MQTTClient('asdfghjk') 115 | client.on_message = on_message 116 | client.on_connect = on_connect 117 | client.set_auth_credentials(token, None) 118 | await client.connect(broker_host, 1883, keepalive=60) 119 | client.publish(TOPIC, 'Message payload', response_topic='RESPONSE/TOPIC') 120 | 121 | await STOP.wait() 122 | await client.disconnect() 123 | ``` 124 | ##### Connect properties 125 | Connect properties are passed to `Client` object as kwargs (later they are stored together with properties received from broker in `client.properties` field). See example below. 126 | * `session_expiry_interval` - `int` Session expiry interval in seconds. If the Session Expiry Interval is absent the value 0 is used. If it is set to 0, or is absent, the Session ends when the Network Connection is closed. If the Session Expiry Interval is 0xFFFFFFFF (max possible value), the Session does not expire. 127 | * `receive_maximum` - `int` The Client uses this value to limit the number of QoS 1 and QoS 2 publications that it is willing to process concurrently. 128 | * `user_property` - `tuple(str, str)` This property may be used to provide additional diagnostic or other information (key-value pairs). 129 | * `maximum_packet_size` - `int` The Client uses the Maximum Packet Size (in bytes) to inform the Server that it will not process packets exceeding this limit. 130 | 131 | Example: 132 | ```python 133 | client = gmqtt.Client("lenkaklient", receive_maximum=24000, session_expiry_interval=60, user_property=('myid', '12345')) 134 | ``` 135 | 136 | ##### Publish properties 137 | This properties will be also sent in publish packet from broker, they will be passed to `on_message` callback. 138 | * `message_expiry_interval` - `int` If present, the value is the lifetime of the Application Message in seconds. 139 | * `content_type` - `unicode` UTF-8 Encoded String describing the content of the Application Message. The value of the Content Type is defined by the sending and receiving application. 140 | * `user_property` - `tuple(str, str)` 141 | * `subscription_identifier` - `int` (see subscribe properties) sent by broker 142 | * `topic_alias` - `int` First client publishes messages with topic string and kwarg topic_alias. After this initial message client can publish message with empty string topic and same topic_alias kwarg. 143 | 144 | Example: 145 | ```python 146 | def on_message(client, topic, payload, qos, properties): 147 | # properties example here: {'content_type': ['json'], 'user_property': [('timestamp', '1524235334.881058')], 'message_expiry_interval': [60], 'subscription_identifier': [42, 64]} 148 | print('RECV MSG:', topic, payload, properties) 149 | 150 | client.publish('TEST/TIME', str(time.time()), qos=1, retain=True, message_expiry_interval=60, content_type='json') 151 | ``` 152 | 153 | ##### Subscribe properties 154 | * `subscription_identifier` - `int` If the Client specified a Subscription Identifier for any of the overlapping subscriptions the Server MUST send those Subscription Identifiers in the message which is published as the result of the subscriptions. 155 | 156 | ### Reconnects 157 | By default, connected MQTT client will always try to reconnect in case of lost connections. Number of reconnect attempts is unlimited. 158 | If you want to change this behaviour, do the following: 159 | ```python 160 | client = MQTTClient("client-id") 161 | client.set_config({'reconnect_retries': 10, 'reconnect_delay': 60}) 162 | ``` 163 | Code above will set number of reconnect attempts to 10 and delay between reconnect attempts to 1min (60s). By default `reconnect_delay=6` and `reconnect_retries=-1` which stands for infinity. 164 | Note that manually calling `await client.disconnect()` will set `reconnect_retries` for 0, which will stop auto reconnect. 165 | 166 | ### Asynchronous on_message callback 167 | You can define asynchronous on_message callback. 168 | Note that it must return valid PUBACK code (`0` is success code, see full list in [constants](gmqtt/mqtt/constants.py#L69)) 169 | ```python 170 | async def on_message(client, topic, payload, qos, properties): 171 | pass 172 | return 0 173 | ``` 174 | 175 | ### Other examples 176 | Check [examples directory](examples) for more use cases. 177 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: off 4 | patch: off 5 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wialon/gmqtt/4b84685ce1c079905da22eeb119f168f1350762c/examples/__init__.py -------------------------------------------------------------------------------- /examples/properties.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import signal 4 | import time 5 | 6 | import uvloop 7 | import asyncio 8 | import gmqtt 9 | 10 | asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) 11 | 12 | STOP = asyncio.Event() 13 | 14 | 15 | def on_connect(client, flags, rc, properties): 16 | logging.info('[CONNECTED {}]'.format(client._client_id)) 17 | 18 | 19 | def on_message(client, topic, payload, qos, properties): 20 | logging.info('[RECV MSG {}] TOPIC: {} PAYLOAD: {} QOS: {} PROPERTIES: {}' 21 | .format(client._client_id, topic, payload, qos, properties)) 22 | 23 | 24 | def on_disconnect(client, packet, exc=None): 25 | logging.info('[DISCONNECTED {}]'.format(client._client_id)) 26 | 27 | 28 | def on_subscribe(client, mid, qos, properties): 29 | logging.info('[SUBSCRIBED {}] QOS: {}'.format(client._client_id, qos)) 30 | 31 | 32 | def assign_callbacks_to_client(client): 33 | # helper function which sets up client's callbacks 34 | client.on_connect = on_connect 35 | client.on_message = on_message 36 | client.on_disconnect = on_disconnect 37 | client.on_subscribe = on_subscribe 38 | 39 | 40 | def ask_exit(*args): 41 | STOP.set() 42 | 43 | 44 | async def main(broker_host, broker_port, token): 45 | # create client instance, kwargs (session expiry interval and maximum packet size) 46 | # will be send as properties in connect packet 47 | sub_client = gmqtt.Client("clientgonnasub", session_expiry_interval=600, maximum_packet_size=65535) 48 | 49 | assign_callbacks_to_client(sub_client) 50 | sub_client.set_auth_credentials(token, None) 51 | await sub_client.connect(broker_host, broker_port) 52 | 53 | # two overlapping subscriptions with different subscription identifiers 54 | sub_client.subscribe('TEST/PROPS/#', qos=1, subscription_identifier=1) 55 | sub_client.subscribe([gmqtt.Subscription('TEST/+', qos=1), gmqtt.Subscription('TEST', qos=0)], 56 | subscription_identifier=2) 57 | 58 | pub_client = gmqtt.Client("clientgonnapub") 59 | 60 | assign_callbacks_to_client(pub_client) 61 | pub_client.set_auth_credentials(token, None) 62 | await pub_client.connect(broker_host, broker_port) 63 | 64 | # this message received by sub_client will have two subscription identifiers 65 | pub_client.publish('TEST/PROPS/42', '42 is the answer', qos=1, content_type='utf-8', 66 | message_expiry_interval=60, topic_alias=42, user_property=('time', str(time.time()))) 67 | 68 | pub_client.publish('TEST/42', 'Test 42', qos=1, content_type='utf-8', 69 | message_expiry_interval=60, topic_alias=1, user_property=('time', str(time.time()))) 70 | 71 | # just another way to publish same message 72 | msg = gmqtt.Message('', '42 is the answer again', qos=1, content_type='utf-8', 73 | message_expiry_interval=60, topic_alias=42, user_property=('time', str(time.time()))) 74 | pub_client.publish(msg) 75 | 76 | pub_client.publish('TEST/42', {42: 'is the answer'}, qos=1, content_type='json', 77 | message_expiry_interval=60, topic_alias=1, user_property=('time', str(time.time()))) 78 | 79 | await STOP.wait() 80 | await pub_client.disconnect() 81 | await sub_client.disconnect(session_expiry_interval=0) 82 | 83 | 84 | if __name__ == '__main__': 85 | loop = asyncio.get_event_loop() 86 | logging.basicConfig(level=logging.INFO) 87 | 88 | host = os.environ.get('HOST', 'mqtt.flespi.io') 89 | port = 1883 90 | token = os.environ.get('TOKEN', 'fake token') 91 | 92 | loop.add_signal_handler(signal.SIGINT, ask_exit) 93 | loop.add_signal_handler(signal.SIGTERM, ask_exit) 94 | 95 | loop.run_until_complete(main(host, port, token)) 96 | -------------------------------------------------------------------------------- /examples/resubscription.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import signal 4 | import time 5 | 6 | import uvloop 7 | import asyncio 8 | import gmqtt 9 | 10 | asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) 11 | 12 | STOP = asyncio.Event() 13 | 14 | 15 | def on_connect(client, flags, rc, properties): 16 | logging.info('[CONNECTED {}]'.format(client._client_id)) 17 | 18 | 19 | def on_message(client, topic, payload, qos, properties): 20 | logging.info('[RECV MSG {}] TOPIC: {} PAYLOAD: {} QOS: {} PROPERTIES: {}' 21 | .format(client._client_id, topic, payload, qos, properties)) 22 | 23 | 24 | def on_disconnect(client, packet, exc=None): 25 | logging.info('[DISCONNECTED {}]'.format(client._client_id)) 26 | 27 | 28 | def on_subscribe(client, mid, qos, properties): 29 | # in order to check if all the subscriptions were successful, we should first get all subscriptions with this 30 | # particular mid (from one subscription request) 31 | subscriptions = client.get_subscriptions_by_mid(mid) 32 | for subscription, granted_qos in zip(subscriptions, qos): 33 | # in case of bad suback code, we can resend subscription 34 | if granted_qos >= gmqtt.constants.SubAckReasonCode.UNSPECIFIED_ERROR.value: 35 | logging.warning('[RETRYING SUB {}] mid {}, reason code: {}, properties {}'.format( 36 | client._client_id, mid, granted_qos, properties)) 37 | client.resubscribe(subscription) 38 | logging.info('[SUBSCRIBED {}] mid {}, QOS: {}, properties {}'.format( 39 | client._client_id, mid, granted_qos, properties)) 40 | 41 | 42 | def assign_callbacks_to_client(client): 43 | # helper function which sets up client's callbacks 44 | client.on_connect = on_connect 45 | client.on_message = on_message 46 | client.on_disconnect = on_disconnect 47 | client.on_subscribe = on_subscribe 48 | 49 | 50 | def ask_exit(*args): 51 | STOP.set() 52 | 53 | 54 | async def main(broker_host, broker_port, token): 55 | # create client instance, kwargs (session expiry interval and maximum packet size) 56 | # will be send as properties in connect packet 57 | sub_client = gmqtt.Client("clientgonnasub") 58 | 59 | assign_callbacks_to_client(sub_client) 60 | sub_client.set_auth_credentials(token, None) 61 | await sub_client.connect(broker_host, broker_port) 62 | 63 | # two overlapping subscriptions with different subscription identifiers 64 | subscriptions = [ 65 | gmqtt.Subscription('TEST/PROPS/#', qos=1), 66 | gmqtt.Subscription('TEST2/PROPS/#', qos=2), 67 | ] 68 | sub_client.subscribe(subscriptions, subscription_identifier=1) 69 | 70 | pub_client = gmqtt.Client("clientgonnapub") 71 | 72 | assign_callbacks_to_client(pub_client) 73 | pub_client.set_auth_credentials(token, None) 74 | await pub_client.connect(broker_host, broker_port) 75 | 76 | # this message received by sub_client will have two subscription identifiers 77 | pub_client.publish('TEST/PROPS/42', '42 is the answer', qos=1, content_type='utf-8', 78 | message_expiry_interval=60, user_property=('time', str(time.time()))) 79 | pub_client.publish('TEST2/PROPS/42', '42 is the answer', qos=1, content_type='utf-8', 80 | message_expiry_interval=60, user_property=('time', str(time.time()))) 81 | 82 | await STOP.wait() 83 | await pub_client.disconnect() 84 | await sub_client.disconnect(session_expiry_interval=0) 85 | 86 | 87 | if __name__ == '__main__': 88 | loop = asyncio.get_event_loop() 89 | logging.basicConfig(level=logging.INFO) 90 | 91 | host = os.environ.get('HOST', 'mqtt.flespi.io') 92 | port = 1883 93 | token = os.environ.get('TOKEN', 'fake token') 94 | 95 | loop.add_signal_handler(signal.SIGINT, ask_exit) 96 | loop.add_signal_handler(signal.SIGTERM, ask_exit) 97 | 98 | loop.run_until_complete(main(host, port, token)) 99 | -------------------------------------------------------------------------------- /examples/shared_subscriptions.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import signal 4 | import time 5 | 6 | import uvloop 7 | import asyncio 8 | import gmqtt 9 | 10 | asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) 11 | 12 | STOP = asyncio.Event() 13 | 14 | 15 | def on_connect(client, flags, rc, properties): 16 | logging.info('[CONNECTED {}]'.format(client._client_id)) 17 | 18 | 19 | def on_message(client, topic, payload, qos, properties): 20 | logging.info('[RECV MSG {}] TOPIC: {} PAYLOAD: {} QOS: {} PROPERTIES: {}' 21 | .format(client._client_id, topic, payload, qos, properties)) 22 | 23 | 24 | def on_disconnect(client, packet, exc=None): 25 | logging.info('[DISCONNECTED {}]'.format(client._client_id)) 26 | 27 | 28 | def on_subscribe(client, mid, qos, properties): 29 | logging.info('[SUBSCRIBED {}] QOS: {}'.format(client._client_id, qos)) 30 | 31 | 32 | def assign_callbacks_to_client(client): 33 | client.on_connect = on_connect 34 | client.on_message = on_message 35 | client.on_disconnect = on_disconnect 36 | client.on_subscribe = on_subscribe 37 | 38 | 39 | def ask_exit(*args): 40 | STOP.set() 41 | 42 | 43 | async def main(broker_host, broker_port, token): 44 | # initiate first client subscribed to TEST/SHARED/# in group mytestgroup 45 | sub_clientA = gmqtt.Client("clientgonnasubA") 46 | 47 | assign_callbacks_to_client(sub_clientA) 48 | sub_clientA.set_auth_credentials(token, None) 49 | await sub_clientA.connect(broker_host, broker_port) 50 | 51 | sub_clientA.subscribe('$share/mytestgroup/TEST/SHARED/#') 52 | 53 | # initiate second client subscribed to TEST/SHARED/# in group mytestgroup 54 | sub_clientB = gmqtt.Client("clientgonnasubB") 55 | 56 | assign_callbacks_to_client(sub_clientB) 57 | sub_clientB.set_auth_credentials(token, None) 58 | await sub_clientB.connect(broker_host, broker_port) 59 | 60 | sub_clientB.subscribe('$share/mytestgroup/TEST/SHARED/#') 61 | 62 | # this client will publish messages to TEST/SHARED/... topics 63 | pub_client = gmqtt.Client("clientgonnapub") 64 | 65 | assign_callbacks_to_client(pub_client) 66 | pub_client.set_auth_credentials(token, None) 67 | await pub_client.connect(broker_host, broker_port) 68 | 69 | # some of this messages will be received by client sub_clientA, 70 | # and another part by client sub_clientB, approximately 50/50 71 | for i in range(100): 72 | pub_client.publish('TEST/SHARED/{}'.format(i), i, user_property=('time', str(time.time()))) 73 | 74 | await STOP.wait() 75 | await pub_client.disconnect() 76 | await sub_clientA.disconnect(session_expiry_interval=0) 77 | await sub_clientB.disconnect(session_expiry_interval=0) 78 | 79 | 80 | if __name__ == '__main__': 81 | loop = asyncio.get_event_loop() 82 | logging.basicConfig(level=logging.INFO) 83 | 84 | host = os.environ.get('HOST', 'mqtt.flespi.io') 85 | port = 1883 86 | token = os.environ.get('TOKEN', 'fake token') 87 | 88 | loop.add_signal_handler(signal.SIGINT, ask_exit) 89 | loop.add_signal_handler(signal.SIGTERM, ask_exit) 90 | 91 | loop.run_until_complete(main(host, port, token)) 92 | -------------------------------------------------------------------------------- /examples/will_message.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import signal 4 | import uvloop 5 | import asyncio 6 | import gmqtt 7 | 8 | asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) 9 | 10 | STOP = asyncio.Event() 11 | 12 | 13 | def on_connect(client, flags, rc, properties): 14 | logging.info('[CONNECTED {}]'.format(client._client_id)) 15 | 16 | 17 | def on_message(client, topic, payload, qos, properties): 18 | logging.info('[RECV MSG {}] TOPIC: {} PAYLOAD: {} QOS: {} PROPERTIES: {}' 19 | .format(client._client_id, topic, payload, qos, properties)) 20 | 21 | 22 | def on_disconnect(client, packet, exc=None): 23 | logging.info('[DISCONNECTED {}]'.format(client._client_id)) 24 | 25 | 26 | def on_subscribe(client, mid, qos, properties): 27 | logging.info('[SUBSCRIBED {}] QOS: {}'.format(client._client_id, qos)) 28 | 29 | 30 | def assign_callbacks_to_client(client): 31 | client.on_connect = on_connect 32 | client.on_message = on_message 33 | client.on_disconnect = on_disconnect 34 | client.on_subscribe = on_subscribe 35 | 36 | 37 | def ask_exit(*args): 38 | STOP.set() 39 | 40 | 41 | async def main(broker_host, broker_port, token): 42 | # this message will be published by broker after client disconnects with "bad" code after 10 sec 43 | will_message = gmqtt.Message('TEST/WILL/42', "I'm dead finally", will_delay_interval=10) 44 | will_client = gmqtt.Client("clientgonnadie", will_message=will_message) 45 | 46 | assign_callbacks_to_client(will_client) 47 | will_client.set_auth_credentials(token, None) 48 | await will_client.connect(broker_host, broker_port) 49 | 50 | another_client = gmqtt.Client("clientgonnalisten") 51 | 52 | assign_callbacks_to_client(another_client) 53 | another_client.set_auth_credentials(token, None) 54 | await another_client.connect(broker_host, broker_port) 55 | 56 | another_client.subscribe('TEST/#') 57 | 58 | # reason code 4 - Disconnect with Will Message 59 | await will_client.disconnect(reason_code=4, reason_string="Smth went wrong") 60 | 61 | await STOP.wait() 62 | await another_client.disconnect() 63 | 64 | 65 | if __name__ == '__main__': 66 | loop = asyncio.get_event_loop() 67 | logging.basicConfig(level=logging.INFO) 68 | 69 | host = os.environ.get('HOST', 'mqtt.flespi.io') 70 | port = 1883 71 | token = os.environ.get('TOKEN', 'fake token') 72 | 73 | loop.add_signal_handler(signal.SIGINT, ask_exit) 74 | loop.add_signal_handler(signal.SIGTERM, ask_exit) 75 | 76 | loop.run_until_complete(main(host, port, token)) 77 | -------------------------------------------------------------------------------- /gmqtt/__init__.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | from .client import Client, Message, Subscription 4 | from .mqtt import constants 5 | from .mqtt.protocol import BaseMQTTProtocol 6 | from .mqtt.handler import MQTTConnectError 7 | 8 | __author__ = "Mikhail Turchunovich" 9 | __email__ = 'mitu@gurtam.com' 10 | __copyright__ = ("Copyright 2013-%d, Gurtam; " % datetime.datetime.now().year,) 11 | 12 | __credits__ = [ 13 | "Mikhail Turchunovich", 14 | "Elena Shylko" 15 | ] 16 | __version__ = "0.7.0" 17 | 18 | 19 | __all__ = [ 20 | 'Client', 21 | 'Message', 22 | 'Subscription', 23 | 'BaseMQTTProtocol', 24 | 'MQTTConnectError', 25 | 'constants' 26 | ] 27 | -------------------------------------------------------------------------------- /gmqtt/client.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | 4 | import logging 5 | import uuid 6 | from copy import copy 7 | from typing import Union, Sequence 8 | 9 | from .mqtt.protocol import MQTTProtocol 10 | from .mqtt.connection import MQTTConnection 11 | from .mqtt.handler import MqttPackageHandler 12 | from .mqtt.constants import MQTTv50, UNLIMITED_RECONNECTS 13 | 14 | from .storage import HeapPersistentStorage 15 | 16 | 17 | class Message: 18 | def __init__(self, topic, payload, qos=0, retain=False, **kwargs): 19 | self.topic = topic.encode('utf-8', errors='replace') if isinstance(topic, str) else topic 20 | self.qos = qos 21 | self.retain = retain 22 | self.dup = False 23 | self.properties = kwargs 24 | 25 | if isinstance(payload, (list, tuple, dict)): 26 | payload = json.dumps(payload, ensure_ascii=False) 27 | 28 | if isinstance(payload, (int, float)): 29 | self.payload = str(payload).encode('ascii') 30 | elif isinstance(payload, str): 31 | self.payload = payload.encode('utf-8', errors='replace') 32 | elif payload is None: 33 | self.payload = b'' 34 | else: 35 | self.payload = payload 36 | 37 | self.payload_size = len(self.payload) 38 | 39 | if self.payload_size > 268435455: 40 | raise ValueError('Payload too large.') 41 | 42 | 43 | class Subscription: 44 | def __init__(self, topic, qos=0, no_local=False, retain_as_published=False, retain_handling_options=0, 45 | subscription_identifier=None): 46 | self.topic = topic 47 | self.qos = qos 48 | self.no_local = no_local 49 | self.retain_as_published = retain_as_published 50 | self.retain_handling_options = retain_handling_options 51 | 52 | self.mid = None 53 | self.acknowledged = False 54 | 55 | # this property can be used only in MQTT5.0 56 | self.subscription_identifier = subscription_identifier 57 | 58 | 59 | class SubscriptionsHandler: 60 | def __init__(self): 61 | self.subscriptions = [] 62 | 63 | def update_subscriptions_with_subscription_or_topic( 64 | self, subscription_or_topic, qos, no_local, retain_as_published, retain_handling_options, kwargs): 65 | 66 | sentinel = object() 67 | subscription_identifier = kwargs.get('subscription_identifier', sentinel) 68 | 69 | if isinstance(subscription_or_topic, Subscription): 70 | 71 | if subscription_identifier is not sentinel: 72 | subscription_or_topic.subscription_identifier = subscription_identifier 73 | 74 | subscriptions = [subscription_or_topic] 75 | elif isinstance(subscription_or_topic, (tuple, list)): 76 | 77 | if subscription_identifier is not sentinel: 78 | for sub in subscription_or_topic: 79 | sub.subscription_identifier = subscription_identifier 80 | 81 | subscriptions = subscription_or_topic 82 | elif isinstance(subscription_or_topic, str): 83 | 84 | if subscription_identifier is sentinel: 85 | subscription_identifier = None 86 | 87 | subscriptions = [Subscription(subscription_or_topic, qos=qos, no_local=no_local, 88 | retain_as_published=retain_as_published, 89 | retain_handling_options=retain_handling_options, 90 | subscription_identifier=subscription_identifier)] 91 | else: 92 | raise ValueError('Bad subscription: must be string or Subscription or list of Subscriptions') 93 | self.subscriptions.extend(subscriptions) 94 | return subscriptions 95 | 96 | def _remove_subscriptions(self, topic: Union[str, Sequence[str]]): 97 | if isinstance(topic, str): 98 | self.subscriptions = [s for s in self.subscriptions if s.topic != topic] 99 | else: 100 | self.subscriptions = [s for s in self.subscriptions if s.topic not in topic] 101 | 102 | def subscribe(self, subscription_or_topic: Union[str, Subscription, Sequence[Subscription]], 103 | qos=0, no_local=False, retain_as_published=False, retain_handling_options=0, **kwargs): 104 | 105 | # Warn: if you will pass a few subscriptions objects, and each will be have different 106 | # subscription identifier - the only first will be used as identifier 107 | # if only you will not pass the identifier in kwargs 108 | 109 | subscriptions = self.update_subscriptions_with_subscription_or_topic( 110 | subscription_or_topic, qos, no_local, retain_as_published, retain_handling_options, kwargs) 111 | return self._connection.subscribe(subscriptions, **kwargs) 112 | 113 | def resubscribe(self, subscription: Subscription, **kwargs): 114 | # send subscribe packet for subscription,that's already in client's subscription list 115 | if 'subscription_identifier' in kwargs: 116 | subscription.subscription_identifier = kwargs['subscription_identifier'] 117 | elif subscription.subscription_identifier is not None: 118 | kwargs['subscription_identifier'] = subscription.subscription_identifier 119 | return self._connection.subscribe([subscription], **kwargs) 120 | 121 | def unsubscribe(self, topic: Union[str, Sequence[str]], **kwargs): 122 | self._remove_subscriptions(topic) 123 | return self._connection.unsubscribe(topic, **kwargs) 124 | 125 | 126 | class Client(MqttPackageHandler, SubscriptionsHandler): 127 | def __init__(self, client_id, clean_session=True, optimistic_acknowledgement=True, 128 | will_message=None, logger=None, **kwargs): 129 | super(Client, self).__init__(optimistic_acknowledgement=optimistic_acknowledgement, logger=logger) 130 | self._client_id = client_id or uuid.uuid4().hex 131 | 132 | # in MQTT 5.0 this is clean start flag 133 | self._clean_session = clean_session 134 | 135 | self._connection = None 136 | self._keepalive = 60 137 | 138 | self._username = None 139 | self._password = None 140 | 141 | self._host = None 142 | self._port = None 143 | self._ssl = None 144 | 145 | self._connect_properties = kwargs 146 | self._connack_properties = {} 147 | 148 | self._will_message = will_message 149 | 150 | # TODO: this constant may be moved to config 151 | self._persistent_storage = kwargs.pop('persistent_storage', HeapPersistentStorage()) 152 | 153 | self._topic_alias_maximum = kwargs.get('topic_alias_maximum', 0) 154 | 155 | self._logger = logger or logging.getLogger(__name__) 156 | 157 | def get_subscription_by_identifier(self, subscription_identifier): 158 | return next((sub for sub in self.subscriptions if sub.subscription_identifier == subscription_identifier), None) 159 | 160 | def get_subscriptions_by_mid(self, mid): 161 | return [sub for sub in self.subscriptions if sub.mid == mid] 162 | 163 | def _remove_message_from_query(self, mid): 164 | self._logger.debug('[REMOVE MESSAGE] %s', mid) 165 | asyncio.ensure_future( 166 | self._persistent_storage.remove_message_by_mid(mid) 167 | ) 168 | 169 | @property 170 | def is_connected(self): 171 | # tells if connection is alive and CONNACK was received 172 | return self._connected.is_set() and not self._connection.is_closing() 173 | 174 | async def _resend_qos_messages(self): 175 | await self._connected.wait() 176 | 177 | if await self._persistent_storage.is_empty: 178 | self._logger.debug('[QoS query IS EMPTY]') 179 | return 180 | elif self._connection.is_closing(): 181 | self._logger.debug('[Some msg need to resend] Transport is closing') 182 | return 183 | else: 184 | msgs = copy(await self._persistent_storage.get_all()) 185 | self._logger.debug('[msgs need to resend] processing %s messages', len(msgs)) 186 | 187 | await self._persistent_storage.clear() 188 | 189 | for msg in msgs: 190 | (_, mid, package) = msg 191 | 192 | try: 193 | self._connection.send_package(package) 194 | except Exception as exc: 195 | self._logger.error('[ERROR WHILE RESENDING] mid: %s', mid, exc_info=exc) 196 | 197 | await self._persistent_storage.push_message(mid, package) 198 | 199 | async def _clear_resend_qos_queue(self): 200 | await self._persistent_storage.clear() 201 | 202 | 203 | @property 204 | def properties(self): 205 | # merge two dictionaries from connect and connack packages 206 | return {**self._connack_properties, **self._connect_properties} 207 | 208 | def set_auth_credentials(self, username, password=None): 209 | self._username = username.encode() 210 | self._password = password 211 | if isinstance(self._password, str): 212 | self._password = password.encode() 213 | 214 | async def connect(self, host, port=1883, ssl=False, keepalive=60, version=MQTTv50, raise_exc=True): 215 | # Init connection 216 | self._host = host 217 | self._port = port 218 | self._ssl = ssl 219 | self._keepalive = keepalive 220 | self._is_active = True 221 | 222 | MQTTProtocol.proto_ver = version 223 | 224 | self._connection = await self._create_connection( 225 | host, port=self._port, ssl=self._ssl, clean_session=self._clean_session, keepalive=keepalive) 226 | 227 | await self._connection.auth(self._client_id, self._username, self._password, will_message=self._will_message, 228 | **self._connect_properties) 229 | await self._connected.wait() 230 | 231 | await self._persistent_storage.wait_empty() 232 | 233 | if raise_exc and self._error: 234 | raise self._error 235 | 236 | async def _create_connection(self, host, port, ssl, clean_session, keepalive): 237 | # important for reconnects, make sure u know what u are doing if wanna change :( 238 | self._exit_reconnecting_state() 239 | self._clear_topics_aliases() 240 | connection = await MQTTConnection.create_connection(host, port, ssl, clean_session, keepalive, logger=self._logger) 241 | connection.set_handler(self) 242 | return connection 243 | 244 | def _allow_reconnect(self): 245 | if self._reconnecting_now or not self._is_active: 246 | return False 247 | if self._config['reconnect_retries'] == UNLIMITED_RECONNECTS: 248 | return True 249 | if self.failed_connections <= self._config['reconnect_retries']: 250 | return True 251 | self._logger.error('[DISCONNECTED] max number of failed connection attempts achieved') 252 | return False 253 | 254 | async def reconnect(self, delay=False): 255 | if not self._allow_reconnect(): 256 | return 257 | # stopping auto-reconnects during reconnect procedure is important, better do not touch :( 258 | self._temporatily_stop_reconnect() 259 | try: 260 | await self._disconnect() 261 | except: 262 | self._logger.info('[RECONNECT] ignored error while disconnecting, trying to reconnect anyway') 263 | if delay: 264 | await asyncio.sleep(self._config['reconnect_delay']) 265 | try: 266 | self._connection = await self._create_connection(self._host, self._port, ssl=self._ssl, 267 | clean_session=False, keepalive=self._keepalive) 268 | except OSError as exc: 269 | self.failed_connections += 1 270 | self._logger.warning("[CAN'T RECONNECT] %s", self.failed_connections) 271 | asyncio.ensure_future(self.reconnect(delay=True)) 272 | return 273 | await self._connection.auth(self._client_id, self._username, self._password, 274 | will_message=self._will_message, **self._connect_properties) 275 | 276 | async def disconnect(self, reason_code=0, **properties): 277 | self._is_active = False 278 | await self._disconnect(reason_code=reason_code, **properties) 279 | 280 | async def _disconnect(self, reason_code=0, **properties): 281 | self._clear_topics_aliases() 282 | 283 | self._connected.clear() 284 | if self._connection: 285 | self._connection.send_disconnect(reason_code=reason_code, **properties) 286 | await self._connection.close() 287 | 288 | def publish(self, message_or_topic, payload=None, qos=0, retain=False, **kwargs): 289 | if isinstance(message_or_topic, Message): 290 | message = message_or_topic 291 | else: 292 | message = Message(message_or_topic, payload, qos=qos, retain=retain, **kwargs) 293 | 294 | mid, package = self._connection.publish(message) 295 | 296 | if qos > 0: 297 | self._persistent_storage.push_message_nowait(mid, package) 298 | 299 | def _send_simple_command(self, cmd): 300 | self._connection.send_simple_command(cmd) 301 | 302 | def _send_command_with_mid(self, cmd, mid, dup, reason_code=0): 303 | self._connection.send_command_with_mid(cmd, mid, dup, reason_code=reason_code) 304 | 305 | @property 306 | def protocol_version(self): 307 | return self._connection._protocol.proto_ver \ 308 | if self._connection is not None else MQTTv50 309 | -------------------------------------------------------------------------------- /gmqtt/mqtt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wialon/gmqtt/4b84685ce1c079905da22eeb119f168f1350762c/gmqtt/mqtt/__init__.py -------------------------------------------------------------------------------- /gmqtt/mqtt/connection.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import time 4 | 5 | from .protocol import MQTTProtocol 6 | 7 | class MQTTConnection(object): 8 | def __init__(self, transport: asyncio.Transport, protocol: MQTTProtocol, clean_session: bool, keepalive: int, logger=None): 9 | self._transport = transport 10 | self._protocol = protocol 11 | self._protocol.set_connection(self) 12 | self._buff = asyncio.Queue() 13 | 14 | self._clean_session = clean_session 15 | self._keepalive = keepalive 16 | 17 | self._last_data_in = time.monotonic() 18 | self._last_data_out = time.monotonic() 19 | 20 | self._keep_connection_callback = asyncio.get_event_loop().call_later(self._keepalive / 2, self._keep_connection) 21 | 22 | self._logger = logger or logging.getLogger(__name__) 23 | 24 | @classmethod 25 | async def create_connection(cls, host, port, ssl, clean_session, keepalive, loop=None, logger=None): 26 | loop = loop or asyncio.get_event_loop() 27 | transport, protocol = await loop.create_connection(MQTTProtocol, host, port, ssl=ssl) 28 | return MQTTConnection(transport, protocol, clean_session, keepalive, logger=logger) 29 | 30 | def _keep_connection(self): 31 | if self.is_closing() or not self._keepalive: 32 | return 33 | 34 | time_ = time.monotonic() 35 | if time_ - self._last_data_in >= 2 * self._keepalive: 36 | self._logger.warning("[LOST HEARTBEAT FOR %s SECONDS, GOING TO CLOSE CONNECTION]", 2 * self._keepalive) 37 | asyncio.ensure_future(self.close()) 38 | return 39 | 40 | if time_ - self._last_data_out >= 0.8 * self._keepalive or \ 41 | time_ - self._last_data_in >= 0.8 * self._keepalive: 42 | self._send_ping_request() 43 | self._keep_connection_callback = asyncio.get_event_loop().call_later(self._keepalive / 2, self._keep_connection) 44 | 45 | def put_package(self, pkg): 46 | self._last_data_in = time.monotonic() 47 | self._handler(*pkg) 48 | 49 | def send_package(self, package): 50 | # This is not blocking operation, because transport place the data 51 | # to the buffer, and this buffer flushing async 52 | self._last_data_out = time.monotonic() 53 | if isinstance(package, (bytes, bytearray)): 54 | package = package 55 | else: 56 | package = package.encode() 57 | 58 | self._transport.write(package) 59 | 60 | async def auth(self, client_id, username, password, will_message=None, **kwargs): 61 | await self._protocol.send_auth_package(client_id, username, password, self._clean_session, 62 | self._keepalive, will_message=will_message, **kwargs) 63 | 64 | def publish(self, message): 65 | return self._protocol.send_publish(message) 66 | 67 | def send_disconnect(self, reason_code=0, **properties): 68 | self._protocol.send_disconnect(reason_code=reason_code, **properties) 69 | 70 | def subscribe(self, subscription, **kwargs): 71 | return self._protocol.send_subscribe_packet(subscription, **kwargs) 72 | 73 | def unsubscribe(self, topic, **kwargs): 74 | return self._protocol.send_unsubscribe_packet(topic, **kwargs) 75 | 76 | def send_simple_command(self, cmd): 77 | self._protocol.send_simple_command_packet(cmd) 78 | 79 | def send_command_with_mid(self, cmd, mid, dup, reason_code=0): 80 | self._protocol.send_command_with_mid(cmd, mid, dup, reason_code=reason_code) 81 | 82 | def _send_ping_request(self): 83 | self._protocol.send_ping_request() 84 | 85 | def set_handler(self, handler): 86 | self._handler = handler 87 | 88 | async def close(self): 89 | if self._keep_connection_callback: 90 | self._keep_connection_callback.cancel() 91 | self._transport.close() 92 | await self._protocol.closed 93 | 94 | def is_closing(self): 95 | return self._transport.is_closing() 96 | 97 | @property 98 | def keepalive(self): 99 | return self._keepalive 100 | 101 | @keepalive.setter 102 | def keepalive(self, value): 103 | if self._keepalive == value: 104 | return 105 | self._keepalive = value 106 | if self._keep_connection_callback: 107 | self._keep_connection_callback.cancel() 108 | self._keep_connection_callback = asyncio.get_event_loop().call_later(self._keepalive / 2, self._keep_connection) 109 | -------------------------------------------------------------------------------- /gmqtt/mqtt/constants.py: -------------------------------------------------------------------------------- 1 | import enum 2 | 3 | import logging 4 | 5 | # Message types 6 | 7 | class MQTTCommands(enum.IntEnum): 8 | CONNECT = 0x10 9 | CONNACK = 0x20 10 | PUBLISH = 0x30 11 | PUBACK = 0x40 12 | PUBREC = 0x50 13 | PUBREL = 0x60 14 | PUBCOMP = 0x70 15 | SUBSCRIBE = 0x80 16 | SUBACK = 0x90 17 | UNSUBSCRIBE = 0xA0 18 | UNSUBACK = 0xB0 19 | PINGREQ = 0xC0 20 | PINGRESP = 0xD0 21 | DISCONNECT = 0xE0 22 | 23 | # CONNACK codes 24 | CONNACK_ACCEPTED = 0 25 | CONNACK_REFUSED_PROTOCOL_VERSION = 1 26 | CONNACK_REFUSED_IDENTIFIER_REJECTED = 2 27 | CONNACK_REFUSED_SERVER_UNAVAILABLE = 3 28 | CONNACK_REFUSED_BAD_USERNAME_PASSWORD = 4 29 | CONNACK_REFUSED_NOT_AUTHORIZED = 5 30 | 31 | 32 | # Error values 33 | MQTT_ERR_AGAIN = -1 34 | MQTT_ERR_SUCCESS = 0 35 | MQTT_ERR_NOMEM = 1 36 | MQTT_ERR_PROTOCOL = 2 37 | MQTT_ERR_INVAL = 3 38 | MQTT_ERR_NO_CONN = 4 39 | MQTT_ERR_CONN_REFUSED = 5 40 | MQTT_ERR_NOT_FOUND = 6 41 | MQTT_ERR_CONN_LOST = 7 42 | MQTT_ERR_TLS = 8 43 | MQTT_ERR_PAYLOAD_SIZE = 9 44 | MQTT_ERR_NOT_SUPPORTED = 10 45 | MQTT_ERR_AUTH = 11 46 | MQTT_ERR_ACL_DENIED = 12 47 | MQTT_ERR_UNKNOWN = 13 48 | MQTT_ERR_ERRNO = 14 49 | MQTT_ERR_QUEUE_SIZE = 15 50 | 51 | 52 | # MQTT protocol versions 53 | MQTTv311 = 4 54 | MQTTv50 = 5 55 | 56 | 57 | class PubAckReasonCode(enum.IntEnum): 58 | SUCCESS = 0 59 | NO_MATCHING_SUBSCRIBERS = 16 60 | UNSPECIFIED_ERROR = 128 61 | IMPLEMENTATION_SPECIFIC_ERROR = 131 62 | NOT_AUTHORIZED = 135 63 | TOPIC_NAME_INVALID = 144 64 | PACKET_IDENTIFIER_IN_USE = 145 65 | QUOTA_EXCEEDED = 151 66 | PAYLOAD_FORMAT_INVALID = 153 67 | 68 | 69 | class PubRecReasonCode(enum.IntEnum): 70 | SUCCESS = 0 71 | NO_MATCHING_SUBSCRIBERS = 16 72 | UNSPECIFIED_ERROR = 128 73 | IMPLEMENTATION_SPECIFIC_ERROR = 131 74 | NOT_AUTHORIZED = 135 75 | TOPIC_NAME_INVALID = 144 76 | PACKET_IDENTIFIER_IN_USE = 145 77 | QUOTA_EXCEEDED = 151 78 | PAYLOAD_FORMAT_INVALID = 153 79 | 80 | 81 | class ConnAckReasonCode(enum.IntEnum): 82 | SUCCESS = 0 83 | UNSPECIFIED_ERROR = 128 84 | MALFORMED_PACKET = 129 85 | PROTOCOL_ERROR = 130 86 | IMPLEMENTATION_SPECIFIC_ERROR = 131 87 | UNSUPPORTED_PROTOCOL_VERSION = 132 88 | CLIENT_IDENTIFIER_NOT_VALID = 133 89 | BAD_USERNAME_OR_PASSWORD = 134 90 | NOT_AUTHORIZED = 135 91 | SERVER_UNAVAILABLE = 136 92 | SERVER_BUSY = 137 93 | BANNED = 138 94 | BAD_AUTHENTICATION_METHOD = 140 95 | TOPIC_NAME_INVALID = 144 96 | PACKET_TOO_LARGE = 149 97 | PAYLOAD_FORMAT_INVALID = 153 98 | RETAIN_NOT_SUPPORTED = 154 99 | QOS_NOT_SUPPORTED = 155 100 | USE_ANOTHER_SERVER = 156 101 | SERVER_MOVED = 157 102 | CONNECTION_RATE_EXCEEDED = 159 103 | 104 | 105 | class SubAckReasonCode(enum.IntEnum): 106 | QOS0 = 0 107 | QOS1 = 1 108 | QOS2 = 2 109 | 110 | UNSPECIFIED_ERROR = 128 111 | IMPLEMENTATION_SPECIFIC_ERROR = 131 112 | NOT_AUTHORIZED = 135 113 | TOPIC_FILTER_INVALID = 143 114 | PACKET_IDENTIFIER_IN_USE = 145 115 | QUOTA_EXCEEDED = 151 116 | SHARED_SUBSCRIPTIONS_NOT_SUPPORTED = 158 117 | SUBSCRIPTION_IDENTIFIERS_NOT_SUPPORTED = 161 118 | WILDCARD_SUBSCRIPTIONS_NOT_SUPPORTED = 162 119 | 120 | 121 | UNLIMITED_RECONNECTS = -1 122 | 123 | DEFAULT_CONFIG = { 124 | 'reconnect_delay': 6, 125 | 'reconnect_retries': UNLIMITED_RECONNECTS, 126 | } 127 | -------------------------------------------------------------------------------- /gmqtt/mqtt/handler.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import struct 4 | import time 5 | from collections import defaultdict 6 | from copy import deepcopy 7 | from functools import partial 8 | 9 | from .utils import unpack_variable_byte_integer, IdGenerator, run_coroutine_or_function 10 | from .property import Property 11 | from .protocol import MQTTProtocol 12 | from .constants import MQTTCommands, PubRecReasonCode, DEFAULT_CONFIG 13 | from .constants import MQTTv311, MQTTv50 14 | 15 | 16 | def _empty_callback(*args, **kwargs): 17 | pass 18 | 19 | 20 | class MQTTError(Exception): 21 | pass 22 | 23 | 24 | class MQTTConnectError(MQTTError): 25 | __messages__ = { 26 | 1: "Connection Refused: unacceptable protocol version", 27 | 2: "Connection Refused: identifier rejected", 28 | 3: "Connection Refused: broker unavailable", 29 | 4: "Connection Refused: bad user name or password", 30 | 5: "Connection Refused: not authorised", 31 | 10: 'Cannot handle CONNACK package', 32 | 128: "Connection Refused: Unspecified error", 33 | 129: "Connection Refused: Malformed Packet", 34 | 130: "Connection Refused: Protocol Error", 35 | 131: "Connection Refused: Implementation specific error", 36 | 132: "Connection Refused: Unsupported Protocol Version", 37 | 133: "Connection Refused: Client Identifier not valid", 38 | 134: "Connection Refused: Bad User Name or Password", 39 | 135: "Connection Refused: Not authorized", 40 | 136: "Connection Refused: Server unavailable", 41 | 137: "Connection Refused: Server busy", 42 | 138: "Connection Refused: Banned", 43 | 140: "Connection Refused: Bad authentication method", 44 | 144: "Connection Refused: Topic Name invalid", 45 | 149: "Connection Refused: Packet too large", 46 | 151: "Connection Refused: Quota exceeded", 47 | 153: "Connection Refused: Payload format invalid", 48 | 154: "Connection Refused: Retain not supported", 49 | 155: "Connection Refused: QoS not supported", 50 | 156: "Connection Refused: Use another server", 51 | 157: "Connection Refused: Server moved", 52 | 159: "Connection Refused: Connection rate exceeded", 53 | } 54 | 55 | def __init__(self, code): 56 | self._code = code 57 | self.message = self.__messages__.get(code, 'Unknown error')\ 58 | 59 | 60 | def __str__(self): 61 | return "code {} ({})".format(self._code, self.message) 62 | 63 | 64 | class EventCallback(object): 65 | def __init__(self, *args, **kwargs): 66 | super(EventCallback, self).__init__() 67 | 68 | self._connected = asyncio.Event() 69 | 70 | self._on_connected_callback = _empty_callback 71 | self._on_disconnected_callback = _empty_callback 72 | self._on_message_callback = _empty_callback 73 | self._on_subscribe_callback = _empty_callback 74 | self._on_unsubscribe_callback = _empty_callback 75 | 76 | self._config = deepcopy(DEFAULT_CONFIG) 77 | self._reconnecting_now = False 78 | 79 | # this flag should be True after connect and False when disconnect was called 80 | self._is_active = False 81 | 82 | self.failed_connections = 0 83 | 84 | def _temporatily_stop_reconnect(self): 85 | self._reconnecting_now = True 86 | 87 | def _exit_reconnecting_state(self): 88 | self._reconnecting_now = False 89 | 90 | def stop_reconnect(self): 91 | self._config['reconnect_retries'] = 0 92 | 93 | def set_config(self, config): 94 | self._config.update(config) 95 | 96 | @property 97 | def reconnect_delay(self): 98 | return self._config['reconnect_delay'] 99 | 100 | @reconnect_delay.setter 101 | def reconnect_delay(self, value): 102 | self._config['reconnect_delay'] = value 103 | 104 | @property 105 | def reconnect_retries(self): 106 | return self._config['reconnect_retries'] 107 | 108 | @reconnect_retries.setter 109 | def reconnect_retries(self, value): 110 | self._config['reconnect_retries'] = value 111 | 112 | 113 | @property 114 | def on_subscribe(self): 115 | return self._on_subscribe_callback 116 | 117 | @on_subscribe.setter 118 | def on_subscribe(self, cb): 119 | if not callable(cb): 120 | raise ValueError 121 | self._on_subscribe_callback = cb 122 | 123 | @property 124 | def on_connect(self): 125 | return self._on_connected_callback 126 | 127 | @on_connect.setter 128 | def on_connect(self, cb): 129 | if not callable(cb): 130 | raise ValueError 131 | self._on_connected_callback = cb 132 | 133 | @property 134 | def on_message(self): 135 | return self._on_message_callback 136 | 137 | @on_message.setter 138 | def on_message(self, cb): 139 | if not callable(cb): 140 | raise ValueError 141 | self._on_message_callback = cb 142 | 143 | @property 144 | def on_disconnect(self): 145 | return self._on_disconnected_callback 146 | 147 | @on_disconnect.setter 148 | def on_disconnect(self, cb): 149 | if not callable(cb): 150 | raise ValueError 151 | self._on_disconnected_callback = cb 152 | 153 | @property 154 | def on_unsubscribe(self): 155 | return self._on_unsubscribe_callback 156 | 157 | @on_unsubscribe.setter 158 | def on_unsubscribe(self, cb): 159 | if not callable(cb): 160 | raise ValueError 161 | self._on_unsubscribe_callback = cb 162 | 163 | 164 | class MqttPackageHandler(EventCallback): 165 | def __init__(self, *args, **kwargs): 166 | super(MqttPackageHandler, self).__init__(*args, **kwargs) 167 | self._messages_in = {} 168 | self._handler_cache = {} 169 | self._error = None 170 | self._connection = None 171 | self._server_topics_aliases = {} 172 | 173 | self._id_generator = IdGenerator(max=kwargs.get('receive_maximum', 65535)) 174 | 175 | if self.protocol_version == MQTTv50: 176 | self._optimistic_acknowledgement = kwargs.get('optimistic_acknowledgement', True) 177 | else: 178 | self._optimistic_acknowledgement = True 179 | 180 | self._logger = kwargs.get('logger', logging.getLogger(__name__)) 181 | 182 | def _clear_topics_aliases(self): 183 | self._server_topics_aliases = {} 184 | 185 | def _send_command_with_mid(self, cmd, mid, dup, reason_code=0): 186 | raise NotImplementedError 187 | 188 | def _remove_message_from_query(self, mid): 189 | raise NotImplementedError 190 | 191 | def _send_puback(self, mid, reason_code=0): 192 | self._send_command_with_mid(MQTTCommands.PUBACK, mid, False, reason_code=reason_code) 193 | 194 | def _send_pubrec(self, mid, reason_code=0): 195 | self._send_command_with_mid(MQTTCommands.PUBREC, mid, False, reason_code=reason_code) 196 | 197 | def _send_pubrel(self, mid, dup, reason_code=0): 198 | self._send_command_with_mid(MQTTCommands.PUBREL | 2, mid, dup, reason_code=reason_code) 199 | 200 | def _send_pubcomp(self, mid, dup, reason_code=0): 201 | self._send_command_with_mid(MQTTCommands.PUBCOMP, mid, dup, reason_code=reason_code) 202 | 203 | def __get_handler__(self, cmd): 204 | cmd_type = cmd & 0xF0 205 | if cmd_type not in self._handler_cache: 206 | handler_name = '_handle_{}_packet'.format(MQTTCommands(cmd_type).name.lower()) 207 | self._handler_cache[cmd_type] = getattr(self, handler_name, self._default_handler) 208 | return self._handler_cache[cmd_type] 209 | 210 | def _handle_packet(self, cmd, packet): 211 | self._logger.debug('[CMD %s] %s', hex(cmd), packet) 212 | handler = self.__get_handler__(cmd) 213 | handler(cmd, packet) 214 | self._last_msg_in = time.monotonic() 215 | 216 | def _handle_exception_in_future(self, future): 217 | if future.exception(): 218 | self._logger.warning('[EXC OCCURED] in reconnect future %s', future.exception()) 219 | return 220 | 221 | def _default_handler(self, cmd, packet): 222 | self._logger.warning('[UNKNOWN CMD] %s %s', hex(cmd), packet) 223 | 224 | def _handle_disconnect_packet(self, cmd, packet): 225 | # reset server topics on disconnect 226 | self._clear_topics_aliases() 227 | 228 | future = asyncio.ensure_future(self.reconnect(delay=True)) 229 | future.add_done_callback(self._handle_exception_in_future) 230 | self.on_disconnect(self, packet) 231 | 232 | def _parse_properties(self, packet): 233 | if self.protocol_version < MQTTv50: 234 | # If protocol is version is less than 5.0, there is no properties in packet 235 | return {}, packet 236 | properties_len, left_packet = unpack_variable_byte_integer(packet) 237 | packet = left_packet[:properties_len] 238 | left_packet = left_packet[properties_len:] 239 | properties_dict = defaultdict(list) 240 | while packet: 241 | property_identifier, = struct.unpack("!B", packet[:1]) 242 | property_obj = Property.factory(id_=property_identifier) 243 | if property_obj is None: 244 | self._logger.critical('[PROPERTIES] received invalid property id {}, disconnecting'.format(property_identifier)) 245 | return None, None 246 | result, packet = property_obj.loads(packet[1:]) 247 | for k, v in result.items(): 248 | properties_dict[k].append(v) 249 | properties_dict = dict(properties_dict) 250 | return properties_dict, left_packet 251 | 252 | def _update_keepalive_if_needed(self): 253 | if not self._connack_properties.get('server_keep_alive'): 254 | return 255 | self._keepalive = self._connack_properties['server_keep_alive'][0] 256 | self._connection.keepalive = self._keepalive 257 | 258 | def _handle_connack_packet(self, cmd, packet): 259 | self._connected.set() 260 | 261 | (session_present, result) = struct.unpack("!BB", packet[:2]) 262 | if session_present: 263 | asyncio.ensure_future(self._resend_qos_messages()) 264 | else: 265 | asyncio.ensure_future(self._clear_resend_qos_queue()) 266 | 267 | if result != 0: 268 | self._logger.warning('[CONNACK] %s', hex(result)) 269 | self.failed_connections += 1 270 | if result == 1 and self.protocol_version == MQTTv50: 271 | self._logger.info('[CONNACK] Downgrading to MQTT 3.1 protocol version') 272 | MQTTProtocol.proto_ver = MQTTv311 273 | future = asyncio.ensure_future(self.reconnect(delay=True)) 274 | future.add_done_callback(self._handle_exception_in_future) 275 | return 276 | else: 277 | self._error = MQTTConnectError(result) 278 | asyncio.ensure_future(self.reconnect(delay=True)) 279 | return 280 | else: 281 | self.failed_connections = 0 282 | 283 | if len(packet) > 2: 284 | properties, _ = self._parse_properties(packet[2:]) 285 | if properties is None: 286 | self._error = MQTTConnectError(10) 287 | asyncio.ensure_future(self.disconnect()) 288 | self._connack_properties = properties 289 | self._update_keepalive_if_needed() 290 | 291 | # TODO: Implement checking for the flags and results 292 | # see 3.2.2.3 Connect Return code of the http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.pdf 293 | 294 | self._logger.debug('[CONNACK] session_present: %s, result: %s', hex(session_present), hex(result)) 295 | self.on_connect(self, session_present, result, self.properties) 296 | 297 | def _handle_publish_packet(self, cmd, raw_packet): 298 | header = cmd 299 | 300 | dup = (header & 0x08) >> 3 301 | qos = (header & 0x06) >> 1 302 | retain = header & 0x01 303 | 304 | pack_format = "!H" + str(len(raw_packet) - 2) + 's' 305 | (slen, packet) = struct.unpack(pack_format, raw_packet) 306 | 307 | pack_format = '!' + str(slen) + 's' + str(len(packet) - slen) + 's' 308 | (topic, packet) = struct.unpack(pack_format, packet) 309 | 310 | # we will change the packet ref, let's save origin 311 | payload = packet 312 | 313 | if qos > 0: 314 | pack_format = "!H" + str(len(packet) - 2) + 's' 315 | (mid, packet) = struct.unpack(pack_format, packet) 316 | else: 317 | mid = None 318 | 319 | properties, packet = self._parse_properties(packet) 320 | properties['dup'] = dup 321 | properties['retain'] = retain 322 | 323 | if packet is None: 324 | self._logger.critical('[INVALID MESSAGE] skipping: {}'.format(raw_packet)) 325 | return 326 | 327 | if 'topic_alias' in properties: 328 | # TODO: need to add validation (topic alias must be greater than 0 and less than topic_alias_maximum) 329 | topic_alias = properties['topic_alias'][0] 330 | if topic: 331 | self._server_topics_aliases[topic_alias] = topic 332 | else: 333 | topic = self._server_topics_aliases.get(topic_alias, None) 334 | 335 | if not topic: 336 | self._logger.warning('[MQTT ERR PROTO] topic name is empty (or server has send invalid topic alias)') 337 | return 338 | 339 | try: 340 | print_topic = topic.decode('utf-8') 341 | except UnicodeDecodeError as exc: 342 | self._logger.warning('[INVALID CHARACTER IN TOPIC] %s', topic, exc_info=exc) 343 | print_topic = topic 344 | 345 | self._logger.debug('[RECV %s with QoS: %s] %s', print_topic, qos, payload) 346 | 347 | if qos == 0: 348 | run_coroutine_or_function(self.on_message, self, print_topic, packet, qos, properties) 349 | elif qos == 1: 350 | self._handle_qos_1_publish_packet(mid, packet, print_topic, properties) 351 | elif qos == 2: 352 | self._handle_qos_2_publish_packet(mid, packet, print_topic, properties) 353 | self._id_generator.free_id(mid) 354 | 355 | def _handle_qos_2_publish_packet(self, mid, packet, print_topic, properties): 356 | if self._optimistic_acknowledgement: 357 | self._send_pubrec(mid) 358 | run_coroutine_or_function(self.on_message, self, print_topic, packet, 2, properties) 359 | else: 360 | run_coroutine_or_function(self.on_message, self, print_topic, packet, 2, properties, 361 | callback=partial(self.__handle_publish_callback, qos=2, mid=mid)) 362 | 363 | def __handle_publish_callback(self, f, qos=None, mid=None): 364 | reason_code = f.result() 365 | if reason_code not in (c.value for c in PubRecReasonCode): 366 | raise ValueError('Invalid PUBREC reason code {}'.format(reason_code)) 367 | if qos == 2: 368 | self._send_pubrec(mid, reason_code=reason_code) 369 | else: 370 | self._send_puback(mid, reason_code=reason_code) 371 | self._id_generator.free_id(mid) 372 | 373 | def _handle_qos_1_publish_packet(self, mid, packet, print_topic, properties): 374 | if self._optimistic_acknowledgement: 375 | self._send_puback(mid) 376 | run_coroutine_or_function(self.on_message, self, print_topic, packet, 1, properties) 377 | else: 378 | run_coroutine_or_function(self.on_message, self, print_topic, packet, 1, properties, 379 | callback=partial(self.__handle_publish_callback, qos=1, mid=mid)) 380 | 381 | def __call__(self, cmd, packet): 382 | try: 383 | result = self._handle_packet(cmd, packet) 384 | except Exception as exc: 385 | self._logger.error('[ERROR HANDLE PKG]', exc_info=exc) 386 | result = None 387 | return result 388 | 389 | def _handle_suback_packet(self, cmd, raw_packet): 390 | pack_format = "!H" + str(len(raw_packet) - 2) + 's' 391 | (mid, packet) = struct.unpack(pack_format, raw_packet) 392 | properties, packet = self._parse_properties(packet) 393 | 394 | pack_format = "!" + "B" * len(packet) 395 | granted_qoses = struct.unpack(pack_format, packet) 396 | 397 | subs = self.get_subscriptions_by_mid(mid) 398 | for granted_qos, sub in zip(granted_qoses, subs): 399 | if granted_qos >= 128: 400 | # subscription was not acknowledged 401 | sub.acknowledged = False 402 | else: 403 | sub.acknowledged = True 404 | sub.qos = granted_qos 405 | 406 | self._logger.info('[SUBACK] %s %s', mid, granted_qoses) 407 | self.on_subscribe(self, mid, granted_qoses, properties) 408 | 409 | for sub in self.subscriptions: 410 | if sub.mid == mid: 411 | sub.mid = None 412 | self._id_generator.free_id(mid) 413 | 414 | def _handle_unsuback_packet(self, cmd, raw_packet): 415 | pack_format = "!H" + str(len(raw_packet) - 2) + 's' 416 | (mid, packet) = struct.unpack(pack_format, raw_packet) 417 | pack_format = "!" + "B" * len(packet) 418 | granted_qos = struct.unpack(pack_format, packet) 419 | 420 | self._logger.info('[UNSUBACK] %s %s', mid, granted_qos) 421 | 422 | self.on_unsubscribe(self, mid, granted_qos) 423 | self._id_generator.free_id(mid) 424 | 425 | def _handle_pingreq_packet(self, cmd, packet): 426 | self._logger.debug('[PING REQUEST] %s %s', hex(cmd), packet) 427 | pass 428 | 429 | def _handle_pingresp_packet(self, cmd, packet): 430 | self._logger.debug('[PONG REQUEST] %s %s', hex(cmd), packet) 431 | 432 | def _handle_puback_packet(self, cmd, packet): 433 | (mid, ) = struct.unpack("!H", packet[:2]) 434 | 435 | # TODO: For MQTT 5.0 parse reason code and properties 436 | 437 | self._logger.debug('[RECEIVED PUBACK FOR] %s', mid) 438 | 439 | self._id_generator.free_id(mid) 440 | self._remove_message_from_query(mid) 441 | 442 | def _handle_pubcomp_packet(self, cmd, packet): 443 | pass 444 | 445 | def _handle_pubrec_packet(self, cmd, packet): 446 | (mid,) = struct.unpack("!H", packet[:2]) 447 | self._logger.debug('[RECEIVED PUBREC FOR] %s', mid) 448 | self._id_generator.free_id(mid) 449 | self._remove_message_from_query(mid) 450 | self._send_pubrel(mid, 0) 451 | 452 | def _handle_pubrel_packet(self, cmd, packet): 453 | (mid, ) = struct.unpack("!H", packet[:2]) 454 | self._logger.debug('[RECEIVED PUBREL FOR] %s', mid) 455 | self._send_pubcomp(mid, 0) 456 | 457 | self._id_generator.free_id(mid) 458 | -------------------------------------------------------------------------------- /gmqtt/mqtt/package.py: -------------------------------------------------------------------------------- 1 | import struct 2 | import logging 3 | from typing import Tuple 4 | 5 | from .constants import MQTTCommands, MQTTv50 6 | from .property import Property 7 | from .utils import pack_variable_byte_integer, IdGenerator 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | LAST_MID = 0 12 | USED_IDS = set() 13 | 14 | 15 | class Packet(object): 16 | __slots__ = ['cmd', 'data'] 17 | 18 | def __init__(self, cmd, data): 19 | self.cmd = cmd 20 | self.data = data 21 | 22 | 23 | class PackageFactory(object): 24 | id_generator = IdGenerator() 25 | 26 | @classmethod 27 | async def parse_package(cls, cmd, package): 28 | pass 29 | 30 | @classmethod 31 | def build_package(cls, *args, **kwargs) -> bytes: 32 | raise NotImplementedError 33 | 34 | @classmethod 35 | def _pack_str16(cls, packet, data): 36 | if isinstance(data, str): 37 | data = data.encode('utf-8') 38 | packet.extend(struct.pack("!H", len(data))) 39 | packet.extend(data) 40 | 41 | @classmethod 42 | def _build_properties_data(cls, properties_dict, protocol_version): 43 | if protocol_version < MQTTv50: 44 | return bytearray() 45 | data = bytearray() 46 | for property_name, property_value in properties_dict.items(): 47 | property = Property.factory(name=property_name) 48 | if property is None: 49 | logger.warning('[GMQTT] property {} is not supported, it was ignored'.format(property_name)) 50 | continue 51 | property_bytes = property.dumps(property_value) 52 | data.extend(property_bytes) 53 | result = pack_variable_byte_integer(len(data)) 54 | result.extend(data) 55 | return result 56 | 57 | 58 | class LoginPackageFactor(PackageFactory): 59 | @classmethod 60 | def build_package(cls, client_id, username, password, clean_session, keepalive, protocol, will_message=None, **kwargs): 61 | remaining_length = 2 + len(protocol.proto_name) + 1 + 1 + 2 + 2 + len(client_id) 62 | 63 | connect_flags = 0 64 | if clean_session: 65 | connect_flags |= 0x02 66 | 67 | if will_message: 68 | will_prop_bytes = cls._build_properties_data(will_message.properties, protocol.proto_ver) 69 | remaining_length += 2 + len(will_message.topic) + 2 + len(will_message.payload) + len(will_prop_bytes) 70 | connect_flags |= 0x04 | ((will_message.qos & 0x03) << 3) | ((will_message.retain & 0x01) << 5) 71 | 72 | if username is not None: 73 | remaining_length += 2 + len(username) 74 | connect_flags |= 0x80 75 | if password is not None: 76 | connect_flags |= 0x40 77 | remaining_length += 2 + len(password) 78 | 79 | command = MQTTCommands.CONNECT 80 | packet = bytearray() 81 | packet.append(command) 82 | 83 | prop_bytes = cls._build_properties_data(kwargs, protocol.proto_ver) 84 | remaining_length += len(prop_bytes) 85 | 86 | packet.extend(pack_variable_byte_integer(remaining_length)) 87 | packet.extend(struct.pack("!H" + str(len(protocol.proto_name)) + "sBBH", 88 | len(protocol.proto_name), 89 | protocol.proto_name, 90 | protocol.proto_ver, 91 | connect_flags, 92 | keepalive)) 93 | 94 | packet.extend(prop_bytes) 95 | 96 | cls._pack_str16(packet, client_id) 97 | 98 | if will_message: 99 | packet += will_prop_bytes 100 | cls._pack_str16(packet, will_message.topic) 101 | cls._pack_str16(packet, will_message.payload) 102 | 103 | if username is not None: 104 | cls._pack_str16(packet, username) 105 | 106 | if password is not None: 107 | cls._pack_str16(packet, password) 108 | 109 | return packet 110 | 111 | 112 | class UnsubscribePacket(PackageFactory): 113 | @classmethod 114 | def build_package(cls, topic, protocol, **kwargs) -> Tuple[int, bytes]: 115 | remaining_length = 2 116 | if not isinstance(topic, (list, tuple)): 117 | topics = [topic] 118 | else: 119 | topics = topic 120 | 121 | for t in topics: 122 | remaining_length += 2 + len(t) 123 | 124 | properties = cls._build_properties_data(kwargs, protocol.proto_ver) 125 | remaining_length += len(properties) 126 | 127 | command = MQTTCommands.UNSUBSCRIBE | 0x2 128 | packet = bytearray() 129 | packet.append(command) 130 | packet.extend(pack_variable_byte_integer(remaining_length)) 131 | local_mid = cls.id_generator.next_id() 132 | packet.extend(struct.pack("!H", local_mid)) 133 | packet.extend(properties) 134 | for t in topics: 135 | cls._pack_str16(packet, t) 136 | 137 | logger.info('[SEND UNSUB] %s', topics) 138 | 139 | return local_mid, packet 140 | 141 | 142 | class SubscribePacket(PackageFactory): 143 | sentinel = object() 144 | 145 | @classmethod 146 | def build_package(cls, subscriptions, protocol, **kwargs) -> Tuple[int, bytes]: 147 | remaining_length = 2 148 | 149 | topics = [] 150 | subscription_identifier = kwargs.get('subscription_identifier', cls.sentinel) 151 | 152 | for s in subscriptions: 153 | topic = s.topic 154 | if isinstance(topic, str): 155 | topic = topic.encode() 156 | 157 | remaining_length += 2 + len(topic) + 1 158 | topics.append(topic) 159 | 160 | # if subscription_identifier hasn't been passed in kwargs, 161 | # we will use the first identifier for all subscriptions; 162 | if subscription_identifier is cls.sentinel: 163 | subscription_identifier = s.subscription_identifier 164 | 165 | if subscription_identifier is not cls.sentinel: 166 | kwargs['subscription_identifier'] = subscription_identifier 167 | 168 | if subscription_identifier is None: 169 | kwargs.pop('subscription_identifier', None) 170 | 171 | properties = cls._build_properties_data(kwargs, protocol.proto_ver) 172 | remaining_length += len(properties) 173 | 174 | command = MQTTCommands.SUBSCRIBE | (False << 3) | 0x2 175 | packet = bytearray() 176 | packet.append(command) 177 | packet.extend(pack_variable_byte_integer(remaining_length)) 178 | local_mid = cls.id_generator.next_id() 179 | packet.extend(struct.pack("!H", local_mid)) 180 | packet.extend(properties) 181 | for s in subscriptions: 182 | cls._pack_str16(packet, s.topic) 183 | subscribe_options = s.retain_handling_options << 4 | s.retain_as_published << 3 | s.no_local << 2 | s.qos 184 | packet.append(subscribe_options) 185 | 186 | logger.info('[SEND SUB] %s %s', local_mid, topics) 187 | 188 | return local_mid, packet 189 | 190 | 191 | class SimpleCommandPacket(PackageFactory): 192 | @classmethod 193 | def build_package(cls, command) -> bytes: 194 | return struct.pack('!BB', command, 0) 195 | 196 | 197 | class PublishPacket(PackageFactory): 198 | @classmethod 199 | def build_package(cls, message, protocol) -> Tuple[int, bytes]: 200 | command = MQTTCommands.PUBLISH | ((message.dup & 0x1) << 3) | (message.qos << 1) | (message.retain & 0x1) 201 | 202 | packet = bytearray() 203 | packet.append(command) 204 | 205 | remaining_length = 2 + len(message.topic) + message.payload_size 206 | prop_bytes = cls._build_properties_data(message.properties, protocol_version=protocol.proto_ver) 207 | remaining_length += len(prop_bytes) 208 | 209 | if message.payload_size == 0: 210 | logger.debug("Sending PUBLISH (q%d), '%s' (NULL payload)", message.qos, message.topic) 211 | else: 212 | logger.debug("Sending PUBLISH (q%d), '%s', ... (%d bytes)", message.qos, message.topic, message.payload_size) 213 | 214 | if message.qos > 0: 215 | # For message id 216 | remaining_length += 2 217 | 218 | packet.extend(pack_variable_byte_integer(remaining_length)) 219 | cls._pack_str16(packet, message.topic) 220 | 221 | if message.qos > 0: 222 | # For message id 223 | mid = cls.id_generator.next_id() 224 | packet.extend(struct.pack("!H", mid)) 225 | else: 226 | mid = None 227 | packet.extend(prop_bytes) 228 | 229 | packet.extend(message.payload) 230 | 231 | return mid, packet 232 | 233 | 234 | class DisconnectPacket(PackageFactory): 235 | @classmethod 236 | def build_package(cls, protocol, reason_code=0, **properties): 237 | if protocol.proto_ver == MQTTv50: 238 | prop_bytes = cls._build_properties_data(properties, protocol_version=protocol.proto_ver) 239 | remaining_length = 1 + len(prop_bytes) 240 | return struct.pack('!BBB', MQTTCommands.DISCONNECT.value, remaining_length, reason_code) + prop_bytes 241 | else: 242 | return struct.pack('!BB', MQTTCommands.DISCONNECT.value, 0) 243 | 244 | 245 | class CommandWithMidPacket(PackageFactory): 246 | 247 | @classmethod 248 | def build_package(cls, cmd, mid, dup, reason_code=0, proto_ver=MQTTv50) -> bytes: 249 | if dup: 250 | cmd |= 0x8 251 | if proto_ver == MQTTv50: 252 | remaining_length = 4 253 | packet = struct.pack('!BBHBB', cmd, remaining_length, mid, reason_code, 0) 254 | else: 255 | remaining_length = 2 256 | packet = struct.pack('!BBH', cmd, remaining_length, mid) 257 | return packet 258 | -------------------------------------------------------------------------------- /gmqtt/mqtt/property.py: -------------------------------------------------------------------------------- 1 | import struct 2 | 3 | from .utils import unpack_variable_byte_integer, pack_variable_byte_integer, pack_utf8, unpack_utf8 4 | 5 | 6 | class Property: 7 | def __init__(self, id_, bytes_struct, name, allowed_packages): 8 | self.id = id_ 9 | self.bytes_struct = bytes_struct 10 | self.name = name 11 | self.allowed_packages = allowed_packages 12 | 13 | def loads(self, bytes_array): 14 | # returns dict with property name-value and remaining bytes which do not belong to this property 15 | if self.bytes_struct == 'u8': 16 | # First two bytes in UTF-8 encoded properties correspond to unicode string length 17 | value, left_str = unpack_utf8(bytes_array) 18 | elif self.bytes_struct == 'u8x2': 19 | value1, left_str = unpack_utf8(bytes_array) 20 | value2, left_str = unpack_utf8(left_str) 21 | value = (value1, value2) 22 | elif self.bytes_struct == 'b': 23 | str_len, = struct.unpack('!H', bytes_array[:2]) 24 | value = bytes_array[2:2 + str_len] 25 | left_str = bytes_array[2 + str_len:] 26 | elif self.bytes_struct == 'vbi': 27 | value, left_str = unpack_variable_byte_integer(bytes_array) 28 | else: 29 | value, left_str = self.unpack_helper(self.bytes_struct, bytes_array) 30 | return {self.name: value}, left_str 31 | 32 | def unpack_helper(self, fmt, data): 33 | # unpacks property value according to format, returns value and remaining bytes 34 | size = struct.calcsize(fmt) 35 | value = struct.unpack(fmt, data[:size]) 36 | left_str = data[size:] 37 | if len(value) == 1: 38 | value = value[0] 39 | return value, left_str 40 | 41 | def _dump_user_property(self, data, packet): 42 | packet.extend(struct.pack('!B', self.id)) 43 | data1, data2 = data 44 | packet.extend(pack_utf8(data1)) 45 | packet.extend(pack_utf8(data2)) 46 | 47 | def dumps(self, data): 48 | # packs property value into byte array 49 | packet = bytearray() 50 | if self.bytes_struct == 'u8': 51 | packet.extend(struct.pack('!B', self.id)) 52 | packet.extend(pack_utf8(data)) 53 | return packet 54 | elif self.bytes_struct == 'u8x2': 55 | if isinstance(data[0], str): 56 | self._dump_user_property(data, packet) 57 | else: 58 | for kv_pair in data: 59 | self._dump_user_property(kv_pair, packet) 60 | return packet 61 | elif self.bytes_struct == 'b': 62 | packet.extend(struct.pack('!B', self.id)) 63 | packet.extend(struct.pack('!H', len(data))) 64 | packet.extend(data) 65 | return packet 66 | elif self.bytes_struct == 'vbi': 67 | packet.extend(struct.pack('!B', self.id)) 68 | packet.extend(pack_variable_byte_integer(data)) 69 | return packet 70 | packet.extend(struct.pack('!B', self.id)) 71 | packet.extend(struct.pack(self.bytes_struct, data)) 72 | return packet 73 | 74 | @classmethod 75 | def factory(cls, id_=None, name=None): 76 | if (name is None and id_ is None) or (name is not None and id_ is not None): 77 | raise ValueError('Either id or name should be not None') 78 | if name is not None: 79 | return PROPERTIES_BY_NAME.get(name) 80 | else: 81 | return PROPERTIES_BY_ID.get(id_) 82 | 83 | 84 | PROPERTIES = [ 85 | Property(1, '!B', 'payload_format_id', ['PUBLISH', ]), 86 | Property(2, '!L', 'message_expiry_interval', ['PUBLISH', ]), 87 | Property(3, 'u8', 'content_type', ['PUBLISH']), 88 | Property(8, 'u8', 'response_topic', ['PUBLISH', ]), 89 | Property(9, 'b', 'correlation_data', ['PUBLISH']), 90 | Property(11, 'vbi', 'subscription_identifier', ['PUBLISH', 'SUBSCRIBE']), 91 | Property(17, '!L', 'session_expiry_interval', ['CONNECT', ]), 92 | Property(18, 'u8', 'assigned_client_identifier', ['CONNACK', ]), 93 | Property(19, '!H', 'server_keep_alive', ['CONNACK']), 94 | Property(21, 'u8', 'auth_method', ['CONNECT', 'CONNACK', 'AUTH']), 95 | Property(23, '!B', 'request_problem_info', ['CONNECT']), 96 | Property(24, '!L', 'will_delay_interval', ['CONNECT', ]), 97 | Property(25, '!B', 'request_response_info', ['CONNECT']), 98 | Property(26, 'u8', 'response_info', ['CONNACK']), 99 | Property(28, 'u8', 'server_reference', ['CONNACK', 'DISCONNECT']), 100 | Property(31, 'u8', 'reason_string', ['CONNACK', 'PUBACK', 'PUBREC', 'PUBREL', 'PUBCOMP', 'SUBACK', 'UNSUBACK', 101 | 'DISCONNECT', 'AUTH']), 102 | Property(33, '!H', 'receive_maximum', ['CONNECT', 'CONNACK']), 103 | Property(34, '!H', 'topic_alias_maximum', ['CONNECT', 'CONNACK']), 104 | Property(35, '!H', 'topic_alias', ['PUBLISH']), 105 | Property(36, '!B', 'max_qos', ['CONNACK', ]), 106 | Property(37, '!B', 'retain_available', ['CONNACK', ]), 107 | Property(38, 'u8x2', 'user_property', ['CONNECT', 'CONNACK', 'PUBLISH', 'PUBACK', 'PUBREC', 'PUBREL', 108 | 'PUBCOMP', 'SUBACK', 'UNSUBACK', 'DISCONNECT', 'AUTH']), 109 | Property(39, '!L', 'maximum_packet_size', ['CONNECT', 'CONNACK']), 110 | Property(40, '!B', 'wildcard_subscription_available', ['CONNACK']), 111 | Property(41, '!B', 'sub_id_available', ['CONNACK', ]), 112 | Property(42, '!B', 'shared_subscription_available', ['CONNACK']), 113 | ] 114 | 115 | PROPERTIES_BY_ID = {pr.id: pr for pr in PROPERTIES} 116 | PROPERTIES_BY_NAME = {pr.name: pr for pr in PROPERTIES} 117 | 118 | 119 | -------------------------------------------------------------------------------- /gmqtt/mqtt/protocol.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import time 4 | 5 | import sys 6 | 7 | from . import package 8 | from .constants import MQTTv50, MQTTCommands 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class _StreamReaderProtocolCompatibilityMixin: 14 | def __init__(self, *args, **kwargs): 15 | if sys.version_info < (3, 7): 16 | self._closed = asyncio.get_event_loop().create_future() 17 | super(_StreamReaderProtocolCompatibilityMixin, self).__init__(*args, **kwargs) 18 | 19 | def connection_lost(self, exc): 20 | super(_StreamReaderProtocolCompatibilityMixin, self).connection_lost(exc) 21 | 22 | if sys.version_info[:2] >= (3, 7): 23 | return 24 | 25 | if not self._closed.done(): 26 | if exc is None: 27 | self._closed.set_result(None) 28 | else: 29 | self._closed.set_exception(exc) 30 | 31 | 32 | class BaseMQTTProtocol(_StreamReaderProtocolCompatibilityMixin, asyncio.StreamReaderProtocol): 33 | def __init__(self, buffer_size=2**16, loop=None): 34 | if not loop: 35 | loop = asyncio.get_event_loop() 36 | 37 | self._connection = None 38 | self._transport = None 39 | 40 | self._connected = asyncio.Event() 41 | 42 | reader = asyncio.StreamReader(limit=buffer_size, loop=loop) 43 | # this is bad hack for python 3.8 44 | # TODO: get rid of StreamReader dependency (deprecated) 45 | self._hard_reader = reader 46 | super(BaseMQTTProtocol, self).__init__(reader, loop=loop) 47 | 48 | def set_connection(self, conn): 49 | self._connection = conn 50 | 51 | @property 52 | def closed(self): 53 | return self._closed 54 | 55 | def _parse_packet(self): 56 | raise NotImplementedError 57 | 58 | def connection_made(self, transport: asyncio.Transport): 59 | super(BaseMQTTProtocol, self).connection_made(transport) 60 | 61 | logger.info('[CONNECTION MADE]') 62 | self._transport = transport 63 | 64 | self._connected.set() 65 | 66 | def data_received(self, data): 67 | super(BaseMQTTProtocol, self).data_received(data) 68 | 69 | def write_data(self, data: bytes): 70 | self._connection._last_data_out = time.monotonic() 71 | if self._transport and not self._transport.is_closing(): 72 | self._transport.write(data) 73 | else: 74 | logger.warning('[TRYING WRITE TO CLOSED SOCKET]') 75 | 76 | def connection_lost(self, exc): 77 | self._connected.clear() 78 | super(BaseMQTTProtocol, self).connection_lost(exc) 79 | if exc: 80 | logger.warning('[EXC: CONN LOST]', exc_info=exc) 81 | else: 82 | logger.info('[CONN CLOSE NORMALLY]') 83 | 84 | async def read(self, n=-1): 85 | bs = await self._stream_reader.read(n=n) 86 | 87 | # so we don't receive anything but connection is not closed - 88 | # let's close it manually 89 | if not bs and not self._transport.is_closing(): 90 | self._transport.close() 91 | # self.connection_lost(ConnectionResetError()) 92 | raise ConnectionResetError("Reset connection manually.") 93 | return bs 94 | 95 | 96 | class MQTTProtocol(BaseMQTTProtocol): 97 | proto_name = b'MQTT' 98 | proto_ver = MQTTv50 99 | 100 | def __init__(self, *args, **kwargs): 101 | super(MQTTProtocol, self).__init__(*args, **kwargs) 102 | self._queue = asyncio.Queue() 103 | 104 | self._disconnect = asyncio.Event() 105 | 106 | self._read_loop_future = None 107 | 108 | def connection_made(self, transport: asyncio.Transport): 109 | super().connection_made(transport) 110 | self._read_loop_future = asyncio.ensure_future(self._read_loop()) 111 | 112 | async def send_auth_package(self, client_id, username, password, clean_session, keepalive, 113 | will_message=None, **kwargs): 114 | pkg = package.LoginPackageFactor.build_package(client_id, username, password, clean_session, 115 | keepalive, self, will_message=will_message, **kwargs) 116 | self.write_data(pkg) 117 | 118 | def send_subscribe_packet(self, subscriptions, **kwargs): 119 | mid, pkg = package.SubscribePacket.build_package(subscriptions, self, **kwargs) 120 | for sub in subscriptions: 121 | sub.mid = mid 122 | self.write_data(pkg) 123 | return mid 124 | 125 | def send_unsubscribe_packet(self, topic, **kwargs): 126 | mid, pkg = package.UnsubscribePacket.build_package(topic, self, **kwargs) 127 | self.write_data(pkg) 128 | return mid 129 | 130 | def send_simple_command_packet(self, cmd): 131 | pkg = package.SimpleCommandPacket.build_package(cmd) 132 | self.write_data(pkg) 133 | 134 | def send_ping_request(self): 135 | self.send_simple_command_packet(MQTTCommands.PINGREQ) 136 | 137 | def send_publish(self, message): 138 | mid, pkg = package.PublishPacket.build_package(message, self) 139 | self.write_data(pkg) 140 | 141 | return mid, pkg 142 | 143 | def send_disconnect(self, reason_code=0, **properties): 144 | pkg = package.DisconnectPacket.build_package(self, reason_code=reason_code, **properties) 145 | 146 | self.write_data(pkg) 147 | 148 | return pkg 149 | 150 | def send_command_with_mid(self, cmd, mid, dup, reason_code=0): 151 | pkg = package.CommandWithMidPacket.build_package(cmd, mid, dup, reason_code=reason_code, 152 | proto_ver=self.proto_ver) 153 | self.write_data(pkg) 154 | 155 | def _read_packet(self, data): 156 | parsed_size = 0 157 | raw_size = len(data) 158 | data_size = raw_size 159 | 160 | while True: 161 | # try to extract packet data, minimum expected packet size is 2 162 | if data_size < 2: 163 | break 164 | 165 | # extract payload size 166 | header_size = 1 167 | mult = 1 168 | payload_size = 0 169 | 170 | while True: 171 | if parsed_size + header_size >= raw_size: 172 | # not full header 173 | return parsed_size 174 | payload_byte = data[parsed_size + header_size] 175 | payload_size += (payload_byte & 0x7F) * mult 176 | if mult > 2097152: # 128 * 128 * 128 177 | return -1 178 | mult *= 128 179 | header_size += 1 180 | if header_size + payload_size > data_size: 181 | # not enough data 182 | break 183 | if payload_byte & 128 == 0: 184 | break 185 | 186 | # check size once more 187 | if header_size + payload_size > data_size: 188 | # not enough data 189 | break 190 | 191 | # determine packet type 192 | command = data[parsed_size] 193 | start = parsed_size + header_size 194 | end = start + payload_size 195 | packet = data[start:end] 196 | 197 | data_size -= header_size + payload_size 198 | parsed_size += header_size + payload_size 199 | 200 | self._connection.put_package((command, packet)) 201 | 202 | return parsed_size 203 | 204 | async def _read_loop(self): 205 | await self._connected.wait() 206 | 207 | buf = b'' 208 | max_buff_size = 65536 # 64 * 1024 209 | while self._connected.is_set(): 210 | try: 211 | buf += await self.read(max_buff_size) 212 | parsed_size = self._read_packet(buf) 213 | if parsed_size == -1 or self._transport.is_closing(): 214 | logger.debug("[RECV EMPTY] Connection will be reset automatically.") 215 | break 216 | buf = buf[parsed_size:] 217 | except ConnectionResetError as exc: 218 | # This connection will be closed, because we received the empty data. 219 | # So we can safely break the while 220 | logger.debug("[RECV EMPTY] Connection will be reset automatically.") 221 | break 222 | 223 | def connection_lost(self, exc): 224 | super(MQTTProtocol, self).connection_lost(exc) 225 | self._connection.put_package((MQTTCommands.DISCONNECT, b'')) 226 | 227 | if self._read_loop_future is not None: 228 | self._read_loop_future.cancel() 229 | self._read_loop_future = None 230 | 231 | self._queue = asyncio.Queue() 232 | -------------------------------------------------------------------------------- /gmqtt/mqtt/utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import struct 3 | import logging 4 | 5 | from functools import partial 6 | 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class Singleton(type): 12 | _instances = {} 13 | 14 | def __call__(cls, *args, **kwargs): 15 | if cls not in cls._instances: 16 | cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) 17 | return cls._instances[cls] 18 | 19 | 20 | class IdGenerator(object, metaclass=Singleton): 21 | def __init__(self, max=65536): 22 | self._max = max 23 | self._used_ids = set() 24 | self._last_used_id = 0 25 | 26 | def _mid_generate(self): 27 | done = False 28 | 29 | while not done: 30 | if len(self._used_ids) >= self._max - 1: 31 | raise OverflowError("All ids has already used. May be your QoS query is full.") 32 | 33 | self._last_used_id += 1 34 | 35 | if self._last_used_id in self._used_ids: 36 | continue 37 | 38 | if self._last_used_id == self._max: 39 | self._last_used_id = 0 40 | continue 41 | 42 | done = True 43 | 44 | self._used_ids.add(self._last_used_id) 45 | return self._last_used_id 46 | 47 | def free_id(self, id): 48 | logger.debug('FREE MID: %s', id) 49 | if id not in self._used_ids: 50 | return 51 | 52 | self._used_ids.remove(id) 53 | 54 | def next_id(self): 55 | id = self._mid_generate() 56 | 57 | logger.debug("NEW ID: %s", id) 58 | return id 59 | 60 | 61 | def pack_variable_byte_integer(value): 62 | remaining_bytes = bytearray() 63 | while True: 64 | value, b = divmod(value, 128) 65 | if value > 0: 66 | b |= 0x80 67 | remaining_bytes.extend(struct.pack('!B', b)) 68 | if value <= 0: 69 | break 70 | return remaining_bytes 71 | 72 | 73 | def unpack_variable_byte_integer(bts): 74 | multiplier = 1 75 | value = 0 76 | i = 0 77 | while i < 4: 78 | b = bts[i] 79 | value += (b & 0x7F) * multiplier 80 | if multiplier > 2097152: # 128 * 128 * 128 81 | raise ValueError('Malformed Variable Byte Integer') 82 | multiplier *= 128 83 | if b & 0x80 == 0: 84 | break 85 | i += 1 86 | return value, bts[i + 1:] 87 | 88 | 89 | def unpack_utf8(bytes_array): 90 | str_len, = struct.unpack('!H', bytes_array[:2]) 91 | value = bytes_array[2:2 + str_len].decode('utf-8') 92 | left_str = bytes_array[2 + str_len:] 93 | return value, left_str 94 | 95 | 96 | def pack_utf8(data): 97 | packet = bytearray() 98 | if isinstance(data, str): 99 | data = data.encode('utf-8') 100 | packet.extend(struct.pack("!H", len(data))) 101 | packet.extend(data) 102 | return packet 103 | 104 | 105 | def iscoroutinefunction_or_partial(object): 106 | if isinstance(object, partial): 107 | object = object.func 108 | return asyncio.iscoroutinefunction(object) 109 | 110 | 111 | def run_coroutine_or_function(func, *args, callback=None, **kwargs): 112 | if iscoroutinefunction_or_partial(func): 113 | f = asyncio.ensure_future(func(*args, **kwargs)) 114 | if callback is not None: 115 | f.add_done_callback(callback) 116 | else: 117 | func(*args, **kwargs) 118 | -------------------------------------------------------------------------------- /gmqtt/storage.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Callable, Tuple, Set 3 | 4 | import heapq 5 | 6 | 7 | class BasePersistentStorage(object): 8 | async def push_message(self, mid, raw_package): 9 | raise NotImplementedError 10 | 11 | def push_message_nowait(self, mid, raw_package) -> asyncio.Future: 12 | return asyncio.ensure_future(self.push_message(mid, raw_package)) 13 | 14 | async def pop_message(self) -> Tuple[int, bytes]: 15 | raise NotImplementedError 16 | 17 | async def remove_message_by_mid(self, mid): 18 | raise NotImplementedError 19 | 20 | @property 21 | async def is_empty(self) -> bool: 22 | raise NotImplementedError 23 | 24 | async def wait_empty(self) -> None: 25 | # Note that when some kinda really persistent storage is used (like Redis or smth), 26 | # this method must implement an atomic transaction from async network exchange perspective. 27 | raise NotImplementedError 28 | 29 | async def clear(self): 30 | raise NotImplementedError 31 | 32 | async def get_all(self): 33 | raise NotImplementedError 34 | 35 | 36 | class HeapPersistentStorage(BasePersistentStorage): 37 | def __init__(self): 38 | self._queue = [] 39 | self._empty_waiters: Set[asyncio.Future] = set() 40 | 41 | def _notify_waiters(self, waiters: Set[asyncio.Future], notify: Callable[[asyncio.Future], None]) -> None: 42 | while waiters: 43 | notify(waiters.pop()) 44 | 45 | def _check_empty(self): 46 | if not self._queue: 47 | self._notify_waiters(self._empty_waiters, lambda waiter: waiter.set_result(None)) 48 | 49 | async def push_message(self, mid, raw_package): 50 | tm = asyncio.get_event_loop().time() 51 | heapq.heappush(self._queue, (tm, mid, raw_package)) 52 | 53 | async def pop_message(self): 54 | (tm, mid, raw_package) = heapq.heappop(self._queue) 55 | 56 | self._check_empty() 57 | return mid, raw_package 58 | 59 | async def remove_message_by_mid(self, mid): 60 | message = next(filter(lambda x: x[1] == mid, self._queue), None) 61 | if message: 62 | self._queue.remove(message) 63 | self._check_empty() 64 | heapq.heapify(self._queue) 65 | 66 | @property 67 | async def is_empty(self): 68 | return not bool(self._queue) 69 | 70 | async def wait_empty(self) -> None: 71 | if self._queue: 72 | waiter = asyncio.get_running_loop().create_future() 73 | self._empty_waiters.add(waiter) 74 | await waiter 75 | 76 | async def clear(self): 77 | self._queue = [] 78 | self._notify_waiters(self._empty_waiters, lambda waiter: waiter.set_result(None)) 79 | 80 | async def get_all(self): 81 | return self._queue -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | norecursedirs = env pyenv 3 | 4 | -------------------------------------------------------------------------------- /requirements_test.txt: -------------------------------------------------------------------------------- 1 | # gmqtt test requirements 2 | .[test] 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | import sys 3 | from os import path 4 | 5 | from setuptools import find_packages, setup 6 | 7 | import gmqtt 8 | 9 | this_directory = path.abspath(path.dirname(__file__)) 10 | with open(path.join(this_directory, "README.md"), encoding="utf-8") as readme: 11 | long_description = readme.read() 12 | 13 | extra = {} 14 | if sys.version_info >= (3, 4): 15 | extra["use_2to3"] = False 16 | extra["convert_2to3_doctests"] = ["README.md"] 17 | 18 | CLASSIFIERS = [ 19 | "Development Status :: 5 - Production/Stable", 20 | "Intended Audience :: Developers", 21 | "License :: OSI Approved :: MIT License", 22 | "Natural Language :: English", 23 | "Operating System :: OS Independent", 24 | "Programming Language :: Python", 25 | "Topic :: Software Development :: Libraries :: Python Modules", 26 | ] 27 | 28 | KEYWORDS = "Gurtam MQTT client." 29 | 30 | TESTS_REQUIRE = [ 31 | "atomicwrites>=1.3.0", 32 | "attrs>=19.1.0", 33 | "codecov>=2.0.15", 34 | "coverage>=4.5.3", 35 | "more-itertools>=7.0.0", 36 | "pluggy>=0.11.0", 37 | "py>=1.8.0", 38 | "pytest-asyncio>=0.12.0", 39 | "pytest-cov>=2.7.1", 40 | "pytest>=5.4.0", 41 | "six>=1.12.0", 42 | "uvloop>=0.14.0", 43 | ] 44 | 45 | # Allow you to run pip install .[test] to get test dependencies included 46 | EXTRAS_REQUIRE = {"test": TESTS_REQUIRE} 47 | 48 | setup( 49 | name="gmqtt", 50 | version=gmqtt.__version__, 51 | description="Client for MQTT protocol", 52 | long_description=long_description, 53 | long_description_content_type="text/markdown", 54 | author=gmqtt.__author__, 55 | author_email=gmqtt.__email__, 56 | license='MIT', 57 | url="https://github.com/wialon/gmqtt", 58 | packages=find_packages(exclude=['examples', 'tests']), 59 | download_url="https://github.com/wialon/gmqtt", 60 | classifiers=CLASSIFIERS, 61 | keywords=KEYWORDS, 62 | zip_safe=True, 63 | test_suite="tests", 64 | install_requires=[], 65 | tests_require=TESTS_REQUIRE, 66 | extras_require=EXTRAS_REQUIRE, 67 | python_requires='>=3.5', 68 | ) 69 | -------------------------------------------------------------------------------- /static/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wialon/gmqtt/4b84685ce1c079905da22eeb119f168f1350762c/static/logo.png -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wialon/gmqtt/4b84685ce1c079905da22eeb119f168f1350762c/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_mqtt5.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | import time 4 | from unittest import mock 5 | 6 | import pytest 7 | import pytest_asyncio 8 | 9 | import gmqtt 10 | from tests.utils import Callbacks, cleanup, clean_retained 11 | 12 | if os.getenv('TOKEN'): 13 | host = 'mqtt.flespi.io' 14 | username = os.getenv("TOKEN") 15 | password = None 16 | port = 1883 17 | else: 18 | host = os.getenv("HOST", "127.0.0.1") 19 | username = os.getenv("USERNAME", "") 20 | password = os.getenv("PASSWORD", None) 21 | port = os.getenv("PORT", 1883) 22 | 23 | PREFIX = 'GMQTT/' + str(time.time()) + '/' 24 | 25 | TOPICS = (PREFIX + "TopicA", 26 | PREFIX + "TopicA/B", 27 | PREFIX + "TopicA/C", 28 | PREFIX + "TopicA/D", 29 | PREFIX + "/TopicA") 30 | WILDTOPICS = (PREFIX + "TopicA/+", 31 | PREFIX + "+/C", 32 | PREFIX + "#", 33 | PREFIX + "/#", 34 | PREFIX + "/+", 35 | PREFIX + "+/+", 36 | PREFIX + "TopicA/#") 37 | NOSUBSCRIBE_TOPICS = (PREFIX + "test/nosubscribe",) 38 | 39 | 40 | @pytest_asyncio.fixture 41 | async def init_clients(): 42 | await cleanup(host, port, username, prefix=PREFIX) 43 | 44 | aclient = gmqtt.Client(PREFIX + "myclientid", clean_session=True) 45 | aclient.set_auth_credentials(username) 46 | callback = Callbacks() 47 | callback.register_for_client(aclient) 48 | 49 | bclient = gmqtt.Client(PREFIX + "myclientid2", clean_session=True) 50 | bclient.set_auth_credentials(username) 51 | callback2 = Callbacks() 52 | callback2.register_for_client(bclient) 53 | 54 | yield aclient, callback, bclient, callback2 55 | 56 | await aclient.disconnect() 57 | await bclient.disconnect() 58 | 59 | 60 | @pytest.mark.asyncio 61 | async def test_basic(init_clients): 62 | aclient, callback, bclient, callback2 = init_clients 63 | 64 | await aclient.connect(host=host, port=port, version=4) 65 | await bclient.connect(host=host, port=port, version=4) 66 | bclient.subscribe(TOPICS[0]) 67 | await asyncio.sleep(1) 68 | 69 | aclient.publish(TOPICS[0], b"qos 0") 70 | aclient.publish(TOPICS[0], b"qos 1", 1) 71 | aclient.publish(TOPICS[0], b"qos 2", 2) 72 | await asyncio.sleep(1) 73 | assert len(callback2.messages) == 3 74 | 75 | 76 | @pytest.mark.asyncio 77 | async def test_basic_subscriptions(init_clients): 78 | aclient, callback, bclient, callback2 = init_clients 79 | 80 | await aclient.connect(host=host, port=port) 81 | await bclient.connect(host=host, port=port) 82 | 83 | subscriptions = [ 84 | gmqtt.Subscription(TOPICS[1], qos=1), 85 | gmqtt.Subscription(TOPICS[2], qos=2), 86 | ] 87 | bclient.subscribe(subscriptions, user_property=('key', 'value'), subscription_identifier=1) 88 | 89 | bclient.subscribe(gmqtt.Subscription(TOPICS[3], qos=1), user_property=('key', 'value'), subscription_identifier=2) 90 | await asyncio.sleep(1) 91 | 92 | aclient.publish(TOPICS[3], b"qos 0") 93 | aclient.publish(TOPICS[1], b"qos 1", 1) 94 | aclient.publish(TOPICS[2], b"qos 2", 2) 95 | await asyncio.sleep(1) 96 | assert len(callback2.messages) == 3 97 | 98 | 99 | @pytest.mark.asyncio 100 | async def test_retained_message(init_clients): 101 | aclient, callback, bclient, callback2 = init_clients 102 | 103 | await aclient.connect(host=host, port=port) 104 | aclient.publish(TOPICS[1], b"ret qos 0", 0, retain=True, user_property=("a", "2")) 105 | aclient.publish(TOPICS[2], b"ret qos 1", 1, retain=True, user_property=("c", "3")) 106 | aclient.publish(TOPICS[3], b"ret qos 2", 2, retain=True, user_property=(("a", "2"), ("c", "3"))) 107 | 108 | await asyncio.sleep(2) 109 | 110 | await bclient.connect(host=host, port=port) 111 | bclient.subscribe(WILDTOPICS[0], qos=2) 112 | await asyncio.sleep(1) 113 | 114 | assert len(callback2.messages) == 3 115 | for msg in callback2.messages: 116 | assert msg[3]['retain'] == True 117 | 118 | aclient.publish(TOPICS[2], b"ret qos 1", 1, retain=True, user_property=("c", "3")) 119 | await asyncio.sleep(1) 120 | 121 | assert len(callback2.messages) == 4 122 | assert callback2.messages[3][3]['retain'] == 0 123 | 124 | await clean_retained(host, port, username, prefix=PREFIX) 125 | 126 | 127 | @pytest.mark.asyncio 128 | async def test_will_message(init_clients): 129 | aclient, callback, bclient, callback2 = init_clients 130 | 131 | # re-initialize aclient with will message 132 | will_message = gmqtt.Message(TOPICS[2], "I'm dead finally") 133 | aclient = gmqtt.Client(PREFIX + "myclientid3", clean_session=True, will_message=will_message) 134 | aclient.set_auth_credentials(username) 135 | 136 | await aclient.connect(host, port=port) 137 | 138 | await bclient.connect(host=host, port=port) 139 | bclient.subscribe(TOPICS[2]) 140 | 141 | await asyncio.sleep(1) 142 | await aclient.disconnect(reason_code=4) 143 | await asyncio.sleep(1) 144 | assert len(callback2.messages) == 1 145 | 146 | 147 | @pytest.mark.asyncio 148 | async def test_no_will_message_on_gentle_disconnect(init_clients): 149 | aclient, callback, bclient, callback2 = init_clients 150 | 151 | # re-initialize aclient with will message 152 | will_message = gmqtt.Message(TOPICS[2], "I'm dead finally") 153 | aclient = gmqtt.Client(PREFIX + "myclientid3", clean_session=True, will_message=will_message) 154 | aclient.set_auth_credentials(username) 155 | 156 | await aclient.connect(host, port=port) 157 | 158 | await bclient.connect(host=host, port=port) 159 | bclient.subscribe(TOPICS[2]) 160 | 161 | await asyncio.sleep(1) 162 | await aclient.disconnect(reason_code=0) 163 | await asyncio.sleep(1) 164 | assert len(callback2.messages) == 0 165 | 166 | 167 | @pytest.mark.asyncio 168 | async def test_shared_subscriptions(init_clients): 169 | aclient, callback, bclient, callback2 = init_clients 170 | 171 | shared_sub_topic = '$share/sharename/{}x'.format(PREFIX) 172 | shared_pub_topic = PREFIX + 'x' 173 | 174 | await aclient.connect(host=host, port=port) 175 | aclient.subscribe(shared_sub_topic) 176 | aclient.subscribe(TOPICS[0]) 177 | 178 | await bclient.connect(host=host, port=port) 179 | bclient.subscribe(shared_sub_topic) 180 | bclient.subscribe(TOPICS[0]) 181 | 182 | pubclient = gmqtt.Client(PREFIX + "myclient3", clean_session=True) 183 | pubclient.set_auth_credentials(username) 184 | await pubclient.connect(host, port) 185 | 186 | count = 10 187 | for i in range(count): 188 | pubclient.publish(TOPICS[0], "message " + str(i), 0) 189 | j = 0 190 | while len(callback.messages) + len(callback2.messages) < 2 * count and j < 20: 191 | await asyncio.sleep(1) 192 | j += 1 193 | await asyncio.sleep(1) 194 | assert len(callback.messages) == count 195 | assert len(callback2.messages) == count 196 | 197 | callback.clear() 198 | callback2.clear() 199 | 200 | count = 10 201 | for i in range(count): 202 | pubclient.publish(shared_pub_topic, "message " + str(i), 0) 203 | j = 0 204 | while len(callback.messages) + len(callback2.messages) < count and j < 20: 205 | await asyncio.sleep(1) 206 | j += 1 207 | await asyncio.sleep(1) 208 | # Each message should only be received once 209 | assert len(callback.messages) + len(callback2.messages) == count 210 | assert len(callback.messages) > 0 211 | assert len(callback2.messages) > 0 212 | 213 | await pubclient.disconnect() 214 | 215 | 216 | @pytest.mark.asyncio 217 | async def test_assigned_clientid(): 218 | noidclient = gmqtt.Client("", clean_session=True) 219 | noidclient.set_auth_credentials(username) 220 | callback = Callbacks() 221 | callback.register_for_client(noidclient) 222 | await noidclient.connect(host=host, port=port) 223 | await noidclient.disconnect() 224 | assert callback.connack[2]['assigned_client_identifier'][0] != "" 225 | 226 | 227 | @pytest.mark.asyncio 228 | async def test_unsubscribe(init_clients): 229 | aclient, callback, bclient, callback2 = init_clients 230 | await bclient.connect(host=host, port=port) 231 | await aclient.connect(host=host, port=port) 232 | 233 | bclient.subscribe(TOPICS[1]) 234 | bclient.subscribe(TOPICS[2]) 235 | bclient.subscribe(TOPICS[3]) 236 | await asyncio.sleep(1) 237 | assert len(bclient.subscriptions) == 3 238 | 239 | aclient.publish(TOPICS[1], b"topic 0 - subscribed", 1, retain=False) 240 | aclient.publish(TOPICS[2], b"topic 1", 1, retain=False) 241 | aclient.publish(TOPICS[3], b"topic 2", 1, retain=False) 242 | await asyncio.sleep(1) 243 | assert len(callback2.messages) == 3 244 | callback2.clear() 245 | # Unsubscribe from one topic 246 | bclient.unsubscribe(TOPICS[1]) 247 | assert len(bclient.subscriptions) == 2 248 | await asyncio.sleep(3) 249 | 250 | aclient.publish(TOPICS[1], b"topic 0 - unsubscribed", 1, retain=False) 251 | aclient.publish(TOPICS[2], b"topic 1", 1, retain=False) 252 | aclient.publish(TOPICS[3], b"topic 2", 1, retain=False) 253 | await asyncio.sleep(2) 254 | 255 | assert len(callback2.messages) == 2 256 | 257 | 258 | @pytest.mark.asyncio 259 | async def test_overlapping_subscriptions(init_clients): 260 | aclient, callback, bclient, callback2 = init_clients 261 | await bclient.connect(host=host, port=port) 262 | await aclient.connect(host=host, port=port) 263 | 264 | aclient.subscribe(TOPICS[3], qos=1, subscription_identifier=21) 265 | aclient.subscribe(WILDTOPICS[6], qos=1, subscription_identifier=42) 266 | await asyncio.sleep(1) 267 | bclient.publish(TOPICS[3], b"overlapping topic filters", 2) 268 | await asyncio.sleep(1) 269 | assert len(callback.messages) in [1, 2] 270 | if len(callback.messages) == 1: 271 | assert callback.messages[0][2] == 1 272 | assert set(callback.messages[0][3]['subscription_identifier']) == {42, 21} 273 | else: 274 | assert (callback.messages[0][2] == 2 and callback.messages[1][2] == 1) or \ 275 | (callback.messages[0][2] == 1 and callback.messages[1][2] == 2) 276 | 277 | 278 | @pytest.mark.asyncio 279 | async def test_redelivery_on_reconnect(init_clients): 280 | # redelivery on reconnect. When a QoS 1 or 2 exchange has not been completed, the server should retry the 281 | # appropriate MQTT packets 282 | messages = [] 283 | 284 | def on_message(client, topic, payload, qos, properties): 285 | print('MSG', (topic, payload, qos, properties)) 286 | messages.append((topic, payload, qos, properties)) 287 | return 131 288 | 289 | aclient, callback, bclient, callback2 = init_clients 290 | 291 | disconnect_client = gmqtt.Client(PREFIX + 'myclientid3', optimistic_acknowledgement=False, 292 | clean_session=False, session_expiry_interval=99999) 293 | disconnect_client.set_config({'reconnect_retries': 0}) 294 | disconnect_client.on_message = on_message 295 | disconnect_client.set_auth_credentials(username) 296 | 297 | await disconnect_client.connect(host=host, port=port) 298 | disconnect_client.subscribe(WILDTOPICS[6], 2) 299 | 300 | await asyncio.sleep(1) 301 | await aclient.connect(host, port) 302 | await asyncio.sleep(1) 303 | 304 | aclient.publish(TOPICS[1], b"", 1, retain=False) 305 | aclient.publish(TOPICS[3], b"", 2, retain=False) 306 | await asyncio.sleep(1) 307 | messages = [] 308 | await disconnect_client.reconnect() 309 | 310 | await asyncio.sleep(2) 311 | assert len(messages) == 2 312 | await disconnect_client.disconnect() 313 | 314 | 315 | @pytest.mark.asyncio 316 | async def xtest_async_on_message(init_clients): 317 | # redelivery on reconnect. When a QoS 1 or 2 exchange has not been completed, the server should retry the 318 | # appropriate MQTT packets 319 | messages = [] 320 | 321 | async def on_message(client, topic, payload, qos, properties): 322 | print('MSG', (topic, payload, qos, properties)) 323 | await asyncio.sleep(0.5) 324 | messages.append((topic, payload, qos, properties)) 325 | return 131 326 | 327 | aclient, callback, bclient, callback2 = init_clients 328 | 329 | disconnect_client = gmqtt.Client(PREFIX + 'myclientid3', optimistic_acknowledgement=False, 330 | clean_session=False, session_expiry_interval=99999) 331 | disconnect_client.set_config({'reconnect_retries': 0}) 332 | disconnect_client.on_message = on_message 333 | disconnect_client.set_auth_credentials(username) 334 | 335 | await disconnect_client.connect(host=host, port=port) 336 | disconnect_client.subscribe(WILDTOPICS[6], 1) 337 | 338 | await asyncio.sleep(1) 339 | await aclient.connect(host, port) 340 | await asyncio.sleep(1) 341 | 342 | aclient.publish(TOPICS[1], b"", 1, retain=False) 343 | aclient.publish(TOPICS[3], b"", 1, retain=False) 344 | await asyncio.sleep(3) 345 | messages = [] 346 | await disconnect_client.reconnect() 347 | 348 | await asyncio.sleep(3) 349 | assert len(messages) == 2 350 | await disconnect_client.disconnect() 351 | 352 | 353 | @pytest.mark.asyncio 354 | async def test_request_response(init_clients): 355 | aclient, callback, bclient, callback2 = init_clients 356 | 357 | await aclient.connect(host=host, port=port) 358 | await bclient.connect(host=host, port=port) 359 | 360 | aclient.subscribe(WILDTOPICS[0], 2) 361 | 362 | bclient.subscribe(WILDTOPICS[0], 2) 363 | 364 | await asyncio.sleep(1) 365 | # client a is the requester 366 | aclient.publish(TOPICS[1], b"request", 1, response_topic=TOPICS[2], correlation_data=b'334') 367 | 368 | await asyncio.sleep(1) 369 | 370 | # client b is the responder 371 | assert len(callback2.messages) == 1 372 | 373 | assert callback2.messages[0][3]['response_topic'] == [TOPICS[2], ] 374 | assert callback2.messages[0][3]['correlation_data'] == [b"334", ] 375 | 376 | bclient.publish(callback2.messages[0][3]['response_topic'][0], b"response", 1, 377 | correlation_data=callback2.messages[0][3]['correlation_data'][0]) 378 | 379 | await asyncio.sleep(1) 380 | assert len(callback.messages) == 2 381 | 382 | 383 | @pytest.mark.asyncio 384 | async def test_subscribe_no_local(init_clients): 385 | aclient, callback, bclient, callback2 = init_clients 386 | 387 | await aclient.connect(host=host, port=port) 388 | await bclient.connect(host=host, port=port) 389 | 390 | aclient.subscribe(WILDTOPICS[0], 2, no_local=True) 391 | 392 | bclient.subscribe(WILDTOPICS[0], 2) 393 | 394 | await asyncio.sleep(1) 395 | 396 | aclient.publish(TOPICS[1], b"aclient msg", 1) 397 | 398 | bclient.publish(TOPICS[1], b"bclient msg", 1) 399 | 400 | await asyncio.sleep(1) 401 | 402 | assert len(callback.messages) == 1 403 | assert len(callback2.messages) == 2 404 | 405 | 406 | @pytest.mark.asyncio 407 | async def test_subscribe_retain_01_handling_flag(init_clients): 408 | aclient, callback, bclient, callback2 = init_clients 409 | 410 | await aclient.connect(host=host, port=port) 411 | await bclient.connect(host=host, port=port) 412 | 413 | aclient.publish(TOPICS[1], b"ret qos 1", 1, retain=True) 414 | 415 | await asyncio.sleep(1) 416 | 417 | bclient.subscribe(WILDTOPICS[0], qos=2, retain_handling_options=0) 418 | 419 | await asyncio.sleep(1) 420 | 421 | assert len(callback2.messages) == 1 422 | 423 | bclient.subscribe(WILDTOPICS[0], qos=2, retain_handling_options=0) 424 | 425 | await asyncio.sleep(1) 426 | 427 | assert len(callback2.messages) == 2 428 | 429 | bclient.subscribe(WILDTOPICS[0], qos=2, retain_handling_options=1) 430 | 431 | await asyncio.sleep(1) 432 | 433 | assert len(callback2.messages) == 2 434 | 435 | 436 | @pytest.mark.asyncio 437 | async def test_subscribe_retain_2_handling_flag(init_clients): 438 | aclient, callback, bclient, callback2 = init_clients 439 | 440 | await aclient.connect(host=host, port=port) 441 | await bclient.connect(host=host, port=port) 442 | 443 | aclient.publish(TOPICS[1], b"ret qos 1", 1, retain=True) 444 | 445 | await asyncio.sleep(1) 446 | 447 | bclient.subscribe(WILDTOPICS[0], qos=2, retain_handling_options=2) 448 | 449 | await asyncio.sleep(1) 450 | 451 | assert len(callback2.messages) == 0 452 | 453 | 454 | @pytest.mark.asyncio 455 | async def test_basic_ssl(init_clients): 456 | aclient, callback, bclient, callback2 = init_clients 457 | ssl_port = 8883 458 | 459 | await aclient.connect(host=host, port=ssl_port, ssl=True, version=4) 460 | await bclient.connect(host=host, port=ssl_port, ssl=True, version=5) 461 | bclient.subscribe(TOPICS[0]) 462 | await asyncio.sleep(1) 463 | 464 | aclient.publish(TOPICS[0], b"qos 0") 465 | aclient.publish(TOPICS[0], b"qos 1", 1) 466 | aclient.publish(TOPICS[0], b"qos 2", 2) 467 | await asyncio.sleep(1) 468 | assert len(callback2.messages) == 3 469 | 470 | 471 | @pytest.mark.asyncio 472 | async def test_reconnection_with_failure(init_clients): 473 | aclient, callback, bclient, callback2 = init_clients 474 | aclient.set_config({'reconnect_retries': -1, 'reconnect_delay': 0}) 475 | await aclient.connect(host=host, port=port) 476 | await bclient.connect(host=host, port=port) 477 | 478 | bclient.subscribe(TOPICS[0]) 479 | 480 | with mock.patch.object(aclient, '_disconnect') as disconnect_mock: 481 | disconnect_mock.side_effect = ConnectionAbortedError("error") 482 | await aclient.reconnect() 483 | 484 | await asyncio.sleep(3) 485 | 486 | # Check aclient is still working after reconnection 487 | aclient.publish(TOPICS[0], b"test") 488 | await asyncio.sleep(3) 489 | assert len(callback2.messages) == 1 490 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import asyncio 4 | 5 | import gmqtt 6 | import logging 7 | 8 | 9 | class Callbacks: 10 | 11 | def __init__(self): 12 | self.messages = [] 13 | self.publisheds = [] 14 | self.subscribeds = [] 15 | self.connack = None 16 | 17 | self.disconnected = False 18 | self.connected = False 19 | 20 | def __str__(self): 21 | return str(self.messages) + str(self.messagedicts) + str(self.publisheds) + \ 22 | str(self.subscribeds) + str(self.unsubscribeds) + str(self.disconnects) 23 | 24 | def clear(self): 25 | self.__init__() 26 | 27 | def on_disconnect(self, client, packet): 28 | logging.info('[DISCONNECTED {}]'.format(client._client_id)) 29 | self.disconnected = True 30 | 31 | def on_message(self, client, topic, payload, qos, properties): 32 | logging.info('[RECV MSG {}] TOPIC: {} PAYLOAD: {} QOS: {} PROPERTIES: {}' 33 | .format(client._client_id, topic, payload, qos, properties)) 34 | self.messages.append((topic, payload, qos, properties)) 35 | 36 | def on_subscribe(self, client, mid, qos, properties): 37 | logging.info('[SUBSCRIBED {}] QOS: {}, properties: {}'.format(client._client_id, qos, properties)) 38 | self.subscribeds.append(mid) 39 | 40 | def on_connect(self, client, flags, rc, properties): 41 | logging.info('[CONNECTED {}]'.format(client._client_id)) 42 | self.connected = True 43 | self.connack = (flags, rc, properties) 44 | 45 | def register_for_client(self, client): 46 | client.on_disconnect = self.on_disconnect 47 | client.on_message = self.on_message 48 | client.on_connect = self.on_connect 49 | client.on_subscribe = self.on_subscribe 50 | 51 | 52 | async def clean_retained(host, port, username, password=None, prefix=None): 53 | def on_message(client, topic, payload, qos, properties): 54 | curclient.publish(topic, b"", qos=0, retain=True) 55 | 56 | curclient = gmqtt.Client(prefix + "cleanretained", clean_session=True) 57 | 58 | curclient.set_auth_credentials(username, password) 59 | curclient.on_message = on_message 60 | await curclient.connect(host=host, port=port) 61 | topic = '#' if not prefix else prefix + '#' 62 | curclient.subscribe(topic) 63 | await asyncio.sleep(10) # wait for all retained messages to arrive 64 | await curclient.disconnect() 65 | time.sleep(.1) 66 | 67 | 68 | async def cleanup(host, port=1883, username=None, password=None, client_ids=None, prefix=None): 69 | # clean all client state 70 | print("clean up starting") 71 | client_ids = client_ids or (prefix + "myclientid", prefix + "myclientid2", prefix + "myclientid3") 72 | 73 | for clientid in client_ids: 74 | curclient = gmqtt.Client(clientid.encode("utf-8"), clean_session=True) 75 | curclient.set_auth_credentials(username=username, password=password) 76 | await curclient.connect(host=host, port=port) 77 | time.sleep(.1) 78 | await curclient.disconnect() 79 | time.sleep(.1) 80 | 81 | # clean retained messages 82 | await clean_retained(host, port, username, password=password, prefix=prefix) 83 | print("clean up finished") 84 | --------------------------------------------------------------------------------