├── tests ├── test2.pcm ├── test0.wav ├── test1.pcm ├── test1.wav ├── test1_1.pcm ├── test_tts.pcm ├── tts_test.pcm ├── tts_test.wav ├── __pycache__ │ └── test_utils.cpython-39.pyc ├── test_token.py ├── test_utils.py ├── test_stream_input_tts.py ├── test_sr.py ├── test_tts.py ├── test_realtime_meeting.py └── test_st.py ├── nls ├── websocket │ ├── tests │ │ ├── __init__.py │ │ ├── data │ │ │ ├── header02.txt │ │ │ ├── header01.txt │ │ │ └── header03.txt │ │ ├── echo-server.py │ │ ├── test_abnf.py │ │ ├── test_cookiejar.py │ │ ├── test_app.py │ │ ├── test_http.py │ │ └── test_url.py │ ├── __pycache__ │ │ ├── _abnf.cpython-39.pyc │ │ ├── _app.cpython-39.pyc │ │ ├── _core.cpython-39.pyc │ │ ├── _http.cpython-39.pyc │ │ ├── _url.cpython-39.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── _logging.cpython-39.pyc │ │ ├── _socket.cpython-39.pyc │ │ ├── _utils.cpython-39.pyc │ │ ├── _cookiejar.cpython-39.pyc │ │ ├── _handshake.cpython-39.pyc │ │ ├── _exceptions.cpython-39.pyc │ │ └── _ssl_compat.cpython-39.pyc │ ├── __init__.py │ ├── _ssl_compat.py │ ├── _logging.py │ ├── _cookiejar.py │ ├── _exceptions.py │ ├── _utils.py │ ├── _url.py │ ├── _socket.py │ ├── _handshake.py │ ├── _http.py │ ├── _abnf.py │ └── _app.py ├── version.py ├── __pycache__ │ ├── core.cpython-39.pyc │ ├── token.cpython-39.pyc │ ├── util.cpython-39.pyc │ ├── __init__.cpython-39.pyc │ ├── _logger.cpython-36.pyc │ ├── _logging.cpython-36.pyc │ ├── logging.cpython-39.pyc │ ├── version.cpython-39.pyc │ ├── exception.cpython-39.pyc │ ├── speech_recognizer.cpython-39.pyc │ ├── speech_synthesizer.cpython-39.pyc │ └── speech_transcriber.cpython-39.pyc ├── __init__.py ├── exception.py ├── util.py ├── token.py ├── logging.py ├── core.py ├── speech_synthesizer.py ├── speech_recognizer.py ├── realtime_meeting.py └── speech_transcriber.py ├── requirements.txt ├── .gitignore ├── README.md ├── LICENSE └── setup.py /tests/test2.pcm: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nls/websocket/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | oss2 2 | aliyun-python-sdk-core>=2.13.3 3 | matplotlib>=3.3.4 4 | 5 | -------------------------------------------------------------------------------- /nls/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | __version__ = '1.0.0' -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | nls/__pycache__ 2 | nls/websocket/__pycache__ 3 | tests/__pycache__ 4 | nls.egg-info/ -------------------------------------------------------------------------------- /tests/test0.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/tests/test0.wav -------------------------------------------------------------------------------- /tests/test1.pcm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/tests/test1.pcm -------------------------------------------------------------------------------- /tests/test1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/tests/test1.wav -------------------------------------------------------------------------------- /tests/test1_1.pcm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/tests/test1_1.pcm -------------------------------------------------------------------------------- /tests/test_tts.pcm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/tests/test_tts.pcm -------------------------------------------------------------------------------- /tests/tts_test.pcm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/tests/tts_test.pcm -------------------------------------------------------------------------------- /tests/tts_test.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/tests/tts_test.wav -------------------------------------------------------------------------------- /nls/__pycache__/core.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/nls/__pycache__/core.cpython-39.pyc -------------------------------------------------------------------------------- /nls/__pycache__/token.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/nls/__pycache__/token.cpython-39.pyc -------------------------------------------------------------------------------- /nls/__pycache__/util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/nls/__pycache__/util.cpython-39.pyc -------------------------------------------------------------------------------- /nls/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/nls/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /nls/__pycache__/_logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/nls/__pycache__/_logger.cpython-36.pyc -------------------------------------------------------------------------------- /nls/__pycache__/_logging.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/nls/__pycache__/_logging.cpython-36.pyc -------------------------------------------------------------------------------- /nls/__pycache__/logging.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/nls/__pycache__/logging.cpython-39.pyc -------------------------------------------------------------------------------- /nls/__pycache__/version.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/nls/__pycache__/version.cpython-39.pyc -------------------------------------------------------------------------------- /nls/__pycache__/exception.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/nls/__pycache__/exception.cpython-39.pyc -------------------------------------------------------------------------------- /tests/__pycache__/test_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/tests/__pycache__/test_utils.cpython-39.pyc -------------------------------------------------------------------------------- /nls/websocket/__pycache__/_abnf.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/nls/websocket/__pycache__/_abnf.cpython-39.pyc -------------------------------------------------------------------------------- /nls/websocket/__pycache__/_app.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/nls/websocket/__pycache__/_app.cpython-39.pyc -------------------------------------------------------------------------------- /nls/websocket/__pycache__/_core.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/nls/websocket/__pycache__/_core.cpython-39.pyc -------------------------------------------------------------------------------- /nls/websocket/__pycache__/_http.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/nls/websocket/__pycache__/_http.cpython-39.pyc -------------------------------------------------------------------------------- /nls/websocket/__pycache__/_url.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/nls/websocket/__pycache__/_url.cpython-39.pyc -------------------------------------------------------------------------------- /nls/__pycache__/speech_recognizer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/nls/__pycache__/speech_recognizer.cpython-39.pyc -------------------------------------------------------------------------------- /nls/__pycache__/speech_synthesizer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/nls/__pycache__/speech_synthesizer.cpython-39.pyc -------------------------------------------------------------------------------- /nls/__pycache__/speech_transcriber.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/nls/__pycache__/speech_transcriber.cpython-39.pyc -------------------------------------------------------------------------------- /nls/websocket/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/nls/websocket/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /nls/websocket/__pycache__/_logging.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/nls/websocket/__pycache__/_logging.cpython-39.pyc -------------------------------------------------------------------------------- /nls/websocket/__pycache__/_socket.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/nls/websocket/__pycache__/_socket.cpython-39.pyc -------------------------------------------------------------------------------- /nls/websocket/__pycache__/_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/nls/websocket/__pycache__/_utils.cpython-39.pyc -------------------------------------------------------------------------------- /nls/websocket/__pycache__/_cookiejar.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/nls/websocket/__pycache__/_cookiejar.cpython-39.pyc -------------------------------------------------------------------------------- /nls/websocket/__pycache__/_handshake.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/nls/websocket/__pycache__/_handshake.cpython-39.pyc -------------------------------------------------------------------------------- /nls/websocket/__pycache__/_exceptions.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/nls/websocket/__pycache__/_exceptions.cpython-39.pyc -------------------------------------------------------------------------------- /nls/websocket/__pycache__/_ssl_compat.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/alibabacloud-nls-python-sdk/HEAD/nls/websocket/__pycache__/_ssl_compat.cpython-39.pyc -------------------------------------------------------------------------------- /tests/test_token.py: -------------------------------------------------------------------------------- 1 | from nls.token import getToken 2 | 3 | from tests.test_utils import TEST_ACCESS_AKID, TEST_ACCESS_AKKEY 4 | 5 | 6 | info = getToken(TEST_ACCESS_AKID, TEST_ACCESS_AKKEY) 7 | print(info) 8 | -------------------------------------------------------------------------------- /nls/websocket/tests/data/header02.txt: -------------------------------------------------------------------------------- 1 | HTTP/1.1 101 WebSocket Protocol Handshake 2 | Connection: Upgrade 3 | Upgrade WebSocket 4 | Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0= 5 | some_header: something 6 | 7 | -------------------------------------------------------------------------------- /nls/websocket/tests/data/header01.txt: -------------------------------------------------------------------------------- 1 | HTTP/1.1 101 WebSocket Protocol Handshake 2 | Connection: Upgrade 3 | Upgrade: WebSocket 4 | Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0= 5 | some_header: something 6 | 7 | -------------------------------------------------------------------------------- /nls/websocket/tests/data/header03.txt: -------------------------------------------------------------------------------- 1 | HTTP/1.1 101 WebSocket Protocol Handshake 2 | Connection: Upgrade, Keep-Alive 3 | Upgrade: WebSocket 4 | Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0= 5 | Set-Cookie: Token=ABCDE 6 | some_header: something 7 | 8 | -------------------------------------------------------------------------------- /nls/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | 3 | from .logging import * 4 | from .speech_recognizer import * 5 | from .speech_transcriber import * 6 | from .speech_synthesizer import * 7 | from .stream_input_tts import * 8 | from .realtime_meeting import * 9 | from .util import * 10 | from .version import __version__ 11 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | 3 | import os 4 | 5 | # only for test 6 | TEST_ACCESS_AKID = os.environ['TEST_ACCESS_AKID'] 7 | TEST_ACCESS_AKKEY = os.environ['TEST_ACCESS_AKKEY'] 8 | TEST_ACCESS_TOKEN = os.environ['TEST_ACCESS_TOKEN'] 9 | TEST_ACCESS_APPKEY = os.environ['TEST_ACCESS_APPKEY'] 10 | 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # alibaba-nls-python-sdk 2 | 3 | This is Python SDK for NLS. It supports 4 | SPEECH-RECOGNIZER/SPEECH-SYNTHESIZER/SPEECH-TRANSLATOR/COMMON-REQUESTS-PROTO. 5 | 6 | This module works on Python versions: 7 | > 3.6 and greater 8 | 9 | install requirements: 10 | > python -m pip install -r requirements.txt 11 | 12 | install package: 13 | > python -m pip install . 14 | -------------------------------------------------------------------------------- /nls/websocket/tests/echo-server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # From https://github.com/aaugustin/websockets/blob/main/example/echo.py 4 | 5 | import asyncio 6 | import websockets 7 | import os 8 | 9 | LOCAL_WS_SERVER_PORT = os.environ.get('LOCAL_WS_SERVER_PORT', '8765') 10 | 11 | 12 | async def echo(websocket, path): 13 | async for message in websocket: 14 | await websocket.send(message) 15 | 16 | 17 | async def main(): 18 | async with websockets.serve(echo, "localhost", LOCAL_WS_SERVER_PORT): 19 | await asyncio.Future() # run forever 20 | 21 | asyncio.run(main()) 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 1999-present Alibaba Group Holding Ltd. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /nls/exception.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | 3 | 4 | class InvalidParameter(Exception): 5 | pass 6 | 7 | # Token 8 | class GetTokenFailed(Exception): 9 | pass 10 | 11 | # Connection 12 | class ConnectionTimeout(Exception): 13 | pass 14 | 15 | class ConnectionUnavailable(Exception): 16 | pass 17 | 18 | class StartTimeoutException(Exception): 19 | pass 20 | 21 | class StopTimeoutException(Exception): 22 | pass 23 | 24 | class NotStartException(Exception): 25 | pass 26 | 27 | class CompleteTimeoutException(Exception): 28 | pass 29 | 30 | class WrongStateException(Exception): 31 | pass -------------------------------------------------------------------------------- /nls/websocket/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | __init__.py 3 | websocket - WebSocket client library for Python 4 | 5 | Copyright 2021 engn33r 6 | 7 | Licensed under the Apache License, Version 2.0 (the "License"); 8 | you may not use this file except in compliance with the License. 9 | You may obtain a copy of the License at 10 | 11 | http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | Unless required by applicable law or agreed to in writing, software 14 | distributed under the License is distributed on an "AS IS" BASIS, 15 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | See the License for the specific language governing permissions and 17 | limitations under the License. 18 | """ 19 | from ._abnf import * 20 | from ._app import WebSocketApp 21 | from ._core import * 22 | from ._exceptions import * 23 | from ._logging import * 24 | from ._socket import * 25 | 26 | __version__ = "1.2.1" 27 | -------------------------------------------------------------------------------- /nls/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | 3 | from struct import * 4 | 5 | __all__=['wav2pcm', 'GetDefaultContext'] 6 | 7 | def GetDefaultContext(): 8 | """ 9 | Return Default Context Object 10 | """ 11 | return { 12 | 'sdk': { 13 | 'name': 'nls-python-sdk', 14 | 'version': '1.1.0', 15 | 'language': 'python' 16 | } 17 | } 18 | 19 | 20 | def wav2pcm(wavfile, pcmfile): 21 | """ 22 | Turn wav into pcm 23 | 24 | Parameters 25 | ---------- 26 | wavfile: str 27 | wav file path 28 | pcmfile: str 29 | output pcm file path 30 | """ 31 | with open(wavfile, 'rb') as i, open(pcmfile, 'wb') as o: 32 | i.seek(0) 33 | _id = i.read(4) 34 | _id = unpack('>I', _id) 35 | _size = i.read(4) 36 | _size = unpack('I', _type) 39 | if _id[0] != 0x52494646 or _type[0] != 0x57415645: 40 | raise ValueError('not a wav!') 41 | i.read(32) 42 | result = i.read() 43 | o.write(result) 44 | 45 | -------------------------------------------------------------------------------- /nls/websocket/_ssl_compat.py: -------------------------------------------------------------------------------- 1 | """ 2 | _ssl_compat.py 3 | websocket - WebSocket client library for Python 4 | 5 | Copyright 2021 engn33r 6 | 7 | Licensed under the Apache License, Version 2.0 (the "License"); 8 | you may not use this file except in compliance with the License. 9 | You may obtain a copy of the License at 10 | 11 | http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | Unless required by applicable law or agreed to in writing, software 14 | distributed under the License is distributed on an "AS IS" BASIS, 15 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | See the License for the specific language governing permissions and 17 | limitations under the License. 18 | """ 19 | __all__ = ["HAVE_SSL", "ssl", "SSLError", "SSLWantReadError", "SSLWantWriteError"] 20 | 21 | try: 22 | import ssl 23 | from ssl import SSLError 24 | from ssl import SSLWantReadError 25 | from ssl import SSLWantWriteError 26 | HAVE_CONTEXT_CHECK_HOSTNAME = False 27 | if hasattr(ssl, 'SSLContext') and hasattr(ssl.SSLContext, 'check_hostname'): 28 | HAVE_CONTEXT_CHECK_HOSTNAME = True 29 | 30 | __all__.append("HAVE_CONTEXT_CHECK_HOSTNAME") 31 | HAVE_SSL = True 32 | except ImportError: 33 | # dummy class of SSLError for environment without ssl support 34 | class SSLError(Exception): 35 | pass 36 | 37 | class SSLWantReadError(Exception): 38 | pass 39 | 40 | class SSLWantWriteError(Exception): 41 | pass 42 | 43 | ssl = None 44 | HAVE_SSL = False 45 | -------------------------------------------------------------------------------- /nls/token.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | 3 | from aliyunsdkcore.client import AcsClient 4 | from aliyunsdkcore.request import CommonRequest 5 | from .exception import GetTokenFailed 6 | 7 | import json 8 | 9 | __all__ = ['getToken'] 10 | 11 | def getToken(akid, aksecret, domain='cn-shanghai', 12 | version='2019-02-28', 13 | url='nls-meta.cn-shanghai.aliyuncs.com'): 14 | """ 15 | Help methods to get token from aliyun by giving access id and access secret 16 | key 17 | 18 | Parameters: 19 | ----------- 20 | akid: str 21 | access id from aliyun 22 | aksecret: str 23 | access secret key from aliyun 24 | domain: str: 25 | default is cn-shanghai 26 | version: str: 27 | default is 2019-02-28 28 | url: str 29 | full url for getting token, default is 30 | nls-meta.cn-shanghai.aliyuncs.com 31 | """ 32 | if akid is None or aksecret is None: 33 | raise GetTokenFailed('No akid or aksecret') 34 | client = AcsClient(akid, aksecret, domain) 35 | request = CommonRequest() 36 | request.set_method('POST') 37 | request.set_domain(url) 38 | request.set_version(version) 39 | request.set_action_name('CreateToken') 40 | response = client.do_action_with_exception(request) 41 | response_json = json.loads(response) 42 | if 'Token' in response_json: 43 | token = response_json['Token'] 44 | if 'Id' in token: 45 | return token['Id'] 46 | else: 47 | raise GetTokenFailed(f'Missing id field in token:{token}') 48 | else: 49 | raise GetTokenFailed(f'Token not in response:{response_json}') 50 | -------------------------------------------------------------------------------- /tests/test_stream_input_tts.py: -------------------------------------------------------------------------------- 1 | import nls 2 | import pyaudio 3 | import time 4 | from tests.test_utils import TEST_ACCESS_TOKEN, TEST_ACCESS_APPKEY 5 | 6 | 7 | test_text = [ 8 | "流式文本语音合成SDK,", 9 | "可以将输入的文本", 10 | "合成为语音二进制数据,", 11 | "相比于非流式语音合成,", 12 | "流式合成的优势在于实时性", 13 | "更强。用户在输入文本的同时", 14 | "可以听到接近同步的语音输出,", 15 | "极大地提升了交互体验,", 16 | "减少了用户等待时间。", 17 | "适用于调用大规模", 18 | "语言模型(LLM),以", 19 | "流式输入文本的方式", 20 | "进行语音合成的场景。", 21 | ] 22 | 23 | if __name__ == "__main__": 24 | player = pyaudio.PyAudio() 25 | stream = player.open(format=pyaudio.paInt16, channels=1, rate=24000, output=True) 26 | 27 | # 创建SDK实例 28 | # 配置回调函数 29 | def test_on_data(data, *args): 30 | stream.write(data) 31 | 32 | def test_on_message(message, *args): 33 | print('on message=>{}'.format(message)) 34 | 35 | def test_on_close(*args): 36 | print('on_close: args=>{}'.format(args)) 37 | 38 | def test_on_error(message, *args): 39 | print('on_error args=>{}, message=>{}'.format(args, message)) 40 | 41 | sdk = nls.NlsStreamInputTtsSynthesizer( 42 | token=TEST_ACCESS_TOKEN, 43 | appkey=TEST_ACCESS_APPKEY, 44 | on_data=test_on_data, 45 | on_sentence_begin=test_on_message, 46 | on_sentence_synthesis=test_on_message, 47 | on_sentence_end=test_on_message, 48 | on_completed=test_on_message, 49 | on_error=test_on_error, 50 | on_close=test_on_close, 51 | callback_args=[], 52 | ) 53 | 54 | # 发送文本消息 55 | sdk.startStreamInputTts() 56 | for text in test_text: 57 | sdk.sendStreamInputTts(text) 58 | time.sleep(0.05) 59 | sdk.stopStreamInputTts() 60 | 61 | stream.stop_stream() 62 | stream.close() 63 | player.terminate() 64 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | ''' 3 | Licensed to the Apache Software Foundation(ASF) under one 4 | or more contributor license agreements. See the NOTICE file 5 | distributed with this work for additional information 6 | regarding copyright ownership. The ASF licenses this file 7 | to you under the Apache License, Version 2.0 (the 8 | "License"); you may not use this file except in compliance 9 | with the License. You may obtain a copy of the License at 10 | http: // www.apache.org/licenses/LICENSE-2.0 11 | Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations under 16 | the License. 17 | ''' 18 | import os 19 | import setuptools 20 | 21 | with open("README.md", "r") as f: 22 | long_description = f.read() 23 | 24 | requires = [ 25 | "oss2", 26 | "aliyun-python-sdk-core>=2.13.3", 27 | "matplotlib>=3.3.4" 28 | ] 29 | 30 | setup_args = { 31 | 'version': "1.1.0", 32 | 'author': "jiaqi.sjq", 33 | 'author_email': "jiaqi.sjq@alibaba-inc.com", 34 | 'description': "python sdk for nls", 35 | 'license':"Apache License 2.0", 36 | 'long_description': long_description, 37 | 'long_description_content_type': "text/markdown", 38 | 'keywords': ["nls", "sdk"], 39 | 'url': "https://github.com/..", 40 | 'packages': ["nls", "nls/websocket"], 41 | 'install_requires': requires, 42 | 'classifiers': [ 43 | "Programming Language :: Python :: 3", 44 | "License :: OSI Approved :: Apache Software License", 45 | "Operating System :: OS Independent", 46 | ], 47 | } 48 | 49 | setuptools.setup(name='nls', **setup_args) 50 | -------------------------------------------------------------------------------- /nls/logging.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | 3 | import logging 4 | 5 | _logger = logging.getLogger('nls') 6 | 7 | try: 8 | from logging import NullHandler 9 | except ImportError: 10 | class NullHandler(logging.Handler): 11 | def emit(self, record): 12 | pass 13 | 14 | _logger.addHandler(NullHandler()) 15 | _traceEnabled = False 16 | __LOG_FORMAT__ = '%(asctime)s - %(levelname)s - %(message)s' 17 | 18 | __all__=['enableTrace', 'dump', 'error', 'warning', 'debug', 'trace', 19 | 'isEnabledForError', 'isEnabledForDebug', 'isEnabledForTrace'] 20 | 21 | def enableTrace(traceable, handler=logging.StreamHandler()): 22 | """ 23 | enable log print 24 | 25 | Parameters 26 | ---------- 27 | traceable: bool 28 | whether enable log print, default log level is logging.DEBUG 29 | handler: Handler object 30 | handle how to print out log, default to stdio 31 | """ 32 | global _traceEnabled 33 | _traceEnabled = traceable 34 | if traceable: 35 | _logger.addHandler(handler) 36 | _logger.setLevel(logging.DEBUG) 37 | handler.setFormatter(logging.Formatter(__LOG_FORMAT__)) 38 | 39 | def dump(title, message): 40 | if _traceEnabled: 41 | _logger.debug('### ' + title + ' ###') 42 | _logger.debug(message) 43 | _logger.debug('########################################') 44 | 45 | def error(msg): 46 | _logger.error(msg) 47 | 48 | def warning(msg): 49 | _logger.warning(msg) 50 | 51 | def debug(msg): 52 | _logger.debug(msg) 53 | 54 | def trace(msg): 55 | if _traceEnabled: 56 | _logger.debug(msg) 57 | 58 | def isEnabledForError(): 59 | return _logger.isEnabledFor(logging.ERROR) 60 | 61 | def isEnabledForDebug(): 62 | return _logger.isEnabledFor(logging.Debug) 63 | 64 | def isEnabledForTrace(): 65 | return _traceEnabled 66 | -------------------------------------------------------------------------------- /nls/websocket/_logging.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | """ 4 | 5 | """ 6 | _logging.py 7 | websocket - WebSocket client library for Python 8 | 9 | Copyright 2021 engn33r 10 | 11 | Licensed under the Apache License, Version 2.0 (the "License"); 12 | you may not use this file except in compliance with the License. 13 | You may obtain a copy of the License at 14 | 15 | http://www.apache.org/licenses/LICENSE-2.0 16 | 17 | Unless required by applicable law or agreed to in writing, software 18 | distributed under the License is distributed on an "AS IS" BASIS, 19 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | See the License for the specific language governing permissions and 21 | limitations under the License. 22 | """ 23 | import logging 24 | 25 | _logger = logging.getLogger('websocket') 26 | try: 27 | from logging import NullHandler 28 | except ImportError: 29 | class NullHandler(logging.Handler): 30 | def emit(self, record): 31 | pass 32 | 33 | _logger.addHandler(NullHandler()) 34 | 35 | _traceEnabled = False 36 | 37 | __all__ = ["enableTrace", "dump", "error", "warning", "debug", "trace", 38 | "isEnabledForError", "isEnabledForDebug", "isEnabledForTrace"] 39 | 40 | 41 | def enableTrace(traceable, handler=logging.StreamHandler()): 42 | """ 43 | Turn on/off the traceability. 44 | 45 | Parameters 46 | ---------- 47 | traceable: bool 48 | If set to True, traceability is enabled. 49 | """ 50 | global _traceEnabled 51 | _traceEnabled = traceable 52 | if traceable: 53 | _logger.addHandler(handler) 54 | _logger.setLevel(logging.ERROR) 55 | 56 | 57 | def dump(title, message): 58 | if _traceEnabled: 59 | _logger.debug("--- " + title + " ---") 60 | _logger.debug(message) 61 | _logger.debug("-----------------------") 62 | 63 | 64 | def error(msg): 65 | _logger.error(msg) 66 | 67 | 68 | def warning(msg): 69 | _logger.warning(msg) 70 | 71 | 72 | def debug(msg): 73 | _logger.debug(msg) 74 | 75 | 76 | def trace(msg): 77 | if _traceEnabled: 78 | _logger.debug(msg) 79 | 80 | 81 | def isEnabledForError(): 82 | return _logger.isEnabledFor(logging.ERROR) 83 | 84 | 85 | def isEnabledForDebug(): 86 | return _logger.isEnabledFor(logging.DEBUG) 87 | 88 | 89 | def isEnabledForTrace(): 90 | return _traceEnabled 91 | -------------------------------------------------------------------------------- /nls/websocket/_cookiejar.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | """ 4 | 5 | """ 6 | _cookiejar.py 7 | websocket - WebSocket client library for Python 8 | 9 | Copyright 2021 engn33r 10 | 11 | Licensed under the Apache License, Version 2.0 (the "License"); 12 | you may not use this file except in compliance with the License. 13 | You may obtain a copy of the License at 14 | 15 | http://www.apache.org/licenses/LICENSE-2.0 16 | 17 | Unless required by applicable law or agreed to in writing, software 18 | distributed under the License is distributed on an "AS IS" BASIS, 19 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | See the License for the specific language governing permissions and 21 | limitations under the License. 22 | """ 23 | import http.cookies 24 | 25 | 26 | class SimpleCookieJar: 27 | def __init__(self): 28 | self.jar = dict() 29 | 30 | def add(self, set_cookie): 31 | if set_cookie: 32 | simpleCookie = http.cookies.SimpleCookie(set_cookie) 33 | 34 | for k, v in simpleCookie.items(): 35 | domain = v.get("domain") 36 | if domain: 37 | if not domain.startswith("."): 38 | domain = "." + domain 39 | cookie = self.jar.get(domain) if self.jar.get(domain) else http.cookies.SimpleCookie() 40 | cookie.update(simpleCookie) 41 | self.jar[domain.lower()] = cookie 42 | 43 | def set(self, set_cookie): 44 | if set_cookie: 45 | simpleCookie = http.cookies.SimpleCookie(set_cookie) 46 | 47 | for k, v in simpleCookie.items(): 48 | domain = v.get("domain") 49 | if domain: 50 | if not domain.startswith("."): 51 | domain = "." + domain 52 | self.jar[domain.lower()] = simpleCookie 53 | 54 | def get(self, host): 55 | if not host: 56 | return "" 57 | 58 | cookies = [] 59 | for domain, simpleCookie in self.jar.items(): 60 | host = host.lower() 61 | if host.endswith(domain) or host == domain[1:]: 62 | cookies.append(self.jar.get(domain)) 63 | 64 | return "; ".join(filter( 65 | None, sorted( 66 | ["%s=%s" % (k, v.value) for cookie in filter(None, cookies) for k, v in cookie.items()] 67 | ))) 68 | -------------------------------------------------------------------------------- /tests/test_sr.py: -------------------------------------------------------------------------------- 1 | import time 2 | import threading 3 | import sys 4 | 5 | import nls 6 | from tests.test_utils import (TEST_ACCESS_TOKEN, TEST_ACCESS_APPKEY) 7 | 8 | class TestSr: 9 | def __init__(self, tid, test_file): 10 | self.__th = threading.Thread(target=self.__test_run) 11 | self.__id = tid 12 | self.__test_file = test_file 13 | 14 | def loadfile(self, filename): 15 | with open(filename, 'rb') as f: 16 | self.__data = f.read() 17 | 18 | def start(self): 19 | self.loadfile(self.__test_file) 20 | self.__th.start() 21 | 22 | def test_on_start(self, message, *args): 23 | print('test_on_start:{}'.format(message)) 24 | 25 | def test_on_error(self, message, *args): 26 | print('on_error args=>{}'.format(args)) 27 | 28 | def test_on_close(self, *args): 29 | print('on_close: args=>{}'.format(args)) 30 | 31 | def test_on_result_chg(self, message, *args): 32 | print('test_on_chg:{}'.format(message)) 33 | 34 | def test_on_completed(self, message, *args): 35 | print('on_completed:args=>{} message=>{}'.format(args, message)) 36 | 37 | 38 | def __test_run(self): 39 | print('thread:{} start..'.format(self.__id)) 40 | 41 | sr = nls.NlsSpeechRecognizer( 42 | token=TEST_ACCESS_TOKEN, 43 | appkey=TEST_ACCESS_APPKEY, 44 | on_start=self.test_on_start, 45 | on_result_changed=self.test_on_result_chg, 46 | on_completed=self.test_on_completed, 47 | on_error=self.test_on_error, 48 | on_close=self.test_on_close, 49 | callback_args=[self.__id] 50 | ) 51 | print("{}: session start".format(self.__id)) 52 | r = sr.start(aformat="pcm", ex={"hello":123}) 53 | 54 | self.__slices = zip(*(iter(self.__data),) * 640) 55 | for i in self.__slices: 56 | sr.send_audio(bytes(i)) 57 | time.sleep(0.01) 58 | 59 | r = sr.stop() 60 | print("{}: sr stopped:{}".format(self.__id, r)) 61 | time.sleep(1) 62 | 63 | def multiruntest(num=500): 64 | for i in range(0, num): 65 | name = 'thread' + str(i) 66 | t = TestSr(name, 'tests/test1.pcm') 67 | t.start() 68 | 69 | nls.enableTrace(True) 70 | multiruntest(1) 71 | 72 | 73 | -------------------------------------------------------------------------------- /tests/test_tts.py: -------------------------------------------------------------------------------- 1 | import time 2 | import threading 3 | import sys 4 | 5 | import nls 6 | from tests.test_utils import (TEST_ACCESS_TOKEN, TEST_ACCESS_APPKEY) 7 | 8 | 9 | TEXT='大壮正想去摘取花瓣,谁知阿丽和阿强突然内讧,阿丽拿去手枪向树干边的阿强射击,两声枪响,阿强直接倒入水中' 10 | 11 | class TestTts: 12 | def __init__(self, tid, test_file): 13 | self.__th = threading.Thread(target=self.__test_run) 14 | self.__id = tid 15 | self.__test_file = test_file 16 | 17 | def start(self, text): 18 | self.__text = text 19 | self.__f = open(self.__test_file, "wb") 20 | self.__th.start() 21 | 22 | def test_on_metainfo(self, message, *args): 23 | print("on_metainfo message=>{}".format(message)) 24 | 25 | def test_on_error(self, message, *args): 26 | print("on_error args=>{}".format(args)) 27 | 28 | def test_on_close(self, *args): 29 | print("on_close: args=>{}".format(args)) 30 | try: 31 | self.__f.close() 32 | except Exception as e: 33 | print("close file failed since:", e) 34 | 35 | def test_on_data(self, data, *args): 36 | try: 37 | self.__f.write(data) 38 | except Exception as e: 39 | print("write data failed:", e) 40 | 41 | def test_on_completed(self, message, *args): 42 | print("on_completed:args=>{} message=>{}".format(args, message)) 43 | 44 | 45 | def __test_run(self): 46 | print("thread:{} start..".format(self.__id)) 47 | tts = nls.NlsSpeechSynthesizer( 48 | token=TEST_ACCESS_TOKEN, 49 | appkey=TEST_ACCESS_APPKEY, 50 | long_tts=False, 51 | on_metainfo=self.test_on_metainfo, 52 | on_data=self.test_on_data, 53 | on_completed=self.test_on_completed, 54 | on_error=self.test_on_error, 55 | on_close=self.test_on_close, 56 | callback_args=[self.__id] 57 | ) 58 | 59 | print("{}: session start".format(self.__id)) 60 | r = tts.start(self.__text, voice="ailun", ex={'enable_subtitle':True}) 61 | print("{}: tts done with result:{}".format(self.__id, r)) 62 | 63 | def multiruntest(num=500): 64 | for i in range(0, num): 65 | name = "thread" + str(i) 66 | t = TestTts(name, "tests/test_tts.pcm") 67 | t.start(TEXT) 68 | 69 | nls.enableTrace(True) 70 | multiruntest(1) 71 | 72 | 73 | -------------------------------------------------------------------------------- /nls/websocket/_exceptions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Define WebSocket exceptions 3 | """ 4 | 5 | """ 6 | _exceptions.py 7 | websocket - WebSocket client library for Python 8 | 9 | Copyright 2021 engn33r 10 | 11 | Licensed under the Apache License, Version 2.0 (the "License"); 12 | you may not use this file except in compliance with the License. 13 | You may obtain a copy of the License at 14 | 15 | http://www.apache.org/licenses/LICENSE-2.0 16 | 17 | Unless required by applicable law or agreed to in writing, software 18 | distributed under the License is distributed on an "AS IS" BASIS, 19 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | See the License for the specific language governing permissions and 21 | limitations under the License. 22 | """ 23 | 24 | 25 | class WebSocketException(Exception): 26 | """ 27 | WebSocket exception class. 28 | """ 29 | pass 30 | 31 | 32 | class WebSocketProtocolException(WebSocketException): 33 | """ 34 | If the WebSocket protocol is invalid, this exception will be raised. 35 | """ 36 | pass 37 | 38 | 39 | class WebSocketPayloadException(WebSocketException): 40 | """ 41 | If the WebSocket payload is invalid, this exception will be raised. 42 | """ 43 | pass 44 | 45 | 46 | class WebSocketConnectionClosedException(WebSocketException): 47 | """ 48 | If remote host closed the connection or some network error happened, 49 | this exception will be raised. 50 | """ 51 | pass 52 | 53 | 54 | class WebSocketTimeoutException(WebSocketException): 55 | """ 56 | WebSocketTimeoutException will be raised at socket timeout during read/write data. 57 | """ 58 | pass 59 | 60 | 61 | class WebSocketProxyException(WebSocketException): 62 | """ 63 | WebSocketProxyException will be raised when proxy error occurred. 64 | """ 65 | pass 66 | 67 | 68 | class WebSocketBadStatusException(WebSocketException): 69 | """ 70 | WebSocketBadStatusException will be raised when we get bad handshake status code. 71 | """ 72 | 73 | def __init__(self, message, status_code, status_message=None, resp_headers=None): 74 | msg = message % (status_code, status_message) 75 | super().__init__(msg) 76 | self.status_code = status_code 77 | self.resp_headers = resp_headers 78 | 79 | 80 | class WebSocketAddressException(WebSocketException): 81 | """ 82 | If the websocket address info cannot be found, this exception will be raised. 83 | """ 84 | pass 85 | -------------------------------------------------------------------------------- /tests/test_realtime_meeting.py: -------------------------------------------------------------------------------- 1 | import time 2 | import threading 3 | import sys 4 | import nls 5 | 6 | class TestRealtimeMeeting: 7 | def __init__(self, tid, test_file, url): 8 | self.__th = threading.Thread(target=self.__test_run) 9 | self.__id = tid 10 | self.__test_file = test_file 11 | self.__url = url 12 | 13 | def loadfile(self, filename): 14 | with open(filename, "rb") as f: 15 | self.__data = f.read() 16 | 17 | def start(self): 18 | self.loadfile(self.__test_file) 19 | self.__th.start() 20 | 21 | def test_on_sentence_begin(self, message, *args): 22 | print("test_on_sentence_begin:{}".format(message)) 23 | 24 | def test_on_sentence_end(self, message, *args): 25 | print("test_on_sentence_end:{}".format(message)) 26 | 27 | def test_on_start(self, message, *args): 28 | print("test_on_start:{}".format(message)) 29 | 30 | def test_on_error(self, message, *args): 31 | print("on_error message=>{} args=>{}".format(message, args)) 32 | 33 | def test_on_close(self, *args): 34 | print("on_close: args=>{}".format(args)) 35 | 36 | def test_on_result_chg(self, message, *args): 37 | print("test_on_chg:{}".format(message)) 38 | 39 | def test_on_result_translated(self, message, *args): 40 | print("test_on_translated:{}".format(message)) 41 | 42 | def test_on_completed(self, message, *args): 43 | print("on_completed:args=>{} message=>{}".format(args, message)) 44 | 45 | 46 | def __test_run(self): 47 | print("thread:{} start..".format(self.__id)) 48 | rm = nls.NlsRealtimeMeeting( 49 | url=self.__url, 50 | on_sentence_begin=self.test_on_sentence_begin, 51 | on_sentence_end=self.test_on_sentence_end, 52 | on_start=self.test_on_start, 53 | on_result_changed=self.test_on_result_chg, 54 | on_result_translated=self.test_on_result_translated, 55 | on_completed=self.test_on_completed, 56 | on_error=self.test_on_error, 57 | on_close=self.test_on_close, 58 | callback_args=[self.__id] 59 | ) 60 | 61 | print("{}: session start".format(self.__id)) 62 | r = rm.start() 63 | 64 | self.__slices = zip(*(iter(self.__data),) * 640) 65 | for i in self.__slices: 66 | rm.send_audio(bytes(i)) 67 | time.sleep(0.01) 68 | 69 | time.sleep(1) 70 | 71 | r = rm.stop() 72 | print("{}: rm stopped:{}".format(self.__id, r)) 73 | time.sleep(5) 74 | 75 | def multiruntest(num=1): 76 | for i in range(0, num): 77 | name = "thread" + str(i) 78 | t = TestRealtimeMeeting(name, "tests/test1.pcm", "wss://tingwu-realtime-cn-hangzhou-pre.aliyuncs.com/api/ws/v1?") 79 | t.start() 80 | 81 | nls.enableTrace(True) 82 | multiruntest(1) 83 | 84 | 85 | -------------------------------------------------------------------------------- /tests/test_st.py: -------------------------------------------------------------------------------- 1 | import time 2 | import threading 3 | import sys 4 | 5 | import nls 6 | from tests.test_utils import (TEST_ACCESS_TOKEN, TEST_ACCESS_APPKEY) 7 | 8 | class TestSt: 9 | def __init__(self, tid, test_file): 10 | self.__th = threading.Thread(target=self.__test_run) 11 | self.__id = tid 12 | self.__test_file = test_file 13 | 14 | def loadfile(self, filename): 15 | with open(filename, "rb") as f: 16 | self.__data = f.read() 17 | 18 | def start(self): 19 | self.loadfile(self.__test_file) 20 | self.__th.start() 21 | 22 | def test_on_sentence_begin(self, message, *args): 23 | print("test_on_sentence_begin:{}".format(message)) 24 | 25 | def test_on_sentence_end(self, message, *args): 26 | print("test_on_sentence_end:{}".format(message)) 27 | 28 | def test_on_start(self, message, *args): 29 | print("test_on_start:{}".format(message)) 30 | 31 | def test_on_error(self, message, *args): 32 | print("on_error args=>{}".format(args)) 33 | 34 | def test_on_close(self, *args): 35 | print("on_close: args=>{}".format(args)) 36 | 37 | def test_on_result_chg(self, message, *args): 38 | print("test_on_chg:{}".format(message)) 39 | 40 | def test_on_completed(self, message, *args): 41 | print("on_completed:args=>{} message=>{}".format(args, message)) 42 | 43 | 44 | def __test_run(self): 45 | print("thread:{} start..".format(self.__id)) 46 | sr = nls.NlsSpeechTranscriber( 47 | token=TEST_ACCESS_TOKEN, 48 | appkey=TEST_ACCESS_APPKEY, 49 | on_sentence_begin=self.test_on_sentence_begin, 50 | on_sentence_end=self.test_on_sentence_end, 51 | on_start=self.test_on_start, 52 | on_result_changed=self.test_on_result_chg, 53 | on_completed=self.test_on_completed, 54 | on_error=self.test_on_error, 55 | on_close=self.test_on_close, 56 | callback_args=[self.__id] 57 | ) 58 | print("{}: session start".format(self.__id)) 59 | r = sr.start(aformat="pcm", 60 | enable_intermediate_result=True, 61 | enable_punctuation_prediction=True, 62 | enable_inverse_text_normalization=True) 63 | 64 | self.__slices = zip(*(iter(self.__data),) * 640) 65 | for i in self.__slices: 66 | sr.send_audio(bytes(i)) 67 | time.sleep(0.01) 68 | 69 | sr.ctrl(ex={"test":"tttt"}) 70 | time.sleep(1) 71 | 72 | r = sr.stop() 73 | print("{}: sr stopped:{}".format(self.__id, r)) 74 | time.sleep(5) 75 | 76 | def multiruntest(num=500): 77 | for i in range(0, num): 78 | name = "thread" + str(i) 79 | t = TestSt(name, "tests/test1.pcm") 80 | t.start() 81 | 82 | nls.enableTrace(True) 83 | multiruntest(1) 84 | 85 | 86 | -------------------------------------------------------------------------------- /nls/websocket/_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | _url.py 3 | websocket - WebSocket client library for Python 4 | 5 | Copyright 2021 engn33r 6 | 7 | Licensed under the Apache License, Version 2.0 (the "License"); 8 | you may not use this file except in compliance with the License. 9 | You may obtain a copy of the License at 10 | 11 | http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | Unless required by applicable law or agreed to in writing, software 14 | distributed under the License is distributed on an "AS IS" BASIS, 15 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | See the License for the specific language governing permissions and 17 | limitations under the License. 18 | """ 19 | __all__ = ["NoLock", "validate_utf8", "extract_err_message", "extract_error_code"] 20 | 21 | 22 | class NoLock: 23 | 24 | def __enter__(self): 25 | pass 26 | 27 | def __exit__(self, exc_type, exc_value, traceback): 28 | pass 29 | 30 | 31 | try: 32 | # If wsaccel is available we use compiled routines to validate UTF-8 33 | # strings. 34 | from wsaccel.utf8validator import Utf8Validator 35 | 36 | def _validate_utf8(utfbytes): 37 | return Utf8Validator().validate(utfbytes)[0] 38 | 39 | except ImportError: 40 | # UTF-8 validator 41 | # python implementation of http://bjoern.hoehrmann.de/utf-8/decoder/dfa/ 42 | 43 | _UTF8_ACCEPT = 0 44 | _UTF8_REJECT = 12 45 | 46 | _UTF8D = [ 47 | # The first part of the table maps bytes to character classes that 48 | # to reduce the size of the transition table and create bitmasks. 49 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 50 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 51 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 52 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 53 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, 54 | 7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7, 7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7, 55 | 8,8,2,2,2,2,2,2,2,2,2,2,2,2,2,2, 2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2, 56 | 10,3,3,3,3,3,3,3,3,3,3,3,3,4,3,3, 11,6,6,6,5,8,8,8,8,8,8,8,8,8,8,8, 57 | 58 | # The second part is a transition table that maps a combination 59 | # of a state of the automaton and a character class to a state. 60 | 0,12,24,36,60,96,84,12,12,12,48,72, 12,12,12,12,12,12,12,12,12,12,12,12, 61 | 12, 0,12,12,12,12,12, 0,12, 0,12,12, 12,24,12,12,12,12,12,24,12,24,12,12, 62 | 12,12,12,12,12,12,12,24,12,12,12,12, 12,24,12,12,12,12,12,12,12,24,12,12, 63 | 12,12,12,12,12,12,12,36,12,36,12,12, 12,36,12,12,12,12,12,36,12,36,12,12, 64 | 12,36,12,12,12,12,12,12,12,12,12,12, ] 65 | 66 | def _decode(state, codep, ch): 67 | tp = _UTF8D[ch] 68 | 69 | codep = (ch & 0x3f) | (codep << 6) if ( 70 | state != _UTF8_ACCEPT) else (0xff >> tp) & ch 71 | state = _UTF8D[256 + state + tp] 72 | 73 | return state, codep 74 | 75 | def _validate_utf8(utfbytes): 76 | state = _UTF8_ACCEPT 77 | codep = 0 78 | for i in utfbytes: 79 | state, codep = _decode(state, codep, i) 80 | if state == _UTF8_REJECT: 81 | return False 82 | 83 | return True 84 | 85 | 86 | def validate_utf8(utfbytes): 87 | """ 88 | validate utf8 byte string. 89 | utfbytes: utf byte string to check. 90 | return value: if valid utf8 string, return true. Otherwise, return false. 91 | """ 92 | return _validate_utf8(utfbytes) 93 | 94 | 95 | def extract_err_message(exception): 96 | if exception.args: 97 | return exception.args[0] 98 | else: 99 | return None 100 | 101 | 102 | def extract_error_code(exception): 103 | if exception.args and len(exception.args) > 1: 104 | return exception.args[0] if isinstance(exception.args[0], int) else None 105 | -------------------------------------------------------------------------------- /nls/websocket/tests/test_abnf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | """ 4 | test_abnf.py 5 | websocket - WebSocket client library for Python 6 | 7 | Copyright 2021 engn33r 8 | 9 | Licensed under the Apache License, Version 2.0 (the "License"); 10 | you may not use this file except in compliance with the License. 11 | You may obtain a copy of the License at 12 | 13 | http://www.apache.org/licenses/LICENSE-2.0 14 | 15 | Unless required by applicable law or agreed to in writing, software 16 | distributed under the License is distributed on an "AS IS" BASIS, 17 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | See the License for the specific language governing permissions and 19 | limitations under the License. 20 | """ 21 | 22 | import websocket as ws 23 | from websocket._abnf import * 24 | import unittest 25 | 26 | 27 | class ABNFTest(unittest.TestCase): 28 | 29 | def testInit(self): 30 | a = ABNF(0,0,0,0, opcode=ABNF.OPCODE_PING) 31 | self.assertEqual(a.fin, 0) 32 | self.assertEqual(a.rsv1, 0) 33 | self.assertEqual(a.rsv2, 0) 34 | self.assertEqual(a.rsv3, 0) 35 | self.assertEqual(a.opcode, 9) 36 | self.assertEqual(a.data, '') 37 | a_bad = ABNF(0,1,0,0, opcode=77) 38 | self.assertEqual(a_bad.rsv1, 1) 39 | self.assertEqual(a_bad.opcode, 77) 40 | 41 | def testValidate(self): 42 | a_invalid_ping = ABNF(0,0,0,0, opcode=ABNF.OPCODE_PING) 43 | self.assertRaises(ws._exceptions.WebSocketProtocolException, a_invalid_ping.validate, skip_utf8_validation=False) 44 | a_bad_rsv_value = ABNF(0,1,0,0, opcode=ABNF.OPCODE_TEXT) 45 | self.assertRaises(ws._exceptions.WebSocketProtocolException, a_bad_rsv_value.validate, skip_utf8_validation=False) 46 | a_bad_opcode = ABNF(0,0,0,0, opcode=77) 47 | self.assertRaises(ws._exceptions.WebSocketProtocolException, a_bad_opcode.validate, skip_utf8_validation=False) 48 | a_bad_close_frame = ABNF(0,0,0,0, opcode=ABNF.OPCODE_CLOSE, data=b'\x01') 49 | self.assertRaises(ws._exceptions.WebSocketProtocolException, a_bad_close_frame.validate, skip_utf8_validation=False) 50 | a_bad_close_frame_2 = ABNF(0,0,0,0, opcode=ABNF.OPCODE_CLOSE, data=b'\x01\x8a\xaa\xff\xdd') 51 | self.assertRaises(ws._exceptions.WebSocketProtocolException, a_bad_close_frame_2.validate, skip_utf8_validation=False) 52 | a_bad_close_frame_3 = ABNF(0,0,0,0, opcode=ABNF.OPCODE_CLOSE, data=b'\x03\xe7') 53 | self.assertRaises(ws._exceptions.WebSocketProtocolException, a_bad_close_frame_3.validate, skip_utf8_validation=True) 54 | 55 | def testMask(self): 56 | abnf_none_data = ABNF(0,0,0,0, opcode=ABNF.OPCODE_PING, mask=1, data=None) 57 | bytes_val = bytes("aaaa", 'utf-8') 58 | self.assertEqual(abnf_none_data._get_masked(bytes_val), bytes_val) 59 | abnf_str_data = ABNF(0,0,0,0, opcode=ABNF.OPCODE_PING, mask=1, data="a") 60 | self.assertEqual(abnf_str_data._get_masked(bytes_val), b'aaaa\x00') 61 | 62 | def testFormat(self): 63 | abnf_bad_rsv_bits = ABNF(2,0,0,0, opcode=ABNF.OPCODE_TEXT) 64 | self.assertRaises(ValueError, abnf_bad_rsv_bits.format) 65 | abnf_bad_opcode = ABNF(0,0,0,0, opcode=5) 66 | self.assertRaises(ValueError, abnf_bad_opcode.format) 67 | abnf_length_10 = ABNF(0,0,0,0, opcode=ABNF.OPCODE_TEXT, data="abcdefghij") 68 | self.assertEqual(b'\x01', abnf_length_10.format()[0].to_bytes(1, 'big')) 69 | self.assertEqual(b'\x8a', abnf_length_10.format()[1].to_bytes(1, 'big')) 70 | self.assertEqual("fin=0 opcode=1 data=abcdefghij", abnf_length_10.__str__()) 71 | abnf_length_20 = ABNF(0,0,0,0, opcode=ABNF.OPCODE_BINARY, data="abcdefghijabcdefghij") 72 | self.assertEqual(b'\x02', abnf_length_20.format()[0].to_bytes(1, 'big')) 73 | self.assertEqual(b'\x94', abnf_length_20.format()[1].to_bytes(1, 'big')) 74 | abnf_no_mask = ABNF(0,0,0,0, opcode=ABNF.OPCODE_TEXT, mask=0, data=b'\x01\x8a\xcc') 75 | self.assertEqual(b'\x01\x03\x01\x8a\xcc', abnf_no_mask.format()) 76 | 77 | def testFrameBuffer(self): 78 | fb = frame_buffer(0, True) 79 | self.assertEqual(fb.recv, 0) 80 | self.assertEqual(fb.skip_utf8_validation, True) 81 | fb.clear 82 | self.assertEqual(fb.header, None) 83 | self.assertEqual(fb.length, None) 84 | self.assertEqual(fb.mask, None) 85 | self.assertEqual(fb.has_mask(), False) 86 | 87 | 88 | if __name__ == "__main__": 89 | unittest.main() 90 | -------------------------------------------------------------------------------- /nls/websocket/tests/test_cookiejar.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | """ 4 | 5 | """ 6 | test_cookiejar.py 7 | websocket - WebSocket client library for Python 8 | 9 | Copyright 2021 engn33r 10 | 11 | Licensed under the Apache License, Version 2.0 (the "License"); 12 | you may not use this file except in compliance with the License. 13 | You may obtain a copy of the License at 14 | 15 | http://www.apache.org/licenses/LICENSE-2.0 16 | 17 | Unless required by applicable law or agreed to in writing, software 18 | distributed under the License is distributed on an "AS IS" BASIS, 19 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | See the License for the specific language governing permissions and 21 | limitations under the License. 22 | """ 23 | import unittest 24 | from websocket._cookiejar import SimpleCookieJar 25 | 26 | 27 | class CookieJarTest(unittest.TestCase): 28 | def testAdd(self): 29 | cookie_jar = SimpleCookieJar() 30 | cookie_jar.add("") 31 | self.assertFalse(cookie_jar.jar, "Cookie with no domain should not be added to the jar") 32 | 33 | cookie_jar = SimpleCookieJar() 34 | cookie_jar.add("a=b") 35 | self.assertFalse(cookie_jar.jar, "Cookie with no domain should not be added to the jar") 36 | 37 | cookie_jar = SimpleCookieJar() 38 | cookie_jar.add("a=b; domain=.abc") 39 | self.assertTrue(".abc" in cookie_jar.jar) 40 | 41 | cookie_jar = SimpleCookieJar() 42 | cookie_jar.add("a=b; domain=abc") 43 | self.assertTrue(".abc" in cookie_jar.jar) 44 | self.assertTrue("abc" not in cookie_jar.jar) 45 | 46 | cookie_jar = SimpleCookieJar() 47 | cookie_jar.add("a=b; c=d; domain=abc") 48 | self.assertEqual(cookie_jar.get("abc"), "a=b; c=d") 49 | self.assertEqual(cookie_jar.get(None), "") 50 | 51 | cookie_jar = SimpleCookieJar() 52 | cookie_jar.add("a=b; c=d; domain=abc") 53 | cookie_jar.add("e=f; domain=abc") 54 | self.assertEqual(cookie_jar.get("abc"), "a=b; c=d; e=f") 55 | 56 | cookie_jar = SimpleCookieJar() 57 | cookie_jar.add("a=b; c=d; domain=abc") 58 | cookie_jar.add("e=f; domain=.abc") 59 | self.assertEqual(cookie_jar.get("abc"), "a=b; c=d; e=f") 60 | 61 | cookie_jar = SimpleCookieJar() 62 | cookie_jar.add("a=b; c=d; domain=abc") 63 | cookie_jar.add("e=f; domain=xyz") 64 | self.assertEqual(cookie_jar.get("abc"), "a=b; c=d") 65 | self.assertEqual(cookie_jar.get("xyz"), "e=f") 66 | self.assertEqual(cookie_jar.get("something"), "") 67 | 68 | def testSet(self): 69 | cookie_jar = SimpleCookieJar() 70 | cookie_jar.set("a=b") 71 | self.assertFalse(cookie_jar.jar, "Cookie with no domain should not be added to the jar") 72 | 73 | cookie_jar = SimpleCookieJar() 74 | cookie_jar.set("a=b; domain=.abc") 75 | self.assertTrue(".abc" in cookie_jar.jar) 76 | 77 | cookie_jar = SimpleCookieJar() 78 | cookie_jar.set("a=b; domain=abc") 79 | self.assertTrue(".abc" in cookie_jar.jar) 80 | self.assertTrue("abc" not in cookie_jar.jar) 81 | 82 | cookie_jar = SimpleCookieJar() 83 | cookie_jar.set("a=b; c=d; domain=abc") 84 | self.assertEqual(cookie_jar.get("abc"), "a=b; c=d") 85 | 86 | cookie_jar = SimpleCookieJar() 87 | cookie_jar.set("a=b; c=d; domain=abc") 88 | cookie_jar.set("e=f; domain=abc") 89 | self.assertEqual(cookie_jar.get("abc"), "e=f") 90 | 91 | cookie_jar = SimpleCookieJar() 92 | cookie_jar.set("a=b; c=d; domain=abc") 93 | cookie_jar.set("e=f; domain=.abc") 94 | self.assertEqual(cookie_jar.get("abc"), "e=f") 95 | 96 | cookie_jar = SimpleCookieJar() 97 | cookie_jar.set("a=b; c=d; domain=abc") 98 | cookie_jar.set("e=f; domain=xyz") 99 | self.assertEqual(cookie_jar.get("abc"), "a=b; c=d") 100 | self.assertEqual(cookie_jar.get("xyz"), "e=f") 101 | self.assertEqual(cookie_jar.get("something"), "") 102 | 103 | def testGet(self): 104 | cookie_jar = SimpleCookieJar() 105 | cookie_jar.set("a=b; c=d; domain=abc.com") 106 | self.assertEqual(cookie_jar.get("abc.com"), "a=b; c=d") 107 | self.assertEqual(cookie_jar.get("x.abc.com"), "a=b; c=d") 108 | self.assertEqual(cookie_jar.get("abc.com.es"), "") 109 | self.assertEqual(cookie_jar.get("xabc.com"), "") 110 | 111 | cookie_jar.set("a=b; c=d; domain=.abc.com") 112 | self.assertEqual(cookie_jar.get("abc.com"), "a=b; c=d") 113 | self.assertEqual(cookie_jar.get("x.abc.com"), "a=b; c=d") 114 | self.assertEqual(cookie_jar.get("abc.com.es"), "") 115 | self.assertEqual(cookie_jar.get("xabc.com"), "") 116 | 117 | 118 | if __name__ == "__main__": 119 | unittest.main() 120 | -------------------------------------------------------------------------------- /nls/websocket/_url.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | """ 4 | """ 5 | _url.py 6 | websocket - WebSocket client library for Python 7 | 8 | Copyright 2021 engn33r 9 | 10 | Licensed under the Apache License, Version 2.0 (the "License"); 11 | you may not use this file except in compliance with the License. 12 | You may obtain a copy of the License at 13 | 14 | http://www.apache.org/licenses/LICENSE-2.0 15 | 16 | Unless required by applicable law or agreed to in writing, software 17 | distributed under the License is distributed on an "AS IS" BASIS, 18 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 19 | See the License for the specific language governing permissions and 20 | limitations under the License. 21 | """ 22 | 23 | import os 24 | import socket 25 | import struct 26 | 27 | from urllib.parse import unquote, urlparse 28 | 29 | 30 | __all__ = ["parse_url", "get_proxy_info"] 31 | 32 | 33 | def parse_url(url): 34 | """ 35 | parse url and the result is tuple of 36 | (hostname, port, resource path and the flag of secure mode) 37 | 38 | Parameters 39 | ---------- 40 | url: str 41 | url string. 42 | """ 43 | if ":" not in url: 44 | raise ValueError("url is invalid") 45 | 46 | scheme, url = url.split(":", 1) 47 | 48 | parsed = urlparse(url, scheme="http") 49 | if parsed.hostname: 50 | hostname = parsed.hostname 51 | else: 52 | raise ValueError("hostname is invalid") 53 | port = 0 54 | if parsed.port: 55 | port = parsed.port 56 | 57 | is_secure = False 58 | if scheme == "ws": 59 | if not port: 60 | port = 80 61 | elif scheme == "wss": 62 | is_secure = True 63 | if not port: 64 | port = 443 65 | else: 66 | raise ValueError("scheme %s is invalid" % scheme) 67 | 68 | if parsed.path: 69 | resource = parsed.path 70 | else: 71 | resource = "/" 72 | 73 | if parsed.query: 74 | resource += "?" + parsed.query 75 | 76 | return hostname, port, resource, is_secure 77 | 78 | 79 | DEFAULT_NO_PROXY_HOST = ["localhost", "127.0.0.1"] 80 | 81 | 82 | def _is_ip_address(addr): 83 | try: 84 | socket.inet_aton(addr) 85 | except socket.error: 86 | return False 87 | else: 88 | return True 89 | 90 | 91 | def _is_subnet_address(hostname): 92 | try: 93 | addr, netmask = hostname.split("/") 94 | return _is_ip_address(addr) and 0 <= int(netmask) < 32 95 | except ValueError: 96 | return False 97 | 98 | 99 | def _is_address_in_network(ip, net): 100 | ipaddr = struct.unpack('!I', socket.inet_aton(ip))[0] 101 | netaddr, netmask = net.split('/') 102 | netaddr = struct.unpack('!I', socket.inet_aton(netaddr))[0] 103 | 104 | netmask = (0xFFFFFFFF << (32 - int(netmask))) & 0xFFFFFFFF 105 | return ipaddr & netmask == netaddr 106 | 107 | 108 | def _is_no_proxy_host(hostname, no_proxy): 109 | if not no_proxy: 110 | v = os.environ.get("no_proxy", os.environ.get("NO_PROXY", "")).replace(" ", "") 111 | if v: 112 | no_proxy = v.split(",") 113 | if not no_proxy: 114 | no_proxy = DEFAULT_NO_PROXY_HOST 115 | 116 | if '*' in no_proxy: 117 | return True 118 | if hostname in no_proxy: 119 | return True 120 | if _is_ip_address(hostname): 121 | return any([_is_address_in_network(hostname, subnet) for subnet in no_proxy if _is_subnet_address(subnet)]) 122 | for domain in [domain for domain in no_proxy if domain.startswith('.')]: 123 | if hostname.endswith(domain): 124 | return True 125 | return False 126 | 127 | 128 | def get_proxy_info( 129 | hostname, is_secure, proxy_host=None, proxy_port=0, proxy_auth=None, 130 | no_proxy=None, proxy_type='http'): 131 | """ 132 | Try to retrieve proxy host and port from environment 133 | if not provided in options. 134 | Result is (proxy_host, proxy_port, proxy_auth). 135 | proxy_auth is tuple of username and password 136 | of proxy authentication information. 137 | 138 | Parameters 139 | ---------- 140 | hostname: str 141 | Websocket server name. 142 | is_secure: bool 143 | Is the connection secure? (wss) looks for "https_proxy" in env 144 | before falling back to "http_proxy" 145 | proxy_host: str 146 | http proxy host name. 147 | http_proxy_port: str or int 148 | http proxy port. 149 | http_no_proxy: list 150 | Whitelisted host names that don't use the proxy. 151 | http_proxy_auth: tuple 152 | HTTP proxy auth information. Tuple of username and password. Default is None. 153 | proxy_type: str 154 | Specify the proxy protocol (http, socks4, socks4a, socks5, socks5h). Default is "http". 155 | Use socks4a or socks5h if you want to send DNS requests through the proxy. 156 | """ 157 | if _is_no_proxy_host(hostname, no_proxy): 158 | return None, 0, None 159 | 160 | if proxy_host: 161 | port = proxy_port 162 | auth = proxy_auth 163 | return proxy_host, port, auth 164 | 165 | env_keys = ["http_proxy"] 166 | if is_secure: 167 | env_keys.insert(0, "https_proxy") 168 | 169 | for key in env_keys: 170 | value = os.environ.get(key, os.environ.get(key.upper(), "")).replace(" ", "") 171 | if value: 172 | proxy = urlparse(value) 173 | auth = (unquote(proxy.username), unquote(proxy.password)) if proxy.username else None 174 | return proxy.hostname, proxy.port, auth 175 | 176 | return None, 0, None 177 | -------------------------------------------------------------------------------- /nls/websocket/_socket.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | """ 4 | 5 | """ 6 | _socket.py 7 | websocket - WebSocket client library for Python 8 | 9 | Copyright 2021 engn33r 10 | 11 | Licensed under the Apache License, Version 2.0 (the "License"); 12 | you may not use this file except in compliance with the License. 13 | You may obtain a copy of the License at 14 | 15 | http://www.apache.org/licenses/LICENSE-2.0 16 | 17 | Unless required by applicable law or agreed to in writing, software 18 | distributed under the License is distributed on an "AS IS" BASIS, 19 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | See the License for the specific language governing permissions and 21 | limitations under the License. 22 | """ 23 | import errno 24 | import selectors 25 | import socket 26 | 27 | from ._exceptions import * 28 | from ._ssl_compat import * 29 | from ._utils import * 30 | 31 | DEFAULT_SOCKET_OPTION = [(socket.SOL_TCP, socket.TCP_NODELAY, 1)] 32 | #if hasattr(socket, "SO_KEEPALIVE"): 33 | # DEFAULT_SOCKET_OPTION.append((socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)) 34 | #if hasattr(socket, "TCP_KEEPIDLE"): 35 | # DEFAULT_SOCKET_OPTION.append((socket.SOL_TCP, socket.TCP_KEEPIDLE, 30)) 36 | #if hasattr(socket, "TCP_KEEPINTVL"): 37 | # DEFAULT_SOCKET_OPTION.append((socket.SOL_TCP, socket.TCP_KEEPINTVL, 10)) 38 | #if hasattr(socket, "TCP_KEEPCNT"): 39 | # DEFAULT_SOCKET_OPTION.append((socket.SOL_TCP, socket.TCP_KEEPCNT, 3)) 40 | 41 | _default_timeout = None 42 | 43 | __all__ = ["DEFAULT_SOCKET_OPTION", "sock_opt", "setdefaulttimeout", "getdefaulttimeout", 44 | "recv", "recv_line", "send"] 45 | 46 | 47 | class sock_opt: 48 | 49 | def __init__(self, sockopt, sslopt): 50 | if sockopt is None: 51 | sockopt = [] 52 | if sslopt is None: 53 | sslopt = {} 54 | self.sockopt = sockopt 55 | self.sslopt = sslopt 56 | self.timeout = None 57 | 58 | 59 | def setdefaulttimeout(timeout): 60 | """ 61 | Set the global timeout setting to connect. 62 | 63 | Parameters 64 | ---------- 65 | timeout: int or float 66 | default socket timeout time (in seconds) 67 | """ 68 | global _default_timeout 69 | _default_timeout = timeout 70 | 71 | 72 | def getdefaulttimeout(): 73 | """ 74 | Get default timeout 75 | 76 | Returns 77 | ---------- 78 | _default_timeout: int or float 79 | Return the global timeout setting (in seconds) to connect. 80 | """ 81 | return _default_timeout 82 | 83 | 84 | def recv(sock, bufsize): 85 | if not sock: 86 | raise WebSocketConnectionClosedException("socket is already closed.") 87 | 88 | def _recv(): 89 | try: 90 | return sock.recv(bufsize) 91 | except SSLWantReadError: 92 | pass 93 | except socket.error as exc: 94 | error_code = extract_error_code(exc) 95 | if error_code is None: 96 | raise 97 | if error_code != errno.EAGAIN or error_code != errno.EWOULDBLOCK: 98 | raise 99 | 100 | sel = selectors.DefaultSelector() 101 | sel.register(sock, selectors.EVENT_READ) 102 | 103 | r = sel.select(sock.gettimeout()) 104 | sel.close() 105 | 106 | if r: 107 | return sock.recv(bufsize) 108 | 109 | try: 110 | if sock.gettimeout() == 0: 111 | bytes_ = sock.recv(bufsize) 112 | else: 113 | bytes_ = _recv() 114 | except socket.timeout as e: 115 | message = extract_err_message(e) 116 | raise WebSocketTimeoutException(message) 117 | except SSLError as e: 118 | message = extract_err_message(e) 119 | if isinstance(message, str) and 'timed out' in message: 120 | raise WebSocketTimeoutException(message) 121 | else: 122 | raise 123 | 124 | if not bytes_: 125 | raise WebSocketConnectionClosedException( 126 | "Connection to remote host was lost.") 127 | 128 | return bytes_ 129 | 130 | 131 | def recv_line(sock): 132 | line = [] 133 | while True: 134 | c = recv(sock, 1) 135 | line.append(c) 136 | if c == b'\n': 137 | break 138 | return b''.join(line) 139 | 140 | 141 | def send(sock, data): 142 | if isinstance(data, str): 143 | data = data.encode('utf-8') 144 | 145 | if not sock: 146 | raise WebSocketConnectionClosedException("socket is already closed.") 147 | 148 | def _send(): 149 | try: 150 | return sock.send(data) 151 | except SSLWantWriteError: 152 | pass 153 | except socket.error as exc: 154 | error_code = extract_error_code(exc) 155 | if error_code is None: 156 | raise 157 | if error_code != errno.EAGAIN or error_code != errno.EWOULDBLOCK: 158 | raise 159 | 160 | sel = selectors.DefaultSelector() 161 | sel.register(sock, selectors.EVENT_WRITE) 162 | 163 | w = sel.select(sock.gettimeout()) 164 | sel.close() 165 | 166 | if w: 167 | return sock.send(data) 168 | 169 | try: 170 | if sock.gettimeout() == 0: 171 | return sock.send(data) 172 | else: 173 | return _send() 174 | except socket.timeout as e: 175 | message = extract_err_message(e) 176 | raise WebSocketTimeoutException(message) 177 | except Exception as e: 178 | message = extract_err_message(e) 179 | if isinstance(message, str) and "timed out" in message: 180 | raise WebSocketTimeoutException(message) 181 | else: 182 | raise 183 | -------------------------------------------------------------------------------- /nls/websocket/_handshake.py: -------------------------------------------------------------------------------- 1 | """ 2 | _handshake.py 3 | websocket - WebSocket client library for Python 4 | 5 | Copyright 2021 engn33r 6 | 7 | Licensed under the Apache License, Version 2.0 (the "License"); 8 | you may not use this file except in compliance with the License. 9 | You may obtain a copy of the License at 10 | 11 | http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | Unless required by applicable law or agreed to in writing, software 14 | distributed under the License is distributed on an "AS IS" BASIS, 15 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | See the License for the specific language governing permissions and 17 | limitations under the License. 18 | """ 19 | import hashlib 20 | import hmac 21 | import os 22 | from base64 import encodebytes as base64encode 23 | from http import client as HTTPStatus 24 | from ._cookiejar import SimpleCookieJar 25 | from ._exceptions import * 26 | from ._http import * 27 | from ._logging import * 28 | from ._socket import * 29 | 30 | __all__ = ["handshake_response", "handshake", "SUPPORTED_REDIRECT_STATUSES"] 31 | 32 | # websocket supported version. 33 | VERSION = 13 34 | 35 | SUPPORTED_REDIRECT_STATUSES = (HTTPStatus.MOVED_PERMANENTLY, HTTPStatus.FOUND, HTTPStatus.SEE_OTHER,) 36 | SUCCESS_STATUSES = SUPPORTED_REDIRECT_STATUSES + (HTTPStatus.SWITCHING_PROTOCOLS,) 37 | 38 | CookieJar = SimpleCookieJar() 39 | 40 | 41 | class handshake_response: 42 | 43 | def __init__(self, status, headers, subprotocol): 44 | self.status = status 45 | self.headers = headers 46 | self.subprotocol = subprotocol 47 | CookieJar.add(headers.get("set-cookie")) 48 | 49 | 50 | def handshake(sock, hostname, port, resource, **options): 51 | headers, key = _get_handshake_headers(resource, hostname, port, options) 52 | 53 | header_str = "\r\n".join(headers) 54 | send(sock, header_str) 55 | dump("request header", header_str) 56 | #print("request header:", header_str) 57 | 58 | status, resp = _get_resp_headers(sock) 59 | if status in SUPPORTED_REDIRECT_STATUSES: 60 | return handshake_response(status, resp, None) 61 | success, subproto = _validate(resp, key, options.get("subprotocols")) 62 | if not success: 63 | raise WebSocketException("Invalid WebSocket Header") 64 | 65 | return handshake_response(status, resp, subproto) 66 | 67 | 68 | def _pack_hostname(hostname): 69 | # IPv6 address 70 | if ':' in hostname: 71 | return '[' + hostname + ']' 72 | 73 | return hostname 74 | 75 | 76 | def _get_handshake_headers(resource, host, port, options): 77 | headers = [ 78 | "GET %s HTTP/1.1" % resource, 79 | "Upgrade: websocket" 80 | ] 81 | if port == 80 or port == 443: 82 | hostport = _pack_hostname(host) 83 | else: 84 | hostport = "%s:%d" % (_pack_hostname(host), port) 85 | if "host" in options and options["host"] is not None: 86 | headers.append("Host: %s" % options["host"]) 87 | else: 88 | headers.append("Host: %s" % hostport) 89 | 90 | if "suppress_origin" not in options or not options["suppress_origin"]: 91 | if "origin" in options and options["origin"] is not None: 92 | headers.append("Origin: %s" % options["origin"]) 93 | else: 94 | headers.append("Origin: http://%s" % hostport) 95 | 96 | key = _create_sec_websocket_key() 97 | 98 | # Append Sec-WebSocket-Key & Sec-WebSocket-Version if not manually specified 99 | if 'header' not in options or 'Sec-WebSocket-Key' not in options['header']: 100 | key = _create_sec_websocket_key() 101 | headers.append("Sec-WebSocket-Key: %s" % key) 102 | else: 103 | key = options['header']['Sec-WebSocket-Key'] 104 | 105 | if 'header' not in options or 'Sec-WebSocket-Version' not in options['header']: 106 | headers.append("Sec-WebSocket-Version: %s" % VERSION) 107 | 108 | if 'connection' not in options or options['connection'] is None: 109 | headers.append('Connection: Upgrade') 110 | else: 111 | headers.append(options['connection']) 112 | 113 | subprotocols = options.get("subprotocols") 114 | if subprotocols: 115 | headers.append("Sec-WebSocket-Protocol: %s" % ",".join(subprotocols)) 116 | 117 | if "header" in options: 118 | header = options["header"] 119 | if isinstance(header, dict): 120 | header = [ 121 | ": ".join([k, v]) 122 | for k, v in header.items() 123 | if v is not None 124 | ] 125 | headers.extend(header) 126 | 127 | server_cookie = CookieJar.get(host) 128 | client_cookie = options.get("cookie", None) 129 | 130 | cookie = "; ".join(filter(None, [server_cookie, client_cookie])) 131 | 132 | if cookie: 133 | headers.append("Cookie: %s" % cookie) 134 | 135 | headers.append("") 136 | headers.append("") 137 | 138 | return headers, key 139 | 140 | 141 | def _get_resp_headers(sock, success_statuses=SUCCESS_STATUSES): 142 | status, resp_headers, status_message = read_headers(sock) 143 | if status not in success_statuses: 144 | raise WebSocketBadStatusException("Handshake status %d %s", status, status_message, resp_headers) 145 | return status, resp_headers 146 | 147 | 148 | _HEADERS_TO_CHECK = { 149 | "upgrade": "websocket", 150 | "connection": "upgrade", 151 | } 152 | 153 | 154 | def _validate(headers, key, subprotocols): 155 | subproto = None 156 | for k, v in _HEADERS_TO_CHECK.items(): 157 | r = headers.get(k, None) 158 | if not r: 159 | return False, None 160 | r = [x.strip().lower() for x in r.split(',')] 161 | if v not in r: 162 | return False, None 163 | 164 | if subprotocols: 165 | subproto = headers.get("sec-websocket-protocol", None) 166 | if not subproto or subproto.lower() not in [s.lower() for s in subprotocols]: 167 | error("Invalid subprotocol: " + str(subprotocols)) 168 | return False, None 169 | subproto = subproto.lower() 170 | 171 | result = headers.get("sec-websocket-accept", None) 172 | if not result: 173 | return False, None 174 | result = result.lower() 175 | 176 | if isinstance(result, str): 177 | result = result.encode('utf-8') 178 | 179 | value = (key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode('utf-8') 180 | hashed = base64encode(hashlib.sha1(value).digest()).strip().lower() 181 | success = hmac.compare_digest(hashed, result) 182 | 183 | if success: 184 | return True, subproto 185 | else: 186 | return False, None 187 | 188 | 189 | def _create_sec_websocket_key(): 190 | randomness = os.urandom(16) 191 | return base64encode(randomness).decode('utf-8').strip() 192 | -------------------------------------------------------------------------------- /nls/core.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | 3 | import logging 4 | import threading 5 | 6 | from enum import Enum, unique 7 | from queue import Queue 8 | 9 | from . import logging, token, websocket 10 | from .exception import InvalidParameter, ConnectionTimeout, ConnectionUnavailable 11 | 12 | __URL__ = 'wss://nls-gateway.cn-shanghai.aliyuncs.com/ws/v1' 13 | __HEADER__ = [ 14 | 'Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==', 15 | 'Sec-WebSocket-Version: 13', 16 | ] 17 | 18 | __FORMAT__ = '%(asctime)s - %(levelname)s - %(message)s' 19 | #__all__ = ['NlsCore'] 20 | 21 | def core_on_msg(ws, message, args): 22 | logging.debug('core_on_msg:{}'.format(message)) 23 | if not args: 24 | logging.error('callback core_on_msg with null args') 25 | return 26 | nls = args[0] 27 | nls._NlsCore__issue_callback('on_message', [message]) 28 | 29 | def core_on_error(ws, message, args): 30 | logging.debug('core_on_error:{}'.format(message)) 31 | if not args: 32 | logging.error('callback core_on_error with null args') 33 | return 34 | nls = args[0] 35 | nls._NlsCore__issue_callback('on_error', [message]) 36 | 37 | def core_on_close(ws, close_status_code, close_msg, args): 38 | logging.debug('core_on_close') 39 | if not args: 40 | logging.error('callback core_on_close with null args') 41 | return 42 | nls = args[0] 43 | nls._NlsCore__issue_callback('on_close') 44 | 45 | def core_on_open(ws, args): 46 | logging.debug('core_on_open:{}'.format(args)) 47 | if not args: 48 | logging.debug('callback with null args') 49 | ws.close() 50 | elif len(args) != 2: 51 | logging.debug('callback args not 2') 52 | ws.close() 53 | nls = args[0] 54 | nls._NlsCore__notify_on_open() 55 | nls.start(args[1], nls._NlsCore__ping_interval, nls._NlsCore__ping_timeout) 56 | nls._NlsCore__issue_callback('on_open') 57 | 58 | def core_on_data(ws, data, opcode, flag, args): 59 | logging.debug('core_on_data opcode={}'.format(opcode)) 60 | if not args: 61 | logging.error('callback core_on_data with null args') 62 | return 63 | nls = args[0] 64 | nls._NlsCore__issue_callback('on_data', [data, opcode, flag]) 65 | 66 | @unique 67 | class NlsConnectionStatus(Enum): 68 | Disconnected = 0 69 | Connected = 1 70 | 71 | 72 | class NlsCore: 73 | """ 74 | NlsCore 75 | """ 76 | def __init__(self, 77 | url=__URL__, 78 | token=None, 79 | on_open=None, on_message=None, on_close=None, 80 | on_error=None, on_data=None, asynch=False, callback_args=[]): 81 | self.__url = url 82 | self.__async = asynch 83 | if not token: 84 | raise InvalidParameter('Must provide a valid token!') 85 | else: 86 | self.__token = token 87 | self.__callbacks = {} 88 | if on_open: 89 | self.__callbacks['on_open'] = on_open 90 | if on_message: 91 | self.__callbacks['on_message'] = on_message 92 | if on_close: 93 | self.__callbacks['on_close'] = on_close 94 | if on_error: 95 | self.__callbacks['on_error'] = on_error 96 | if on_data: 97 | self.__callbacks['on_data'] = on_data 98 | if not on_open and not on_message and not on_close and not on_error: 99 | raise InvalidParameter('Must provide at least one callback') 100 | logging.debug('callback args:{}'.format(callback_args)) 101 | self.__callback_args = callback_args 102 | self.__header = __HEADER__ + ['X-NLS-Token: {}'.format(self.__token)] 103 | websocket.enableTrace(True) 104 | self.__ws = websocket.WebSocketApp(self.__url, 105 | self.__header, 106 | on_message=core_on_msg, 107 | on_data=core_on_data, 108 | on_error=core_on_error, 109 | on_close=core_on_close, 110 | callback_args=[self]) 111 | self.__ws.on_open = core_on_open 112 | self.__lock = threading.Lock() 113 | self.__cond = threading.Condition() 114 | self.__connection_status = NlsConnectionStatus.Disconnected 115 | 116 | def start(self, msg, ping_interval, ping_timeout): 117 | self.__lock.acquire() 118 | self.__ping_interval = ping_interval 119 | self.__ping_timeout = ping_timeout 120 | if self.__connection_status == NlsConnectionStatus.Disconnected: 121 | self.__ws.update_args(self, msg) 122 | self.__lock.release() 123 | self.__connect_before_start(ping_interval, ping_timeout) 124 | else: 125 | self.__lock.release() 126 | self.__ws.send(msg) 127 | 128 | def __notify_on_open(self): 129 | logging.debug('notify on open') 130 | with self.__cond: 131 | self.__connection_status = NlsConnectionStatus.Connected 132 | self.__cond.notify() 133 | 134 | def __issue_callback(self, which, exargs=[]): 135 | if which not in self.__callbacks: 136 | logging.error('no such callback:{}'.format(which)) 137 | return 138 | if which is 'on_close': 139 | with self.__cond: 140 | self.__connection_status = NlsConnectionStatus.Disconnected 141 | self.__cond.notify() 142 | args = exargs+self.__callback_args 143 | self.__callbacks[which](*args) 144 | 145 | def send(self, msg, binary): 146 | self.__lock.acquire() 147 | if self.__connection_status == NlsConnectionStatus.Disconnected: 148 | self.__lock.release() 149 | logging.error('start before send') 150 | raise ConnectionUnavailable('Must call start before send!') 151 | else: 152 | self.__lock.release() 153 | if binary: 154 | self.__ws.send(msg, opcode=websocket.ABNF.OPCODE_BINARY) 155 | else: 156 | logging.debug('send {}'.format(msg)) 157 | self.__ws.send(msg) 158 | 159 | def shutdown(self): 160 | self.__ws.close() 161 | 162 | def __run(self, ping_interval, ping_timeout): 163 | logging.debug('ws run...') 164 | self.__ws.run_forever(ping_interval=ping_interval, 165 | ping_timeout=ping_timeout) 166 | with self.__lock: 167 | self.__connection_status = NlsConnectionStatus.Disconnected 168 | logging.debug('ws exit...') 169 | 170 | def __connect_before_start(self, ping_interval, ping_timeout): 171 | with self.__cond: 172 | self.__th = threading.Thread(target=self.__run, 173 | args=[ping_interval, ping_timeout]) 174 | self.__th.start() 175 | if self.__connection_status == NlsConnectionStatus.Disconnected: 176 | logging.debug('wait cond wakeup') 177 | if not self.__async: 178 | if self.__cond.wait(timeout=10): 179 | logging.debug('wakeup without timeout') 180 | return self.__connection_status == NlsConnectionStatus.Connected 181 | else: 182 | logging.debug('wakeup with timeout') 183 | raise ConnectionTimeout('Wait response timeout! Please check local network!') 184 | -------------------------------------------------------------------------------- /nls/websocket/tests/test_app.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | """ 4 | test_app.py 5 | websocket - WebSocket client library for Python 6 | 7 | Copyright 2021 engn33r 8 | 9 | Licensed under the Apache License, Version 2.0 (the "License"); 10 | you may not use this file except in compliance with the License. 11 | You may obtain a copy of the License at 12 | 13 | http://www.apache.org/licenses/LICENSE-2.0 14 | 15 | Unless required by applicable law or agreed to in writing, software 16 | distributed under the License is distributed on an "AS IS" BASIS, 17 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | See the License for the specific language governing permissions and 19 | limitations under the License. 20 | """ 21 | 22 | import os 23 | import os.path 24 | import websocket as ws 25 | import ssl 26 | import unittest 27 | 28 | # Skip test to access the internet unless TEST_WITH_INTERNET == 1 29 | TEST_WITH_INTERNET = os.environ.get('TEST_WITH_INTERNET', '0') == '1' 30 | # Skip tests relying on local websockets server unless LOCAL_WS_SERVER_PORT != -1 31 | LOCAL_WS_SERVER_PORT = os.environ.get('LOCAL_WS_SERVER_PORT', '-1') 32 | TEST_WITH_LOCAL_SERVER = LOCAL_WS_SERVER_PORT != '-1' 33 | TRACEABLE = True 34 | 35 | 36 | class WebSocketAppTest(unittest.TestCase): 37 | 38 | class NotSetYet: 39 | """ A marker class for signalling that a value hasn't been set yet. 40 | """ 41 | 42 | def setUp(self): 43 | ws.enableTrace(TRACEABLE) 44 | 45 | WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet() 46 | WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet() 47 | WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet() 48 | 49 | def tearDown(self): 50 | WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet() 51 | WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet() 52 | WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet() 53 | 54 | @unittest.skipUnless(TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled") 55 | def testKeepRunning(self): 56 | """ A WebSocketApp should keep running as long as its self.keep_running 57 | is not False (in the boolean context). 58 | """ 59 | 60 | def on_open(self, *args, **kwargs): 61 | """ Set the keep_running flag for later inspection and immediately 62 | close the connection. 63 | """ 64 | self.send("hello!") 65 | WebSocketAppTest.keep_running_open = self.keep_running 66 | self.keep_running = False 67 | 68 | def on_message(wsapp, message): 69 | print(message) 70 | self.close() 71 | 72 | def on_close(self, *args, **kwargs): 73 | """ Set the keep_running flag for the test to use. 74 | """ 75 | WebSocketAppTest.keep_running_close = self.keep_running 76 | 77 | app = ws.WebSocketApp('ws://127.0.0.1:' + LOCAL_WS_SERVER_PORT, on_open=on_open, on_close=on_close, on_message=on_message) 78 | app.run_forever() 79 | 80 | @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") 81 | def testSockMaskKey(self): 82 | """ A WebSocketApp should forward the received mask_key function down 83 | to the actual socket. 84 | """ 85 | 86 | def my_mask_key_func(): 87 | return "\x00\x00\x00\x00" 88 | 89 | app = ws.WebSocketApp('wss://stream.meetup.com/2/rsvps', get_mask_key=my_mask_key_func) 90 | 91 | # if numpy is installed, this assertion fail 92 | # Note: We can't use 'is' for comparing the functions directly, need to use 'id'. 93 | self.assertEqual(id(app.get_mask_key), id(my_mask_key_func)) 94 | 95 | @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") 96 | def testInvalidPingIntervalPingTimeout(self): 97 | """ Test exception handling if ping_interval < ping_timeout 98 | """ 99 | 100 | def on_ping(app, msg): 101 | print("Got a ping!") 102 | app.close() 103 | 104 | def on_pong(app, msg): 105 | print("Got a pong! No need to respond") 106 | app.close() 107 | 108 | app = ws.WebSocketApp('wss://api-pub.bitfinex.com/ws/1', on_ping=on_ping, on_pong=on_pong) 109 | self.assertRaises(ws.WebSocketException, app.run_forever, ping_interval=1, ping_timeout=2, sslopt={"cert_reqs": ssl.CERT_NONE}) 110 | 111 | @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") 112 | def testPingInterval(self): 113 | """ Test WebSocketApp proper ping functionality 114 | """ 115 | 116 | def on_ping(app, msg): 117 | print("Got a ping!") 118 | app.close() 119 | 120 | def on_pong(app, msg): 121 | print("Got a pong! No need to respond") 122 | app.close() 123 | 124 | app = ws.WebSocketApp('wss://api-pub.bitfinex.com/ws/1', on_ping=on_ping, on_pong=on_pong) 125 | app.run_forever(ping_interval=2, ping_timeout=1, sslopt={"cert_reqs": ssl.CERT_NONE}) 126 | 127 | @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") 128 | def testOpcodeClose(self): 129 | """ Test WebSocketApp close opcode 130 | """ 131 | 132 | app = ws.WebSocketApp('wss://tsock.us1.twilio.com/v3/wsconnect') 133 | app.run_forever(ping_interval=2, ping_timeout=1, ping_payload="Ping payload") 134 | 135 | @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") 136 | def testOpcodeBinary(self): 137 | """ Test WebSocketApp binary opcode 138 | """ 139 | 140 | app = ws.WebSocketApp('streaming.vn.teslamotors.com/streaming/') 141 | app.run_forever(ping_interval=2, ping_timeout=1, ping_payload="Ping payload") 142 | 143 | @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") 144 | def testBadPingInterval(self): 145 | """ A WebSocketApp handling of negative ping_interval 146 | """ 147 | app = ws.WebSocketApp('wss://api-pub.bitfinex.com/ws/1') 148 | self.assertRaises(ws.WebSocketException, app.run_forever, ping_interval=-5, sslopt={"cert_reqs": ssl.CERT_NONE}) 149 | 150 | @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") 151 | def testBadPingTimeout(self): 152 | """ A WebSocketApp handling of negative ping_timeout 153 | """ 154 | app = ws.WebSocketApp('wss://api-pub.bitfinex.com/ws/1') 155 | self.assertRaises(ws.WebSocketException, app.run_forever, ping_timeout=-3, sslopt={"cert_reqs": ssl.CERT_NONE}) 156 | 157 | @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") 158 | def testCloseStatusCode(self): 159 | """ Test extraction of close frame status code and close reason in WebSocketApp 160 | """ 161 | def on_close(wsapp, close_status_code, close_msg): 162 | print("on_close reached") 163 | 164 | app = ws.WebSocketApp('wss://tsock.us1.twilio.com/v3/wsconnect', on_close=on_close) 165 | closeframe = ws.ABNF(opcode=ws.ABNF.OPCODE_CLOSE, data=b'\x03\xe8no-init-from-client') 166 | self.assertEqual([1000, 'no-init-from-client'], app._get_close_args(closeframe)) 167 | 168 | closeframe = ws.ABNF(opcode=ws.ABNF.OPCODE_CLOSE, data=b'') 169 | self.assertEqual([None, None], app._get_close_args(closeframe)) 170 | 171 | app2 = ws.WebSocketApp('wss://tsock.us1.twilio.com/v3/wsconnect') 172 | closeframe = ws.ABNF(opcode=ws.ABNF.OPCODE_CLOSE, data=b'') 173 | self.assertEqual([None, None], app2._get_close_args(closeframe)) 174 | 175 | self.assertRaises(ws.WebSocketConnectionClosedException, app.send, data="test if connection is closed") 176 | 177 | 178 | if __name__ == "__main__": 179 | unittest.main() 180 | -------------------------------------------------------------------------------- /nls/websocket/tests/test_http.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | """ 4 | test_http.py 5 | websocket - WebSocket client library for Python 6 | 7 | Copyright 2021 engn33r 8 | 9 | Licensed under the Apache License, Version 2.0 (the "License"); 10 | you may not use this file except in compliance with the License. 11 | You may obtain a copy of the License at 12 | 13 | http://www.apache.org/licenses/LICENSE-2.0 14 | 15 | Unless required by applicable law or agreed to in writing, software 16 | distributed under the License is distributed on an "AS IS" BASIS, 17 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | See the License for the specific language governing permissions and 19 | limitations under the License. 20 | """ 21 | 22 | import os 23 | import os.path 24 | import websocket as ws 25 | from websocket._http import proxy_info, read_headers, _start_proxied_socket, _tunnel, _get_addrinfo_list, connect 26 | import unittest 27 | import ssl 28 | import websocket 29 | import socket 30 | 31 | try: 32 | from python_socks._errors import ProxyError, ProxyTimeoutError, ProxyConnectionError 33 | except: 34 | from websocket._http import ProxyError, ProxyTimeoutError, ProxyConnectionError 35 | 36 | # Skip test to access the internet unless TEST_WITH_INTERNET == 1 37 | TEST_WITH_INTERNET = os.environ.get('TEST_WITH_INTERNET', '0') == '1' 38 | TEST_WITH_PROXY = os.environ.get('TEST_WITH_PROXY', '0') == '1' 39 | # Skip tests relying on local websockets server unless LOCAL_WS_SERVER_PORT != -1 40 | LOCAL_WS_SERVER_PORT = os.environ.get('LOCAL_WS_SERVER_PORT', '-1') 41 | TEST_WITH_LOCAL_SERVER = LOCAL_WS_SERVER_PORT != '-1' 42 | 43 | 44 | class SockMock: 45 | def __init__(self): 46 | self.data = [] 47 | self.sent = [] 48 | 49 | def add_packet(self, data): 50 | self.data.append(data) 51 | 52 | def gettimeout(self): 53 | return None 54 | 55 | def recv(self, bufsize): 56 | if self.data: 57 | e = self.data.pop(0) 58 | if isinstance(e, Exception): 59 | raise e 60 | if len(e) > bufsize: 61 | self.data.insert(0, e[bufsize:]) 62 | return e[:bufsize] 63 | 64 | def send(self, data): 65 | self.sent.append(data) 66 | return len(data) 67 | 68 | def close(self): 69 | pass 70 | 71 | 72 | class HeaderSockMock(SockMock): 73 | 74 | def __init__(self, fname): 75 | SockMock.__init__(self) 76 | path = os.path.join(os.path.dirname(__file__), fname) 77 | with open(path, "rb") as f: 78 | self.add_packet(f.read()) 79 | 80 | 81 | class OptsList(): 82 | 83 | def __init__(self): 84 | self.timeout = 1 85 | self.sockopt = [] 86 | self.sslopt = {"cert_reqs": ssl.CERT_NONE} 87 | 88 | 89 | class HttpTest(unittest.TestCase): 90 | 91 | def testReadHeader(self): 92 | status, header, status_message = read_headers(HeaderSockMock("data/header01.txt")) 93 | self.assertEqual(status, 101) 94 | self.assertEqual(header["connection"], "Upgrade") 95 | # header02.txt is intentionally malformed 96 | self.assertRaises(ws.WebSocketException, read_headers, HeaderSockMock("data/header02.txt")) 97 | 98 | def testTunnel(self): 99 | self.assertRaises(ws.WebSocketProxyException, _tunnel, HeaderSockMock("data/header01.txt"), "example.com", 80, ("username", "password")) 100 | self.assertRaises(ws.WebSocketProxyException, _tunnel, HeaderSockMock("data/header02.txt"), "example.com", 80, ("username", "password")) 101 | 102 | @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") 103 | def testConnect(self): 104 | # Not currently testing an actual proxy connection, so just check whether proxy errors are raised. This requires internet for a DNS lookup 105 | if ws._http.HAVE_PYTHON_SOCKS: 106 | # Need this check, otherwise case where python_socks is not installed triggers 107 | # websocket._exceptions.WebSocketException: Python Socks is needed for SOCKS proxying but is not available 108 | self.assertRaises(ProxyTimeoutError, _start_proxied_socket, "wss://example.com", OptsList(), proxy_info(http_proxy_host="example.com", http_proxy_port="8080", proxy_type="socks4", timeout=1)) 109 | self.assertRaises(ProxyTimeoutError, _start_proxied_socket, "wss://example.com", OptsList(), proxy_info(http_proxy_host="example.com", http_proxy_port="8080", proxy_type="socks4a", timeout=1)) 110 | self.assertRaises(ProxyTimeoutError, _start_proxied_socket, "wss://example.com", OptsList(), proxy_info(http_proxy_host="example.com", http_proxy_port="8080", proxy_type="socks5", timeout=1)) 111 | self.assertRaises(ProxyTimeoutError, _start_proxied_socket, "wss://example.com", OptsList(), proxy_info(http_proxy_host="example.com", http_proxy_port="8080", proxy_type="socks5h", timeout=1)) 112 | self.assertRaises(ProxyConnectionError, connect, "wss://example.com", OptsList(), proxy_info(http_proxy_host="127.0.0.1", http_proxy_port=9999, proxy_type="socks4", timeout=1), None) 113 | 114 | self.assertRaises(TypeError, _get_addrinfo_list, None, 80, True, proxy_info(http_proxy_host="127.0.0.1", http_proxy_port="9999", proxy_type="http")) 115 | self.assertRaises(TypeError, _get_addrinfo_list, None, 80, True, proxy_info(http_proxy_host="127.0.0.1", http_proxy_port="9999", proxy_type="http")) 116 | self.assertRaises(socket.timeout, connect, "wss://google.com", OptsList(), proxy_info(http_proxy_host="8.8.8.8", http_proxy_port=9999, proxy_type="http", timeout=1), None) 117 | self.assertEqual( 118 | connect("wss://google.com", OptsList(), proxy_info(http_proxy_host="8.8.8.8", http_proxy_port=8080, proxy_type="http"), True), 119 | (True, ("google.com", 443, "/"))) 120 | # The following test fails on Mac OS with a gaierror, not an OverflowError 121 | # self.assertRaises(OverflowError, connect, "wss://example.com", OptsList(), proxy_info(http_proxy_host="127.0.0.1", http_proxy_port=99999, proxy_type="socks4", timeout=2), False) 122 | 123 | @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") 124 | @unittest.skipUnless(TEST_WITH_PROXY, "This test requires a HTTP proxy to be running on port 8899") 125 | @unittest.skipUnless(TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled") 126 | def testProxyConnect(self): 127 | ws = websocket.WebSocket() 128 | ws.connect("ws://127.0.0.1:" + LOCAL_WS_SERVER_PORT, http_proxy_host="127.0.0.1", http_proxy_port="8899", proxy_type="http") 129 | ws.send("Hello, Server") 130 | server_response = ws.recv() 131 | self.assertEqual(server_response, "Hello, Server") 132 | # self.assertEqual(_start_proxied_socket("wss://api.bitfinex.com/ws/2", OptsList(), proxy_info(http_proxy_host="127.0.0.1", http_proxy_port="8899", proxy_type="http"))[1], ("api.bitfinex.com", 443, '/ws/2')) 133 | self.assertEqual(_get_addrinfo_list("api.bitfinex.com", 443, True, proxy_info(http_proxy_host="127.0.0.1", http_proxy_port="8899", proxy_type="http")), 134 | (socket.getaddrinfo("127.0.0.1", 8899, 0, socket.SOCK_STREAM, socket.SOL_TCP), True, None)) 135 | self.assertEqual(connect("wss://api.bitfinex.com/ws/2", OptsList(), proxy_info(http_proxy_host="127.0.0.1", http_proxy_port=8899, proxy_type="http"), None)[1], ("api.bitfinex.com", 443, '/ws/2')) 136 | # TODO: Test SOCKS4 and SOCK5 proxies with unit tests 137 | 138 | @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") 139 | def testSSLopt(self): 140 | ssloptions = { 141 | "cert_reqs": ssl.CERT_NONE, 142 | "check_hostname": False, 143 | "server_hostname": "ServerName", 144 | "ssl_version": ssl.PROTOCOL_TLS, 145 | "ciphers": "TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_SHA256:\ 146 | TLS_AES_128_GCM_SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:\ 147 | ECDHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES256-GCM-SHA384:\ 148 | ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:\ 149 | DHE-RSA-CHACHA20-POLY1305:ECDHE-ECDSA-AES128-GCM-SHA256:\ 150 | ECDHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES128-GCM-SHA256:\ 151 | ECDHE-ECDSA-AES256-SHA384:ECDHE-RSA-AES256-SHA384:\ 152 | DHE-RSA-AES256-SHA256:ECDHE-ECDSA-AES128-SHA256:\ 153 | ECDHE-RSA-AES128-SHA256:DHE-RSA-AES128-SHA256:\ 154 | ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES256-SHA", 155 | "ecdh_curve": "prime256v1" 156 | } 157 | ws_ssl1 = websocket.WebSocket(sslopt=ssloptions) 158 | ws_ssl1.connect("wss://api.bitfinex.com/ws/2") 159 | ws_ssl1.send("Hello") 160 | ws_ssl1.close() 161 | 162 | ws_ssl2 = websocket.WebSocket(sslopt={"check_hostname": True}) 163 | ws_ssl2.connect("wss://api.bitfinex.com/ws/2") 164 | ws_ssl2.close 165 | 166 | def testProxyInfo(self): 167 | self.assertEqual(proxy_info(http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="http").proxy_protocol, "http") 168 | self.assertRaises(ProxyError, proxy_info, http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="badval") 169 | self.assertEqual(proxy_info(http_proxy_host="example.com", http_proxy_port="8080", proxy_type="http").proxy_host, "example.com") 170 | self.assertEqual(proxy_info(http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="http").proxy_port, "8080") 171 | self.assertEqual(proxy_info(http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="http").auth, None) 172 | self.assertEqual(proxy_info(http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="http", http_proxy_auth=("my_username123", "my_pass321")).auth[0], "my_username123") 173 | self.assertEqual(proxy_info(http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="http", http_proxy_auth=("my_username123", "my_pass321")).auth[1], "my_pass321") 174 | 175 | 176 | if __name__ == "__main__": 177 | unittest.main() 178 | -------------------------------------------------------------------------------- /nls/speech_synthesizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | 3 | import logging 4 | from re import I 5 | import uuid 6 | import json 7 | import threading 8 | 9 | from nls.core import NlsCore 10 | from . import logging 11 | from . import util 12 | from .exception import (StartTimeoutException, 13 | CompleteTimeoutException, 14 | InvalidParameter) 15 | 16 | __SPEECH_SYNTHESIZER_NAMESPACE__ = 'SpeechSynthesizer' 17 | __SPEECH_LONG_SYNTHESIZER_NAMESPACE__ = 'SpeechLongSynthesizer' 18 | 19 | __SPEECH_SYNTHESIZER_REQUEST_CMD__ = { 20 | 'start': 'StartSynthesis' 21 | } 22 | 23 | __URL__ = 'wss://nls-gateway.cn-shanghai.aliyuncs.com/ws/v1' 24 | 25 | __all__ = ['NlsSpeechSynthesizer'] 26 | 27 | 28 | class NlsSpeechSynthesizer: 29 | """ 30 | Api for text-to-speech 31 | """ 32 | def __init__(self, 33 | url=__URL__, 34 | token=None, 35 | appkey=None, 36 | long_tts=False, 37 | on_metainfo=None, 38 | on_data=None, 39 | on_completed=None, 40 | on_error=None, 41 | on_close=None, 42 | callback_args=[]): 43 | """ 44 | NlsSpeechSynthesizer initialization 45 | 46 | Parameters: 47 | ----------- 48 | url: str 49 | websocket url. 50 | akid: str 51 | access id from aliyun. if you provide a token, ignore this argument. 52 | appkey: str 53 | appkey from aliyun 54 | long_tts: bool 55 | whether using long-text synthesis support, default is False. long-text synthesis 56 | can support longer text but more expensive. 57 | on_metainfo: function 58 | Callback object which is called when recognition started. 59 | on_start has two arguments. 60 | The 1st argument is message which is a json format string. 61 | The 2nd argument is *args which is callback_args. 62 | on_data: function 63 | Callback object which is called when partial synthesis result arrived 64 | arrived. 65 | on_result_changed has two arguments. 66 | The 1st argument is binary data corresponding to aformat in start 67 | method. 68 | The 2nd argument is *args which is callback_args. 69 | on_completed: function 70 | Callback object which is called when recognition is completed. 71 | on_completed has two arguments. 72 | The 1st argument is message which is a json format string. 73 | The 2nd argument is *args which is callback_args. 74 | on_error: function 75 | Callback object which is called when any error occurs. 76 | on_error has two arguments. 77 | The 1st argument is message which is a json format string. 78 | The 2nd argument is *args which is callback_args. 79 | on_close: function 80 | Callback object which is called when connection closed. 81 | on_close has one arguments. 82 | The 1st argument is *args which is callback_args. 83 | callback_args: list 84 | callback_args will return in callbacks above for *args. 85 | """ 86 | if not token or not appkey: 87 | raise InvalidParameter('Must provide token and appkey') 88 | self.__response_handler__ = { 89 | 'MetaInfo': self.__metainfo, 90 | 'SynthesisCompleted': self.__synthesis_completed, 91 | 'TaskFailed': self.__task_failed 92 | } 93 | self.__callback_args = callback_args 94 | self.__url = url 95 | self.__appkey = appkey 96 | self.__token = token 97 | self.__long_tts = long_tts 98 | self.__start_cond = threading.Condition() 99 | self.__start_flag = False 100 | self.__on_metainfo = on_metainfo 101 | self.__on_data = on_data 102 | self.__on_completed = on_completed 103 | self.__on_error = on_error 104 | self.__on_close = on_close 105 | self.__allow_aformat = ( 106 | 'pcm', 'wav', 'mp3' 107 | ) 108 | self.__allow_sample_rate = ( 109 | 8000, 11025, 16000, 22050, 110 | 24000, 32000, 44100, 48000 111 | ) 112 | 113 | def __handle_message(self, message): 114 | logging.debug('__handle_message') 115 | try: 116 | __result = json.loads(message) 117 | if __result['header']['name'] in self.__response_handler__: 118 | __handler = self.__response_handler__[__result['header']['name']] 119 | __handler(message) 120 | else: 121 | logging.error('cannot handle cmd{}'.format( 122 | __result['header']['name'])) 123 | return 124 | except json.JSONDecodeError: 125 | logging.error('cannot parse message:{}'.format(message)) 126 | return 127 | 128 | def __syn_core_on_open(self): 129 | logging.debug('__syn_core_on_open') 130 | with self.__start_cond: 131 | self.__start_flag = True 132 | self.__start_cond.notify() 133 | 134 | def __syn_core_on_data(self, data, opcode, flag): 135 | logging.debug('__syn_core_on_data') 136 | if self.__on_data: 137 | self.__on_data(data, *self.__callback_args) 138 | 139 | def __syn_core_on_msg(self, msg, *args): 140 | logging.debug('__syn_core_on_msg:msg={} args={}'.format(msg, args)) 141 | self.__handle_message(msg) 142 | 143 | def __syn_core_on_error(self, msg, *args): 144 | logging.debug('__sr_core_on_error:msg={} args={}'.format(msg, args)) 145 | 146 | def __syn_core_on_close(self): 147 | logging.debug('__sr_core_on_close') 148 | if self.__on_close: 149 | self.__on_close(*self.__callback_args) 150 | with self.__start_cond: 151 | self.__start_flag = False 152 | self.__start_cond.notify() 153 | 154 | def __metainfo(self, message): 155 | logging.debug('__metainfo') 156 | if self.__on_metainfo: 157 | self.__on_metainfo(message, *self.__callback_args) 158 | 159 | def __synthesis_completed(self, message): 160 | logging.debug('__synthesis_completed') 161 | self.__nls.shutdown() 162 | logging.debug('__synthesis_completed shutdown done') 163 | if self.__on_completed: 164 | self.__on_completed(message, *self.__callback_args) 165 | with self.__start_cond: 166 | self.__start_flag = False 167 | self.__start_cond.notify() 168 | 169 | def __task_failed(self, message): 170 | logging.debug('__task_failed') 171 | with self.__start_cond: 172 | self.__start_flag = False 173 | self.__start_cond.notify() 174 | if self.__on_error: 175 | self.__on_error(message, *self.__callback_args) 176 | 177 | def start(self, 178 | text=None, 179 | voice='xiaoyun', 180 | aformat='pcm', 181 | sample_rate=16000, 182 | volume=50, 183 | speech_rate=0, 184 | pitch_rate=0, 185 | wait_complete=True, 186 | start_timeout=10, 187 | completed_timeout=60, 188 | ex:dict=None): 189 | """ 190 | Synthesis start 191 | 192 | Parameters: 193 | ----------- 194 | text: str 195 | utf-8 text 196 | voice: str 197 | voice for text-to-speech, default is xiaoyun 198 | aformat: str 199 | audio binary format, support: 'pcm', 'wav', 'mp3', default is 'pcm' 200 | sample_rate: int 201 | audio sample rate, default is 16000, support:8000, 11025, 16000, 22050, 202 | 24000, 32000, 44100, 48000 203 | volume: int 204 | audio volume, from 0~100, default is 50 205 | speech_rate: int 206 | speech rate from -500~500, default is 0 207 | pitch_rate: int 208 | pitch for voice from -500~500, default is 0 209 | wait_complete: bool 210 | whether block until syntheis completed or timeout for completed timeout 211 | start_timeout: int 212 | timeout for connection established 213 | completed_timeout: int 214 | timeout for waiting synthesis completed from connection established 215 | ex: dict 216 | dict which will merge into 'payload' field in request 217 | """ 218 | if text is None: 219 | raise InvalidParameter('Text cannot be None') 220 | 221 | self.__nls = NlsCore( 222 | url=self.__url, 223 | token=self.__token, 224 | on_open=self.__syn_core_on_open, 225 | on_message=self.__syn_core_on_msg, 226 | on_data=self.__syn_core_on_data, 227 | on_close=self.__syn_core_on_close, 228 | on_error=self.__syn_core_on_error, 229 | callback_args=[]) 230 | 231 | if aformat not in self.__allow_aformat: 232 | raise InvalidParameter('format {} not support'.format(aformat)) 233 | if sample_rate not in self.__allow_sample_rate: 234 | raise InvalidParameter('samplerate {} not support'.format(sample_rate)) 235 | if volume < 0 or volume > 100: 236 | raise InvalidParameter('volume {} not support'.format(volume)) 237 | if speech_rate < -500 or speech_rate > 500: 238 | raise InvalidParameter('speech_rate {} not support'.format(speech_rate)) 239 | if pitch_rate < -500 or pitch_rate > 500: 240 | raise InvalidParameter('pitch rate {} not support'.format(pitch_rate)) 241 | 242 | __id4 = uuid.uuid4().hex 243 | self.__task_id = uuid.uuid4().hex 244 | __namespace = __SPEECH_SYNTHESIZER_NAMESPACE__ 245 | if self.__long_tts: 246 | __namespace = __SPEECH_LONG_SYNTHESIZER_NAMESPACE__ 247 | __header = { 248 | 'message_id': __id4, 249 | 'task_id': self.__task_id, 250 | 'namespace': __namespace, 251 | 'name': __SPEECH_SYNTHESIZER_REQUEST_CMD__['start'], 252 | 'appkey': self.__appkey 253 | } 254 | __payload = { 255 | 'text': text, 256 | 'voice': voice, 257 | 'format': aformat, 258 | 'sample_rate': sample_rate, 259 | 'volume': volume, 260 | 'speech_rate': speech_rate, 261 | 'pitch_rate': pitch_rate 262 | } 263 | if ex: 264 | __payload.update(ex) 265 | __msg = { 266 | 'header': __header, 267 | 'payload': __payload, 268 | 'context': util.GetDefaultContext() 269 | } 270 | __jmsg = json.dumps(__msg) 271 | with self.__start_cond: 272 | if self.__start_flag: 273 | logging.debug('already start...') 274 | return 275 | self.__nls.start(__jmsg, ping_interval=0, ping_timeout=None) 276 | if self.__start_flag == False: 277 | if not self.__start_cond.wait(start_timeout): 278 | logging.debug('syn start timeout') 279 | raise StartTimeoutException(f'Waiting Start over {start_timeout}s') 280 | if self.__start_flag and wait_complete: 281 | if not self.__start_cond.wait(completed_timeout): 282 | raise CompleteTimeoutException(f'Waiting Complete over {completed_timeout}s') 283 | 284 | def shutdown(self): 285 | """ 286 | Shutdown connection immediately 287 | """ 288 | self.__nls.shutdown() 289 | -------------------------------------------------------------------------------- /nls/speech_recognizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | 3 | import logging 4 | import uuid 5 | import json 6 | import threading 7 | 8 | 9 | from nls.core import NlsCore 10 | from . import logging 11 | from . import util 12 | from .exception import (StartTimeoutException, 13 | StopTimeoutException, 14 | NotStartException, 15 | InvalidParameter) 16 | 17 | __SPEECH_RECOGNIZER_NAMESPACE__ = 'SpeechRecognizer' 18 | 19 | __SPEECH_RECOGNIZER_REQUEST_CMD__ = { 20 | 'start': 'StartRecognition', 21 | 'stop': 'StopRecognition' 22 | } 23 | 24 | __URL__ = 'wss://nls-gateway.cn-shanghai.aliyuncs.com/ws/v1' 25 | 26 | __all__ = ['NlsSpeechRecognizer'] 27 | 28 | 29 | class NlsSpeechRecognizer: 30 | """ 31 | Api for short sentence speech recognition 32 | """ 33 | def __init__(self, 34 | url=__URL__, 35 | token=None, 36 | appkey=None, 37 | on_start=None, 38 | on_result_changed=None, 39 | on_completed=None, 40 | on_error=None, on_close=None, 41 | callback_args=[]): 42 | """ 43 | NlsSpeechRecognizer initialization 44 | 45 | Parameters: 46 | ----------- 47 | url: str 48 | websocket url. 49 | token: str 50 | access token. if you do not have a token, provide access id and key 51 | secret from your aliyun account. 52 | appkey: str 53 | appkey from aliyun 54 | on_start: function 55 | Callback object which is called when recognition started. 56 | on_start has two arguments. 57 | The 1st argument is message which is a json format string. 58 | The 2nd argument is *args which is callback_args. 59 | on_result_changed: function 60 | Callback object which is called when partial recognition result 61 | arrived. 62 | on_result_changed has two arguments. 63 | The 1st argument is message which is a json format string. 64 | The 2nd argument is *args which is callback_args. 65 | on_completed: function 66 | Callback object which is called when recognition is completed. 67 | on_completed has two arguments. 68 | The 1st argument is message which is a json format string. 69 | The 2nd argument is *args which is callback_args. 70 | on_error: function 71 | Callback object which is called when any error occurs. 72 | on_error has two arguments. 73 | The 1st argument is message which is a json format string. 74 | The 2nd argument is *args which is callback_args. 75 | on_close: function 76 | Callback object which is called when connection closed. 77 | on_close has one arguments. 78 | The 1st argument is *args which is callback_args. 79 | callback_args: list 80 | callback_args will return in callbacks above for *args. 81 | """ 82 | if not token or not appkey: 83 | raise InvalidParameter('Must provide token and appkey') 84 | self.__response_handler__ = { 85 | 'RecognitionStarted': self.__recognition_started, 86 | 'RecognitionResultChanged': self.__recognition_result_changed, 87 | 'RecognitionCompleted': self.__recognition_completed, 88 | 'TaskFailed': self.__task_failed 89 | } 90 | self.__callback_args = callback_args 91 | self.__appkey = appkey 92 | self.__url = url 93 | self.__token = token 94 | self.__start_cond = threading.Condition() 95 | self.__start_flag = False 96 | self.__on_start = on_start 97 | self.__on_result_changed = on_result_changed 98 | self.__on_completed = on_completed 99 | self.__on_error = on_error 100 | self.__on_close = on_close 101 | self.__allow_aformat = ( 102 | 'pcm', 'opus', 'opu', 'wav', 'mp3', 'speex', 'aac', 'amr' 103 | ) 104 | 105 | def __handle_message(self, message): 106 | logging.debug('__handle_message') 107 | try: 108 | __result = json.loads(message) 109 | if __result['header']['name'] in self.__response_handler__: 110 | __handler = self.__response_handler__[ 111 | __result['header']['name']] 112 | __handler(message) 113 | else: 114 | logging.error('cannot handle cmd{}'.format( 115 | __result['header']['name'])) 116 | return 117 | except json.JSONDecodeError: 118 | logging.error('cannot parse message:{}'.format(message)) 119 | return 120 | 121 | def __sr_core_on_open(self): 122 | logging.debug('__sr_core_on_open') 123 | 124 | def __sr_core_on_msg(self, msg, *args): 125 | logging.debug('__sr_core_on_msg:msg={} args={}'.format(msg, args)) 126 | self.__handle_message(msg) 127 | 128 | def __sr_core_on_error(self, msg, *args): 129 | logging.debug('__sr_core_on_error:msg={} args={}'.format(msg, args)) 130 | 131 | def __sr_core_on_close(self): 132 | logging.debug('__sr_core_on_close') 133 | if self.__on_close: 134 | self.__on_close(*self.__callback_args) 135 | with self.__start_cond: 136 | self.__start_flag = False 137 | self.__start_cond.notify() 138 | 139 | def __recognition_started(self, message): 140 | logging.debug('__recognition_started') 141 | if self.__on_start: 142 | self.__on_start(message, *self.__callback_args) 143 | with self.__start_cond: 144 | self.__start_flag = True 145 | self.__start_cond.notify() 146 | 147 | def __recognition_result_changed(self, message): 148 | logging.debug('__recognition_result_changed') 149 | if self.__on_result_changed: 150 | self.__on_result_changed(message, *self.__callback_args) 151 | 152 | def __recognition_completed(self, message): 153 | logging.debug('__recognition_completed') 154 | self.__nls.shutdown() 155 | logging.debug('__recognition_completed shutdown done') 156 | if self.__on_completed: 157 | self.__on_completed(message, *self.__callback_args) 158 | with self.__start_cond: 159 | self.__start_flag = False 160 | self.__start_cond.notify() 161 | 162 | def __task_failed(self, message): 163 | logging.debug('__task_failed') 164 | with self.__start_cond: 165 | self.__start_flag = False 166 | self.__start_cond.notify() 167 | if self.__on_error: 168 | self.__on_error(message, *self.__callback_args) 169 | 170 | def start(self, aformat='pcm', sample_rate=16000, ch=1, 171 | enable_intermediate_result=False, 172 | enable_punctuation_prediction=False, 173 | enable_inverse_text_normalization=False, 174 | timeout=10, 175 | ping_interval=8, 176 | ping_timeout=None, 177 | ex:dict=None): 178 | """ 179 | Recognition start 180 | 181 | Parameters: 182 | ----------- 183 | aformat: str 184 | audio binary format, support: 'pcm', 'opu', 'opus', default is 'pcm' 185 | sample_rate: int 186 | audio sample rate, default is 16000 187 | ch: int 188 | audio channels, only support mono which is 1 189 | enable_intermediate_result: bool 190 | whether enable return intermediate recognition result, default is False 191 | enable_punctuation_prediction: bool 192 | whether enable punctuation prediction, default is False 193 | enable_inverse_text_normalization: bool 194 | whether enable ITN, default is False 195 | timeout: int 196 | wait timeout for connection setup 197 | ping_interval: int 198 | send ping interval, 0 for disable ping send, default is 8 199 | ping_timeout: int 200 | timeout after send ping and recive pong, set None for disable timeout check and default is None 201 | ex: dict 202 | dict which will merge into 'payload' field in request 203 | """ 204 | self.__nls = NlsCore( 205 | url=self.__url, 206 | token=self.__token, 207 | on_open=self.__sr_core_on_open, 208 | on_message=self.__sr_core_on_msg, 209 | on_close=self.__sr_core_on_close, 210 | on_error=self.__sr_core_on_error, 211 | callback_args=[]) 212 | 213 | if ch != 1: 214 | raise InvalidParameter(f'Not support channel {ch}') 215 | if aformat not in self.__allow_aformat: 216 | raise InvalidParameter(f'Format {aformat} not support') 217 | 218 | __id4 = uuid.uuid4().hex 219 | self.__task_id = uuid.uuid4().hex 220 | __header = { 221 | 'message_id': __id4, 222 | 'task_id': self.__task_id, 223 | 'namespace': __SPEECH_RECOGNIZER_NAMESPACE__, 224 | 'name': __SPEECH_RECOGNIZER_REQUEST_CMD__['start'], 225 | 'appkey': self.__appkey 226 | } 227 | __payload = { 228 | 'format': aformat, 229 | 'sample_rate': sample_rate, 230 | 'enable_intermediate_result': enable_intermediate_result, 231 | 'enable_punctuation_prediction': enable_punctuation_prediction, 232 | 'enable_inverse_text_normalization': enable_inverse_text_normalization 233 | } 234 | 235 | if ex: 236 | __payload.update(ex) 237 | 238 | __msg = { 239 | 'header': __header, 240 | 'payload': __payload, 241 | 'context': util.GetDefaultContext() 242 | } 243 | __jmsg = json.dumps(__msg) 244 | with self.__start_cond: 245 | if self.__start_flag: 246 | logging.debug('already start...') 247 | return 248 | self.__nls.start(__jmsg, ping_interval, ping_timeout) 249 | if self.__start_flag == False: 250 | if self.__start_cond.wait(timeout=timeout): 251 | return 252 | else: 253 | raise StartTimeoutException(f'Waiting Start over {timeout}s') 254 | 255 | def stop(self, timeout=10): 256 | """ 257 | Stop recognition and mark session finished 258 | 259 | Parameters: 260 | ----------- 261 | timeout: int 262 | timeout for waiting completed message from cloud 263 | """ 264 | __id4 = uuid.uuid4().hex 265 | __header = { 266 | 'message_id': __id4, 267 | 'task_id': self.__task_id, 268 | 'namespace': __SPEECH_RECOGNIZER_NAMESPACE__, 269 | 'name': __SPEECH_RECOGNIZER_REQUEST_CMD__['stop'], 270 | 'appkey': self.__appkey 271 | } 272 | __msg = { 273 | 'header': __header, 274 | 'context': util.GetDefaultContext() 275 | } 276 | __jmsg = json.dumps(__msg) 277 | with self.__start_cond: 278 | if not self.__start_flag: 279 | logging.debug('not start yet...') 280 | return 281 | self.__nls.send(__jmsg, False) 282 | if self.__start_flag == True: 283 | logging.debug('stop wait..') 284 | if self.__start_cond.wait(timeout): 285 | return 286 | else: 287 | raise StopTimeoutException(f'Waiting stop over {timeout}s') 288 | def shutdown(self): 289 | """ 290 | Shutdown connection immediately 291 | """ 292 | self.__nls.shutdown() 293 | 294 | def send_audio(self, pcm_data): 295 | """ 296 | Send audio binary, audio size prefer 20ms length 297 | 298 | Parameters: 299 | ----------- 300 | pcm_data: bytes 301 | audio binary which format is 'aformat' in start method 302 | """ 303 | if not pcm_data: 304 | raise InvalidParameter('data empty!') 305 | __data = pcm_data 306 | with self.__start_cond: 307 | if not self.__start_flag: 308 | raise NotStartException('Need start before send!') 309 | try: 310 | self.__nls.send(__data, True) 311 | except ConnectionResetError as __e: 312 | logging.error('connection reset') 313 | self.__start_flag = False 314 | self.__nls.shutdown() 315 | raise __e 316 | -------------------------------------------------------------------------------- /nls/websocket/_http.py: -------------------------------------------------------------------------------- 1 | """ 2 | _http.py 3 | websocket - WebSocket client library for Python 4 | 5 | Copyright 2021 engn33r 6 | 7 | Licensed under the Apache License, Version 2.0 (the "License"); 8 | you may not use this file except in compliance with the License. 9 | You may obtain a copy of the License at 10 | 11 | http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | Unless required by applicable law or agreed to in writing, software 14 | distributed under the License is distributed on an "AS IS" BASIS, 15 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | See the License for the specific language governing permissions and 17 | limitations under the License. 18 | """ 19 | import errno 20 | import os 21 | import socket 22 | import sys 23 | 24 | from ._exceptions import * 25 | from ._logging import * 26 | from ._socket import* 27 | from ._ssl_compat import * 28 | from ._url import * 29 | 30 | from base64 import encodebytes as base64encode 31 | 32 | __all__ = ["proxy_info", "connect", "read_headers"] 33 | 34 | try: 35 | from python_socks.sync import Proxy 36 | from python_socks._errors import * 37 | from python_socks._types import ProxyType 38 | HAVE_PYTHON_SOCKS = True 39 | except: 40 | HAVE_PYTHON_SOCKS = False 41 | 42 | class ProxyError(Exception): 43 | pass 44 | 45 | class ProxyTimeoutError(Exception): 46 | pass 47 | 48 | class ProxyConnectionError(Exception): 49 | pass 50 | 51 | 52 | class proxy_info: 53 | 54 | def __init__(self, **options): 55 | self.proxy_host = options.get("http_proxy_host", None) 56 | if self.proxy_host: 57 | self.proxy_port = options.get("http_proxy_port", 0) 58 | self.auth = options.get("http_proxy_auth", None) 59 | self.no_proxy = options.get("http_no_proxy", None) 60 | self.proxy_protocol = options.get("proxy_type", "http") 61 | # Note: If timeout not specified, default python-socks timeout is 60 seconds 62 | self.proxy_timeout = options.get("timeout", None) 63 | if self.proxy_protocol not in ['http', 'socks4', 'socks4a', 'socks5', 'socks5h']: 64 | raise ProxyError("Only http, socks4, socks5 proxy protocols are supported") 65 | else: 66 | self.proxy_port = 0 67 | self.auth = None 68 | self.no_proxy = None 69 | self.proxy_protocol = "http" 70 | 71 | 72 | def _start_proxied_socket(url, options, proxy): 73 | if not HAVE_PYTHON_SOCKS: 74 | raise WebSocketException("Python Socks is needed for SOCKS proxying but is not available") 75 | 76 | hostname, port, resource, is_secure = parse_url(url) 77 | 78 | if proxy.proxy_protocol == "socks5": 79 | rdns = False 80 | proxy_type = ProxyType.SOCKS5 81 | if proxy.proxy_protocol == "socks4": 82 | rdns = False 83 | proxy_type = ProxyType.SOCKS4 84 | # socks5h and socks4a send DNS through proxy 85 | if proxy.proxy_protocol == "socks5h": 86 | rdns = True 87 | proxy_type = ProxyType.SOCKS5 88 | if proxy.proxy_protocol == "socks4a": 89 | rdns = True 90 | proxy_type = ProxyType.SOCKS4 91 | 92 | ws_proxy = Proxy.create( 93 | proxy_type=proxy_type, 94 | host=proxy.proxy_host, 95 | port=int(proxy.proxy_port), 96 | username=proxy.auth[0] if proxy.auth else None, 97 | password=proxy.auth[1] if proxy.auth else None, 98 | rdns=rdns) 99 | 100 | sock = ws_proxy.connect(hostname, port, timeout=proxy.proxy_timeout) 101 | 102 | if is_secure and HAVE_SSL: 103 | sock = _ssl_socket(sock, options.sslopt, hostname) 104 | elif is_secure: 105 | raise WebSocketException("SSL not available.") 106 | 107 | return sock, (hostname, port, resource) 108 | 109 | 110 | def connect(url, options, proxy, socket): 111 | # Use _start_proxied_socket() only for socks4 or socks5 proxy 112 | # Use _tunnel() for http proxy 113 | # TODO: Use python-socks for http protocol also, to standardize flow 114 | if proxy.proxy_host and not socket and not (proxy.proxy_protocol == "http"): 115 | return _start_proxied_socket(url, options, proxy) 116 | 117 | hostname, port, resource, is_secure = parse_url(url) 118 | 119 | if socket: 120 | return socket, (hostname, port, resource) 121 | 122 | addrinfo_list, need_tunnel, auth = _get_addrinfo_list( 123 | hostname, port, is_secure, proxy) 124 | if not addrinfo_list: 125 | raise WebSocketException( 126 | "Host not found.: " + hostname + ":" + str(port)) 127 | 128 | sock = None 129 | try: 130 | sock = _open_socket(addrinfo_list, options.sockopt, options.timeout) 131 | if need_tunnel: 132 | sock = _tunnel(sock, hostname, port, auth) 133 | 134 | if is_secure: 135 | if HAVE_SSL: 136 | sock = _ssl_socket(sock, options.sslopt, hostname) 137 | else: 138 | raise WebSocketException("SSL not available.") 139 | 140 | return sock, (hostname, port, resource) 141 | except: 142 | if sock: 143 | sock.close() 144 | raise 145 | 146 | 147 | def _get_addrinfo_list(hostname, port, is_secure, proxy): 148 | phost, pport, pauth = get_proxy_info( 149 | hostname, is_secure, proxy.proxy_host, proxy.proxy_port, proxy.auth, proxy.no_proxy) 150 | try: 151 | # when running on windows 10, getaddrinfo without socktype returns a socktype 0. 152 | # This generates an error exception: `_on_error: exception Socket type must be stream or datagram, not 0` 153 | # or `OSError: [Errno 22] Invalid argument` when creating socket. Force the socket type to SOCK_STREAM. 154 | if not phost: 155 | addrinfo_list = socket.getaddrinfo( 156 | hostname, port, 0, socket.SOCK_STREAM, socket.SOL_TCP) 157 | return addrinfo_list, False, None 158 | else: 159 | pport = pport and pport or 80 160 | # when running on windows 10, the getaddrinfo used above 161 | # returns a socktype 0. This generates an error exception: 162 | # _on_error: exception Socket type must be stream or datagram, not 0 163 | # Force the socket type to SOCK_STREAM 164 | addrinfo_list = socket.getaddrinfo(phost, pport, 0, socket.SOCK_STREAM, socket.SOL_TCP) 165 | return addrinfo_list, True, pauth 166 | except socket.gaierror as e: 167 | raise WebSocketAddressException(e) 168 | 169 | 170 | def _open_socket(addrinfo_list, sockopt, timeout): 171 | err = None 172 | for addrinfo in addrinfo_list: 173 | family, socktype, proto = addrinfo[:3] 174 | sock = socket.socket(family, socktype, proto) 175 | sock.settimeout(timeout) 176 | for opts in DEFAULT_SOCKET_OPTION: 177 | sock.setsockopt(*opts) 178 | for opts in sockopt: 179 | sock.setsockopt(*opts) 180 | 181 | address = addrinfo[4] 182 | err = None 183 | while not err: 184 | try: 185 | sock.connect(address) 186 | except socket.error as error: 187 | error.remote_ip = str(address[0]) 188 | try: 189 | eConnRefused = (errno.ECONNREFUSED, errno.WSAECONNREFUSED) 190 | except: 191 | eConnRefused = (errno.ECONNREFUSED, ) 192 | if error.errno == errno.EINTR: 193 | continue 194 | elif error.errno in eConnRefused: 195 | err = error 196 | continue 197 | else: 198 | if sock: 199 | sock.close() 200 | raise error 201 | else: 202 | break 203 | else: 204 | continue 205 | break 206 | else: 207 | if err: 208 | raise err 209 | 210 | return sock 211 | 212 | 213 | def _wrap_sni_socket(sock, sslopt, hostname, check_hostname): 214 | context = ssl.SSLContext(sslopt.get('ssl_version', ssl.PROTOCOL_TLS)) 215 | 216 | if sslopt.get('cert_reqs', ssl.CERT_NONE) != ssl.CERT_NONE: 217 | cafile = sslopt.get('ca_certs', None) 218 | capath = sslopt.get('ca_cert_path', None) 219 | if cafile or capath: 220 | context.load_verify_locations(cafile=cafile, capath=capath) 221 | elif hasattr(context, 'load_default_certs'): 222 | context.load_default_certs(ssl.Purpose.SERVER_AUTH) 223 | if sslopt.get('certfile', None): 224 | context.load_cert_chain( 225 | sslopt['certfile'], 226 | sslopt.get('keyfile', None), 227 | sslopt.get('password', None), 228 | ) 229 | # see 230 | # https://github.com/liris/websocket-client/commit/b96a2e8fa765753e82eea531adb19716b52ca3ca#commitcomment-10803153 231 | context.verify_mode = sslopt['cert_reqs'] 232 | if HAVE_CONTEXT_CHECK_HOSTNAME: 233 | context.check_hostname = check_hostname 234 | if 'ciphers' in sslopt: 235 | context.set_ciphers(sslopt['ciphers']) 236 | if 'cert_chain' in sslopt: 237 | certfile, keyfile, password = sslopt['cert_chain'] 238 | context.load_cert_chain(certfile, keyfile, password) 239 | if 'ecdh_curve' in sslopt: 240 | context.set_ecdh_curve(sslopt['ecdh_curve']) 241 | 242 | return context.wrap_socket( 243 | sock, 244 | do_handshake_on_connect=sslopt.get('do_handshake_on_connect', True), 245 | suppress_ragged_eofs=sslopt.get('suppress_ragged_eofs', True), 246 | server_hostname=hostname, 247 | ) 248 | 249 | 250 | def _ssl_socket(sock, user_sslopt, hostname): 251 | sslopt = dict(cert_reqs=ssl.CERT_REQUIRED) 252 | sslopt.update(user_sslopt) 253 | 254 | certPath = os.environ.get('WEBSOCKET_CLIENT_CA_BUNDLE') 255 | if certPath and os.path.isfile(certPath) \ 256 | and user_sslopt.get('ca_certs', None) is None: 257 | sslopt['ca_certs'] = certPath 258 | elif certPath and os.path.isdir(certPath) \ 259 | and user_sslopt.get('ca_cert_path', None) is None: 260 | sslopt['ca_cert_path'] = certPath 261 | 262 | if sslopt.get('server_hostname', None): 263 | hostname = sslopt['server_hostname'] 264 | 265 | check_hostname = sslopt["cert_reqs"] != ssl.CERT_NONE and sslopt.pop( 266 | 'check_hostname', True) 267 | sock = _wrap_sni_socket(sock, sslopt, hostname, check_hostname) 268 | 269 | if not HAVE_CONTEXT_CHECK_HOSTNAME and check_hostname: 270 | match_hostname(sock.getpeercert(), hostname) 271 | 272 | return sock 273 | 274 | 275 | def _tunnel(sock, host, port, auth): 276 | debug("Connecting proxy...") 277 | connect_header = "CONNECT %s:%d HTTP/1.1\r\n" % (host, port) 278 | connect_header += "Host: %s:%d\r\n" % (host, port) 279 | 280 | # TODO: support digest auth. 281 | if auth and auth[0]: 282 | auth_str = auth[0] 283 | if auth[1]: 284 | auth_str += ":" + auth[1] 285 | encoded_str = base64encode(auth_str.encode()).strip().decode().replace('\n', '') 286 | connect_header += "Proxy-Authorization: Basic %s\r\n" % encoded_str 287 | connect_header += "\r\n" 288 | dump("request header", connect_header) 289 | 290 | send(sock, connect_header) 291 | 292 | try: 293 | status, resp_headers, status_message = read_headers(sock) 294 | except Exception as e: 295 | raise WebSocketProxyException(str(e)) 296 | 297 | if status != 200: 298 | raise WebSocketProxyException( 299 | "failed CONNECT via proxy status: %r" % status) 300 | 301 | return sock 302 | 303 | 304 | def read_headers(sock): 305 | status = None 306 | status_message = None 307 | headers = {} 308 | trace("--- response header ---") 309 | 310 | while True: 311 | line = recv_line(sock) 312 | line = line.decode('utf-8').strip() 313 | if not line: 314 | break 315 | trace(line) 316 | if not status: 317 | 318 | status_info = line.split(" ", 2) 319 | status = int(status_info[1]) 320 | if len(status_info) > 2: 321 | status_message = status_info[2] 322 | else: 323 | kv = line.split(":", 1) 324 | if len(kv) == 2: 325 | key, value = kv 326 | if key.lower() == "set-cookie" and headers.get("set-cookie"): 327 | headers["set-cookie"] = headers.get("set-cookie") + "; " + value.strip() 328 | else: 329 | headers[key.lower()] = value.strip() 330 | else: 331 | raise WebSocketException("Invalid header") 332 | 333 | trace("-----------------------") 334 | 335 | return status, headers, status_message 336 | -------------------------------------------------------------------------------- /nls/realtime_meeting.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | 3 | import logging 4 | import uuid 5 | import json 6 | import threading 7 | 8 | from nls.core import NlsCore 9 | from . import logging 10 | from . import util 11 | from nls.exception import (StartTimeoutException, 12 | StopTimeoutException, 13 | InvalidParameter) 14 | 15 | __REALTIME_MEETING_NAMESPACE__ = 'SpeechTranscriber' 16 | 17 | __REALTIME_MEETING_REQUEST_CMD__ = { 18 | 'start': 'StartTranscription', 19 | 'stop': 'StopTranscription' 20 | } 21 | 22 | __all__ = ['NlsRealtimeMeeting'] 23 | 24 | 25 | class NlsRealtimeMeeting: 26 | """ 27 | Api for realtime meeting 28 | """ 29 | 30 | def __init__(self, 31 | url=None, 32 | on_start=None, 33 | on_sentence_begin=None, 34 | on_sentence_end=None, 35 | on_result_changed=None, 36 | on_result_translated=None, 37 | on_completed=None, 38 | on_error=None, 39 | on_close=None, 40 | callback_args=[]): 41 | ''' 42 | NlsRealtimeMeeting initialization 43 | 44 | Parameters: 45 | ----------- 46 | url: str 47 | meeting join url. 48 | on_start: function 49 | Callback object which is called when recognition started. 50 | on_start has two arguments. 51 | The 1st argument is message which is a json format string. 52 | The 2nd argument is *args which is callback_args. 53 | on_sentence_begin: function 54 | Callback object which is called when one sentence started. 55 | on_sentence_begin has two arguments. 56 | The 1st argument is message which is a json format string. 57 | The 2nd argument is *args which is callback_args. 58 | on_sentence_end: function 59 | Callback object which is called when sentence is end. 60 | on_sentence_end has two arguments. 61 | The 1st argument is message which is a json format string. 62 | The 2nd argument is *args which is callback_args. 63 | on_result_changed: function 64 | Callback object which is called when partial recognition result 65 | arrived. 66 | on_result_changed has two arguments. 67 | The 1st argument is message which is a json format string. 68 | The 2nd argument is *args which is callback_args. 69 | on_result_translated: function 70 | Callback object which is called when partial translation result 71 | arrived. 72 | on_result_translated has two arguments. 73 | The 1st argument is message which is a json format string. 74 | The 2nd argument is *args which is callback_args. 75 | on_completed: function 76 | Callback object which is called when recognition is completed. 77 | on_completed has two arguments. 78 | The 1st argument is message which is a json format string. 79 | The 2nd argument is *args which is callback_args. 80 | on_error: function 81 | Callback object which is called when any error occurs. 82 | on_error has two arguments. 83 | The 1st argument is message which is a json format string. 84 | The 2nd argument is *args which is callback_args. 85 | on_close: function 86 | Callback object which is called when connection closed. 87 | on_close has one arguments. 88 | The 1st argument is *args which is callback_args. 89 | callback_args: list 90 | callback_args will return in callbacks above for *args. 91 | ''' 92 | if not url: 93 | raise InvalidParameter('Must provide url') 94 | self.__response_handler__ = { 95 | 'SentenceBegin': self.__sentence_begin, 96 | 'SentenceEnd': self.__sentence_end, 97 | 'TranscriptionStarted': self.__transcription_started, 98 | 'TranscriptionResultChanged': self.__transcription_result_changed, 99 | 'ResultTranslated': self.__transcription_result_translated, 100 | 'TranscriptionCompleted': self.__transcription_completed, 101 | 'TaskFailed': self.__task_failed 102 | } 103 | self.__callback_args = callback_args 104 | self.__url = url 105 | self.__start_cond = threading.Condition() 106 | self.__start_flag = False 107 | self.__on_start = on_start 108 | self.__on_sentence_begin = on_sentence_begin 109 | self.__on_sentence_end = on_sentence_end 110 | self.__on_result_changed = on_result_changed 111 | self.__on_result_translated = on_result_translated 112 | self.__on_completed = on_completed 113 | self.__on_error = on_error 114 | self.__on_close = on_close 115 | 116 | def __handle_message(self, message): 117 | logging.debug('__handle_message {}'.format(message)) 118 | try: 119 | __result = json.loads(message) 120 | if __result['header']['name'] in self.__response_handler__: 121 | __handler = self.__response_handler__[ 122 | __result['header']['name']] 123 | __handler(message) 124 | else: 125 | logging.error('cannot handle cmd{}'.format( 126 | __result['header']['name'])) 127 | return 128 | except json.JSONDecodeError: 129 | logging.error('cannot parse message:{}'.format(message)) 130 | return 131 | 132 | def __tr_core_on_open(self): 133 | logging.debug('__tr_core_on_open') 134 | 135 | def __tr_core_on_msg(self, msg, *args): 136 | logging.debug('__tr_core_on_msg:msg={} args={}'.format(msg, args)) 137 | self.__handle_message(msg) 138 | 139 | def __tr_core_on_error(self, msg, *args): 140 | logging.debug('__tr_core_on_error:msg={} args={}'.format(msg, args)) 141 | with self.__start_cond: 142 | self.__start_flag = False 143 | self.__start_cond.notify() 144 | if self.__on_error: 145 | self.__on_error(msg, *self.__callback_args) 146 | 147 | def __tr_core_on_close(self): 148 | logging.debug('__tr_core_on_close') 149 | if self.__on_close: 150 | self.__on_close(*self.__callback_args) 151 | with self.__start_cond: 152 | self.__start_flag = False 153 | self.__start_cond.notify() 154 | 155 | def __sentence_begin(self, message): 156 | logging.debug('__sentence_begin') 157 | if self.__on_sentence_begin: 158 | self.__on_sentence_begin(message, *self.__callback_args) 159 | 160 | def __sentence_end(self, message): 161 | logging.debug('__sentence_end') 162 | if self.__on_sentence_end: 163 | self.__on_sentence_end(message, *self.__callback_args) 164 | 165 | def __transcription_started(self, message): 166 | logging.debug('__transcription_started') 167 | if self.__on_start: 168 | self.__on_start(message, *self.__callback_args) 169 | with self.__start_cond: 170 | self.__start_flag = True 171 | self.__start_cond.notify() 172 | 173 | def __transcription_result_changed(self, message): 174 | logging.debug('__transcription_result_changed') 175 | if self.__on_result_changed: 176 | self.__on_result_changed(message, *self.__callback_args) 177 | 178 | def __transcription_result_translated(self, message): 179 | logging.debug('__transcription_result_translated') 180 | if self.__on_result_translated: 181 | self.__on_result_translated(message, *self.__callback_args) 182 | 183 | def __transcription_completed(self, message): 184 | logging.debug('__transcription_completed') 185 | self.__nls.shutdown() 186 | logging.debug('__transcription_completed shutdown done') 187 | if self.__on_completed: 188 | self.__on_completed(message, *self.__callback_args) 189 | with self.__start_cond: 190 | self.__start_flag = False 191 | self.__start_cond.notify() 192 | 193 | def __task_failed(self, message): 194 | logging.debug('__task_failed') 195 | with self.__start_cond: 196 | self.__start_flag = False 197 | self.__start_cond.notify() 198 | if self.__on_error: 199 | self.__on_error(message, *self.__callback_args) 200 | 201 | def start(self, 202 | timeout=10, 203 | ping_interval=8, 204 | ping_timeout=None, 205 | ex:dict=None): 206 | """ 207 | Realtime meeting start 208 | 209 | Parameters: 210 | ----------- 211 | timeout: int 212 | wait timeout for connection setup 213 | ping_interval: int 214 | send ping interval, 0 for disable ping send, default is 8 215 | ping_timeout: int 216 | timeout after send ping and recive pong, set None for disable timeout check and default is None 217 | ex: dict 218 | dict which will merge into 'payload' field in request 219 | """ 220 | self.__nls = NlsCore( 221 | url=self.__url, 222 | token='default', 223 | on_open=self.__tr_core_on_open, 224 | on_message=self.__tr_core_on_msg, 225 | on_close=self.__tr_core_on_close, 226 | on_error=self.__tr_core_on_error, 227 | callback_args=[]) 228 | 229 | __id4 = uuid.uuid4().hex 230 | self.__task_id = uuid.uuid4().hex 231 | __header = { 232 | 'message_id': __id4, 233 | 'task_id': self.__task_id, 234 | 'namespace': __REALTIME_MEETING_NAMESPACE__, 235 | 'name': __REALTIME_MEETING_REQUEST_CMD__['start'], 236 | 'appkey': 'default' 237 | } 238 | __payload = { 239 | } 240 | 241 | if ex: 242 | __payload.update(ex) 243 | 244 | __msg = { 245 | 'header': __header, 246 | 'payload': __payload, 247 | 'context': util.GetDefaultContext() 248 | } 249 | __jmsg = json.dumps(__msg) 250 | with self.__start_cond: 251 | if self.__start_flag: 252 | logging.debug('already start...') 253 | return 254 | self.__nls.start(__jmsg, ping_interval, ping_timeout) 255 | if self.__start_flag == False: 256 | if self.__start_cond.wait(timeout): 257 | return 258 | else: 259 | raise StartTimeoutException(f'Waiting Start over {timeout}s') 260 | 261 | def stop(self, timeout=10): 262 | """ 263 | Stop realtime meeting and mark session finished 264 | 265 | Parameters: 266 | ----------- 267 | timeout: int 268 | timeout for waiting completed message from cloud 269 | """ 270 | __id4 = uuid.uuid4().hex 271 | __header = { 272 | 'message_id': __id4, 273 | 'task_id': self.__task_id, 274 | 'namespace': __REALTIME_MEETING_NAMESPACE__, 275 | 'name': __REALTIME_MEETING_REQUEST_CMD__['stop'], 276 | 'appkey': 'default' 277 | } 278 | __msg = { 279 | 'header': __header, 280 | 'context': util.GetDefaultContext() 281 | } 282 | __jmsg = json.dumps(__msg) 283 | with self.__start_cond: 284 | if not self.__start_flag: 285 | logging.debug('not start yet...') 286 | return 287 | self.__nls.send(__jmsg, False) 288 | if self.__start_flag == True: 289 | logging.debug('stop wait..') 290 | if self.__start_cond.wait(timeout): 291 | return 292 | else: 293 | raise StopTimeoutException(f'Waiting stop over {timeout}s') 294 | 295 | def shutdown(self): 296 | """ 297 | Shutdown connection immediately 298 | """ 299 | self.__nls.shutdown() 300 | 301 | def send_audio(self, pcm_data): 302 | """ 303 | Send audio binary, audio size prefer 20ms length 304 | 305 | Parameters: 306 | ----------- 307 | pcm_data: bytes 308 | audio binary which format created by CreateTask 309 | """ 310 | 311 | __data = pcm_data 312 | with self.__start_cond: 313 | if not self.__start_flag: 314 | return 315 | try: 316 | self.__nls.send(__data, True) 317 | except ConnectionResetError as __e: 318 | logging.error('connection reset') 319 | self.__start_flag = False 320 | self.__nls.shutdown() 321 | raise __e 322 | -------------------------------------------------------------------------------- /nls/websocket/_abnf.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | """ 4 | 5 | """ 6 | _abnf.py 7 | websocket - WebSocket client library for Python 8 | 9 | Copyright 2021 engn33r 10 | 11 | Licensed under the Apache License, Version 2.0 (the "License"); 12 | you may not use this file except in compliance with the License. 13 | You may obtain a copy of the License at 14 | 15 | http://www.apache.org/licenses/LICENSE-2.0 16 | 17 | Unless required by applicable law or agreed to in writing, software 18 | distributed under the License is distributed on an "AS IS" BASIS, 19 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | See the License for the specific language governing permissions and 21 | limitations under the License. 22 | """ 23 | import array 24 | import os 25 | import struct 26 | import sys 27 | 28 | from ._exceptions import * 29 | from ._utils import validate_utf8 30 | from threading import Lock 31 | 32 | try: 33 | # If wsaccel is available, use compiled routines to mask data. 34 | # wsaccel only provides around a 10% speed boost compared 35 | # to the websocket-client _mask() implementation. 36 | # Note that wsaccel is unmaintained. 37 | from wsaccel.xormask import XorMaskerSimple 38 | 39 | def _mask(_m, _d): 40 | return XorMaskerSimple(_m).process(_d) 41 | 42 | except ImportError: 43 | # wsaccel is not available, use websocket-client _mask() 44 | native_byteorder = sys.byteorder 45 | 46 | def _mask(mask_value, data_value): 47 | datalen = len(data_value) 48 | data_value = int.from_bytes(data_value, native_byteorder) 49 | mask_value = int.from_bytes(mask_value * (datalen // 4) + mask_value[: datalen % 4], native_byteorder) 50 | return (data_value ^ mask_value).to_bytes(datalen, native_byteorder) 51 | 52 | 53 | __all__ = [ 54 | 'ABNF', 'continuous_frame', 'frame_buffer', 55 | 'STATUS_NORMAL', 56 | 'STATUS_GOING_AWAY', 57 | 'STATUS_PROTOCOL_ERROR', 58 | 'STATUS_UNSUPPORTED_DATA_TYPE', 59 | 'STATUS_STATUS_NOT_AVAILABLE', 60 | 'STATUS_ABNORMAL_CLOSED', 61 | 'STATUS_INVALID_PAYLOAD', 62 | 'STATUS_POLICY_VIOLATION', 63 | 'STATUS_MESSAGE_TOO_BIG', 64 | 'STATUS_INVALID_EXTENSION', 65 | 'STATUS_UNEXPECTED_CONDITION', 66 | 'STATUS_BAD_GATEWAY', 67 | 'STATUS_TLS_HANDSHAKE_ERROR', 68 | ] 69 | 70 | # closing frame status codes. 71 | STATUS_NORMAL = 1000 72 | STATUS_GOING_AWAY = 1001 73 | STATUS_PROTOCOL_ERROR = 1002 74 | STATUS_UNSUPPORTED_DATA_TYPE = 1003 75 | STATUS_STATUS_NOT_AVAILABLE = 1005 76 | STATUS_ABNORMAL_CLOSED = 1006 77 | STATUS_INVALID_PAYLOAD = 1007 78 | STATUS_POLICY_VIOLATION = 1008 79 | STATUS_MESSAGE_TOO_BIG = 1009 80 | STATUS_INVALID_EXTENSION = 1010 81 | STATUS_UNEXPECTED_CONDITION = 1011 82 | STATUS_BAD_GATEWAY = 1014 83 | STATUS_TLS_HANDSHAKE_ERROR = 1015 84 | 85 | VALID_CLOSE_STATUS = ( 86 | STATUS_NORMAL, 87 | STATUS_GOING_AWAY, 88 | STATUS_PROTOCOL_ERROR, 89 | STATUS_UNSUPPORTED_DATA_TYPE, 90 | STATUS_INVALID_PAYLOAD, 91 | STATUS_POLICY_VIOLATION, 92 | STATUS_MESSAGE_TOO_BIG, 93 | STATUS_INVALID_EXTENSION, 94 | STATUS_UNEXPECTED_CONDITION, 95 | STATUS_BAD_GATEWAY, 96 | ) 97 | 98 | 99 | class ABNF: 100 | """ 101 | ABNF frame class. 102 | See http://tools.ietf.org/html/rfc5234 103 | and http://tools.ietf.org/html/rfc6455#section-5.2 104 | """ 105 | 106 | # operation code values. 107 | OPCODE_CONT = 0x0 108 | OPCODE_TEXT = 0x1 109 | OPCODE_BINARY = 0x2 110 | OPCODE_CLOSE = 0x8 111 | OPCODE_PING = 0x9 112 | OPCODE_PONG = 0xa 113 | 114 | # available operation code value tuple 115 | OPCODES = (OPCODE_CONT, OPCODE_TEXT, OPCODE_BINARY, OPCODE_CLOSE, 116 | OPCODE_PING, OPCODE_PONG) 117 | 118 | # opcode human readable string 119 | OPCODE_MAP = { 120 | OPCODE_CONT: "cont", 121 | OPCODE_TEXT: "text", 122 | OPCODE_BINARY: "binary", 123 | OPCODE_CLOSE: "close", 124 | OPCODE_PING: "ping", 125 | OPCODE_PONG: "pong" 126 | } 127 | 128 | # data length threshold. 129 | LENGTH_7 = 0x7e 130 | LENGTH_16 = 1 << 16 131 | LENGTH_63 = 1 << 63 132 | 133 | def __init__(self, fin=0, rsv1=0, rsv2=0, rsv3=0, 134 | opcode=OPCODE_TEXT, mask=1, data=""): 135 | """ 136 | Constructor for ABNF. Please check RFC for arguments. 137 | """ 138 | self.fin = fin 139 | self.rsv1 = rsv1 140 | self.rsv2 = rsv2 141 | self.rsv3 = rsv3 142 | self.opcode = opcode 143 | self.mask = mask 144 | if data is None: 145 | data = "" 146 | self.data = data 147 | self.get_mask_key = os.urandom 148 | 149 | def validate(self, skip_utf8_validation=False): 150 | """ 151 | Validate the ABNF frame. 152 | 153 | Parameters 154 | ---------- 155 | skip_utf8_validation: skip utf8 validation. 156 | """ 157 | if self.rsv1 or self.rsv2 or self.rsv3: 158 | raise WebSocketProtocolException("rsv is not implemented, yet") 159 | 160 | if self.opcode not in ABNF.OPCODES: 161 | raise WebSocketProtocolException("Invalid opcode %r", self.opcode) 162 | 163 | if self.opcode == ABNF.OPCODE_PING and not self.fin: 164 | raise WebSocketProtocolException("Invalid ping frame.") 165 | 166 | if self.opcode == ABNF.OPCODE_CLOSE: 167 | l = len(self.data) 168 | if not l: 169 | return 170 | if l == 1 or l >= 126: 171 | raise WebSocketProtocolException("Invalid close frame.") 172 | if l > 2 and not skip_utf8_validation and not validate_utf8(self.data[2:]): 173 | raise WebSocketProtocolException("Invalid close frame.") 174 | 175 | code = 256 * self.data[0] + self.data[1] 176 | if not self._is_valid_close_status(code): 177 | raise WebSocketProtocolException("Invalid close opcode.") 178 | 179 | @staticmethod 180 | def _is_valid_close_status(code): 181 | return code in VALID_CLOSE_STATUS or (3000 <= code < 5000) 182 | 183 | def __str__(self): 184 | return "fin=" + str(self.fin) \ 185 | + " opcode=" + str(self.opcode) \ 186 | + " data=" + str(self.data) 187 | 188 | @staticmethod 189 | def create_frame(data, opcode, fin=1): 190 | """ 191 | Create frame to send text, binary and other data. 192 | 193 | Parameters 194 | ---------- 195 | data: 196 | data to send. This is string value(byte array). 197 | If opcode is OPCODE_TEXT and this value is unicode, 198 | data value is converted into unicode string, automatically. 199 | opcode: 200 | operation code. please see OPCODE_XXX. 201 | fin: 202 | fin flag. if set to 0, create continue fragmentation. 203 | """ 204 | if opcode == ABNF.OPCODE_TEXT and isinstance(data, str): 205 | data = data.encode("utf-8") 206 | # mask must be set if send data from client 207 | return ABNF(fin, 0, 0, 0, opcode, 1, data) 208 | 209 | def format(self): 210 | """ 211 | Format this object to string(byte array) to send data to server. 212 | """ 213 | if any(x not in (0, 1) for x in [self.fin, self.rsv1, self.rsv2, self.rsv3]): 214 | raise ValueError("not 0 or 1") 215 | if self.opcode not in ABNF.OPCODES: 216 | raise ValueError("Invalid OPCODE") 217 | length = len(self.data) 218 | if length >= ABNF.LENGTH_63: 219 | raise ValueError("data is too long") 220 | 221 | frame_header = chr(self.fin << 7 | 222 | self.rsv1 << 6 | self.rsv2 << 5 | self.rsv3 << 4 | 223 | self.opcode).encode('latin-1') 224 | if length < ABNF.LENGTH_7: 225 | frame_header += chr(self.mask << 7 | length).encode('latin-1') 226 | elif length < ABNF.LENGTH_16: 227 | frame_header += chr(self.mask << 7 | 0x7e).encode('latin-1') 228 | frame_header += struct.pack("!H", length) 229 | else: 230 | frame_header += chr(self.mask << 7 | 0x7f).encode('latin-1') 231 | frame_header += struct.pack("!Q", length) 232 | 233 | if not self.mask: 234 | return frame_header + self.data 235 | else: 236 | mask_key = self.get_mask_key(4) 237 | return frame_header + self._get_masked(mask_key) 238 | 239 | def _get_masked(self, mask_key): 240 | s = ABNF.mask(mask_key, self.data) 241 | 242 | if isinstance(mask_key, str): 243 | mask_key = mask_key.encode('utf-8') 244 | 245 | return mask_key + s 246 | 247 | @staticmethod 248 | def mask(mask_key, data): 249 | """ 250 | Mask or unmask data. Just do xor for each byte 251 | 252 | Parameters 253 | ---------- 254 | mask_key: 255 | 4 byte string. 256 | data: 257 | data to mask/unmask. 258 | """ 259 | if data is None: 260 | data = "" 261 | 262 | if isinstance(mask_key, str): 263 | mask_key = mask_key.encode('latin-1') 264 | 265 | if isinstance(data, str): 266 | data = data.encode('latin-1') 267 | 268 | return _mask(array.array("B", mask_key), array.array("B", data)) 269 | 270 | 271 | class frame_buffer: 272 | _HEADER_MASK_INDEX = 5 273 | _HEADER_LENGTH_INDEX = 6 274 | 275 | def __init__(self, recv_fn, skip_utf8_validation): 276 | self.recv = recv_fn 277 | self.skip_utf8_validation = skip_utf8_validation 278 | # Buffers over the packets from the layer beneath until desired amount 279 | # bytes of bytes are received. 280 | self.recv_buffer = [] 281 | self.clear() 282 | self.lock = Lock() 283 | 284 | def clear(self): 285 | self.header = None 286 | self.length = None 287 | self.mask = None 288 | 289 | def has_received_header(self): 290 | return self.header is None 291 | 292 | def recv_header(self): 293 | header = self.recv_strict(2) 294 | b1 = header[0] 295 | fin = b1 >> 7 & 1 296 | rsv1 = b1 >> 6 & 1 297 | rsv2 = b1 >> 5 & 1 298 | rsv3 = b1 >> 4 & 1 299 | opcode = b1 & 0xf 300 | b2 = header[1] 301 | has_mask = b2 >> 7 & 1 302 | length_bits = b2 & 0x7f 303 | 304 | self.header = (fin, rsv1, rsv2, rsv3, opcode, has_mask, length_bits) 305 | 306 | def has_mask(self): 307 | if not self.header: 308 | return False 309 | return self.header[frame_buffer._HEADER_MASK_INDEX] 310 | 311 | def has_received_length(self): 312 | return self.length is None 313 | 314 | def recv_length(self): 315 | bits = self.header[frame_buffer._HEADER_LENGTH_INDEX] 316 | length_bits = bits & 0x7f 317 | if length_bits == 0x7e: 318 | v = self.recv_strict(2) 319 | self.length = struct.unpack("!H", v)[0] 320 | elif length_bits == 0x7f: 321 | v = self.recv_strict(8) 322 | self.length = struct.unpack("!Q", v)[0] 323 | else: 324 | self.length = length_bits 325 | 326 | def has_received_mask(self): 327 | return self.mask is None 328 | 329 | def recv_mask(self): 330 | self.mask = self.recv_strict(4) if self.has_mask() else "" 331 | 332 | def recv_frame(self): 333 | 334 | with self.lock: 335 | # Header 336 | if self.has_received_header(): 337 | self.recv_header() 338 | (fin, rsv1, rsv2, rsv3, opcode, has_mask, _) = self.header 339 | 340 | # Frame length 341 | if self.has_received_length(): 342 | self.recv_length() 343 | length = self.length 344 | 345 | # Mask 346 | if self.has_received_mask(): 347 | self.recv_mask() 348 | mask = self.mask 349 | 350 | # Payload 351 | payload = self.recv_strict(length) 352 | if has_mask: 353 | payload = ABNF.mask(mask, payload) 354 | 355 | # Reset for next frame 356 | self.clear() 357 | 358 | frame = ABNF(fin, rsv1, rsv2, rsv3, opcode, has_mask, payload) 359 | frame.validate(self.skip_utf8_validation) 360 | 361 | return frame 362 | 363 | def recv_strict(self, bufsize): 364 | shortage = bufsize - sum(map(len, self.recv_buffer)) 365 | while shortage > 0: 366 | # Limit buffer size that we pass to socket.recv() to avoid 367 | # fragmenting the heap -- the number of bytes recv() actually 368 | # reads is limited by socket buffer and is relatively small, 369 | # yet passing large numbers repeatedly causes lots of large 370 | # buffers allocated and then shrunk, which results in 371 | # fragmentation. 372 | bytes_ = self.recv(min(16384, shortage)) 373 | self.recv_buffer.append(bytes_) 374 | shortage -= len(bytes_) 375 | 376 | unified = bytes("", 'utf-8').join(self.recv_buffer) 377 | 378 | if shortage == 0: 379 | self.recv_buffer = [] 380 | return unified 381 | else: 382 | self.recv_buffer = [unified[bufsize:]] 383 | return unified[:bufsize] 384 | 385 | 386 | class continuous_frame: 387 | 388 | def __init__(self, fire_cont_frame, skip_utf8_validation): 389 | self.fire_cont_frame = fire_cont_frame 390 | self.skip_utf8_validation = skip_utf8_validation 391 | self.cont_data = None 392 | self.recving_frames = None 393 | 394 | def validate(self, frame): 395 | if not self.recving_frames and frame.opcode == ABNF.OPCODE_CONT: 396 | raise WebSocketProtocolException("Illegal frame") 397 | if self.recving_frames and \ 398 | frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY): 399 | raise WebSocketProtocolException("Illegal frame") 400 | 401 | def add(self, frame): 402 | if self.cont_data: 403 | self.cont_data[1] += frame.data 404 | else: 405 | if frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY): 406 | self.recving_frames = frame.opcode 407 | self.cont_data = [frame.opcode, frame.data] 408 | 409 | if frame.fin: 410 | self.recving_frames = None 411 | 412 | def is_fire(self, frame): 413 | return frame.fin or self.fire_cont_frame 414 | 415 | def extract(self, frame): 416 | data = self.cont_data 417 | self.cont_data = None 418 | frame.data = data[1] 419 | if not self.fire_cont_frame and data[0] == ABNF.OPCODE_TEXT and not self.skip_utf8_validation and not validate_utf8(frame.data): 420 | raise WebSocketPayloadException( 421 | "cannot decode: " + repr(frame.data)) 422 | 423 | return [data[0], frame] 424 | -------------------------------------------------------------------------------- /nls/speech_transcriber.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | 3 | import logging 4 | import uuid 5 | import json 6 | import threading 7 | 8 | from nls.core import NlsCore 9 | from . import logging 10 | from . import util 11 | from nls.exception import (StartTimeoutException, 12 | StopTimeoutException, 13 | NotStartException, 14 | InvalidParameter) 15 | 16 | __SPEECH_TRANSCRIBER_NAMESPACE__ = 'SpeechTranscriber' 17 | 18 | __SPEECH_TRANSCRIBER_REQUEST_CMD__ = { 19 | 'start': 'StartTranscription', 20 | 'stop': 'StopTranscription', 21 | 'control': 'ControlTranscriber' 22 | } 23 | 24 | __URL__ = 'wss://nls-gateway.cn-shanghai.aliyuncs.com/ws/v1' 25 | __all__ = ['NlsSpeechTranscriber'] 26 | 27 | 28 | class NlsSpeechTranscriber: 29 | """ 30 | Api for realtime speech transcription 31 | """ 32 | 33 | def __init__(self, 34 | url=__URL__, 35 | token=None, 36 | appkey=None, 37 | on_start=None, 38 | on_sentence_begin=None, 39 | on_sentence_end=None, 40 | on_result_changed=None, 41 | on_completed=None, 42 | on_error=None, 43 | on_close=None, 44 | callback_args=[]): 45 | ''' 46 | NlsSpeechTranscriber initialization 47 | 48 | Parameters: 49 | ----------- 50 | url: str 51 | websocket url. 52 | token: str 53 | access token. if you do not have a token, provide access id and key 54 | secret from your aliyun account. 55 | appkey: str 56 | appkey from aliyun 57 | on_start: function 58 | Callback object which is called when recognition started. 59 | on_start has two arguments. 60 | The 1st argument is message which is a json format string. 61 | The 2nd argument is *args which is callback_args. 62 | on_sentence_begin: function 63 | Callback object which is called when one sentence started. 64 | on_sentence_begin has two arguments. 65 | The 1st argument is message which is a json format string. 66 | The 2nd argument is *args which is callback_args. 67 | on_sentence_end: function 68 | Callback object which is called when sentence is end. 69 | on_sentence_end has two arguments. 70 | The 1st argument is message which is a json format string. 71 | The 2nd argument is *args which is callback_args. 72 | on_result_changed: function 73 | Callback object which is called when partial recognition result 74 | arrived. 75 | on_result_changed has two arguments. 76 | The 1st argument is message which is a json format string. 77 | The 2nd argument is *args which is callback_args. 78 | on_completed: function 79 | Callback object which is called when recognition is completed. 80 | on_completed has two arguments. 81 | The 1st argument is message which is a json format string. 82 | The 2nd argument is *args which is callback_args. 83 | on_error: function 84 | Callback object which is called when any error occurs. 85 | on_error has two arguments. 86 | The 1st argument is message which is a json format string. 87 | The 2nd argument is *args which is callback_args. 88 | on_close: function 89 | Callback object which is called when connection closed. 90 | on_close has one arguments. 91 | The 1st argument is *args which is callback_args. 92 | callback_args: list 93 | callback_args will return in callbacks above for *args. 94 | ''' 95 | if not token or not appkey: 96 | raise InvalidParameter('Must provide token and appkey') 97 | self.__response_handler__ = { 98 | 'SentenceBegin': self.__sentence_begin, 99 | 'SentenceEnd': self.__sentence_end, 100 | 'TranscriptionStarted': self.__transcription_started, 101 | 'TranscriptionResultChanged': self.__transcription_result_changed, 102 | 'TranscriptionCompleted': self.__transcription_completed, 103 | 'TaskFailed': self.__task_failed 104 | } 105 | self.__callback_args = callback_args 106 | self.__url = url 107 | self.__appkey = appkey 108 | self.__token = token 109 | self.__start_cond = threading.Condition() 110 | self.__start_flag = False 111 | self.__on_start = on_start 112 | self.__on_sentence_begin = on_sentence_begin 113 | self.__on_sentence_end = on_sentence_end 114 | self.__on_result_changed = on_result_changed 115 | self.__on_completed = on_completed 116 | self.__on_error = on_error 117 | self.__on_close = on_close 118 | self.__allow_aformat = ( 119 | 'pcm', 'opus', 'opu', 'wav', 'amr', 'speex', 'mp3', 'aac' 120 | ) 121 | 122 | def __handle_message(self, message): 123 | logging.debug('__handle_message') 124 | try: 125 | __result = json.loads(message) 126 | if __result['header']['name'] in self.__response_handler__: 127 | __handler = self.__response_handler__[ 128 | __result['header']['name']] 129 | __handler(message) 130 | else: 131 | logging.error('cannot handle cmd{}'.format( 132 | __result['header']['name'])) 133 | return 134 | except json.JSONDecodeError: 135 | logging.error('cannot parse message:{}'.format(message)) 136 | return 137 | 138 | def __tr_core_on_open(self): 139 | logging.debug('__tr_core_on_open') 140 | 141 | def __tr_core_on_msg(self, msg, *args): 142 | logging.debug('__tr_core_on_msg:msg={} args={}'.format(msg, args)) 143 | self.__handle_message(msg) 144 | 145 | def __tr_core_on_error(self, msg, *args): 146 | logging.debug('__tr_core_on_error:msg={} args={}'.format(msg, args)) 147 | 148 | def __tr_core_on_close(self): 149 | logging.debug('__tr_core_on_close') 150 | if self.__on_close: 151 | self.__on_close(*self.__callback_args) 152 | with self.__start_cond: 153 | self.__start_flag = False 154 | self.__start_cond.notify() 155 | 156 | def __sentence_begin(self, message): 157 | logging.debug('__sentence_begin') 158 | if self.__on_sentence_begin: 159 | self.__on_sentence_begin(message, *self.__callback_args) 160 | 161 | def __sentence_end(self, message): 162 | logging.debug('__sentence_end') 163 | if self.__on_sentence_end: 164 | self.__on_sentence_end(message, *self.__callback_args) 165 | 166 | def __transcription_started(self, message): 167 | logging.debug('__transcription_started') 168 | if self.__on_start: 169 | self.__on_start(message, *self.__callback_args) 170 | with self.__start_cond: 171 | self.__start_flag = True 172 | self.__start_cond.notify() 173 | 174 | def __transcription_result_changed(self, message): 175 | logging.debug('__transcription_result_changed') 176 | if self.__on_result_changed: 177 | self.__on_result_changed(message, *self.__callback_args) 178 | 179 | def __transcription_completed(self, message): 180 | logging.debug('__transcription_completed') 181 | self.__nls.shutdown() 182 | logging.debug('__transcription_completed shutdown done') 183 | if self.__on_completed: 184 | self.__on_completed(message, *self.__callback_args) 185 | with self.__start_cond: 186 | self.__start_flag = False 187 | self.__start_cond.notify() 188 | 189 | def __task_failed(self, message): 190 | logging.debug('__task_failed') 191 | with self.__start_cond: 192 | self.__start_flag = False 193 | self.__start_cond.notify() 194 | if self.__on_error: 195 | self.__on_error(message, *self.__callback_args) 196 | 197 | def start(self, aformat='pcm', sample_rate=16000, ch=1, 198 | enable_intermediate_result=False, 199 | enable_punctuation_prediction=False, 200 | enable_inverse_text_normalization=False, 201 | timeout=10, 202 | ping_interval=8, 203 | ping_timeout=None, 204 | ex:dict=None): 205 | """ 206 | Transcription start 207 | 208 | Parameters: 209 | ----------- 210 | aformat: str 211 | audio binary format, support: 'pcm', 'opu', 'opus', default is 'pcm' 212 | sample_rate: int 213 | audio sample rate, default is 16000 214 | ch: int 215 | audio channels, only support mono which is 1 216 | enable_intermediate_result: bool 217 | whether enable return intermediate recognition result, default is False 218 | enable_punctuation_prediction: bool 219 | whether enable punctuation prediction, default is False 220 | enable_inverse_text_normalization: bool 221 | whether enable ITN, default is False 222 | timeout: int 223 | wait timeout for connection setup 224 | ping_interval: int 225 | send ping interval, 0 for disable ping send, default is 8 226 | ping_timeout: int 227 | timeout after send ping and recive pong, set None for disable timeout check and default is None 228 | ex: dict 229 | dict which will merge into 'payload' field in request 230 | """ 231 | self.__nls = NlsCore( 232 | url=self.__url, 233 | token=self.__token, 234 | on_open=self.__tr_core_on_open, 235 | on_message=self.__tr_core_on_msg, 236 | on_close=self.__tr_core_on_close, 237 | on_error=self.__tr_core_on_error, 238 | callback_args=[]) 239 | 240 | if ch != 1: 241 | raise ValueError('not support channel: {}'.format(ch)) 242 | if aformat not in self.__allow_aformat: 243 | raise ValueError('format {} not support'.format(aformat)) 244 | __id4 = uuid.uuid4().hex 245 | self.__task_id = uuid.uuid4().hex 246 | __header = { 247 | 'message_id': __id4, 248 | 'task_id': self.__task_id, 249 | 'namespace': __SPEECH_TRANSCRIBER_NAMESPACE__, 250 | 'name': __SPEECH_TRANSCRIBER_REQUEST_CMD__['start'], 251 | 'appkey': self.__appkey 252 | } 253 | __payload = { 254 | 'format': aformat, 255 | 'sample_rate': sample_rate, 256 | 'enable_intermediate_result': enable_intermediate_result, 257 | 'enable_punctuation_prediction': enable_punctuation_prediction, 258 | 'enable_inverse_text_normalization': enable_inverse_text_normalization 259 | } 260 | 261 | if ex: 262 | __payload.update(ex) 263 | 264 | __msg = { 265 | 'header': __header, 266 | 'payload': __payload, 267 | 'context': util.GetDefaultContext() 268 | } 269 | __jmsg = json.dumps(__msg) 270 | with self.__start_cond: 271 | if self.__start_flag: 272 | logging.debug('already start...') 273 | return 274 | self.__nls.start(__jmsg, ping_interval, ping_timeout) 275 | if self.__start_flag == False: 276 | if self.__start_cond.wait(timeout): 277 | return 278 | else: 279 | raise StartTimeoutException(f'Waiting Start over {timeout}s') 280 | 281 | def stop(self, timeout=10): 282 | """ 283 | Stop transcription and mark session finished 284 | 285 | Parameters: 286 | ----------- 287 | timeout: int 288 | timeout for waiting completed message from cloud 289 | """ 290 | __id4 = uuid.uuid4().hex 291 | __header = { 292 | 'message_id': __id4, 293 | 'task_id': self.__task_id, 294 | 'namespace': __SPEECH_TRANSCRIBER_NAMESPACE__, 295 | 'name': __SPEECH_TRANSCRIBER_REQUEST_CMD__['stop'], 296 | 'appkey': self.__appkey 297 | } 298 | __msg = { 299 | 'header': __header, 300 | 'context': util.GetDefaultContext() 301 | } 302 | __jmsg = json.dumps(__msg) 303 | with self.__start_cond: 304 | if not self.__start_flag: 305 | logging.debug('not start yet...') 306 | return 307 | self.__nls.send(__jmsg, False) 308 | if self.__start_flag == True: 309 | logging.debug('stop wait..') 310 | if self.__start_cond.wait(timeout): 311 | return 312 | else: 313 | raise StopTimeoutException(f'Waiting stop over {timeout}s') 314 | 315 | def ctrl(self, **kwargs): 316 | """ 317 | Send control message to cloud 318 | 319 | Parameters: 320 | ----------- 321 | kwargs: dict 322 | dict which will merge into 'payload' field in request 323 | """ 324 | if not kwargs: 325 | raise InvalidParameter('Empty kwargs not allowed!') 326 | __id4 = uuid.uuid4().hex 327 | __header = { 328 | 'message_id': __id4, 329 | 'task_id': self.__task_id, 330 | 'namespace': __SPEECH_TRANSCRIBER_NAMESPACE__, 331 | 'name': __SPEECH_TRANSCRIBER_REQUEST_CMD__['control'], 332 | 'appkey': self.__appkey 333 | } 334 | payload = {} 335 | payload.update(kwargs) 336 | __msg = { 337 | 'header': __header, 338 | 'payload': payload, 339 | 'context': util.GetDefaultContext() 340 | } 341 | __jmsg = json.dumps(__msg) 342 | with self.__start_cond: 343 | if not self.__start_flag: 344 | logging.debug('not start yet...') 345 | return 346 | self.__nls.send(__jmsg, False) 347 | 348 | def shutdown(self): 349 | """ 350 | Shutdown connection immediately 351 | """ 352 | self.__nls.shutdown() 353 | 354 | def send_audio(self, pcm_data): 355 | """ 356 | Send audio binary, audio size prefer 20ms length 357 | 358 | Parameters: 359 | ----------- 360 | pcm_data: bytes 361 | audio binary which format is 'aformat' in start method 362 | """ 363 | 364 | __data = pcm_data 365 | with self.__start_cond: 366 | if not self.__start_flag: 367 | return 368 | try: 369 | self.__nls.send(__data, True) 370 | except ConnectionResetError as __e: 371 | logging.error('connection reset') 372 | self.__start_flag = False 373 | self.__nls.shutdown() 374 | raise __e 375 | -------------------------------------------------------------------------------- /nls/websocket/tests/test_url.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | """ 4 | test_url.py 5 | websocket - WebSocket client library for Python 6 | 7 | Copyright 2021 engn33r 8 | 9 | Licensed under the Apache License, Version 2.0 (the "License"); 10 | you may not use this file except in compliance with the License. 11 | You may obtain a copy of the License at 12 | 13 | http://www.apache.org/licenses/LICENSE-2.0 14 | 15 | Unless required by applicable law or agreed to in writing, software 16 | distributed under the License is distributed on an "AS IS" BASIS, 17 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | See the License for the specific language governing permissions and 19 | limitations under the License. 20 | """ 21 | 22 | import os 23 | import unittest 24 | from websocket._url import get_proxy_info, parse_url, _is_address_in_network, _is_no_proxy_host 25 | 26 | 27 | class UrlTest(unittest.TestCase): 28 | 29 | def test_address_in_network(self): 30 | self.assertTrue(_is_address_in_network('127.0.0.1', '127.0.0.0/8')) 31 | self.assertTrue(_is_address_in_network('127.1.0.1', '127.0.0.0/8')) 32 | self.assertFalse(_is_address_in_network('127.1.0.1', '127.0.0.0/24')) 33 | 34 | def testParseUrl(self): 35 | p = parse_url("ws://www.example.com/r") 36 | self.assertEqual(p[0], "www.example.com") 37 | self.assertEqual(p[1], 80) 38 | self.assertEqual(p[2], "/r") 39 | self.assertEqual(p[3], False) 40 | 41 | p = parse_url("ws://www.example.com/r/") 42 | self.assertEqual(p[0], "www.example.com") 43 | self.assertEqual(p[1], 80) 44 | self.assertEqual(p[2], "/r/") 45 | self.assertEqual(p[3], False) 46 | 47 | p = parse_url("ws://www.example.com/") 48 | self.assertEqual(p[0], "www.example.com") 49 | self.assertEqual(p[1], 80) 50 | self.assertEqual(p[2], "/") 51 | self.assertEqual(p[3], False) 52 | 53 | p = parse_url("ws://www.example.com") 54 | self.assertEqual(p[0], "www.example.com") 55 | self.assertEqual(p[1], 80) 56 | self.assertEqual(p[2], "/") 57 | self.assertEqual(p[3], False) 58 | 59 | p = parse_url("ws://www.example.com:8080/r") 60 | self.assertEqual(p[0], "www.example.com") 61 | self.assertEqual(p[1], 8080) 62 | self.assertEqual(p[2], "/r") 63 | self.assertEqual(p[3], False) 64 | 65 | p = parse_url("ws://www.example.com:8080/") 66 | self.assertEqual(p[0], "www.example.com") 67 | self.assertEqual(p[1], 8080) 68 | self.assertEqual(p[2], "/") 69 | self.assertEqual(p[3], False) 70 | 71 | p = parse_url("ws://www.example.com:8080") 72 | self.assertEqual(p[0], "www.example.com") 73 | self.assertEqual(p[1], 8080) 74 | self.assertEqual(p[2], "/") 75 | self.assertEqual(p[3], False) 76 | 77 | p = parse_url("wss://www.example.com:8080/r") 78 | self.assertEqual(p[0], "www.example.com") 79 | self.assertEqual(p[1], 8080) 80 | self.assertEqual(p[2], "/r") 81 | self.assertEqual(p[3], True) 82 | 83 | p = parse_url("wss://www.example.com:8080/r?key=value") 84 | self.assertEqual(p[0], "www.example.com") 85 | self.assertEqual(p[1], 8080) 86 | self.assertEqual(p[2], "/r?key=value") 87 | self.assertEqual(p[3], True) 88 | 89 | self.assertRaises(ValueError, parse_url, "http://www.example.com/r") 90 | 91 | p = parse_url("ws://[2a03:4000:123:83::3]/r") 92 | self.assertEqual(p[0], "2a03:4000:123:83::3") 93 | self.assertEqual(p[1], 80) 94 | self.assertEqual(p[2], "/r") 95 | self.assertEqual(p[3], False) 96 | 97 | p = parse_url("ws://[2a03:4000:123:83::3]:8080/r") 98 | self.assertEqual(p[0], "2a03:4000:123:83::3") 99 | self.assertEqual(p[1], 8080) 100 | self.assertEqual(p[2], "/r") 101 | self.assertEqual(p[3], False) 102 | 103 | p = parse_url("wss://[2a03:4000:123:83::3]/r") 104 | self.assertEqual(p[0], "2a03:4000:123:83::3") 105 | self.assertEqual(p[1], 443) 106 | self.assertEqual(p[2], "/r") 107 | self.assertEqual(p[3], True) 108 | 109 | p = parse_url("wss://[2a03:4000:123:83::3]:8080/r") 110 | self.assertEqual(p[0], "2a03:4000:123:83::3") 111 | self.assertEqual(p[1], 8080) 112 | self.assertEqual(p[2], "/r") 113 | self.assertEqual(p[3], True) 114 | 115 | 116 | class IsNoProxyHostTest(unittest.TestCase): 117 | def setUp(self): 118 | self.no_proxy = os.environ.get("no_proxy", None) 119 | if "no_proxy" in os.environ: 120 | del os.environ["no_proxy"] 121 | 122 | def tearDown(self): 123 | if self.no_proxy: 124 | os.environ["no_proxy"] = self.no_proxy 125 | elif "no_proxy" in os.environ: 126 | del os.environ["no_proxy"] 127 | 128 | def testMatchAll(self): 129 | self.assertTrue(_is_no_proxy_host("any.websocket.org", ['*'])) 130 | self.assertTrue(_is_no_proxy_host("192.168.0.1", ['*'])) 131 | self.assertTrue(_is_no_proxy_host("any.websocket.org", ['other.websocket.org', '*'])) 132 | os.environ['no_proxy'] = '*' 133 | self.assertTrue(_is_no_proxy_host("any.websocket.org", None)) 134 | self.assertTrue(_is_no_proxy_host("192.168.0.1", None)) 135 | os.environ['no_proxy'] = 'other.websocket.org, *' 136 | self.assertTrue(_is_no_proxy_host("any.websocket.org", None)) 137 | 138 | def testIpAddress(self): 139 | self.assertTrue(_is_no_proxy_host("127.0.0.1", ['127.0.0.1'])) 140 | self.assertFalse(_is_no_proxy_host("127.0.0.2", ['127.0.0.1'])) 141 | self.assertTrue(_is_no_proxy_host("127.0.0.1", ['other.websocket.org', '127.0.0.1'])) 142 | self.assertFalse(_is_no_proxy_host("127.0.0.2", ['other.websocket.org', '127.0.0.1'])) 143 | os.environ['no_proxy'] = '127.0.0.1' 144 | self.assertTrue(_is_no_proxy_host("127.0.0.1", None)) 145 | self.assertFalse(_is_no_proxy_host("127.0.0.2", None)) 146 | os.environ['no_proxy'] = 'other.websocket.org, 127.0.0.1' 147 | self.assertTrue(_is_no_proxy_host("127.0.0.1", None)) 148 | self.assertFalse(_is_no_proxy_host("127.0.0.2", None)) 149 | 150 | def testIpAddressInRange(self): 151 | self.assertTrue(_is_no_proxy_host("127.0.0.1", ['127.0.0.0/8'])) 152 | self.assertTrue(_is_no_proxy_host("127.0.0.2", ['127.0.0.0/8'])) 153 | self.assertFalse(_is_no_proxy_host("127.1.0.1", ['127.0.0.0/24'])) 154 | os.environ['no_proxy'] = '127.0.0.0/8' 155 | self.assertTrue(_is_no_proxy_host("127.0.0.1", None)) 156 | self.assertTrue(_is_no_proxy_host("127.0.0.2", None)) 157 | os.environ['no_proxy'] = '127.0.0.0/24' 158 | self.assertFalse(_is_no_proxy_host("127.1.0.1", None)) 159 | 160 | def testHostnameMatch(self): 161 | self.assertTrue(_is_no_proxy_host("my.websocket.org", ['my.websocket.org'])) 162 | self.assertTrue(_is_no_proxy_host("my.websocket.org", ['other.websocket.org', 'my.websocket.org'])) 163 | self.assertFalse(_is_no_proxy_host("my.websocket.org", ['other.websocket.org'])) 164 | os.environ['no_proxy'] = 'my.websocket.org' 165 | self.assertTrue(_is_no_proxy_host("my.websocket.org", None)) 166 | self.assertFalse(_is_no_proxy_host("other.websocket.org", None)) 167 | os.environ['no_proxy'] = 'other.websocket.org, my.websocket.org' 168 | self.assertTrue(_is_no_proxy_host("my.websocket.org", None)) 169 | 170 | def testHostnameMatchDomain(self): 171 | self.assertTrue(_is_no_proxy_host("any.websocket.org", ['.websocket.org'])) 172 | self.assertTrue(_is_no_proxy_host("my.other.websocket.org", ['.websocket.org'])) 173 | self.assertTrue(_is_no_proxy_host("any.websocket.org", ['my.websocket.org', '.websocket.org'])) 174 | self.assertFalse(_is_no_proxy_host("any.websocket.com", ['.websocket.org'])) 175 | os.environ['no_proxy'] = '.websocket.org' 176 | self.assertTrue(_is_no_proxy_host("any.websocket.org", None)) 177 | self.assertTrue(_is_no_proxy_host("my.other.websocket.org", None)) 178 | self.assertFalse(_is_no_proxy_host("any.websocket.com", None)) 179 | os.environ['no_proxy'] = 'my.websocket.org, .websocket.org' 180 | self.assertTrue(_is_no_proxy_host("any.websocket.org", None)) 181 | 182 | 183 | class ProxyInfoTest(unittest.TestCase): 184 | def setUp(self): 185 | self.http_proxy = os.environ.get("http_proxy", None) 186 | self.https_proxy = os.environ.get("https_proxy", None) 187 | self.no_proxy = os.environ.get("no_proxy", None) 188 | if "http_proxy" in os.environ: 189 | del os.environ["http_proxy"] 190 | if "https_proxy" in os.environ: 191 | del os.environ["https_proxy"] 192 | if "no_proxy" in os.environ: 193 | del os.environ["no_proxy"] 194 | 195 | def tearDown(self): 196 | if self.http_proxy: 197 | os.environ["http_proxy"] = self.http_proxy 198 | elif "http_proxy" in os.environ: 199 | del os.environ["http_proxy"] 200 | 201 | if self.https_proxy: 202 | os.environ["https_proxy"] = self.https_proxy 203 | elif "https_proxy" in os.environ: 204 | del os.environ["https_proxy"] 205 | 206 | if self.no_proxy: 207 | os.environ["no_proxy"] = self.no_proxy 208 | elif "no_proxy" in os.environ: 209 | del os.environ["no_proxy"] 210 | 211 | def testProxyFromArgs(self): 212 | self.assertEqual(get_proxy_info("echo.websocket.org", False, proxy_host="localhost"), ("localhost", 0, None)) 213 | self.assertEqual(get_proxy_info("echo.websocket.org", False, proxy_host="localhost", proxy_port=3128), 214 | ("localhost", 3128, None)) 215 | self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost"), ("localhost", 0, None)) 216 | self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost", proxy_port=3128), 217 | ("localhost", 3128, None)) 218 | 219 | self.assertEqual(get_proxy_info("echo.websocket.org", False, proxy_host="localhost", proxy_auth=("a", "b")), 220 | ("localhost", 0, ("a", "b"))) 221 | self.assertEqual( 222 | get_proxy_info("echo.websocket.org", False, proxy_host="localhost", proxy_port=3128, proxy_auth=("a", "b")), 223 | ("localhost", 3128, ("a", "b"))) 224 | self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost", proxy_auth=("a", "b")), 225 | ("localhost", 0, ("a", "b"))) 226 | self.assertEqual( 227 | get_proxy_info("echo.websocket.org", True, proxy_host="localhost", proxy_port=3128, proxy_auth=("a", "b")), 228 | ("localhost", 3128, ("a", "b"))) 229 | 230 | self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost", proxy_port=3128, 231 | no_proxy=["example.com"], proxy_auth=("a", "b")), 232 | ("localhost", 3128, ("a", "b"))) 233 | self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost", proxy_port=3128, 234 | no_proxy=["echo.websocket.org"], proxy_auth=("a", "b")), 235 | (None, 0, None)) 236 | 237 | def testProxyFromEnv(self): 238 | os.environ["http_proxy"] = "http://localhost/" 239 | self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", None, None)) 240 | os.environ["http_proxy"] = "http://localhost:3128/" 241 | self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", 3128, None)) 242 | 243 | os.environ["http_proxy"] = "http://localhost/" 244 | os.environ["https_proxy"] = "http://localhost2/" 245 | self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", None, None)) 246 | os.environ["http_proxy"] = "http://localhost:3128/" 247 | os.environ["https_proxy"] = "http://localhost2:3128/" 248 | self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", 3128, None)) 249 | 250 | os.environ["http_proxy"] = "http://localhost/" 251 | os.environ["https_proxy"] = "http://localhost2/" 252 | self.assertEqual(get_proxy_info("echo.websocket.org", True), ("localhost2", None, None)) 253 | os.environ["http_proxy"] = "http://localhost:3128/" 254 | os.environ["https_proxy"] = "http://localhost2:3128/" 255 | self.assertEqual(get_proxy_info("echo.websocket.org", True), ("localhost2", 3128, None)) 256 | 257 | os.environ["http_proxy"] = "http://a:b@localhost/" 258 | self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", None, ("a", "b"))) 259 | os.environ["http_proxy"] = "http://a:b@localhost:3128/" 260 | self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", 3128, ("a", "b"))) 261 | 262 | os.environ["http_proxy"] = "http://a:b@localhost/" 263 | os.environ["https_proxy"] = "http://a:b@localhost2/" 264 | self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", None, ("a", "b"))) 265 | os.environ["http_proxy"] = "http://a:b@localhost:3128/" 266 | os.environ["https_proxy"] = "http://a:b@localhost2:3128/" 267 | self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", 3128, ("a", "b"))) 268 | 269 | os.environ["http_proxy"] = "http://a:b@localhost/" 270 | os.environ["https_proxy"] = "http://a:b@localhost2/" 271 | self.assertEqual(get_proxy_info("echo.websocket.org", True), ("localhost2", None, ("a", "b"))) 272 | os.environ["http_proxy"] = "http://a:b@localhost:3128/" 273 | os.environ["https_proxy"] = "http://a:b@localhost2:3128/" 274 | self.assertEqual(get_proxy_info("echo.websocket.org", True), ("localhost2", 3128, ("a", "b"))) 275 | 276 | os.environ["http_proxy"] = "http://john%40example.com:P%40SSWORD@localhost:3128/" 277 | os.environ["https_proxy"] = "http://john%40example.com:P%40SSWORD@localhost2:3128/" 278 | self.assertEqual(get_proxy_info("echo.websocket.org", True), ("localhost2", 3128, ("john@example.com", "P@SSWORD"))) 279 | 280 | os.environ["http_proxy"] = "http://a:b@localhost/" 281 | os.environ["https_proxy"] = "http://a:b@localhost2/" 282 | os.environ["no_proxy"] = "example1.com,example2.com" 283 | self.assertEqual(get_proxy_info("example.1.com", True), ("localhost2", None, ("a", "b"))) 284 | os.environ["http_proxy"] = "http://a:b@localhost:3128/" 285 | os.environ["https_proxy"] = "http://a:b@localhost2:3128/" 286 | os.environ["no_proxy"] = "example1.com,example2.com, echo.websocket.org" 287 | self.assertEqual(get_proxy_info("echo.websocket.org", True), (None, 0, None)) 288 | os.environ["http_proxy"] = "http://a:b@localhost:3128/" 289 | os.environ["https_proxy"] = "http://a:b@localhost2:3128/" 290 | os.environ["no_proxy"] = "example1.com,example2.com, .websocket.org" 291 | self.assertEqual(get_proxy_info("echo.websocket.org", True), (None, 0, None)) 292 | 293 | os.environ["http_proxy"] = "http://a:b@localhost:3128/" 294 | os.environ["https_proxy"] = "http://a:b@localhost2:3128/" 295 | os.environ["no_proxy"] = "127.0.0.0/8, 192.168.0.0/16" 296 | self.assertEqual(get_proxy_info("127.0.0.1", False), (None, 0, None)) 297 | self.assertEqual(get_proxy_info("192.168.1.1", False), (None, 0, None)) 298 | 299 | 300 | if __name__ == "__main__": 301 | unittest.main() 302 | -------------------------------------------------------------------------------- /nls/websocket/_app.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | """ 4 | 5 | """ 6 | _app.py 7 | websocket - WebSocket client library for Python 8 | 9 | Copyright 2021 engn33r 10 | 11 | Licensed under the Apache License, Version 2.0 (the "License"); 12 | you may not use this file except in compliance with the License. 13 | You may obtain a copy of the License at 14 | 15 | http://www.apache.org/licenses/LICENSE-2.0 16 | 17 | Unless required by applicable law or agreed to in writing, software 18 | distributed under the License is distributed on an "AS IS" BASIS, 19 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | See the License for the specific language governing permissions and 21 | limitations under the License. 22 | """ 23 | import selectors 24 | import sys 25 | import threading 26 | import time 27 | import traceback 28 | from ._abnf import ABNF 29 | from ._core import WebSocket, getdefaulttimeout 30 | from ._exceptions import * 31 | from . import _logging 32 | 33 | 34 | __all__ = ["WebSocketApp"] 35 | 36 | 37 | class Dispatcher: 38 | """ 39 | Dispatcher 40 | """ 41 | def __init__(self, app, ping_timeout): 42 | self.app = app 43 | self.ping_timeout = ping_timeout 44 | 45 | def read(self, sock, read_callback, check_callback): 46 | while self.app.keep_running: 47 | sel = selectors.DefaultSelector() 48 | sel.register(self.app.sock.sock, selectors.EVENT_READ) 49 | 50 | r = sel.select(self.ping_timeout) 51 | if r: 52 | if not read_callback(): 53 | break 54 | check_callback() 55 | sel.close() 56 | 57 | 58 | class SSLDispatcher: 59 | """ 60 | SSLDispatcher 61 | """ 62 | def __init__(self, app, ping_timeout): 63 | self.app = app 64 | self.ping_timeout = ping_timeout 65 | 66 | def read(self, sock, read_callback, check_callback): 67 | while self.app.keep_running: 68 | r = self.select() 69 | if r: 70 | if not read_callback(): 71 | break 72 | check_callback() 73 | 74 | def select(self): 75 | sock = self.app.sock.sock 76 | if sock.pending(): 77 | return [sock,] 78 | 79 | sel = selectors.DefaultSelector() 80 | sel.register(sock, selectors.EVENT_READ) 81 | 82 | r = sel.select(self.ping_timeout) 83 | sel.close() 84 | 85 | if len(r) > 0: 86 | return r[0][0] 87 | 88 | 89 | class WebSocketApp: 90 | """ 91 | Higher level of APIs are provided. The interface is like JavaScript WebSocket object. 92 | """ 93 | 94 | def __init__(self, url, header=None, 95 | on_open=None, on_message=None, on_error=None, 96 | on_close=None, on_ping=None, on_pong=None, 97 | on_cont_message=None, 98 | keep_running=True, get_mask_key=None, cookie=None, 99 | subprotocols=None, 100 | on_data=None, callback_args=[]): 101 | """ 102 | WebSocketApp initialization 103 | 104 | Parameters 105 | ---------- 106 | url: str 107 | Websocket url. 108 | header: list or dict 109 | Custom header for websocket handshake. 110 | on_open: function 111 | Callback object which is called at opening websocket. 112 | on_open has one argument. 113 | The 1st argument is this class object. 114 | on_message: function 115 | Callback object which is called when received data. 116 | on_message has 2 arguments. 117 | The 1st argument is this class object. 118 | The 2nd argument is utf-8 data received from the server. 119 | on_error: function 120 | Callback object which is called when we get error. 121 | on_error has 2 arguments. 122 | The 1st argument is this class object. 123 | The 2nd argument is exception object. 124 | on_close: function 125 | Callback object which is called when connection is closed. 126 | on_close has 3 arguments. 127 | The 1st argument is this class object. 128 | The 2nd argument is close_status_code. 129 | The 3rd argument is close_msg. 130 | on_cont_message: function 131 | Callback object which is called when a continuation 132 | frame is received. 133 | on_cont_message has 3 arguments. 134 | The 1st argument is this class object. 135 | The 2nd argument is utf-8 string which we get from the server. 136 | The 3rd argument is continue flag. if 0, the data continue 137 | to next frame data 138 | on_data: function 139 | Callback object which is called when a message received. 140 | This is called before on_message or on_cont_message, 141 | and then on_message or on_cont_message is called. 142 | on_data has 4 argument. 143 | The 1st argument is this class object. 144 | The 2nd argument is utf-8 string which we get from the server. 145 | The 3rd argument is data type. ABNF.OPCODE_TEXT or ABNF.OPCODE_BINARY will be came. 146 | The 4th argument is continue flag. If 0, the data continue 147 | keep_running: bool 148 | This parameter is obsolete and ignored. 149 | get_mask_key: function 150 | A callable function to get new mask keys, see the 151 | WebSocket.set_mask_key's docstring for more information. 152 | cookie: str 153 | Cookie value. 154 | subprotocols: list 155 | List of available sub protocols. Default is None. 156 | """ 157 | self.url = url 158 | self.header = header if header is not None else [] 159 | self.cookie = cookie 160 | 161 | self.on_open = on_open 162 | self.on_message = on_message 163 | self.on_data = on_data 164 | self.on_error = on_error 165 | self.on_close = on_close 166 | self.on_ping = on_ping 167 | self.on_pong = on_pong 168 | self.on_cont_message = on_cont_message 169 | self.keep_running = False 170 | self.get_mask_key = get_mask_key 171 | self.sock = None 172 | self.last_ping_tm = 0 173 | self.last_pong_tm = 0 174 | self.subprotocols = subprotocols 175 | self.callback_args = callback_args 176 | self.has_teardown = False 177 | 178 | def update_args(self, *args): 179 | self.callback_args = args 180 | #print(self.callback_args) 181 | 182 | def send(self, data, opcode=ABNF.OPCODE_TEXT): 183 | """ 184 | send message 185 | 186 | Parameters 187 | ---------- 188 | data: str 189 | Message to send. If you set opcode to OPCODE_TEXT, 190 | data must be utf-8 string or unicode. 191 | opcode: int 192 | Operation code of data. Default is OPCODE_TEXT. 193 | """ 194 | 195 | if not self.sock or self.sock.send(data, opcode) == 0: 196 | raise WebSocketConnectionClosedException( 197 | "Connection is already closed.") 198 | 199 | def close(self, **kwargs): 200 | """ 201 | Close websocket connection. 202 | """ 203 | self.keep_running = False 204 | if self.sock: 205 | self.sock.close(**kwargs) 206 | self.sock = None 207 | 208 | def _send_ping(self, interval, event, payload): 209 | while not event.wait(interval): 210 | self.last_ping_tm = time.time() 211 | if self.sock: 212 | try: 213 | self.sock.ping(payload) 214 | except Exception as ex: 215 | _logging.warning("send_ping routine terminated: {}".format(ex)) 216 | break 217 | 218 | def run_forever(self, sockopt=None, sslopt=None, 219 | ping_interval=0, ping_timeout=None, 220 | ping_payload="", 221 | http_proxy_host=None, http_proxy_port=None, 222 | http_no_proxy=None, http_proxy_auth=None, 223 | skip_utf8_validation=False, 224 | host=None, origin=None, dispatcher=None, 225 | suppress_origin=False, proxy_type=None): 226 | """ 227 | Run event loop for WebSocket framework. 228 | 229 | This loop is an infinite loop and is alive while websocket is available. 230 | 231 | Parameters 232 | ---------- 233 | sockopt: tuple 234 | Values for socket.setsockopt. 235 | sockopt must be tuple 236 | and each element is argument of sock.setsockopt. 237 | sslopt: dict 238 | Optional dict object for ssl socket option. 239 | ping_interval: int or float 240 | Automatically send "ping" command 241 | every specified period (in seconds). 242 | If set to 0, no ping is sent periodically. 243 | ping_timeout: int or float 244 | Timeout (in seconds) if the pong message is not received. 245 | ping_payload: str 246 | Payload message to send with each ping. 247 | http_proxy_host: str 248 | HTTP proxy host name. 249 | http_proxy_port: int or str 250 | HTTP proxy port. If not set, set to 80. 251 | http_no_proxy: list 252 | Whitelisted host names that don't use the proxy. 253 | skip_utf8_validation: bool 254 | skip utf8 validation. 255 | host: str 256 | update host header. 257 | origin: str 258 | update origin header. 259 | dispatcher: Dispatcher object 260 | customize reading data from socket. 261 | suppress_origin: bool 262 | suppress outputting origin header. 263 | 264 | Returns 265 | ------- 266 | teardown: bool 267 | False if caught KeyboardInterrupt, True if other exception was raised during a loop 268 | """ 269 | 270 | if ping_timeout is not None and ping_timeout <= 0: 271 | raise WebSocketException("Ensure ping_timeout > 0") 272 | if ping_interval is not None and ping_interval < 0: 273 | raise WebSocketException("Ensure ping_interval >= 0") 274 | if ping_timeout and ping_interval and ping_interval <= ping_timeout: 275 | raise WebSocketException("Ensure ping_interval > ping_timeout") 276 | if not sockopt: 277 | sockopt = [] 278 | if not sslopt: 279 | sslopt = {} 280 | if self.sock: 281 | raise WebSocketException("socket is already opened") 282 | thread = None 283 | self.keep_running = True 284 | self.last_ping_tm = 0 285 | self.last_pong_tm = 0 286 | 287 | def teardown(close_frame=None): 288 | """ 289 | Tears down the connection. 290 | 291 | Parameters 292 | ---------- 293 | close_frame: ABNF frame 294 | If close_frame is set, the on_close handler is invoked 295 | with the statusCode and reason from the provided frame. 296 | """ 297 | if self.has_teardown: 298 | return 299 | self.has_teardown = True 300 | if thread and thread.is_alive(): 301 | event.set() 302 | thread.join() 303 | self.keep_running = False 304 | if self.sock: 305 | self.sock.close() 306 | close_status_code, close_reason = self._get_close_args( 307 | close_frame if close_frame else None) 308 | self.sock = None 309 | 310 | # Finally call the callback AFTER all teardown is complete 311 | self._callback(self.on_close, close_status_code, close_reason, 312 | self.callback_args) 313 | 314 | try: 315 | self.sock = WebSocket( 316 | self.get_mask_key, sockopt=sockopt, sslopt=sslopt, 317 | fire_cont_frame=self.on_cont_message is not None, 318 | skip_utf8_validation=skip_utf8_validation, 319 | enable_multithread=True) 320 | self.sock.settimeout(getdefaulttimeout()) 321 | self.sock.connect( 322 | self.url, header=self.header, cookie=self.cookie, 323 | http_proxy_host=http_proxy_host, 324 | http_proxy_port=http_proxy_port, http_no_proxy=http_no_proxy, 325 | http_proxy_auth=http_proxy_auth, subprotocols=self.subprotocols, 326 | host=host, origin=origin, suppress_origin=suppress_origin, 327 | proxy_type=proxy_type) 328 | if not dispatcher: 329 | dispatcher = self.create_dispatcher(ping_timeout) 330 | 331 | self._callback(self.on_open, self.callback_args) 332 | 333 | if ping_interval: 334 | event = threading.Event() 335 | thread = threading.Thread( 336 | target=self._send_ping, args=(ping_interval, event, ping_payload)) 337 | thread.daemon = True 338 | thread.start() 339 | 340 | def read(): 341 | if not self.keep_running: 342 | return teardown() 343 | 344 | op_code, frame = self.sock.recv_data_frame(True) 345 | if op_code == ABNF.OPCODE_CLOSE: 346 | return teardown(frame) 347 | elif op_code == ABNF.OPCODE_PING: 348 | self._callback(self.on_ping, frame.data, self.callback_args) 349 | elif op_code == ABNF.OPCODE_PONG: 350 | self.last_pong_tm = time.time() 351 | self._callback(self.on_pong, frame.data, self.callback_args) 352 | elif op_code == ABNF.OPCODE_CONT and self.on_cont_message: 353 | self._callback(self.on_data, frame.data, 354 | frame.opcode, frame.fin, self.callback_args) 355 | self._callback(self.on_cont_message, 356 | frame.data, frame.fin, self.callback_args) 357 | else: 358 | data = frame.data 359 | if op_code == ABNF.OPCODE_TEXT: 360 | data = data.decode("utf-8") 361 | self._callback(self.on_message, data, self.callback_args) 362 | else: 363 | self._callback(self.on_data, data, frame.opcode, True, 364 | self.callback_args) 365 | 366 | return True 367 | 368 | def check(): 369 | if (ping_timeout): 370 | has_timeout_expired = time.time() - self.last_ping_tm > ping_timeout 371 | has_pong_not_arrived_after_last_ping = self.last_pong_tm - self.last_ping_tm < 0 372 | has_pong_arrived_too_late = self.last_pong_tm - self.last_ping_tm > ping_timeout 373 | 374 | if (self.last_ping_tm and 375 | has_timeout_expired and 376 | (has_pong_not_arrived_after_last_ping or has_pong_arrived_too_late)): 377 | raise WebSocketTimeoutException("ping/pong timed out") 378 | return True 379 | 380 | dispatcher.read(self.sock.sock, read, check) 381 | except (Exception, KeyboardInterrupt, SystemExit) as e: 382 | self._callback(self.on_error, e, self.callback_args) 383 | if isinstance(e, SystemExit): 384 | # propagate SystemExit further 385 | raise 386 | teardown() 387 | return not isinstance(e, KeyboardInterrupt) 388 | else: 389 | teardown() 390 | return True 391 | 392 | def create_dispatcher(self, ping_timeout): 393 | timeout = ping_timeout or 10 394 | if self.sock.is_ssl(): 395 | return SSLDispatcher(self, timeout) 396 | 397 | return Dispatcher(self, timeout) 398 | 399 | def _get_close_args(self, close_frame): 400 | """ 401 | _get_close_args extracts the close code and reason from the close body 402 | if it exists (RFC6455 says WebSocket Connection Close Code is optional) 403 | """ 404 | # Need to catch the case where close_frame is None 405 | # Otherwise the following if statement causes an error 406 | if not self.on_close or not close_frame: 407 | return [None, None] 408 | 409 | # Extract close frame status code 410 | if close_frame.data and len(close_frame.data) >= 2: 411 | close_status_code = 256 * close_frame.data[0] + close_frame.data[1] 412 | reason = close_frame.data[2:].decode('utf-8') 413 | return [close_status_code, reason] 414 | else: 415 | # Most likely reached this because len(close_frame_data.data) < 2 416 | return [None, None] 417 | 418 | def _callback(self, callback, *args): 419 | if callback: 420 | try: 421 | callback(self, *args) 422 | 423 | except Exception as e: 424 | _logging.error("error from callback {}: {}".format(callback, e)) 425 | if self.on_error: 426 | self.on_error(self, e) 427 | --------------------------------------------------------------------------------