├── .github └── workflows │ ├── pythonapp.yml │ └── pythonpublish.yml ├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── Makefile ├── README.md ├── example.py ├── graphql_client └── __init__.py ├── setup.py └── tests ├── test.py └── websocket_server.py /.github/workflows/pythonapp.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: py-graphql-client 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up Python 3.8 20 | uses: actions/setup-python@v1 21 | with: 22 | python-version: 3.8 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install websocket-client 27 | python setup.py install 28 | 29 | # - name: Lint with flake8 30 | # run: | 31 | # pip install flake8 32 | # # stop the build if there are Python syntax errors or undefined names 33 | # flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 34 | # # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 35 | # flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 36 | - name: Test with unittest 37 | run: | 38 | make test 39 | -------------------------------------------------------------------------------- /.github/workflows/pythonpublish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Publish Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v1 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.gitignore.io/api/python 3 | # Edit at https://www.gitignore.io/?templates=python 4 | 5 | ### Python ### 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # celery beat schedule file 91 | celerybeat-schedule 92 | 93 | # SageMath parsed files 94 | *.sage.py 95 | 96 | # Environments 97 | .env 98 | .venv 99 | env/ 100 | venv/ 101 | ENV/ 102 | env.bak/ 103 | venv.bak/ 104 | 105 | # Spyder project settings 106 | .spyderproject 107 | .spyproject 108 | 109 | # Rope project settings 110 | .ropeproject 111 | 112 | # mkdocs documentation 113 | /site 114 | 115 | # mypy 116 | .mypy_cache/ 117 | .dmypy.json 118 | dmypy.json 119 | 120 | # Pyre type checker 121 | .pyre/ 122 | 123 | ### Python Patch ### 124 | .venv/ 125 | 126 | # Editor configs 127 | .vcode/ 128 | 129 | # End of https://www.gitignore.io/api/python 130 | ws/ 131 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Unreleased 2 | 3 | ## Fixes 4 | 5 | ## Enhancements/Features 6 | - Added support for context manager API 7 | now you can use the client with context manager API, like so: 8 | ```python 9 | with GraphQLClient("ws://localhost/graphql") as client: 10 | client.subscribe(...) 11 | ``` 12 | 13 | 14 | # 0.1.1 15 | 16 | ## Fixes 17 | - fixed: when stop_subscribe was called it would stop the receive thread, so further subscription would not work 18 | - fixed: if more than operations were scheduled, the correct callback might not receive the correct data 19 | - refactor: use a separate thread all the time to continously receive data from the server and put it on queues 20 | - refactor: use separate queues for each operation to properly keep track of incoming and data and who it is meant for 21 | - other misc improvements 22 | - Removing sleep, the `conn.recv` call is blocking 23 | - Added `graphql-ws` subprotocol header to help with some WSS connections 24 | 25 | ## Enhancements/Features 26 | - UUIDv4 for generating operation IDs (#16) 27 | - Added tests 28 | 29 | 30 | # 0.1.0 31 | - basic working of GraphQL over Websocket (Apollo) protocol 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright Anon Ray (c) 2023 2 | 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | * Redistributions in binary form must reproduce the above 12 | copyright notice, this list of conditions and the following 13 | disclaimer in the documentation and/or other materials provided 14 | with the distribution. 15 | 16 | * Neither the name of Anon Ray nor the names of other 17 | contributors may be used to endorse or promote products derived 18 | from this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | project_dir := $(shell pwd) 2 | 3 | test: 4 | python -m unittest tests.test 5 | 6 | install: 7 | python setup.py install 8 | 9 | build: 10 | python setup.py sdist bdist_wheel 11 | 12 | publish: build 13 | twine upload dist/* 14 | make clean 15 | 16 | test_publish: build 17 | twine upload --repository-url https://test.pypi.org/legacy/ dist/* 18 | make clean 19 | 20 | clean: 21 | rm -rf build dist .egg py_graphql_client.egg-info 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # py-graphql-client 2 | Dead-simple to use GraphQL client over websocket. Using the 3 | [apollo-transport-ws](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md) 4 | protocol. 5 | 6 | ## Install 7 | 8 | ```bash 9 | pip install py-graphql-client 10 | ``` 11 | 12 | ## Examples 13 | 14 | ### Setup subscriptions super easily 15 | 16 | ```python 17 | from graphql_client import GraphQLClient 18 | 19 | query = """ 20 | subscription { 21 | notifications { 22 | id 23 | title 24 | content 25 | } 26 | } 27 | """ 28 | 29 | def callback(_id, data): 30 | print("got new data..") 31 | print(f"msg id: {_id}. data: {data}") 32 | 33 | with GraphQLClient('ws://localhost:8080/graphql') as client: 34 | sub_id = client.subscribe(query, callback=callback) 35 | # do other stuff 36 | # ... 37 | # later stop the subscription 38 | client.stop_subscribe(sub_id) 39 | ``` 40 | 41 | ### Variables can be passed 42 | 43 | ```python 44 | from graphql_client import GraphQLClient 45 | 46 | query = """ 47 | subscription ($limit: Int!) { 48 | notifications (order_by: {created: "desc"}, limit: $limit) { 49 | id 50 | title 51 | content 52 | } 53 | } 54 | """ 55 | 56 | def callback(_id, data): 57 | print("got new data..") 58 | print(f"msg id: {_id}. data: {data}") 59 | 60 | with GraphQLClient('ws://localhost:8080/graphql') as client: 61 | sub_id = client.subscribe(query, variables={'limit': 10}, callback=callback) 62 | # ... 63 | ``` 64 | 65 | ### Headers can be passed too 66 | 67 | ```python 68 | from graphql_client import GraphQLClient 69 | 70 | query = """ 71 | subscription ($limit: Int!) { 72 | notifications (order_by: {created: "desc"}, limit: $limit) { 73 | id 74 | title 75 | content 76 | } 77 | } 78 | """ 79 | 80 | def callback(_id, data): 81 | print("got new data..") 82 | print(f"msg id: {_id}. data: {data}") 83 | 84 | with GraphQLClient('ws://localhost:8080/graphql') as client: 85 | sub_id = client.subscribe(query, 86 | variables={'limit': 10}, 87 | headers={'Authorization': 'Bearer xxxx'}, 88 | callback=callback) 89 | ... 90 | client.stop_subscribe(sub_id) 91 | ``` 92 | 93 | ### Normal queries and mutations work too 94 | 95 | ```python 96 | from graphql_client import GraphQLClient 97 | 98 | query = """ 99 | query ($limit: Int!) { 100 | notifications (order_by: {created: "desc"}, limit: $limit) { 101 | id 102 | title 103 | content 104 | } 105 | } 106 | """ 107 | 108 | with GraphQLClient('ws://localhost:8080/graphql') as client: 109 | res = client.query(query, variables={'limit': 10}, headers={'Authorization': 'Bearer xxxx'}) 110 | print(res) 111 | ``` 112 | 113 | ### Without the context manager API 114 | 115 | ```python 116 | from graphql_client import GraphQLClient 117 | 118 | query = """ 119 | query ($limit: Int!) { 120 | notifications (order_by: {created: "desc"}, limit: $limit) { 121 | id 122 | title 123 | content 124 | } 125 | } 126 | """ 127 | 128 | client = GraphQLClient('ws://localhost:8080/graphql') 129 | res = client.query(query, variables={'limit': 10}, headers={'Authorization': 'Bearer xxxx'}) 130 | print(res) 131 | client.close() 132 | ``` 133 | 134 | 135 | ## TODO 136 | - support http as well 137 | - should use asyncio websocket library? 138 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import time 2 | from graphql_client import GraphQLClient 3 | 4 | # some sample GraphQL server which supports websocket transport and subscription 5 | client = GraphQLClient('ws://localhost:9001') 6 | 7 | # Simple Query Example 8 | 9 | # query example with GraphQL variables 10 | query = """ 11 | query getUser($userId: Int!) { 12 | users (id: $userId) { 13 | id 14 | username 15 | } 16 | } 17 | """ 18 | 19 | # This is a blocking call, you receive response in the `res` variable 20 | 21 | print('Making a query first') 22 | res = client.query(query, variables={'userId': 2}) 23 | print('query result', res) 24 | 25 | 26 | # Subscription Example 27 | 28 | subscription_query = """ 29 | subscription getUser { 30 | users (id: 2) { 31 | id 32 | username 33 | } 34 | } 35 | """ 36 | 37 | # Our callback function, which will be called and passed data everytime new data is available 38 | def my_callback(op_id, data): 39 | print(f"Got data for Operation ID: {op_id}. Data: {data}") 40 | 41 | print('Making a graphql subscription now...') 42 | sub_id = client.subscribe(subscription_query, callback=my_callback) 43 | print('Created subscription and waiting. Callback function is called whenever there is new data') 44 | 45 | # do some operation while the subscription is running... 46 | time.sleep(10) 47 | client.stop_subscribe(sub_id) 48 | client.close() 49 | -------------------------------------------------------------------------------- /graphql_client/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | A simple GraphQL client that works over Websocket as the transport 4 | protocol, instead of HTTP. 5 | This follows the Apollo protocol. 6 | https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md 7 | """ 8 | 9 | import json 10 | import threading 11 | import uuid 12 | import queue 13 | import logging 14 | from typing import Callable 15 | 16 | import websocket 17 | 18 | 19 | GQL_WS_SUBPROTOCOL = "graphql-ws" 20 | 21 | # all the message types 22 | GQL_CONNECTION_INIT = 'connection_init' 23 | GQL_START = 'start' 24 | GQL_STOP = 'stop' 25 | GQL_CONNECTION_TERMINATE = 'connection_terminate' 26 | GQL_CONNECTION_ERROR = 'connection_error' 27 | GQL_CONNECTION_ACK = 'connection_ack' 28 | GQL_DATA = 'data' 29 | GQL_ERROR = 'error' 30 | GQL_COMPLETE = 'complete' 31 | GQL_CONNECTION_KEEP_ALIVE = 'ka' 32 | 33 | logger = logging.getLogger(__name__) 34 | logger.addHandler(logging.NullHandler()) 35 | 36 | 37 | class ConnectionException(Exception): 38 | """Exception thrown during connection errors to the GraphQL server""" 39 | 40 | class InvalidPayloadException(Exception): 41 | """Exception thrown if payload recived from server is mal-formed or cannot be parsed """ 42 | 43 | class GraphQLClient(): 44 | """ 45 | A simple GraphQL client that works over Websocket as the transport 46 | protocol, instead of HTTP. 47 | This follows the Apollo protocol. 48 | https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md 49 | """ 50 | def __init__(self, url): 51 | self.ws_url = url 52 | self._connection_init_done = False 53 | # cache of the headers for a session 54 | self._headers = None 55 | # map of subscriber id to a callback function 56 | self._subscriber_callbacks = {} 57 | # our general receive queue 58 | self._queue = queue.Queue() 59 | # map of queues for each subscriber 60 | self._subscriber_queues = {} 61 | self._shutdown_receiver = False 62 | self._subscriptions = [] 63 | self.connect() 64 | 65 | def connect(self) -> None: 66 | """ 67 | Initializes a connection with the server. 68 | """ 69 | self._connection = websocket.create_connection(self.ws_url, 70 | subprotocols=[GQL_WS_SUBPROTOCOL]) 71 | # start the reciever thread 72 | self._recevier_thread = threading.Thread(target=self._receiver_task) 73 | self._recevier_thread.start() 74 | 75 | def _reconnect(self): 76 | subscriptions = self._subscriptions 77 | self.__init__(self.ws_url) 78 | 79 | for subscription in subscriptions: 80 | self.subscribe(query=subscription['query'], 81 | variables=subscription['variables'], 82 | headers=subscription['headers'], 83 | callback=subscription['callback']) 84 | 85 | def __dump_queues(self): 86 | logger.debug('[GQL_CLIENT] => Dump of all the internal queues') 87 | logger.debug('[GQL_CLIENT] => Global queue => \n %s', self._queue.queue) 88 | dumps = list(map(lambda q: (q[0], q[1].queue), self._subscriber_queues.items())) 89 | logger.debug('[GQL_CLIENT] => Operation queues: \n %s', dumps) 90 | 91 | # wait for any valid message, while ignoring GQL_CONNECTION_KEEP_ALIVE 92 | def _receiver_task(self): 93 | """the recieve function of the client. Which validates response from the 94 | server and queues data """ 95 | reconnected = False 96 | while not self._shutdown_receiver and not reconnected: 97 | self.__dump_queues() 98 | try: 99 | res = self._connection.recv() 100 | except websocket._exceptions.WebSocketConnectionClosedException as e: 101 | self._reconnect() 102 | reconnected = True 103 | continue 104 | 105 | try: 106 | msg = json.loads(res) 107 | except json.JSONDecodeError as err: 108 | logger.warning('Ignoring. Server sent invalid JSON data: %s \n %s', res, err) 109 | continue 110 | 111 | # ignore messages which are GQL_CONNECTION_KEEP_ALIVE 112 | if msg['type'] != GQL_CONNECTION_KEEP_ALIVE: 113 | 114 | # check all GQL_DATA and GQL_COMPLETE should have 'id'. 115 | # Otherwise, server is sending malformed responses, error out! 116 | if msg['type'] in [GQL_DATA, GQL_COMPLETE] and 'id' not in msg: 117 | # TODO: main thread can't catch this exception; setup 118 | # exception queues. but this scenario will only happen with 119 | # servers having glaring holes in implementing the protocol 120 | # correctly, which is rare. hence this is not very urgent 121 | err = f'Protocol Violation.\nExpected "id" in {msg}, but could not find.' 122 | raise InvalidPayloadException(err) 123 | 124 | # if the message has an id, it is meant for a particular operation 125 | if 'id' in msg: 126 | op_id = msg['id'] 127 | 128 | # put it in the correct operation/subscriber queue 129 | if op_id in self._subscriber_queues: 130 | self._subscriber_queues[op_id].put(msg) 131 | else: 132 | self._subscriber_queues[op_id] = queue.Queue() 133 | self._subscriber_queues[op_id].put(msg) 134 | 135 | # if a callback fn exists with the id, call it 136 | if op_id in self._subscriber_callbacks: 137 | user_fn = self._subscriber_callbacks[op_id] 138 | user_fn(op_id, msg) 139 | 140 | # if it doesn't have an id, put in the global queue 141 | else: 142 | self._queue.put(msg) 143 | 144 | def _insert_subscriber(self, op_id, callback_fn): 145 | self._subscriber_callbacks[op_id] = callback_fn 146 | 147 | def _remove_subscriber(self, op_id): 148 | del self._subscriber_callbacks[op_id] 149 | 150 | def _create_operation_queue(self, op_id): 151 | self._subscriber_queues[op_id] = queue.Queue() 152 | 153 | def _remove_operation_queue(self, op_id): 154 | if op_id in self._subscriber_queues: 155 | del self._subscriber_queues[op_id] 156 | 157 | def _get_operation_result(self, op_id): 158 | return self._subscriber_queues[op_id].get() 159 | 160 | def _connection_init(self, headers=None): 161 | # if we have already initialized and the passed headers are same as 162 | # prev headers, then do nothing and return 163 | if self._connection_init_done and headers == self._headers: 164 | return 165 | 166 | self._headers = headers 167 | # send the `connection_init` message with the payload 168 | payload = {'type': GQL_CONNECTION_INIT, 'payload': {'headers': headers}} 169 | self._connection.send(json.dumps(payload)) 170 | 171 | res = self._queue.get() 172 | 173 | if res['type'] == GQL_CONNECTION_ERROR: 174 | err = res['payload'] if 'payload' in res else 'unknown error' 175 | raise ConnectionException(err) 176 | if res['type'] == GQL_CONNECTION_ACK: 177 | self._connection_init_done = True 178 | return 179 | 180 | err_msg = "Unknown message from server, this client did not understand. " + \ 181 | "Original message: " + res['type'] 182 | raise ConnectionException(err_msg) 183 | 184 | def _start(self, payload, callback=None): 185 | """ pass a callback function only if this is a subscription """ 186 | op_id = uuid.uuid4().hex 187 | frame = {'id': op_id, 'type': GQL_START, 'payload': payload} 188 | self._create_operation_queue(op_id) 189 | if callback: 190 | self._insert_subscriber(op_id, callback) 191 | self._connection.send(json.dumps(frame)) 192 | return op_id 193 | 194 | def _stop(self, op_id): 195 | payload = {'id': op_id, 'type': GQL_STOP} 196 | self._connection.send(json.dumps(payload)) 197 | 198 | def query(self, query: str, variables: dict = None, headers: dict = None) -> dict: 199 | """ 200 | Run a GraphQL query or mutation. The `query` argument is a GraphQL query 201 | string. You can pass optional variables and headers. 202 | 203 | PS: To run a subscription, see the `subscribe` method. 204 | """ 205 | self._connection_init(headers) 206 | payload = {'headers': headers, 'query': query, 'variables': variables} 207 | op_id = self._start(payload) 208 | res = self._get_operation_result(op_id) 209 | self._stop(op_id) 210 | ack = self._get_operation_result(op_id) 211 | if ack['type'] != GQL_COMPLETE: 212 | logger.warning('Expected to recieve complete, but received: %s', ack) 213 | self._remove_operation_queue(op_id) 214 | return res 215 | 216 | def subscribe(self, query: str, variables: dict = None, headers: dict = None, 217 | callback: Callable[[str, dict], None] = None) -> str: 218 | """ 219 | Run a GraphQL subscription. 220 | 221 | Parameters: 222 | query (str): the GraphQL query string 223 | callback (function): a callback function. This is mandatory. 224 | This callback function is called, everytime there is new data from the 225 | subscription. 226 | variables (dict): (optional) GraphQL variables 227 | headers (dict): (optional) a dictionary of headers for the session 228 | 229 | Returns: 230 | op_id (str): The operation id (a UUIDv4) for this subscription operation 231 | """ 232 | 233 | # sanity check that the user passed a valid function 234 | if not callback or not callable(callback): 235 | raise TypeError('the argument `callback` is mandatory and it should be a function') 236 | 237 | self._connection_init(headers) 238 | payload = {'headers': headers, 'query': query, 'variables': variables} 239 | op_id = self._start(payload, callback) 240 | self._subscriptions.append({ 241 | 'query': query, 242 | 'variables': variables, 243 | 'headers': headers, 244 | 'callback': callback 245 | }) 246 | return op_id 247 | 248 | def stop_subscribe(self, op_id: str) -> None: 249 | """ 250 | Stop a subscription. Takes an operation ID (`op_id`) and stops the 251 | subscription. 252 | """ 253 | self._stop(op_id) 254 | self._remove_subscriber(op_id) 255 | self._remove_operation_queue(op_id) 256 | 257 | def close(self) -> None: 258 | """ 259 | Close the connection with the server. To reconnect, use the `connect` 260 | method. 261 | """ 262 | self._shutdown_receiver = True 263 | self._recevier_thread.join() 264 | self._connection.close() 265 | 266 | def __enter__(self): 267 | """ enter method for context manager """ 268 | return self 269 | 270 | def __exit__(self, exc_type, exc_value, exc_traceback): 271 | """ exit method for context manager """ 272 | self.close() 273 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from setuptools import find_packages, setup 3 | 4 | 5 | __version__ = "0.1.2" 6 | __desc__ = "A dead-simple GraphQL client that supports subscriptions over websockets" 7 | 8 | with open('README.md') as readme_file: 9 | readme = readme_file.read() 10 | 11 | requirements = [ 12 | 'websocket-client==0.54.0' 13 | ] 14 | 15 | test_requirements = [] 16 | 17 | setup( 18 | name='py-graphql-client', 19 | version=__version__, 20 | description=__desc__, 21 | long_description=readme, 22 | long_description_content_type='text/markdown', 23 | author="Anon Ray", 24 | author_email='rayanon004@gmail.com', 25 | url='https://github.com/ecthiender/py-graphql-client', 26 | packages=find_packages(exclude=['tests', 'tests.*']), 27 | package_data={'': ['LICENSE']}, 28 | package_dir={'graphql_client': 'graphql_client'}, 29 | python_requires=">=3.4", 30 | include_package_data=True, 31 | install_requires=requirements, 32 | license="BSD3", 33 | zip_safe=False, 34 | keywords=['graphql', 'websocket', 'subscriptions', 'graphql-client'], 35 | classifiers=[ 36 | 'Development Status :: 3 - Alpha', 37 | 'Intended Audience :: Developers', 38 | 'License :: OSI Approved :: BSD License', 39 | 'Natural Language :: English', 40 | 'Programming Language :: Python :: 3', 41 | 'Programming Language :: Python :: 3.8', 42 | 'Environment :: Console', 43 | 'Environment :: Web Environment', 44 | 'Environment :: Other Environment', 45 | 'Topic :: Internet :: WWW/HTTP', 46 | 'Topic :: Software Development :: Libraries', 47 | ], 48 | test_suite='tests', 49 | tests_require=test_requirements 50 | ) 51 | -------------------------------------------------------------------------------- /tests/test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import json 3 | import threading 4 | import unittest 5 | 6 | from .websocket_server import WebsocketServer 7 | from graphql_client import * 8 | 9 | # The protocol: 10 | # https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md 11 | 12 | 13 | # Called for every client connecting (after handshake) 14 | def new_client(client, server): 15 | pass 16 | # print("[TEST_SERVER] => New client connected and was given id %d" % client['id']) 17 | # server.send_message_to_all("Hey all, a new client has joined us") 18 | 19 | 20 | # Called for every client disconnecting 21 | def client_left(client, server): 22 | pass 23 | # print("[TEST_SERVER] => Client(%d) disconnected" % client['id']) 24 | 25 | 26 | # Called when a client sends a message 27 | def message_received(client, server, message): 28 | # print("[TEST_SERVER] => Client(%d) said: %s" % (client['id'], message)) 29 | frame = json.loads(message) 30 | response = mock_server(frame) 31 | response.send(client, server) 32 | 33 | 34 | class GQLResponse(): 35 | def __init__(self, payload, time_between=0.1): 36 | if isinstance(payload, list): 37 | self._list_response = True 38 | elif isinstance(payload, dict): 39 | self._list_response = False 40 | 41 | self.payload = payload 42 | self.time_between = time_between 43 | 44 | def get_next_payload(self): 45 | pass 46 | 47 | def send(self, client, server): 48 | if self._list_response: 49 | # print('[TEST_SERVER] => Sending a list of messages to the client') 50 | for response in self.payload: 51 | # print('[TEST_SERVER] => Sending message to client', response) 52 | server.send_message(client, json.dumps(response)) 53 | time.sleep(self.time_between) 54 | else: 55 | # print('[TEST_SERVER] => Sending message to client', self.payload) 56 | server.send_message(client, json.dumps(self.payload)) 57 | 58 | 59 | def mock_server(frame): 60 | if frame['type'] == GQL_CONNECTION_INIT: 61 | return GQLResponse({'type': GQL_CONNECTION_ACK}) 62 | 63 | elif frame['type'] == GQL_START: 64 | op_id = frame['id'] 65 | if frame['payload']['query'].strip().startswith('subscription'): 66 | return GQLResponse([ 67 | {'id': op_id, 'type': GQL_DATA, 'payload': {'data': {'msg': 'hello world'}}}, 68 | {'id': op_id, 'type': GQL_DATA, 'payload': {'data': {'msg': 'hello world'}}}, 69 | {'id': op_id, 'type': GQL_DATA, 'payload': {'data': {'msg': 'hello world'}}}, 70 | {'id': op_id, 'type': GQL_COMPLETE} 71 | ], time_between=0.5) 72 | 73 | return GQLResponse([ 74 | {'id': op_id, 'type': GQL_DATA, 'payload': {'data': {'msg': 'hello world'}}}, 75 | {'id': op_id, 'type': GQL_COMPLETE} 76 | ], time_between=0.5) 77 | 78 | elif frame['type'] == GQL_STOP: 79 | return GQLResponse({'id': frame['id'], 'type': GQL_COMPLETE}) 80 | # return GQLResponse({'type': GQL_CONNECTION_KEEP_ALIVE}) 81 | 82 | 83 | class ApolloProtocolServer(): 84 | def __init__(self, port=9001): 85 | self.port = port 86 | self.is_running = False 87 | self.server = None 88 | self.server_thread = None 89 | 90 | def _serve(self): 91 | try: 92 | self.server.run_forever() 93 | finally: 94 | self.server.server_close() 95 | 96 | def start_server(self): 97 | if not self.is_running: 98 | self.server = WebsocketServer(self.port) 99 | self.server.set_fn_new_client(new_client) 100 | self.server.set_fn_client_left(client_left) 101 | self.server.set_fn_message_received(message_received) 102 | 103 | self.server_thread = threading.Thread(target=self._serve, daemon=True) 104 | self.server_thread.start() 105 | 106 | self.is_running = True 107 | # print('[TEST_SERVER] => [DEBUG] => end of start server..') 108 | 109 | def stop_server(self): 110 | # print('[TEST_SERVER] => [DEBUG] => inside stop server..') 111 | if self.is_running: 112 | self.is_running = False 113 | # print('[TEST_SERVER] => [DEBUG] => calling server_close..') 114 | self.server.server_close() 115 | # print('[TEST_SERVER] => [DEBUG] => called server_close..') 116 | # print('[TEST_SERVER] => [DEBUG] => calling thread.join..') 117 | self.server_thread.join(timeout=2) 118 | # print('[TEST_SERVER] => [DEBUG] => called thread.join..') 119 | 120 | 121 | query = """ 122 | query getUser($userId: Int!) { 123 | user (id: $userId) { 124 | id 125 | username 126 | } 127 | } 128 | """ 129 | 130 | subscription = """ 131 | subscription getUser($userId: Int!) { 132 | user (id: $userId) { 133 | id 134 | username 135 | } 136 | } 137 | """ 138 | 139 | class TestClient(unittest.TestCase): 140 | 141 | def __init__(self, *args, **kwargs): 142 | super().__init__(*args, **kwargs) 143 | self.ws_server = ApolloProtocolServer() 144 | 145 | def setUp(self): 146 | # print('[TEST] => setUp()') 147 | self.ws_server.start_server() 148 | self.client = GraphQLClient('ws://localhost:9001') 149 | 150 | def test_query(self): 151 | res = self.client.query(query, variables={'userId': 2}) 152 | # print('[TEST] => Got response inside the test', res) 153 | self.assertTrue(res['type'] == GQL_DATA) 154 | 155 | def test_subscription(self): 156 | op_ids = [] 157 | all_datas = [] 158 | def my_callback(op_id, data): 159 | op_ids.append(op_id) 160 | all_datas.append(data) 161 | 162 | sub_id = self.client.subscribe(subscription, variables={'userId': 2}, callback=my_callback) 163 | # print('[TEST] => Got response inside the test', sub_id) 164 | # wait for 3 seconds to finish subscription 165 | time.sleep(3) 166 | self.client.stop_subscribe(sub_id) 167 | 168 | for op_id in op_ids: 169 | self.assertEqual(op_id, sub_id) 170 | 171 | for res in all_datas[:-1]: 172 | self.assertEqual(res['type'], GQL_DATA) 173 | 174 | self.assertEqual(all_datas[-1]['type'], GQL_COMPLETE) 175 | 176 | def test_multiple_queries(self): 177 | for _ in range(4): 178 | res = self.client.query(query, variables={'userId': 2}) 179 | # print('[TEST] => Got response inside the test', res) 180 | self.assertTrue(res['type'] == GQL_DATA) 181 | 182 | def test_multiple_subscriptions(self): 183 | op_ids1 = [] 184 | op_ids2 = [] 185 | all_datas1 = [] 186 | all_datas2 = [] 187 | 188 | def my_callback1(op_id, data): 189 | # print('[TEST] => inside callback: ', op_id, data) 190 | op_ids1.append(op_id) 191 | all_datas1.append(data) 192 | 193 | def my_callback2(op_id, data): 194 | # print('[TEST] => inside callback: ', op_id, data) 195 | op_ids2.append(op_id) 196 | all_datas2.append(data) 197 | 198 | sub_id1 = self.client.subscribe(subscription, variables={'userId': 2}, callback=my_callback1) 199 | sub_id2 = self.client.subscribe(subscription, variables={'userId': 2}, callback=my_callback2) 200 | 201 | # wait for 4 seconds to finish subscription 202 | time.sleep(4) 203 | self.client.stop_subscribe(sub_id1) 204 | self.client.stop_subscribe(sub_id2) 205 | 206 | # check invariants for sub_id1 207 | for op_id in op_ids1: 208 | self.assertEqual(op_id, sub_id1) 209 | 210 | for res in all_datas1[:-1]: 211 | self.assertEqual(res['type'], GQL_DATA) 212 | 213 | self.assertEqual(all_datas1[-1]['type'], GQL_COMPLETE) 214 | 215 | # check invariants for sub_id2 216 | for op_id in op_ids2: 217 | self.assertEqual(op_id, sub_id2) 218 | 219 | for res in all_datas2[:-1]: 220 | self.assertEqual(res['type'], GQL_DATA) 221 | 222 | self.assertEqual(all_datas2[-1]['type'], GQL_COMPLETE) 223 | 224 | # TODO: one more testcase with multiple queries and multiple subscriptions mixed 225 | 226 | def tearDown(self): 227 | # print('[TEST] => tearDown()') 228 | self.client.close() 229 | self.ws_server.stop_server() 230 | 231 | 232 | if __name__ == '__main__': 233 | unittest.main() 234 | -------------------------------------------------------------------------------- /tests/websocket_server.py: -------------------------------------------------------------------------------- 1 | # Author: Johan Hanssen Seferidis 2 | # License: MIT 3 | # Downloaded from: https://github.com/Pithikos/python-websocket-server/blob/master/websocket_server/websocket_server.py 4 | 5 | import sys 6 | import struct 7 | from base64 import b64encode 8 | from hashlib import sha1 9 | import logging 10 | from socket import error as SocketError 11 | import errno 12 | 13 | if sys.version_info[0] < 3: 14 | from SocketServer import ThreadingMixIn, TCPServer, StreamRequestHandler 15 | else: 16 | from socketserver import ThreadingMixIn, TCPServer, StreamRequestHandler 17 | 18 | logger = logging.getLogger(__name__) 19 | logging.basicConfig() 20 | 21 | ''' 22 | +-+-+-+-+-------+-+-------------+-------------------------------+ 23 | 0 1 2 3 24 | 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 25 | +-+-+-+-+-------+-+-------------+-------------------------------+ 26 | |F|R|R|R| opcode|M| Payload len | Extended payload length | 27 | |I|S|S|S| (4) |A| (7) | (16/64) | 28 | |N|V|V|V| |S| | (if payload len==126/127) | 29 | | |1|2|3| |K| | | 30 | +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + 31 | | Extended payload length continued, if payload len == 127 | 32 | + - - - - - - - - - - - - - - - +-------------------------------+ 33 | | Payload Data continued ... | 34 | +---------------------------------------------------------------+ 35 | ''' 36 | 37 | FIN = 0x80 38 | OPCODE = 0x0f 39 | MASKED = 0x80 40 | PAYLOAD_LEN = 0x7f 41 | PAYLOAD_LEN_EXT16 = 0x7e 42 | PAYLOAD_LEN_EXT64 = 0x7f 43 | 44 | OPCODE_CONTINUATION = 0x0 45 | OPCODE_TEXT = 0x1 46 | OPCODE_BINARY = 0x2 47 | OPCODE_CLOSE_CONN = 0x8 48 | OPCODE_PING = 0x9 49 | OPCODE_PONG = 0xA 50 | 51 | 52 | # -------------------------------- API --------------------------------- 53 | 54 | class API(): 55 | 56 | def run_forever(self): 57 | try: 58 | logger.info("Listening on port %d for clients.." % self.port) 59 | self.serve_forever() 60 | except KeyboardInterrupt: 61 | self.server_close() 62 | logger.info("Server terminated.") 63 | except Exception as e: 64 | logger.error(str(e), exc_info=True) 65 | exit(1) 66 | 67 | def new_client(self, client, server): 68 | pass 69 | 70 | def client_left(self, client, server): 71 | pass 72 | 73 | def message_received(self, client, server, message): 74 | pass 75 | 76 | def set_fn_new_client(self, fn): 77 | self.new_client = fn 78 | 79 | def set_fn_client_left(self, fn): 80 | self.client_left = fn 81 | 82 | def set_fn_message_received(self, fn): 83 | self.message_received = fn 84 | 85 | def send_message(self, client, msg): 86 | self._unicast_(client, msg) 87 | 88 | def send_message_to_all(self, msg): 89 | self._multicast_(msg) 90 | 91 | 92 | # ------------------------- Implementation ----------------------------- 93 | 94 | class WebsocketServer(ThreadingMixIn, TCPServer, API): 95 | """ 96 | A websocket server waiting for clients to connect. 97 | 98 | Args: 99 | port(int): Port to bind to 100 | host(str): Hostname or IP to listen for connections. By default 127.0.0.1 101 | is being used. To accept connections from any client, you should use 102 | 0.0.0.0. 103 | loglevel: Logging level from logging module to use for logging. By default 104 | warnings and errors are being logged. 105 | 106 | Properties: 107 | clients(list): A list of connected clients. A client is a dictionary 108 | like below. 109 | { 110 | 'id' : id, 111 | 'handler' : handler, 112 | 'address' : (addr, port) 113 | } 114 | """ 115 | 116 | allow_reuse_address = True 117 | daemon_threads = True # comment to keep threads alive until finished 118 | 119 | clients = [] 120 | id_counter = 0 121 | 122 | def __init__(self, port, host='127.0.0.1', loglevel=logging.WARNING): 123 | logger.setLevel(loglevel) 124 | TCPServer.__init__(self, (host, port), WebSocketHandler) 125 | self.port = self.socket.getsockname()[1] 126 | 127 | def _message_received_(self, handler, msg): 128 | self.message_received(self.handler_to_client(handler), self, msg) 129 | 130 | def _ping_received_(self, handler, msg): 131 | handler.send_pong(msg) 132 | 133 | def _pong_received_(self, handler, msg): 134 | pass 135 | 136 | def _new_client_(self, handler): 137 | self.id_counter += 1 138 | client = { 139 | 'id': self.id_counter, 140 | 'handler': handler, 141 | 'address': handler.client_address 142 | } 143 | self.clients.append(client) 144 | self.new_client(client, self) 145 | 146 | def _client_left_(self, handler): 147 | client = self.handler_to_client(handler) 148 | self.client_left(client, self) 149 | if client in self.clients: 150 | self.clients.remove(client) 151 | 152 | def _unicast_(self, to_client, msg): 153 | to_client['handler'].send_message(msg) 154 | 155 | def _multicast_(self, msg): 156 | for client in self.clients: 157 | self._unicast_(client, msg) 158 | 159 | def handler_to_client(self, handler): 160 | for client in self.clients: 161 | if client['handler'] == handler: 162 | return client 163 | 164 | 165 | class WebSocketHandler(StreamRequestHandler): 166 | 167 | def __init__(self, socket, addr, server): 168 | self.server = server 169 | StreamRequestHandler.__init__(self, socket, addr, server) 170 | 171 | def setup(self): 172 | StreamRequestHandler.setup(self) 173 | self.keep_alive = True 174 | self.handshake_done = False 175 | self.valid_client = False 176 | 177 | def handle(self): 178 | while self.keep_alive: 179 | if not self.handshake_done: 180 | self.handshake() 181 | elif self.valid_client: 182 | self.read_next_message() 183 | 184 | def read_bytes(self, num): 185 | # python3 gives ordinal of byte directly 186 | bytes = self.rfile.read(num) 187 | if sys.version_info[0] < 3: 188 | return map(ord, bytes) 189 | else: 190 | return bytes 191 | 192 | def read_next_message(self): 193 | try: 194 | b1, b2 = self.read_bytes(2) 195 | except SocketError as e: # to be replaced with ConnectionResetError for py3 196 | if e.errno == errno.ECONNRESET: 197 | logger.info("Client closed connection.") 198 | self.keep_alive = 0 199 | return 200 | b1, b2 = 0, 0 201 | except ValueError as e: 202 | b1, b2 = 0, 0 203 | 204 | fin = b1 & FIN 205 | opcode = b1 & OPCODE 206 | masked = b2 & MASKED 207 | payload_length = b2 & PAYLOAD_LEN 208 | 209 | if opcode == OPCODE_CLOSE_CONN: 210 | logger.info("Client asked to close connection.") 211 | self.keep_alive = 0 212 | return 213 | if not masked: 214 | logger.warn("Client must always be masked.") 215 | self.keep_alive = 0 216 | return 217 | if opcode == OPCODE_CONTINUATION: 218 | logger.warn("Continuation frames are not supported.") 219 | return 220 | elif opcode == OPCODE_BINARY: 221 | logger.warn("Binary frames are not supported.") 222 | return 223 | elif opcode == OPCODE_TEXT: 224 | opcode_handler = self.server._message_received_ 225 | elif opcode == OPCODE_PING: 226 | opcode_handler = self.server._ping_received_ 227 | elif opcode == OPCODE_PONG: 228 | opcode_handler = self.server._pong_received_ 229 | else: 230 | logger.warn("Unknown opcode %#x." % opcode) 231 | self.keep_alive = 0 232 | return 233 | 234 | if payload_length == 126: 235 | payload_length = struct.unpack(">H", self.rfile.read(2))[0] 236 | elif payload_length == 127: 237 | payload_length = struct.unpack(">Q", self.rfile.read(8))[0] 238 | 239 | masks = self.read_bytes(4) 240 | message_bytes = bytearray() 241 | for message_byte in self.read_bytes(payload_length): 242 | message_byte ^= masks[len(message_bytes) % 4] 243 | message_bytes.append(message_byte) 244 | opcode_handler(self, message_bytes.decode('utf8')) 245 | 246 | def send_message(self, message): 247 | self.send_text(message) 248 | 249 | def send_pong(self, message): 250 | self.send_text(message, OPCODE_PONG) 251 | 252 | def send_text(self, message, opcode=OPCODE_TEXT): 253 | """ 254 | Important: Fragmented(=continuation) messages are not supported since 255 | their usage cases are limited - when we don't know the payload length. 256 | """ 257 | 258 | # Validate message 259 | if isinstance(message, bytes): 260 | message = try_decode_UTF8(message) # this is slower but ensures we have UTF-8 261 | if not message: 262 | logger.warning("Can\'t send message, message is not valid UTF-8") 263 | return False 264 | elif sys.version_info < (3,0) and (isinstance(message, str) or isinstance(message, unicode)): 265 | pass 266 | elif isinstance(message, str): 267 | pass 268 | else: 269 | logger.warning('Can\'t send message, message has to be a string or bytes. Given type is %s' % type(message)) 270 | return False 271 | 272 | header = bytearray() 273 | payload = encode_to_UTF8(message) 274 | payload_length = len(payload) 275 | 276 | # Normal payload 277 | if payload_length <= 125: 278 | header.append(FIN | opcode) 279 | header.append(payload_length) 280 | 281 | # Extended payload 282 | elif payload_length >= 126 and payload_length <= 65535: 283 | header.append(FIN | opcode) 284 | header.append(PAYLOAD_LEN_EXT16) 285 | header.extend(struct.pack(">H", payload_length)) 286 | 287 | # Huge extended payload 288 | elif payload_length < 18446744073709551616: 289 | header.append(FIN | opcode) 290 | header.append(PAYLOAD_LEN_EXT64) 291 | header.extend(struct.pack(">Q", payload_length)) 292 | 293 | else: 294 | raise Exception("Message is too big. Consider breaking it into chunks.") 295 | return 296 | 297 | self.request.send(header + payload) 298 | 299 | def read_http_headers(self): 300 | headers = {} 301 | # first line should be HTTP GET 302 | http_get = self.rfile.readline().decode().strip() 303 | assert http_get.upper().startswith('GET') 304 | # remaining should be headers 305 | while True: 306 | header = self.rfile.readline().decode().strip() 307 | if not header: 308 | break 309 | head, value = header.split(':', 1) 310 | headers[head.lower().strip()] = value.strip() 311 | return headers 312 | 313 | def handshake(self): 314 | headers = self.read_http_headers() 315 | 316 | try: 317 | assert headers['upgrade'].lower() == 'websocket' 318 | except AssertionError: 319 | self.keep_alive = False 320 | return 321 | 322 | try: 323 | key = headers['sec-websocket-key'] 324 | except KeyError: 325 | logger.warning("Client tried to connect but was missing a key") 326 | self.keep_alive = False 327 | return 328 | 329 | response = self.make_handshake_response(key) 330 | self.handshake_done = self.request.send(response.encode()) 331 | self.valid_client = True 332 | self.server._new_client_(self) 333 | 334 | @classmethod 335 | def make_handshake_response(cls, key): 336 | # mocking a hardcoded apollo-protcol graphql server over websockets 337 | return \ 338 | 'HTTP/1.1 101 Switching Protocols\r\n'\ 339 | 'Upgrade: websocket\r\n' \ 340 | 'Connection: Upgrade\r\n' \ 341 | 'Sec-WebSocket-Accept: %s\r\n' \ 342 | 'Sec-WebSocket-Protocol: graphql-ws\r\n'\ 343 | '\r\n' % cls.calculate_response_key(key) 344 | 345 | @classmethod 346 | def calculate_response_key(cls, key): 347 | GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' 348 | hash = sha1(key.encode() + GUID.encode()) 349 | response_key = b64encode(hash.digest()).strip() 350 | return response_key.decode('ASCII') 351 | 352 | def finish(self): 353 | self.server._client_left_(self) 354 | 355 | 356 | def encode_to_UTF8(data): 357 | try: 358 | return data.encode('UTF-8') 359 | except UnicodeEncodeError as e: 360 | logger.error("Could not encode data to UTF-8 -- %s" % e) 361 | return False 362 | except Exception as e: 363 | raise(e) 364 | return False 365 | 366 | 367 | def try_decode_UTF8(data): 368 | try: 369 | return data.decode('utf-8') 370 | except UnicodeDecodeError: 371 | return False 372 | except Exception as e: 373 | raise(e) 374 | --------------------------------------------------------------------------------